forked from OpenNMT/OpenNMT-py
-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: try and replicate changes from fork to the latest OpenNMT-py
- Loading branch information
1 parent
a02330a
commit 083a0e1
Showing
10 changed files
with
538 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
name: Build and publish rxn-onmt-py on PyPI | ||
|
||
on: | ||
push: | ||
tags: | ||
- 'v*' | ||
|
||
jobs: | ||
build-and-publish: | ||
name: Build and publish rxn-onmt-py on PyPI | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@master | ||
- name: Python setup 3.8 | ||
uses: actions/setup-python@v1 | ||
with: | ||
python-version: 3.8 | ||
- name: Install build package (for packaging) | ||
run: pip install --upgrade build | ||
- name: Build dist | ||
run: python -m build | ||
- name: Publish to PyPI | ||
uses: pypa/gh-action-pypi-publish@release/v1 | ||
with: | ||
user: __token__ | ||
password: ${{ secrets.PYPI_TOKEN }} | ||
skip_existing: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,4 +21,4 @@ | |
onmt.modules, | ||
] | ||
|
||
__version__ = "3.4.3" | ||
__version__ = "2.0.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
#!/usr/bin/env python | ||
"""Train models.""" | ||
import os | ||
import signal | ||
import torch | ||
|
||
import onmt.opts as opts | ||
import onmt.utils.distributed | ||
|
||
from onmt.utils.misc import set_random_seed | ||
from onmt.utils.logging import init_logger, logger | ||
from onmt.train_single import main as single_main | ||
from onmt.utils.parse import ArgumentParser | ||
from onmt.inputters.inputter import build_dataset_iter, \ | ||
load_old_vocab, old_style_vocab, build_dataset_iter_multiple | ||
|
||
from itertools import cycle | ||
import torch.cuda.profiler as profiler | ||
import pyprof2 | ||
pyprof2.init() | ||
|
||
def train(opt): | ||
ArgumentParser.validate_train_opts(opt) | ||
ArgumentParser.update_model_opts(opt) | ||
ArgumentParser.validate_model_opts(opt) | ||
|
||
# Load checkpoint if we resume from a previous training. | ||
if opt.train_from: | ||
logger.info('Loading checkpoint from %s' % opt.train_from) | ||
checkpoint = torch.load(opt.train_from, | ||
map_location=lambda storage, loc: storage) | ||
logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) | ||
vocab = checkpoint['vocab'] | ||
else: | ||
vocab = torch.load(opt.data + '.vocab.pt') | ||
|
||
# check for code where vocab is saved instead of fields | ||
# (in the future this will be done in a smarter way) | ||
if old_style_vocab(vocab): | ||
fields = load_old_vocab( | ||
vocab, opt.model_type, dynamic_dict=opt.copy_attn) | ||
else: | ||
fields = vocab | ||
|
||
if len(opt.data_ids) > 1: | ||
train_shards = [] | ||
for train_id in opt.data_ids: | ||
shard_base = "train_" + train_id | ||
train_shards.append(shard_base) | ||
train_iter = build_dataset_iter_multiple(train_shards, fields, opt) | ||
else: | ||
if opt.data_ids[0] is not None and opt.data_ids[0] != 'None': | ||
shard_base = "train_" + opt.data_ids[0] | ||
else: | ||
shard_base = "train" | ||
train_iter = build_dataset_iter(shard_base, fields, opt) | ||
|
||
nb_gpu = len(opt.gpu_ranks) | ||
|
||
if opt.world_size > 1: | ||
queues = [] | ||
mp = torch.multiprocessing.get_context('spawn') | ||
semaphore = mp.Semaphore(opt.world_size * opt.queue_size) | ||
# Create a thread to listen for errors in the child processes. | ||
error_queue = mp.SimpleQueue() | ||
error_handler = ErrorHandler(error_queue) | ||
# Train with multiprocessing. | ||
procs = [] | ||
for device_id in range(nb_gpu): | ||
q = mp.Queue(opt.queue_size) | ||
queues += [q] | ||
procs.append(mp.Process(target=run, args=( | ||
opt, device_id, error_queue, q, semaphore), daemon=True)) | ||
procs[device_id].start() | ||
logger.info(" Starting process pid: %d " % procs[device_id].pid) | ||
error_handler.add_child(procs[device_id].pid) | ||
producer = mp.Process(target=batch_producer, | ||
args=(train_iter, queues, semaphore, opt,), | ||
daemon=True) | ||
producer.start() | ||
error_handler.add_child(producer.pid) | ||
|
||
for p in procs: | ||
p.join() | ||
producer.terminate() | ||
|
||
elif nb_gpu == 1: # case 1 GPU only | ||
single_main(opt, 0) | ||
else: # case only CPU | ||
single_main(opt, -1) | ||
|
||
|
||
def batch_producer(generator_to_serve, queues, semaphore, opt): | ||
init_logger(opt.log_file) | ||
set_random_seed(opt.seed, False) | ||
# generator_to_serve = iter(generator_to_serve) | ||
|
||
def pred(x): | ||
""" | ||
Filters batches that belong only | ||
to gpu_ranks of current node | ||
""" | ||
for rank in opt.gpu_ranks: | ||
if x[0] % opt.world_size == rank: | ||
return True | ||
|
||
generator_to_serve = filter( | ||
pred, enumerate(generator_to_serve)) | ||
|
||
def next_batch(device_id): | ||
new_batch = next(generator_to_serve) | ||
semaphore.acquire() | ||
return new_batch[1] | ||
|
||
b = next_batch(0) | ||
|
||
for device_id, q in cycle(enumerate(queues)): | ||
b.dataset = None | ||
if isinstance(b.src, tuple): | ||
b.src = tuple([_.to(torch.device(device_id)) | ||
for _ in b.src]) | ||
else: | ||
b.src = b.src.to(torch.device(device_id)) | ||
b.tgt = b.tgt.to(torch.device(device_id)) | ||
b.indices = b.indices.to(torch.device(device_id)) | ||
b.alignment = b.alignment.to(torch.device(device_id)) \ | ||
if hasattr(b, 'alignment') else None | ||
b.src_map = b.src_map.to(torch.device(device_id)) \ | ||
if hasattr(b, 'src_map') else None | ||
|
||
# hack to dodge unpicklable `dict_keys` | ||
b.fields = list(b.fields) | ||
q.put(b) | ||
b = next_batch(device_id) | ||
|
||
|
||
def run(opt, device_id, error_queue, batch_queue, semaphore): | ||
""" run process """ | ||
try: | ||
gpu_rank = onmt.utils.distributed.multi_init(opt, device_id) | ||
if gpu_rank != opt.gpu_ranks[device_id]: | ||
raise AssertionError("An error occurred in \ | ||
Distributed initialization") | ||
single_main(opt, device_id, batch_queue, semaphore) | ||
except KeyboardInterrupt: | ||
pass # killed by parent, do nothing | ||
except Exception: | ||
# propagate exception to parent process, keeping original traceback | ||
import traceback | ||
error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) | ||
|
||
|
||
class ErrorHandler(object): | ||
"""A class that listens for exceptions in children processes and propagates | ||
the tracebacks to the parent process.""" | ||
|
||
def __init__(self, error_queue): | ||
""" init error handler """ | ||
import signal | ||
import threading | ||
self.error_queue = error_queue | ||
self.children_pids = [] | ||
self.error_thread = threading.Thread( | ||
target=self.error_listener, daemon=True) | ||
self.error_thread.start() | ||
signal.signal(signal.SIGUSR1, self.signal_handler) | ||
|
||
def add_child(self, pid): | ||
""" error handler """ | ||
self.children_pids.append(pid) | ||
|
||
def error_listener(self): | ||
""" error listener """ | ||
(rank, original_trace) = self.error_queue.get() | ||
self.error_queue.put((rank, original_trace)) | ||
os.kill(os.getpid(), signal.SIGUSR1) | ||
|
||
def signal_handler(self, signalnum, stackframe): | ||
""" signal handler """ | ||
for pid in self.children_pids: | ||
os.kill(pid, signal.SIGINT) # kill children processes | ||
(rank, original_trace) = self.error_queue.get() | ||
msg = """\n\n-- Tracebacks above this line can probably | ||
be ignored --\n\n""" | ||
msg += original_trace | ||
raise Exception(msg) | ||
|
||
|
||
def _get_parser(): | ||
parser = ArgumentParser(description='train.py') | ||
|
||
opts.config_opts(parser) | ||
opts.model_opts(parser) | ||
opts.train_opts(parser) | ||
return parser | ||
|
||
|
||
def main(): | ||
parser = _get_parser() | ||
|
||
opt = parser.parse_args() | ||
train(opt) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.