Skip to content

Commit

Permalink
Introduce Embeddings index, CompleteAndGrounded metric, Unbatchify ut…
Browse files Browse the repository at this point in the history
…ils (stanfordnlp#1843)

* Introduce Embeddings (faiss NN index), CompleteAndGrounded metric, and Unbatchify utils

* adjust faiss import

* adjust tests

* adjust to dspy.Embedder
  • Loading branch information
okhat authored Nov 22, 2024
1 parent 44b3331 commit 1a577f8
Show file tree
Hide file tree
Showing 9 changed files with 364 additions and 52 deletions.
2 changes: 2 additions & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .retrieve import *
from .signatures import *

import dspy.retrievers

# Functional must be imported after primitives, predict and signatures
from .functional import * # isort: skip
from dspy.evaluate import Evaluate # isort: skip
Expand Down
2 changes: 1 addition & 1 deletion dspy/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .lm import LM
from .provider import Provider, TrainingJob
from .base_lm import BaseLM, inspect_history
from .embedding import Embedding
from .embedding import Embedder
import litellm
import os
from pathlib import Path
Expand Down
76 changes: 60 additions & 16 deletions dspy/clients/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import numpy as np


class Embedding:
class Embedder:
"""DSPy embedding class.
The class for computing embeddings for text inputs. This class provides a unified interface for both:
1. Hosted embedding models (e.g. OpenAI's text-embedding-3-small) via litellm integration
2. Custom embedding functions that you provide
For hosted models, simply pass the model name as a string (e.g. "openai/text-embedding-3-small"). The class will use
For hosted models, simply pass the model name as a string (e.g., "openai/text-embedding-3-small"). The class will use
litellm to handle the API calls and caching.
For custom embedding models, pass a callable function that:
Expand All @@ -24,14 +24,17 @@ class Embedding:
model: The embedding model to use. This can be either a string (representing the name of the hosted embedding
model, must be an embedding model supported by litellm) or a callable that represents a custom embedding
model.
batch_size (int, optional): The default batch size for processing inputs in batches. Defaults to 200.
caching (bool, optional): Whether to cache the embedding response when using a hosted model. Defaults to True.
**kwargs: Additional default keyword arguments to pass to the embedding model.
Examples:
Example 1: Using a hosted model.
```python
import dspy
embedder = dspy.Embedding("openai/text-embedding-3-small")
embedder = dspy.Embedder("openai/text-embedding-3-small", batch_size=100)
embeddings = embedder(["hello", "world"])
assert embeddings.shape == (2, 1536)
Expand All @@ -41,37 +44,78 @@ class Embedding:
```python
import dspy
import numpy as np
def my_embedder(texts):
return np.random.rand(len(texts), 10)
embedder = dspy.Embedding(my_embedder)
embeddings = embedder(["hello", "world"])
embedder = dspy.Embedder(my_embedder)
embeddings = embedder(["hello", "world"], batch_size=1)
assert embeddings.shape == (2, 10)
```
"""

def __init__(self, model):
def __init__(self, model, batch_size=200, caching=True, **kwargs):
self.model = model
self.batch_size = batch_size
self.caching = caching
self.default_kwargs = kwargs

def __call__(self, inputs, caching=True, **kwargs):
def __call__(self, inputs, batch_size=None, caching=None, **kwargs):
"""Compute embeddings for the given inputs.
Args:
inputs: The inputs to compute embeddings for, can be a single string or a list of strings.
caching: Whether to cache the embedding response, only valid when using a hosted embedding model.
kwargs: Additional keyword arguments to pass to the embedding model.
batch_size (int, optional): The batch size for processing inputs. If None, defaults to the batch_size set during initialization.
caching (bool, optional): Whether to cache the embedding response when using a hosted model. If None, defaults to the caching setting from initialization.
**kwargs: Additional keyword arguments to pass to the embedding model. These will override the default kwargs provided during initialization.
Returns:
A 2-D numpy array of embeddings, one embedding per row.
numpy.ndarray: If the input is a single string, returns a 1D numpy array representing the embedding.
If the input is a list of strings, returns a 2D numpy array of embeddings, one embedding per row.
"""

if isinstance(inputs, str):
is_single_input = True
inputs = [inputs]
if isinstance(self.model, str):
embedding_response = litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs)
return np.array([data["embedding"] for data in embedding_response.data], dtype=np.float32)
elif callable(self.model):
return np.array(self.model(inputs, **kwargs), dtype=np.float32)
else:
raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.")
is_single_input = False

