-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
additional hyperparameter experiment for annotation + renaming of gen…
…erate train script for grid search
- Loading branch information
1 parent
03ebefd
commit 30d7aad
Showing
3 changed files
with
194 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import os | ||
from datasets import load_dataset, concatenate_datasets | ||
from haystack.nodes import PromptNode | ||
from fabricator import DatasetGenerator, BasePrompt | ||
from fabricator.dataset_transformations.text_classification import convert_label_ids_to_texts | ||
|
||
def run(): | ||
for possible_examples_per_class, fewshot_example_per_class in [(0,0), (2,1), (2,2), (4,1), (4,2), (4,3), (4,4), (8,1), (8,2), (8,3), | ||
(8,4), (16,1), (16,2), (16,3), (16,4)]: | ||
dataset = load_dataset("trec", split="train").shuffle(seed=42).train_test_split(500, stratify_by_column="coarse_label") | ||
fewshot_dataset = dataset["train"] | ||
annotation_dataset = dataset["test"] | ||
fewshot_datasets = [] | ||
for label in range(6): | ||
filtered_ds = fewshot_dataset.filter(lambda x: x["coarse_label"] == label) | ||
fewshot_datasets.append(filtered_ds.select(range(possible_examples_per_class))) | ||
fewshot_dataset = concatenate_datasets(fewshot_datasets).shuffle(seed=42) | ||
|
||
extended_mapping = { | ||
0: "abbreviation", | ||
1: "entity", | ||
2: "description", | ||
3: "human", | ||
4: "location", | ||
5: "number" | ||
} | ||
|
||
if possible_examples_per_class > 0: | ||
fewshot_dataset = convert_label_ids_to_texts( | ||
fewshot_dataset, | ||
"coarse_label", | ||
expanded_label_mapping=extended_mapping, | ||
) | ||
|
||
annotation_dataset, label_options = convert_label_ids_to_texts( | ||
annotation_dataset, | ||
"coarse_label", | ||
expanded_label_mapping=extended_mapping, | ||
return_label_options=True, | ||
) | ||
|
||
prompt = BasePrompt( | ||
task_description="Classify the question into exactly one of the following classes: {}.", | ||
label_options=label_options, | ||
generate_data_for_column="coarse_label", | ||
fewshot_example_columns="text", | ||
fewshot_formatting_template="Question: {text}\nClass: {coarse_label}", | ||
target_formatting_template="Question: {text}\nClass: ", | ||
) | ||
|
||
prompt_node = PromptNode( | ||
model_name_or_path="gpt-3.5-turbo", | ||
api_key=os.environ.get("OPENAI_API_KEY"), | ||
max_length=100, | ||
) | ||
|
||
generator = DatasetGenerator(prompt_node) | ||
generated_dataset = generator.generate( | ||
prompt_template=prompt, | ||
fewshot_dataset=fewshot_dataset if possible_examples_per_class > 0 else None, | ||
fewshot_examples_per_class=fewshot_example_per_class if possible_examples_per_class > 0 else 0, | ||
fewshot_sampling_strategy="stratified" if possible_examples_per_class > 0 else None, | ||
fewshot_sampling_column="coarse_label" if possible_examples_per_class > 0 else None, | ||
unlabeled_dataset=annotation_dataset, | ||
max_prompt_calls=len(annotation_dataset), | ||
) | ||
|
||
model_name = f"trec_hyperparameter_annotated_{possible_examples_per_class}_possible_examples_{fewshot_example_per_class}_used" | ||
generated_dataset.push_to_hub(model_name, private=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
121 changes: 121 additions & 0 deletions
121
paper_experiments/trec_hyperparameter_annotate_train_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import numpy as np | ||
from datasets import load_dataset, ClassLabel | ||
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding | ||
import evaluate | ||
import shutil | ||
|
||
|
||
def run(possible_examples_per_class, fewshot_example_per_class, seed): | ||
corpus_name = f"whoisjones/trec_hyperparameter_annotated_{possible_examples_per_class}_possible_examples_{fewshot_example_per_class}_used" | ||
|
||
if "corpus_name" not in locals(): | ||
raise Exception("Please insert the generated corpora before running this script.") | ||
|
||
label_alignment = { | ||
"NUM": "number", | ||
"ENTY": "entity", | ||
"DESC": "description", | ||
"ABBR": "abbreviation", | ||
"HUM": "human", | ||
"LOC": "location", | ||
} | ||
# Load the dataset | ||
dataset = load_dataset(corpus_name, split="train").shuffle(seed=seed) | ||
test_split = load_dataset("trec", split="test") | ||
original_labels = test_split.features["coarse_label"].names | ||
|
||
def clean_labels(examples): | ||
label = examples["coarse_label"].replace("Class: ", "") | ||
if label not in list(label_alignment.values()): | ||
label = "remove" | ||
examples["coarse_label"] = label | ||
return examples | ||
|
||
dataset = dataset.map(clean_labels) | ||
dataset = dataset.filter(lambda x: x["coarse_label"] != "remove") | ||
|
||
dst_feat = ClassLabel(names=[label_alignment[k] for k in original_labels]) | ||
dataset = dataset.map(lambda batch: { | ||
"coarse_label": dst_feat.str2int(batch)}, input_columns="coarse_label", batched=True) | ||
new_features = dataset.features.copy() | ||
new_features["coarse_label"] = dst_feat | ||
dataset = dataset.cast(new_features) | ||
|
||
dataset = dataset.train_test_split(test_size=0.1) | ||
dataset["validation"] = dataset["test"] | ||
dataset["test"] = test_split | ||
dataset = dataset.rename_column("coarse_label", "label") | ||
num_labels = dataset["train"].features["label"].num_classes | ||
|
||
# Load the BERT tokenizer and model | ||
model_name = "bert-base-uncased" | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | ||
|
||
def preprocess_function(examples): | ||
return tokenizer(examples["text"], padding=True, truncation=True, return_tensors="pt") | ||
|
||
tokenized_dataset = dataset.map(preprocess_function, batched=True) | ||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | ||
accuracy = evaluate.load("accuracy") | ||
|
||
def compute_metrics(eval_pred): | ||
predictions, labels = eval_pred | ||
predictions = np.argmax(predictions, axis=1) | ||
return accuracy.compute(predictions=predictions, references=labels) | ||
|
||
id2label = dict(enumerate(dataset["train"].features["label"].names)) | ||
label2id = {v: k for k, v in id2label.items()} | ||
|
||
model = AutoModelForSequenceClassification.from_pretrained( | ||
model_name, | ||
num_labels=num_labels, | ||
id2label=id2label, | ||
label2id=label2id | ||
).to("cuda") | ||
|
||
num_train_epochs = 20 | ||
|
||
# Training arguments | ||
training_args = TrainingArguments( | ||
output_dir="output_model", | ||
learning_rate=2e-5, | ||
per_device_train_batch_size=16, | ||
per_device_eval_batch_size=16, | ||
num_train_epochs=num_train_epochs, | ||
weight_decay=0.01, | ||
save_total_limit=1, | ||
evaluation_strategy="epoch", | ||
push_to_hub=False, | ||
) | ||
|
||
trainer = Trainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=tokenized_dataset["train"], | ||
eval_dataset=tokenized_dataset["validation"], | ||
tokenizer=tokenizer, | ||
data_collator=data_collator, | ||
compute_metrics=compute_metrics, | ||
) | ||
|
||
trainer.train() | ||
|
||
return trainer.predict(tokenized_dataset["test"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
# for every combination of possible fewshot examples and fewshot examples used | ||
for possible_examples_per_class, fewshot_example_per_class in [(0, 0), (2, 1), (2, 2), (4, 1), (4, 2), (4, 3), | ||
(4, 4), (8, 1), (8, 2), (8, 3), (8, 4), (16, 1), | ||
(16, 2), (16, 3), (16, 4)]: | ||
result_avg = [] | ||
# iterate over seeds | ||
for seed in [41, 42, 43, 44, 45]: | ||
results = run(possible_examples_per_class, fewshot_example_per_class, seed) | ||
result_avg.append(results.metrics["test_accuracy"] * 100) | ||
|
||
# log for hyperparameter run | ||
file = f"hyperparameter-trec-annotation-{possible_examples_per_class}-possible-{fewshot_example_per_class}-used" | ||
with open(f"results/{file}.log", "w") as f: | ||
f.write(f"Accuracy: {np.mean(result_avg)}\n") | ||
f.write(f"Standard deviation: {np.std(result_avg)}\n") |
File renamed without changes.