Skip to content

Commit

Permalink
delete test script and format files
Browse files Browse the repository at this point in the history
  • Loading branch information
stolzenp committed May 2, 2024
1 parent ca3c22d commit 6960917
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 152 deletions.
125 changes: 87 additions & 38 deletions src/fabricator/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from .samplers import single_label_stratified_sample
from .utils import log_dir, create_timestamp_path


class DatasetGenerator:
"""The DatasetGenerator class is the main class of the fabricator package.
It generates datasets based on a prompt template. The main function is generate()."""
It generates datasets based on a prompt template. The main function is generate().
"""

def __init__(self, prompt_node: PromptNode, max_tries: int = 10):
"""Initialize the DatasetGenerator with a prompt node.
Expand Down Expand Up @@ -63,7 +65,7 @@ def generate(
train_small_model_every_X_generations: Optional[int] = None,
timeout_per_prompt: Optional[int] = None,
log_every_n_api_calls: int = 25,
dummy_response: Optional[Union[str, Callable]] = None
dummy_response: Optional[Union[str, Callable]] = None,
) -> Union[Dataset, Tuple[Dataset, Dataset]]:
"""Generate a dataset based on a prompt template and support examples.
Optionally, unlabeled examples can be provided to annotate unlabeled data.
Expand Down Expand Up @@ -96,16 +98,23 @@ def generate(
dataset.
"""
if fewshot_dataset:
self._assert_fewshot_dataset_matches_prompt(prompt_template, fewshot_dataset)
self._assert_fewshot_dataset_matches_prompt(
prompt_template, fewshot_dataset
)

assert fewshot_sampling_strategy in [None, "uniform", "stratified"], \
"Sampling strategy must be 'uniform' or 'stratified'"
assert fewshot_sampling_strategy in [
None,
"uniform",
"stratified",
], "Sampling strategy must be 'uniform' or 'stratified'"

if fewshot_dataset and not fewshot_sampling_column:
fewshot_sampling_column = prompt_template.generate_data_for_column[0]

assert small_model_training in [None, "text_classification"], \
"Task for small model training must be available in 'src/small_model_training' e.g. 'text_classification'"
assert small_model_training in [
None,
"text_classification",
], "Task for small model training must be available in 'src/small_model_training' e.g. 'text_classification'"