assert all(isinstance(inp, str) for inp in inputs), "All inputs must be strings."

if batch_size is None:
batch_size = self.batch_size
if caching is None:
caching = self.caching

merged_kwargs = self.default_kwargs.copy()
merged_kwargs.update(kwargs)

embeddings_list = []

def chunk(inputs_list, size):
for i in range(0, len(inputs_list), size):
yield inputs_list[i : i + size]

for batch_inputs in chunk(inputs, batch_size):
if isinstance(self.model, str):
embedding_response = litellm.embedding(
model=self.model, input=batch_inputs, caching=caching, **merged_kwargs
)
batch_embeddings = [data["embedding"] for data in embedding_response.data]
elif callable(self.model):
batch_embeddings = self.model(batch_inputs, **merged_kwargs)
else:
raise ValueError(
f"`model` in `dspy.Embedder` must be a string or a callable, but got {type(self.model)}."
)

embeddings_list.extend(batch_embeddings)

embeddings = np.array(embeddings_list, dtype=np.float32)

if is_single_input:
return embeddings[0]
else:
return embeddings
129 changes: 100 additions & 29 deletions dspy/evaluate/auto_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,35 @@ class SemanticRecallPrecision(dspy.Signature):
precision: float = dspy.OutputField(desc="fraction (out of 1.0) of system response covered by the ground truth")


class DecompositionalSemanticRecallPrecision(dspy.Signature):
"""
Compare a system's response to the ground truth to compute recall and precision of key ideas.
You will first enumerate key ideas in each response, discuss their overlap, and then report recall and precision.
"""

question: str = dspy.InputField()
ground_truth: str = dspy.InputField()
system_response: str = dspy.InputField()
ground_truth_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the ground truth")
system_response_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the system response")
discussion: str = dspy.OutputField(desc="discussion of the overlap between ground truth and system response")
recall: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response")
precision: float = dspy.OutputField(desc="fraction (out of 1.0) of system response covered by the ground truth")


def f1_score(precision, recall):
precision, recall = max(0.0, min(1.0, precision)), max(0.0, min(1.0, recall))
return 0.0 if precision + recall == 0 else 2 * (precision * recall) / (precision + recall)


class SemanticF1(dspy.Module):
def __init__(self, threshold=0.66):
def __init__(self, threshold=0.66, decompositional=False):
self.threshold = threshold
self.module = dspy.ChainOfThought(SemanticRecallPrecision)

if decompositional:
self.module = dspy.ChainOfThought(DecompositionalSemanticRecallPrecision)
else:
self.module = dspy.ChainOfThought(SemanticRecallPrecision)

def forward(self, example, pred, trace=None):
scores = self.module(question=example.question, ground_truth=example.response, system_response=pred.response)
Expand All @@ -30,42 +51,92 @@ def forward(self, example, pred, trace=None):
return score if trace is None else score >= self.threshold


"""
Soon-to-be deprecated Signatures & Modules Below.
"""

###########


class DecompositionalSemanticRecall(dspy.Signature):
"""
Estimate the completeness of a system's responses, against the ground truth.
You will first enumerate key ideas in each response, discuss their overlap, and then report completeness.
"""

question: str = dspy.InputField()
ground_truth: str = dspy.InputField()
system_response: str = dspy.InputField()
ground_truth_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the ground truth")
system_response_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the system response")
discussion: str = dspy.OutputField(desc="discussion of the overlap between ground truth and system response")
completeness: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response")



class DecompositionalGroundedness(dspy.Signature):
"""
Estimate the groundedness of a system's responses, against real retrieved documents written by people.
You will first enumerate whatever non-trivial or check-worthy claims are made in the system response, and then
discuss the extent to which some or all of them can be deduced from the retrieved context and basic commonsense.
"""

question: str = dspy.InputField()
retrieved_context: str = dspy.InputField()
system_response: str = dspy.InputField()
system_response_claims: str = dspy.OutputField(desc="enumeration of non-trivial or check-worthy claims in the system response")
discussion: str = dspy.OutputField(desc="discussion of how supported the claims are by the retrieved context")
groundedness: float = dspy.OutputField(desc="fraction (out of 1.0) of system response supported by the retrieved context")


class CompleteAndGrounded(dspy.Module):
def __init__(self, threshold=0.66):
self.threshold = threshold
self.completeness_module = dspy.ChainOfThought(DecompositionalSemanticRecall)
self.groundedness_module = dspy.ChainOfThought(DecompositionalGroundedness)

