Skip to content

Commit

Permalink
Merge pull request #57 from flairNLP/dry_run
Browse files Browse the repository at this point in the history
Adding dummy_response option, allows for dry running
  • Loading branch information
HallerPatrick authored Aug 2, 2023
2 parents ec27a4a + fd5decf commit 59a0fa6
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 7 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
datasets
farm-haystack>=1.18.0
farm-haystack[inference]
spacy
loguru
27 changes: 22 additions & 5 deletions src/ai_dataset_generator/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pathlib import Path
from collections import defaultdict
from typing import Any, Dict, Optional, Union, Tuple, List
from typing import Any, Callable, Dict, Optional, Union, Tuple, List
from tqdm import tqdm
from loguru import logger

Expand Down Expand Up @@ -53,14 +53,15 @@ def generate(
prompt_template: BasePrompt,
fewshot_dataset: Optional[Dataset] = None,
fewshot_examples_per_class: int = 1,
fewshot_label_sampling_strategy: str = None,
fewshot_sampling_column: str = None,
fewshot_label_sampling_strategy: Optional[str] = None,
fewshot_sampling_column: Optional[str] = None,
unlabeled_dataset: Optional[Dataset] = None,
return_unlabeled_dataset: bool = False,
max_prompt_calls: int = 10,
num_samples_to_generate: int = 10,
timeout_per_prompt: Optional[int] = None,
log_every_n_api_calls: int = 25,
dummy_response: Optional[Union[str, Callable]] = None
) -> Union[Dataset, Tuple[Dataset, Dataset]]:
"""Generate a dataset based on a prompt template and support examples.
Optionally, unlabeled examples can be provided to annotate unlabeled data.
Expand All @@ -81,6 +82,7 @@ class are used. "stratified" sampling strategy means that all classes in support
num_samples_to_generate (int, optional): Number of samples to generate. Defaults to 10.
timeout_per_prompt (Optional[int], optional): Timeout per prompt call. Defaults to None.
log_every_n_api_calls (int, optional): Log every n api calls. Defaults to 25.
dummy_response (Optional[Union[str, Callable]], optional): Dummy response for dry runs. Defaults to None.
Returns:
Union[Dataset, Tuple[Dataset, Dataset]]: Generated dataset or tuple of generated dataset and original
Expand All @@ -107,6 +109,7 @@ class are used. "stratified" sampling strategy means that all classes in support
num_samples_to_generate,
timeout_per_prompt,
log_every_n_api_calls,
dummy_response
)

if return_unlabeled_dataset:
Expand All @@ -115,7 +118,7 @@ class are used. "stratified" sampling strategy means that all classes in support
return generated_dataset

def _try_generate(
self, prompt_text: str, invocation_context: Dict,
self, prompt_text: str, invocation_context: Dict, dummy_response: Optional[Union[str, Callable]]
) -> Optional[str]:
"""Tries to generate a single example. Restrict the time spent on this.
Expand All @@ -128,6 +131,19 @@ def _try_generate(
Generated example
"""

if dummy_response:

if isinstance(dummy_response, str):
logger.info(f"Returning dummy response: {dummy_response}")
return dummy_response

if callable(dummy_response):
dummy_value = dummy_response(prompt_text)
logger.info(f"Returning dummy response: {dummy_response}")
return dummy_value

raise ValueError("Dummy response must be a string or a callable")

# Haystack internally uses timeouts and retries, so we dont have to do it
# We dont catch authentification errors here, because we want to fail fast
try:
Expand All @@ -154,6 +170,7 @@ def _inner_generate_loop(
num_samples_to_generate: int,
timeout_per_prompt: Optional[int],
log_every_n_api_calls: int = 25,
dummy_response: Optional[Union[str, Callable]] = None
):
current_tries_left = self._max_tries
current_log_file = self._setup_log(prompt_template)
Expand Down Expand Up @@ -199,7 +216,7 @@ def _inner_generate_loop(
f"Invocation context: {invocation_context} \n"
)

prediction = self._try_generate(prompt_text, invocation_context)
prediction = self._try_generate(prompt_text, invocation_context, dummy_response)

if prediction is None:
current_tries_left -= 1
Expand Down
40 changes: 39 additions & 1 deletion tests/test_dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def setUp(self) -> None:
"text": ["This movie is great!", "This movie is bad!"],
"label": ["positive", "negative"]
})
self.generator = DatasetGenerator(PromptNode("google/flan-t5-small"))

# We are using dummy respones here, because we are not testing the LLM itself.
self.generator = DatasetGenerator(None)

def test_simple_generation(self):
"""Test simple generation without fewshot examples."""
Expand All @@ -28,6 +30,7 @@ def test_simple_generation(self):
generated_dataset = self.generator.generate(
prompt_template=prompt,
max_prompt_calls=2,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
Expand All @@ -44,6 +47,7 @@ def test_simple_generation_with_label_options(self):
generated_dataset = self.generator.generate(
prompt_template=prompt,
max_prompt_calls=2,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
Expand All @@ -66,6 +70,7 @@ def test_generation_with_fewshot_examples(self):
fewshot_label_sampling_strategy="uniform",
fewshot_sampling_column="label",
max_prompt_calls=2,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
Expand Down Expand Up @@ -93,6 +98,7 @@ def test_annotation_with_fewshot_and_unlabeled_examples(self):
fewshot_label_sampling_strategy="stratified",
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=2,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
Expand Down Expand Up @@ -129,6 +135,7 @@ def test_tranlation(self):
# (default)
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=2,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
Expand Down Expand Up @@ -161,9 +168,40 @@ def test_textual_similarity(self):
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=2,
return_unlabeled_dataset=True,
dummy_response="A dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
self.assertEqual(generated_dataset.features["sentence1"].dtype, "string")
self.assertEqual(generated_dataset.features["sentence2"].dtype, "string")
self.assertEqual(generated_dataset.features["label"].dtype, "string")

def test_dummy_response(self):

prompt = BasePrompt(
task_description="Generate a short movie review.",
)
generated_dataset = self.generator.generate(
prompt_template=prompt,
max_prompt_calls=2,
dummy_response=lambda _: "This is a dummy movie review."
)

self.assertEqual(len(generated_dataset), 2)
self.assertEqual(generated_dataset.features["text"].dtype, "string")
self.assertIn("text", generated_dataset.features)
self.assertEqual(generated_dataset[0]["text"], "This is a dummy movie review.")
self.assertEqual(generated_dataset[1]["text"], "This is a dummy movie review.")


generated_dataset = self.generator.generate(
prompt_template=prompt,
max_prompt_calls=2,
dummy_response="This is a dummy movie review as a string."
)

self.assertEqual(len(generated_dataset), 2)
self.assertEqual(generated_dataset.features["text"].dtype, "string")
self.assertIn("text", generated_dataset.features)
self.assertEqual(generated_dataset[0]["text"], "This is a dummy movie review as a string.")
self.assertEqual(generated_dataset[1]["text"], "This is a dummy movie review as a string.")

0 comments on commit 59a0fa6

Please sign in to comment.