forked from huggingface/transfer-learning-conv-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
157 lines (129 loc) · 6.83 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright (c) 2019-present, HuggingFace Inc.
# All rights reserved. This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from datetime import datetime
import json
import logging
import multiprocessing as mp
import os
import tarfile
import tempfile
import socket
import torch
from pytorch_transformers import cached_path
PERSONACHAT_URL = "https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json"
HF_FINETUNED_MODEL = "https://s3.amazonaws.com/models.huggingface.co/transfer-learning-chatbot/gpt_personachat_cache.tar.gz"
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
mp.log_to_stderr(level=logging.INFO)
mp_logger = mp.get_logger()
mp_logger.setLevel(level=logging.INFO)
def download_pretrained_model():
""" Download and extract finetuned model from S3 """
resolved_archive_file = cached_path(HF_FINETUNED_MODEL)
tempdir = tempfile.mkdtemp()
logger.info("extracting archive file {} to temp dir {}".format(resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
return tempdir
def worker_tokenize(args_list):
"""Target function for multiprocessing text encoding. All input args are included in a list as workaround
for worker_tokenize() calling itself recursively with constant tokenizer as one argument.
IMPORTANT: This function has to be implemented globally (outside of get_dataset()) to avoid
multiprocessing error 'AttributeError: Can't pickle local object 'get_dataset.<locals>.worker_tokenize''.
Args:
args_list: [obj, tokenizer] as workaround for recursive self-calling of function within itself."""
obj = args_list[0]
tokenizer = args_list[1]
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
worker_tokenize._dict_key_calls += 1
mp_logger.debug(
'Encoding {}. obj.key() = {}, obj.items().__len__() = {}'.format(worker_tokenize._dict_key_calls,
obj.keys(), obj.items().__len__()))
return dict((n, worker_tokenize([o, tokenizer])) for n, o in obj.items())
return list(worker_tokenize([o, tokenizer]) for o in obj)
worker_tokenize._dict_key_calls = 0
def get_dataset(tokenizer, dataset_path, dataset_cache=None):
""" Get PERSONACHAT from S3 """
dataset_path = dataset_path or PERSONACHAT_URL
dataset_cache = dataset_cache + '_' + type(tokenizer).__name__ # Do avoid using GPT cache for GPT-2 and vice-versa
if dataset_cache and os.path.isfile(dataset_cache):
logger.info("Load tokenized dataset from cache at %s", dataset_cache)
dataset = torch.load(dataset_cache)
else:
logger.info("Download dataset from %s", dataset_path)
personachat_file = cached_path(dataset_path)
with open(personachat_file, "r", encoding="utf-8") as f:
dataset = json.loads(f.read())
logger.info("Tokenize and encode the dataset")
def tokenize(obj):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
tokenize.dict_key_calls += 1
logger.debug(
'Encoding {}. obj.keys() = {}, obj.items().__len__() = {}'.format(tokenize.dict_key_calls,
obj.keys(),
obj.items().__len__()))
return dict((n, tokenize(o)) for n, o in obj.items())
min_samples_for_multiprocessing = 100
if obj.__len__() > min_samples_for_multiprocessing:
logger.debug(' Encoding VERY LONG list of obj.__len__() = {}'.format(obj.__len__()))
logger.debug(' Encoding list with with multiprocessing...')
"""functools.partial does not work becuase tokenizer has to be handed recusively together with obj to
worker_tokenize again. As a workaround of not knowing how to handle splash-operator for possible
dict-output and **kwargs input, the list_args is implemented."""
with mp.Pool(processes=mp.cpu_count() - 1) as pool:
results = pool.map(func=worker_tokenize,
iterable=[[o, tokenizer] for o in obj])
return results
else:
logger.debug(' Encoding list of obj.__len__() = {}'.format(obj.__len__()))
return list(tokenize(o) for o in obj)
tokenize.dict_key_calls = 0
dataset = tokenize(dataset)
# dataset = tokenize(dataset)
if dataset_cache:
torch.save(dataset, dataset_cache)
return dataset
def get_dataset_personalities(tokenizer, dataset_path, dataset_cache=None):
""" Get personalities from PERSONACHAT """
dataset_path = dataset_path or PERSONACHAT_URL
dataset_cache = dataset_cache + '_' + type(tokenizer).__name__ # Do avoid using GPT cache for GPT-2 and vice-versa
if os.path.isfile(dataset_cache):
logger.info("Load tokenized dataset from cache at %s", dataset_cache)
personachat = torch.load(dataset_cache)
else:
logger.info("Download PERSONACHAT dataset from %s", dataset_path)
personachat_file = cached_path(dataset_path)
with open(personachat_file, "r", encoding="utf-8") as f:
personachat = json.loads(f.read())
logger.info("Tokenize and encode the dataset")
def tokenize(obj):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
personachat = tokenize(personachat)
# torch.save(personachat, dataset_cache)
logger.info("Filter personalities")
personalities = []
for dataset in personachat.values():
for dialog in dataset:
personalities.append(dialog["personality"])
logger.info("Gathered {} personalities".format(len(personalities)))
return personalities
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def make_logdir(model_name: str):
"""Create unique path to save results and checkpoints, e.g. runs/Sep22_19-45-59_gpu-7_gpt2"""
# Code copied from ignite repo
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
logdir = os.path.join(
'runs', current_time + '_' + socket.gethostname() + '_' + model_name)
return logdir