diff --git a/master/internal/configpolicy/utils.go b/master/internal/configpolicy/utils.go index 43c821dd5680..9e6fb1b9c82b 100644 --- a/master/internal/configpolicy/utils.go +++ b/master/internal/configpolicy/utils.go @@ -44,6 +44,39 @@ func ValidWorkloadType(val string) bool { } } +func UnmarshalConfigPolicies[T any](errMsg string, constraints, + invariantConfig *string) (*model.Constraints, *T, + error, +) { + var globalConstraints *model.Constraints + var globalConfig *T + + if constraints != nil { + unmarshaledConstraints, err := UnmarshalConfigPolicy[model.Constraints]( + *constraints, + errMsg, + ) + if err != nil { + ConfigPolicyWarning(err.Error()) + return nil, nil, err + } + globalConstraints = unmarshaledConstraints + } + + if invariantConfig != nil { + unmarshaledConfig, err := UnmarshalConfigPolicy[T]( + *invariantConfig, + errMsg, + ) + if err != nil { + ConfigPolicyWarning(err.Error()) + return nil, nil, err + } + globalConfig = unmarshaledConfig + } + return globalConstraints, globalConfig, nil +} + // UnmarshalConfigPolicy is a generic helper function to unmarshal both JSON and YAML strings. func UnmarshalConfigPolicy[T any](str string, errString string) (*T, error) { var configPolicy T @@ -87,30 +120,16 @@ func ValidateExperimentConfig( var globalConstraints *model.Constraints var globalConfig *expconf.ExperimentConfig if globalConfigPolicies != nil { - if globalConfigPolicies.Constraints != nil { - globalConstraints, err = UnmarshalConfigPolicy[model.Constraints]( - *globalConfigPolicies.Constraints, - InvalidExperimentConfigPolicyErr, - ) - if err != nil { - ConfigPolicyWarning(err.Error()) - return err - } - } - - if globalConfigPolicies.InvariantConfig != nil { - globalConfig, err = UnmarshalConfigPolicy[expconf.ExperimentConfig]( - *globalConfigPolicies.InvariantConfig, - InvalidExperimentConfigPolicyErr, - ) - if err != nil { - ConfigPolicyWarning(err.Error()) - return err - } + globalConstraints, globalConfig, err = UnmarshalConfigPolicies[expconf.ExperimentConfig]( + InvalidExperimentConfigPolicyErr, + globalConfigPolicies.Constraints, + globalConfigPolicies.InvariantConfig) + if err != nil { + return err } - warnConfigPolicyOverlap(globalConstraints, cp.Constraints) - warnConfigPolicyOverlap(globalConfig, cp.InvariantConfig) + configPolicyOverlap(globalConstraints, cp.Constraints) + configPolicyOverlap(globalConfig, cp.InvariantConfig) } if cp.Constraints != nil { @@ -161,24 +180,17 @@ func ValidateNTSCConfig( var globalConfig *model.CommandConfig if globalConfigPolicies != nil { if globalConfigPolicies.Constraints != nil { - globalConstraints, err = UnmarshalConfigPolicy[model.Constraints](*globalConfigPolicies.Constraints, - InvalidNTSCConfigPolicyErr) - if err != nil { - ConfigPolicyWarning(err.Error()) - return err - } - } - if globalConfigPolicies.InvariantConfig != nil { - globalConfig, err = UnmarshalConfigPolicy[model.CommandConfig](*globalConfigPolicies.InvariantConfig, - InvalidNTSCConfigPolicyErr) + globalConstraints, globalConfig, err = UnmarshalConfigPolicies[model.CommandConfig]( + InvalidNTSCConfigPolicyErr, + globalConfigPolicies.Constraints, + globalConfigPolicies.InvariantConfig) if err != nil { - ConfigPolicyWarning(err.Error()) return err } } - warnConfigPolicyOverlap(globalConstraints, cp.Constraints) - warnConfigPolicyOverlap(globalConfig, cp.InvariantConfig) + configPolicyOverlap(globalConstraints, cp.Constraints) + configPolicyOverlap(globalConfig, cp.InvariantConfig) } if cp.Constraints != nil { @@ -242,9 +254,9 @@ func checkConstraintConflicts(constraints *model.Constraints, maxSlots, slots, p return nil } -// warnConfigPolicyOverlap compares two different configurations and -// warns the user when both configurations define the same field. -func warnConfigPolicyOverlap(config1, config2 interface{}) { +// configPolicyOverlap compares two different configurations and warns the user when both +// configurations define the same field. +func configPolicyOverlap(config1, config2 interface{}) { if reflect.ValueOf(config1).Type() != reflect.ValueOf(config2).Type() && reflect.ValueOf(config1).Type() != reflect.ValueOf(&model.Constraints{}).Type() && reflect.ValueOf(config1).Type() != reflect.ValueOf(&model.CommandConfig{}).Type() &&