From bca872b468f70631bddd50fa44598ee14e1f1c0f Mon Sep 17 00:00:00 2001 From: Andrew Choi Date: Thu, 19 Sep 2024 17:10:56 -0700 Subject: [PATCH] Change override to override_sole_init in alf.config --- alf/config_helpers.py | 6 +- alf/config_util.py | 122 +++++++++++++++++++--------------------- alf/config_util_test.py | 69 +++++++++++++++++++++-- 3 files changed, 126 insertions(+), 71 deletions(-) diff --git a/alf/config_helpers.py b/alf/config_helpers.py index 38965e48d..03c151e8c 100644 --- a/alf/config_helpers.py +++ b/alf/config_helpers.py @@ -285,7 +285,11 @@ def get_env(): # A 'None' random seed won't set a deterministic torch behavior. for _ in range(PerProcessContext().ddp_rank): random_seed = random.randint(0, 2**32) - config1("TrainerConfig.random_seed", random_seed, raise_if_used=False) + config1( + "TrainerConfig.random_seed", + random_seed, + raise_if_used=False, + override_sole_init=True) # We have to call set_random_seed() here because we need the actual # random seed to call create_environment. diff --git a/alf/config_util.py b/alf/config_util.py index dcc8dbd69..1b2183dd1 100644 --- a/alf/config_util.py +++ b/alf/config_util.py @@ -60,7 +60,7 @@ def config(prefix_or_dict, mutable=True, raise_if_used=True, sole_init=False, - override=False, + override_sole_init=False, **kwargs): """Set the values for the configs with given name as suffix. @@ -112,19 +112,18 @@ def func(self, a, b): immutable value to an existing immutable value. raise_if_used (bool): If True, ValueError will be raised if trying to config a value which has already been used. - sole_init (bool): If True, the config value can only be set once. Any - previous or future calls will raise a ValueError. This is helpful - in enforcing a singular point of initialization, thus eliminating - any potential side effects from possible prior or future overrides. - This flag overrides the mutable flag if True. For users wanting this - to be the default behavior, the ALF_SOLE_CONFIG env variable can be - set to 1. - override (bool): If True, the value of the config will be set regardless - of any pre-existing ``mutable`` or ``sole_init`` settings. This should - be used only when absolutely necessary (e.g., a teacher-student training - loop, where the student must override certain configs inherited from the - teacher). Otherwise, use ``mutable`` or ``sole_init`` instead. If override - is True, the config's ``mutable`` or ``sole_init`` values are not changed. + sole_init (bool): If True, the config value can no longer be set again + after this config call. Any future calls will raise a RuntimeError. + This is helpful in enforcing a singular point of initialization, + thus eliminating any potential side effects from possible future + overrides. For users wanting this to be the default behavior, the + ALF_SOLE_CONFIG env variable can be set to 1. + override_sole_init (bool): If True, the value of the config will be set + regardless of any previous ``sole_init`` setting. This should be used + only when absolutely necessary (e.g., a teacher-student training loop, + where the student must override certain configs inherited from the + teacher). If the config is immutable, a warning will be declared with + no changes made. **kwargs: only used if ``prefix_or_dict`` is a str. """ if isinstance(prefix_or_dict, str): @@ -144,18 +143,18 @@ def func(self, a, b): sole_init = sole_init or GET_ALF_SOLE_CONFIG() for key, value in configs.items(): - config1(key, value, mutable, raise_if_used, sole_init, override) + config1(key, value, mutable, raise_if_used, sole_init, + override_sole_init) def override_config(prefix_or_dict, **kwargs): - """Wrapper function for configuring a config with override=True. + """Wrapper function for configuring a config with override_sole_init=True. - This call will change a config's value, ignoring any previous protections - placed upon a config by the ``mutable`` and ``sole_init`` flags. - Therefore, it is highly recommended that this be used only when absolutely - necessary (e.g., a teacher-student training loop, where the student must - override certain configs inherited from the teacher). Otherwise, it is best - to use alf.config with the ``mutable`` and ``sole_init`` flags instead. + This call allows a user to attempt to overwrite a config's value even + if it is protected by ``sole_init``. It is highly recommended that this be + used only when absolutely necessary (e.g., a teacher-student training loop, + where the student must override certain configs inherited from the teacher). + This call has no effect for configs who are immutable. Args: prefix_or_dict (str|dict): if a dict, each (key, value) pair in it @@ -164,7 +163,7 @@ def override_config(prefix_or_dict, **kwargs): value for config with name ``prefix + '.' + key`` **kwargs: only used if ``prefix_or_dict`` is a str. """ - config(prefix_or_dict, override=True, **kwargs) + config(prefix_or_dict, override_sole_init=True, **kwargs) def get_all_config_names(): @@ -367,7 +366,7 @@ def config1(config_name, mutable=True, raise_if_used=True, sole_init=False, - override=False): + override_sole_init=False): """Set one configurable value. Args: @@ -380,17 +379,18 @@ def config1(config_name, immutable value to an existing immutable value. raise_if_used (bool): If True, ValueError will be raised if trying to config a value which has already been used. - sole_init (bool): If True, the config value can only be set once. Any - previous or future calls will raise a ValueError. This is helpful - in enforcing a singular point of initialization, thus eliminating - any potential side effects from possible prior or future overrides. - This flag overrides the mutable flag if True. - override (bool): If True, the value of the config will be set regardless - of any pre-existing ``mutable`` or ``sole_init`` settings. This should - be used only when absolutely necessary (e.g., a teacher-student training - loop, where the student must override certain configs inherited from the - teacher). Otherwise, use ``mutable`` or ``sole_init`` instead. If override - is True, the config's ``mutable`` or ``sole_init`` values are not changed. + sole_init (bool): If True, the config value can no longer be set again + after this config call. Any future calls will raise a RuntimeError. + This is helpful in enforcing a singular point of initialization, + thus eliminating any potential side effects from possible future + overrides. For users wanting this to be the default behavior, the + ALF_SOLE_CONFIG env variable can be set to 1. + override_sole_init (bool): If True, the value of the config will be set + regardless of any previous ``sole_init`` setting. This should be used + only when absolutely necessary (e.g., a teacher-student training loop, + where the student must override certain configs inherited from the + teacher). If the config is immutable, a warning will be declared with + no changes made. """ config_node = _get_config_node(config_name) @@ -399,49 +399,43 @@ def config1(config_name, "Config '%s' has already been used. You should config " "its value before using it." % config_name) - if override: + if override_sole_init: + if config_node.is_configured(): + if not config_node.is_mutable(): + logging.warning( + "The value of config '%s' (%s) is immutable. " + "Override flag with new value %s is ignored. " % + (config_name, config_node.get_value(), value)) + return + elif config_node.get_sole_init(): + logging.warning( + "The value of config '%s' (%s) is protected by sole_init. " + "It is now being overridden by the override flag to a new value %s. " + "Use at your own risk." % (config_name, + config_node.get_value(), value)) + elif config_node.is_configured(): if config_node.get_sole_init(): - logging.warning( - "The value of config '%s' (%s) is protected by sole_init. " - "It is now being overridden by the overide_all flag to a new value %s. " - "Use at your own risk." % (config_name, - config_node.get_value(), value)) - if not config_node.is_mutable(): - logging.warning( - "The value of config '%s' (%s) is immutable. " - "It is now being overridden by the overide_all flag to a new value %s. " - "Use at your own risk." % (config_name, - config_node.get_value(), value)) - config_node.set_value(value) - return - - if config_node.is_configured(): - if config_node.get_sole_init(): - raise ValueError( + raise RuntimeError( "Config '%s' is protected by sole_init and cannot be reconfigured. " "If you wish to set this config value, do so the location of the " "previous call." % config_name) - if sole_init: - raise ValueError( - "Config '%s' has already been configured. If you wish to protect " - "this config with sole_init, the previous alf.config call must be " - "removed." % config_name) if config_node.is_mutable(): logging.warning( "The value of config '%s' has been configured to %s. It is " "replaced by the new value %s" % (config_name, config_node.get_value(), value)) - config_node.set_value(value) - config_node.set_mutable(mutable) else: logging.warning( "The config '%s' has been configured to an immutable value " "of %s. The new value %s will be ignored" % (config_name, config_node.get_value(), value)) - else: - config_node.set_value(value) - config_node.set_mutable(mutable) + config_node.set_sole_init(sole_init) + return + + config_node.set_value(value) + config_node.set_mutable(mutable) + if not override_sole_init: config_node.set_sole_init(sole_init) @@ -462,7 +456,7 @@ def pre_config(configs): """ for name, value in configs.items(): try: - config1(name, value, mutable=False) + config1(name, value, mutable=False, sole_init=False) _HANDLED_PRE_CONFIGS.append((name, value)) except ValueError: _PRE_CONFIGS.append((name, value)) diff --git a/alf/config_util_test.py b/alf/config_util_test.py index c1b18e80c..44accd916 100644 --- a/alf/config_util_test.py +++ b/alf/config_util_test.py @@ -133,23 +133,23 @@ def test_config1(self): pprint.pformat(inoperative_configs)) self.assertTrue('A.B.C.D.test.arg' in dict(inoperative_configs)) + def test_sole_config(self): # Test sole_init protection against future config calls @alf.configurable def sole_init_test_prior(x): pass alf.config("sole_init_test_prior", x=0, sole_init=True) - with self.assertRaises(ValueError) as context: + with self.assertRaises(RuntimeError) as context: alf.config("sole_init_test_prior", x=0) - # Test sole_init protection against previous config calls + # Test sole_init against previous config calls. Should not trigger an error. @alf.configurable def sole_init_test_after(x): pass alf.config("sole_init_test_after", x=0) - with self.assertRaises(ValueError) as context: - alf.config("sole_init_test_after", x=0, sole_init=True) + alf.config("sole_init_test_after", x=0, sole_init=True) # Test sole_init protection against other sole_init config calls @alf.configurable @@ -157,7 +157,7 @@ def sole_init_test_twice(x): pass alf.config("sole_init_test_twice", x=0, sole_init=True) - with self.assertRaises(ValueError) as context: + with self.assertRaises(RuntimeError) as context: alf.config("sole_init_test_twice", x=0, sole_init=True) # Test sole_init protection works as expected when the ALF_SOLE_CONFIG @@ -168,7 +168,7 @@ def sole_init_test_env(x): os.environ["ALF_SOLE_CONFIG"] = "1" alf.config("sole_init_test_env", x=0) - with self.assertRaises(ValueError) as context: + with self.assertRaises(RuntimeError) as context: alf.config("sole_init_test_env", x=0) os.environ["ALF_SOLE_CONFIG"] = "0" @@ -182,6 +182,63 @@ def sole_init_test_env(x): self.assertEqual(alf.get_config_value("sole_init_test_twice.x"), 1) self.assertEqual(alf.get_config_value("sole_init_test_env.x"), 1) + # Test override_config doesn't doesn't overwrite for immutable values. + @alf.configurable + def override_on_immutable(x): + pass + + alf.config("override_on_immutable", x=0, mutable=False) + alf.override_config("override_on_immutable", x=1) + self.assertEqual(alf.get_config_value("override_on_immutable.x"), 0) + + # Test override_config doesn't doesn't overwrite for immutable values. + @alf.configurable + def override_on_immutable_and_sole_init(x): + pass + + alf.config( + "override_on_immutable_and_sole_init", + x=0, + sole_init=True, + mutable=False) + alf.override_config("override_on_immutable_and_sole_init", x=1) + self.assertEqual( + alf.get_config_value("override_on_immutable_and_sole_init.x"), 0) + + # Test that pre_config calls work before a sole_init config call + @alf.configurable + def pre_config_before(x): + pass + + alf.pre_config({"pre_config_before.x": 0}) + alf.config("pre_config_before", x=1) + alf.override_config("pre_config_before", x=1) + # sole_init starts to take effect for all calls AFTER the first call. + alf.config("pre_config_before", x=1, sole_init=True) + with self.assertRaises(RuntimeError) as context: + alf.config("pre_config_before", x=2) + # If truly immutable, the value should never have been changed + self.assertEqual(alf.get_config_value("pre_config_before.x"), 0) + + # Test that pre_config calls after a sole_init config call raises an error + @alf.configurable + def pre_config_after(x): + pass + + alf.config("pre_config_after", x=1, sole_init=True) + with self.assertRaises(RuntimeError) as context: + alf.pre_config({"pre_config_after.x": 0}) + + # Test that calling override_config doesn't affect previous sole_init calls. + @alf.configurable + def override_no_affect_sole_init(x): + pass + + alf.config("override_no_affect_sole_init", x=1, sole_init=True) + alf.override_config("override_no_affect_sole_init", x=2) + with self.assertRaises(RuntimeError) as context: + alf.config("override_no_affect_sole_init", x=3) + def test_repr_wrapper(self): a = MyClass(1, 2) self.assertEqual(repr(a), "MyClass(1, 2)")