-
Notifications
You must be signed in to change notification settings - Fork 181
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat(chain): Initial prototype for agentic chain feature * feat: Add chain error handling * feat: Add parallel ops and rename `chain` to `pipeline` * docs: Update example * feat: Add extraction pipeline op * docs: Add extraction pipeline example * feat: Add `try_parallel!` pipeline op macro * misc: Remove unused module * style: cargo fmt * test: fix typo in test * test: fix typo in test #2 * test: Fix broken lookup op test * feat: Add `Op::batch_call` and `TryOp::try_batch_call` * test: Fix tests * docs: Add more docstrings to agent pipeline ops * docs: Add pipeline module level docs * docs: improve pipeline docs * style: clippy+fmt * fix(pipelines): Type errors * fix: Missing trait import in macros * feat(pipeline): Add id and score to `lookup` op result * docs(pipelines): Add more docstrings * docs(pipelines): Update example * test(mongodb): Fix flaky test again * style: fmt * test(mongodb): fix
- Loading branch information
Showing
10 changed files
with
2,097 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
use std::env; | ||
|
||
use rig::{ | ||
embeddings::EmbeddingsBuilder, | ||
parallel, | ||
pipeline::{self, agent_ops::lookup, passthrough, Op}, | ||
providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, | ||
vector_store::in_memory_store::InMemoryVectorStore, | ||
}; | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), anyhow::Error> { | ||
// Create OpenAI client | ||
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); | ||
let openai_client = Client::new(&openai_api_key); | ||
|
||
let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); | ||
|
||
// Create embeddings for our documents | ||
let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) | ||
.document("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")? | ||
.document("Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")? | ||
.document("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")? | ||
.build() | ||
.await?; | ||
|
||
// Create vector store with the embeddings | ||
let vector_store = InMemoryVectorStore::from_documents(embeddings); | ||
|
||
// Create vector store index | ||
let index = vector_store.index(embedding_model); | ||
|
||
let agent = openai_client.agent("gpt-4") | ||
.preamble(" | ||
You are a dictionary assistant here to assist the user in understanding the meaning of words. | ||
") | ||
.build(); | ||
|
||
let chain = pipeline::new() | ||
// Chain a parallel operation to the current chain. The parallel operation will | ||
// perform a lookup operation to retrieve additional context from the user prompt | ||
// while simultaneously applying a passthrough operation. The latter will allow | ||
// us to forward the initial prompt to the next operation in the chain. | ||
.chain(parallel!( | ||
passthrough(), | ||
lookup::<_, _, String>(index, 1), // Required to specify document type | ||
)) | ||
// Chain a "map" operation to the current chain, which will combine the user | ||
// prompt with the retrieved context documents to create the final prompt. | ||
// If an error occurs during the lookup operation, we will log the error and | ||
// simply return the initial prompt. | ||
.map(|(prompt, maybe_docs)| match maybe_docs { | ||
Ok(docs) => format!( | ||
"Non standard word definitions:\n{}\n\n{}", | ||
docs.into_iter() | ||
.map(|(_, _, doc)| doc) | ||
.collect::<Vec<_>>() | ||
.join("\n"), | ||
prompt, | ||
), | ||
Err(err) => { | ||
println!("Error: {}! Prompting without additional context", err); | ||
format!("{prompt}") | ||
} | ||
}) | ||
// Chain a "prompt" operation which will prompt out agent with the final prompt | ||
.prompt(agent); | ||
|
||
// Prompt the agent and print the response | ||
let response = chain.call("What does \"glarb-glarb\" mean?").await?; | ||
|
||
println!("{:?}", response); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
use rig::{ | ||
pipeline::{self, agent_ops, TryOp}, | ||
providers::openai, | ||
try_parallel, | ||
}; | ||
use schemars::JsonSchema; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Deserialize, JsonSchema, Serialize)] | ||
/// A record containing extracted names | ||
pub struct Names { | ||
/// The names extracted from the text | ||
pub names: Vec<String>, | ||
} | ||
|
||
#[derive(Debug, Deserialize, JsonSchema, Serialize)] | ||
/// A record containing extracted topics | ||
pub struct Topics { | ||
/// The topics extracted from the text | ||
pub topics: Vec<String>, | ||
} | ||
|
||
#[derive(Debug, Deserialize, JsonSchema, Serialize)] | ||
/// A record containing extracted sentiment | ||
pub struct Sentiment { | ||
/// The sentiment of the text (-1 being negative, 1 being positive) | ||
pub sentiment: f64, | ||
/// The confidence of the sentiment | ||
pub confidence: f64, | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() -> anyhow::Result<()> { | ||
let openai = openai::Client::from_env(); | ||
|
||
let names_extractor = openai | ||
.extractor::<Names>("gpt-4") | ||
.preamble("Extract names (e.g.: of people, places) from the given text.") | ||
.build(); | ||
|
||
let topics_extractor = openai | ||
.extractor::<Topics>("gpt-4") | ||
.preamble("Extract topics from the given text.") | ||
.build(); | ||
|
||
let sentiment_extractor = openai | ||
.extractor::<Sentiment>("gpt-4") | ||
.preamble( | ||
"Extract sentiment (and how confident you are of the sentiment) from the given text.", | ||
) | ||
.build(); | ||
|
||
// Create a chain that extracts names, topics, and sentiment from a given text | ||
// using three different GPT-4 based extractors. | ||
// The chain will output a formatted string containing the extracted information. | ||
let chain = pipeline::new() | ||
.chain(try_parallel!( | ||
agent_ops::extract(names_extractor), | ||
agent_ops::extract(topics_extractor), | ||
agent_ops::extract(sentiment_extractor), | ||
)) | ||
.map_ok(|(names, topics, sentiment)| { | ||
format!( | ||
"Extracted names: {names}\nExtracted topics: {topics}\nExtracted sentiment: {sentiment}", | ||
names = names.names.join(", "), | ||
topics = topics.topics.join(", "), | ||
sentiment = sentiment.sentiment, | ||
) | ||
}); | ||
|
||
// Batch call the chain with up to 4 inputs concurrently | ||
let response = chain | ||
.try_batch_call( | ||
4, | ||
vec![ | ||
"Screw you Putin!", | ||
"I love my dog, but I hate my cat.", | ||
"I'm going to the store to buy some milk.", | ||
], | ||
) | ||
.await?; | ||
|
||
for response in response { | ||
println!("Text analysis:\n{response}"); | ||
} | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.