Skip to content

Commit

Permalink
feat: agent pipelines (#131)
Browse files Browse the repository at this point in the history
* feat(chain): Initial prototype for agentic chain feature

* feat: Add chain error handling

* feat: Add parallel ops and rename `chain` to `pipeline`

* docs: Update example

* feat: Add extraction pipeline op

* docs: Add extraction pipeline example

* feat: Add `try_parallel!` pipeline op macro

* misc: Remove unused module

* style: cargo fmt

* test: fix typo in test

* test: fix typo in test #2

* test: Fix broken lookup op test

* feat: Add `Op::batch_call` and `TryOp::try_batch_call`

* test: Fix tests

* docs: Add more docstrings to agent pipeline ops

* docs: Add pipeline module level docs

* docs: improve pipeline docs

* style: clippy+fmt

* fix(pipelines): Type errors

* fix: Missing trait import in macros

* feat(pipeline): Add id and score to `lookup` op result

* docs(pipelines): Add more docstrings

* docs(pipelines): Update example

* test(mongodb): Fix flaky test again

* style: fmt

* test(mongodb): fix
  • Loading branch information
cvauclair authored Dec 17, 2024
1 parent 5dfa93b commit 9e132ac
Show file tree
Hide file tree
Showing 10 changed files with 2,097 additions and 2 deletions.
75 changes: 75 additions & 0 deletions rig-core/examples/chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use std::env;

use rig::{
embeddings::EmbeddingsBuilder,
parallel,
pipeline::{self, agent_ops::lookup, passthrough, Op},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::in_memory_store::InMemoryVectorStore,
};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = Client::new(&openai_api_key);

let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

// Create embeddings for our documents
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.document("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")?
.document("Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")?
.document("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")?
.build()
.await?;

// Create vector store with the embeddings
let vector_store = InMemoryVectorStore::from_documents(embeddings);

// Create vector store index
let index = vector_store.index(embedding_model);

let agent = openai_client.agent("gpt-4")
.preamble("
You are a dictionary assistant here to assist the user in understanding the meaning of words.
")
.build();

let chain = pipeline::new()
// Chain a parallel operation to the current chain. The parallel operation will
// perform a lookup operation to retrieve additional context from the user prompt
// while simultaneously applying a passthrough operation. The latter will allow
// us to forward the initial prompt to the next operation in the chain.
.chain(parallel!(
passthrough(),
lookup::<_, _, String>(index, 1), // Required to specify document type
))
// Chain a "map" operation to the current chain, which will combine the user
// prompt with the retrieved context documents to create the final prompt.
// If an error occurs during the lookup operation, we will log the error and
// simply return the initial prompt.
.map(|(prompt, maybe_docs)| match maybe_docs {
Ok(docs) => format!(
"Non standard word definitions:\n{}\n\n{}",
docs.into_iter()
.map(|(_, _, doc)| doc)
.collect::<Vec<_>>()
.join("\n"),
prompt,
),
Err(err) => {
println!("Error: {}! Prompting without additional context", err);
format!("{prompt}")
}
})
// Chain a "prompt" operation which will prompt out agent with the final prompt
.prompt(agent);

// Prompt the agent and print the response
let response = chain.call("What does \"glarb-glarb\" mean?").await?;

println!("{:?}", response);

Ok(())
}
88 changes: 88 additions & 0 deletions rig-core/examples/multi_extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use rig::{
pipeline::{self, agent_ops, TryOp},
providers::openai,
try_parallel,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

#[derive(Debug, Deserialize, JsonSchema, Serialize)]
/// A record containing extracted names
pub struct Names {
/// The names extracted from the text
pub names: Vec<String>,
}

#[derive(Debug, Deserialize, JsonSchema, Serialize)]
/// A record containing extracted topics
pub struct Topics {
/// The topics extracted from the text
pub topics: Vec<String>,
}

#[derive(Debug, Deserialize, JsonSchema, Serialize)]
/// A record containing extracted sentiment
pub struct Sentiment {
/// The sentiment of the text (-1 being negative, 1 being positive)
pub sentiment: f64,
/// The confidence of the sentiment
pub confidence: f64,
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
let openai = openai::Client::from_env();

let names_extractor = openai
.extractor::<Names>("gpt-4")
.preamble("Extract names (e.g.: of people, places) from the given text.")
.build();

let topics_extractor = openai
.extractor::<Topics>("gpt-4")
.preamble("Extract topics from the given text.")
.build();

let sentiment_extractor = openai
.extractor::<Sentiment>("gpt-4")
.preamble(
"Extract sentiment (and how confident you are of the sentiment) from the given text.",
)
.build();

// Create a chain that extracts names, topics, and sentiment from a given text
// using three different GPT-4 based extractors.
// The chain will output a formatted string containing the extracted information.
let chain = pipeline::new()
.chain(try_parallel!(
agent_ops::extract(names_extractor),
agent_ops::extract(topics_extractor),
agent_ops::extract(sentiment_extractor),
))
.map_ok(|(names, topics, sentiment)| {
format!(
"Extracted names: {names}\nExtracted topics: {topics}\nExtracted sentiment: {sentiment}",
names = names.names.join(", "),
topics = topics.topics.join(", "),
sentiment = sentiment.sentiment,
)
});

// Batch call the chain with up to 4 inputs concurrently
let response = chain
.try_batch_call(
4,
vec![
"Screw you Putin!",
"I love my dog, but I hate my cat.",
"I'm going to the store to buy some milk.",
],
)
.await?;

for response in response {
println!("Text analysis:\n{response}");
}

Ok(())
}
6 changes: 6 additions & 0 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ impl<M: CompletionModel> Prompt for Agent<M> {
}
}

impl<M: CompletionModel> Prompt for &Agent<M> {
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}

impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
match self.completion(prompt, chat_history).await?.send().await? {
Expand Down
1 change: 1 addition & 0 deletions rig-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ pub mod extractor;
pub(crate) mod json_utils;
pub mod loaders;
pub mod one_or_many;
pub mod pipeline;
pub mod providers;
pub mod tool;
pub mod vector_store;
Expand Down
Loading

0 comments on commit 9e132ac

Please sign in to comment.