Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #620 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.5.3
  • Loading branch information
lukaszkaiser authored Feb 26, 2018
2 parents b929e30 + 3ab6147 commit 5c0d89e
Show file tree
Hide file tree
Showing 36 changed files with 2,292 additions and 1,148 deletions.
4 changes: 2 additions & 2 deletions docs/cloud_mlengine.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ principle work just fine. Contributions/testers welcome.
Launching on Cloud ML Engine works with `--t2t_usr_dir` as well as long as the
directory is fully self-contained (i.e. the imports only refer to other modules
in the directory). If there are additional PyPI dependencies that you need, you
can include a `setup.py` file in your directory (ensure that it uses
`setuptools.find_packages`).
can include a `requirements.txt` file in the directory specified by
`t2t_usr_dir`.

# Hyperparameter Tuning

Expand Down
10 changes: 5 additions & 5 deletions docs/new_problem.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ class PoetryLines(text_problems.Text2TextProblem):
# 10% evaluation data
return [{
"split": problem.DatasetSplit.TRAIN,
"shards": 90,
"shards": 9,
}, {
"split": problem.DatasetSplit.EVAL,
"shards": 10,
"shards": 1,
}]

def generate_samples(self, data_dir, tmp_dir, dataset_split):
Expand Down Expand Up @@ -133,7 +133,7 @@ pre-existing "training" and "evaluation" sets. If we did, we'd set
split.

The `dataset_splits` method determines the fraction that goes to each split. The
training data will be generated into 90 files and the evaluation data into 10.
training data will be generated into 9 files and the evaluation data into 1.
90% of the data will be for training. 10% of the data will be for evaluation.

