Skip to content

Commit

Permalink
Feature/specify tuning dir (#28)
Browse files Browse the repository at this point in the history
Mandatory argument for output directory when running the hyperparam tuning/making plots
---------

Co-authored-by: Richard Lane <[email protected]>
  • Loading branch information
richard-lane and Richard Lane authored Nov 13, 2024
1 parent 8f6ac33 commit 3f3263c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
31 changes: 18 additions & 13 deletions scripts/explore_hyperparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,35 @@
from fishjaw.visualisation import images_3d, training


def _output_parent(mode: str) -> pathlib.Path:
def _output_parent(mode: str, out_dir: pathlib.Path) -> pathlib.Path:
"""Parent dir for output"""
retval = files.script_out_dir() / "tuning_output" / mode
retval = files.script_out_dir() / out_dir / mode
if not retval.is_dir():
retval.mkdir(parents=True)
return retval


def _output_dir(n: int, mode: str) -> pathlib.Path:
def _output_dir(n: int, mode: str, out_dir: pathlib.Path) -> pathlib.Path:
"""
Get the output directory for this run
"""
out_dir = _output_parent(mode) / str(n)
out_dir = _output_parent(mode, out_dir) / str(n)
if not out_dir.is_dir():
out_dir.mkdir(parents=True)

return out_dir


def _n_existing_dirs(mode: str) -> int:
def _n_existing_dirs(mode: str, out_dir: pathlib.Path) -> int:
"""
How many runs we've done in this mode so far
Doesn't check the integrity of the directories, just counts them
"""
out_dir = _output_parent(mode)

# The directories are named by number, so we can just count them
return sum(1 for file_ in out_dir.iterdir() if file_.is_dir())
return sum(1 for file_ in _output_parent(mode, out_dir).iterdir() if file_.is_dir())


def _lr(rng: np.random.Generator, mode: str) -> float:
Expand Down Expand Up @@ -265,13 +263,15 @@ def step(
np.save(out_dir / "val_losses.npy", val_losses)


def main(*, mode: str, n_steps: int, continue_run: bool):
def main(*, mode: str, n_steps: int, continue_run: bool, out_dir: str):
"""
Set up the configuration and run the training
"""
out_dir = pathlib.Path(out_dir)

# Check if we have existing directories
n_existing_dirs = _n_existing_dirs(mode)
n_existing_dirs = _n_existing_dirs(mode, out_dir)
if continue_run:
if n_existing_dirs == 0:
raise ValueError("No existing directories to continue from")
Expand Down Expand Up @@ -308,18 +308,18 @@ def main(*, mode: str, n_steps: int, continue_run: bool):
)

for i in range(start, start + n_steps):
out_dir = _output_dir(i, mode)
run_dir = _output_dir(i, mode, out_dir)
config = _config(rng, mode)

# Since the dataloader picks random patches, the training data is slightly different
# between runs. Hopefully this doesn't matter though
data_config = data.DataConfig(config, train_subjects, val_subjects)

with open(out_dir / "config.yaml", "w", encoding="utf-8") as cfg_file:
with open(run_dir / "config.yaml", "w", encoding="utf-8") as cfg_file:
yaml.dump(config, cfg_file)

try:
step(config, data_config, out_dir, full_validation_subjects)
step(config, data_config, run_dir, full_validation_subjects)
except torch.cuda.OutOfMemoryError as e:
print(config)
print(e)
Expand All @@ -338,6 +338,11 @@ def main(*, mode: str, n_steps: int, continue_run: bool):
Determines the range of hyperparameters, and which are searched""",
)
parser.add_argument("n_steps", type=int, help="Number of models to train")
parser.add_argument(
"out_dir",
type=str,
help="Directory to save outputs in, relative to the script output directory",
)
parser.add_argument(
"--continue_run",
action="store_true",
Expand Down
19 changes: 14 additions & 5 deletions scripts/plot_hyperparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,13 @@ def _plot_scatters(data_dir: pathlib.Path, metric: str) -> plt.Figure:
return _plot_scores(runs)


def main(mode: str):
def main(mode: str, out_dir: str):
"""Choose the granularity of the search to plot"""
input_dir = files.script_out_dir() / "tuning_output" / mode
output_dir = files.script_out_dir() / "tuning_plots" / mode
input_dir = files.script_out_dir() / out_dir / mode
if not input_dir.exists():
raise FileNotFoundError(f"Directory {input_dir} not found")

output_dir = files.script_out_dir() / "tuning_plots" / out_dir / mode
if not output_dir.exists():
output_dir.mkdir(parents=True)

Expand Down Expand Up @@ -272,16 +275,22 @@ def main(mode: str):
choices={"coarse", "med", "fine"},
help="Granularity of the search.",
)
parser.add_argument(
"out_dir",
type=str,
help="Directory to read the tuning outputs from,"
"relative to the script output directory",
)

args = parser.parse_args()

# We need to write the files holding the table of metrics
if args.mode == "fine":
_write_all_metrics_files(
sorted(
list((files.script_out_dir() / "tuning_output" / "fine").glob("*")),
list((files.script_out_dir() / args.out_dir / args.mode).glob("*")),
key=lambda x: int(x.name),
)
)

main(args.mode)
main(args.mode, args.out_dir)

0 comments on commit 3f3263c

Please sign in to comment.