Skip to content

Commit

Permalink
Merge branch 'divya/add-thresh-mode-option' into divya/add-path-bin-f…
Browse files Browse the repository at this point in the history
…iles
  • Loading branch information
gitttt-1234 committed Oct 28, 2024
2 parents c75399e + 162e949 commit 466ae62
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 67 deletions.
33 changes: 32 additions & 1 deletion sleap_nn/data/get_data_chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
convert_to_rgb,
)
from sleap_nn.data.providers import process_lf
from sleap_nn.data.resizing import apply_sizematcher
from sleap_nn.data.resizing import apply_sizematcher, apply_resizer


def bottomup_data_chunks(
Expand All @@ -23,6 +23,7 @@ def bottomup_data_chunks(
max_instances: int,
max_hw: Tuple[int, int],
user_instances_only: bool = True,
scale: float = 1.0,
) -> Dict[str, torch.Tensor]:
"""Generate dict from `sio.LabeledFrame`.
Expand All @@ -42,6 +43,7 @@ def bottomup_data_chunks(
are used.
user_instances_only: True if filter labels only to user instances else False.
Default: True.
scale: Factor to resize the image dimensions by. Default: 1.0.
Returns:
Dict with image, instances, frame index, video index, original image size and
Expand Down Expand Up @@ -69,6 +71,13 @@ def bottomup_data_chunks(
)
sample["instances"] = sample["instances"] * eff_scale

# resize the image
sample["image"], sample["instances"] = apply_resizer(
sample["image"],
sample["instances"],
scale=scale,
)

transform = T.ToPILImage()
sample["image"] = transform(sample["image"].squeeze(dim=0))

Expand All @@ -83,6 +92,7 @@ def centered_instance_data_chunks(
anchor_ind: Optional[int],
max_hw: Tuple[int, int],
user_instances_only: bool = True,
scale: float = 1.0,
) -> Iterator[Dict[str, torch.Tensor]]:
"""Generate dict from `sio.LabeledFrame`.
Expand All @@ -106,6 +116,7 @@ def centered_instance_data_chunks(
are used.
user_instances_only: True if filter labels only to user instances else False.
Default: True.
scale: Factor to resize the image dimensions by. Default: 1.0.
Returns:
Dict with image, instances, frame index, video index, original image size and
Expand Down Expand Up @@ -155,6 +166,12 @@ def centered_instance_data_chunks(
res["video_idx"] = sample["video_idx"]
res["num_instances"] = sample["num_instances"]
res["orig_size"] = sample["orig_size"]

# resize image
res["instance_image"], res["instance"] = apply_resizer(
res["instance_image"], res["instance"], scale=scale
)

res["instance_image"] = transform(res["instance_image"].squeeze(dim=0))

yield res
Expand All @@ -167,6 +184,7 @@ def centroid_data_chunks(
anchor_ind: Optional[int],
max_hw: Tuple[int, int],
user_instances_only: bool = True,
scale: float = 1.0,
) -> Dict[str, torch.Tensor]:
"""Generate dict from `sio.LabeledFrame`.
Expand All @@ -189,6 +207,7 @@ def centroid_data_chunks(
are used.
user_instances_only: True if filter labels only to user instances else False.
Default: True.
scale: Factor to resize the image dimensions by. Default: 1.0.
Returns:
Dict with image, instances, frame index, video index, original image size and
Expand Down Expand Up @@ -222,6 +241,11 @@ def centroid_data_chunks(

sample["centroids"] = centroids

# resize image
sample["image"], sample["centroids"] = apply_resizer(
sample["image"], sample["centroids"], scale=scale
)

transform = T.ToPILImage()
sample["image"] = transform(sample["image"].squeeze(dim=0))

Expand All @@ -233,6 +257,7 @@ def single_instance_data_chunks(
data_config: DictConfig,
max_hw: Tuple[int, int],
user_instances_only: bool = True,
scale: float = 1.0,
) -> Dict[str, torch.Tensor]:
"""Generate dict from `sio.LabeledFrame`.
Expand All @@ -251,6 +276,7 @@ def single_instance_data_chunks(
are used.
user_instances_only: True if filter labels only to user instances else False.
Default: True.
scale: Factor to resize the image dimensions by. Default: 1.0.
Returns:
Dict with image, instances, frame index, video index, original image size and
Expand Down Expand Up @@ -280,6 +306,11 @@ def single_instance_data_chunks(
)
sample["instances"] = sample["instances"] * eff_scale

# resize image
sample["image"], sample["instances"] = apply_resizer(
sample["image"], sample["instances"], scale=scale
)

transform = T.ToPILImage()
sample["image"] = transform(sample["image"].squeeze(dim=0))

Expand Down
53 changes: 4 additions & 49 deletions sleap_nn/data/streaming_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ class BottomUpStreamingDataset(ld.StreamingDataset):
for PAFs.
max_stride: Scalar integer specifying the maximum stride that the image must be
divisible by.
scale: Factor to resize the image dimensions by, specified as either a float scalar
or as a 2-tuple of [scale_x, scale_y]. If a scalar is provided, both dimensions
are resized by the same factor. Default: 1.0.
apply_aug: `True` if augmentations should be applied to the data pipeline,
else `False`. Default: `False`.
augmentation_config: Augmentation parameters. (`data_config.preprocessing.augmentation_config`
Expand All @@ -47,7 +44,6 @@ def __init__(
pafs_head: DictConfig,
edge_inds: list,
max_stride: int,
scale: float = 1.0,
apply_aug: bool = False,
augmentation_config: DictConfig = None,
*args,
Expand All @@ -60,7 +56,6 @@ def __init__(
self.pafs_head = pafs_head
self.edge_inds = edge_inds
self.max_stride = max_stride
self.scale = scale
self.apply_aug = apply_aug
self.aug_config = augmentation_config

Expand All @@ -85,13 +80,6 @@ def __getitem__(self, index):

ex["image"] = apply_normalization(ex["image"])

# resize the image
ex["image"], ex["instances"] = apply_resizer(
ex["image"],
ex["instances"],
scale=self.scale,
)

# Pad the image (if needed) according max stride
ex["image"] = apply_pad_to_stride(ex["image"], max_stride=self.max_stride)

Expand Down Expand Up @@ -136,23 +124,21 @@ class CenteredInstanceStreamingDataset(ld.StreamingDataset):
crop_hw: Height and width of the crop in pixels.
max_stride: Scalar integer specifying the maximum stride that the image must be
divisible by.
scale: Factor to resize the image dimensions by, specified as either a float scalar
or as a 2-tuple of [scale_x, scale_y]. If a scalar is provided, both dimensions
are resized by the same factor. Default: 1.0.
apply_aug: `True` if augmentations should be applied to the data pipeline,
else `False`. Default: `False`.
augmentation_config: Augmentation parameters. (`data_config.preprocessing.augmentation_config`
section in the config file.)
input_scale: Resize factor applied to the image. Default: 1.0.
"""

def __init__(
self,
confmap_head: DictConfig,
crop_hw: Tuple[int],
max_stride: int,
scale: float = 1.0,
apply_aug: bool = False,
augmentation_config: DictConfig = None,
input_scale: float = 1.0,
*args,
**kwargs,
):
Expand All @@ -161,10 +147,10 @@ def __init__(

self.confmap_head = confmap_head
self.crop_hw = crop_hw
self.scale = scale
self.max_stride = max_stride
self.apply_aug = apply_aug
self.aug_config = augmentation_config
self.input_scale = input_scale

def __getitem__(self, index):
"""Apply augmentation and generate confidence maps."""
Expand All @@ -185,7 +171,7 @@ def __getitem__(self, index):
)

# Re-crop to original crop size
self.crop_hw = list(self.crop_hw)
self.crop_hw = [int(x * self.input_scale) for x in self.crop_hw]
ex["instance_bbox"] = torch.unsqueeze(
make_centered_bboxes(ex["centroid"][0], self.crop_hw[0], self.crop_hw[1]), 0
)
Expand All @@ -201,13 +187,6 @@ def __getitem__(self, index):

ex["instance_image"] = apply_normalization(ex["instance_image"])

# resize the image
ex["instance_image"], ex["instance"] = apply_resizer(
ex["instance_image"],
ex["instance"],
scale=self.scale,
)

# Pad the image (if needed) according max stride
ex["instance_image"] = apply_pad_to_stride(
ex["instance_image"], max_stride=self.max_stride
Expand Down Expand Up @@ -239,9 +218,6 @@ class CentroidStreamingDataset(ld.StreamingDataset):
(required keys: `sigma`, `output_stride` and `anchor_part` depending on the model type ).
max_stride: Scalar integer specifying the maximum stride that the image must be
divisible by.
scale: Factor to resize the image dimensions by, specified as either a float scalar
or as a 2-tuple of [scale_x, scale_y]. If a scalar is provided, both dimensions
are resized by the same factor. Default: 1.0.
apply_aug: `True` if augmentations should be applied to the data pipeline,
else `False`. Default: `False`.
augmentation_config: Augmentation parameters. (`data_config.preprocessing.augmentation_config`
Expand All @@ -252,7 +228,6 @@ def __init__(
self,
confmap_head: DictConfig,
max_stride: int,
scale: float = 1.0,
apply_aug: bool = False,
augmentation_config: DictConfig = None,
*args,
Expand All @@ -263,7 +238,6 @@ def __init__(

self.confmap_head = confmap_head
self.max_stride = max_stride
self.scale = scale
self.apply_aug = apply_aug
self.aug_config = augmentation_config

Expand All @@ -288,13 +262,6 @@ def __getitem__(self, index):

ex["image"] = apply_normalization(ex["image"])

# resize the image
ex["image"], ex["centroids"] = apply_resizer(
ex["image"],
ex["centroids"],
scale=self.scale,
)

# Pad the image (if needed) according max stride
ex["image"] = apply_pad_to_stride(ex["image"], max_stride=self.max_stride)

Expand Down Expand Up @@ -326,9 +293,6 @@ class SingleInstanceStreamingDataset(ld.StreamingDataset):
(required keys: `sigma`, `output_stride` and `anchor_part` depending on the model type ).
max_stride: Scalar integer specifying the maximum stride that the image must be
divisible by.
scale: Factor to resize the image dimensions by, specified as either a float scalar
or as a 2-tuple of [scale_x, scale_y]. If a scalar is provided, both dimensions
are resized by the same factor. Default: 1.0.
apply_aug: `True` if augmentations should be applied to the data pipeline,
else `False`. Default: `False`.
augmentation_config: Augmentation parameters. (`data_config.preprocessing.augmentation_config`
Expand All @@ -339,7 +303,6 @@ def __init__(
self,
confmap_head: DictConfig,
max_stride: int,
scale: float = 1.0,
apply_aug: bool = False,
augmentation_config: DictConfig = None,
*args,
Expand All @@ -350,7 +313,6 @@ def __init__(

self.confmap_head = confmap_head
self.max_stride = max_stride
self.scale = scale
self.apply_aug = apply_aug
self.aug_config = augmentation_config

Expand All @@ -375,13 +337,6 @@ def __getitem__(self, index):

ex["image"] = apply_normalization(ex["image"])

# resize the image
ex["image"], ex["instances"] = apply_resizer(
ex["image"],
ex["instances"],
scale=self.scale,
)

# Pad the image (if needed) according max stride
ex["image"] = apply_pad_to_stride(ex["image"], max_stride=self.max_stride)

Expand Down
5 changes: 5 additions & 0 deletions sleap_nn/training/get_bin_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
parser.add_argument("--model_type", type=str)
parser.add_argument("--num_workers", type=int)
parser.add_argument("--chunk_size", type=int)
parser.add_argument("--scale", type=float)
parser.add_argument("--crop_hw", type=int, default=None)
args = parser.parse_args()

Expand All @@ -45,6 +46,7 @@
data_config=config.data_config,
user_instances_only=user_instances_only,
max_hw=(max_height, max_width),
scale=args.scale,
)

ld.optimize(
Expand Down Expand Up @@ -73,6 +75,7 @@
anchor_ind=config.model_config.head_configs.centered_instance.confmaps.anchor_part,
user_instances_only=user_instances_only,
max_hw=(max_height, max_width),
scale=args.scale,
)

ld.optimize(
Expand All @@ -99,6 +102,7 @@
anchor_ind=config.model_config.head_configs.centroid.confmaps.anchor_part,
user_instances_only=user_instances_only,
max_hw=(max_height, max_width),
scale=args.scale,
)

ld.optimize(
Expand All @@ -124,6 +128,7 @@
max_instances=max_instances,
user_instances_only=user_instances_only,
max_hw=(max_height, max_width),
scale=args.scale,
)

ld.optimize(
Expand Down
Loading

0 comments on commit 466ae62

Please sign in to comment.