Skip to content

Commit

Permalink
fix: black
Browse files Browse the repository at this point in the history
  • Loading branch information
helderlopes97 committed Feb 9, 2024
1 parent 083a0e1 commit 4a66b28
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 136 deletions.
95 changes: 58 additions & 37 deletions onmt/bin/train_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,40 @@
from onmt.utils.logging import init_logger, logger
from onmt.train_single import main as single_main
from onmt.utils.parse import ArgumentParser
from onmt.inputters.inputter import build_dataset_iter, \
load_old_vocab, old_style_vocab, build_dataset_iter_multiple
from onmt.inputters.inputter import (
build_dataset_iter,
load_old_vocab,
old_style_vocab,
build_dataset_iter_multiple,
)

from itertools import cycle
import torch.cuda.profiler as profiler
import pyprof2

pyprof2.init()


def train(opt):
ArgumentParser.validate_train_opts(opt)
ArgumentParser.update_model_opts(opt)
ArgumentParser.validate_model_opts(opt)

# Load checkpoint if we resume from a previous training.
if opt.train_from:
logger.info('Loading checkpoint from %s' % opt.train_from)
checkpoint = torch.load(opt.train_from,
map_location=lambda storage, loc: storage)
logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
vocab = checkpoint['vocab']
logger.info("Loading checkpoint from %s" % opt.train_from)
checkpoint = torch.load(
opt.train_from, map_location=lambda storage, loc: storage
)
logger.info("Loading vocab from checkpoint at %s." % opt.train_from)
vocab = checkpoint["vocab"]
else:
vocab = torch.load(opt.data + '.vocab.pt')
vocab = torch.load(opt.data + ".vocab.pt")

# check for code where vocab is saved instead of fields
# (in the future this will be done in a smarter way)
if old_style_vocab(vocab):
fields = load_old_vocab(
vocab, opt.model_type, dynamic_dict=opt.copy_attn)
fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn)
else:
fields = vocab

Expand All @@ -49,7 +55,7 @@ def train(opt):
train_shards.append(shard_base)
train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
else:
if opt.data_ids[0] is not None and opt.data_ids[0] != 'None':
if opt.data_ids[0] is not None and opt.data_ids[0] != "None":
shard_base = "train_" + opt.data_ids[0]
else:
shard_base = "train"
Expand All @@ -59,7 +65,7 @@ def train(opt):

if opt.world_size > 1:
queues = []
mp = torch.multiprocessing.get_context('spawn')
mp = torch.multiprocessing.get_context("spawn")
semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
Expand All @@ -69,14 +75,26 @@ def train(opt):
for device_id in range(nb_gpu):
q = mp.Queue(opt.queue_size)
queues += [q]
procs.append(mp.Process(target=run, args=(
opt, device_id, error_queue, q, semaphore), daemon=True))
procs.append(
mp.Process(
target=run,
args=(opt, device_id, error_queue, q, semaphore),
daemon=True,
)
)
procs[device_id].start()
logger.info(" Starting process pid: %d " % procs[device_id].pid)
error_handler.add_child(procs[device_id].pid)
producer = mp.Process(target=batch_producer,
args=(train_iter, queues, semaphore, opt,),
daemon=True)
producer = mp.Process(
target=batch_producer,
args=(
train_iter,
queues,
semaphore,
opt,
),
daemon=True,
)
producer.start()
error_handler.add_child(producer.pid)

Expand All @@ -86,7 +104,7 @@ def train(opt):

elif nb_gpu == 1: # case 1 GPU only
single_main(opt, 0)
else: # case only CPU
else: # case only CPU
single_main(opt, -1)


Expand All @@ -104,8 +122,7 @@ def pred(x):
if x[0] % opt.world_size == rank:
return True

generator_to_serve = filter(
pred, enumerate(generator_to_serve))
generator_to_serve = filter(pred, enumerate(generator_to_serve))

def next_batch(device_id):
new_batch = next(generator_to_serve)
Expand All @@ -117,16 +134,17 @@ def next_batch(device_id):
for device_id, q in cycle(enumerate(queues)):
b.dataset = None
if isinstance(b.src, tuple):
b.src = tuple([_.to(torch.device(device_id))
for _ in b.src])
b.src = tuple([_.to(torch.device(device_id)) for _ in b.src])
else:
b.src = b.src.to(torch.device(device_id))
b.tgt = b.tgt.to(torch.device(device_id))
b.indices = b.indices.to(torch.device(device_id))
b.alignment = b.alignment.to(torch.device(device_id)) \
if hasattr(b, 'alignment') else None
b.src_map = b.src_map.to(torch.device(device_id)) \
if hasattr(b, 'src_map') else None
b.alignment = (
b.alignment.to(torch.device(device_id)) if hasattr(b, "alignment") else None
)
b.src_map = (
b.src_map.to(torch.device(device_id)) if hasattr(b, "src_map") else None
)

# hack to dodge unpicklable `dict_keys`
b.fields = list(b.fields)
Expand All @@ -135,18 +153,21 @@ def next_batch(device_id):


