diff --git a/dacapo/experiments/tasks/affinities_task.py b/dacapo/experiments/tasks/affinities_task.py index 5341da8c6..08cbe7888 100644 --- a/dacapo/experiments/tasks/affinities_task.py +++ b/dacapo/experiments/tasks/affinities_task.py @@ -18,6 +18,7 @@ def __init__(self, task_config): affs_weight_clipmax=task_config.affs_weight_clipmax, lsd_weight_clipmin=task_config.lsd_weight_clipmin, lsd_weight_clipmax=task_config.lsd_weight_clipmax, + background_as_object=task_config.background_as_object, ) self.loss = AffinitiesLoss( len(task_config.neighborhood), task_config.lsds_to_affs_weight_ratio diff --git a/dacapo/experiments/tasks/affinities_task_config.py b/dacapo/experiments/tasks/affinities_task_config.py index 913a28187..0bbb8f4bc 100644 --- a/dacapo/experiments/tasks/affinities_task_config.py +++ b/dacapo/experiments/tasks/affinities_task_config.py @@ -52,3 +52,13 @@ class AffinitiesTaskConfig(TaskConfig): default=0.95, metadata={"help_text": "The maximum value for lsds weights."}, ) + background_as_object: bool = attr.ib( + default=False, + metadata={ + "help_text": ( + "Whether to treat the background as a separate object. " + "If set to false background should get an affinity near 0. If " + "set to true, the background should also have high affinity with other background." + ) + }, + ) diff --git a/dacapo/experiments/tasks/predictors/affinities_predictor.py b/dacapo/experiments/tasks/predictors/affinities_predictor.py index 92915384a..d68541349 100644 --- a/dacapo/experiments/tasks/predictors/affinities_predictor.py +++ b/dacapo/experiments/tasks/predictors/affinities_predictor.py @@ -28,6 +28,7 @@ def __init__( affs_weight_clipmax: float = 0.95, lsd_weight_clipmin: float = 0.05, lsd_weight_clipmax: float = 0.95, + background_as_object: bool = False, ): self.neighborhood = neighborhood self.lsds = lsds @@ -51,6 +52,8 @@ def __init__( self.lsd_weight_clipmin = lsd_weight_clipmin self.lsd_weight_clipmax = lsd_weight_clipmax + self.background_as_object = background_as_object + def extractor(self, voxel_size): if self._extractor is None: self._extractor = LsdExtractor( @@ -105,10 +108,12 @@ def create_target(self, gt): label_data = label_data[0] else: axes = ["c"] + axes - affinities = seg_to_affgraph(label_data, self.neighborhood).astype(np.float32) + affinities = seg_to_affgraph( + label_data + int(self.background_as_object), self.neighborhood + ).astype(np.float32) if self.lsds: descriptors = self.extractor(gt.voxel_size).get_descriptors( - segmentation=label_data, + segmentation=label_data + int(self.background_as_object), voxel_size=gt.voxel_size, ) return NumpyArray.from_np_array( @@ -208,7 +213,9 @@ def gt_region_for_roi(self, target_spec): for a, b in zip(pad_pos, self.lsd_pad(target_spec.voxel_size)) ] ) - gt_spec.roi = gt_spec.roi.grow(pad_neg, pad_pos) + gt_spec.roi = gt_spec.roi.grow(pad_neg, pad_pos).snap_to_grid( + target_spec.voxel_size + ) gt_spec.dtype = None return gt_spec