Skip to content

Commit

Permalink
adjusted to refactorings and move squad annotation example to paper_e…
Browse files Browse the repository at this point in the history
…xperiments folder. some adjustments on the processing function for squad + one test fixed.
  • Loading branch information
whoisjones committed Aug 3, 2023
1 parent d4a5219 commit 9cdc2a3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,17 @@

from datasets import Sequence, Value, load_dataset, concatenate_datasets
from haystack.nodes import PromptNode
from ai_dataset_generator import DatasetGenerator
from ai_dataset_generator import DatasetGenerator, BasePrompt
from ai_dataset_generator.dataset_transformations.question_answering import (
preprocess_squad_format,
postprocess_squad_format,
)
from ai_dataset_generator.prompts import TextLabelPrompt


def run(arguments):
"""Generate answers based on a few-shot example from context and question."""
org = load_dataset(arguments.dataset, split=arguments.split)


dataset_name = f"qa-dataset"
dataset_answerable_questions = org.filter(lambda sample: sample['answers']['text']).shuffle()
dataset_unanswerable_questions = org.filter(lambda sample: not sample['answers']['text']).shuffle()
Expand All @@ -32,19 +30,19 @@ def run(arguments):
filtered_original_datasets = []

def merge_columns(example):
if example["answer"] == "":
if example["answers"] == "":
example["question"] = example["question"]
return example
example["question"] = f"{example['question']}\nAnswer: {example['answer']}"
example["question"] = f"{example['question']}\nAnswer: {example['answers']}"
return example

def split_columns(example):
entries = example["question"].split("\nAnswer:")
example["question"] = entries[0]
if len(entries) == 1:
example["answer"] = ""
example["answers"] = ""
return example
example["answer"] = entries[1].strip()
example["answers"] = entries[1].strip()
return example

for index, dataset in enumerate([dataset_answerable_questions, dataset_unanswerable_questions]):
Expand All @@ -59,23 +57,20 @@ def split_columns(example):
"Given a text, first create a difficult question that can be answered using the text. The question must describe the context of the text. Second, extract the answer to this question from the text. The answer must be word for word exactly as it appears in the text.",
f"You are a student and a teacher is teaching you about a new topic. Ask a short follow-up question about something the teacher hasn't mentioned yet at all. You must not ask something you already know the answer to from the teacher's explanations. You must not ask for further clarification if the teacher already mentioned something in passing. The question should be self-contained. It must not contain the word \"other\" as in \"which other\" or \"what other\". The question should start with one of {random.sample(question_words, 3)}"]

prompt = TextLabelPrompt(
input_variables=arguments.input_variables,
target_variable=arguments.target_variable,
prompt = BasePrompt(
task_description=task_descriptions[index],
fewshot_example_columns=arguments.input_variables,
generate_data_for_column=arguments.target_variable,
)

raw_prompt = prompt.get_prompt_text(fewshot_examples)
print(raw_prompt)

current_unlabeled_examples = unlabeled_examples.select(range(i, min(len(unlabeled_examples), i + arguments.save_steps)))
generated_dataset, original_dataset = generator.generate(
support_examples=fewshot_examples,
unlabeled_examples=current_unlabeled_examples,
fewshot_dataset=fewshot_examples,
fewshot_examples_per_class=arguments.support_examples_per_prompt,
unlabeled_dataset=current_unlabeled_examples,
prompt_template=prompt,
max_prompt_calls=arguments.max_prompt_calls,
support_examples_per_prompt=arguments.support_examples_per_prompt,
return_original_dataset=True,
return_unlabeled_dataset=True,
)

generated_dataset = generated_dataset.map(split_columns)
Expand Down Expand Up @@ -106,7 +101,7 @@ def split_columns(example):
original_dataset = dataset.filter(lambda example: example['id'] in ids_to_keep)

features = generated_dataset.features
features["answers"] = Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int64', id=None)}, length=-1, id=None)
features["answers"] = Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)
generated_dataset = generated_dataset.cast(features)

filtered_generated_datasets.append(generated_dataset)
Expand All @@ -116,14 +111,15 @@ def split_columns(example):
filtered_generated_concatenated_dataset.save_to_disk(f"{dataset_name}-generated-{index}-{i}")
filtered_original_concatenated_dataset = concatenate_datasets(filtered_original_datasets)
filtered_original_concatenated_dataset.save_to_disk(f"{dataset_name}-original-{index}-{i}")

if arguments.push_to_hub:
filtered_generated_concatenated_dataset.push_to_hub(f"{dataset_name}-generated-{len(filtered_generated_concatenated_dataset)}", private=False)
filtered_original_concatenated_dataset.push_to_hub(f"{dataset_name}-original-{len(filtered_original_concatenated_dataset)}", private=False)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--llm", type=str, default="text-davinci-003")
parser.add_argument("--llm", type=str, default="gpt-3.5-turbo")
parser.add_argument("--max_generation_length", type=int, default=100)
parser.add_argument(
"--task_description",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def postprocess_squad_format(dataset: Dataset, add_answer_start: bool = True) ->
"""
# remove punctuation and whitespace from the start and end of the answer
def remove_punctuation(example):
example["answer"] = example["answer"].strip(".,;!? ")
example["answers"] = example["answers"].strip(".,;!? ")
return example

dataset = dataset.map(remove_punctuation)
Expand All @@ -46,9 +46,8 @@ def remove_punctuation(example):
dataset = dataset.map(calculate_answer_start)

def unify_answers(example):
example["answers"] = {"text": example["answers"], "start": example["answer_start"]}
is_unanswerable = "answer_start" in example
if is_unanswerable:
is_answerable = "answer_start" in example
if is_answerable:
example["answers"] = {"text": [example["answers"]], "answer_start": [example["answer_start"]]}
else:
example["answers"] = {"text": [], "answer_start": []}
Expand Down Expand Up @@ -85,6 +84,6 @@ def calculate_answer_start(example):
answer_start = -1
else:
# correct potential wrong capitalization of the answer compared to the context
example["answer"] = example["context"][answer_start : answer_start + len(example["answer"])]
example["answers"] = example["context"][answer_start : answer_start + len(example["answers"])]
example["answer_start"] = answer_start
return example
2 changes: 1 addition & 1 deletion tests/test_dataset_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@ def test_squad_postprocessing(self):
dataset = preprocess_squad_format(self.dataset.select(range(50)))
dataset = postprocess_squad_format(dataset)
self.assertEqual(type(dataset[0]["answers"]), dict)
self.assertIn("start", dataset[0]["answers"])
self.assertIn("answer_start", dataset[0]["answers"])

0 comments on commit 9cdc2a3

Please sign in to comment.