diff --git a/models/models.py b/models/models.py index 7fee5a1..8b3bd2c 100644 --- a/models/models.py +++ b/models/models.py @@ -102,6 +102,12 @@ def create_modules(module_defs, img_size, cfg): filters = output_filters[-1] modules = Silence() + elif mdef['type'] == 'scale_channels': # nn.Sequential() placeholder for 'shortcut' layer + layers = mdef['from'] + filters = output_filters[-1] + routs.extend([i + l if l < 0 else l for l in layers]) + modules = ScaleChannel(layers=layers) + elif mdef['type'] == 'sam': # nn.Sequential() placeholder for 'shortcut' layer layers = mdef['from'] filters = output_filters[-1] @@ -501,7 +507,7 @@ def forward_once(self, x, augment=False, verbose=False): for i, module in enumerate(self.module_list): name = module.__class__.__name__ #print(name) - if name in ['WeightedFeatureFusion', 'FeatureConcat', 'FeatureConcat2', 'FeatureConcat3', 'FeatureConcat_l', 'ScaleSpatial']: # sum, concat + if name in ['WeightedFeatureFusion', 'FeatureConcat', 'FeatureConcat2', 'FeatureConcat3', 'FeatureConcat_l', 'ScaleChannel', 'ScaleSpatial']: # sum, concat if verbose: l = [i - 1] + module.layers # layers sh = [list(x.shape)] + [list(out[i].shape) for i in module.layers] # shapes