```python
Expand All @@ -148,10 +148,10 @@ training data will be generated into 90 files and the evaluation data into 10.
# 10% evaluation data
return [{
"split": problem.DatasetSplit.TRAIN,
"shards": 90,
"shards": 9,
}, {
"split": problem.DatasetSplit.EVAL,
"shards": 10,
"shards": 1,
}]
```

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.5.2',
version='1.5.3',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/bin/t2t_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
flags.DEFINE_string("schedule", "continuous_train_and_eval",
"Method of Experiment to run.")
flags.DEFINE_integer("eval_steps", 10000,
flags.DEFINE_integer("eval_steps", 100,
"Number of steps in evaluation. By default, eval will "
"stop after eval_steps or when it runs through the eval "
"dataset once in full, whichever comes first, so this "
Expand Down
21 changes: 13 additions & 8 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def generate_files(generator, output_filenames, max_cases=None):
for writer in writers:
writer.close()

tf.logging.info("Generated %s Examples", counter)


def download_report_hook(count, block_size, total_size):
"""Report hook for download progress.
Expand Down Expand Up @@ -198,19 +200,22 @@ def maybe_download(directory, filename, uri):
"""
if not tf.gfile.Exists(directory):
tf.logging.info("Creating directory %s" % directory)
os.mkdir(directory)
tf.gfile.MakeDirs(directory)
filepath = os.path.join(directory, filename)
if not tf.gfile.Exists(filepath):
tf.logging.info("Downloading %s to %s" % (uri, filepath))
try:
tf.gfile.Copy(uri, filepath)
except tf.errors.UnimplementedError:
inprogress_filepath = filepath + ".incomplete"
inprogress_filepath, _ = urllib.urlretrieve(
uri, inprogress_filepath, reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress
print()
tf.gfile.Rename(inprogress_filepath, filepath)
if uri.startswith("http"):
inprogress_filepath = filepath + ".incomplete"
inprogress_filepath, _ = urllib.urlretrieve(
uri, inprogress_filepath, reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress
print()
tf.gfile.Rename(inprogress_filepath, filepath)
else:
raise ValueError("Unrecognized URI: " + filepath)
statinfo = os.stat(filepath)
tf.logging.info("Successfully downloaded %s, %s bytes." %
(filename, statinfo.st_size))
Expand All @@ -232,7 +237,7 @@ def maybe_download_from_drive(directory, filename, url):
"""
if not tf.gfile.Exists(directory):
tf.logging.info("Creating directory %s" % directory)
os.mkdir(directory)
tf.gfile.MakeDirs(directory)
filepath = os.path.join(directory, filename)
confirm_token = None
if tf.gfile.Exists(filepath):
Expand Down
20 changes: 9 additions & 11 deletions tensor2tensor/data_generators/ice_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,10 @@
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import translate
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry


# End-of-sentence marker.
EOS = text_encoder.EOS_ID


def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix,
source_vocab_size, target_vocab_size):
"""Generate source and target data from a single file."""
Expand All @@ -51,17 +47,18 @@ def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix,
data_dir, tmp_dir, filename, 1,
prefix + "_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size)
pair_filepath = os.path.join(tmp_dir, filename)
return translate.tabbed_generator(pair_filepath, source_vocab, target_vocab,
EOS)
return text_problems.text2text_generate_encoded(
text_problems.text2text_txt_tab_iterator(pair_filepath), source_vocab,
target_vocab)


def tabbed_parsing_character_generator(tmp_dir, train):
"""Generate source and target data from a single file."""
character_vocab = text_encoder.ByteTextEncoder()
filename = "parsing_{0}.pairs".format("train" if train else "dev")
pair_filepath = os.path.join(tmp_dir, filename)
return translate.tabbed_generator(pair_filepath, character_vocab,
character_vocab, EOS)
return text_problems.text2text_generate_encoded(
text_problems.text2text_txt_tab_iterator(pair_filepath), character_vocab)


@registry.register_problem
Expand Down Expand Up @@ -114,8 +111,9 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
def hparams(self, defaults, unused_model_hparams):
p = defaults
source_vocab_size = self._encoders["inputs"].vocab_size
p.input_modality = {"inputs": (registry.Modalities.SYMBOL,
source_vocab_size)}
p.input_modality = {
"inputs": (registry.Modalities.SYMBOL, source_vocab_size)
}
p.target_modality = (registry.Modalities.SYMBOL, self.targeted_vocab_size)
p.input_space_id = self.input_space_id
p.target_space_id = self.target_space_id
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/lm1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class LanguagemodelLm1b32k(text_problems.Text2TextProblem):
"""A language model on the 1B words corpus."""

@property
def vocab_name(self):
def vocab_filename(self):
return "vocab.lm1b.en.%d" % self.approx_vocab_size

@property
Expand Down
7 changes: 6 additions & 1 deletion tensor2tensor/data_generators/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class Problem(object):
data_dir. Vocab files are newline-separated files with each line
containing a token. The standard convention for the filename is to
set it to be
${Problem.vocab_name}.${Problem.targeted_vocab_size}
${Problem.vocab_filename}.${Problem.targeted_vocab_size}
- Downloads and other files can be written to tmp_dir
- If you have a training and dev generator, you can generate the
training and dev datasets with
Expand Down Expand Up @@ -721,6 +721,11 @@ def define_shapes(example):
dataset = dataset.repeat()
data_files = tf.contrib.slim.parallel_reader.get_data_files(
self.filepattern(data_dir, mode))
# In continuous_train_and_eval when switching between train and
# eval, this input_fn method gets called multiple times and it
# would give you the exact same samples from the last call
# (because the Graph seed is set). So this skip gives you some
# shuffling.
dataset = skip_random_fraction(dataset, data_files[0])

dataset = dataset.map(
Expand Down
19 changes: 9 additions & 10 deletions tensor2tensor/data_generators/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,12 @@ def _init_vocab_from_file(self, filename):
Args:
filename: The file to load vocabulary from.
"""
with tf.gfile.Open(filename) as f:
tokens = [token.strip() for token in f.readlines()]

def token_gen():
with tf.gfile.Open(filename) as f:
for line in f:
token = line.strip()
yield token
for token in tokens:
yield token

self._init_vocab(token_gen(), add_reserved_tokens=False)

Expand Down Expand Up @@ -379,7 +380,7 @@ def match(m):
try:
return six.unichr(int(m.group(1)))
except (ValueError, OverflowError) as _:
return u"\u3013"
return u"\u3013" # Unicode for undefined character.

trimmed = escaped_token[:-1] if escaped_token.endswith("_") else escaped_token
return _UNESCAPE_REGEX.sub(match, trimmed)
Expand Down Expand Up @@ -827,11 +828,9 @@ def _load_from_file_object(self, f):
self._init_alphabet_from_tokens(subtoken_strings)

def _load_from_file(self, filename):
"""Load from a file.
Args:
filename: Filename to load vocabulary from
"""
"""Load from a vocab file."""
if not tf.gfile.Exists(filename):
raise ValueError("File %s not found" % filename)
with tf.gfile.Open(filename) as f:
self._load_from_file_object(f)

Expand Down
39 changes: 23 additions & 16 deletions tensor2tensor/data_generators/text_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,8 @@ def _maybe_pack_examples(self, generator):
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
encoder = self.get_or_create_vocab(data_dir, tmp_dir)
for sample in generator:
targets = encoder.encode(sample["targets"])
targets.append(text_encoder.EOS_ID)
encoded_sample = {"targets": targets}
if self.has_inputs:
inputs = encoder.encode(sample["inputs"])
inputs.append(text_encoder.EOS_ID)
encoded_sample["inputs"] = inputs
yield encoded_sample
return text2text_generate_encoded(generator, encoder,
has_inputs=self.has_inputs)

@property
def batch_size_means_tokens(self):
Expand All @@ -244,15 +237,15 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
problem.DatasetSplit.TEST: self.test_filepaths,
}

split_paths = dict([(split["split"], filepath_fns[split["split"]](
split_paths = [(split["split"], filepath_fns[split["split"]](
data_dir, split["shards"], shuffled=False))
for split in self.dataset_splits])
for split in self.dataset_splits]
all_paths = []
for paths in split_paths.values():
for _, paths in split_paths:
all_paths.extend(paths)

if self.is_generate_per_split:
for split, paths in split_paths.items():
for split, paths in split_paths:
generator_utils.generate_files(
self._maybe_pack_examples(
self.generate_encoded_samples(data_dir, tmp_dir, split)), paths)
Expand Down Expand Up @@ -418,8 +411,7 @@ def example_reading_spec(self):
def txt_line_iterator(txt_path):
"""Iterate through lines of file."""
with tf.gfile.Open(txt_path) as f:
readline = lambda: f.readline()
for line in iter(readline, ""):
for line in f:
yield line.strip()


Expand Down Expand Up @@ -472,11 +464,26 @@ def text2text_txt_tab_iterator(txt_path):
"""
for line in txt_line_iterator(txt_path):
if line and "\t" in line:
parts = line.split("\t")
parts = line.split("\t", 1)
inputs, targets = parts[:2]
yield {"inputs": inputs.strip(), "targets": targets.strip()}


def text2text_generate_encoded(sample_generator,
vocab,
targets_vocab=None,
has_inputs=True):
"""Encode Text2Text samples from the generator with the vocab."""
targets_vocab = targets_vocab or vocab
for sample in sample_generator:
if has_inputs:
sample["inputs"] = vocab.encode(sample["inputs"])
sample["inputs"].append(text_encoder.EOS_ID)
sample["targets"] = targets_vocab.encode(sample["targets"])
sample["targets"].append(text_encoder.EOS_ID)
yield sample


@registry.register_problem
class Text2textTmpdir(Text2TextProblem):
"""Allows training a Text2TextProblem without defining a subclass.
Expand Down
Loading

0 comments on commit 5c0d89e

Please sign in to comment.