diff --git a/scripts/explore_hyperparams.py b/scripts/explore_hyperparams.py index 56c4a51..94f3ece 100644 --- a/scripts/explore_hyperparams.py +++ b/scripts/explore_hyperparams.py @@ -21,37 +21,35 @@ from fishjaw.visualisation import images_3d, training -def _output_parent(mode: str) -> pathlib.Path: +def _output_parent(mode: str, out_dir: pathlib.Path) -> pathlib.Path: """Parent dir for output""" - retval = files.script_out_dir() / "tuning_output" / mode + retval = files.script_out_dir() / out_dir / mode if not retval.is_dir(): retval.mkdir(parents=True) return retval -def _output_dir(n: int, mode: str) -> pathlib.Path: +def _output_dir(n: int, mode: str, out_dir: pathlib.Path) -> pathlib.Path: """ Get the output directory for this run """ - out_dir = _output_parent(mode) / str(n) + out_dir = _output_parent(mode, out_dir) / str(n) if not out_dir.is_dir(): out_dir.mkdir(parents=True) return out_dir -def _n_existing_dirs(mode: str) -> int: +def _n_existing_dirs(mode: str, out_dir: pathlib.Path) -> int: """ How many runs we've done in this mode so far Doesn't check the integrity of the directories, just counts them """ - out_dir = _output_parent(mode) - # The directories are named by number, so we can just count them - return sum(1 for file_ in out_dir.iterdir() if file_.is_dir()) + return sum(1 for file_ in _output_parent(mode, out_dir).iterdir() if file_.is_dir()) def _lr(rng: np.random.Generator, mode: str) -> float: @@ -265,13 +263,15 @@ def step( np.save(out_dir / "val_losses.npy", val_losses) -def main(*, mode: str, n_steps: int, continue_run: bool): +def main(*, mode: str, n_steps: int, continue_run: bool, out_dir: str): """ Set up the configuration and run the training """ + out_dir = pathlib.Path(out_dir) + # Check if we have existing directories - n_existing_dirs = _n_existing_dirs(mode) + n_existing_dirs = _n_existing_dirs(mode, out_dir) if continue_run: if n_existing_dirs == 0: raise ValueError("No existing directories to continue from") @@ -308,18 +308,18 @@ def main(*, mode: str, n_steps: int, continue_run: bool): ) for i in range(start, start + n_steps): - out_dir = _output_dir(i, mode) + run_dir = _output_dir(i, mode, out_dir) config = _config(rng, mode) # Since the dataloader picks random patches, the training data is slightly different # between runs. Hopefully this doesn't matter though data_config = data.DataConfig(config, train_subjects, val_subjects) - with open(out_dir / "config.yaml", "w", encoding="utf-8") as cfg_file: + with open(run_dir / "config.yaml", "w", encoding="utf-8") as cfg_file: yaml.dump(config, cfg_file) try: - step(config, data_config, out_dir, full_validation_subjects) + step(config, data_config, run_dir, full_validation_subjects) except torch.cuda.OutOfMemoryError as e: print(config) print(e) @@ -338,6 +338,11 @@ def main(*, mode: str, n_steps: int, continue_run: bool): Determines the range of hyperparameters, and which are searched""", ) parser.add_argument("n_steps", type=int, help="Number of models to train") + parser.add_argument( + "out_dir", + type=str, + help="Directory to save outputs in, relative to the script output directory", + ) parser.add_argument( "--continue_run", action="store_true", diff --git a/scripts/plot_hyperparams.py b/scripts/plot_hyperparams.py index e8d55c6..655412d 100644 --- a/scripts/plot_hyperparams.py +++ b/scripts/plot_hyperparams.py @@ -235,10 +235,13 @@ def _plot_scatters(data_dir: pathlib.Path, metric: str) -> plt.Figure: return _plot_scores(runs) -def main(mode: str): +def main(mode: str, out_dir: str): """Choose the granularity of the search to plot""" - input_dir = files.script_out_dir() / "tuning_output" / mode - output_dir = files.script_out_dir() / "tuning_plots" / mode + input_dir = files.script_out_dir() / out_dir / mode + if not input_dir.exists(): + raise FileNotFoundError(f"Directory {input_dir} not found") + + output_dir = files.script_out_dir() / "tuning_plots" / out_dir / mode if not output_dir.exists(): output_dir.mkdir(parents=True) @@ -272,6 +275,12 @@ def main(mode: str): choices={"coarse", "med", "fine"}, help="Granularity of the search.", ) + parser.add_argument( + "out_dir", + type=str, + help="Directory to read the tuning outputs from," + "relative to the script output directory", + ) args = parser.parse_args() @@ -279,9 +288,9 @@ def main(mode: str): if args.mode == "fine": _write_all_metrics_files( sorted( - list((files.script_out_dir() / "tuning_output" / "fine").glob("*")), + list((files.script_out_dir() / args.out_dir / args.mode).glob("*")), key=lambda x: int(x.name), ) ) - main(args.mode) + main(args.mode, args.out_dir)