Skip to content

Commit

Permalink
Merge pull request OpenNMT#679 from OpenNMT/prep
Browse files Browse the repository at this point in the history
Prep
  • Loading branch information
srush authored Apr 11, 2018
2 parents 1817ed1 + bad9e8e commit fb82df7
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 9 deletions.
4 changes: 2 additions & 2 deletions docs/source/options/preprocess.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ Optimal value should be multiples of 64 bytes.

### **Vocab**:
* **-src_vocab []**
Path to an existing source vocabulary
Path to an existing source vocabulary. Format: one word per line.

* **-tgt_vocab []**
Path to an existing target vocabulary
Path to an existing target vocabulary. Format: one word per line.

* **-features_vocabs_prefix []**
Path prefix to existing features vocabularies
Expand Down
3 changes: 2 additions & 1 deletion docs/source/options/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ Use a shared weight matrix for the input and output word embeddings in the
decoder.

* **-share_embeddings []**
Share the word embeddings between encoder and decoder. It requires using `-share_vocab` during pre-processing.
Share the word embeddings between encoder and decoder. Need to use shared
dictionary for this option.

* **-position_encoding []**
Use a sin to mark relative words positions. Necessary for non-RNN style models.
Expand Down
7 changes: 7 additions & 0 deletions docs/source/options/translate.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ Google NMT length penalty parameter (higher = longer generation)
* **-beta []**
Coverage penalty parameter

* **-block_ngram_repeat []**
Block repetition of ngrams during decoding.

* **-ignore_when_blocking []**
Ignore these strings when blocking repeats. You want to block sentence
delimiters.

* **-replace_unk []**
Replace the generated UNK tokens with the source token that had highest
attention weight. If phrase_table is provided, it will lookup the identified
Expand Down
34 changes: 32 additions & 2 deletions onmt/io/IO.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-

import os
from collections import Counter, defaultdict, OrderedDict
from itertools import count

Expand Down Expand Up @@ -226,17 +227,19 @@ def _build_field_vocab(field, counter, **kwargs):


def build_vocab(train_dataset_files, fields, data_type, share_vocab,
src_vocab_size, src_words_min_frequency,
tgt_vocab_size, tgt_words_min_frequency):
src_vocab_path, src_vocab_size, src_words_min_frequency,
tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency):
"""
Args:
train_dataset_files: a list of train dataset pt file.
fields (dict): fields to build vocab for.
data_type: "text", "img" or "audio"?
share_vocab(bool): share source and target vocabulary?
src_vocab_path(string): Path to src vocabulary file.
src_vocab_size(int): size of the source vocabulary.
src_words_min_frequency(int): the minimum frequency needed to
include a source word in the vocabulary.
tgt_vocab_path(string): Path to tgt vocabulary file.
tgt_vocab_size(int): size of the target vocabulary.
tgt_words_min_frequency(int): the minimum frequency needed to
include a target word in the vocabulary.
Expand All @@ -248,6 +251,29 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
for k in fields:
counter[k] = Counter()

# Load vocabulary
src_vocab = None
if len(src_vocab_path) > 0:
src_vocab = set([])
print('Loading source vocab from %s' % src_vocab_path)
assert os.path.exists(src_vocab_path), \
'src vocab %s not found!' % src_vocab_path
with open(src_vocab_path) as f:
for line in f:
word = line.strip().split()[0]
src_vocab.add(word)

tgt_vocab = None
if len(tgt_vocab_path) > 0:
tgt_vocab = set([])
print('Loading target vocab from %s' % tgt_vocab_path)
assert os.path.exists(tgt_vocab_path), \
'tgt vocab %s not found!' % tgt_vocab_path
with open(tgt_vocab_path) as f:
for line in f:
word = line.strip().split()[0]
tgt_vocab.add(word)

for path in train_dataset_files:
dataset = torch.load(path)
print(" * reloading %s." % path)
Expand All @@ -256,6 +282,10 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
val = getattr(ex, k, None)
if val is not None and not fields[k].sequential:
val = [val]
elif k == 'src' and src_vocab:
val = [item for item in val if item in src_vocab]
elif k == 'tgt' and tgt_vocab:
val = [item for item in val if item in tgt_vocab]
counter[k].update(val)

_build_field_vocab(fields["tgt"], counter["tgt"],
Expand Down
10 changes: 6 additions & 4 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ def preprocess_opts(parser):
# Dictionary options, for text corpus

group = parser.add_argument_group('Vocab')
group.add_argument('-src_vocab',
help="Path to an existing source vocabulary")
group.add_argument('-tgt_vocab',
help="Path to an existing target vocabulary")
group.add_argument('-src_vocab', default="",
help="""Path to an existing source vocabulary. Format:
one word per line.""")
group.add_argument('-tgt_vocab', default="",
help="""Path to an existing target vocabulary. Format:
one word per line.""")
group.add_argument('-features_vocabs_prefix', type=str, default='',
help="Path prefix to existing features vocabularies")
group.add_argument('-src_vocab_size', type=int, default=50000,
Expand Down
2 changes: 2 additions & 0 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,10 @@ def build_save_dataset(corpus_type, fields, opt):
def build_save_vocab(train_dataset, fields, opt):
fields = onmt.io.build_vocab(train_dataset, fields, opt.data_type,
opt.share_vocab,
opt.src_vocab,
opt.src_vocab_size,
opt.src_words_min_frequency,
opt.tgt_vocab,
opt.tgt_vocab_size,
opt.tgt_words_min_frequency)

Expand Down
14 changes: 14 additions & 0 deletions test/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest
import glob
import os
import codecs
from collections import Counter

import torchtext
Expand Down Expand Up @@ -39,6 +40,13 @@ def __init__(self, *args, **kwargs):
def dataset_build(self, opt):
fields = onmt.io.get_fields("text", 0, 0)

if hasattr(opt, 'src_vocab') and len(opt.src_vocab) > 0:
with codecs.open(opt.src_vocab, 'w', 'utf-8') as f:
f.write('a\nb\nc\nd\ne\nf\n')
if hasattr(opt, 'tgt_vocab') and len(opt.tgt_vocab) > 0:
with codecs.open(opt.tgt_vocab, 'w', 'utf-8') as f:
f.write('a\nb\nc\nd\ne\nf\n')

train_data_files = preprocess.build_save_dataset('train', fields, opt)

preprocess.build_save_vocab(train_data_files, fields, opt)
Expand All @@ -48,6 +56,10 @@ def dataset_build(self, opt):
# Remove the generated *pt files.
for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'):
os.remove(pt)
if hasattr(opt, 'src_vocab') and os.path.exists(opt.src_vocab):
os.remove(opt.src_vocab)
if hasattr(opt, 'tgt_vocab') and os.path.exists(opt.tgt_vocab):
os.remove(opt.tgt_vocab)

def test_merge_vocab(self):
va = torchtext.vocab.Vocab(Counter('abbccc'))
Expand Down Expand Up @@ -109,6 +121,8 @@ def test_method(self):
('share_vocab', True)],
[('dynamic_dict', True),
('max_shard_size', 500000)],
[('src_vocab', '/tmp/src_vocab.txt'),
('tgt_vocab', '/tmp/tgt_vocab.txt')],
]

for p in test_databuild:
Expand Down

0 comments on commit fb82df7

Please sign in to comment.