From aa4a0bbc327e81498a39adbaf751878fe73a9c9c Mon Sep 17 00:00:00 2001 From: Volker Austrup Date: Thu, 7 Dec 2023 21:53:01 +0000 Subject: [PATCH] add mode logical_and to pruning --- src/pyhf/workspace.py | 83 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 72 insertions(+), 11 deletions(-) diff --git a/src/pyhf/workspace.py b/src/pyhf/workspace.py index 5e90ffa621..f032b7cf3e 100644 --- a/src/pyhf/workspace.py +++ b/src/pyhf/workspace.py @@ -591,16 +591,40 @@ def _prune_and_rename( ), ) for modifier in sample['modifiers'] - if modifier['name'] not in prune_modifiers - and modifier['type'] not in prune_modifier_types + if ( + channel['name'] not in prune_channels + and prune_channels != [] + ) # want to remove only if channel is in prune_channels or if prune_channels is empty, i.e. we want to prune this modifier for every channel + or ( + sample['name'] not in prune_samples + and prune_samples != [] + ) # want to remove only if sample is in prune_samples or if prune_samples is empty, i.e. we want to prune this modifier for every sample + or ( + modifier['name'] not in prune_modifiers + and modifier['type'] not in prune_modifier_types + ) + or prune_measurements + != [] # need to keep the modifier in case it is used in another measurement ], } for sample in channel['samples'] - if sample['name'] not in prune_samples + if ( + channel['name'] not in prune_channels + and prune_channels != [] + ) # want to remove only if channel is in prune_channels or if prune_channels is empty, i.e. we want to prune this sample for every channel + or sample['name'] not in prune_samples + or prune_modifiers + != [] # we only want to remove this sample if we did not specify modifiers to prune + or prune_modifier_types != [] ], } for channel in self['channels'] if channel['name'] not in prune_channels + or ( # we only want to remove this channel if we did not specify any samples or modifiers to prune + prune_samples != [] + or prune_modifiers != [] + or prune_modifier_types != [] + ) ], 'measurements': [ { @@ -615,8 +639,14 @@ def _prune_and_rename( parameter['name'], parameter['name'] ), ) - for parameter in measurement['config']['parameters'] - if parameter['name'] not in prune_modifiers + for parameter in measurement['config'][ + 'parameters' + ] # we only want to remove this parameter if measurement is in prune_measurements or if prune_measurements is empty + if ( + measurement['name'] not in prune_measurements + and prune_measurements != [] + ) + or parameter['name'] not in prune_modifiers ], 'poi': rename_modifiers.get( measurement['config']['poi'], @@ -626,6 +656,8 @@ def _prune_and_rename( } for measurement in self['measurements'] if measurement['name'] not in prune_measurements + or prune_modifiers + != [] # we only want to remove this measurement if we did not specify parameters to remove ], 'observations': [ dict( @@ -634,6 +666,11 @@ def _prune_and_rename( ) for observation in self['observations'] if observation['name'] not in prune_channels + or ( # we only want to remove this channel if we did not specify any samples or modifiers to prune + prune_samples != [] + or prune_modifiers != [] + or prune_modifier_types != [] + ) ], 'version': self['version'], } @@ -646,6 +683,7 @@ def prune( samples=None, channels=None, measurements=None, + mode="logical_or", ): """ Return a new, pruned workspace specification. This will not modify the original workspace. @@ -658,6 +696,7 @@ def prune( samples: A :obj:`list` of samples to prune. channels: A :obj:`list` of channels to prune. measurements: A :obj:`list` of measurements to prune. + mode (:obj: string): `logical_or` or `logical_and` to chain pruning with a logical OR or a logical AND, respectively. Default: `logical_or`. Returns: ~pyhf.workspace.Workspace: A new workspace object with the specified components removed @@ -666,6 +705,12 @@ def prune( ~pyhf.exceptions.InvalidWorkspaceOperation: An item name to prune does not exist in the workspace. """ + + if mode not in ["logical_and", "logical_or"]: + raise ValueError( + "Pruning mode must be either `logical_and` or `logical_or`." + ) + # avoid mutable defaults modifiers = [] if modifiers is None else modifiers modifier_types = [] if modifier_types is None else modifier_types @@ -673,12 +718,28 @@ def prune( channels = [] if channels is None else channels measurements = [] if measurements is None else measurements - return self._prune_and_rename( - prune_modifiers=modifiers, - prune_modifier_types=modifier_types, - prune_samples=samples, - prune_channels=channels, - prune_measurements=measurements, + if mode == "logical_and": + if samples != [] and measurements != []: + raise ValueError( + "Pruning of measurements and samples cannot be run with mode `logical_and`." + ) + if modifier_types != [] and measurements != []: + raise ValueError( + "Pruning of measurements and modifier_types cannot be run with mode `logical_and`." + ) + return self._prune_and_rename( + prune_modifiers=modifiers, + prune_modifier_types=modifier_types, + prune_samples=samples, + prune_channels=channels, + prune_measurements=measurements, + ) + return ( + self._prune_and_rename(prune_modifiers=modifiers) + ._prune_and_rename(prune_modifier_types=modifier_types) + ._prune_and_rename(prune_samples=samples) + ._prune_and_rename(prune_channels=channels) + ._prune_and_rename(prune_measurements=measurements) ) def rename(self, modifiers=None, samples=None, channels=None, measurements=None):