Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change override to override_sole_init in alf.config #1705

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion alf/config_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,11 @@ def get_env():
# A 'None' random seed won't set a deterministic torch behavior.
for _ in range(PerProcessContext().ddp_rank):
random_seed = random.randint(0, 2**32)
config1("TrainerConfig.random_seed", random_seed, raise_if_used=False)
config1(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are several other config1 calls in config_helpers.py that may need fix too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch. This has been fixed for the other calls.

"TrainerConfig.random_seed",
random_seed,
raise_if_used=False,
override_sole_init=True)

# We have to call set_random_seed() here because we need the actual
# random seed to call create_environment.
Expand Down
122 changes: 58 additions & 64 deletions alf/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def config(prefix_or_dict,
mutable=True,
raise_if_used=True,
sole_init=False,
override=False,
override_sole_init=False,
le-horizon marked this conversation as resolved.
Show resolved Hide resolved
**kwargs):
"""Set the values for the configs with given name as suffix.

Expand Down Expand Up @@ -112,19 +112,18 @@ def func(self, a, b):
immutable value to an existing immutable value.
raise_if_used (bool): If True, ValueError will be raised if trying to
config a value which has already been used.
sole_init (bool): If True, the config value can only be set once. Any
previous or future calls will raise a ValueError. This is helpful
in enforcing a singular point of initialization, thus eliminating
any potential side effects from possible prior or future overrides.
This flag overrides the mutable flag if True. For users wanting this
to be the default behavior, the ALF_SOLE_CONFIG env variable can be
set to 1.
override (bool): If True, the value of the config will be set regardless
of any pre-existing ``mutable`` or ``sole_init`` settings. This should
be used only when absolutely necessary (e.g., a teacher-student training
loop, where the student must override certain configs inherited from the
teacher). Otherwise, use ``mutable`` or ``sole_init`` instead. If override
is True, the config's ``mutable`` or ``sole_init`` values are not changed.
sole_init (bool): If True, the config value can no longer be set again
after this config call. Any future calls will raise a RuntimeError.
This is helpful in enforcing a singular point of initialization,
thus eliminating any potential side effects from possible future
overrides. For users wanting this to be the default behavior, the
ALF_SOLE_CONFIG env variable can be set to 1.
override_sole_init (bool): If True, the value of the config will be set
regardless of any previous ``sole_init`` setting. This should be used
only when absolutely necessary (e.g., a teacher-student training loop,
where the student must override certain configs inherited from the
teacher). If the config is immutable, a warning will be declared with
no changes made.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of quietly ignoring the config, should we just throw error for this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the reason we cannot do this is because alf.pre_config and --conf_param set values with mutable=False (that is how those values avoid getting overwritten). Therefore, there may be times we want to override our conf files using one of the two options and raising an error in this situation is problematic. A workaround for this was to possibly introduce a pre_configged state variable for configs, but I thought that may be unnecessary. Let me know what you think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. You are right. Let's go with current way for the PR.

Separately, are there some examples of these cases? It would be good to understand a bit more, and see if we can design better solutions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to examples of the teacher-student setup?

One is our own agent training. Let's say you train expert with ALF_SOLE_CONFIG=1. When we train the agent with ALF_SOLE_CONFIG=1 as well, it will inevitably raise errors as agent often overwrites some of the expert's configs. Here, we can use the alf.override_config function. Not only does this minimize config calls, but it also explicitly tells the coder which inherited calls are being overwritten.

**kwargs: only used if ``prefix_or_dict`` is a str.
"""
if isinstance(prefix_or_dict, str):
Expand All @@ -144,18 +143,18 @@ def func(self, a, b):
sole_init = sole_init or GET_ALF_SOLE_CONFIG()

for key, value in configs.items():
config1(key, value, mutable, raise_if_used, sole_init, override)
config1(key, value, mutable, raise_if_used, sole_init,
override_sole_init)


def override_config(prefix_or_dict, **kwargs):
"""Wrapper function for configuring a config with override=True.
"""Wrapper function for configuring a config with override_sole_init=True.

This call will change a config's value, ignoring any previous protections
placed upon a config by the ``mutable`` and ``sole_init`` flags.
Therefore, it is highly recommended that this be used only when absolutely
necessary (e.g., a teacher-student training loop, where the student must
override certain configs inherited from the teacher). Otherwise, it is best
to use alf.config with the ``mutable`` and ``sole_init`` flags instead.
This call allows a user to attempt to overwrite a config's value even
if it is protected by ``sole_init``. It is highly recommended that this be
used only when absolutely necessary (e.g., a teacher-student training loop,
where the student must override certain configs inherited from the teacher).
This call has no effect for configs who are immutable.

