forked from huggingface/transfer-learning-conv-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
68 lines (56 loc) · 2.64 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) 2019-present, HuggingFace Inc.
# All rights reserved. This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from datetime import datetime
import json
import logging
import os
import tarfile
import tempfile
import socket
import torch
from transformers import cached_path
PERSONACHAT_URL = "https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json"
HF_FINETUNED_MODEL = "https://s3.amazonaws.com/models.huggingface.co/transfer-learning-chatbot/gpt_personachat_cache.tar.gz"
logger = logging.getLogger(__file__)
def download_pretrained_model():
""" Download and extract finetuned model from S3 """
resolved_archive_file = cached_path(HF_FINETUNED_MODEL)
tempdir = tempfile.mkdtemp()
logger.info("extracting archive file {} to temp dir {}".format(resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
return tempdir
def get_dataset(tokenizer, dataset_path, dataset_cache):
""" Get tokenized PERSONACHAT dataset from S3 or cache."""
dataset_path = dataset_path or PERSONACHAT_URL
dataset_cache = dataset_cache + '_' + type(tokenizer).__name__ # To avoid using GPT cache for GPT-2 and vice-versa
if dataset_cache and os.path.isfile(dataset_cache):
logger.info("Load tokenized dataset from cache at %s", dataset_cache)
dataset = torch.load(dataset_cache)
else:
logger.info("Download dataset from %s", dataset_path)
personachat_file = cached_path(dataset_path)
with open(personachat_file, "r", encoding="utf-8") as f:
dataset = json.loads(f.read())
logger.info("Tokenize and encode the dataset")
def tokenize(obj):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
dataset = tokenize(dataset)
torch.save(dataset, dataset_cache)
return dataset
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def make_logdir(model_name: str):
"""Create unique path to save results and checkpoints, e.g. runs/Sep22_19-45-59_gpu-7_gpt2"""
# Code copied from ignite repo
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
logdir = os.path.join(
'runs', current_time + '_' + socket.gethostname() + '_' + model_name)
return logdir