From 1a577f8515e5b1dd857465870d7c12ae69b27a3f Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Fri, 22 Nov 2024 11:39:18 -0800 Subject: [PATCH] Introduce Embeddings index, CompleteAndGrounded metric, Unbatchify utils (#1843) * Introduce Embeddings (faiss NN index), CompleteAndGrounded metric, and Unbatchify utils * adjust faiss import * adjust tests * adjust to dspy.Embedder --- dspy/__init__.py | 2 + dspy/clients/__init__.py | 2 +- dspy/clients/embedding.py | 76 ++++++++++++++---- dspy/evaluate/auto_evaluation.py | 129 ++++++++++++++++++++++++------- dspy/predict/knn.py | 4 +- dspy/retrievers/__init__.py | 1 + dspy/retrievers/embeddings.py | 83 ++++++++++++++++++++ dspy/utils/unbatchify.py | 111 ++++++++++++++++++++++++++ tests/clients/test_embedding.py | 8 +- 9 files changed, 364 insertions(+), 52 deletions(-) create mode 100644 dspy/retrievers/__init__.py create mode 100644 dspy/retrievers/embeddings.py create mode 100644 dspy/utils/unbatchify.py diff --git a/dspy/__init__.py b/dspy/__init__.py index 28e0a352b..9e3e85fd2 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -6,6 +6,8 @@ from .retrieve import * from .signatures import * +import dspy.retrievers + # Functional must be imported after primitives, predict and signatures from .functional import * # isort: skip from dspy.evaluate import Evaluate # isort: skip diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index dc10f865f..2fc0e2543 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -1,7 +1,7 @@ from .lm import LM from .provider import Provider, TrainingJob from .base_lm import BaseLM, inspect_history -from .embedding import Embedding +from .embedding import Embedder import litellm import os from pathlib import Path diff --git a/dspy/clients/embedding.py b/dspy/clients/embedding.py index eec41c32b..ec7c1174e 100644 --- a/dspy/clients/embedding.py +++ b/dspy/clients/embedding.py @@ -2,7 +2,7 @@ import numpy as np -class Embedding: +class Embedder: """DSPy embedding class. The class for computing embeddings for text inputs. This class provides a unified interface for both: @@ -10,7 +10,7 @@ class Embedding: 1. Hosted embedding models (e.g. OpenAI's text-embedding-3-small) via litellm integration 2. Custom embedding functions that you provide - For hosted models, simply pass the model name as a string (e.g. "openai/text-embedding-3-small"). The class will use + For hosted models, simply pass the model name as a string (e.g., "openai/text-embedding-3-small"). The class will use litellm to handle the API calls and caching. For custom embedding models, pass a callable function that: @@ -24,6 +24,9 @@ class Embedding: model: The embedding model to use. This can be either a string (representing the name of the hosted embedding model, must be an embedding model supported by litellm) or a callable that represents a custom embedding model. + batch_size (int, optional): The default batch size for processing inputs in batches. Defaults to 200. + caching (bool, optional): Whether to cache the embedding response when using a hosted model. Defaults to True. + **kwargs: Additional default keyword arguments to pass to the embedding model. Examples: Example 1: Using a hosted model. @@ -31,7 +34,7 @@ class Embedding: ```python import dspy - embedder = dspy.Embedding("openai/text-embedding-3-small") + embedder = dspy.Embedder("openai/text-embedding-3-small", batch_size=100) embeddings = embedder(["hello", "world"]) assert embeddings.shape == (2, 1536) @@ -41,37 +44,78 @@ class Embedding: ```python import dspy + import numpy as np def my_embedder(texts): return np.random.rand(len(texts), 10) - embedder = dspy.Embedding(my_embedder) - embeddings = embedder(["hello", "world"]) + embedder = dspy.Embedder(my_embedder) + embeddings = embedder(["hello", "world"], batch_size=1) assert embeddings.shape == (2, 10) ``` """ - def __init__(self, model): + def __init__(self, model, batch_size=200, caching=True, **kwargs): self.model = model + self.batch_size = batch_size + self.caching = caching + self.default_kwargs = kwargs - def __call__(self, inputs, caching=True, **kwargs): + def __call__(self, inputs, batch_size=None, caching=None, **kwargs): """Compute embeddings for the given inputs. Args: inputs: The inputs to compute embeddings for, can be a single string or a list of strings. - caching: Whether to cache the embedding response, only valid when using a hosted embedding model. - kwargs: Additional keyword arguments to pass to the embedding model. + batch_size (int, optional): The batch size for processing inputs. If None, defaults to the batch_size set during initialization. + caching (bool, optional): Whether to cache the embedding response when using a hosted model. If None, defaults to the caching setting from initialization. + **kwargs: Additional keyword arguments to pass to the embedding model. These will override the default kwargs provided during initialization. Returns: - A 2-D numpy array of embeddings, one embedding per row. + numpy.ndarray: If the input is a single string, returns a 1D numpy array representing the embedding. + If the input is a list of strings, returns a 2D numpy array of embeddings, one embedding per row. """ + if isinstance(inputs, str): + is_single_input = True inputs = [inputs] - if isinstance(self.model, str): - embedding_response = litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs) - return np.array([data["embedding"] for data in embedding_response.data], dtype=np.float32) - elif callable(self.model): - return np.array(self.model(inputs, **kwargs), dtype=np.float32) else: - raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.") + is_single_input = False + + assert all(isinstance(inp, str) for inp in inputs), "All inputs must be strings." + + if batch_size is None: + batch_size = self.batch_size + if caching is None: + caching = self.caching + + merged_kwargs = self.default_kwargs.copy() + merged_kwargs.update(kwargs) + + embeddings_list = [] + + def chunk(inputs_list, size): + for i in range(0, len(inputs_list), size): + yield inputs_list[i : i + size] + + for batch_inputs in chunk(inputs, batch_size): + if isinstance(self.model, str): + embedding_response = litellm.embedding( + model=self.model, input=batch_inputs, caching=caching, **merged_kwargs + ) + batch_embeddings = [data["embedding"] for data in embedding_response.data] + elif callable(self.model): + batch_embeddings = self.model(batch_inputs, **merged_kwargs) + else: + raise ValueError( + f"`model` in `dspy.Embedder` must be a string or a callable, but got {type(self.model)}." + ) + + embeddings_list.extend(batch_embeddings) + + embeddings = np.array(embeddings_list, dtype=np.float32) + + if is_single_input: + return embeddings[0] + else: + return embeddings diff --git a/dspy/evaluate/auto_evaluation.py b/dspy/evaluate/auto_evaluation.py index 38b02fe35..d96d58f21 100644 --- a/dspy/evaluate/auto_evaluation.py +++ b/dspy/evaluate/auto_evaluation.py @@ -14,14 +14,35 @@ class SemanticRecallPrecision(dspy.Signature): precision: float = dspy.OutputField(desc="fraction (out of 1.0) of system response covered by the ground truth") +class DecompositionalSemanticRecallPrecision(dspy.Signature): + """ + Compare a system's response to the ground truth to compute recall and precision of key ideas. + You will first enumerate key ideas in each response, discuss their overlap, and then report recall and precision. + """ + + question: str = dspy.InputField() + ground_truth: str = dspy.InputField() + system_response: str = dspy.InputField() + ground_truth_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the ground truth") + system_response_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the system response") + discussion: str = dspy.OutputField(desc="discussion of the overlap between ground truth and system response") + recall: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response") + precision: float = dspy.OutputField(desc="fraction (out of 1.0) of system response covered by the ground truth") + + def f1_score(precision, recall): + precision, recall = max(0.0, min(1.0, precision)), max(0.0, min(1.0, recall)) return 0.0 if precision + recall == 0 else 2 * (precision * recall) / (precision + recall) class SemanticF1(dspy.Module): - def __init__(self, threshold=0.66): + def __init__(self, threshold=0.66, decompositional=False): self.threshold = threshold - self.module = dspy.ChainOfThought(SemanticRecallPrecision) + + if decompositional: + self.module = dspy.ChainOfThought(DecompositionalSemanticRecallPrecision) + else: + self.module = dspy.ChainOfThought(SemanticRecallPrecision) def forward(self, example, pred, trace=None): scores = self.module(question=example.question, ground_truth=example.response, system_response=pred.response) @@ -30,42 +51,92 @@ def forward(self, example, pred, trace=None): return score if trace is None else score >= self.threshold -""" -Soon-to-be deprecated Signatures & Modules Below. -""" + +########### + + +class DecompositionalSemanticRecall(dspy.Signature): + """ + Estimate the completeness of a system's responses, against the ground truth. + You will first enumerate key ideas in each response, discuss their overlap, and then report completeness. + """ + + question: str = dspy.InputField() + ground_truth: str = dspy.InputField() + system_response: str = dspy.InputField() + ground_truth_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the ground truth") + system_response_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the system response") + discussion: str = dspy.OutputField(desc="discussion of the overlap between ground truth and system response") + completeness: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response") + + + +class DecompositionalGroundedness(dspy.Signature): + """ + Estimate the groundedness of a system's responses, against real retrieved documents written by people. + You will first enumerate whatever non-trivial or check-worthy claims are made in the system response, and then + discuss the extent to which some or all of them can be deduced from the retrieved context and basic commonsense. + """ + + question: str = dspy.InputField() + retrieved_context: str = dspy.InputField() + system_response: str = dspy.InputField() + system_response_claims: str = dspy.OutputField(desc="enumeration of non-trivial or check-worthy claims in the system response") + discussion: str = dspy.OutputField(desc="discussion of how supported the claims are by the retrieved context") + groundedness: float = dspy.OutputField(desc="fraction (out of 1.0) of system response supported by the retrieved context") + + +class CompleteAndGrounded(dspy.Module): + def __init__(self, threshold=0.66): + self.threshold = threshold + self.completeness_module = dspy.ChainOfThought(DecompositionalSemanticRecall) + self.groundedness_module = dspy.ChainOfThought(DecompositionalGroundedness) + + def forward(self, example, pred, trace=None): + completeness = self.completeness_module(question=example.question, ground_truth=example.response, system_response=pred.response) + groundedness = self.groundedness_module(question=example.question, retrieved_context=pred.context, system_response=pred.response) + score = f1_score(groundedness.groundedness, completeness.completeness) + + return score if trace is None else score >= self.threshold + + + +# """ +# Soon-to-be deprecated Signatures & Modules Below. +# """ -class AnswerCorrectnessSignature(dspy.Signature): - """Verify that the predicted answer matches the gold answer.""" +# class AnswerCorrectnessSignature(dspy.Signature): +# """Verify that the predicted answer matches the gold answer.""" - question = dspy.InputField() - gold_answer = dspy.InputField(desc="correct answer for question") - predicted_answer = dspy.InputField(desc="predicted answer for question") - is_correct = dspy.OutputField(desc="True or False") +# question = dspy.InputField() +# gold_answer = dspy.InputField(desc="correct answer for question") +# predicted_answer = dspy.InputField(desc="predicted answer for question") +# is_correct = dspy.OutputField(desc="True or False") -class AnswerCorrectness(dspy.Module): - def __init__(self): - super().__init__() - self.evaluate_correctness = dspy.ChainOfThought(AnswerCorrectnessSignature) +# class AnswerCorrectness(dspy.Module): +# def __init__(self): +# super().__init__() +# self.evaluate_correctness = dspy.ChainOfThought(AnswerCorrectnessSignature) - def forward(self, question, gold_answer, predicted_answer): - return self.evaluate_correctness(question=question, gold_answer=gold_answer, predicted_answer=predicted_answer) +# def forward(self, question, gold_answer, predicted_answer): +# return self.evaluate_correctness(question=question, gold_answer=gold_answer, predicted_answer=predicted_answer) -class AnswerFaithfulnessSignature(dspy.Signature): - """Verify that the predicted answer is based on the provided context.""" +# class AnswerFaithfulnessSignature(dspy.Signature): +# """Verify that the predicted answer is based on the provided context.""" - context = dspy.InputField(desc="relevant facts for producing answer") - question = dspy.InputField() - answer = dspy.InputField(desc="often between 1 and 5 words") - is_faithful = dspy.OutputField(desc="True or False") +# context = dspy.InputField(desc="relevant facts for producing answer") +# question = dspy.InputField() +# answer = dspy.InputField(desc="often between 1 and 5 words") +# is_faithful = dspy.OutputField(desc="True or False") -class AnswerFaithfulness(dspy.Module): - def __init__(self): - super().__init__() - self.evaluate_faithfulness = dspy.ChainOfThought(AnswerFaithfulnessSignature) +# class AnswerFaithfulness(dspy.Module): +# def __init__(self): +# super().__init__() +# self.evaluate_faithfulness = dspy.ChainOfThought(AnswerFaithfulnessSignature) - def forward(self, context, question, answer): - return self.evaluate_faithfulness(context=context, question=question, answer=answer) +# def forward(self, context, question, answer): +# return self.evaluate_faithfulness(context=context, question=question, answer=answer) diff --git a/dspy/predict/knn.py b/dspy/predict/knn.py index 434a07aaa..17a5a3fb7 100644 --- a/dspy/predict/knn.py +++ b/dspy/predict/knn.py @@ -13,7 +13,7 @@ def __init__(self, k: int, trainset: List[dsp.Example], vectorizer=None): Args: k: Number of nearest neighbors to retrieve trainset: List of training examples to search through - vectorizer: Optional dspy.Embedding for computing embeddings. If None, uses sentence-transformers. + vectorizer: Optional dspy.Embedder for computing embeddings. If None, uses sentence-transformers. Example: >>> trainset = [dsp.Example(input="hello", output="world"), ...] @@ -24,7 +24,7 @@ def __init__(self, k: int, trainset: List[dsp.Example], vectorizer=None): self.k = k self.trainset = trainset - self.embedding = vectorizer or dspy.Embedding(dsp.SentenceTransformersVectorizer()) + self.embedding = vectorizer or dspy.Embedder(dsp.SentenceTransformersVectorizer()) trainset_casted_to_vectorize = [ " | ".join([f"{key}: {value}" for key, value in example.items() if key in example._input_keys]) for example in self.trainset diff --git a/dspy/retrievers/__init__.py b/dspy/retrievers/__init__.py new file mode 100644 index 000000000..3fdc977bb --- /dev/null +++ b/dspy/retrievers/__init__.py @@ -0,0 +1 @@ +from .embeddings import Embeddings \ No newline at end of file diff --git a/dspy/retrievers/embeddings.py b/dspy/retrievers/embeddings.py new file mode 100644 index 000000000..75e1ff1fb --- /dev/null +++ b/dspy/retrievers/embeddings.py @@ -0,0 +1,83 @@ +import numpy as np +from typing import Any, List, Optional +from dspy.utils.unbatchify import Unbatchify + +# TODO: Add .save and .load methods! + + +class Embeddings: + def __init__( + self, + corpus: List[str], + embedder, + k: int = 5, + callbacks: Optional[List[Any]] = None, + cache: bool = False, + brute_force_threshold: int = 20_000, + normalize: bool = True + ): + assert cache is False, "Caching is not supported for embeddings-based retrievers" + + self.embedder = embedder + self.k = k + self.corpus = corpus + self.normalize = normalize + + self.corpus_embeddings = self.embedder(self.corpus) + self.corpus_embeddings = self._normalize(self.corpus_embeddings) if self.normalize else self.corpus_embeddings + + self.index = self._build_faiss() if len(corpus) >= brute_force_threshold else None + self.search_fn = Unbatchify(self._batch_forward) + + def __call__(self, query: str): + return self.forward(query) + + def forward(self, query: str): + import dspy + return dspy.Prediction(passages=self.search_fn(query)) + + def _batch_forward(self, queries: List[str]): + q_embeds = self.embedder(queries) + q_embeds = self._normalize(q_embeds) if self.normalize else q_embeds + + pids = self._faiss_search(q_embeds, self.k * 10) if self.index else None + pids = np.tile(np.arange(len(self.corpus)), (len(queries), 1)) if pids is None else pids + + return self._rerank_and_predict(q_embeds, pids) + + def _build_faiss(self): + nbytes = 32 + partitions = int(2 * np.sqrt(len(self.corpus))) + dim = self.corpus_embeddings.shape[1] + + try: + import faiss + except ImportError: + raise ImportError("Please `pip install faiss-cpu` or increase `brute_force_threshold` to avoid FAISS.") + + quantizer = faiss.IndexFlatL2(dim) + index = faiss.IndexIVFPQ(quantizer, dim, partitions, nbytes, 8) + + print(f"Training a {nbytes}-byte FAISS index with {partitions} partitions, based on " + f"{len(self.corpus)} x {dim}-dim embeddings") + index.train(self.corpus_embeddings) + index.add(self.corpus_embeddings) + index.nprobe = min(16, partitions) + + return index + + def _faiss_search(self, query_embeddings: np.ndarray, num_candidates: int): + return self.index.search(query_embeddings, num_candidates)[1] + + def _rerank_and_predict(self, q_embeds: np.ndarray, candidate_indices: np.ndarray): + candidate_embeddings = self.corpus_embeddings[candidate_indices] + scores = np.einsum('qd,qkd->qk', q_embeds, candidate_embeddings) + + top_k_indices = np.argsort(-scores, axis=1)[:, :self.k] + top_indices = candidate_indices[np.arange(len(q_embeds))[:, None], top_k_indices] + + return [[self.corpus[idx] for idx in indices] for indices in top_indices] + + def _normalize(self, embeddings: np.ndarray): + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + return embeddings / np.maximum(norms, 1e-10) diff --git a/dspy/utils/unbatchify.py b/dspy/utils/unbatchify.py new file mode 100644 index 000000000..bafdc8cb3 --- /dev/null +++ b/dspy/utils/unbatchify.py @@ -0,0 +1,111 @@ +import time +import queue +import threading +from typing import Any, Callable, List +from concurrent.futures import Future + +class Unbatchify: + def __init__( + self, + batch_fn: Callable[[List[Any]], List[Any]], + max_batch_size: int = 32, + max_wait_time: float = 0.1 + ): + """ + Initializes the Unbatchify. + + Args: + batch_fn: The batch-processing function that accepts a list of inputs and returns a list of outputs. + max_batch_size: The maximum number of items to include in a batch. + max_wait_time: The maximum time (in seconds) to wait for batch to fill before processing. + """ + + self.batch_fn = batch_fn + self.max_batch_size = max_batch_size + self.max_wait_time = max_wait_time + self.input_queue = queue.Queue() + self.stop_event = threading.Event() + self.worker_thread = threading.Thread(target=self._worker) + self.worker_thread.daemon = True # Ensures thread exits when main program exits + self.worker_thread.start() + + def __call__(self, input_item: Any) -> Any: + """ + Thread-safe function that accepts a single input and returns the corresponding output. + + Args: + input_item: The single input item to process. + + Returns: + The output corresponding to the input_item after processing through batch_fn. + """ + future = Future() + self.input_queue.put((input_item, future)) + try: + result = future.result() + except Exception as e: + raise e + return result + + def _worker(self): + """ + Worker thread that batches inputs and processes them using batch_fn. + """ + while not self.stop_event.is_set(): + batch = [] + futures = [] + start_time = time.time() + while len(batch) < self.max_batch_size and (time.time() - start_time) < self.max_wait_time: + try: + input_item, future = self.input_queue.get(timeout=self.max_wait_time) + batch.append(input_item) + futures.append(future) + except queue.Empty: + break + + if batch: + try: + outputs = self.batch_fn(batch) + for output, future in zip(outputs, futures): + future.set_result(output) + except Exception as e: + for future in futures: + future.set_exception(e) + else: + time.sleep(0.01) + + # Clean up remaining items when stopping + while True: + try: + _, future = self.input_queue.get_nowait() + future.set_exception(RuntimeError("Unbatchify is closed")) + except queue.Empty: + break + + print("Worker thread has been terminated.") + + def close(self): + """ + Stops the worker thread and cleans up resources. + """ + if not self.stop_event.is_set(): + self.stop_event.set() + self.worker_thread.join() + + def __enter__(self): + """ + Enables use as a context manager. + """ + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ + Ensures resources are cleaned up when exiting context. + """ + self.close() + + def __del__(self): + """ + Ensures the worker thread is terminated when the object is garbage collected. + """ + self.close() diff --git a/tests/clients/test_embedding.py b/tests/clients/test_embedding.py index d12850e52..0ac9e24ba 100644 --- a/tests/clients/test_embedding.py +++ b/tests/clients/test_embedding.py @@ -2,7 +2,7 @@ from unittest.mock import Mock, patch import numpy as np -from dspy.clients.embedding import Embedding +from dspy.clients.embedding import Embedder # Mock response format similar to litellm's embedding response. @@ -27,7 +27,7 @@ def test_litellm_embedding(): mock_litellm.return_value = MockEmbeddingResponse(mock_embeddings) # Create embedding instance and call it. - embedding = Embedding(model) + embedding = Embedder(model) result = embedding(inputs) # Verify litellm was called with correct parameters. @@ -51,7 +51,7 @@ def mock_embedding_fn(texts): return expected_embeddings # Create embedding instance with callable - embedding = Embedding(mock_embedding_fn) + embedding = Embedder(mock_embedding_fn) result = embedding(inputs) np.testing.assert_allclose(result, expected_embeddings) @@ -60,5 +60,5 @@ def mock_embedding_fn(texts): def test_invalid_model_type(): # Test that invalid model type raises ValueError with pytest.raises(ValueError): - embedding = Embedding(123) # Invalid model type + embedding = Embedder(123) # Invalid model type embedding(["test"])