Skip to content

Commit

Permalink
fix: flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
helderlopes97 committed Feb 9, 2024
1 parent 4a66b28 commit e93c97a
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 140 deletions.
3 changes: 2 additions & 1 deletion onmt/bin/train_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
)

from itertools import cycle
import torch.cuda.profiler as profiler

# import torch.cuda.profiler as profiler
import pyprof2

pyprof2.init()
Expand Down
78 changes: 39 additions & 39 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,39 +203,39 @@ def main(opt, device_id, batch_queue=None, semaphore=None):
opt, device_id, model, vocabs, optim, model_saver=model_saver
)

if batch_queue is None:
if len(opt.data_ids) > 1:
train_shards = []
for train_id in opt.data_ids:
shard_base = "train_" + train_id
train_shards.append(shard_base)
train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
else:
if opt.data_ids[0] is not None:
if opt.data_ids[0] is not None and opt.data_ids[0] != "None":
shard_base = "train_" + opt.data_ids[0]
else:
shard_base = "train"
train_iter = build_dataset_iter(shard_base, fields, opt)
else:
assert semaphore is not None, "Using batch_queue requires semaphore as well"

def _train_iter():
while True:
batch = batch_queue.get()
semaphore.release()
yield batch

train_iter = _train_iter()
valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)
if len(opt.gpu_ranks):
logger.info("Starting training on GPU: %s" % opt.gpu_ranks)
else:
logger.info("Starting training on CPU, could be very slow")
train_steps = opt.train_steps
if opt.single_pass and train_steps > 0:
logger.warning("Option single_pass is enabled, ignoring train_steps.")
train_steps = 0
# if batch_queue is None:
# if len(opt.data_ids) > 1:
# train_shards = []
# for train_id in opt.data_ids:
# shard_base = "train_" + train_id
# train_shards.append(shard_base)
# train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
# else:
# if opt.data_ids[0] is not None:
# if opt.data_ids[0] is not None and opt.data_ids[0] != "None":
# shard_base = "train_" + opt.data_ids[0]
# else:
# shard_base = "train"
# train_iter = build_dataset_iter(shard_base, fields, opt)
# else:
# assert semaphore is not None, "Using batch_queue requires semaphore as well"
#
# def _train_iter():
# while True:
# batch = batch_queue.get()
# semaphore.release()
# yield batch
#
# train_iter = _train_iter()
# valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)
# if len(opt.gpu_ranks):
# logger.info("Starting training on GPU: %s" % opt.gpu_ranks)
# else:
# logger.info("Starting training on CPU, could be very slow")
# train_steps = opt.train_steps
# if opt.single_pass and train_steps > 0:
# logger.warning("Option single_pass is enabled, ignoring train_steps.")
# train_steps = 0

offset = max(0, device_id) if opt.parallel_mode == "data_parallel" else 0
stride = max(1, len(opt.gpu_ranks)) if opt.parallel_mode == "data_parallel" else 1
Expand Down Expand Up @@ -281,9 +281,9 @@ def _train_iter():
mlflow.start_run()
for k, v in vars(opt).items():
mlflow.log_param(k, v)
mlflow.log_param("n_enc_parameters", enc)
mlflow.log_param("n_dec_parameters", dec)
mlflow.log_param("n_total_parameters", n_params)
# mlflow.log_param("n_enc_parameters", enc)
# mlflow.log_param("n_dec_parameters", dec)
# mlflow.log_param("n_total_parameters", n_params)
import onmt

mlflow.log_param("onmt_version", onmt.__version__)
Expand All @@ -302,9 +302,9 @@ def _train_iter():

wandb.config.update(
{
"n_enc_parameters": enc,
"n_dec_parameters": dec,
"n_total_parameters": n_params,
# "n_enc_parameters": enc,
# "n_dec_parameters": dec,
# "n_total_parameters": n_params,
"onmt_version": onmt.__version__,
}
)
Expand Down
202 changes: 102 additions & 100 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from copy import deepcopy
import onmt.model_builder
import onmt.decoders.ensemble

# import onmt.inputters as inputters
from onmt.constants import DefaultTokens
from onmt.translate.beam_search import BeamSearch, BeamSearchLM
from onmt.translate.greedy_search import GreedySearch, GreedySearchLM
Expand Down Expand Up @@ -324,85 +326,85 @@ def _gold_score(
glp = None
return gs, glp

def likelihood(
self,
src,
tgt=None,
src_dir=None,
batch_size=None,
batch_type="sents",
attn_debug=False,
phrase_table="",
):
"""Translate content of ``src`` and get gold scores from ``tgt``.
Args:
src: See :func:`self.src_reader.read()`.
tgt: See :func:`self.tgt_reader.read()`.
src_dir: See :func:`self.src_reader.read()` (only relevant
for certain types of data).
batch_size (int): size of examples per mini-batch
attn_debug (bool): enables the attention logging
Returns:
(`list`, `list`)
* all_scores is a list of `batch_size` lists of `n_best` scores
* all_predictions is a list of `batch_size` lists
of `n_best` predictions
"""

if batch_size is None:
raise ValueError("batch_size must be set")

data = inputters.Dataset(
self.fields,
readers=([self.src_reader, self.tgt_reader] if tgt else [self.src_reader]),
data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)],
dirs=[src_dir, None] if tgt else [src_dir],
sort_key=inputters.str2sortkey[self.data_type],
filter_pred=self._filter_pred,
)

