diff --git a/alf/config_helpers.py b/alf/config_helpers.py index f31867b37..ce71f88a9 100644 --- a/alf/config_helpers.py +++ b/alf/config_helpers.py @@ -154,7 +154,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int, tag, math.ceil(num_parallel_environments / multi_process_divider), raise_if_used=False, - override_sole_init=True) + override_all=True) # Adjust the mini_batch_size. If the original configured value is 64 and # there are 4 processes, it should mean that "jointly the 4 processes have @@ -167,7 +167,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int, tag, math.ceil(mini_batch_size / multi_process_divider), raise_if_used=False, - override_sole_init=True) + override_all=True) # If the termination condition is num_env_steps instead of num_iterations, # we need to adjust it as well since each process only sees env steps taking @@ -179,7 +179,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int, tag, math.ceil(num_env_steps / multi_process_divider), raise_if_used=False, - override_sole_init=True) + override_all=True) tag = 'TrainerConfig.initial_collect_steps' init_collect_steps = get_config_value(tag) @@ -187,7 +187,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int, tag, math.ceil(init_collect_steps / multi_process_divider), raise_if_used=False, - override_sole_init=True) + override_all=True) # Only allow process with rank 0 to have evaluate. Enabling evaluation for # other parallel processes is a waste as such evaluation does not offer more @@ -197,7 +197,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int, 'TrainerConfig.evaluate', False, raise_if_used=False, - override_sole_init=True) + override_all=True) def parse_config(conf_file, conf_params, create_env=True): diff --git a/alf/config_util.py b/alf/config_util.py index f43c29d1b..ae7885c20 100644 --- a/alf/config_util.py +++ b/alf/config_util.py @@ -366,7 +366,8 @@ def config1(config_name, mutable=True, raise_if_used=True, sole_init=False, - override_sole_init=False): + override_sole_init=False, + override_all=False): """Set one configurable value. Args: @@ -391,7 +392,10 @@ def config1(config_name, where the student must override certain configs inherited from the teacher). If the config is immutable, a warning will be declared with no changes made. - """ + override_all (bool): If True, the value of the config will be set regardless + of any pre-existing ``mutable`` or ``sole_init`` settings. This should + be used only when absolutely necessary (e.g., adjusting certain configs + such as mini_batch_size for DDP workers.).""" config_node = _get_config_node(config_name) if raise_if_used and config_node.is_used(): @@ -399,7 +403,22 @@ def config1(config_name, "Config '%s' has already been used. You should config " "its value before using it." % config_name) - if override_sole_init: + if override_all: + if config_node.get_sole_init(): + logging.warning( + "The value of config '%s' (%s) is protected by sole_init. " + "It is now being overridden by the overide_all flag to a new value %s. " + "Use at your own risk." % (config_name, + config_node.get_value(), value)) + if not config_node.is_mutable(): + logging.warning( + "The value of config '%s' (%s) is immutable. " + "It is now being overridden by the overide_all flag to a new value %s. " + "Use at your own risk." % (config_name, + config_node.get_value(), value)) + config_node.set_value(value) + return + elif override_sole_init: if config_node.is_configured(): if not config_node.is_mutable(): logging.warning(