Args:
prefix_or_dict (str|dict): if a dict, each (key, value) pair in it
Expand All @@ -164,7 +163,7 @@ def override_config(prefix_or_dict, **kwargs):
value for config with name ``prefix + '.' + key``
**kwargs: only used if ``prefix_or_dict`` is a str.
"""
config(prefix_or_dict, override=True, **kwargs)
config(prefix_or_dict, override_sole_init=True, **kwargs)


def get_all_config_names():
Expand Down Expand Up @@ -367,7 +366,7 @@ def config1(config_name,
mutable=True,
raise_if_used=True,
sole_init=False,
override=False):
override_sole_init=False):
"""Set one configurable value.

Args:
Expand All @@ -380,17 +379,18 @@ def config1(config_name,
immutable value to an existing immutable value.
raise_if_used (bool): If True, ValueError will be raised if trying to
config a value which has already been used.
sole_init (bool): If True, the config value can only be set once. Any
previous or future calls will raise a ValueError. This is helpful
in enforcing a singular point of initialization, thus eliminating
any potential side effects from possible prior or future overrides.
This flag overrides the mutable flag if True.
override (bool): If True, the value of the config will be set regardless
of any pre-existing ``mutable`` or ``sole_init`` settings. This should
be used only when absolutely necessary (e.g., a teacher-student training
loop, where the student must override certain configs inherited from the
teacher). Otherwise, use ``mutable`` or ``sole_init`` instead. If override
is True, the config's ``mutable`` or ``sole_init`` values are not changed.
sole_init (bool): If True, the config value can no longer be set again
after this config call. Any future calls will raise a RuntimeError.
This is helpful in enforcing a singular point of initialization,
thus eliminating any potential side effects from possible future
overrides. For users wanting this to be the default behavior, the
ALF_SOLE_CONFIG env variable can be set to 1.
override_sole_init (bool): If True, the value of the config will be set
regardless of any previous ``sole_init`` setting. This should be used
only when absolutely necessary (e.g., a teacher-student training loop,
where the student must override certain configs inherited from the
teacher). If the config is immutable, a warning will be declared with
no changes made.
"""
config_node = _get_config_node(config_name)

Expand All @@ -399,49 +399,43 @@ def config1(config_name,
"Config '%s' has already been used. You should config "
"its value before using it." % config_name)

if override:
if override_sole_init:
if config_node.is_configured():
if not config_node.is_mutable():
logging.warning(
"The value of config '%s' (%s) is immutable. "
"Override flag with new value %s is ignored. " %
(config_name, config_node.get_value(), value))
return
elif config_node.get_sole_init():
logging.warning(
"The value of config '%s' (%s) is protected by sole_init. "
"It is now being overridden by the override flag to a new value %s. "
"Use at your own risk." % (config_name,
config_node.get_value(), value))
elif config_node.is_configured():
if config_node.get_sole_init():
logging.warning(
"The value of config '%s' (%s) is protected by sole_init. "
"It is now being overridden by the overide_all flag to a new value %s. "
"Use at your own risk." % (config_name,
config_node.get_value(), value))
if not config_node.is_mutable():
logging.warning(
"The value of config '%s' (%s) is immutable. "
"It is now being overridden by the overide_all flag to a new value %s. "
"Use at your own risk." % (config_name,
config_node.get_value(), value))
config_node.set_value(value)
return

if config_node.is_configured():
if config_node.get_sole_init():
raise ValueError(
raise RuntimeError(
"Config '%s' is protected by sole_init and cannot be reconfigured. "
"If you wish to set this config value, do so the location of the "
"previous call." % config_name)
if sole_init:
raise ValueError(
"Config '%s' has already been configured. If you wish to protect "
"this config with sole_init, the previous alf.config call must be "
"removed." % config_name)

if config_node.is_mutable():
logging.warning(
"The value of config '%s' has been configured to %s. It is "
"replaced by the new value %s" %
(config_name, config_node.get_value(), value))
config_node.set_value(value)
config_node.set_mutable(mutable)
else:
logging.warning(
"The config '%s' has been configured to an immutable value "
"of %s. The new value %s will be ignored" %
(config_name, config_node.get_value(), value))
else:
config_node.set_value(value)
config_node.set_mutable(mutable)
config_node.set_sole_init(sole_init)
return

config_node.set_value(value)
config_node.set_mutable(mutable)
if not override_sole_init:
config_node.set_sole_init(sole_init)


Expand All @@ -462,7 +456,7 @@ def pre_config(configs):
"""
for name, value in configs.items():
try:
config1(name, value, mutable=False)
config1(name, value, mutable=False, sole_init=False)
_HANDLED_PRE_CONFIGS.append((name, value))
except ValueError:
_PRE_CONFIGS.append((name, value))
Expand Down
69 changes: 63 additions & 6 deletions alf/config_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,31 +133,31 @@ def test_config1(self):
pprint.pformat(inoperative_configs))
self.assertTrue('A.B.C.D.test.arg' in dict(inoperative_configs))

