From c19e8d0b619e6c928ea469ca86236fbae51b9aa2 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 18 Apr 2024 15:24:45 -0400 Subject: [PATCH 1/3] Make help length more equal. Change hdf5 in track_from_X. logging -> Verbose --- dwi_ml/io_utils.py | 36 +++++----- dwi_ml/testing/projects/tt_visu_argparser.py | 4 +- dwi_ml/testing/projects/tt_visu_main.py | 2 +- dwi_ml/testing/utils.py | 52 ++++++++++++-- dwi_ml/testing/visu_loss_utils.py | 4 +- dwi_ml/tracking/io_utils.py | 67 ++++++++++--------- .../dwiml_compute_loss_copy_previous.py | 2 +- scripts_python/dwiml_create_hdf5_dataset.py | 6 +- .../l2t_resume_training_from_checkpoint.py | 10 ++- scripts_python/l2t_track_from_model.py | 25 ++++--- scripts_python/l2t_train_model.py | 10 ++- scripts_python/l2t_visualize_loss.py | 8 +-- .../tt_resume_training_from_checkpoint.py | 10 ++- scripts_python/tt_track_from_model.py | 23 ++++--- scripts_python/tt_train_model.py | 15 ++--- scripts_python/tt_visualize_loss.py | 8 +-- 16 files changed, 159 insertions(+), 123 deletions(-) diff --git a/dwi_ml/io_utils.py b/dwi_ml/io_utils.py index 84b89564..0a244352 100644 --- a/dwi_ml/io_utils.py +++ b/dwi_ml/io_utils.py @@ -5,13 +5,13 @@ from scilpy.io.utils import add_processes_arg -def add_logging_arg(p): - p.add_argument( - '--logging', default='WARNING', metavar='level', - choices=['ERROR', 'WARNING', 'INFO', 'DEBUG'], - help="Logging level. Note that, for readability, not all debug logs \n" - "are printed in DEBUG mode, only the main ones. \n" - "Default: WARNING.") +def add_verbose_arg(p): + # Can eventually become scilpy.io.utils.add_verbose_arg + p.add_argument('-v', default="WARNING", const='INFO', nargs='?', + choices=['DEBUG', 'INFO', 'WARNING'], dest='verbose', + help='Produces verbose output depending on ' + 'the provided level. \nDefault level is warning, ' + 'default when using -v is info.') def add_resample_or_compress_arg(p: ArgumentParser): @@ -28,8 +28,8 @@ def add_resample_or_compress_arg(p: ArgumentParser): def add_arg_existing_experiment_path(p: ArgumentParser): p.add_argument('experiment_path', - help='Path to the directory containing the experiment.\n' - '(Should contain a model subdir with a file \n' + help='Path to the directory containing the experiment. ' + '(Should contain a model subdir \nwith a file ' 'parameters.json and a file best_model_state.pkl.)') p.add_argument('--use_latest_epoch', action='store_true', help="If true, use model at latest epoch rather than " @@ -44,11 +44,15 @@ def add_memory_args(p: ArgumentParser, add_lazy_options=False, if add_multiprocessing_option: ram_options = g.add_mutually_exclusive_group() # Parallel processing or GPU processing - add_processes_arg(ram_options) + ram_options.add_argument( + '--processes', dest='nbr_processes', metavar='nb', type=int, + default=1, + help='Number of sub-processes to start for parallel processing. ' + 'Default: [%(default)s]') ram_options.add_argument( '--use_gpu', action='store_true', help="If set, use GPU for processing. Cannot be used together " - "with \noption --processes.") + "with option --processes.") else: p.add_argument('--use_gpu', action='store_true', help="If set, use GPU for processing.") @@ -63,14 +67,14 @@ def add_memory_args(p: ArgumentParser, add_lazy_options=False, g.add_argument( '--cache_size', type=int, metavar='s', default=1, help="Relevant only if lazy data is used. Size of the cache in " - "terms\n of length of the queue (i.e. number of volumes). \n" - "NOTE: Real cache size will actually be larger depending on " - "use;\nthe training, validation and testing sets each have " - "their cache. [1]") + "terms of length of the \nqueue (i.e. number of volumes). " + "NOTE: Real cache size will actually be larger \ndepending " + "on usage; the training, validation and testing sets each " + "have their \ncache. [1]") g.add_argument( '--lazy', action='store_true', help="If set, do not load all the dataset in memory at once. " - "Load \nonly what is needed for a batch.") + "Load only what is needed \nfor a batch.") return g diff --git a/dwi_ml/testing/projects/tt_visu_argparser.py b/dwi_ml/testing/projects/tt_visu_argparser.py index 72b7099d..cde26c57 100644 --- a/dwi_ml/testing/projects/tt_visu_argparser.py +++ b/dwi_ml/testing/projects/tt_visu_argparser.py @@ -54,7 +54,7 @@ from scilpy.io.utils import (add_overwrite_arg, add_reference_arg) from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_logging_arg, add_memory_args) + add_verbose_arg, add_memory_args) from dwi_ml.testing.utils import add_args_testing_subj_hdf5 @@ -172,7 +172,7 @@ def build_argparser_transformer_visu(): help="Batch size in number of streamlines. If not set, " "uses all streamlines \nin one batch.") add_reference_arg(p) - add_logging_arg(p) + add_verbose_arg(p) add_overwrite_arg(p) return p diff --git a/dwi_ml/testing/projects/tt_visu_main.py b/dwi_ml/testing/projects/tt_visu_main.py index c01a5484..40dc4bd6 100644 --- a/dwi_ml/testing/projects/tt_visu_main.py +++ b/dwi_ml/testing/projects/tt_visu_main.py @@ -75,7 +75,7 @@ def tt_visualize_weights_main(args, parser): os.remove(f) sub_logger_level = 'WARNING' - logging.getLogger().setLevel(level=args.logging) + logging.getLogger().setLevel(level=args.verbose) if args.use_gpu: if torch.cuda.is_available(): diff --git a/dwi_ml/testing/utils.py b/dwi_ml/testing/utils.py index 7521e97f..c34d589d 100644 --- a/dwi_ml/testing/utils.py +++ b/dwi_ml/testing/utils.py @@ -1,14 +1,28 @@ # -*- coding: utf-8 -*- +import json +import logging +import os +from argparse import ArgumentParser from typing import List +import torch + from dwi_ml.data.dataset.multi_subject_containers import (MultiSubjectDataset, MultisubjectSubset) -def add_args_testing_subj_hdf5(p, ask_input_group=False, +def add_args_testing_subj_hdf5(p: ArgumentParser, optional_hdf5=False, + ask_input_group=False, ask_streamlines_group=False): - p.add_argument('hdf5_file', - help="Path to the hdf5 file.") + g = p.add_argument_group("Inputs options") + if optional_hdf5: + g.add_argument('--hdf5_file', metavar='file', + help="Path to the hdf5 file. If not given, will use " + "the file from the experiment's \nparameters. " + "(in parameters_latest.json)") + else: + p.add_argument('hdf5_file', + help="Path to the hdf5 file.") p.add_argument('subj_id', help="Subject id to use in the hdf5.") if ask_input_group: @@ -17,13 +31,41 @@ def add_args_testing_subj_hdf5(p, ask_input_group=False, if ask_streamlines_group: p.add_argument('streamlines_group', help="Model's streamlines group in the hdf5.") - p.add_argument('--subset', default='testing', + g.add_argument('--subset', default='testing', choices=['training', 'validation', 'testing'], help="Subject id should probably come come the " - "'testing' set but you can \nmodify this to " + "'testing' set but you can modify this \nto " "'training' or 'validation'.") +def find_hdf5_associated_to_experiment(experiment_path): + parameters_json = os.path.join(experiment_path, 'parameters_latest.json') + hdf5_file = None + if os.path.isfile(parameters_json): + with open(parameters_json, 'r') as json_file: + params = json.load(json_file) + if 'hdf5 file' in params: + hdf5_file = params['hdf5 file'] + + if hdf5_file is None: + logging.warning("Did not find the hdf5 file associated to your " + "exeperiment in the parameters file {}.\n" + "Will try to find it in the latest checkpoint." + .format(parameters_json)) + checkpoint_path = os.path.join( + experiment_path, "checkpoint", "checkpoint_state.pkl") + if not os.path.isfile(checkpoint_path): + raise FileNotFoundError( + 'Checkpoint was not found! ({}). Could not find the hdf5 ' + 'associated to your experiment. Please specify it yourself.' + .format(checkpoint_path)) + else: + checkpoint_state = torch.load(checkpoint_path) + hdf5_file = checkpoint_state['dataset_params']['hdf5_file'] + + return hdf5_file + + def prepare_dataset_one_subj( hdf5_file: str, subj_id: str, lazy: bool = False, cache_size: int = 1, subset_name: str = 'testing', volume_groups: List[str] = None, diff --git a/dwi_ml/testing/visu_loss_utils.py b/dwi_ml/testing/visu_loss_utils.py index 0b91f3f7..9fd1bc53 100644 --- a/dwi_ml/testing/visu_loss_utils.py +++ b/dwi_ml/testing/visu_loss_utils.py @@ -6,7 +6,7 @@ assert_inputs_exist, assert_outputs_exist, add_reference_arg) -from dwi_ml.io_utils import add_memory_args, add_logging_arg +from dwi_ml.io_utils import add_memory_args, add_verbose_arg def prepare_args_visu_loss(p: ArgumentParser): @@ -80,7 +80,7 @@ def prepare_args_visu_loss(p: ArgumentParser): "(base on loss).") add_overwrite_arg(p) - add_logging_arg(p) + add_verbose_arg(p) add_reference_arg(p) diff --git a/dwi_ml/tracking/io_utils.py b/dwi_ml/tracking/io_utils.py index 956af16c..3667068d 100644 --- a/dwi_ml/tracking/io_utils.py +++ b/dwi_ml/tracking/io_utils.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import logging import os +from argparse import ArgumentParser from dipy.io.stateful_tractogram import (Space, Origin, set_sft_logger_level, StatefulTractogram) @@ -16,32 +17,30 @@ from dwi_ml.tracking.tracking_mask import TrackingMask from dwi_ml.tracking.tracker import DWIMLAbstractTracker - ALWAYS_VOX_SPACE = Space.VOX ALWAYS_CORNER = Origin('corner') -def add_tracking_options(p): - +def add_tracking_options(p: ArgumentParser): add_arg_existing_experiment_path(p) - add_args_testing_subj_hdf5(p, ask_input_group=True) + add_args_testing_subj_hdf5(p, optional_hdf5=True, + ask_input_group=True) p.add_argument('out_tractogram', help='Tractogram output file (must be .trk or .tck).') p.add_argument('seeding_mask_group', help="Seeding mask's volume group in the hdf5.") - track_g = p.add_argument_group(' Tracking options') + track_g = p.add_argument_group('Tracking options') track_g.add_argument('--algo', choices=['det', 'prob'], default='det', help="Tracking algorithm (det or prob). Must be " - "implemented in the chosen model. [det]") + "implemented in the chosen model. \n[det]") track_g.add_argument('--step_size', type=float, help='Step size in mm. Default: using the step size ' 'saved in the model parameters.') track_g.add_argument('--track_forward_only', action='store_true', help="If set, tracks in one direction only (forward) " - "given the initial \nseed. The direction is " - "randomly drawn from the ODF.") + "given the initial seed.") track_g.add_argument('--mask_interp', default='nearest', choices=['nearest', 'trilinear'], help="Mask interpolation: nearest-neighbor or " @@ -60,48 +59,51 @@ def add_tracking_options(p): metavar='M', help='Maximum length of a streamline in mm. ' '[%(default)s]') - stop_g.add_argument('--tracking_mask_group', + stop_g.add_argument('--tracking_mask_group', metavar='key', help="Tracking mask's volume group in the hdf5.") - stop_g.add_argument('--theta', metavar='t', type=float, - default=90, + stop_g.add_argument('--theta', metavar='t', type=float, default=90, help="Stopping criterion during propagation: " - "tracking \nis stopped when a direction is more " - "than an angle t from \npreceding direction." + "tracking is stopped when a direction is \nmore " + "than an angle theta from preceding direction. " "[%(default)s]") stop_g.add_argument('--eos_stop', metavar='prob', help="Stopping criterion if a EOS value was learned " - "during training. \nCan either be a probability " - "(default 0.5) or the string 'max', which will " - "\nstop the propagation if the EOS class's " - "probability is the class with maximal " - "probability, no mather its value.") + "during training. For all models, \ncan be a " + "probability (default 0.5). For classification " + "models, can also be the \nkeyword 'max', which " + "will stop the propagation if the EOS class is " + "the class \nwith maximal probability, no matter " + "its value.") stop_g.add_argument( '--discard_last_point', action='store_true', - help="If set, discard the last point (once out of the tracking mask)\n" - "of the streamline. Default: append them. This is the default in\n" - "Dipy too. Note that points obtained after an invalid direction\n" - "(based on the propagator's definition of invalid; ex when \n" - "angle is too sharp of sh_threshold not reached) are never added.") + help="If set, discard the last point (once out of the tracking mask) " + "of the \nstreamline. Default: do not discard them; append them. " + "This is the default in \nDipy too. Note that points obtained " + "after an invalid direction (based on the \npropagator's " + "definition of invalid; ex when angle is too sharp or " + "sh_threshold \nis not reached) are never added.") r_g = p.add_argument_group(' Random seeding options') r_g.add_argument('--rng_seed', type=int, help='Initial value for the random number generator. ' '[%(default)s]') - r_g.add_argument('--skip', type=int, default=0, - help="Skip the first N random numbers. \n" - "Useful if you want to create new streamlines to " - "add to \na previously created tractogram with a " - "fixed --rng_seed.\nEx: If tractogram_1 was created " - "with -nt 1,000,000, \nyou can create tractogram_2 " - "with \n--skip 1,000,000.") + r_g.add_argument( + '--skip', type=int, default=0, + help="Skip the first N random numbers. Useful if you want to create " + "new streamlines to \nadd to a tractogram previously created " + "with a fixed --rng_seed. Ex: If \ntractogram_1 was created " + "with -nt 1,000,000, you can create tractogram_2 with \n" + "--skip 1,000,000.") # Memory options: m_g = add_memory_args(p, add_lazy_options=True, add_multiprocessing_option=True, add_rng=True) m_g.add_argument('--simultaneous_tracking', type=int, default=1, + metavar='nb', help='Track n streamlines at the same time. Intended for ' - 'GPU usage. Default = 1 (no simultaneous tracking).') + 'GPU usage. Default = 1 \n(no simultaneous ' + 'tracking).') return track_g @@ -144,7 +146,8 @@ def prepare_seed_generator(parser, args, hdf_handle): return seed_generator, nbr_seeds, seed_header, ref -def prepare_tracking_mask(hdf_handle, tracking_mask_group, subj_id, mask_interp): +def prepare_tracking_mask(hdf_handle, tracking_mask_group, subj_id, + mask_interp): """ Prepare the tracking mask as a DataVolume from scilpy's library. Returns also some header information to allow verifications. diff --git a/scripts_python/dwiml_compute_loss_copy_previous.py b/scripts_python/dwiml_compute_loss_copy_previous.py index d0caf322..205ac157 100644 --- a/scripts_python/dwiml_compute_loss_copy_previous.py +++ b/scripts_python/dwiml_compute_loss_copy_previous.py @@ -48,7 +48,7 @@ def prepare_arg_parser(): def main(): p = prepare_arg_parser() args = p.parse_args() - logging.getLogger().setLevel(level=args.logging) + logging.getLogger().setLevel(level=args.verbose) # Checks if args.out_dir is None: diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py index 3377bc2f..5dd11cf7 100644 --- a/scripts_python/dwiml_create_hdf5_dataset.py +++ b/scripts_python/dwiml_create_hdf5_dataset.py @@ -34,7 +34,7 @@ from dwi_ml.data.hdf5.utils import ( add_hdf5_creation_args, add_streamline_processing_args) from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_logging_arg +from dwi_ml.io_utils import add_verbose_arg def _initialize_intermediate_subdir(hdf5_file, save_intermediate): @@ -102,7 +102,7 @@ def _parse_args(): add_hdf5_creation_args(p) add_streamline_processing_args(p) add_overwrite_arg(p) - add_logging_arg(p) + add_verbose_arg(p) return p @@ -113,7 +113,7 @@ def main(): args = p.parse_args() # Initialize logger - logging.getLogger().setLevel(level=args.logging) + logging.getLogger().setLevel(level=args.verbose) # Silencing SFT's logger if our logging is in DEBUG mode, because it # typically produces a lot of outputs! diff --git a/scripts_python/l2t_resume_training_from_checkpoint.py b/scripts_python/l2t_resume_training_from_checkpoint.py index 26e6cb2d..754d6a26 100644 --- a/scripts_python/l2t_resume_training_from_checkpoint.py +++ b/scripts_python/l2t_resume_training_from_checkpoint.py @@ -11,7 +11,7 @@ from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_logging_arg +from dwi_ml.io_utils import add_verbose_arg from dwi_ml.models.projects.learn2track_model import Learn2TrackModel from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer @@ -26,7 +26,7 @@ def prepare_arg_parser(): formatter_class=argparse.RawTextHelpFormatter) add_args_resuming_experiment(p) - add_logging_arg(p) + add_verbose_arg(p) return p @@ -49,9 +49,7 @@ def init_from_checkpoint(args, checkpoint_path): dataset = prepare_multisubjectdataset(argparse.Namespace(**args_data)) # Setting log level to INFO maximum for sub-loggers, else it become ugly - sub_loggers_level = args.logging - if args.logging == 'DEBUG': - sub_loggers_level = 'INFO' + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' # Load model from checkpoint directory model = Learn2TrackModel.load_model_from_params_and_state( @@ -72,7 +70,7 @@ def init_from_checkpoint(args, checkpoint_path): model, args.experiments_path, args.experiment_name, batch_sampler, batch_loader, checkpoint_state, args.new_patience, args.new_max_epochs, - args.logging) + args.verbose) return trainer diff --git a/scripts_python/l2t_track_from_model.py b/scripts_python/l2t_track_from_model.py index 9bda1254..b290eafe 100644 --- a/scripts_python/l2t_track_from_model.py +++ b/scripts_python/l2t_track_from_model.py @@ -23,9 +23,10 @@ from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_logging_arg +from dwi_ml.io_utils import add_verbose_arg from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.testing.utils import prepare_dataset_one_subj +from dwi_ml.testing.utils import prepare_dataset_one_subj, \ + find_hdf5_associated_to_experiment from dwi_ml.tracking.projects.learn2track_tracker import RecurrentTracker from dwi_ml.tracking.tracking_mask import TrackingMask from dwi_ml.tracking.io_utils import (add_tracking_options, @@ -43,20 +44,19 @@ def build_argparser(): # As in scilpy: add_seeding_options(p) - add_out_options(p) + add_out_options(p) # Formatting a bit ugly compared to us, but ok. - add_logging_arg(p) + add_verbose_arg(p) return p def prepare_tracker(parser, args): - hdf_handle = h5py.File(args.hdf5_file, 'r') + hdf5_file = args.hdf5_file or find_hdf5_associated_to_experiment( + args.experiment_path) + hdf_handle = h5py.File(hdf5_file, 'r') - sub_logger_level = args.logging.upper() - if sub_logger_level == 'DEBUG': - # make them info max - sub_logger_level = 'INFO' + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' with Timer("\nLoading data and preparing tracker...", newline=True, color='green'): @@ -89,7 +89,7 @@ def prepare_tracker(parser, args): else: model_dir = os.path.join(args.experiment_path, 'checkpoint/model') model = Learn2TrackModel.load_model_from_params_and_state( - model_dir, log_level=sub_logger_level) + model_dir, log_level=sub_loggers_level) logging.info("* Formatted model: " + format_dict_to_str(model.params_for_checkpoint)) @@ -108,7 +108,7 @@ def prepare_tracker(parser, args): use_gpu=args.use_gpu, eos_stopping_thresh=args.eos_stop, simultaneous_tracking=args.simultaneous_tracking, append_last_point=append_last_point, - log_level=args.logging.upper()) + log_level=args.verbose()) return tracker, ref @@ -116,11 +116,10 @@ def prepare_tracker(parser, args): def main(): parser = build_argparser() args = parser.parse_args() - torch.cuda.empty_cache() # Setting root logger to high level to max info, not debug, prints way too # much stuff. (but we can set our tracker's logger to debug) - root_level = args.logging.upper() + root_level = args.verbose() if root_level == 'DEBUG': root_level = 'INFO' logging.getLogger().setLevel(level=root_level) diff --git a/scripts_python/l2t_train_model.py b/scripts_python/l2t_train_model.py index ff4d16d9..126bd9c3 100755 --- a/scripts_python/l2t_train_model.py +++ b/scripts_python/l2t_train_model.py @@ -19,7 +19,7 @@ from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_logging_arg, add_memory_args +from dwi_ml.io_utils import add_verbose_arg, add_memory_args from dwi_ml.models.projects.learn2track_model import Learn2TrackModel from dwi_ml.models.projects.learn2track_utils import add_model_args from dwi_ml.models.utils.direction_getters import check_args_direction_getter @@ -42,7 +42,7 @@ def prepare_arg_parser(): add_args_batch_loader(p) training_group = add_training_args(p, add_a_tracking_validation_phase=True) add_memory_args(p, add_lazy_options=True, add_rng=True) - add_logging_arg(p) + add_verbose_arg(p) # Additional arg for projects training_group.add_argument( @@ -132,7 +132,7 @@ def init_from_args(args, sub_loggers_level): tracking_phase_mask_group=args.tracking_mask, # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, - log_level=args.logging) + log_level=args.verbose) logging.info("Trainer params : " + format_dict_to_str(trainer.params_for_checkpoint)) @@ -145,9 +145,7 @@ def main(): # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, # but we will set trainer to user-defined level. - sub_loggers_level = args.logging - if args.logging == 'DEBUG': - sub_loggers_level = 'INFO' + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' # General logging (ex, scilpy: Warning) logging.getLogger().setLevel(level=logging.WARNING) diff --git a/scripts_python/l2t_visualize_loss.py b/scripts_python/l2t_visualize_loss.py index 8a7980b1..3d6a720a 100644 --- a/scripts_python/l2t_visualize_loss.py +++ b/scripts_python/l2t_visualize_loss.py @@ -38,10 +38,8 @@ def main(): names = visu_checks(args, p) # Loggers - sub_logger_level = args.logging.upper() - if sub_logger_level == 'DEBUG': - sub_logger_level = 'INFO' - logging.getLogger().setLevel(level=args.logging) + logging.getLogger().setLevel(level=args.verbose) + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' # Device device = (torch.device('cuda') if torch.cuda.is_available() and @@ -54,7 +52,7 @@ def main(): else: model_dir = os.path.join(args.experiment_path, 'checkpoint/model') model = Learn2TrackModel.load_model_from_params_and_state( - model_dir, log_level=sub_logger_level) + model_dir, log_level=sub_loggers_level) model.set_context('visu') # 2. Load data through the tester diff --git a/scripts_python/tt_resume_training_from_checkpoint.py b/scripts_python/tt_resume_training_from_checkpoint.py index db20e605..b78e5926 100644 --- a/scripts_python/tt_resume_training_from_checkpoint.py +++ b/scripts_python/tt_resume_training_from_checkpoint.py @@ -11,7 +11,7 @@ from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_logging_arg, verify_which_model_in_path +from dwi_ml.io_utils import add_verbose_arg, verify_which_model_in_path from dwi_ml.models.projects.transformer_models import find_transformer_class from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.projects.transformer_trainer import TransformerTrainer @@ -25,7 +25,7 @@ def prepare_arg_parser(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter) add_args_resuming_experiment(p) - add_logging_arg(p) + add_verbose_arg(p) return p @@ -48,9 +48,7 @@ def init_from_checkpoint(args, checkpoint_path): dataset = prepare_multisubjectdataset(argparse.Namespace(**args_data)) # Setting log level to INFO maximum for sub-loggers, else it become ugly - sub_loggers_level = args.logging - if args.logging == 'DEBUG': - sub_loggers_level = 'INFO' + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' # Prepare model model_dir = os.path.join(checkpoint_path, 'model') @@ -74,7 +72,7 @@ def init_from_checkpoint(args, checkpoint_path): model, args.experiments_path, args.experiment_name, batch_sampler, batch_loader, checkpoint_state, args.new_patience, args.new_max_epochs, - args.logging) + args.verbose) return trainer diff --git a/scripts_python/tt_track_from_model.py b/scripts_python/tt_track_from_model.py index 5876905b..59724284 100644 --- a/scripts_python/tt_track_from_model.py +++ b/scripts_python/tt_track_from_model.py @@ -23,9 +23,10 @@ from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_logging_arg, verify_which_model_in_path +from dwi_ml.io_utils import add_verbose_arg, verify_which_model_in_path from dwi_ml.models.projects.transformer_models import find_transformer_class -from dwi_ml.testing.utils import prepare_dataset_one_subj +from dwi_ml.testing.utils import prepare_dataset_one_subj, \ + find_hdf5_associated_to_experiment from dwi_ml.tracking.projects.transformer_tracker import \ TransformerTracker from dwi_ml.tracking.tracking_mask import TrackingMask @@ -47,18 +48,17 @@ def build_argparser(): add_seeding_options(p) add_out_options(p) - add_logging_arg(p) + add_verbose_arg(p) return p def prepare_tracker(parser, args): - hdf_handle = h5py.File(args.hdf5_file, 'r') + hdf5_file = args.hdf5_file or find_hdf5_associated_to_experiment( + args.experiment_path) + hdf_handle = h5py.File(hdf5_file, 'r') - sub_logger_level = args.logging.upper() - if sub_logger_level == 'DEBUG': - # make them info max - sub_logger_level = 'INFO' + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' with Timer("\nLoading data and preparing tracker...", newline=True, color='green'): @@ -93,7 +93,8 @@ def prepare_tracker(parser, args): model_type = verify_which_model_in_path(model_dir) print("Model's class: {}".format(model_type)) cls = find_transformer_class(model_type) - model = cls.load_model_from_params_and_state(model_dir, sub_logger_level) + model = cls.load_model_from_params_and_state(model_dir, + sub_loggers_level) logging.info("* Formatted model: " + format_dict_to_str(model.params_for_checkpoint)) @@ -112,7 +113,7 @@ def prepare_tracker(parser, args): use_gpu=args.use_gpu, eos_stopping_thresh=args.eos_stop, simultaneous_tracking=args.simultaneous_tracking, append_last_point=append_last_point, - log_level=args.logging) + log_level=args.verbose) return tracker, ref @@ -123,7 +124,7 @@ def main(): # Setting root logger to high level to max info, not debug, prints way too # much stuff. (but we can set our tracker's logger to debug) - root_level = args.logging + root_level = args.verbose if root_level == 'DEBUG': root_level = 'INFO' logging.getLogger().setLevel(level=root_level) diff --git a/scripts_python/tt_train_model.py b/scripts_python/tt_train_model.py index 912645fa..47f7f653 100755 --- a/scripts_python/tt_train_model.py +++ b/scripts_python/tt_train_model.py @@ -19,7 +19,7 @@ from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_memory_args, add_logging_arg +from dwi_ml.io_utils import add_memory_args, add_verbose_arg from dwi_ml.models.projects.transformer_models import \ OriginalTransformerModel, TransformerSrcAndTgtModel, TransformerSrcOnlyModel from dwi_ml.models.projects.transformers_utils import ( @@ -40,7 +40,7 @@ def prepare_arg_parser(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter) add_mandatory_args_experiment_and_hdf5_path(p) - add_logging_arg(p) + add_verbose_arg(p) add_args_batch_sampler(p) add_args_batch_loader(p) add_training_args(p, add_a_tracking_validation_phase=True) @@ -63,7 +63,8 @@ def init_from_args(args, sub_loggers_level): cls = TransformerSrcAndTgtModel if args.target_embedded_size is None and \ args.target_embedding_key != 'no_embedding': - raise ValueError("target_embedded_size must be given for this model.") + raise ValueError("target_embedded_size must be given for this " + "model.") specific_args['target_embedded_size'] = args.target_embedded_size else: # Model TTO: input_embedding = target_embedding = d_model @@ -150,7 +151,7 @@ def init_from_args(args, sub_loggers_level): tracking_phase_mask_group=args.tracking_mask, # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, - log_level=args.logging) + log_level=args.verbose) logging.info("Trainer params : " + format_dict_to_str(trainer.params_for_checkpoint)) @@ -163,12 +164,8 @@ def main(): # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, # but we will set trainer to user-defined level. - sub_loggers_level = args.logging - if args.logging == 'DEBUG': - sub_loggers_level = 'INFO' - - # General logging (ex, scilpy: Warning) logging.getLogger().setLevel(level=logging.WARNING) + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' # Check that all files exist assert_inputs_exist(p, [args.hdf5_file]) diff --git a/scripts_python/tt_visualize_loss.py b/scripts_python/tt_visualize_loss.py index 111d6f60..6708a98a 100644 --- a/scripts_python/tt_visualize_loss.py +++ b/scripts_python/tt_visualize_loss.py @@ -41,10 +41,8 @@ def main(): names = visu_checks(args, p) # Loggers - sub_logger_level = args.logging.upper() - if sub_logger_level == 'DEBUG': - sub_logger_level = 'INFO' - logging.getLogger().setLevel(level=args.logging) + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + logging.getLogger().setLevel(level=args.verbose) # Device device = (torch.device('cuda') if torch.cuda.is_available() and @@ -58,7 +56,7 @@ def main(): model_dir = os.path.join(args.experiment_path, 'checkpoint/model') model_type = verify_which_model_in_path(model_dir) cls = find_transformer_class(model_type) - model = cls.load_model_from_params_and_state(model_dir, sub_logger_level) + model = cls.load_model_from_params_and_state(model_dir, sub_loggers_level) model.set_context('visu') # 2. Load data through the tester From adf28c8b1ad2996bf39708ec63fd6052d7f5f772 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 18 Apr 2024 15:43:37 -0400 Subject: [PATCH 2/3] Fix tests --- dwi_ml/testing/utils.py | 5 ++--- scripts_python/l2t_track_from_model.py | 11 ++++------- scripts_python/tests/test_all_steps_l2t.py | 11 ++++++----- scripts_python/tests/test_all_steps_tto.py | 11 ++++++----- scripts_python/tests/test_all_steps_tts.py | 9 +++++---- scripts_python/tests/test_all_steps_ttst.py | 9 +++++---- scripts_python/tt_track_from_model.py | 4 ++-- 7 files changed, 30 insertions(+), 30 deletions(-) diff --git a/dwi_ml/testing/utils.py b/dwi_ml/testing/utils.py index c34d589d..a718817b 100644 --- a/dwi_ml/testing/utils.py +++ b/dwi_ml/testing/utils.py @@ -59,9 +59,8 @@ def find_hdf5_associated_to_experiment(experiment_path): 'Checkpoint was not found! ({}). Could not find the hdf5 ' 'associated to your experiment. Please specify it yourself.' .format(checkpoint_path)) - else: - checkpoint_state = torch.load(checkpoint_path) - hdf5_file = checkpoint_state['dataset_params']['hdf5_file'] + checkpoint_state = torch.load(checkpoint_path) + hdf5_file = checkpoint_state['dataset_params']['hdf5_file'] return hdf5_file diff --git a/scripts_python/l2t_track_from_model.py b/scripts_python/l2t_track_from_model.py index b290eafe..b2021cd6 100644 --- a/scripts_python/l2t_track_from_model.py +++ b/scripts_python/l2t_track_from_model.py @@ -79,7 +79,7 @@ def prepare_tracker(parser, args): logging.info("Loading subject's data.") subset = prepare_dataset_one_subj( - args.hdf5_file, args.subj_id, lazy=False, + hdf5_file, args.subj_id, lazy=False, cache_size=args.cache_size, subset_name=args.subset, volume_groups=[args.input_group], streamline_groups=[]) @@ -108,7 +108,7 @@ def prepare_tracker(parser, args): use_gpu=args.use_gpu, eos_stopping_thresh=args.eos_stop, simultaneous_tracking=args.simultaneous_tracking, append_last_point=append_last_point, - log_level=args.verbose()) + log_level=args.verbose) return tracker, ref @@ -119,17 +119,14 @@ def main(): # Setting root logger to high level to max info, not debug, prints way too # much stuff. (but we can set our tracker's logger to debug) - root_level = args.verbose() - if root_level == 'DEBUG': - root_level = 'INFO' - logging.getLogger().setLevel(level=root_level) + logging.getLogger().setLevel(level=args.verbose) # ----- Checks if not nib.streamlines.is_supported(args.out_tractogram): parser.error('Invalid output streamline file format (must be trk or ' 'tck): {0}'.format(args.out_tractogram)) - assert_inputs_exist(parser, args.hdf5_file) + assert_inputs_exist(parser, [], args.hdf5_file) assert_outputs_exist(parser, args, args.out_tractogram) verify_streamline_length_options(parser, args) diff --git a/scripts_python/tests/test_all_steps_l2t.py b/scripts_python/tests/test_all_steps_l2t.py index 78449992..5589dd32 100644 --- a/scripts_python/tests/test_all_steps_l2t.py +++ b/scripts_python/tests/test_all_steps_l2t.py @@ -51,7 +51,7 @@ def test_training(script_runner, experiments_path): '--batch_size_units', 'nb_streamlines', '--max_batches_per_epoch_training', '2', '--max_batches_per_epoch_validation', '1', - '--logging', 'INFO', '--step_size', '0.5', + '-v', 'INFO', '--step_size', '0.5', '--nb_previous_dirs', '1') assert ret.success @@ -69,7 +69,7 @@ def test_training(script_runner, experiments_path): '--max_batches_per_epoch_training', '2', '--max_batches_per_epoch_validation', '1', '--dg_key', 'cosine-regression', '--add_eos', - '--logging', 'INFO', '--use_gpu') + '-v', 'INFO', '--use_gpu') assert ret.success @@ -93,8 +93,9 @@ def test_tracking(script_runner, experiments_path): # Testing HDF5 data does not contain a testing set to keep it light. Using # subjectX from training set. ret = script_runner.run( - 'l2t_track_from_model.py', whole_experiment_path, hdf5_file, subj_id, + 'l2t_track_from_model.py', whole_experiment_path, subj_id, input_group_name, out_tractogram, seeding_mask_group, + '--hdf5_file', hdf5_file, '--algo', 'det', '--nt', '2', '--rng_seed', '0', '--min_length', '0', '--subset', 'training', '--tracking_mask_group', tracking_mask_group) @@ -105,7 +106,7 @@ def test_tracking(script_runner, experiments_path): logging.info("********** TESTING GPU TRACKING FROM MODEL ************") out_tractogram = os.path.join(experiments_path, 'test_tractogram2.trk') ret = script_runner.run( - 'l2t_track_from_model.py', whole_experiment_path, hdf5_file, + 'l2t_track_from_model.py', whole_experiment_path, subj_id, input_group_name, out_tractogram, seeding_mask_group, '--algo', 'det', '--nt', '20', '--rng_seed', '0', '--min_length', '0', '--subset', 'training', @@ -155,7 +156,7 @@ def future_test_training_with_generation_validation(script_runner, experiments_p '--batch_size_units', 'nb_streamlines', '--max_batches_per_epoch_training', '2', '--max_batches_per_epoch_validation', '1', - '--logging', 'INFO', '--step_size', '0.5', + '-v', 'INFO', '--step_size', '0.5', '--add_a_tracking_validation_phase', '--tracking_phase_frequency', '1', option) assert ret.success diff --git a/scripts_python/tests/test_all_steps_tto.py b/scripts_python/tests/test_all_steps_tto.py index d92a18a0..3ecd8ec5 100644 --- a/scripts_python/tests/test_all_steps_tto.py +++ b/scripts_python/tests/test_all_steps_tto.py @@ -62,7 +62,7 @@ def test_execution(script_runner, experiments_path): '--input_embedding_key', 'nn_embedding', '--input_embedded_size', '6', '--n_layers_e', '1', '--ffnn_hidden_size', '3', '--step_size', '1', - '--logging', 'INFO') + '-v', 'INFO') assert ret.success logging.info("************ TESTING RESUMING FROM CHECKPOINT ************") @@ -87,7 +87,7 @@ def test_execution(script_runner, experiments_path): '--input_embedding_key', 'nn_embedding', '--input_embedded_size', '6', '--n_layers_e', '1', '--ffnn_hidden_size', '3', - '--logging', 'INFO', '--use_gpu') + '-v', 'INFO', '--use_gpu') assert ret.success logging.info("************ TESTING TRACKING FROM MODEL ************") @@ -100,10 +100,11 @@ def test_execution(script_runner, experiments_path): subj_id = TEST_EXPECTED_SUBJ_NAMES[0] ret = script_runner.run( - 'tt_track_from_model.py', whole_experiment_path, hdf5_file, subj_id, + 'tt_track_from_model.py', whole_experiment_path, subj_id, input_group, out_tractogram, seeding_mask_group, + '--hdf5_file', hdf5_file, '--algo', 'det', '--nt', '2', '--rng_seed', '0', - '--min_length', '0', '--subset', 'training', '--logging', 'DEBUG', + '--min_length', '0', '--subset', 'training', '-v', 'DEBUG', '--max_length', str(MAX_LEN * 0.5), '--step', '0.5', '--tracking_mask_group', tracking_mask_group) @@ -133,7 +134,7 @@ def test_execution(script_runner, experiments_path): input_group, in_sft, '--out_prefix', prefix, '--as_matrices', '--color_multi_length', '--color_x_y_summary', '--bertviz_locally', - '--subset', 'training', '--logging', 'INFO', + '--subset', 'training', '-v', 'INFO', '--resample_plots', '15', '--rescale_non_lin') assert ret.success diff --git a/scripts_python/tests/test_all_steps_tts.py b/scripts_python/tests/test_all_steps_tts.py index fe13d6ba..d3adb2a8 100644 --- a/scripts_python/tests/test_all_steps_tts.py +++ b/scripts_python/tests/test_all_steps_tts.py @@ -50,7 +50,7 @@ def test_execution(script_runner, experiments_path): '--input_embedding_key', 'nn_embedding', '--input_embedded_size', '6', '--n_layers_e', '1', - '--ffnn_hidden_size', '3', '--logging', 'INFO') + '--ffnn_hidden_size', '3', '-v', 'INFO') assert ret.success logging.info("************ TESTING RESUMING FROM CHECKPOINT ************") @@ -69,10 +69,11 @@ def test_execution(script_runner, experiments_path): subj_id = TEST_EXPECTED_SUBJ_NAMES[0] ret = script_runner.run( - 'tt_track_from_model.py', whole_experiment_path, hdf5_file, subj_id, + 'tt_track_from_model.py', whole_experiment_path, subj_id, input_group, out_tractogram, seeding_mask_group, + '--hdf5_file', hdf5_file, '--algo', 'det', '--nt', '2', '--rng_seed', '0', - '--min_length', '0', '--subset', 'training', '--logging', 'DEBUG', + '--min_length', '0', '--subset', 'training', '-v', 'DEBUG', '--max_length', str(MAX_LEN * 0.5), '--step', '0.5', '--tracking_mask_group', tracking_mask_group) @@ -102,6 +103,6 @@ def test_execution(script_runner, experiments_path): input_group, in_sft, '--out_prefix', prefix, '--as_matrices', '--color_multi_length', '--color_x_y_summary', '--bertviz_locally', - '--subset', 'training', '--logging', 'INFO', + '--subset', 'training', '-v', 'INFO', '--resample_plots', '15', '--rescale_0') assert ret.success diff --git a/scripts_python/tests/test_all_steps_ttst.py b/scripts_python/tests/test_all_steps_ttst.py index 85403c9c..4d98d5cc 100644 --- a/scripts_python/tests/test_all_steps_ttst.py +++ b/scripts_python/tests/test_all_steps_ttst.py @@ -51,7 +51,7 @@ def test_execution(script_runner, experiments_path): '--input_embedded_size', '6', '--target_embedded_size', '2', '--n_layers_e', '1', - '--ffnn_hidden_size', '3', '--logging', 'INFO') + '--ffnn_hidden_size', '3', '-v', 'INFO') assert ret.success logging.info("************ TESTING RESUMING FROM CHECKPOINT ************") @@ -70,10 +70,11 @@ def test_execution(script_runner, experiments_path): subj_id = TEST_EXPECTED_SUBJ_NAMES[0] ret = script_runner.run( - 'tt_track_from_model.py', whole_experiment_path, hdf5_file, subj_id, + 'tt_track_from_model.py', whole_experiment_path, subj_id, input_group, out_tractogram, seeding_mask_group, + '--hdf5_file', hdf5_file, '--algo', 'det', '--nt', '2', '--rng_seed', '0', - '--min_length', '0', '--subset', 'training', '--logging', 'DEBUG', + '--min_length', '0', '--subset', 'training', '-v', 'DEBUG', '--max_length', str(MAX_LEN * 0.5), '--step', '0.5', '--tracking_mask_group', tracking_mask_group) @@ -103,6 +104,6 @@ def test_execution(script_runner, experiments_path): input_group, in_sft, '--out_prefix', prefix, '--as_matrices', '--color_multi_length', '--color_x_y_summary', '--bertviz_locally', - '--subset', 'training', '--logging', 'INFO', + '--subset', 'training', '-v', 'INFO', '--resample_plots', '15', '--rescale_z') assert ret.success diff --git a/scripts_python/tt_track_from_model.py b/scripts_python/tt_track_from_model.py index 59724284..296d6e68 100644 --- a/scripts_python/tt_track_from_model.py +++ b/scripts_python/tt_track_from_model.py @@ -81,7 +81,7 @@ def prepare_tracker(parser, args): logging.info("Loading subject's data.") subset = prepare_dataset_one_subj( - args.hdf5_file, args.subj_id, lazy=False, + hdf5_file, args.subj_id, lazy=False, cache_size=args.cache_size, subset_name=args.subset, volume_groups=[args.input_group], streamline_groups=[]) @@ -134,7 +134,7 @@ def main(): parser.error('Invalid output streamline file format (must be trk or ' 'tck): {0}'.format(args.out_tractogram)) - assert_inputs_exist(parser, args.hdf5_file) + assert_inputs_exist(parser, [], args.hdf5_file) assert_outputs_exist(parser, args, args.out_tractogram) verify_streamline_length_options(parser, args) From 718e1c5ce91488e0da0a8421436f8365d31ed109 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 18 Apr 2024 16:17:25 -0400 Subject: [PATCH 3/3] Improve aspect of _train_ scripts. To be continued --- dwi_ml/training/utils/batch_loaders.py | 29 +++++++------- dwi_ml/training/utils/batch_samplers.py | 29 +++++++------- dwi_ml/training/utils/experiment.py | 2 +- dwi_ml/training/utils/trainer.py | 50 +++++++++++++------------ 4 files changed, 57 insertions(+), 53 deletions(-) diff --git a/dwi_ml/training/utils/batch_loaders.py b/dwi_ml/training/utils/batch_loaders.py index af4694dd..ff8b4166 100644 --- a/dwi_ml/training/utils/batch_loaders.py +++ b/dwi_ml/training/utils/batch_loaders.py @@ -13,27 +13,26 @@ def add_args_batch_loader(p: argparse.ArgumentParser): bl_g = p.add_argument_group("Batch loader") bl_g.add_argument( '--noise_gaussian_size_forward', type=float, metavar='s', default=0., - help="If set, add random Gaussian noise to streamline coordinates \n" - "with given variance. Noise is added AFTER interpolation of " - "underlying data. \nExample of use: when concatenating previous " - "direction to input.\n" - "This corresponds to the std of the Gaussian. [0]\n" - "**Make sure noise is smaller than your step size " - "to avoid \nflipping direction! (We can't verify if --step_size " - "is not \nspecified here, but if it is, we limit noise to \n" - "+/- 0.5 * step-size.).\n" - "** We also limit noise to +/- 2 * noise_gaussian_size.\n" - "Suggestion: 0.1 * step-size.") + help="If set, we will add random Gaussian noise to the streamline " + "coordinates. Noise \nis added AFTER interpolation of " + "the DWI inputs, so this is only useful if your \nforward method " + "uses the streamlines; when they also serve as inputs. See also\n" + "noise_gaussian_size_loss for an alternate option. The value " + "corresponds to the \nstd of the Gaussian. We limit noise to " + "+/- 2 * noise_gaussian_size. Suggestion: \n0.1 * step-size.\n" + "**Make sure that this noise is smaller than your step size, " + "to avoid flipping \ndirection! (If --step_size is set, we limit " + "noise to +/- 0.5 * step-size). ") bl_g.add_argument( '--noise_gaussian_size_loss', type=float, metavar='s', default=0., help='Idem, but loss is added to targets instead (during training ' 'only).') bl_g.add_argument( '--split_ratio', type=float, metavar='r', default=0., - help="Percentage of streamlines to randomly split into 2, in each \n" - "batch (keeping both segments as two independent streamlines). \n" - "The reason for cutting is to help the ML algorithm to track " - "from \nthe middle of WM by having already seen half-streamlines." + help="Percentage of streamlines to randomly split into 2, in each " + "batch (keeping both \nsegments as two independent streamlines). " + "The reason for cutting is to help \ntracking from the middle of " + "white matter by having already seen half-streamlines." "\nIf you are using interface seeding, this is not necessary. " "[0]") bl_g.add_argument( diff --git a/dwi_ml/training/utils/batch_samplers.py b/dwi_ml/training/utils/batch_samplers.py index 3ac130fd..fc60b94f 100644 --- a/dwi_ml/training/utils/batch_samplers.py +++ b/dwi_ml/training/utils/batch_samplers.py @@ -11,10 +11,10 @@ def add_args_batch_sampler(p: argparse.ArgumentParser): g_batch_size.add_argument( '--batch_size_training', type=int, default=100, metavar='s', - help="Batch size. Unit must be spectified (through batch_size_units)." - "\nThe total size your computer will accept depends on the " - "type of \ninput data. You will need to test this value. [100]\n" - "Suggestion: in nb_streamlines: 100. In length_mm: 10000. \n") + help="Batch size. Unit must be spectified (through batch_size_units). " + "The total size \nyour computer will accept depends on the " + "type of input data. You will need to \ntest this value. " + "Suggestion: in nb_streamlines: 100. In length_mm: 10,000. [100]") g_batch_size.add_argument( '--batch_size_validation', type=int, default=100, metavar='s', help="Idem; batch size during validation.") @@ -22,24 +22,25 @@ def add_args_batch_sampler(p: argparse.ArgumentParser): '--batch_size_units', type=str, metavar='u', default='nb_streamlines', choices={'nb_streamlines', 'length_mm'}, help="One of 'nb_streamlines' or 'length_mm' (which should hopefully " - "\nbe correlated to the number of input data points).") + "be correlated \nto the number of input data points).") g_batch_size.add_argument( '--nb_streamlines_per_chunk', type=int, default=None, metavar='n', - help="Only used with batch_size_units='length_mm'. \nChunks of " - "n streamlines are sampled at once, their size is \nchecked, " - "and then number of streamlines is ajusted until below \n" + help="Only used with batch_size_units='length_mm'. Chunks of " + "n streamlines are sampled \nat once, their size is checked, " + "and then number of streamlines is ajusted until \nbelow " "batch_size.") g_batch_size.add_argument( '--nb_subjects_per_batch', type=int, metavar='n', - help="Maximum number of different subjects from which to load data \n" - "in each batch. This should help avoid loading too many inputs \n" - "in memory, particularly for lazy data. If not set, we will " - "use \ntrue random sampling. Suggestion, 5. \n" + help="Maximum number of different subjects from which to load data " + "in each batch. This \nshould help avoid loading too many inputs " + "in memory, particularly for lazy data. \nIf not set, we will " + "use true random sampling. Suggestion, 5.\n" "**Note: Will influence the cache if the cache_manager is used.") g_batch_size.add_argument( '--cycles', type=int, metavar='c', - help="Relevant only if nb_subject_per_batch is set. Number of cycles\n" - "before changing to new subjects (and thus loading new volumes).") + help="Relevant only if nb_subject_per_batch is set. Number of cycles " + "before changing \nto new subjects (and thus loading new " + "volumes).") def prepare_batch_sampler(dataset, args, sub_loggers_level): diff --git a/dwi_ml/training/utils/experiment.py b/dwi_ml/training/utils/experiment.py index bb3be4a5..a138bee5 100644 --- a/dwi_ml/training/utils/experiment.py +++ b/dwi_ml/training/utils/experiment.py @@ -4,7 +4,7 @@ def add_mandatory_args_experiment_and_hdf5_path(p): p.add_argument( 'experiments_path', - help='Path where to save your experiment. \nComplete path will be ' + help='Path where to save your experiment. Complete path will be \n' 'experiments_path/experiment_name.') p.add_argument( 'experiment_name', diff --git a/dwi_ml/training/utils/trainer.py b/dwi_ml/training/utils/trainer.py index 8b90a70c..b9c9ce89 100644 --- a/dwi_ml/training/utils/trainer.py +++ b/dwi_ml/training/utils/trainer.py @@ -14,35 +14,34 @@ def add_training_args(p: argparse.ArgumentParser, training_group.add_argument( '--learning_rate', metavar='r', nargs='+', help="Learning rate. Can be set as a single float, or as a list of " - "[lr*step]. \n" - "Ex: '--learning_rate 0.001*3 0.0001' would set the lr to 0.001 " - "for the first \n3 epochs, and 0.0001 for the remaining epochs.\n" + "[lr*step]. For instance, \n" + "--learning_rate 0.001*3 0.0001 would set the lr to 0.001 " + "for the 3 first epochs, \nand 0.0001 for the remaining epochs." "(torch's default = 0.001)") training_group.add_argument( '--weight_decay', type=float, default=0.01, metavar='v', - help="Add a weight decay penalty on the parameters (regularization " - "parameter)\n[0.01] (torch's default).") + help="Add a weight decay penalty on the parameters (regularization ) " + "(torch's default: \n0.01).") training_group.add_argument( '--optimizer', choices=['Adam', 'RAdam', 'SGD'], default='Adam', - help="Choice of torch optimizer amongst ['Adam', 'RAdam', 'SGD'].\n" + help="Choice of torch optimizer amongst ['Adam', 'RAdam', 'SGD']. " "Default: Adam.") training_group.add_argument( '--max_epochs', type=int, default=100, metavar='n', help="Maximum number of epochs. [100]") training_group.add_argument( - '--patience', type=int, default=20, metavar='n', - help="Use early stopping. Defines the number of epochs after which \n" - "the model should stop if the loss hasn't improved. \n" - "Default: same as max_epochs.") + '--patience', type=int, metavar='n', + help="If set, uses early stopping. Defines the number of epochs after " + "which the model \nshould stop if the loss hasn't improved.") training_group.add_argument( '--patience_delta', type=float, default=1e-6, metavar='eps', help="Limit difference between two validation losses to consider that " - "\nthe model improved between the two epochs.") + "the model has \nimproved between two epochs. [1e-6]") training_group.add_argument( '--max_batches_per_epoch_training', type=int, default=1000, metavar='n', - help="Maximum number of batches per epoch. This will help avoid long\n" - "epochs, to ensure that we save checkpoints regularly. [1000]") + help="Maximum number of batches per epoch. This will help avoid long " + "epochs, to ensure \nthat we save checkpoints regularly. [1000]") training_group.add_argument( '--max_batches_per_epoch_validation', type=int, default=1000, metavar='n', @@ -50,28 +49,33 @@ def add_training_args(p: argparse.ArgumentParser, if add_a_tracking_validation_phase: training_group.add_argument( - '--add_a_tracking_validation_phase', action='store_true') + '--add_a_tracking_validation_phase', action='store_true', + help="If set, a generation validation phase (GV) will be added.") training_group.add_argument( - '--tracking_phase_frequency', type=int, default=5) + '--tracking_phase_frequency', type=int, default=1, metavar='N', + help="The GV phase can be computed at every epoch (default), or " + "once every N epochs.") training_group.add_argument( '--tracking_mask', help="Volume group to use as tracking mask during the generation " "phase.") training_group.add_argument( - '--tracking_phase_nb_segments_init', type=int, default=5, - help="Number of segments copied from the 'real' streamlines " - "before starting propagation during generation phases.") + '--tracking_phase_nb_segments_init', type=int, default=1, + metavar='N', + help="Number of segments copied from the 'real' validation " + "streamlines before starting \npropagation during GV phases " + "[1].") comet_g = p.add_argument_group("Comet") comet_g.add_argument( '--comet_workspace', metavar='w', - help='Your comet workspace. If not set, comet.ml will not be used.\n' - 'See our docs/Getting Started for more information on comet \n' - 'and its API key.') + help='Your comet workspace. If not set, comet.ml will not be used. ' + 'See our doc for more \ninformation on comet and its API key: \n' + 'https://dwi-ml.readthedocs.io/en/latest/getting_started.html') comet_g.add_argument( '--comet_project', metavar='p', - help='Send your experiment to a specific comet.ml project. If not \n' - 'set, it will be sent to Uncategorized Experiments.') + help='Send your experiment to a specific comet.ml project. If not ' + 'set, it will be sent \nto Uncategorized Experiments.') return training_group