Skip to content

Commit

Permalink
separate functions for reading and creating metric table files
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard Lane committed Nov 12, 2024
1 parent 740dd4e commit d4a62f7
Showing 1 changed file with 51 additions and 21 deletions.
72 changes: 51 additions & 21 deletions scripts/plot_hyperparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,16 @@ def _plot_scatters(data_dir: pathlib.Path, metric: str) -> plt.Figure:
"""
data_dirs = list(data_dir.glob("*"))
runs = []

# Create metrics files for all the runs
# If the run is still going, there might be some missing files
for dir_ in tqdm(data_dirs, desc="Creating metric tables"):
# We might still be running, in which case the last dir will be incomplete
try:
_write_metrics_files(dir_)
except FileNotFoundError:
continue

for dir_ in data_dirs:
if metric == "dice":
try:
score = _dicescore(dir_)
Expand Down Expand Up @@ -250,38 +258,42 @@ def _plot_med(input_dir: pathlib.Path, output_dir: pathlib.Path):
fig.savefig(output_dir / "scores.png")


def _dicescore(results_dir: pathlib.Path) -> float:
def _write_metrics_files(results_dir: pathlib.Path) -> None:
"""
Get the DICE score from the i-th run
Write files containing the validation metrics for each run, as markdown
:param results_dir: The directory containing the results from a run
"""
# We'll write the score to a file
metrics_file = results_dir / "metrics.txt"

# If it doesn't exist, write the metrics file
if not metrics_file.exists():
n_val_imgs = len(list(results_dir.glob("*val_pred_*.npy")))
assert n_val_imgs == len(list(results_dir.glob("val_truth_*.npy")))

if n_val_imgs == 0:
raise FileNotFoundError(f"No validation images found in {results_dir}")
if metrics_file.exists():
return

pred = [np.load(results_dir / f"val_pred_{i}.npy") for i in range(n_val_imgs)]
truth = [np.load(results_dir / f"val_truth_{i}.npy") for i in range(n_val_imgs)]
n_val_imgs = len(list(results_dir.glob("*val_pred_*.npy")))
assert n_val_imgs == len(list(results_dir.glob("val_truth_*.npy")))
if n_val_imgs == 0:
raise FileNotFoundError(f"No validation images found in {results_dir}")

for p, t in zip(pred, truth):
assert p.shape == t.shape, f"{p.shape=} != {t.shape=}"
# Load the arrays
pred = [np.load(results_dir / f"val_pred_{i}.npy") for i in range(n_val_imgs)]
truth = [np.load(results_dir / f"val_truth_{i}.npy") for i in range(n_val_imgs)]
for p, t in zip(pred, truth):
assert p.shape == t.shape, f"{p.shape=} != {t.shape=}"
if not p.min() >= 0 and p.max() <= 1:
raise ValueError("Prediction should be scaled to between 0 and 1")

# The prediction should already be scaled to be between 0 and 1
if not p.min() >= 0 and p.max() <= 1:
raise ValueError("Prediction should be scaled to between 0 and 1")
table = metrics.table(truth, pred)
with open(metrics_file, "w", encoding="utf-8") as f:
f.write(table.to_markdown())

table = metrics.table(truth, pred)

with open(metrics_file, "w", encoding="utf-8") as f:
f.write(table.to_markdown())
def _metrics_df(metrics_file: pathlib.Path) -> pd.DataFrame:
"""
Read the metrics files and return a DataFrame
# get the average DICE score from the metrics file
"""
df = (
pd.read_table(
metrics_file,
Expand All @@ -295,6 +307,24 @@ def _dicescore(results_dir: pathlib.Path) -> float:
.astype(float)
)
df.columns = df.columns.str.strip()
return df


def _dicescore(results_dir: pathlib.Path) -> float:
"""
Get the DICE score from the i-th run
:param results_dir: The directory containing the results from a run
"""
# We'll write the score to a file
metrics_file = results_dir / "metrics.txt"

if not metrics_file.exists():
raise FileNotFoundError(f"No metrics file found in {results_dir}")

# get the average DICE score from the metrics file
df = _metrics_df(metrics_file)

return df["Dice"].mean()

Expand Down

0 comments on commit d4a62f7

Please sign in to comment.