diff --git a/onmt/bin/train_profile.py b/onmt/bin/train_profile.py index b0c3b1dc64..c388d84bd1 100644 --- a/onmt/bin/train_profile.py +++ b/onmt/bin/train_profile.py @@ -19,7 +19,8 @@ ) from itertools import cycle -import torch.cuda.profiler as profiler + +# import torch.cuda.profiler as profiler import pyprof2 pyprof2.init() diff --git a/onmt/train_single.py b/onmt/train_single.py index 236dee1933..59b91ad562 100644 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -203,39 +203,39 @@ def main(opt, device_id, batch_queue=None, semaphore=None): opt, device_id, model, vocabs, optim, model_saver=model_saver ) - if batch_queue is None: - if len(opt.data_ids) > 1: - train_shards = [] - for train_id in opt.data_ids: - shard_base = "train_" + train_id - train_shards.append(shard_base) - train_iter = build_dataset_iter_multiple(train_shards, fields, opt) - else: - if opt.data_ids[0] is not None: - if opt.data_ids[0] is not None and opt.data_ids[0] != "None": - shard_base = "train_" + opt.data_ids[0] - else: - shard_base = "train" - train_iter = build_dataset_iter(shard_base, fields, opt) - else: - assert semaphore is not None, "Using batch_queue requires semaphore as well" - - def _train_iter(): - while True: - batch = batch_queue.get() - semaphore.release() - yield batch - - train_iter = _train_iter() - valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) - if len(opt.gpu_ranks): - logger.info("Starting training on GPU: %s" % opt.gpu_ranks) - else: - logger.info("Starting training on CPU, could be very slow") - train_steps = opt.train_steps - if opt.single_pass and train_steps > 0: - logger.warning("Option single_pass is enabled, ignoring train_steps.") - train_steps = 0 + # if batch_queue is None: + # if len(opt.data_ids) > 1: + # train_shards = [] + # for train_id in opt.data_ids: + # shard_base = "train_" + train_id + # train_shards.append(shard_base) + # train_iter = build_dataset_iter_multiple(train_shards, fields, opt) + # else: + # if opt.data_ids[0] is not None: + # if opt.data_ids[0] is not None and opt.data_ids[0] != "None": + # shard_base = "train_" + opt.data_ids[0] + # else: + # shard_base = "train" + # train_iter = build_dataset_iter(shard_base, fields, opt) + # else: + # assert semaphore is not None, "Using batch_queue requires semaphore as well" + # + # def _train_iter(): + # while True: + # batch = batch_queue.get() + # semaphore.release() + # yield batch + # + # train_iter = _train_iter() + # valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) + # if len(opt.gpu_ranks): + # logger.info("Starting training on GPU: %s" % opt.gpu_ranks) + # else: + # logger.info("Starting training on CPU, could be very slow") + # train_steps = opt.train_steps + # if opt.single_pass and train_steps > 0: + # logger.warning("Option single_pass is enabled, ignoring train_steps.") + # train_steps = 0 offset = max(0, device_id) if opt.parallel_mode == "data_parallel" else 0 stride = max(1, len(opt.gpu_ranks)) if opt.parallel_mode == "data_parallel" else 1 @@ -281,9 +281,9 @@ def _train_iter(): mlflow.start_run() for k, v in vars(opt).items(): mlflow.log_param(k, v) - mlflow.log_param("n_enc_parameters", enc) - mlflow.log_param("n_dec_parameters", dec) - mlflow.log_param("n_total_parameters", n_params) + # mlflow.log_param("n_enc_parameters", enc) + # mlflow.log_param("n_dec_parameters", dec) + # mlflow.log_param("n_total_parameters", n_params) import onmt mlflow.log_param("onmt_version", onmt.__version__) @@ -302,9 +302,9 @@ def _train_iter(): wandb.config.update( { - "n_enc_parameters": enc, - "n_dec_parameters": dec, - "n_total_parameters": n_params, + # "n_enc_parameters": enc, + # "n_dec_parameters": dec, + # "n_total_parameters": n_params, "onmt_version": onmt.__version__, } ) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 7e3acf57fb..e0ef87aeca 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -10,6 +10,8 @@ from copy import deepcopy import onmt.model_builder import onmt.decoders.ensemble + +# import onmt.inputters as inputters from onmt.constants import DefaultTokens from onmt.translate.beam_search import BeamSearch, BeamSearchLM from onmt.translate.greedy_search import GreedySearch, GreedySearchLM @@ -324,85 +326,85 @@ def _gold_score( glp = None return gs, glp - def likelihood( - self, - src, - tgt=None, - src_dir=None, - batch_size=None, - batch_type="sents", - attn_debug=False, - phrase_table="", - ): - """Translate content of ``src`` and get gold scores from ``tgt``. - Args: - src: See :func:`self.src_reader.read()`. - tgt: See :func:`self.tgt_reader.read()`. - src_dir: See :func:`self.src_reader.read()` (only relevant - for certain types of data). - batch_size (int): size of examples per mini-batch - attn_debug (bool): enables the attention logging - Returns: - (`list`, `list`) - * all_scores is a list of `batch_size` lists of `n_best` scores - * all_predictions is a list of `batch_size` lists - of `n_best` predictions - """ - - if batch_size is None: - raise ValueError("batch_size must be set") - - data = inputters.Dataset( - self.fields, - readers=([self.src_reader, self.tgt_reader] if tgt else [self.src_reader]), - data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)], - dirs=[src_dir, None] if tgt else [src_dir], - sort_key=inputters.str2sortkey[self.data_type], - filter_pred=self._filter_pred, - ) - - data_iter = inputters.OrderedIterator( - dataset=data, - device=self._dev, - batch_size=batch_size, - batch_size_fn=max_tok_len if batch_type == "tokens" else None, - train=False, - sort=False, - sort_within_batch=True, - shuffle=False, - ) - - all_gold_scores = [] - - use_src_map = self.copy_attn - beam_size = self.beam_size - - for batch in data_iter: - # import pdb; pdb.set_trace() - # (0) Prep the components of the search. - - # (1) Run the encoder on the src. - src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) - self.model.decoder.init_state(src, memory_bank, enc_states) - - gold_scores = self._gold_score( - batch, - memory_bank, - src_lengths, - data.src_vocabs, - use_src_map, - enc_states, - batch_size, - src, - ) - gold_scores = gold_scores.detach().numpy().tolist() - - all_gold_scores += [ - score - for _, score in sorted(zip(batch.indices.numpy().tolist(), gold_scores)) - ] - - return all_gold_scores + # def likelihood( + # self, + # src, + # tgt=None, + # src_dir=None, + # batch_size=None, + # batch_type="sents", + # attn_debug=False, + # phrase_table="", + # ): + # """Translate content of ``src`` and get gold scores from ``tgt``. + # Args: + # src: See :func:`self.src_reader.read()`. + # tgt: See :func:`self.tgt_reader.read()`. + # src_dir: See :func:`self.src_reader.read()` (only relevant + # for certain types of data). + # batch_size (int): size of examples per mini-batch + # attn_debug (bool): enables the attention logging + # Returns: + # (`list`, `list`) + # * all_scores is a list of `batch_size` lists of `n_best` scores + # * all_predictions is a list of `batch_size` lists + # of `n_best` predictions + # """ + # + # if batch_size is None: + # raise ValueError("batch_size must be set") + # + # data = inputters.Dataset( + # self.fields, + # readers=([self.src_reader, self.tgt_reader] if tgt else [self.src_reader]), + # data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)], + # dirs=[src_dir, None] if tgt else [src_dir], + # sort_key=inputters.str2sortkey[self.data_type], + # filter_pred=self._filter_pred, + # ) + # + # data_iter = inputters.OrderedIterator( + # dataset=data, + # device=self._dev, + # batch_size=batch_size, + # batch_size_fn=max_tok_len if batch_type == "tokens" else None, + # train=False, + # sort=False, + # sort_within_batch=True, + # shuffle=False, + # ) + # + # all_gold_scores = [] + # + # use_src_map = self.copy_attn + # beam_size = self.beam_size + # + # for batch in data_iter: + # # import pdb; pdb.set_trace() + # # (0) Prep the components of the search. + # + # # (1) Run the encoder on the src. + # src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) + # self.model.decoder.init_state(src, memory_bank, enc_states) + # + # gold_scores = self._gold_score( + # batch, + # memory_bank, + # src_lengths, + # data.src_vocabs, + # use_src_map, + # enc_states, + # batch_size, + # src, + # ) + # gold_scores = gold_scores.detach().numpy().tolist() + # + # all_gold_scores += [ + # score + # for _, score in sorted(zip(batch.indices.numpy().tolist(), gold_scores)) + # ] + # + # return all_gold_scores def _translate( self, @@ -445,9 +447,9 @@ def _translate( all_scores = [] all_predictions = [] - if self.is_ibmrxn: - all_attentions = [] # added phs - attn_debug = True + # if self.is_ibmrxn: + # all_attentions = [] # added phs + # attn_debug = True start_time = time() @@ -602,6 +604,24 @@ def _process_bucket(bucket_translations): output = report_matrix(srcs, tgts, align) self._log(output) + # if self.is_ibmrxn: + # all_attentions.append(trans.attns[0]) # added phs + # + # # phs: added to log gold scores to file + # if self.target_score_out_file is not None: + # self.target_score_out_file.write(str(trans.gold_score.item()) + "\n") + # self.target_score_out_file.flush() + # # + # + # if self.is_ibmrxn: # added phs + # return { + # "score": all_scores + # if batch_size > 1 + # else all_scores[0], # return more scores when batch_size > 1 + # "prediction": all_predictions, + # "context_attns": all_attentions, + # } + return ( bucket_scores, bucket_predictions, @@ -653,9 +673,6 @@ def _process_bucket(bucket_translations): gold_words_total += bucket_gold_words bucket_translations = [] - if self.is_ibmrxn: - all_attentions.append(trans.attns[0]) # added phs - if len(bucket_translations) > 0: ( bucket_scores, @@ -672,21 +689,6 @@ def _process_bucket(bucket_translations): gold_score_total += bucket_gold_score gold_words_total += bucket_gold_words - # phs: added to log gold scores to file - if self.target_score_out_file is not None: - self.target_score_out_file.write(str(trans.gold_score.item()) + "\n") - self.target_score_out_file.flush() - # - - if self.is_ibmrxn: # added phs - return { - "score": all_scores - if batch_size > 1 - else all_scores[0], # return more scores when batch_size > 1 - "prediction": all_predictions, - "context_attns": all_attentions, - } - end_time = time() if self.report_score: