Skip to content

Commit

Permalink
updated with pull requestss
Browse files Browse the repository at this point in the history
  • Loading branch information
gorkemgoknar committed Dec 21, 2020
1 parent 35ab904 commit 21c90f1
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 55 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
## changed
# 🦄 Building a State-of-the-Art Conversational AI with Transfer Learning

The present repo contains the code accompanying the blog post [🦄 How to build a State-of-the-Art Conversational AI with Transfer Learning](https://medium.com/@Thomwolf/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313).
Expand Down Expand Up @@ -64,7 +63,7 @@ Argument | Type | Default value | Description
---------|------|---------------|------------
dataset_path | `str` | `""` | Path or url of the dataset. If empty download from S3.
dataset_cache | `str` | `'./dataset_cache.bin'` | Path or url of the dataset cache
model | `str` | `"openai-gpt"` | Path, url or short name of the model
model_checkpoint | `str` | `"openai-gpt"` | Path, url or short name of the model
num_candidates | `int` | `2` | Number of candidates for training
max_history | `int` | `2` | Number of previous exchanges to keep in history
train_batch_size | `int` | `4` | Batch size for training
Expand Down Expand Up @@ -102,7 +101,7 @@ You can then use the interactive script to interact with the model simply by poi
Here is an example command line to run the interactive script:

```bash
python ./interact.py --model_checkpoint ./data/Apr17_13-31-38_thunder/ # run the interactive script with a training checkpoint
python ./interact.py --model_checkpoint ./runs/Apr17_13-31-38_thunder/ # run the interactive script with a training checkpoint
python ./interact.py # run the interactive script with the finetuned model on our S3
```

Expand All @@ -114,7 +113,8 @@ Argument | Type | Default value | Description
---------|------|---------------|------------
dataset_path | `str` | `""` | Path or url of the dataset. If empty download from S3.
dataset_cache | `str` | `'./dataset_cache.bin'` | Path or url of the dataset cache
model | `str` | `"openai-gpt"` | Path, url or short name of the model
model | `str` | `"openai-gpt"` | Model type (openai-gpt or gpt2)
model_checkpoint | `str` | `""` | Path, url or short name of the model
max_history | `int` | `2` | Number of previous utterances to keep in history
device | `str` | `cuda` if `torch.cuda.is_available()` else `cpu` | Device (cuda or cpu)
no_sample | action `store_true` | Set to use greedy decoding instead of sampling
Expand Down Expand Up @@ -148,7 +148,7 @@ The evaluation script accept a few arguments to select the evaluation metric and
Argument | Type | Default value | Description
---------|------|---------------|------------
eval_type | `str` | `"hits@1"` | Evaluate the model on `hits@1`, `ppl` or `f1` metric on the ConvAI2 validation dataset
model | `str` | `"openai-gpt"` | Path, url or short name of the model
model_checkpoint | `str` | `"openai-gpt"` | Path, url or short name of the model. Must be OpenAIGPT.
max_history | `int` | `2` | Number of previous utterances to keep in history
device | `str` | `cuda` if `torch.cuda.is_available()` else `cpu` | Device (cuda or cpu)
no_sample | action `store_true` | Set to use greedy decoding instead of sampling
Expand Down Expand Up @@ -184,4 +184,4 @@ If you use this code in your research, you can cite our NeurIPS CAI workshop [pa
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1901-08149},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
```
109 changes: 70 additions & 39 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,21 @@

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, TensorDataset
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, global_step_from_engine
from ignite.handlers import ModelCheckpoint
from ignite.metrics import Accuracy, Loss, MetricsLambda, RunningAverage
from ignite.contrib.handlers import ProgressBar, PiecewiseLinear
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler
from transformers import (AdamW, OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
from pytorch_transformers import (AdamW, OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
GPT2DoubleHeadsModel, GPT2Tokenizer, WEIGHTS_NAME, CONFIG_NAME)

from utils import get_dataset, make_logdir

SPECIAL_TOKENS = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", "<pad>"]
ATTR_TO_SPECIAL_TOKEN = {'bos_token': '<bos>', 'eos_token': '<eos>', 'pad_token': '<pad>',
'additional_special_tokens': ['<speaker1>', '<speaker2>']}
'additional_special_tokens': ('<speaker1>', '<speaker2>')}
MODEL_INPUTS = ["input_ids", "mc_token_ids", "lm_labels", "mc_labels", "token_type_ids"]
PADDED_INPUTS = ["input_ids", "lm_labels", "token_type_ids"]

Expand All @@ -42,7 +43,7 @@ def pad_dataset(dataset, padding=0):
""" Pad the dataset. This could be optimized by defining a Dataset class and padding at the batch level, but this is simpler. """
max_l = max(len(x) for x in dataset["input_ids"])
for name in PADDED_INPUTS:
dataset[name] = [x + [padding if name != "lm_labels" else -100] * (max_l - len(x)) for x in dataset[name]]
dataset[name] = [x + [padding if name != "lm_labels" else -1] * (max_l - len(x)) for x in dataset[name]]
return dataset


Expand All @@ -62,19 +63,70 @@ def build_input_from_segments(persona, history, reply, tokenizer, lm_labels=Fals
instance["input_ids"] = list(chain(*sequence))
instance["token_type_ids"] = [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence) for _ in s]
instance["mc_token_ids"] = len(instance["input_ids"]) - 1
instance["lm_labels"] = [-100] * len(instance["input_ids"])
instance["lm_labels"] = [-1] * len(instance["input_ids"])
if lm_labels:
instance["lm_labels"] = ([-100] * sum(len(s) for s in sequence[:-1])) + [-100] + sequence[-1][1:]
instance["lm_labels"] = ([-1] * sum(len(s) for s in sequence[:-1])) + [-1] + sequence[-1][1:]
return instance


def pad_and_tensorize(batch_dict, padding):
""" Pad the batch_dict."""
tensors = []
for name in MODEL_INPUTS:
if name not in PADDED_INPUTS:
tensors.append(torch.tensor(batch_dict[name]))
continue
entry = batch_dict[name]
pad_id = padding if name != "lm_labels" else -1
padded = pad_sequence([torch.tensor(seq) for x in entry for seq in x], batch_first=True,
padding_value=pad_id)
bs, n_candidates = len(entry), len(entry[0])
tensors.append(padded.view(bs, n_candidates, -1))
return tensors

class ChatDataset(torch.utils.data.Dataset):

def __init__(self, fields, pad_id):
self.fields = fields
self.pad_id = pad_id

def __getitem__(self, item) -> dict:
return {f: self.fields[f][item] for f in MODEL_INPUTS}

def collate_fn(self, examples):
batch_dict = defaultdict(list)
for input_name in MODEL_INPUTS:
for e in examples:
batch_dict[input_name].append(e[input_name])
tensors = pad_and_tensorize(batch_dict, padding=self.pad_id)
return tensors

def __len__(self):
return len(self.fields['input_ids'])


def get_data_loaders(args, tokenizer):
""" Prepare the dataset for training and evaluation """
personachat = get_dataset(tokenizer, args.dataset_path, args.dataset_cache)

logger.info("Build inputs and labels")
datasets: dict = make_data_lists(args, personachat, tokenizer)
pad_id = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1])
train_dataset = ChatDataset(datasets['train'], pad_id)
valid_dataset = ChatDataset(datasets['valid'], pad_id)

logger.info("Build train and validation dataloaders")
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed),
collate_fn=train_dataset.collate_fn)
valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False,
collate_fn=valid_dataset.collate_fn)
return train_loader, valid_loader, train_sampler, valid_sampler


def make_data_lists(args, personachat, tokenizer):
datasets = {"train": defaultdict(list), "valid": defaultdict(list)}
print(personachat.keys())
for dataset_name, dataset in personachat.items():
num_candidates = len(dataset[0]["utterances"][0]["candidates"])
if args.num_candidates > 0 and dataset_name == 'train':
Expand All @@ -83,38 +135,19 @@ def get_data_loaders(args, tokenizer):
persona = dialog["personality"].copy()
for _ in range(args.personality_permutations):
for utterance in dialog["utterances"]:
history = utterance["history"][-(2*args.max_history+1):]
candidate_instances = defaultdict(list)
history = utterance["history"][-(2 * args.max_history + 1):]
for j, candidate in enumerate(utterance["candidates"][-num_candidates:]):
lm_labels = bool(j == num_candidates-1)
instance = build_input_from_segments(persona, history, candidate, tokenizer, lm_labels)
for input_name, input_array in instance.items():
#print(input_name)
#print(dataset_name)
datasets[dataset_name][input_name].append(input_array)
candidate_instances[input_name].append(input_array)
for k in candidate_instances.keys():
datasets[dataset_name][k].append(candidate_instances[k])
datasets[dataset_name]["mc_labels"].append(num_candidates - 1)
datasets[dataset_name]["n_candidates"] = num_candidates
persona = [persona[-1]] + persona[:-1] # permuted personalities

logger.info("Pad inputs and convert to Tensor")
tensor_datasets = {"train": [], "valid": []}
for dataset_name, dataset in datasets.items():
dataset = pad_dataset(dataset, padding=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1]))
for input_name in MODEL_INPUTS:
tensor = torch.tensor(dataset[input_name])
if input_name != "mc_labels":
tensor = tensor.view((-1, datasets[dataset_name]["n_candidates"]) + tensor.shape[1:])
tensor_datasets[dataset_name].append(tensor)

logger.info("Build train and validation dataloaders")
train_dataset, valid_dataset = TensorDataset(*tensor_datasets["train"]), TensorDataset(*tensor_datasets["valid"])
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed))
valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False)

logger.info("Train dataset (Batch, Candidates, Seq length): {}".format(train_dataset.tensors[0].shape))
logger.info("Valid dataset (Batch, Candidates, Seq length): {}".format(valid_dataset.tensors[0].shape))
return train_loader, valid_loader, train_sampler, valid_sampler
return datasets


def train():
Expand Down Expand Up @@ -199,12 +232,10 @@ def update(engine, batch):
# Evaluation function and evaluator (evaluator output is the input of the metrics)
def inference(engine, batch):
model.eval()
logger.info("Passing inference")
with torch.no_grad():
batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
##messes around comment printing token
##logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
# if we dont send labels to model, it doesnt return losses
lm_logits, mc_logits, *_ = model(
input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
Expand Down Expand Up @@ -232,7 +263,7 @@ def inference(engine, batch):

# Prepare metrics - note how we compute distributed metrics
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0])),
metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
"accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args),
"average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)})
Expand All @@ -251,7 +282,7 @@ def inference(engine, batch):

tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED)
tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', save_interval=1, n_saved=3)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) # "getattr" takes care of distributed encapsulation
Expand All @@ -265,8 +296,8 @@ def inference(engine, batch):

# On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
if args.local_rank in [-1, 0] and args.n_epochs > 0:
os.rename(os.path.join(log_dir, checkpoint_handler._saved[-1][1]), os.path.join(log_dir, WEIGHTS_NAME)) # TODO: PR in ignite to have better access to saved file paths (cleaner)
os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(log_dir, WEIGHTS_NAME)) # TODO: PR in ignite to have better access to saved file paths (cleaner)
tb_logger.close()

if __name__ == "__main__":
train()
train()
Loading

0 comments on commit 21c90f1

Please sign in to comment.