data_iter = inputters.OrderedIterator(
dataset=data,
device=self._dev,
batch_size=batch_size,
batch_size_fn=max_tok_len if batch_type == "tokens" else None,
train=False,
sort=False,
sort_within_batch=True,
shuffle=False,
)

all_gold_scores = []

use_src_map = self.copy_attn
beam_size = self.beam_size

for batch in data_iter:
# import pdb; pdb.set_trace()
# (0) Prep the components of the search.

# (1) Run the encoder on the src.
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
self.model.decoder.init_state(src, memory_bank, enc_states)

gold_scores = self._gold_score(
batch,
memory_bank,
src_lengths,
data.src_vocabs,
use_src_map,
enc_states,
batch_size,
src,
)
gold_scores = gold_scores.detach().numpy().tolist()

all_gold_scores += [
score
for _, score in sorted(zip(batch.indices.numpy().tolist(), gold_scores))
]

return all_gold_scores
# def likelihood(
# self,
# src,
# tgt=None,
# src_dir=None,
# batch_size=None,
# batch_type="sents",
# attn_debug=False,
# phrase_table="",
# ):
# """Translate content of ``src`` and get gold scores from ``tgt``.
# Args:
# src: See :func:`self.src_reader.read()`.
# tgt: See :func:`self.tgt_reader.read()`.
# src_dir: See :func:`self.src_reader.read()` (only relevant
# for certain types of data).
# batch_size (int): size of examples per mini-batch
# attn_debug (bool): enables the attention logging
# Returns:
# (`list`, `list`)
# * all_scores is a list of `batch_size` lists of `n_best` scores
# * all_predictions is a list of `batch_size` lists
# of `n_best` predictions
# """
#
# if batch_size is None:
# raise ValueError("batch_size must be set")
#
# data = inputters.Dataset(
# self.fields,
# readers=([self.src_reader, self.tgt_reader] if tgt else [self.src_reader]),
# data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)],
# dirs=[src_dir, None] if tgt else [src_dir],
# sort_key=inputters.str2sortkey[self.data_type],
# filter_pred=self._filter_pred,
# )
#
# data_iter = inputters.OrderedIterator(
# dataset=data,
# device=self._dev,
# batch_size=batch_size,
# batch_size_fn=max_tok_len if batch_type == "tokens" else None,
# train=False,
# sort=False,
# sort_within_batch=True,
# shuffle=False,
# )
#
# all_gold_scores = []
#
# use_src_map = self.copy_attn
# beam_size = self.beam_size
#
# for batch in data_iter:
# # import pdb; pdb.set_trace()
# # (0) Prep the components of the search.
#
# # (1) Run the encoder on the src.
# src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
# self.model.decoder.init_state(src, memory_bank, enc_states)
#
# gold_scores = self._gold_score(
# batch,
# memory_bank,
# src_lengths,
# data.src_vocabs,
# use_src_map,
# enc_states,
# batch_size,
# src,
# )
# gold_scores = gold_scores.detach().numpy().tolist()
#
# all_gold_scores += [
# score
# for _, score in sorted(zip(batch.indices.numpy().tolist(), gold_scores))
# ]
#
# return all_gold_scores

def _translate(
self,
Expand Down Expand Up @@ -445,9 +447,9 @@ def _translate(

all_scores = []
all_predictions = []
if self.is_ibmrxn:
all_attentions = [] # added phs
attn_debug = True
# if self.is_ibmrxn:
# all_attentions = [] # added phs
# attn_debug = True

start_time = time()

Expand Down Expand Up @@ -602,6 +604,24 @@ def _process_bucket(bucket_translations):
output = report_matrix(srcs, tgts, align)
self._log(output)

# if self.is_ibmrxn:
# all_attentions.append(trans.attns[0]) # added phs
#
# # phs: added to log gold scores to file
# if self.target_score_out_file is not None:
# self.target_score_out_file.write(str(trans.gold_score.item()) + "\n")
# self.target_score_out_file.flush()
# #
#
# if self.is_ibmrxn: # added phs
# return {
# "score": all_scores
# if batch_size > 1
# else all_scores[0], # return more scores when batch_size > 1
# "prediction": all_predictions,
# "context_attns": all_attentions,
# }

return (
bucket_scores,
bucket_predictions,
Expand Down Expand Up @@ -653,9 +673,6 @@ def _process_bucket(bucket_translations):
gold_words_total += bucket_gold_words
bucket_translations = []

if self.is_ibmrxn:
all_attentions.append(trans.attns[0]) # added phs

if len(bucket_translations) > 0:
(
bucket_scores,
Expand All @@ -672,21 +689,6 @@ def _process_bucket(bucket_translations):
gold_score_total += bucket_gold_score
gold_words_total += bucket_gold_words

# phs: added to log gold scores to file
if self.target_score_out_file is not None:
self.target_score_out_file.write(str(trans.gold_score.item()) + "\n")
self.target_score_out_file.flush()
#

if self.is_ibmrxn: # added phs
return {
"score": all_scores
if batch_size > 1
else all_scores[0], # return more scores when batch_size > 1
"prediction": all_predictions,
"context_attns": all_attentions,
}

end_time = time()

if self.report_score:
Expand Down

0 comments on commit e93c97a

Please sign in to comment.