Skip to content

Commit

Permalink
Adding dummy_response option, allows for dry_running with expected
Browse files Browse the repository at this point in the history
results, also helps with testing, added test
  • Loading branch information
HallerPatrick committed Aug 2, 2023
1 parent 52116d9 commit 05838d4
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
22 changes: 17 additions & 5 deletions src/ai_dataset_generator/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pathlib import Path
from collections import defaultdict
from typing import Any, Dict, Optional, Union, Tuple, List
from typing import Any, Callable, Dict, Optional, Union, Tuple, List
from tqdm import tqdm
from loguru import logger

Expand Down Expand Up @@ -53,14 +53,15 @@ def generate(
prompt_template: BasePrompt,
fewshot_dataset: Optional[Dataset] = None,
fewshot_examples_per_class: int = 1,
fewshot_label_sampling_strategy: str = None,
fewshot_sampling_column: str = None,
fewshot_label_sampling_strategy: Optional[str] = None,
fewshot_sampling_column: Optional[str] = None,
unlabeled_dataset: Optional[Dataset] = None,
return_unlabeled_dataset: bool = False,
max_prompt_calls: int = 10,
num_samples_to_generate: int = 10,
timeout_per_prompt: Optional[int] = None,
log_every_n_api_calls: int = 25,
dummy_response: Optional[Union[str, Callable]] = None
) -> Union[Dataset, Tuple[Dataset, Dataset]]:
"""Generate a dataset based on a prompt template and support examples.
Optionally, unlabeled examples can be provided to annotate unlabeled data.
Expand All @@ -81,6 +82,7 @@ class are used. "stratified" sampling strategy means that all classes in support
num_samples_to_generate (int, optional): Number of samples to generate. Defaults to 10.
timeout_per_prompt (Optional[int], optional): Timeout per prompt call. Defaults to None.
log_every_n_api_calls (int, optional): Log every n api calls. Defaults to 25.
dummy_response (Optional[Union[str, Callable]], optional): Dummy response for dry runs. Defaults to None.
Returns:
Union[Dataset, Tuple[Dataset, Dataset]]: Generated dataset or tuple of generated dataset and original
Expand All @@ -107,6 +109,7 @@ class are used. "stratified" sampling strategy means that all classes in support
num_samples_to_generate,
timeout_per_prompt,
log_every_n_api_calls,
dummy_response
)

if return_unlabeled_dataset:
Expand All @@ -115,7 +118,7 @@ class are used. "stratified" sampling strategy means that all classes in support
return generated_dataset

def _try_generate(
self, prompt_text: str, invocation_context: Dict,
self, prompt_text: str, invocation_context: Dict, dummy_response: Optional[Union[str, Callable]]
) -> Optional[str]:
"""Tries to generate a single example. Restrict the time spent on this.
Expand All @@ -128,6 +131,14 @@ def _try_generate(
Generated example
"""

if dummy_response:
if isinstance(dummy_response, str):
return dummy_response
if callable(dummy_response):
return dummy_response(prompt_text)
else:
raise ValueError("Dummy response must be a string or a callable")

# Haystack internally uses timeouts and retries, so we dont have to do it
# We dont catch authentification errors here, because we want to fail fast
try:
Expand All @@ -154,6 +165,7 @@ def _inner_generate_loop(
num_samples_to_generate: int,
timeout_per_prompt: Optional[int],
log_every_n_api_calls: int = 25,
dummy_response: Optional[Union[str, Callable]] = None
):
current_tries_left = self._max_tries
current_log_file = self._setup_log(prompt_template)
Expand Down Expand Up @@ -199,7 +211,7 @@ def _inner_generate_loop(
f"Invocation context: {invocation_context} \n"
)

prediction = self._try_generate(prompt_text, invocation_context)
prediction = self._try_generate(prompt_text, invocation_context, dummy_response)

if prediction is None:
current_tries_left -= 1
Expand Down
38 changes: 37 additions & 1 deletion tests/test_dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def setUp(self) -> None:
"text": ["This movie is great!", "This movie is bad!"],
"label": ["positive", "negative"]
})
self.generator = DatasetGenerator(PromptNode("google/flan-t5-small"))
self.generator = DatasetGenerator(None)

def test_simple_generation(self):
"""Test simple generation without fewshot examples."""
Expand All @@ -28,6 +28,7 @@ def test_simple_generation(self):
generated_dataset = self.generator.generate(
prompt_template=prompt,
max_prompt_calls=2,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
Expand All @@ -44,6 +45,7 @@ def test_simple_generation_with_label_options(self):
generated_dataset = self.generator.generate(
prompt_template=prompt,
max_prompt_calls=2,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
Expand All @@ -66,6 +68,7 @@ def test_generation_with_fewshot_examples(self):
fewshot_label_sampling_strategy="uniform",
fewshot_sampling_column="label",
max_prompt_calls=2,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
Expand Down Expand Up @@ -93,6 +96,7 @@ def test_annotation_with_fewshot_and_unlabeled_examples(self):
fewshot_label_sampling_strategy="stratified",
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=2,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
Expand Down Expand Up @@ -129,6 +133,7 @@ def test_tranlation(self):
# (default)
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=2,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
Expand Down Expand Up @@ -161,9 +166,40 @@ def test_textual_similarity(self):
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=2,
return_unlabeled_dataset=True,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
self.assertEqual(generated_dataset.features["sentence1"].dtype, "string")
self.assertEqual(generated_dataset.features["sentence2"].dtype, "string")
self.assertEqual(generated_dataset.features["label"].dtype, "string")

def test_dummy_response(self):

prompt = BasePrompt(
task_description="Generate a short movie review.",
)
generated_dataset = self.generator.generate(
prompt_template=prompt,
max_prompt_calls=2,
dummy_response=lambda _: "This is a dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
self.assertEqual(generated_dataset.features["text"].dtype, "string")
self.assertIn("text", generated_dataset.features)
self.assertEqual(generated_dataset[0]["text"], "This is a dummy movie review.")
self.assertEqual(generated_dataset[1]["text"], "This is a dummy movie review.")


generated_dataset = self.generator.generate(
prompt_template=prompt,
max_prompt_calls=2,
dummy_response="This is a dummy movie review as a string."
)

self.assertEqual(len(generated_dataset), 2)
self.assertEqual(generated_dataset.features["text"].dtype, "string")
self.assertIn("text", generated_dataset.features)
self.assertEqual(generated_dataset[0]["text"], "This is a dummy movie review as a string.")
self.assertEqual(generated_dataset[1]["text"], "This is a dummy movie review as a string.")

0 comments on commit 05838d4

Please sign in to comment.