From 462a0815f9c6e9d067f08767a2e7d9e6c48623dd Mon Sep 17 00:00:00 2001 From: Wei Xu Date: Fri, 8 Nov 2024 12:29:45 -0800 Subject: [PATCH] Use override=True for config1 in adjust_config_by_multi_process_divider() Sometimes, those configurations are provided through commandline. In general, configurations provided through commandline will not be changed by config1(). However, for DDP training, even if these configurations are from commandline, we still need to change them. --- alf/config_helpers.py | 10 +++++----- alf/config_util.py | 25 ++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/alf/config_helpers.py b/alf/config_helpers.py index f31867b37..ce71f88a9 100644 --- a/alf/config_helpers.py +++ b/alf/config_helpers.py @@ -154,7 +154,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int, tag, math.ceil(num_parallel_environments / multi_process_divider), raise_if_used=False, - override_sole_init=True) + override_all=True) # Adjust the mini_batch_size. If the original configured value is 64 and # there are 4 processes, it should mean that "jointly the 4 processes have @@ -167,7 +167,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int, tag, math.ceil(mini_batch_size / multi_process_divider), raise_if_used=False, - override_sole_init=True) + override_all=True) # If the termination condition is num_env_steps instead of num_iterations, # we need to adjust it as well since each process only sees env steps taking @@ -179,7 +179,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int, tag, math.ceil(num_env_steps / multi_process_divider), raise_if_used=False, - override_sole_init=True) + override_all=True) tag = 'TrainerConfig.initial_collect_steps' init_collect_steps = get_config_value(tag) @@ -187,7 +187,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int, tag, math.ceil(init_collect_steps / multi_process_divider), raise_if_used=False, - override_sole_init=True) + override_all=True) # Only allow process with rank 0 to have evaluate. Enabling evaluation for # other parallel processes is a waste as such evaluation does not offer more @@ -197,7 +197,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int, 'TrainerConfig.evaluate', False, raise_if_used=False, - override_sole_init=True) + override_all=True) def parse_config(conf_file, conf_params, create_env=True): diff --git a/alf/config_util.py b/alf/config_util.py index f43c29d1b..ae7885c20 100644 --- a/alf/config_util.py +++ b/alf/config_util.py @@ -366,7 +366,8 @@ def config1(config_name, mutable=True, raise_if_used=True, sole_init=False, - override_sole_init=False): + override_sole_init=False, + override_all=False): """Set one configurable value. Args: @@ -391,7 +392,10 @@ def config1(config_name, where the student must override certain configs inherited from the teacher). If the config is immutable, a warning will be declared with no changes made. - """ + override_all (bool): If True, the value of the config will be set regardless + of any pre-existing ``mutable`` or ``sole_init`` settings. This should + be used only when absolutely necessary (e.g., adjusting certain configs + such as mini_batch_size for DDP workers.).""" config_node = _get_config_node(config_name) if raise_if_used and config_node.is_used(): @@ -399,7 +403,22 @@ def config1(config_name, "Config '%s' has already been used. You should config " "its value before using it." % config_name) - if override_sole_init: + if override_all: + if config_node.get_sole_init(): + logging.warning( + "The value of config '%s' (%s) is protected by sole_init. " + "It is now being overridden by the overide_all flag to a new value %s. " + "Use at your own risk." % (config_name, + config_node.get_value(), value)) + if not config_node.is_mutable(): + logging.warning( + "The value of config '%s' (%s) is immutable. " + "It is now being overridden by the overide_all flag to a new value %s. " + "Use at your own risk." % (config_name, + config_node.get_value(), value)) + config_node.set_value(value) + return + elif override_sole_init: if config_node.is_configured(): if not config_node.is_mutable(): logging.warning(