Skip to content

Commit

Permalink
improved fewshot sampling naming convention + label options / no fews…
Browse files Browse the repository at this point in the history
…hot examples fix. added new tests.
  • Loading branch information
whoisjones committed Aug 4, 2023
1 parent dd4739d commit 18dcd6f
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 37 deletions.
2 changes: 1 addition & 1 deletion paper_experiments/mrpc_annotate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def run():
prompt_template=prompt,
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=2,
fewshot_label_sampling_strategy="stratified",
fewshot_sampling_strategy="stratified",
unlabeled_dataset=annotation_dataset,
max_prompt_calls=len(annotation_dataset),
return_unlabeled_dataset=True
Expand Down
2 changes: 1 addition & 1 deletion paper_experiments/snli_annotate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run():
prompt_template=prompt,
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=2,
fewshot_label_sampling_strategy="stratified",
fewshot_sampling_strategy="stratified",
unlabeled_dataset=annotation_dataset,
max_prompt_calls=len(annotation_dataset),
return_unlabeled_dataset=True
Expand Down
2 changes: 1 addition & 1 deletion paper_experiments/trec_annotate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run():
prompt_template=prompt,
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=2,
fewshot_label_sampling_strategy="stratified",
fewshot_sampling_strategy="stratified",
unlabeled_dataset=annotation_dataset,
max_prompt_calls=len(annotation_dataset),
return_unlabeled_dataset=True
Expand Down
2 changes: 1 addition & 1 deletion paper_experiments/trec_generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def run():
prompt_template=prompt,
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=3,
fewshot_label_sampling_strategy="uniform",
fewshot_sampling_strategy="uniform",
fewshot_sampling_column="coarse_label",
max_prompt_calls=10000,
num_samples_to_generate=10000,
Expand Down
2 changes: 1 addition & 1 deletion paper_experiments/trec_hyperparameter_generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run():
prompt_template=prompt,
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=fewshot_example_per_class,
fewshot_label_sampling_strategy="uniform",
fewshot_sampling_strategy="uniform",
fewshot_sampling_column="coarse_label",
max_prompt_calls=500,
num_samples_to_generate=500,
Expand Down
46 changes: 27 additions & 19 deletions src/ai_dataset_generator/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def generate(
self,
prompt_template: BasePrompt,
fewshot_dataset: Optional[Dataset] = None,
fewshot_examples_per_class: int = 1,
fewshot_label_sampling_strategy: Optional[str] = None,
fewshot_sampling_strategy: Optional[str] = None,
fewshot_examples_per_class: int = None,
fewshot_sampling_column: Optional[str] = None,
unlabeled_dataset: Optional[Dataset] = None,
return_unlabeled_dataset: bool = False,
Expand All @@ -69,11 +69,13 @@ def generate(
Args:
prompt_template (BasePrompt): Prompt template to generate the dataset with.
fewshot_dataset (Dataset): Support examples to generate the dataset from. Defaults to None.
fewshot_sampling_strategy (str, optional): Sampling strategy for support examples.
Defaults to None and means all fewshot examples are used or limited by number of
fewshot_examples_per_class.
"uniform" sampling strategy means that fewshot examples for a uniformly sampled label are used.
"stratified" sampling strategy means that fewshot examples uniformly selected from each label.
fewshot_examples_per_class (int, optional): Number of support examples for a certain class per prompt.
Defaults to 1.
fewshot_label_sampling_strategy (str, optional): Sampling strategy for support examples. Defaults to
None. "uniform" sampling strategy means that one label is sampled and only fewshot examples for that
class are used. "stratified" sampling strategy means that all classes in support examples are used.
Defaults to None.
fewshot_sampling_column (str, optional): Column to sample from. Defaults to None and function will try
to sample from the generate_data_for_column attribute of the prompt template.
unlabeled_dataset (Optional[Dataset], optional): Unlabeled examples to annotate. Defaults to None.
Expand All @@ -91,7 +93,7 @@ class are used. "stratified" sampling strategy means that all classes in support
if fewshot_dataset:
self._assert_fewshot_dataset_matches_prompt(prompt_template, fewshot_dataset)

assert fewshot_label_sampling_strategy in [None, "uniform", "stratified"], \
assert fewshot_sampling_strategy in [None, "uniform", "stratified"], \
"Sampling strategy must be 'uniform' or 'stratified'"

if fewshot_dataset and not fewshot_sampling_column:
Expand All @@ -101,7 +103,7 @@ class are used. "stratified" sampling strategy means that all classes in support
prompt_template,
fewshot_dataset,
fewshot_examples_per_class,
fewshot_label_sampling_strategy,
fewshot_sampling_strategy,
fewshot_sampling_column,
unlabeled_dataset,
return_unlabeled_dataset,
Expand Down Expand Up @@ -162,7 +164,7 @@ def _inner_generate_loop(
prompt_template: BasePrompt,
fewshot_dataset: Dataset,
fewshot_examples_per_class: int,
fewshot_label_sampling_strategy: str,
fewshot_sampling_strategy: str,
fewshot_sampling_column: str,
unlabeled_dataset: Dataset,
return_unlabeled_dataset: bool,
Expand Down Expand Up @@ -192,11 +194,13 @@ def _inner_generate_loop(
prompt_labels = None

if prompt_template.label_options:
prompt_labels = choice(prompt_template.label_options, 1)[0]
# At some point: how can we do label-conditioned generation without fewshot examples? Currently it
# require a second parameter for sample from label options and not from fewshot examples
prompt_labels = prompt_template.label_options

if fewshot_dataset:
prompt_labels, fewshot_examples = self._sample_fewshot_examples(
prompt_template, fewshot_dataset, fewshot_examples_per_class, fewshot_label_sampling_strategy,
prompt_template, fewshot_dataset, fewshot_sampling_strategy, fewshot_examples_per_class,
fewshot_sampling_column
)

Expand Down Expand Up @@ -313,19 +317,20 @@ def _convert_prediction(self, prediction: str, target_type: type) -> Any:

@staticmethod
def _sample_fewshot_examples(
prompt_template: BasePrompt,
fewshot_dataset: Dataset,
fewshot_examples_per_class: int,
fewshot_label_sampling_strategy: str,
fewshot_sampling_column: str
prompt_template: BasePrompt,
fewshot_dataset: Dataset,
fewshot_sampling_strategy: str,
fewshot_examples_per_class: int,
fewshot_sampling_column: str
) -> Tuple[Union[List[str], str], Dataset]:
if fewshot_label_sampling_strategy == "uniform":

if fewshot_sampling_strategy == "uniform":
prompt_labels = choice(prompt_template.label_options, 1)[0]
fewshot_examples = fewshot_dataset.filter(
lambda example: example[fewshot_sampling_column] == prompt_labels
).shuffle().select(range(fewshot_examples_per_class))

elif fewshot_label_sampling_strategy == "stratified":
elif fewshot_sampling_strategy == "stratified":
prompt_labels = prompt_template.label_options
fewshot_examples = single_label_stratified_sample(
fewshot_dataset,
Expand All @@ -335,7 +340,10 @@ def _sample_fewshot_examples(

else:
prompt_labels = prompt_template.label_options if prompt_template.label_options else None
fewshot_examples = fewshot_dataset.shuffle().select(range(fewshot_examples_per_class))
if fewshot_examples_per_class:
fewshot_examples = fewshot_dataset.shuffle().select(range(fewshot_examples_per_class))
else:
fewshot_examples = fewshot_dataset.shuffle()

assert len(fewshot_examples) > 0, f"Could not find any fewshot examples for label(s) {prompt_labels}." \
f"Ensure that labels of fewshot examples match the label_options " \
Expand Down
120 changes: 114 additions & 6 deletions tests/test_dataset_generator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import unittest

from datasets import Dataset, load_dataset
from haystack.nodes import PromptNode

from ai_dataset_generator import DatasetGenerator
from ai_dataset_generator.prompts import BasePrompt
from ai_dataset_generator.dataset_transformations.text_classification import convert_label_ids_to_texts
Expand Down Expand Up @@ -67,7 +65,7 @@ def test_generation_with_fewshot_examples(self):
prompt_template=prompt,
fewshot_dataset=self.text_classification_dataset,
fewshot_examples_per_class=1,
fewshot_label_sampling_strategy="uniform",
fewshot_sampling_strategy="uniform",
fewshot_sampling_column="label",
max_prompt_calls=2,
dummy_response="A dummy movie review."
Expand Down Expand Up @@ -95,7 +93,7 @@ def test_annotation_with_fewshot_and_unlabeled_examples(self):
prompt_template=prompt,
fewshot_dataset=self.text_classification_dataset,
fewshot_examples_per_class=1,
fewshot_label_sampling_strategy="stratified",
fewshot_sampling_strategy="stratified",
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=2,
dummy_response="A dummy movie review."
Expand Down Expand Up @@ -130,7 +128,7 @@ def test_tranlation(self):
prompt_template=prompt,
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=2, # Take both fewshot examples per prompt
fewshot_label_sampling_strategy=None,
fewshot_sampling_strategy=None,
# Since we do not have a class label column, we can just set this to None
# (default)
unlabeled_dataset=unlabeled_dataset,
Expand Down Expand Up @@ -164,7 +162,7 @@ def test_textual_similarity(self):
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=1,
fewshot_sampling_column="label",
fewshot_label_sampling_strategy="stratified",
fewshot_sampling_strategy="stratified",
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=2,
return_unlabeled_dataset=True,
Expand All @@ -176,6 +174,116 @@ def test_textual_similarity(self):
self.assertEqual(generated_dataset.features["sentence2"].dtype, "string")
self.assertEqual(generated_dataset.features["label"].dtype, "string")

def test_sampling_all_fewshot_examples(self):
"""Test sampling all fewshot examples"""
prompt = BasePrompt(
task_description="Generate a short movie review.",
)

prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
prompt_template=prompt,
fewshot_dataset=self.text_classification_dataset,
fewshot_sampling_strategy=None,
fewshot_examples_per_class=None,
fewshot_sampling_column=None,
)

self.assertEqual(prompt_labels, None)
self.assertEqual(len(fewshot_examples), 2)

def test_sampling_all_fewshot_examples_with_label_options(self):
"""Test sampling all fewshot examples with label options"""
prompt = BasePrompt(
task_description="Generate a short movie review: {}.",
label_options=["positive", "negative"],
)

prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
prompt_template=prompt,
fewshot_dataset=self.text_classification_dataset,
fewshot_sampling_strategy=None,
fewshot_examples_per_class=None,
fewshot_sampling_column=None,
)

prompt_text = prompt.task_description.format(prompt_labels)
self.assertIn("positive", prompt_text)
self.assertIn("negative", prompt_text)

def test_sampling_uniform_fewshot_examples(self):
"""Test uniform sampling fewshot examples"""
prompt = BasePrompt(
task_description="Generate a short movie review: {}.",
label_options=["positive", "negative"],
)

prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
prompt_template=prompt,
fewshot_dataset=self.text_classification_dataset,
fewshot_sampling_strategy="uniform",
fewshot_examples_per_class=1,
fewshot_sampling_column="label",
)

self.assertEqual(len(fewshot_examples), 1)
self.assertIn(prompt_labels, ["positive", "negative"])

def test_sampling_stratified_fewshot_examples(self):
"""Test stratified sampling fewshot examples"""
larger_fewshot_dataset = Dataset.from_dict({
"text": ["This movie is great!", "This movie is bad!", "This movie is great!", "This movie is bad!"],
"label": ["positive", "negative", "positive", "negative"]
})

prompt = BasePrompt(
task_description="Generate a short movie review: {}.",
label_options=["positive", "negative"],
)

prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
prompt_template=prompt,
fewshot_dataset=larger_fewshot_dataset,
fewshot_sampling_strategy="stratified",
fewshot_examples_per_class=1,
fewshot_sampling_column="label",
)

self.assertEqual(len(fewshot_examples), 2)
self.assertEqual(len(set(fewshot_examples["label"])), 2)
self.assertEqual(len(prompt_labels), 2)

def test_sampling_uniform_fewshot_examples_without_number_of_examples(self):
"""Test failure of uniform sampling fewshot examples if attributes are missing"""
prompt = BasePrompt(
task_description="Generate a short movie review: {}.",
label_options=["positive", "negative"],
)

with self.assertRaises(KeyError):
prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
prompt_template=prompt,
fewshot_dataset=self.text_classification_dataset,
fewshot_sampling_strategy="uniform",
fewshot_examples_per_class=None,
fewshot_sampling_column=None,
)

def test_sampling_uniform_fewshot_examples_with_number_of_examples_without_sampling_column(self):
"""Test failure of uniform sampling fewshot examples if attributes are missing"""
prompt = BasePrompt(
task_description="Generate a short movie review: {}.",
label_options=["positive", "negative"],
)

with self.assertRaises(KeyError):
prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
prompt_template=prompt,
fewshot_dataset=self.text_classification_dataset,
fewshot_sampling_strategy="uniform",
fewshot_examples_per_class=1,
fewshot_sampling_column=None,
)