generated_dataset, original_dataset = self._inner_generate_loop(
prompt_template,
Expand All @@ -121,7 +130,7 @@ def generate(
train_small_model_every_X_generations,
timeout_per_prompt,
log_every_n_api_calls,
dummy_response
dummy_response,
)

if return_unlabeled_dataset:
Expand All @@ -130,7 +139,10 @@ def generate(
return generated_dataset

def _try_generate(
self, prompt_text: str, invocation_context: Dict, dummy_response: Optional[Union[str, Callable]]
self,
prompt_text: str,
invocation_context: Dict,
dummy_response: Optional[Union[str, Callable]],
) -> Optional[str]:
"""Tries to generate a single example. Restrict the time spent on this.
Expand Down Expand Up @@ -184,7 +196,7 @@ def _inner_generate_loop(
train_small_model_every_x_generations: Optional[int],
timeout_per_prompt: Optional[int],
log_every_n_api_calls: int = 25,
dummy_response: Optional[Union[str, Callable]] = None
dummy_response: Optional[Union[str, Callable]] = None,
):
current_tries_left = self._max_tries
current_log_file = self._setup_log(prompt_template)
Expand All @@ -198,7 +210,9 @@ def _inner_generate_loop(
api_calls = range(min(max_prompt_calls, num_samples_to_generate))

for prompt_call_idx, unlabeled_example_idx in tqdm(
enumerate(api_calls, start=1), desc="Generating dataset", total=len(api_calls)
enumerate(api_calls, start=1),
desc="Generating dataset",
total=len(api_calls),
):
fewshot_examples = None
unlabeled_example = None
Expand All @@ -211,16 +225,25 @@ def _inner_generate_loop(
prompt_labels = prompt_template.label_options

# label-conditioned generation with label options
if fewshot_dataset is None and isinstance(prompt_labels,list) or prompt_call_idx > train_small_model_every_x_generations:
if (
fewshot_dataset is None
and isinstance(prompt_labels, list)
or prompt_call_idx > train_small_model_every_x_generations
):
prompt_labels = choice(prompt_labels, 1)[0]

if fewshot_dataset and prompt_labels in fewshot_dataset['label']:
if fewshot_dataset and prompt_labels in fewshot_dataset["label"]:
prompt_labels, fewshot_examples = self._sample_fewshot_examples(
prompt_labels, fewshot_dataset, fewshot_sampling_strategy, fewshot_examples_per_class,
fewshot_sampling_column
prompt_labels,
fewshot_dataset,
fewshot_sampling_strategy,
fewshot_examples_per_class,
fewshot_sampling_column,
)

prompt_text = prompt_template.get_prompt_text(prompt_labels, fewshot_examples)
prompt_text = prompt_template.get_prompt_text(
prompt_labels, fewshot_examples
)

if unlabeled_dataset:
unlabeled_example = unlabeled_dataset[unlabeled_example_idx]
Expand All @@ -236,7 +259,9 @@ def _inner_generate_loop(
f"Invocation context: {invocation_context} \n"
)

prediction = self._try_generate(prompt_text, invocation_context, dummy_response)
prediction = self._try_generate(
prompt_text, invocation_context, dummy_response
)

if prediction is None:
current_tries_left -= 1
Expand Down Expand Up @@ -267,20 +292,28 @@ def _inner_generate_loop(
prediction, type(prompt_template.generate_data_for_column[0])
)

generated_dataset[prompt_template.generate_data_for_column[0]].append(prediction)
generated_dataset[prompt_template.generate_data_for_column[0]].append(
prediction
)

else:
generated_dataset[prompt_template.DEFAULT_TEXT_COLUMN[0]].append(prediction)
generated_dataset[prompt_template.DEFAULT_TEXT_COLUMN[0]].append(
prediction
)
if prompt_labels and isinstance(prompt_labels, str):
generated_dataset[prompt_template.DEFAULT_LABEL_COLUMN[0]].append(prompt_labels)
generated_dataset[prompt_template.DEFAULT_LABEL_COLUMN[0]].append(
prompt_labels
)

log_entry = {
"prompt": prompt_text,
"invocation_context": invocation_context,
"prediction": prediction,
"target": prompt_template.generate_data_for_column[0]
if prompt_template.generate_data_for_column
else prompt_template.DEFAULT_TEXT_COLUMN[0],
"target": (
prompt_template.generate_data_for_column[0]
if prompt_template.generate_data_for_column
else prompt_template.DEFAULT_TEXT_COLUMN[0]
),
}
with open(current_log_file, "a", encoding="utf-8") as log_file:
log_file.write(f"{json.dumps(log_entry)}\n")
Expand All @@ -290,7 +323,9 @@ def _inner_generate_loop(
original_dataset[key].append(value)

if prompt_call_idx >= max_prompt_calls:
logger.info("Reached maximum number of prompt calls ({}).", max_prompt_calls)
logger.info(
"Reached maximum number of prompt calls ({}).", max_prompt_calls
)
break

if len(generated_dataset) >= num_samples_to_generate:
Expand All @@ -300,10 +335,15 @@ def _inner_generate_loop(
if timeout_per_prompt is not None:
time.sleep(timeout_per_prompt)

if train_small_model_every_x_generations is not None and train_small_model_every_x_generations > 0:
if (
train_small_model_every_x_generations is not None
and train_small_model_every_x_generations > 0
):
if prompt_call_idx % train_small_model_every_x_generations == 0:
logger.info("Commencing small model training.")
small_model = import_module("src.small_model_training." + small_model_training, __package__)
small_model = import_module(
"src.small_model_training." + small_model_training, __package__
)
inf_subset = small_model.get_influential_subset(generated_dataset)
fewshot_dataset = inf_subset
logger.info("Continuing generation.")
Expand Down Expand Up @@ -335,7 +375,9 @@ def _convert_prediction(self, prediction: str, target_type: type) -> Any:
except ValueError:
logger.warning(
"Could not convert prediction {} to type {}. "
"Returning original prediction.", repr(prediction), target_type
"Returning original prediction.",
repr(prediction),
target_type,
)
return prediction

Expand All @@ -345,7 +387,7 @@ def _sample_fewshot_examples(
fewshot_dataset: Dataset,
fewshot_sampling_strategy: str,
fewshot_examples_per_class: int,
fewshot_sampling_column: str
fewshot_sampling_column: str,
) -> Tuple[Union[List[str], str], Dataset]:

if fewshot_sampling_strategy == "uniform":
Expand All @@ -354,31 +396,38 @@ def _sample_fewshot_examples(
lambda example: example[fewshot_sampling_column] == prompt_labels
)
fewshot_examples = fewshot_examples.shuffle().select(
range(fewshot_examples_per_class) if fewshot_examples_per_class is not None else range(len(fewshot_examples))
range(fewshot_examples_per_class)
if fewshot_examples_per_class is not None
else range(len(fewshot_examples))
)

elif fewshot_sampling_strategy == "stratified":
fewshot_examples = single_label_stratified_sample(
fewshot_dataset,
fewshot_sampling_column,
fewshot_examples_per_class
fewshot_dataset, fewshot_sampling_column, fewshot_examples_per_class
)

else:
if fewshot_examples_per_class:
fewshot_examples = fewshot_dataset.shuffle().select(range(fewshot_examples_per_class))
fewshot_examples = fewshot_dataset.shuffle().select(
range(fewshot_examples_per_class)
)
else:
fewshot_examples = fewshot_dataset.shuffle()

assert len(fewshot_examples) > 0, f"Could not find any fewshot examples for label(s) {prompt_labels}." \
f"Ensure that labels of fewshot examples match the label_options " \
f"from the prompt."
assert len(fewshot_examples) > 0, (
f"Could not find any fewshot examples for label(s) {prompt_labels}."
f"Ensure that labels of fewshot examples match the label_options "
f"from the prompt."
)

return prompt_labels, fewshot_examples

@staticmethod
def _assert_fewshot_dataset_matches_prompt(prompt_template: BasePrompt, fewshot_dataset: Dataset) -> None:
def _assert_fewshot_dataset_matches_prompt(
prompt_template: BasePrompt, fewshot_dataset: Dataset
) -> None:
"""Asserts that the prompt template is valid and all columns are present in the fewshot dataset."""
assert all(
field in fewshot_dataset.column_names for field in prompt_template.relevant_columns_for_fewshot_examples
field in fewshot_dataset.column_names
for field in prompt_template.relevant_columns_for_fewshot_examples
), "Not all required variables of the prompt template occur in the support examples."
Loading

0 comments on commit 6960917

Please sign in to comment.