Skip to content

Commit

Permalink
Change override to override_sole_init in alf.config
Browse files Browse the repository at this point in the history
  • Loading branch information
QuantuMope committed Sep 20, 2024
1 parent 827532a commit 192a34c
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 65 deletions.
6 changes: 5 additions & 1 deletion alf/config_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,11 @@ def get_env():
# A 'None' random seed won't set a deterministic torch behavior.
for _ in range(PerProcessContext().ddp_rank):
random_seed = random.randint(0, 2**32)
config1("TrainerConfig.random_seed", random_seed, raise_if_used=False)
config1(
"TrainerConfig.random_seed",
random_seed,
raise_if_used=False,
override_sole_init=True)

# We have to call set_random_seed() here because we need the actual
# random seed to call create_environment.
Expand Down
108 changes: 50 additions & 58 deletions alf/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def config(prefix_or_dict,
mutable=True,
raise_if_used=True,
sole_init=False,
override=False,
override_sole_init=False,
**kwargs):
"""Set the values for the configs with given name as suffix.
Expand Down Expand Up @@ -112,19 +112,17 @@ def func(self, a, b):
immutable value to an existing immutable value.
raise_if_used (bool): If True, ValueError will be raised if trying to
config a value which has already been used.
sole_init (bool): If True, the config value can only be set once. Any
previous or future calls will raise a ValueError. This is helpful
in enforcing a singular point of initialization, thus eliminating
any potential side effects from possible prior or future overrides.
This flag overrides the mutable flag if True. For users wanting this
to be the default behavior, the ALF_SOLE_CONFIG env variable can be
set to 1.
override (bool): If True, the value of the config will be set regardless
of any pre-existing ``mutable`` or ``sole_init`` settings. This should
be used only when absolutely necessary (e.g., a teacher-student training
loop, where the student must override certain configs inherited from the
teacher). Otherwise, use ``mutable`` or ``sole_init`` instead. If override
is True, the config's ``mutable`` or ``sole_init`` values are not changed.
sole_init (bool): If True, the config value can no longer be set after
this config call.once. Any future calls will raise a ValueError.
This is helpful in enforcing a singular point of initialization,
thus eliminating any potential side effects from possible future
overrides.
override_sole_init (bool): If True, the value of the config will be set
regardless of any previous ``sole_init`` setting. This should be used
only when absolutely necessary (e.g., a teacher-student training loop,
where the student must override certain configs inherited from the
teacher). If override is True and the config is immutable, a warning
will be declared with no changes made.
**kwargs: only used if ``prefix_or_dict`` is a str.
"""
if isinstance(prefix_or_dict, str):
Expand All @@ -144,7 +142,8 @@ def func(self, a, b):
sole_init = sole_init or GET_ALF_SOLE_CONFIG()

for key, value in configs.items():
config1(key, value, mutable, raise_if_used, sole_init, override)
config1(key, value, mutable, raise_if_used, sole_init,
override_sole_init)


def override_config(prefix_or_dict, **kwargs):
Expand All @@ -164,7 +163,7 @@ def override_config(prefix_or_dict, **kwargs):
value for config with name ``prefix + '.' + key``
**kwargs: only used if ``prefix_or_dict`` is a str.
"""
config(prefix_or_dict, override=True, **kwargs)
config(prefix_or_dict, override_sole_init=True, **kwargs)


def get_all_config_names():
Expand Down Expand Up @@ -359,7 +358,7 @@ def config1(config_name,
mutable=True,
raise_if_used=True,
sole_init=False,
override=False):
override_sole_init=False):
"""Set one configurable value.
Args:
Expand All @@ -372,17 +371,17 @@ def config1(config_name,
immutable value to an existing immutable value.
raise_if_used (bool): If True, ValueError will be raised if trying to
config a value which has already been used.
sole_init (bool): If True, the config value can only be set once. Any
previous or future calls will raise a ValueError. This is helpful
in enforcing a singular point of initialization, thus eliminating
any potential side effects from possible prior or future overrides.
This flag overrides the mutable flag if True.
override (bool): If True, the value of the config will be set regardless
of any pre-existing ``mutable`` or ``sole_init`` settings. This should
be used only when absolutely necessary (e.g., a teacher-student training
loop, where the student must override certain configs inherited from the
teacher). Otherwise, use ``mutable`` or ``sole_init`` instead. If override
is True, the config's ``mutable`` or ``sole_init`` values are not changed.
sole_init (bool): If True, the config value can no longer be set after
this config call.once. Any future calls will raise a ValueError.
This is helpful in enforcing a singular point of initialization,
thus eliminating any potential side effects from possible future
overrides.
override_sole_init (bool): If True, the value of the config will be set
regardless of any previous ``sole_init`` setting. This should be used
only when absolutely necessary (e.g., a teacher-student training loop,
where the student must override certain configs inherited from the
teacher). If override is True and the config is immutable, a warning
will be declared with no changes made.
"""
config_node = _get_config_node(config_name)

Expand All @@ -391,50 +390,43 @@ def config1(config_name,
"Config '%s' has already been used. You should config "
"its value before using it." % config_name)

if override:
if override_sole_init:
if config_node.is_configured():
if not config_node.is_mutable():
logging.warning(
"The value of config '%s' (%s) is immutable. "
"Override flag with new value %s is ignored. " %
(config_name, config_node.get_value(), value))
return
elif config_node.get_sole_init():
logging.warning(
"The value of config '%s' (%s) is protected by sole_init. "
"It is now being overriden by the override flag to a new value %s. "
"Use at your own risk." % (config_name,
config_node.get_value(), value))
elif config_node.is_configured():
if config_node.get_sole_init():
logging.warning(
"The value of config '%s' (%s) is protected by sole_init. "
"It is now being overriden by the overide_all flag to a new value %s. "
"Use at your own risk." % (config_name,
config_node.get_value(), value))
if not config_node.is_mutable():
logging.warning(
"The value of config '%s' (%s) is immutable. "
"It is now being overriden by the overide_all flag to a new value %s. "
"Use at your own risk." % (config_name,
config_node.get_value(), value))
config_node.set_value(value)
return

if config_node.is_configured():
if config_node.get_sole_init():
raise ValueError(
raise RuntimeError(
"Config '%s' is protected by sole_init and cannot be reconfigured. "
"If you wish to set this config value, do so the location of the "
"previous call." % config_name)
if sole_init:
raise ValueError(
"Config '%s' has already been configured. If you wish to protect "
"this config with sole_init, the previous alf.config call must be "
"removed." % config_name)

if config_node.is_mutable():
logging.warning(
"The value of config '%s' has been configured to %s. It is "
"replaced by the new value %s" %
(config_name, config_node.get_value(), value))
config_node.set_value(value)
config_node.set_mutable(mutable)
else:
logging.warning(
"The config '%s' has been configured to an immutable value "
"of %s. The new value %s will be ignored" %
(config_name, config_node.get_value(), value))
else:
config_node.set_value(value)
config_node.set_mutable(mutable)
config_node.set_sole_init(sole_init)
config_node.set_sole_init(sole_init)
return

config_node.set_value(value)
config_node.set_mutable(mutable)
config_node.set_sole_init(sole_init)


@logging.skip_log_prefix
Expand All @@ -454,7 +446,7 @@ def pre_config(configs):
"""
for name, value in configs.items():
try:
config1(name, value, mutable=False)
config1(name, value, mutable=False, sole_init=False)
_HANDLED_PRE_CONFIGS.append((name, value))
except ValueError:
_PRE_CONFIGS.append((name, value))
Expand Down
59 changes: 53 additions & 6 deletions alf/config_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,31 +133,31 @@ def test_config1(self):
pprint.pformat(inoperative_configs))
self.assertTrue('A.B.C.D.test.arg' in dict(inoperative_configs))

