diff --git a/requirements.txt b/requirements.txt index d2ca408..66e5988 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ datasets farm-haystack>=1.18.0 -farm-haystack[inference] spacy loguru diff --git a/src/ai_dataset_generator/dataset_generator.py b/src/ai_dataset_generator/dataset_generator.py index 2e02482..0504c26 100644 --- a/src/ai_dataset_generator/dataset_generator.py +++ b/src/ai_dataset_generator/dataset_generator.py @@ -3,7 +3,7 @@ from pathlib import Path from collections import defaultdict -from typing import Any, Dict, Optional, Union, Tuple, List +from typing import Any, Callable, Dict, Optional, Union, Tuple, List from tqdm import tqdm from loguru import logger @@ -53,14 +53,15 @@ def generate( prompt_template: BasePrompt, fewshot_dataset: Optional[Dataset] = None, fewshot_examples_per_class: int = 1, - fewshot_label_sampling_strategy: str = None, - fewshot_sampling_column: str = None, + fewshot_label_sampling_strategy: Optional[str] = None, + fewshot_sampling_column: Optional[str] = None, unlabeled_dataset: Optional[Dataset] = None, return_unlabeled_dataset: bool = False, max_prompt_calls: int = 10, num_samples_to_generate: int = 10, timeout_per_prompt: Optional[int] = None, log_every_n_api_calls: int = 25, + dummy_response: Optional[Union[str, Callable]] = None ) -> Union[Dataset, Tuple[Dataset, Dataset]]: """Generate a dataset based on a prompt template and support examples. Optionally, unlabeled examples can be provided to annotate unlabeled data. @@ -81,6 +82,7 @@ class are used. "stratified" sampling strategy means that all classes in support num_samples_to_generate (int, optional): Number of samples to generate. Defaults to 10. timeout_per_prompt (Optional[int], optional): Timeout per prompt call. Defaults to None. log_every_n_api_calls (int, optional): Log every n api calls. Defaults to 25. + dummy_response (Optional[Union[str, Callable]], optional): Dummy response for dry runs. Defaults to None. Returns: Union[Dataset, Tuple[Dataset, Dataset]]: Generated dataset or tuple of generated dataset and original @@ -107,6 +109,7 @@ class are used. "stratified" sampling strategy means that all classes in support num_samples_to_generate, timeout_per_prompt, log_every_n_api_calls, + dummy_response ) if return_unlabeled_dataset: @@ -115,7 +118,7 @@ class are used. "stratified" sampling strategy means that all classes in support return generated_dataset def _try_generate( - self, prompt_text: str, invocation_context: Dict, + self, prompt_text: str, invocation_context: Dict, dummy_response: Optional[Union[str, Callable]] ) -> Optional[str]: """Tries to generate a single example. Restrict the time spent on this. @@ -128,6 +131,19 @@ def _try_generate( Generated example """ + if dummy_response: + + if isinstance(dummy_response, str): + logger.info(f"Returning dummy response: {dummy_response}") + return dummy_response + + if callable(dummy_response): + dummy_value = dummy_response(prompt_text) + logger.info(f"Returning dummy response: {dummy_response}") + return dummy_value + + raise ValueError("Dummy response must be a string or a callable") + # Haystack internally uses timeouts and retries, so we dont have to do it # We dont catch authentification errors here, because we want to fail fast try: @@ -154,6 +170,7 @@ def _inner_generate_loop( num_samples_to_generate: int, timeout_per_prompt: Optional[int], log_every_n_api_calls: int = 25, + dummy_response: Optional[Union[str, Callable]] = None ): current_tries_left = self._max_tries current_log_file = self._setup_log(prompt_template) @@ -199,7 +216,7 @@ def _inner_generate_loop( f"Invocation context: {invocation_context} \n" ) - prediction = self._try_generate(prompt_text, invocation_context) + prediction = self._try_generate(prompt_text, invocation_context, dummy_response) if prediction is None: current_tries_left -= 1 diff --git a/tests/test_dataset_generator.py b/tests/test_dataset_generator.py index d6764de..63c5e28 100644 --- a/tests/test_dataset_generator.py +++ b/tests/test_dataset_generator.py @@ -17,7 +17,9 @@ def setUp(self) -> None: "text": ["This movie is great!", "This movie is bad!"], "label": ["positive", "negative"] }) - self.generator = DatasetGenerator(PromptNode("google/flan-t5-small")) + + # We are using dummy respones here, because we are not testing the LLM itself. + self.generator = DatasetGenerator(None) def test_simple_generation(self): """Test simple generation without fewshot examples.""" @@ -28,6 +30,7 @@ def test_simple_generation(self): generated_dataset = self.generator.generate( prompt_template=prompt, max_prompt_calls=2, + dummy_response="A dummy movie review." ) self.assertEqual(len(generated_dataset), 2) @@ -44,6 +47,7 @@ def test_simple_generation_with_label_options(self): generated_dataset = self.generator.generate( prompt_template=prompt, max_prompt_calls=2, + dummy_response="A dummy movie review." ) self.assertEqual(len(generated_dataset), 2) @@ -66,6 +70,7 @@ def test_generation_with_fewshot_examples(self): fewshot_label_sampling_strategy="uniform", fewshot_sampling_column="label", max_prompt_calls=2, + dummy_response="A dummy movie review." ) self.assertEqual(len(generated_dataset), 2) @@ -93,6 +98,7 @@ def test_annotation_with_fewshot_and_unlabeled_examples(self): fewshot_label_sampling_strategy="stratified", unlabeled_dataset=unlabeled_dataset, max_prompt_calls=2, + dummy_response="A dummy movie review." ) self.assertEqual(len(generated_dataset), 2) @@ -129,6 +135,7 @@ def test_tranlation(self): # (default) unlabeled_dataset=unlabeled_dataset, max_prompt_calls=2, + dummy_response="A dummy movie review." ) self.assertEqual(len(generated_dataset), 2) @@ -161,9 +168,40 @@ def test_textual_similarity(self): unlabeled_dataset=unlabeled_dataset, max_prompt_calls=2, return_unlabeled_dataset=True, + dummy_response="A dummy movie review." ) self.assertEqual(len(generated_dataset), 2) self.assertEqual(generated_dataset.features["sentence1"].dtype, "string") self.assertEqual(generated_dataset.features["sentence2"].dtype, "string") self.assertEqual(generated_dataset.features["label"].dtype, "string") + + def test_dummy_response(self): + + prompt = BasePrompt( + task_description="Generate a short movie review.", + ) + generated_dataset = self.generator.generate( + prompt_template=prompt, + max_prompt_calls=2, + dummy_response=lambda _: "This is a dummy movie review." + ) + + self.assertEqual(len(generated_dataset), 2) + self.assertEqual(generated_dataset.features["text"].dtype, "string") + self.assertIn("text", generated_dataset.features) + self.assertEqual(generated_dataset[0]["text"], "This is a dummy movie review.") + self.assertEqual(generated_dataset[1]["text"], "This is a dummy movie review.") + + + generated_dataset = self.generator.generate( + prompt_template=prompt, + max_prompt_calls=2, + dummy_response="This is a dummy movie review as a string." + ) + + self.assertEqual(len(generated_dataset), 2) + self.assertEqual(generated_dataset.features["text"].dtype, "string") + self.assertIn("text", generated_dataset.features) + self.assertEqual(generated_dataset[0]["text"], "This is a dummy movie review as a string.") + self.assertEqual(generated_dataset[1]["text"], "This is a dummy movie review as a string.")