Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to reuse .bin files #116

Open
wants to merge 7 commits into
base: divya/load-head-ckpt-inference
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions sleap_nn/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@ def get_skeleton_from_config(skeleton_config: OmegaConf):
for name in skeleton_config.keys():

# create `sio.Node` object.
nodes = [
sio.model.skeleton.Node(n["name"]) for n in skeleton_config[name].nodes
]
nodes = [sio.Node(n["name"]) for n in skeleton_config[name].nodes]

# create `sio.Edge` object.
edges = [
sio.model.skeleton.Edge(
sio.model.skeleton.Node(e["source"]["name"]),
sio.model.skeleton.Node(e["destination"]["name"]),
sio.Edge(
sio.Node(e["source"]["name"]),
sio.Node(e["destination"]["name"]),
)
for e in skeleton_config[name].edges
]
Expand All @@ -38,17 +36,17 @@ def get_skeleton_from_config(skeleton_config: OmegaConf):
list_args = [
set(
[
sio.model.skeleton.Node(s[0]["name"]),
sio.model.skeleton.Node(s[1]["name"]),
sio.Node(s[0]["name"]),
sio.Node(s[1]["name"]),
]
)
for s in skeleton_config[name].symmetries
]
symmetries = [sio.model.skeleton.Symmetry(x) for x in list_args]
symmetries = [sio.Symmetry(x) for x in list_args]
else:
symmetries = []

skeletons.append(sio.model.skeleton.Skeleton(nodes, edges, symmetries, name))
skeletons.append(sio.Skeleton(nodes, edges, symmetries, name))

return skeletons

Expand Down
159 changes: 100 additions & 59 deletions sleap_nn/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def __init__(self, config: OmegaConf):
self.model = None
self.train_data_loader = None
self.val_data_loader = None
self.bin_files_path = None
self.trainer = None
self.crop_hw = -1

# check which head type to choose the model
Expand All @@ -110,25 +112,8 @@ def __init__(self, config: OmegaConf):
f"Cannot create a new folder in {self.dir_path}. Check the permissions to the given Checkpoint directory. \n {e}"
)

self.bin_files_path = self.config.trainer_config.bin_files_path
if self.bin_files_path is None:
self.bin_files_path = self.dir_path

self.bin_files_path = f"{self.bin_files_path}/chunks_{datetime.strftime(datetime.now(), '%Y%m%d_%H-%M-%S-%f')}"
print(f"`.bin` files are saved in {self.bin_files_path}")

if not Path(self.bin_files_path).exists():
try:
Path(self.bin_files_path).mkdir(parents=True, exist_ok=True)
except OSError as e:
raise OSError(
f"Cannot create a new folder in {self.bin_files_path}. Check the permissions to the given Checkpoint directory. \n {e}"
)

OmegaConf.save(config=self.config, f=f"{self.dir_path}/initial_config.yaml")

self.config.trainer_config.saved_bin_files_path = self.bin_files_path

# set seed
torch.manual_seed(self.seed)

Expand All @@ -140,6 +125,20 @@ def __init__(self, config: OmegaConf):
else True
) # TODO: defaults should be handles in config validation.
self.skeletons = train_labels.skeletons
# save the skeleton in the config
self.config["data_config"]["skeletons"] = {}
for skl in self.skeletons:
if skl.symmetries:
symm = [list(s.nodes) for s in skl.symmetries]
else:
symm = None
skl_name = skl.name if skl.name is not None else "skeleton-0"
self.config["data_config"]["skeletons"][skl_name] = {
"nodes": skl.nodes,
"edges": skl.edges,
"symmetries": symm,
}