def test_dummy_response(self):

prompt = BasePrompt(
Expand Down
8 changes: 4 additions & 4 deletions tutorials/TUTORIAL-2_SIMPLE-GENERATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ reviews and provide the LLM with examples, we need to specify the `generate_data

Since we are using fewshot examples, we can control the prompt generation through different sampling strategies and
number of examples per class. We pass the `fewshot_dataset` to the generate function and specify to use one fewshot
example per class per prompt. The `fewshot_label_sampling_strategy` argument specifies how to sample from the fewshot
example per class per prompt. The `fewshot_sampling_strategy` argument specifies how to sample from the fewshot
dataset. In this case, we use a uniform sampling strategy which means that the generator will uniformly sample one
example per class from the fewshot dataset. The `fewshot_sampling_column` argument specifies which column to use for
sampling. In this case, we use the `label` column.
Expand Down Expand Up @@ -114,7 +114,7 @@ generated_dataset = generator.generate(
prompt_template=prompt,
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=1,
fewshot_label_sampling_strategy="uniform",
fewshot_sampling_strategy="uniform",
fewshot_sampling_column="label",
max_prompt_calls=10,
)
Expand All @@ -129,7 +129,7 @@ two columns: `text` and `label`. We also have an unlabeled dataset with only a `
unlabeled dataset with the fewshot dataset. In order to do this, we need to specify the `unlabeled_dataset` argument
to the `DatasetGenerator.generate()` method. We also need to specify the `fewshot_examples_per_class` argument to
specify how many fewshot examples to use per class. In this case, we use one example per class. The
`fewshot_label_sampling_strategy` argument specifies how to sample from the fewshot dataset.
`fewshot_sampling_strategy` argument specifies how to sample from the fewshot dataset.
In this case, we use a stratfied sampling strategy which means that the generator will sample exactly one example from
each class from the fewshot dataset. In this case, we do not need to explicitly specify the `fewshot_sampling_column`
argument since the generator will use the column specified in `generate_data_for_column` by default.
Expand Down Expand Up @@ -170,7 +170,7 @@ generated_dataset = generator.generate(
prompt_template=prompt,
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=1,
fewshot_label_sampling_strategy="stratified",
fewshot_sampling_strategy="stratified",
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=10,
)
Expand Down
6 changes: 3 additions & 3 deletions tutorials/TUTORIAL-3_ADVANCED-GENERATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ generated_dataset = generator.generate(
prompt_template=prompt,
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=2, # Take both fewshot examples per prompt
fewshot_label_sampling_strategy=None, # Since we do not have a class label column, we can just set this to None
fewshot_sampling_strategy=None, # Since we do not have a class label column, we can just set this to None
# (default)
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=2,
Expand Down Expand Up @@ -475,7 +475,7 @@ generated_dataset, original_dataset = generator.generate(
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=1, # Take 1 fewshot examples per class per prompt
fewshot_sampling_column="label", # We want to sample fewshot examples based on the label column
fewshot_label_sampling_strategy="stratified", # We want to sample fewshot examples stratified by class
fewshot_sampling_strategy="stratified", # We want to sample fewshot examples stratified by class
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=2,
return_unlabeled_dataset=True, # We can return the original unlabelled dataset which might be interesting in this
Expand Down Expand Up @@ -524,7 +524,7 @@ generated_dataset, original_dataset = generator.generate(
fewshot_dataset=fewshot_dataset,
fewshot_examples_per_class=1, # Take 1 fewshot examples per class per prompt
fewshot_sampling_column="label", # We want to sample fewshot examples based on the label column
fewshot_label_sampling_strategy="stratified", # We want to sample fewshot examples stratified by class
fewshot_sampling_strategy="stratified", # We want to sample fewshot examples stratified by class
unlabeled_dataset=unlabeled_dataset,
max_prompt_calls=2,
return_unlabeled_dataset=True, # We can return the original unlabelled dataset which might be interesting in this
Expand Down

0 comments on commit 18dcd6f

Please sign in to comment.