def forward(self, example, pred, trace=None):
completeness = self.completeness_module(question=example.question, ground_truth=example.response, system_response=pred.response)
groundedness = self.groundedness_module(question=example.question, retrieved_context=pred.context, system_response=pred.response)
score = f1_score(groundedness.groundedness, completeness.completeness)

return score if trace is None else score >= self.threshold



# """
# Soon-to-be deprecated Signatures & Modules Below.
# """


class AnswerCorrectnessSignature(dspy.Signature):
"""Verify that the predicted answer matches the gold answer."""
# class AnswerCorrectnessSignature(dspy.Signature):
# """Verify that the predicted answer matches the gold answer."""

question = dspy.InputField()
gold_answer = dspy.InputField(desc="correct answer for question")
predicted_answer = dspy.InputField(desc="predicted answer for question")
is_correct = dspy.OutputField(desc="True or False")
# question = dspy.InputField()
# gold_answer = dspy.InputField(desc="correct answer for question")
# predicted_answer = dspy.InputField(desc="predicted answer for question")
# is_correct = dspy.OutputField(desc="True or False")


class AnswerCorrectness(dspy.Module):
def __init__(self):
super().__init__()
self.evaluate_correctness = dspy.ChainOfThought(AnswerCorrectnessSignature)
# class AnswerCorrectness(dspy.Module):
# def __init__(self):
# super().__init__()
# self.evaluate_correctness = dspy.ChainOfThought(AnswerCorrectnessSignature)

def forward(self, question, gold_answer, predicted_answer):
return self.evaluate_correctness(question=question, gold_answer=gold_answer, predicted_answer=predicted_answer)
# def forward(self, question, gold_answer, predicted_answer):
# return self.evaluate_correctness(question=question, gold_answer=gold_answer, predicted_answer=predicted_answer)


class AnswerFaithfulnessSignature(dspy.Signature):
"""Verify that the predicted answer is based on the provided context."""
# class AnswerFaithfulnessSignature(dspy.Signature):
# """Verify that the predicted answer is based on the provided context."""

context = dspy.InputField(desc="relevant facts for producing answer")
question = dspy.InputField()
answer = dspy.InputField(desc="often between 1 and 5 words")
is_faithful = dspy.OutputField(desc="True or False")
# context = dspy.InputField(desc="relevant facts for producing answer")
# question = dspy.InputField()
# answer = dspy.InputField(desc="often between 1 and 5 words")
# is_faithful = dspy.OutputField(desc="True or False")


class AnswerFaithfulness(dspy.Module):
def __init__(self):
super().__init__()
self.evaluate_faithfulness = dspy.ChainOfThought(AnswerFaithfulnessSignature)
# class AnswerFaithfulness(dspy.Module):
# def __init__(self):
# super().__init__()
# self.evaluate_faithfulness = dspy.ChainOfThought(AnswerFaithfulnessSignature)

def forward(self, context, question, answer):
return self.evaluate_faithfulness(context=context, question=question, answer=answer)
# def forward(self, context, question, answer):
# return self.evaluate_faithfulness(context=context, question=question, answer=answer)
4 changes: 2 additions & 2 deletions dspy/predict/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, k: int, trainset: List[dsp.Example], vectorizer=None):
Args:
k: Number of nearest neighbors to retrieve
trainset: List of training examples to search through
vectorizer: Optional dspy.Embedding for computing embeddings. If None, uses sentence-transformers.
vectorizer: Optional dspy.Embedder for computing embeddings. If None, uses sentence-transformers.
Example:
>>> trainset = [dsp.Example(input="hello", output="world"), ...]
Expand All @@ -24,7 +24,7 @@ def __init__(self, k: int, trainset: List[dsp.Example], vectorizer=None):

self.k = k
self.trainset = trainset
self.embedding = vectorizer or dspy.Embedding(dsp.SentenceTransformersVectorizer())
self.embedding = vectorizer or dspy.Embedder(dsp.SentenceTransformersVectorizer())
trainset_casted_to_vectorize = [
" | ".join([f"{key}: {value}" for key, value in example.items() if key in example._input_keys])
for example in self.trainset
Expand Down
1 change: 1 addition & 0 deletions dspy/retrievers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .embeddings import Embeddings
Loading

0 comments on commit 1a577f8

Please sign in to comment.