Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adam monai #38

Draft
wants to merge 64 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
962bf26
tv-detection first commit
adam-peaston-SC Sep 5, 2023
41e4d05
tv-detections working, pending epoch completion
adam-peaston-SC Sep 6, 2023
a09c24b
all detection and segmentation working pending full epoch
adam-peaston-SC Sep 6, 2023
7a3cb0c
removed WIP from isc-demos README
adam-peaston-SC Sep 6, 2023
ccbc0f6
Updated tv-detection and tv-segmentation to checkpoint (and report) a…
adam-peaston-SC Sep 6, 2023
9896974
Updated tv-detection to fix epoch roll-over error and added sophistic…
adam-peaston-SC Sep 7, 2023
cea7e8b
Updates
adam-peaston-SC Sep 7, 2023
7ce8cb4
updates
adam-peaston-SC Sep 7, 2023
ba815e8
updates
adam-peaston-SC Sep 7, 2023
3539a17
Added Timer to cycling_utils/saving
adam-peaston-SC Sep 7, 2023
b61aec5
moved Timer to its own file, updated training scripts to time imports
adam-peaston-SC Sep 7, 2023
228d6e3
Updated timer to only report if the global rank of the process is == …
adam-peaston-SC Sep 7, 2023
a887392
Updated tv-detection to use new InterruptableDistributedGroupedBatchS…
adam-peaston-SC Sep 10, 2023
1595277
detection and segmentation working all but detection eval
adam-peaston-SC Sep 10, 2023
d44cd5b
minor updates
adam-peaston-SC Sep 10, 2023
62df2d4
updates thinking about checkpointing evaluation better
adam-peaston-SC Sep 10, 2023
60db14f
minor updates and full training kickoffs
adam-peaston-SC Sep 10, 2023
0b7c932
monai autoencoder traininig interruptably
adam-peaston-SC Sep 13, 2023
1ec3b1c
updateds
adam-peaston-SC Sep 13, 2023
bd04c29
updates
adam-peaston-SC Sep 13, 2023
fd82382
updates
adam-peaston-SC Sep 13, 2023
7259776
potential fix for cudnnbatchnorm error
adam-peaston-SC Sep 13, 2023
91fc314
minor update
adam-peaston-SC Sep 13, 2023
70cba9d
still failing when discriminator included in loss
adam-peaston-SC Sep 13, 2023
e2863d1
autoencoder issue fixed
adam-peaston-SC Sep 14, 2023
43b3bd2
diffusion model training, finished 3 epochs
adam-peaston-SC Sep 15, 2023
29c6bfb
updates to args to fix issue resuming
adam-peaston-SC Sep 15, 2023
288f0a9
Latest updates
adam-peaston-SC Sep 18, 2023
a560f59
Integrated tensorboard with maskrcnn and monai
adam-peaston-SC Sep 19, 2023
6859b18
backup
adam-peaston-SC Sep 19, 2023
5d7a69b
updates, updates, updates
adam-peaston-SC Sep 19, 2023
89bf099
code tidier
adam-peaston-SC Sep 19, 2023
fcd1387
Updates and experiments aligned with literature benchmarks
adam-peaston-SC Sep 20, 2023
953cdc4
Fixed reporting issue with mask/retina
adam-peaston-SC Sep 20, 2023
580d095
Merge branch 'main' into adam-tv-detection
StrongFennecs Sep 21, 2023
ec3760a
tidy up and update of timer etc.
adam-peaston-SC Sep 21, 2023
62e5360
just in case
adam-peaston-SC Sep 21, 2023
dc55f1b
changes
adam-peaston-SC Sep 21, 2023
4683fc0
changes
adam-peaston-SC Sep 21, 2023
43e7bd1
fixed mess thank you Calvin!
adam-peaston-SC Sep 21, 2023
37d878b
tidy up, linting, deleting log files
adam-peaston-SC Sep 22, 2023
4ef575e
removed local pyproject.toml files
adam-peaston-SC Sep 22, 2023
d09abfb
removed monai
adam-peaston-SC Sep 22, 2023
addc304
updated readmes
adam-peaston-SC Sep 22, 2023
73531f4
updated readme with ref to requirements
adam-peaston-SC Sep 22, 2023
0d09904
removed resuming dir arg from fcn_resnet101.isc
adam-peaston-SC Sep 22, 2023
e043663
tackling instabilities
adam-peaston-SC Sep 25, 2023
2837068
first commit to monai branch
adam-peaston-SC Sep 25, 2023
d0cf9a3
monai updated with bones of pancreas
adam-peaston-SC Sep 27, 2023
7ed1204
updates to pancreas
adam-peaston-SC Sep 27, 2023
66099cf
monai pancreas search ready to test
adam-peaston-SC Sep 27, 2023
f3f7e0c
added saving for search space
adam-peaston-SC Sep 27, 2023
11bbb8a
updates to pancreas mostly
adam-peaston-SC Sep 28, 2023
1ccffe5
progress on pancreas
adam-peaston-SC Sep 28, 2023
b1ffdb5
latest updates to monai pancreas, ready for isc testing
adam-peaston-SC Sep 29, 2023
46903ef
ok now done
adam-peaston-SC Sep 29, 2023
b15ceff
updates to pancreas with timing and imports
adam-peaston-SC Sep 29, 2023
6b1a13a
pancreas search cycling
adam-peaston-SC Sep 29, 2023
35cdc2b
slimmed down imports for pancreas
adam-peaston-SC Sep 29, 2023
2c289cb
updates including monai and cycling utils enhancement
adam-peaston-SC Oct 6, 2023
b2cd243
Updated monai with setup readme
adam-peaston-SC Oct 6, 2023
1dabdb8
fixed tv-detection tensorboard logging of val metrics
adam-peaston-SC Oct 9, 2023
ee328a8
removed data subsetting from tv-detection
adam-peaston-SC Oct 9, 2023
52d5441
updates
adam-peaston-SC Oct 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ isc clusters # view the status of the clusters