self.max_stride = self.config.model_config.backbone_config.max_stride
self.edge_inds = train_labels.skeletons[0].edge_inds
self.chunk_size = (
Expand Down Expand Up @@ -180,7 +179,10 @@ def __init__(self, config: OmegaConf):
else:
self.crop_hw = self.crop_hw[0]

def _create_data_loaders(self):
def _create_data_loaders(
self,
chunks_dir_path: Optional[str] = None,
):
"""Create a DataLoader for train, validation and test sets using the data_config."""

def run_subprocess():
Expand Down Expand Up @@ -218,16 +220,49 @@ def run_subprocess():
print("Standard Output:\n", stdout)
print("Standard Error:\n", stderr)

try:
run_subprocess()
if chunks_dir_path is None:
try:
self.bin_files_path = self.config.trainer_config.bin_files_path
if self.bin_files_path is None:
self.bin_files_path = self.dir_path

self.bin_files_path = f"{self.bin_files_path}/chunks_{datetime.strftime(datetime.now(), '%Y%m%d_%H-%M-%S-%f')}"
print(
f"New dir is created and `.bin` files are saved in {self.bin_files_path}"
)

if not Path(self.bin_files_path).exists():
try:
Path(self.bin_files_path).mkdir(parents=True, exist_ok=True)
except OSError as e:
raise OSError(
f"Cannot create a new folder in {self.bin_files_path}. Check the permissions to the given Checkpoint directory. \n {e}"
)

self.config.trainer_config.saved_bin_files_path = self.bin_files_path

except Exception as e:
raise Exception(f"Error while creating the `.bin` files... {e}")
self.train_input_dir = (
Path(self.bin_files_path) / "train_chunks"
).as_posix()
self.val_input_dir = (
Path(self.bin_files_path) / "val_chunks"
).as_posix()

run_subprocess()

except Exception as e:
raise Exception(f"Error while creating the `.bin` files... {e}")

else:
print(f"Using `.bin` files from {chunks_dir_path}.")
self.train_input_dir = (Path(chunks_dir_path) / "train_chunks").as_posix()
self.val_input_dir = (Path(chunks_dir_path) / "val_chunks").as_posix()
self.config.trainer_config.saved_bin_files_path = chunks_dir_path

if self.model_type == "single_instance":

train_dataset = SingleInstanceStreamingDataset(
input_dir=(Path(self.bin_files_path) / "train_chunks").as_posix(),
input_dir=self.train_input_dir,
shuffle=self.config.trainer_config.train_data_loader.shuffle,
apply_aug=self.config.data_config.use_augmentations_train,
augmentation_config=self.config.data_config.augmentation_config,
Expand All @@ -236,7 +271,7 @@ def run_subprocess():
)

val_dataset = SingleInstanceStreamingDataset(
input_dir=(Path(self.bin_files_path) / "val_chunks").as_posix(),
input_dir=self.val_input_dir,
shuffle=False,
apply_aug=False,
confmap_head=self.config.model_config.head_configs.single_instance.confmaps,
Expand All @@ -246,7 +281,7 @@ def run_subprocess():
elif self.model_type == "centered_instance":

train_dataset = CenteredInstanceStreamingDataset(
input_dir=(Path(self.bin_files_path) / "train_chunks").as_posix(),
input_dir=self.train_input_dir,
shuffle=self.config.trainer_config.train_data_loader.shuffle,
apply_aug=self.config.data_config.use_augmentations_train,
augmentation_config=self.config.data_config.augmentation_config,
Expand All @@ -257,7 +292,7 @@ def run_subprocess():
)

val_dataset = CenteredInstanceStreamingDataset(
input_dir=(Path(self.bin_files_path) / "val_chunks").as_posix(),
input_dir=self.val_input_dir,
shuffle=False,
apply_aug=False,
confmap_head=self.config.model_config.head_configs.centered_instance.confmaps,
Expand All @@ -268,7 +303,7 @@ def run_subprocess():

elif self.model_type == "centroid":
train_dataset = CentroidStreamingDataset(
input_dir=(Path(self.bin_files_path) / "train_chunks").as_posix(),
input_dir=self.train_input_dir,
shuffle=self.config.trainer_config.train_data_loader.shuffle,
apply_aug=self.config.data_config.use_augmentations_train,
augmentation_config=self.config.data_config.augmentation_config,
Expand All @@ -277,7 +312,7 @@ def run_subprocess():
)

val_dataset = CentroidStreamingDataset(
input_dir=(Path(self.bin_files_path) / "val_chunks").as_posix(),
input_dir=self.val_input_dir,
shuffle=False,
apply_aug=False,
confmap_head=self.config.model_config.head_configs.centroid.confmaps,
Expand All @@ -286,7 +321,7 @@ def run_subprocess():

elif self.model_type == "bottomup":
train_dataset = BottomUpStreamingDataset(
input_dir=(Path(self.bin_files_path) / "train_chunks").as_posix(),
input_dir=self.train_input_dir,
shuffle=self.config.trainer_config.train_data_loader.shuffle,
apply_aug=self.config.data_config.use_augmentations_train,
augmentation_config=self.config.data_config.augmentation_config,
Expand All @@ -297,7 +332,7 @@ def run_subprocess():
)

val_dataset = BottomUpStreamingDataset(
input_dir=(Path(self.bin_files_path) / "val_chunks").as_posix(),
input_dir=self.val_input_dir,
shuffle=False,
apply_aug=False,
confmap_head=self.config.model_config.head_configs.bottomup.confmaps,
Expand Down Expand Up @@ -355,8 +390,22 @@ def train(
self,
backbone_trained_ckpts_path: Optional[str] = None,
head_trained_ckpts_path: Optional[str] = None,
delete_bin_files_after_training: bool = True,
chunks_dir_path: Optional[str] = None,
):
"""Initiate the training by calling the fit method of Trainer."""
"""Initiate the training by calling the fit method of Trainer.

Args:
backbone_trained_ckpts_path: Path of the `ckpt` file with which the backbone
is initialized. If `None`, random init is used.
head_trained_ckpts_path: Path of the `ckpt` file with which the head layers
are initialized. If `None`, random init is used.
delete_bin_files_after_training: If `False`, the `bin` files are retained after
training. Else, the `bin` files are deleted.
chunks_dir_path: Path to chunks dir (this dir should contain `train_chunks`
and `val_chunks` folder.). If `None`, `bin` files are generated.

"""
logger = []

if self.config.trainer_config.save_ckpt:
Expand Down Expand Up @@ -416,23 +465,9 @@ def train(
# save the configs as yaml in the checkpoint dir
OmegaConf.save(config=self.config, f=f"{self.dir_path}/training_config.yaml")

self._create_data_loaders()

# save the skeleton in the config
self.config["data_config"]["skeletons"] = {}
for skl in self.skeletons:
if skl.symmetries:
symm = [list(s.nodes) for s in skl.symmetries]
else:
symm = None
skl_name = skl.name if skl.name is not None else "skeleton-0"
self.config["data_config"]["skeletons"][skl_name] = {
"nodes": skl.nodes,
"edges": skl.edges,
"symmetries": symm,
}
self._create_data_loaders(chunks_dir_path)

trainer = L.Trainer(
self.trainer = L.Trainer(
callbacks=callbacks,
logger=logger,
enable_checkpointing=self.config.trainer_config.save_ckpt,
Expand All @@ -444,7 +479,7 @@ def train(
)

try:
trainer.fit(
self.trainer.fit(
self.model,
self.train_data_loader,
self.val_data_loader,
Expand All @@ -468,17 +503,23 @@ def train(
config=self.config, f=f"{self.dir_path}/training_config.yaml"
)
# TODO: (ubuntu test failing (running for > 6hrs) with the below lines)
# print("Deleting training and validation files...")
# if (Path(self.dir_path) / "train_chunks").exists():
# shutil.rmtree(
# (Path(self.dir_path) / "train_chunks").as_posix(),
# ignore_errors=True,
# )
# if (Path(self.dir_path) / "val_chunks").exists():
# shutil.rmtree(
# (Path(self.dir_path) / "val_chunks").as_posix(),
# ignore_errors=True,
# )
if delete_bin_files_after_training:
print("Deleting training and validation files...")
if (Path(self.train_input_dir)).exists():
shutil.rmtree(
(Path(self.train_input_dir)).as_posix(),
ignore_errors=True,
)
if (Path(self.val_input_dir)).exists():
shutil.rmtree(
(Path(self.val_input_dir)).as_posix(),
ignore_errors=True,
)
if self.bin_files_path is not None:
shutil.rmtree(
(Path(self.bin_files_path)).as_posix(),
ignore_errors=True,
)


class TrainingModel(L.LightningModule):
Expand Down
34 changes: 30 additions & 4 deletions tests/inference/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ def test_topdown_predictor(
# check if the predicted labels have same video and skeleton as the ground truth labels
gt_labels = sio.load_slp(minimal_instance)
gt_lf = gt_labels[0]
assert pred_labels.skeletons == gt_labels.skeletons

skl = pred_labels.skeletons[0]
gt_skl = gt_labels.skeletons[0]
assert [a.name for a in skl.nodes] == [a.name for a in gt_skl.nodes]
assert len(skl.edges) == len(gt_skl.edges)
for a, b in zip(skl.edges, gt_skl.edges):
assert a[0].name == b[0].name and a[1].name == b[1].name
assert skl.symmetries == gt_skl.symmetries

assert lf.frame_idx == gt_lf.frame_idx
assert lf.instances[0].numpy().shape == gt_lf.instances[0].numpy().shape
assert lf.instances[1].numpy().shape == gt_lf.instances[1].numpy().shape
Expand Down Expand Up @@ -339,7 +347,13 @@ def test_single_instance_predictor(
# check if the predicted labels have same video and skeleton as the ground truth labels
gt_labels = sio.load_slp(minimal_instance)
gt_lf = gt_labels[0]
assert pred_labels.skeletons == gt_labels.skeletons
skl = pred_labels.skeletons[0]
gt_skl = gt_labels.skeletons[0]
assert [a.name for a in skl.nodes] == [a.name for a in gt_skl.nodes]
assert len(skl.edges) == len(gt_skl.edges)
for a, b in zip(skl.edges, gt_skl.edges):
assert a[0].name == b[0].name and a[1].name == b[1].name
assert skl.symmetries == gt_skl.symmetries
assert lf.frame_idx == gt_lf.frame_idx
assert lf.instances[0].numpy().shape == gt_lf.instances[0].numpy().shape

Expand Down Expand Up @@ -391,7 +405,13 @@ def test_single_instance_predictor(

# check if the predicted labels have same skeleton as the GT labels
gt_labels = sio.load_slp(minimal_instance)
assert pred_labels.skeletons == gt_labels.skeletons
skl = pred_labels.skeletons[0]
gt_skl = gt_labels.skeletons[0]
assert [a.name for a in skl.nodes] == [a.name for a in gt_skl.nodes]
assert len(skl.edges) == len(gt_skl.edges)
for a, b in zip(skl.edges, gt_skl.edges):
assert a[0].name == b[0].name and a[1].name == b[1].name
assert skl.symmetries == gt_skl.symmetries
assert lf.frame_idx == 0

# check if dictionaries are created when make labels is set to False
Expand Down Expand Up @@ -567,7 +587,13 @@ def test_bottomup_predictor(
# check if the predicted labels have same video and skeleton as the ground truth labels
gt_labels = sio.load_slp(minimal_instance)
gt_lf = gt_labels[0]
assert pred_labels.skeletons == gt_labels.skeletons
skl = pred_labels.skeletons[0]
gt_skl = gt_labels.skeletons[0]
assert [a.name for a in skl.nodes] == [a.name for a in gt_skl.nodes]
assert len(skl.edges) == len(gt_skl.edges)
for a, b in zip(skl.edges, gt_skl.edges):
assert a[0].name == b[0].name and a[1].name == b[1].name
assert skl.symmetries == gt_skl.symmetries
assert lf.frame_idx == gt_lf.frame_idx
assert lf.instances[0].numpy().shape == gt_lf.instances[0].numpy().shape

Expand Down
Loading
Loading