def test_sole_config(self):
# Test sole_init protection against future config calls
@alf.configurable
def sole_init_test_prior(x):
pass

alf.config("sole_init_test_prior", x=0, sole_init=True)
with self.assertRaises(ValueError) as context:
with self.assertRaises(RuntimeError) as context:
alf.config("sole_init_test_prior", x=0)

# Test sole_init protection against previous config calls
# Test sole_init against previous config calls. Should not trigger an error.
@alf.configurable
def sole_init_test_after(x):
pass

alf.config("sole_init_test_after", x=0)
with self.assertRaises(ValueError) as context:
alf.config("sole_init_test_after", x=0, sole_init=True)
alf.config("sole_init_test_after", x=0, sole_init=True)

# Test sole_init protection against other sole_init config calls
@alf.configurable
def sole_init_test_twice(x):
pass

alf.config("sole_init_test_twice", x=0, sole_init=True)
with self.assertRaises(ValueError) as context:
with self.assertRaises(RuntimeError) as context:
alf.config("sole_init_test_twice", x=0, sole_init=True)

# Test sole_init protection works as expected when the ALF_SOLE_CONFIG
Expand All @@ -168,7 +168,7 @@ def sole_init_test_env(x):

os.environ["ALF_SOLE_CONFIG"] = "1"
alf.config("sole_init_test_env", x=0)
with self.assertRaises(ValueError) as context:
with self.assertRaises(RuntimeError) as context:
alf.config("sole_init_test_env", x=0)
os.environ["ALF_SOLE_CONFIG"] = "0"

Expand All @@ -182,6 +182,63 @@ def sole_init_test_env(x):
self.assertEqual(alf.get_config_value("sole_init_test_twice.x"), 1)
self.assertEqual(alf.get_config_value("sole_init_test_env.x"), 1)

# Test override_config doesn't doesn't overwrite for immutable values.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove one "doesn't"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Done.

@alf.configurable
def override_on_immutable(x):
pass

alf.config("override_on_immutable", x=0, mutable=False)
alf.override_config("override_on_immutable", x=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to raise an error on overriding immutable config? Silently ignoring override can be hard to debug?

Copy link
Contributor Author

@QuantuMope QuantuMope Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to my reply above for why this cannot be done currently without adding more state parameters. There is a logger warning that gets printed in this scenario, though it's easy for users to not see it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's color code the warning message a bright yellow, so it's harder to miss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now it prints as red via absl.logging.warning. Do you think yellow is better?

self.assertEqual(alf.get_config_value("override_on_immutable.x"), 0)

# Test override_config doesn't doesn't overwrite for immutable values.
@alf.configurable
def override_on_immutable_and_sole_init(x):
pass

alf.config(
"override_on_immutable_and_sole_init",
x=0,
sole_init=True,
mutable=False)
alf.override_config("override_on_immutable_and_sole_init", x=1)
self.assertEqual(
alf.get_config_value("override_on_immutable_and_sole_init.x"), 0)

# Test that pre_config calls work before a sole_init config call
@alf.configurable
def pre_config_before(x):
pass

alf.pre_config({"pre_config_before.x": 0})
alf.config("pre_config_before", x=1)
alf.override_config("pre_config_before", x=1)
le-horizon marked this conversation as resolved.
Show resolved Hide resolved
# sole_init starts to take effect for all calls AFTER the first call.
alf.config("pre_config_before", x=1, sole_init=True)
with self.assertRaises(RuntimeError) as context:
alf.config("pre_config_before", x=2)
# If truly immutable, the value should never have been changed
self.assertEqual(alf.get_config_value("pre_config_before.x"), 0)

# Test that pre_config calls after a sole_init config call raises an error
@alf.configurable
def pre_config_after(x):
pass

alf.config("pre_config_after", x=1, sole_init=True)
with self.assertRaises(RuntimeError) as context:
alf.pre_config({"pre_config_after.x": 0})

# Test that calling override_config doesn't affect previous sole_init calls.
@alf.configurable
def override_no_affect_sole_init(x):
pass

alf.config("override_no_affect_sole_init", x=1, sole_init=True)
alf.override_config("override_no_affect_sole_init", x=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this also throw error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't throw an error since override_config allows us to set a config's value despite an earlier sole_init=True call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Maybe rename the function to alf.override_sole_config()? Just as how you renamed override to override_sole_init.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Done.

with self.assertRaises(RuntimeError) as context:
alf.config("override_no_affect_sole_init", x=3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test that x's value is 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Done.


def test_repr_wrapper(self):
a = MyClass(1, 2)
self.assertEqual(repr(a), "MyClass(1, 2)")
Expand Down
Loading