def test_sole_config(self):
# Test sole_init protection against future config calls
@alf.configurable
def sole_init_test_prior(x):
pass

alf.config("sole_init_test_prior", x=0, sole_init=True)
with self.assertRaises(ValueError) as context:
with self.assertRaises(RuntimeError) as context:
alf.config("sole_init_test_prior", x=0)

# Test sole_init protection against previous config calls
# Test sole_init against previous config calls. Should not trigger an error.
@alf.configurable
def sole_init_test_after(x):
pass

alf.config("sole_init_test_after", x=0)
with self.assertRaises(ValueError) as context:
alf.config("sole_init_test_after", x=0, sole_init=True)
alf.config("sole_init_test_after", x=0, sole_init=True)

# Test sole_init protection against other sole_init config calls
@alf.configurable
def sole_init_test_twice(x):
pass

alf.config("sole_init_test_twice", x=0, sole_init=True)
with self.assertRaises(ValueError) as context:
with self.assertRaises(RuntimeError) as context:
alf.config("sole_init_test_twice", x=0, sole_init=True)

# Test sole_init protection works as expected when the ALF_SOLE_CONFIG
Expand All @@ -168,7 +168,7 @@ def sole_init_test_env(x):

os.environ["ALF_SOLE_CONFIG"] = "1"
alf.config("sole_init_test_env", x=0)
with self.assertRaises(ValueError) as context:
with self.assertRaises(RuntimeError) as context:
alf.config("sole_init_test_env", x=0)
os.environ["ALF_SOLE_CONFIG"] = "0"

Expand All @@ -182,6 +182,53 @@ def sole_init_test_env(x):
self.assertEqual(alf.get_config_value("sole_init_test_twice.x"), 1)
self.assertEqual(alf.get_config_value("sole_init_test_env.x"), 1)

# Test override_config doesn't doesn't overwrite for immutable values.
@alf.configurable
def override_on_immutable(x):
pass

alf.config("override_on_immutable", x=0, mutable=False)
alf.override_config("override_on_immutable", x=1)
self.assertEqual(alf.get_config_value("override_on_immutable.x"), 0)

# Test override_config doesn't doesn't overwrite for immutable values.
@alf.configurable
def override_on_immutable_and_sole_init(x):
pass

alf.config(
"override_on_immutable_and_sole_init",
x=0,
sole_init=True,
mutable=False)
alf.override_config("override_on_immutable_and_sole_init", x=1)
self.assertEqual(
alf.get_config_value("override_on_immutable_and_sole_init.x"), 0)

# Test that pre_config calls work before a sole_init config call
@alf.configurable
def pre_config_before(x):
pass

alf.pre_config({"pre_config_before.x": 0})
alf.config("pre_config_before", x=1)
alf.override_config("pre_config_before", x=1)
# sole_init starts to take effect for all calls AFTER the first call.
alf.config("pre_config_before", x=1, sole_init=True)
with self.assertRaises(RuntimeError) as context:
alf.config("pre_config_before", x=2)
# If truly immutable, the value should never have been changed
self.assertEqual(alf.get_config_value("pre_config_before.x"), 0)

# Test that pre_config calls after a sole_init config call raises an error
@alf.configurable
def pre_config_after(x):
pass

alf.config("pre_config_after", x=1, sole_init=True)
with self.assertRaises(RuntimeError) as context:
alf.pre_config({"pre_config_after.x": 0})

def test_repr_wrapper(self):
a = MyClass(1, 2)
self.assertEqual(repr(a), "MyClass(1, 2)")
Expand Down

0 comments on commit 192a34c

Please sign in to comment.