def run(opt, device_id, error_queue, batch_queue, semaphore):
""" run process """
"""run process"""
try:
gpu_rank = onmt.utils.distributed.multi_init(opt, device_id)
if gpu_rank != opt.gpu_ranks[device_id]:
raise AssertionError("An error occurred in \
Distributed initialization")
raise AssertionError(
"An error occurred in \
Distributed initialization"
)
single_main(opt, device_id, batch_queue, semaphore)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback

error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc()))


Expand All @@ -155,28 +176,28 @@ class ErrorHandler(object):
the tracebacks to the parent process."""

def __init__(self, error_queue):
""" init error handler """
"""init error handler"""
import signal
import threading

self.error_queue = error_queue
self.children_pids = []
self.error_thread = threading.Thread(
target=self.error_listener, daemon=True)
self.error_thread = threading.Thread(target=self.error_listener, daemon=True)
self.error_thread.start()
signal.signal(signal.SIGUSR1, self.signal_handler)

def add_child(self, pid):
""" error handler """
"""error handler"""
self.children_pids.append(pid)

def error_listener(self):
""" error listener """
"""error listener"""
(rank, original_trace) = self.error_queue.get()
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)

def signal_handler(self, signalnum, stackframe):
""" signal handler """
"""signal handler"""
for pid in self.children_pids:
os.kill(pid, signal.SIGINT) # kill children processes
(rank, original_trace) = self.error_queue.get()
Expand All @@ -187,7 +208,7 @@ def signal_handler(self, signalnum, stackframe):


def _get_parser():
parser = ArgumentParser(description='train.py')
parser = ArgumentParser(description="train.py")

opts.config_opts(parser)
opts.model_opts(parser)
Expand All @@ -203,4 +224,4 @@ def main():


if __name__ == "__main__":
main()
main()
83 changes: 59 additions & 24 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,51 @@ def _add_logging_opts(parser, is_train=True):
help="Log directory for Tensorboard. " "This is also the name of the run.",
)
# Use MLflow for logging training runs and config parameters
group.add('--mlflow', '-mlflow', action="store_true",
help="Use mlflow to log training runs and parameters. "
"Must have the library mlflow >= 1.3.0")
group.add("--mlflow_experiment_name", "-mlflow_experiment_name",
type=str, default=None,
help="MLflow experiment name")
group.add("--mlflow_run_name", "-mlflow_run_name",
type=str, default=None,
help="MLflow run name")
group.add(
"--mlflow",
"-mlflow",
action="store_true",
help="Use mlflow to log training runs and parameters. "
"Must have the library mlflow >= 1.3.0",
)
group.add(
"--mlflow_experiment_name",
"-mlflow_experiment_name",
type=str,
default=None,
help="MLflow experiment name",
)
group.add(
"--mlflow_run_name",
"-mlflow_run_name",
type=str,
default=None,
help="MLflow run name",
)

# Use MLflow for logging training runs and config parameters
# https://docs.wandb.ai/guides/track/advanced/environment-variables
# should be set: WANDB_API_KEY / WANDB_BASE_URL
group.add('--wandb', '-wandb', action="store_true",
help="Use wandb to log training runs and parameters. ")
group.add("--wandb_project_name", "-wandb_project_name",
type=str, default=None,
help="wandb experiment name")
group.add("--wandb_run_name", "-wandb_run_name",
type=str, default=None,
help="wandb run name")
group.add(
"--wandb",
"-wandb",
action="store_true",
help="Use wandb to log training runs and parameters. ",
)
group.add(
"--wandb_project_name",
"-wandb_project_name",
type=str,
default=None,
help="wandb experiment name",
)
group.add(
"--wandb_run_name",
"-wandb_run_name",
type=str,
default=None,
help="wandb run name",
)
group.add(
"--override_opts",
"-override-opts",
Expand Down Expand Up @@ -1904,8 +1928,12 @@ def translate_opts(parser):
help="Path to output the predictions (each line will "
"be the decoded sequence",
)
group.add('--log_probs', '-log_probs', action='store_true',
help="Output file with log_probs and gold_score ")
group.add(
"--log_probs",
"-log_probs",
action="store_true",
help="Output file with log_probs and gold_score ",
)
group.add(
"--report_align",
"-report_align",
Expand Down Expand Up @@ -1954,8 +1982,12 @@ def translate_opts(parser):
)
group.add("--gpu", "-gpu", type=int, default=-1, help="Device to run on")

group.add('--num_threads', '-num_threads', type=int,
help="Number of CPUs to use for translation")
group.add(
"--num_threads",
"-num_threads",
type=int,
help="Number of CPUs to use for translation",
)

group.add(
"-transforms",
Expand All @@ -1966,9 +1998,12 @@ def translate_opts(parser):
help="Default transform pipeline to apply to data.",
)

group = parser.add_argument_group('ibmrxn')
group.add_argument('--is_ibmrxn', action='store_true',
help='Translate returns in a format that is compatible with the api')
group = parser.add_argument_group("ibmrxn")
group.add_argument(
"--is_ibmrxn",
action="store_true",
help="Translate returns in a format that is compatible with the api",
)

# Adding options related to Transforms
_add_transform_opts(parser)
Expand Down
Loading

0 comments on commit 4a66b28

Please sign in to comment.