diff --git a/pycbc/transforms.py b/pycbc/transforms.py index c8fa44a9a9b..487de5ab103 100644 --- a/pycbc/transforms.py +++ b/pycbc/transforms.py @@ -318,6 +318,98 @@ def from_config(cls, cp, section, outputs): return cls(inputs, outputs, transform_functions, jacobian=jacobian) +class CustomTransformMultiOutputs(CustomTransform): + """Allows for any transform to be defined. Based on CustomTransform, + but also supports multi-returning value functions. + + Parameters + ---------- + input_args : (list of) str + The names of the input parameters. + output_args : (list of) str + The names of the output parameters. + transform_functions : dict + Dictionary mapping input args to a string giving a function call; + e.g., ``{'q': 'q_from_mass1_mass2(mass1, mass2)'}``. + jacobian : str, optional + String giving a jacobian function. The function must be in terms of + the input arguments. + """ + + name = "custom_multi" + + def __init__(self, input_args, output_args, transform_functions, + jacobian=None): + super(CustomTransformMultiOutputs, self).__init__( + input_args, output_args, transform_functions, jacobian) + + def transform(self, maps): + """Applies the transform functions to the given maps object. + Parameters + ---------- + maps : dict, or FieldArray + Returns + ------- + dict or FieldArray + A map object containing the transformed variables, along with the + original variables. The type of the output will be the same as the + input. + """ + if self.transform_functions is None: + raise NotImplementedError("no transform function(s) provided") + # copy values to scratch + self._copytoscratch(maps) + # ensure that we return the same data type in each dict + getslice = self._getslice(maps) + # evaluate the functions + # func[0] is the function itself, func[1] is the index, + # this supports multiple returning values function + out = { + p: self._scratch[func[0]][func[1]][getslice] if + len(self._scratch[func[0]]) > 1 else + self._scratch[func[0]][getslice] + for p, func in self.transform_functions.items() + } + return self.format_output(maps, out) + + @classmethod + def from_config(cls, cp, section, outputs): + """Loads a CustomTransformMultiOutputs from the given config file. + + Example section: + + .. code-block:: ini + + [{section}-outvar1+outvar2] + name = custom_multi + inputs = inputvar1, inputvar2 + outvar1, outvar2 = func1(inputs) + jacobian = func2(inputs) + """ + tag = outputs + outputs = list(outputs.split(VARARGS_DELIM)) + all_vars = ", ".join(outputs) + inputs = map(str.strip, + cp.get_opt_tag(section, "inputs", tag).split(",")) + # get the functions for each output + transform_functions = {} + output_index = slice(None, None, None) + for var in outputs: + # check if option can be cast as a float + try: + func = cp.get_opt_tag(section, var, tag) + except Exception: + func = cp.get_opt_tag(section, all_vars, tag) + output_index = slice(outputs.index(var), outputs.index(var)+1) + transform_functions[var] = [func, output_index] + s = "-".join([section, tag]) + if cp.has_option(s, "jacobian"): + jacobian = cp.get_opt_tag(section, "jacobian", tag) + else: + jacobian = None + return cls(inputs, outputs, transform_functions, jacobian=jacobian) + + # # ============================================================================= # @@ -2725,6 +2817,7 @@ def from_config(cls, cp, section, outputs, # dictionary of all transforms transforms = { CustomTransform.name: CustomTransform, + CustomTransformMultiOutputs.name: CustomTransformMultiOutputs, MchirpQToMass1Mass2.name: MchirpQToMass1Mass2, Mass1Mass2ToMchirpQ.name: Mass1Mass2ToMchirpQ, MchirpEtaToMass1Mass2.name: MchirpEtaToMass1Mass2,