(from https://github.com/pytorch/vision/tree/main/references/segmentation)

- WIP [fcn_resnet101.isc](./tv-segmentation/fcn_resnet101.isc)
- WIP [deeplabv3_mobilenet_v3_large.isc](./tv-segmentation/deeplabv3_mobilenet_v3_large.isc)
- [fcn_resnet101.isc](./tv-segmentation/fcn_resnet101.isc)
- [deeplabv3_mobilenet_v3_large.isc](./tv-segmentation/deeplabv3_mobilenet_v3_large.isc)

### tv-detection

(from https://github.com/pytorch/vision/tree/main/references/detection)

- WIP [maskrcnn_resnet50_fpn.isc](./tv-detection/fasterrcnn_resnet50_fpn.isc)
- WIP [retinanet_resnet50_fpn.isc](./tv-detection/retinanet_resnet50_fpn.isc)
- [maskrcnn_resnet50_fpn.isc](./tv-detection/fasterrcnn_resnet50_fpn.isc)
- [retinanet_resnet50_fpn.isc](./tv-detection/retinanet_resnet50_fpn.isc)
8 changes: 4 additions & 4 deletions cycling_utils/cycling_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .saving import atomic_torch_save
from .sampler import InterruptableDistributedSampler
from .lightning_utils import EpochHandler
from .timer import Timer, TimestampedTimer
from .saving import atomic_torch_save, MetricsTracker
from .sampler import InterruptableDistributedSampler, InterruptableDistributedGroupedBatchSampler

__all__ = ["InterruptableDistributedSampler", "atomic_torch_save", "EpochHandler"]
__all__ = ["InterruptableDistributedSampler", "InterruptableDistributedGroupedBatchSampler", "atomic_torch_save", "Timer", "TimestampedTimer"]
166 changes: 166 additions & 0 deletions cycling_utils/cycling_utils/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
from torch.utils.data import Dataset, DistributedSampler
from contextlib import contextmanager
from collections import defaultdict
from itertools import chain, repeat

class HasNotResetProgressError(Exception):
pass
Expand Down Expand Up @@ -113,3 +115,167 @@ def in_epoch(self, epoch):
self.set_epoch(epoch)
yield
self._reset_progress()

def _repeat_to_at_least(iterable, n):
repeat_times = math.ceil(n / len(iterable))
repeated = chain.from_iterable(repeat(iterable, repeat_times))
return list(repeated)

class InterruptableDistributedGroupedBatchSampler(DistributedSampler):
def __init__(
self,
dataset: Dataset,
group_ids: list,
batch_size: int,
num_replicas: int | None = None,
rank: int | None = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
"""
This is a DistributedSampler that can be suspended and resumed.

This works by keeping track of the sample batches that have already been
dispatched. This InterruptableDistributedGroupedBatchSampler also
reproduces the sampling strategy exhibited in the torch vision detection
reference wherein batches are created from images from within the same
'group', defined in the torchvision example by similarity of image
aspect ratio.

https://github.com/pytorch/vision/tree/main/references/detection

For this reason, InterruptableDistributedGroupedBatchSampler progress is
tracked in units of batches, not samples. This is an important
distinction from the InterruptableDistributedSampler which tracks progress
in units of samples. The progress is reset to 0 at the end of each epoch.

The epoch is set to 0 at initialization and incremented at the start
of each epoch.

Suspending and resuming the sampler is done by saving and loading the
state dict. The state dict contains the epoch and progress.
"""
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)

# OVERALL STATUS INDICATOR
self.progress = 0
self._has_reset_progress = True
self.batch_size = batch_size
self.group_ids = group_ids
self.batches = self._create_batches()

def _create_batches(self):
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]

if not self.drop_last:
# add extra samples to make dataset evenly divisible accross ranks
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make dataset evenly divisible accross ranks
indices = indices[: self.total_size]
assert len(indices) == self.total_size

# subsample indices to use on this rank
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples

# PRE-COMPUTE GROUPED BATCHES
buffer_per_group = defaultdict(list)
samples_per_group = defaultdict(list)
self.num_batches = math.ceil(len(indices)/ self.batch_size)

batches = [] # pre-computed so progress refers to batches, not samples.
for idx in indices:
group_id = self.group_ids[idx]
buffer_per_group[group_id].append(idx)
samples_per_group[group_id].append(idx)
if len(buffer_per_group[group_id]) == self.batch_size:
batches.append(buffer_per_group[group_id])
del buffer_per_group[group_id]
assert len(buffer_per_group[group_id]) < self.batch_size

# now we have run out of elements that satisfy
# the group criteria, let's return the remaining
# elements so that the size of the sampler is
# deterministic
num_remaining = self.num_batches - len(batches)
if num_remaining > 0:
# for the remaining batches, take first the buffers with the largest number
# of elements
for group_id, _ in sorted(buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True):
remaining = self.batch_size - len(buffer_per_group[group_id])
samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
assert len(buffer_per_group[group_id]) == self.batch_size
batches.append(buffer_per_group[group_id])
num_remaining -= 1
if num_remaining == 0:
break

# Check that the batches are all good to go
assert len(batches) == self.num_batches
return batches

def _reset_progress(self):
self.progress = 0
self._has_reset_progress = True

def set_epoch(self, epoch: int) -> None:
raise NotImplementedError("Use `with sampler.in_epoch(epoch)` instead of `sampler.set_epoch(epoch)`")

def _set_epoch(self, epoch):
if not self._has_reset_progress:
raise HasNotResetProgressError("You must reset progress before setting epoch e.g. `sampler.reset_progress()`\nor use `with sampler.in_epoch(epoch)` instead of `sampler.set_epoch(epoch)`")
self.epoch = epoch

def state_dict(self):
return {"progress": self.progress, "epoch": self.epoch}

def load_state_dict(self, state_dict):
self.progress = state_dict["progress"]
if not self.progress <= self.num_batches:
raise AdvancedTooFarError(f"progress should be less than or equal to the number of batches. progress: {self.progress}, num_batches: {self.num_batches}")
self.epoch = state_dict["epoch"]

def advance(self):
"""
Record that one batch has been consumed.
"""
self.progress += 1
if self.progress > self.num_batches:
raise AdvancedTooFarError(f"You have advanced too far. You can only advance up to the total number of batches: {self.num_batches}.")

def __iter__(self):

# slice from progress to pick up where we left off
for batch in self.batches[self.progress:]:
yield batch

def __len__(self):
return self.num_batches

@contextmanager
def in_epoch(self, epoch):
"""
This context manager is used to set the epoch. It is used like this:
```
for epoch in range(0, 10):
with sampler.in_epoch(epoch):
for step, (x, ) in enumerate(dataloader):
# work would be done here...
```
"""
self._set_epoch(epoch)
yield
self._reset_progress()
82 changes: 80 additions & 2 deletions cycling_utils/cycling_utils/saving.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,87 @@
from pathlib import Path
import os
import torch
import torch.distributed as dist
from collections import defaultdict

def atomic_torch_save(obj, f: str | Path, **kwargs):
def atomic_torch_save(obj, f: str | Path, timer=None, **kwargs):
f = str(f)
temp_f = f + ".temp"
torch.save(obj, temp_f, **kwargs)
os.replace(temp_f, f)
if timer is not None:
timer.report(f'saving temp checkpoint')
os.replace(temp_f, f)
if timer is not None:
timer.report(f'replacing temp checkpoint with checkpoint')
return timer
else:
return

class MetricsTracker:
'''
This is a general purpose MetricsTracker to assist with recording metrics from
a disributed cluster.

The MetricsTracker is initialised without any prior knowledge of the metrics
to be tracked.

>>> metrics = MetricsTracker()

Metrics can be accumulated as required, for example after each batch is procesed
by the model, by passing a dictionary with metrics to be updated, then reduced
accross all nodes. Metric values are stored in a defaultdict.

>>> preds = model(input)
>>> loss = loss_fn(preds, targs)
>>> metrics.update({"images_seen": len(images), "loss": loss.item()})
>>> metrics.reduce()

Metrics are assumed to be summable scalar values. After calling reduce(), the
metrics.local object contains the sum of corresponding metrics from all nodes
which can be used for intermediate reporting or logging.

>>> writer = SummaryWriter()
>>> for metric,val in metrics.local.items():
>>> writer.add_scalar(metric, val, step)
>>> writer.flush()
>>> writer.close()

Once all processing of the current batch has been completed, the MetricsTracker
can be prepared for the next batch using reset_local().

>>> metrics.reset_loca()

Metrics are also accumulated for consecutive batches in the metrics.agg object.
At the end of an epoch the MetricsTracker can be reset using end_epoch().

>>> metrics.end_epoch()

The MetricsTracker saves a copy of the accumulated metrics (metrics.agg) for
each epoch which can be stored within a checkpoint.
'''
def __init__(self):
self.local = defaultdict(float)
self.agg = defaultdict(float)
self.epoch_reports = []

def update(self, metrics: dict):
for m,v in metrics.items():
self.local[m] += v

def reduce(self):
names, local = zip(*self.local.items())
local = torch.tensor(local, dtype=torch.float16, requires_grad=False, device='cuda')
dist.all_reduce(local, op=dist.ReduceOp.SUM)
self.local = defaultdict(float, zip(names, local.cpu().numpy()))
for k in self.local:
self.agg[k] += self.local[k]

def reset_local(self):
self.local = defaultdict(float)

def end_epoch(self):
self.epoch_reports.append(dict(self.agg))
self.local = defaultdict(float)
self.agg = defaultdict(float)


64 changes: 64 additions & 0 deletions cycling_utils/cycling_utils/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os, time
from datetime import datetime

class Timer:
'''
This Timer can be integrated within a training routine to provide point-to-point
script timing and reporting.

def main():
timer = Timer()
time.sleep(2)
timer.report("sleeping for 2 seconds")
time.sleep(3)
timer.report("sleeping for 3 seconds")

>>> main()
Start 0.000 ms 0.000 s total
Completed sleeping for 2 seconds 2,000.000 ms 2.000 s total
Completed sleeping for 3 seconds 3,000.000 ms 5.000 s total
'''
def __init__(self, report=None, start_time=None, running=0):
self.start_time = start_time if start_time is not None else time.time()
self.running = running
if str(os.environ["RANK"]) == "0":
report = report if report else "Start"
print("[{:<80}] {:>12} ms, {:>12} s total".format(report, f'{0.0:,.3f}', f'{0.0:,.2f}'))
def report(self, annot):
if str(os.environ["RANK"]) == "0":
now = time.time()
duration = now - self.start_time
self.running += duration
print("Completed {:<70}{:>12} ms, {:>12} s total".format(annot, f'{1000*duration:,.3f}', f'{self.running:,.2f}'))
self.start_time = now

class TimestampedTimer:
'''
This TimestampedTimer can be integrated within a training routine to provide
point-to-point script timing and reporting.

def main():
timer = TimestampedTimer()
time.sleep(2)
timer.report("sleeping for 2 seconds")
time.sleep(3)
timer.report("sleeping for 3 seconds")

>>> main()
[TIME] Start 0.000 ms 0.000 s total
[TIME] Completed sleeping for 2 seconds 2,000.000 ms 2.000 s total
[TIME] Completed sleeping for 3 seconds 3,000.000 ms 5.000 s total
'''
def __init__(self, report=None, start_time=None, running=0):
if str(os.environ.get("RANK","NONE")) in ["0", "NONE"]:
self.start_time = start_time if start_time is not None else time.time()
self.running = running
report = report if report else "Start"
print("[ {} ] Completed {:<70}{:>12} ms, {:>12} s total".format(time.strftime("%Y-%m-%d %H:%M:%S"), report, f'{0.0:,.3f}', f'{0.0:,.2f}'))
def report(self, annot):
if str(os.environ.get("RANK","NONE")) in ["0", "NONE"]:
now = time.time()
duration = now - self.start_time
self.running += duration
print("[ {} ] Completed {:<70}{:>12} ms, {:>12} s total".format(time.strftime("%Y-%m-%d %H:%M:%S"), annot, f'{1000*duration:,.3f}', f'{self.running:,.2f}'))
self.start_time = now
3 changes: 3 additions & 0 deletions monai_brats_mri_2d/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# MONAI Generative Models Installation
For this demonstration, you will need to clone the MONAI GenerativeModels GitHub repository and follow the instructions for installation. This will install the `generative` package from MONAI.
You will then need to run `pip install -r requirements-dev.txt` to install other necessary dependencies. You may then also need to ensure that monai version 1.2.0 is installed using the command `pip install monai==1.2.0` as later versions of monai do not support all of the transforms used in this example.
6 changes: 6 additions & 0 deletions monai_brats_mri_2d/brats_mri_2d_diff.isc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
experiment_name="brats_mri_2d_diff"
gpu_type="24GB VRAM GPU"
nnodes = 11
venv_path = "~/.venv/bin/activate"
output_path = "~outputs/brats_mri_2d_diff"
command="train_cycling_diff.py --data-path=/mnt/.node1/Open-Datsets/MONAI --resume $OUTPUT_PATH/checkpoint.isc, --gen-load-path ~/output_brats_mri_2d_gen/exp_1855/checkpoint.isc --tboard-path $OUTPUT_PATH/tb"
6 changes: 6 additions & 0 deletions monai_brats_mri_2d/brats_mri_2d_gen.isc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
experiment_name="brats_mri_2d_gen"
gpu_type="24GB VRAM GPU"
nnodes = 11
venv_path = "~/.venv/bin/activate"
output_path = "~/outputs/brats_mri_2d_gen"
command="train_cycling_gen.py --lr 1e-5 --data-path=/mnt/.node1/Open-Datasets/MONAI --resume $OUTPUT_PATH/checkpoint.isc --tboard-path $OUTPUT_PATH/tb --prev-resume /mnt/Client/StrongHumans/strong_adam/outputs/brats_mri_2d_gen/301e7ac7-0c9a-4daa-920e-57ea5ea983b9/checkpoint.isc"
Loading