diff --git a/iris/algorithms/ars_algorithm.py b/iris/algorithms/ars_algorithm.py index b45fc22..d1da025 100644 --- a/iris/algorithms/ars_algorithm.py +++ b/iris/algorithms/ars_algorithm.py @@ -14,59 +14,33 @@ """Algorithm class for Augmented Random Search Blackbox algorithm.""" -import collections import math -import pathlib -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence -from absl import logging -from flax import linen as nn -from iris import checkpoint_util from iris import normalizer from iris import worker_util from iris.algorithms import algorithm from iris.algorithms import stateless_perturbation_generators -import jax -import jax.numpy as jnp import numpy as np -PRNGKey = jax.Array - _DUMMY_REWARD = -1_000_000_000.0 -class MLP(nn.Module): - """Defines an MLP model for learned hyper-params.""" - - hidden_sizes: Sequence[int] = (32, 16) - output_size: int = 2 - - @nn.compact - def __call__(self, x: jnp.ndarray, state: Any): - for feat in self.hidden_sizes: - x = nn.Dense(feat)(x) - x = nn.tanh(x) - x = nn.Dense(self.output_size)(x) - return nn.sigmoid(x), state - - def initialize_carry(self, rng: PRNGKey, params: jnp.ndarray) -> Any: - del rng, params - return None - - class AugmentedRandomSearch(algorithm.BlackboxAlgorithm): """Augmented random search algorithm for blackbox optimization.""" - def __init__(self, - std: float | Callable[[int], float], - step_size: float | Callable[[int], float], - top_percentage: float = 1.0, - orthogonal_suggestions: bool = False, - quasirandom_suggestions: bool = False, - top_sort_type: str = "max", - obs_norm_data_buffer: Optional[normalizer.MeanStdBuffer] = None, - **kwargs) -> None: + def __init__( + self, + std: float | Callable[[int], float], + step_size: float | Callable[[int], float], + top_percentage: float = 1.0, + orthogonal_suggestions: bool = False, + quasirandom_suggestions: bool = False, + top_sort_type: str = "max", + obs_norm_data_buffer: Optional[normalizer.MeanStdBuffer] = None, + **kwargs, + ) -> None: """Initializes the augmented random search algorithm. Args: @@ -107,7 +81,8 @@ def initialize(self, state: Dict[str, Any]) -> None: self._obs_norm_data_buffer.data = state["obs_norm_buffer_data"] def process_evaluations( - self, eval_results: Sequence[worker_util.EvaluationResult]) -> None: + self, eval_results: Sequence[worker_util.EvaluationResult] + ) -> None: """Processes the list of Blackbox function evaluations return from workers. Gradient is computed by taking a weighted sum of directions and @@ -124,13 +99,12 @@ def process_evaluations( """ # Retrieve delta direction from the param suggestion sent for evaluation. - pos_eval_results = eval_results[:self._num_suggestions] - neg_eval_results = eval_results[self._num_suggestions:] + pos_eval_results = eval_results[: self._num_suggestions] + neg_eval_results = eval_results[self._num_suggestions :] filtered_pos_eval_results = [] filtered_neg_eval_results = [] - for (peval, neval) in zip(pos_eval_results, neg_eval_results): - if (peval.params_evaluated.size) and ( - neval.params_evaluated.size): + for peval, neval in zip(pos_eval_results, neg_eval_results): + if (peval.params_evaluated.size) and (neval.params_evaluated.size): filtered_pos_eval_results.append(peval) filtered_neg_eval_results.append(neval) params = np.array([r.params_evaluated for r in filtered_pos_eval_results]) @@ -145,7 +119,7 @@ def process_evaluations( max_evals = np.max(np.vstack([pos_evals, neg_evals]), axis=0) else: max_evals = np.abs(pos_evals - neg_evals) - idx = (-max_evals).argsort()[:self._num_top] + idx = (-max_evals).argsort()[: self._num_top] pos_evals = pos_evals[idx] neg_evals = neg_evals[idx] all_top_evals = np.hstack([pos_evals, neg_evals]) @@ -170,8 +144,9 @@ def process_evaluations( for r in eval_results: self._obs_norm_data_buffer.merge(r.obs_norm_buffer_data) - def get_param_suggestions(self, - evaluate: bool = False) -> Sequence[Dict[str, Any]]: + def get_param_suggestions( + self, evaluate: bool = False + ) -> Sequence[Dict[str, Any]]: """Suggests a list of inputs to evaluate the Blackbox function on. Suggestions are sampled from a gaussian distribution around the current @@ -204,20 +179,22 @@ def get_param_suggestions(self, ortho_pert_blocks = [] for _ in range(math.ceil(float(self._num_suggestions / dimensions))): perturbations = self._np_random_state.normal( - 0, 1, (self._num_suggestions, dimensions)) + 0, 1, (self._num_suggestions, dimensions) + ) ortho_matrix, _ = np.linalg.qr(perturbations.T) ortho_pert_blocks.append(np.sqrt(dimensions) * ortho_matrix.T) param_suggestions = np.vstack(ortho_pert_blocks) - param_suggestions = param_suggestions[:self._num_suggestions, :] + param_suggestions = param_suggestions[: self._num_suggestions, :] else: param_suggestions = self._np_random_state.normal( - 0, 1, (self._num_suggestions, dimensions)) + 0, 1, (self._num_suggestions, dimensions) + ) self._last_std_used = self._std if callable(self._std): self._last_std_used = self._std(self._iteration) param_suggestions = np.vstack([ self._opt_params + self._last_std_used * param_suggestions, - self._opt_params - self._last_std_used * param_suggestions + self._opt_params - self._last_std_used * param_suggestions, ]) suggestions = [] @@ -250,386 +227,3 @@ def _set_state(self, new_state: Dict[str, Any]) -> None: def restore_state_from_checkpoint(self, new_state: Dict[str, Any]) -> None: self.state = new_state - - -class LearnableAugmentedRandomSearch(AugmentedRandomSearch): - """Learnable augmented random search algorithm for blackbox optimization.""" - - def __init__( - self, - model: Callable[[], nn.Module] = MLP, - model_path: Optional[str] = None, - top_percentage: float = 1.0, - orthogonal_suggestions: bool = False, - quasirandom_suggestions: bool = False, - top_sort_type: str = "max", - obs_norm_data_buffer: Optional[normalizer.MeanStdBuffer] = None, - seed: int = 42, - reward_buffer_size: int = 10, - **kwargs, - ) -> None: - """Initializes the learnable augmented random search algorithm. - - Args: - model: The model class to use when loading the meta-policy. - model_path: The checkpoint path to load the meta-policy from. - top_percentage: Fraction of top performing perturbations to use for - gradient estimation. - orthogonal_suggestions: Whether to orthogonalize the perturbations. - quasirandom_suggestions: Whether quasirandom perturbations should be used; - valid only if orthogonal_suggestions = True. - top_sort_type: How to sort evaluation results for selecting top - directions. Valid options are: "max" and "diff". - obs_norm_data_buffer: Buffer to sync statistics from all workers for - online mean std observation normalizer. - seed: The seed to use. - reward_buffer_size: the size of the reward buffer that stores a history of - rewards. - **kwargs: Other keyword arguments for base class. - """ - super().__init__(**kwargs) - super().__init__(**kwargs) - self._iteration = 0 - self._seed = seed - self._model_path = model_path - self._model = model() - self._last_std_used = 1.0 - self._num_top = int(top_percentage * self._num_suggestions) - self._num_top = max(1, self._num_top) - self._orthogonal_suggestions = orthogonal_suggestions - self._quasirandom_suggestions = quasirandom_suggestions - self._top_sort_type = top_sort_type - self._obs_norm_data_buffer = obs_norm_data_buffer - self._tree_weights = None - self._reward_buffer_size = reward_buffer_size - self._reward_buffer = collections.deque(maxlen=self._reward_buffer_size) - self._populate_reward_buffer() - self._step_size = 0.02 - self._std = 1.0 - - def _populate_reward_buffer(self): - """Populate reward buffer with very negative values.""" - self._reward_buffer.extend([_DUMMY_REWARD] * self._reward_buffer_size) - - def _restore_state_from_checkpoint(self, logdir: str): - try: - state = checkpoint_util.load_checkpoint_state(logdir) - iteration = 0 # No iteration information is extracted - return state, iteration - except ValueError: - logging.warning( - "Failed to load directly as a checkpoint, try searching subfolders" - " with checkpoints." - ) - return None, 0 - - def get_param_suggestions( - self, evaluate: bool = False - ) -> Sequence[Dict[str, Any]]: - """Suggests a list of inputs to evaluate the Blackbox function on. - - Suggestions are sampled from a gaussian distribution around the current - parameter vector. For each suggestion, a dict containing keyword arguments - for the worker is sent. - - Args: - evaluate: Whether to evaluate current optimization variables for reporting - training progress. - - Returns: - A list of suggested inputs for the workers to evaluate. - """ - if evaluate: - param_suggestions = [self._opt_params] * self._num_evals - else: - dimensions = self._opt_params.shape[0] - if self._orthogonal_suggestions: - if self._quasirandom_suggestions: - param_suggestions = ( - stateless_perturbation_generators.RandomHadamardMatrixGenerator( - self._num_suggestions, dimensions - ).generate_matrix() - ) - else: - # We generate random iid perturbations and orthogonalize them. In the - # case when the number of suggestions to be generated is greater than - # param dimensionality, we generate multiple orthogonal perturbation - # blocks. Rows are othogonal within a block but not across blocks. - ortho_pert_blocks = [] - for _ in range(math.ceil(float(self._num_suggestions / dimensions))): - perturbations = self._np_random_state.normal( - 0, 1, (self._num_suggestions, dimensions) - ) - ortho_matrix, _ = np.linalg.qr(perturbations.T) - ortho_pert_blocks.append(np.sqrt(dimensions) * ortho_matrix.T) - param_suggestions = np.vstack(ortho_pert_blocks) - param_suggestions = param_suggestions[: self._num_suggestions, :] - else: - param_suggestions = self._np_random_state.normal( - 0, 1, (self._num_suggestions, dimensions) - ) - self._last_std_used = self._std - param_suggestions = np.vstack([ - self._opt_params, - self._opt_params + self._last_std_used * param_suggestions, - self._opt_params - self._last_std_used * param_suggestions, - ]) - - suggestions = [] - for params in param_suggestions: - suggestion = {"params_to_eval": params} - if self._obs_norm_data_buffer is not None: - suggestion["obs_norm_state"] = self._obs_norm_data_buffer.state - suggestion["update_obs_norm_buffer"] = not evaluate - suggestions.append(suggestion) - return suggestions - - def process_evaluations( - self, eval_results: Sequence[worker_util.EvaluationResult] - ) -> None: - - self._reward_buffer.append(eval_results[0].value) - rewards = np.asarray(self._reward_buffer) - model_input = np.concatenate([[self._iteration], rewards]) - - if self._tree_weights is None: - self._state = self._restore_state_from_checkpoint(self._model_path) - self._tree_weights = self._model.init( - jax.random.PRNGKey(seed=self._seed), model_input, self._state - ) - - hyper_params, self._state = self._model.apply( - self._tree_weights, model_input, self._state - ) - step_size, std = hyper_params - self._step_size = step_size - self._std = std - super().process_evaluations(eval_results) - - -class MultiAgentAugmentedRandomSearch(AugmentedRandomSearch): - """Augmented random search algorithm for blackbox optimization.""" - - def __init__(self, - std: float, - step_size: float, - top_percentage: float = 1.0, - orthogonal_suggestions: bool = False, - quasirandom_suggestions: bool = False, - top_sort_type: str = "max", - obs_norm_data_buffer: Optional[normalizer.MeanStdBuffer] = None, - agent_keys: Optional[List[str]] = None, - restore_state_from_single_agent: bool = False, - **kwargs) -> None: - """Initializes the augmented random search algorithm for multi-agent training. - - Args: - std: Standard deviation for normal perturbations around current - optimization parameter vector. - step_size: Step size for gradient ascent. - top_percentage: Fraction of top performing perturbations to use for - gradient estimation. - orthogonal_suggestions: Whether to orthogonalize the perturbations. - quasirandom_suggestions: Whether quasirandom perturbations should be used; - valid only if orthogonal_suggestions = True. - top_sort_type: How to sort evaluation results for selecting top - directions. Valid options are: "max" and "diff". - obs_norm_data_buffer: Buffer to sync statistics from all workers for - online mean std observation normalizer. - agent_keys: List of keys which uniquely identify the agents. The ordering - needs to be consistent across the algorithm, policy, and worker. - restore_state_from_single_agent: if True then when - restore_state_from_checkpoint is called the state is duplicated - self._num_agents times. - **kwargs: Other keyword arguments for base class. - """ - super().__init__(std=std, - step_size=step_size, - top_percentage=top_percentage, - orthogonal_suggestions=orthogonal_suggestions, - quasirandom_suggestions=quasirandom_suggestions, - top_sort_type=top_sort_type, - obs_norm_data_buffer=obs_norm_data_buffer, - **kwargs) - if agent_keys is None: - self._agent_keys = ["arm", "opp"] - else: - self._agent_keys = agent_keys - self._num_agents = len(self._agent_keys) - self._restore_state_from_single_agent = restore_state_from_single_agent - - def _split_params(self, params: np.ndarray) -> List[np.ndarray]: - return np.array_split(params, self._num_agents) - - def _combine_params(self, params_per_agents: List[np.ndarray]) -> np.ndarray: - return np.concatenate(params_per_agents, axis=0) - - def restore_state_from_checkpoint(self, new_state: Dict[str, Any]) -> None: - logging.info("Restore: restore from 1 agent: %d", - self._restore_state_from_single_agent) - logging.info("Restore: num_agents: %d", self._num_agents) - logging.info("Restore: new state keys: %s", list(new_state.keys())) - logging.info("Restore: new_state params shape: %s", - new_state["params_to_eval"].shape) - - # Initialize multiple agents from a single agent. - if self._restore_state_from_single_agent: - if new_state["params_to_eval"].ndim != 1: - raise ValueError( - f"Params to eval has {new_state['params_to_eval'].ndim} dims, " - "should only have 1." - ) - duplicated_state = { - "params_to_eval": - np.tile(new_state["params_to_eval"], self._num_agents) - } - if self._obs_norm_data_buffer is not None: - duplicated_state["obs_norm_state"] = {} - duplicated_state["obs_norm_state"]["mean"] = np.tile( - new_state["obs_norm_state"]["mean"], self._num_agents) - duplicated_state["obs_norm_state"]["std"] = np.tile( - new_state["obs_norm_state"]["std"], self._num_agents) - duplicated_state["obs_norm_state"]["n"] = ( - new_state["obs_norm_state"]["n"]) - - self.state = duplicated_state - logging.info("Restore: duplicated states params shape: %s", - duplicated_state["params_to_eval"].shape) - - # Initialize one agent from a single agent. - else: - self.state = new_state - - logging.info("Restored state: params shape: %s, opt params shape: %s, " - "obs norm state: %s", - self.state["params_to_eval"].shape, - self._opt_params.shape, - self.state.get("obs_norm_state", None)) - if self._obs_norm_data_buffer is not None: - logging.info("Restored state: obs norm mean shape: %s, std shape: %s", - self.state["obs_norm_state"]["mean"].shape, - self.state["obs_norm_state"]["std"].shape) - - def maybe_save_custom_checkpoint(self, - state: Dict[str, Any], - checkpoint_path: Union[pathlib.Path, str] - ) -> None: - """Saves a checkpoint per agent with prefix checkpoint_path.""" - agent_params = self._split_params(state["params_to_eval"]) - for i in range(self._num_agents): - per_agent_state = {} - per_agent_state["params_to_eval"] = agent_params[i] - if self._obs_norm_data_buffer is not None: - obs_norm_state = state["obs_norm_state"] - elems_per_agent = int( - obs_norm_state["mean"].shape[-1] / self._num_agents) - per_agent_state["obs_norm_state"] = {} - start_idx = i * elems_per_agent - end_idx = (i + 1) * elems_per_agent - if obs_norm_state["mean"].ndim == 1: - per_agent_state["obs_norm_state"]["mean"] = ( - obs_norm_state["mean"][start_idx: end_idx]) - per_agent_state["obs_norm_state"]["std"] = ( - obs_norm_state["std"][start_idx: end_idx]) - else: - per_agent_state["obs_norm_state"]["mean"] = ( - obs_norm_state["mean"][:, start_idx: end_idx]) - per_agent_state["obs_norm_state"]["std"] = ( - obs_norm_state["std"][:, start_idx: end_idx]) - per_agent_state["obs_norm_state"]["n"] = obs_norm_state["n"] - agent_checkpoint_path = f"{checkpoint_path}_agent_{i}" - logging.info("Saving agent checkpoints to %s...", agent_checkpoint_path) - checkpoint_util.save_checkpoint(agent_checkpoint_path, per_agent_state) - - def split_and_save_checkpoint(self, checkpoint_path: str) -> None: - state = checkpoint_util.load_checkpoint_state(checkpoint_path) - self.maybe_save_custom_checkpoint(state=state, - checkpoint_path=checkpoint_path) - - def _get_top_evaluation_results( - self, - agent_key: str, - pos_eval_results: Sequence[worker_util.EvaluationResult], - neg_eval_results: Sequence[worker_util.EvaluationResult] - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - pos_evals = np.array( - [r.metrics[f"reward_{agent_key}"] for r in pos_eval_results]) - neg_evals = np.array( - [r.metrics[f"reward_{agent_key}"] for r in neg_eval_results]) - if self._top_sort_type == "max": - max_evals = np.max(np.vstack([pos_evals, neg_evals]), axis=0) - elif self._top_sort_type == "diff": - max_evals = np.abs(pos_evals - neg_evals) - idx = (-max_evals).argsort()[:self._num_top] - pos_evals = pos_evals[idx] - neg_evals = neg_evals[idx] - return pos_evals, neg_evals, idx - - def process_evaluations( - self, eval_results: Sequence[worker_util.EvaluationResult]) -> None: - """Processes the list of Blackbox function evaluations return from workers. - - Gradient is computed by taking a weighted sum of directions and - difference of their value from the current value. The current parameter - vector is then updated in the gradient direction with specified step size. - - Args: - eval_results: List containing Blackbox function evaluations based on the - order in which the suggestions were sent. ARS performs antithetic - gradient estimation. The suggestions are sent for evaluation in pairs. - The eval_results list should contain an even number of entries with the - first half entries corresponding to evaluation result of positive - perturbations and the last half corresponding to negative perturbations. - """ - - # Retrieve delta direction from the param suggestion sent for evaluation. - pos_eval_results = eval_results[:self._num_suggestions] - neg_eval_results = eval_results[self._num_suggestions:] - filtered_pos_eval_results = [] - filtered_neg_eval_results = [] - for i in range(len(pos_eval_results)): - if (pos_eval_results[i].params_evaluated.size) and ( - neg_eval_results[i].params_evaluated.size): - filtered_pos_eval_results.append(pos_eval_results[i]) - filtered_neg_eval_results.append(neg_eval_results[i]) - - params = np.array([r.params_evaluated for r in filtered_pos_eval_results]) - eval_results = filtered_pos_eval_results + filtered_neg_eval_results - - # This is length num pos results with splits per agent - eval_params_per_agent = [self._split_params(p) for p in params] - eval_params_per_agent = list(zip(*eval_params_per_agent)) - # This has length num agents with a 2d array with shape - # (num_pos_results, agent_params_dim). - eval_params_per_agent = [np.array(a) for a in eval_params_per_agent] - - current_params_per_agent = self._split_params(self._opt_params) - updated_params_per_agent = [] - for (agent_eval_params, agent_params, agent_key) in zip( - eval_params_per_agent, current_params_per_agent, self._agent_keys): - pos_evals, neg_evals, idx = self._get_top_evaluation_results( - agent_key=agent_key, - pos_eval_results=filtered_pos_eval_results, - neg_eval_results=filtered_neg_eval_results) - all_top_evals = np.hstack([pos_evals, neg_evals]) - evals = pos_evals - neg_evals - - # Get delta directions corresponding to top evals - directions = (agent_eval_params - agent_params) / self._std - directions = directions[idx, :] - - # Estimate gradients - gradient = np.dot(evals, directions) / evals.shape[0] - if not np.isclose(np.std(all_top_evals), 0.0): - gradient /= np.std(all_top_evals) - - # Apply gradients - updated_agent_params = agent_params + self._step_size * gradient - updated_params_per_agent.append(updated_agent_params) - - self._opt_params = self._combine_params(updated_params_per_agent) - - # Update the observation buffer - if self._obs_norm_data_buffer is not None: - for r in eval_results: - self._obs_norm_data_buffer.merge(r.obs_norm_buffer_data) diff --git a/iris/algorithms/learnable_ars_algorithm.py b/iris/algorithms/learnable_ars_algorithm.py new file mode 100644 index 0000000..98e9378 --- /dev/null +++ b/iris/algorithms/learnable_ars_algorithm.py @@ -0,0 +1,208 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Algorithm class for Learnable ARS.""" + +import collections +import math +from typing import Any, Callable, Dict, Optional, Sequence + +from absl import logging +from flax import linen as nn +from iris import checkpoint_util +from iris import normalizer +from iris import worker_util +from iris.algorithms import ars_algorithm +from iris.algorithms import stateless_perturbation_generators +import jax +import jax.numpy as jnp +import numpy as np + + +_DUMMY_REWARD = -1_000_000_000.0 + + +class MLP(nn.Module): + """Defines an MLP model for learned hyper-params.""" + + hidden_sizes: Sequence[int] = (32, 16) + output_size: int = 2 + + @nn.compact + def __call__(self, x: jnp.ndarray, state: Any): + for feat in self.hidden_sizes: + x = nn.Dense(feat)(x) + x = nn.tanh(x) + x = nn.Dense(self.output_size)(x) + return nn.sigmoid(x), state + + def initialize_carry(self, rng: jax.Array, params: jnp.ndarray) -> Any: + del rng, params + return None + + +class LearnableAugmentedRandomSearch(ars_algorithm.AugmentedRandomSearch): + """Learnable augmented random search algorithm for blackbox optimization.""" + + def __init__( + self, + model: Callable[[], nn.Module] = MLP, + model_path: Optional[str] = None, + top_percentage: float = 1.0, + orthogonal_suggestions: bool = False, + quasirandom_suggestions: bool = False, + top_sort_type: str = "max", + obs_norm_data_buffer: Optional[normalizer.MeanStdBuffer] = None, + seed: int = 42, + reward_buffer_size: int = 10, + **kwargs, + ) -> None: + """Initializes the learnable augmented random search algorithm. + + Args: + model: The model class to use when loading the meta-policy. + model_path: The checkpoint path to load the meta-policy from. + top_percentage: Fraction of top performing perturbations to use for + gradient estimation. + orthogonal_suggestions: Whether to orthogonalize the perturbations. + quasirandom_suggestions: Whether quasirandom perturbations should be used; + valid only if orthogonal_suggestions = True. + top_sort_type: How to sort evaluation results for selecting top + directions. Valid options are: "max" and "diff". + obs_norm_data_buffer: Buffer to sync statistics from all workers for + online mean std observation normalizer. + seed: The seed to use. + reward_buffer_size: the size of the reward buffer that stores a history of + rewards. + **kwargs: Other keyword arguments for base class. + """ + super().__init__(**kwargs) + super().__init__(**kwargs) + self._iteration = 0 + self._seed = seed + self._model_path = model_path + self._model = model() + self._last_std_used = 1.0 + self._num_top = int(top_percentage * self._num_suggestions) + self._num_top = max(1, self._num_top) + self._orthogonal_suggestions = orthogonal_suggestions + self._quasirandom_suggestions = quasirandom_suggestions + self._top_sort_type = top_sort_type + self._obs_norm_data_buffer = obs_norm_data_buffer + self._tree_weights = None + self._model_state = None + self._reward_buffer_size = reward_buffer_size + self._reward_buffer = collections.deque(maxlen=self._reward_buffer_size) + self._populate_reward_buffer() + self._step_size = 0.02 + self._std = 1.0 + + def _populate_reward_buffer(self): + """Populate reward buffer with very negative values.""" + self._reward_buffer.extend([_DUMMY_REWARD] * self._reward_buffer_size) + + def _restore_state_from_checkpoint(self, logdir: str): + try: + state = checkpoint_util.load_checkpoint_state(logdir) + iteration = 0 # No iteration information is extracted + return state, iteration + except ValueError: + logging.warning( + "Failed to load directly as a checkpoint, try searching subfolders" + " with checkpoints." + ) + return None, 0 + + def get_param_suggestions( + self, evaluate: bool = False + ) -> Sequence[Dict[str, Any]]: + """Suggests a list of inputs to evaluate the Blackbox function on. + + Suggestions are sampled from a gaussian distribution around the current + parameter vector. For each suggestion, a dict containing keyword arguments + for the worker is sent. + + Args: + evaluate: Whether to evaluate current optimization variables for reporting + training progress. + + Returns: + A list of suggested inputs for the workers to evaluate. + """ + if evaluate: + param_suggestions = [self._opt_params] * self._num_evals + else: + dimensions = self._opt_params.shape[0] + if self._orthogonal_suggestions: + if self._quasirandom_suggestions: + param_suggestions = ( + stateless_perturbation_generators.RandomHadamardMatrixGenerator( + self._num_suggestions, dimensions + ).generate_matrix() + ) + else: + # We generate random iid perturbations and orthogonalize them. In the + # case when the number of suggestions to be generated is greater than + # param dimensionality, we generate multiple orthogonal perturbation + # blocks. Rows are othogonal within a block but not across blocks. + ortho_pert_blocks = [] + for _ in range(math.ceil(float(self._num_suggestions / dimensions))): + perturbations = self._np_random_state.normal( + 0, 1, (self._num_suggestions, dimensions) + ) + ortho_matrix, _ = np.linalg.qr(perturbations.T) + ortho_pert_blocks.append(np.sqrt(dimensions) * ortho_matrix.T) + param_suggestions = np.vstack(ortho_pert_blocks) + param_suggestions = param_suggestions[: self._num_suggestions, :] + else: + param_suggestions = self._np_random_state.normal( + 0, 1, (self._num_suggestions, dimensions) + ) + self._last_std_used = self._std + param_suggestions = np.vstack([ + self._opt_params, + self._opt_params + self._last_std_used * param_suggestions, + self._opt_params - self._last_std_used * param_suggestions, + ]) + + suggestions = [] + for params in param_suggestions: + suggestion = {"params_to_eval": params} + if self._obs_norm_data_buffer is not None: + suggestion["obs_norm_state"] = self._obs_norm_data_buffer.state + suggestion["update_obs_norm_buffer"] = not evaluate + suggestions.append(suggestion) + return suggestions + + def process_evaluations( + self, eval_results: Sequence[worker_util.EvaluationResult] + ) -> None: + + self._reward_buffer.append(eval_results[0].value) + rewards = np.asarray(self._reward_buffer) + model_input = np.concatenate([[self._iteration], rewards]) + + if self._tree_weights is None: + self._model_state = self._restore_state_from_checkpoint(self._model_path) + self._tree_weights = self._model.init( + jax.random.PRNGKey(seed=self._seed), model_input, self._model_state + ) + + hyper_params, self._state = self._model.apply( + self._tree_weights, model_input, self._model_state + ) + step_size, std = hyper_params + self._step_size = step_size + self._std = std + super().process_evaluations(eval_results) diff --git a/iris/algorithms/multi_agent_ars_algorithm.py b/iris/algorithms/multi_agent_ars_algorithm.py new file mode 100644 index 0000000..b160096 --- /dev/null +++ b/iris/algorithms/multi_agent_ars_algorithm.py @@ -0,0 +1,283 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiagent ARS algorithm.""" + +import pathlib +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from absl import logging +from iris import checkpoint_util +from iris import normalizer +from iris import worker_util +from iris.algorithms import ars_algorithm +import numpy as np + + +class MultiAgentAugmentedRandomSearch(ars_algorithm.AugmentedRandomSearch): + """Augmented random search algorithm for blackbox optimization.""" + + def __init__( + self, + std: float, + step_size: float, + top_percentage: float = 1.0, + orthogonal_suggestions: bool = False, + quasirandom_suggestions: bool = False, + top_sort_type: str = "max", + obs_norm_data_buffer: Optional[normalizer.MeanStdBuffer] = None, + agent_keys: Optional[List[str]] = None, + restore_state_from_single_agent: bool = False, + **kwargs, + ) -> None: + """Initializes the augmented random search algorithm for multi-agent training. + + Args: + std: Standard deviation for normal perturbations around current + optimization parameter vector. + step_size: Step size for gradient ascent. + top_percentage: Fraction of top performing perturbations to use for + gradient estimation. + orthogonal_suggestions: Whether to orthogonalize the perturbations. + quasirandom_suggestions: Whether quasirandom perturbations should be used; + valid only if orthogonal_suggestions = True. + top_sort_type: How to sort evaluation results for selecting top + directions. Valid options are: "max" and "diff". + obs_norm_data_buffer: Buffer to sync statistics from all workers for + online mean std observation normalizer. + agent_keys: List of keys which uniquely identify the agents. The ordering + needs to be consistent across the algorithm, policy, and worker. + restore_state_from_single_agent: if True then when + restore_state_from_checkpoint is called the state is duplicated + self._num_agents times. + **kwargs: Other keyword arguments for base class. + """ + super().__init__( + std=std, + step_size=step_size, + top_percentage=top_percentage, + orthogonal_suggestions=orthogonal_suggestions, + quasirandom_suggestions=quasirandom_suggestions, + top_sort_type=top_sort_type, + obs_norm_data_buffer=obs_norm_data_buffer, + **kwargs, + ) + if agent_keys is None: + self._agent_keys = ["arm", "opp"] + else: + self._agent_keys = agent_keys + self._num_agents = len(self._agent_keys) + self._restore_state_from_single_agent = restore_state_from_single_agent + + def _split_params(self, params: np.ndarray) -> List[np.ndarray]: + return np.array_split(params, self._num_agents) + + def _combine_params(self, params_per_agents: List[np.ndarray]) -> np.ndarray: + return np.concatenate(params_per_agents, axis=0) + + def restore_state_from_checkpoint(self, new_state: Dict[str, Any]) -> None: + logging.info( + "Restore: restore from 1 agent: %d", + self._restore_state_from_single_agent, + ) + logging.info("Restore: num_agents: %d", self._num_agents) + logging.info("Restore: new state keys: %s", list(new_state.keys())) + logging.info( + "Restore: new_state params shape: %s", new_state["params_to_eval"].shape + ) + + # Initialize multiple agents from a single agent. + if self._restore_state_from_single_agent: + if new_state["params_to_eval"].ndim != 1: + raise ValueError( + f"Params to eval has {new_state['params_to_eval'].ndim} dims, " + "should only have 1." + ) + duplicated_state = { + "params_to_eval": np.tile( + new_state["params_to_eval"], self._num_agents + ) + } + if self._obs_norm_data_buffer is not None: + duplicated_state["obs_norm_state"] = {} + duplicated_state["obs_norm_state"]["mean"] = np.tile( + new_state["obs_norm_state"]["mean"], self._num_agents + ) + duplicated_state["obs_norm_state"]["std"] = np.tile( + new_state["obs_norm_state"]["std"], self._num_agents + ) + duplicated_state["obs_norm_state"]["n"] = new_state["obs_norm_state"][ + "n" + ] + + self.state = duplicated_state + logging.info( + "Restore: duplicated states params shape: %s", + duplicated_state["params_to_eval"].shape, + ) + + # Initialize one agent from a single agent. + else: + self.state = new_state + + logging.info( + "Restored state: params shape: %s, opt params shape: %s, " + "obs norm state: %s", + self.state["params_to_eval"].shape, + self._opt_params.shape, + self.state.get("obs_norm_state", None), + ) + if self._obs_norm_data_buffer is not None: + logging.info( + "Restored state: obs norm mean shape: %s, std shape: %s", + self.state["obs_norm_state"]["mean"].shape, + self.state["obs_norm_state"]["std"].shape, + ) + + def maybe_save_custom_checkpoint( + self, state: Dict[str, Any], checkpoint_path: Union[pathlib.Path, str] + ) -> None: + """Saves a checkpoint per agent with prefix checkpoint_path.""" + agent_params = self._split_params(state["params_to_eval"]) + for i in range(self._num_agents): + per_agent_state = {} + per_agent_state["params_to_eval"] = agent_params[i] + if self._obs_norm_data_buffer is not None: + obs_norm_state = state["obs_norm_state"] + elems_per_agent = int( + obs_norm_state["mean"].shape[-1] / self._num_agents + ) + per_agent_state["obs_norm_state"] = {} + start_idx = i * elems_per_agent + end_idx = (i + 1) * elems_per_agent + if obs_norm_state["mean"].ndim == 1: + per_agent_state["obs_norm_state"]["mean"] = obs_norm_state["mean"][ + start_idx:end_idx + ] + per_agent_state["obs_norm_state"]["std"] = obs_norm_state["std"][ + start_idx:end_idx + ] + else: + per_agent_state["obs_norm_state"]["mean"] = obs_norm_state["mean"][ + :, start_idx:end_idx + ] + per_agent_state["obs_norm_state"]["std"] = obs_norm_state["std"][ + :, start_idx:end_idx + ] + per_agent_state["obs_norm_state"]["n"] = obs_norm_state["n"] + agent_checkpoint_path = f"{checkpoint_path}_agent_{i}" + logging.info("Saving agent checkpoints to %s...", agent_checkpoint_path) + checkpoint_util.save_checkpoint(agent_checkpoint_path, per_agent_state) + + def split_and_save_checkpoint(self, checkpoint_path: str) -> None: + state = checkpoint_util.load_checkpoint_state(checkpoint_path) + self.maybe_save_custom_checkpoint( + state=state, checkpoint_path=checkpoint_path + ) + + def _get_top_evaluation_results( + self, + agent_key: str, + pos_eval_results: Sequence[worker_util.EvaluationResult], + neg_eval_results: Sequence[worker_util.EvaluationResult], + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + pos_evals = np.array( + [r.metrics[f"reward_{agent_key}"] for r in pos_eval_results] + ) + neg_evals = np.array( + [r.metrics[f"reward_{agent_key}"] for r in neg_eval_results] + ) + if self._top_sort_type == "max": + max_evals = np.max(np.vstack([pos_evals, neg_evals]), axis=0) + elif self._top_sort_type == "diff": + max_evals = np.abs(pos_evals - neg_evals) + else: + raise ValueError(f"Unknown top sort type: {self._top_sort_type}") + idx = (-max_evals).argsort()[: self._num_top] + pos_evals = pos_evals[idx] + neg_evals = neg_evals[idx] + return pos_evals, neg_evals, idx + + def process_evaluations( + self, eval_results: Sequence[worker_util.EvaluationResult] + ) -> None: + """Processes the list of Blackbox function evaluations return from workers. + + Gradient is computed by taking a weighted sum of directions and + difference of their value from the current value. The current parameter + vector is then updated in the gradient direction with specified step size. + + Args: + eval_results: List containing Blackbox function evaluations based on the + order in which the suggestions were sent. ARS performs antithetic + gradient estimation. The suggestions are sent for evaluation in pairs. + The eval_results list should contain an even number of entries with the + first half entries corresponding to evaluation result of positive + perturbations and the last half corresponding to negative perturbations. + """ + + # Retrieve delta direction from the param suggestion sent for evaluation. + pos_eval_results = eval_results[: self._num_suggestions] + neg_eval_results = eval_results[self._num_suggestions :] + filtered_pos_eval_results = [] + filtered_neg_eval_results = [] + for i in range(len(pos_eval_results)): + if (pos_eval_results[i].params_evaluated.size) and ( + neg_eval_results[i].params_evaluated.size + ): + filtered_pos_eval_results.append(pos_eval_results[i]) + filtered_neg_eval_results.append(neg_eval_results[i]) + + params = np.array([r.params_evaluated for r in filtered_pos_eval_results]) + eval_results = filtered_pos_eval_results + filtered_neg_eval_results + + # This is length num pos results with splits per agent + eval_params_per_agent = [self._split_params(p) for p in params] + eval_params_per_agent = list(zip(*eval_params_per_agent)) + # This has length num agents with a 2d array with shape + # (num_pos_results, agent_params_dim). + eval_params_per_agent = [np.array(a) for a in eval_params_per_agent] + + current_params_per_agent = self._split_params(self._opt_params) + updated_params_per_agent = [] + for agent_eval_params, agent_params, agent_key in zip( + eval_params_per_agent, current_params_per_agent, self._agent_keys + ): + pos_evals, neg_evals, idx = self._get_top_evaluation_results( + agent_key=agent_key, + pos_eval_results=filtered_pos_eval_results, + neg_eval_results=filtered_neg_eval_results, + ) + all_top_evals = np.hstack([pos_evals, neg_evals]) + evals = pos_evals - neg_evals + + # Get delta directions corresponding to top evals + directions = (agent_eval_params - agent_params) / self._std + directions = directions[idx, :] + + # Estimate gradients + gradient = np.dot(evals, directions) / evals.shape[0] + if not np.isclose(np.std(all_top_evals), 0.0): + gradient /= np.std(all_top_evals) + + # Apply gradients + updated_agent_params = agent_params + self._step_size * gradient + updated_params_per_agent.append(updated_agent_params) + + self._opt_params = self._combine_params(updated_params_per_agent) + + # Update the observation buffer + if self._obs_norm_data_buffer is not None: + for r in eval_results: + self._obs_norm_data_buffer.merge(r.obs_norm_buffer_data) diff --git a/iris/algorithms/multi_agent_ars_algorithm_test.py b/iris/algorithms/multi_agent_ars_algorithm_test.py index 634ff0e..dfc78f4 100644 --- a/iris/algorithms/multi_agent_ars_algorithm_test.py +++ b/iris/algorithms/multi_agent_ars_algorithm_test.py @@ -16,7 +16,7 @@ from iris import checkpoint_util from iris import normalizer from iris import worker_util -from iris.algorithms import ars_algorithm +from iris.algorithms import multi_agent_ars_algorithm import numpy as np import tensorflow as tf from absl.testing import absltest @@ -26,16 +26,17 @@ class AlgorithmTest(tf.test.TestCase, parameterized.TestCase): def _init_algo(self, agent_keys=None): - return ars_algorithm.MultiAgentAugmentedRandomSearch( + return multi_agent_ars_algorithm.MultiAgentAugmentedRandomSearch( num_suggestions=4, step_size=0.5, - std=1., + std=1.0, top_percentage=1, orthogonal_suggestions=True, quasirandom_suggestions=False, top_sort_type='diff', random_seed=7, - agent_keys=agent_keys) + agent_keys=agent_keys, + ) @parameterized.parameters( (None, ['arm', 'opp'], 2), @@ -49,82 +50,74 @@ def test_init(self, agent_keys, expected_agent_keys, expected_num_agents): def _build_evaluation_results(self) -> list[worker_util.EvaluationResult]: eval_results = [ worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars - params_evaluated=np.array([10., 11., 12., 13.]), + params_evaluated=np.array([10.0, 11.0, 12.0, 13.0]), value=10, - metrics={ - 'reward_arm': 10, - 'reward_opp': -5 - }), + metrics={'reward_arm': 10, 'reward_opp': -5}, + ), worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars - params_evaluated=np.array([10., 11., 14., 15.]), + params_evaluated=np.array([10.0, 11.0, 14.0, 15.0]), value=10, - metrics={ - 'reward_arm': 10, - 'reward_opp': -10 - }), + metrics={'reward_arm': 10, 'reward_opp': -10}, + ), worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars params_evaluated=np.empty(0), value=0, - metrics={ - 'reward_arm': 0, - 'reward_opp': 0 - }), + metrics={'reward_arm': 0, 'reward_opp': 0}, + ), worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars - params_evaluated=np.array([1., 2., 3., 4.]), + params_evaluated=np.array([1.0, 2.0, 3.0, 4.0]), value=10, - metrics={ - 'reward_arm': 10, - 'reward_opp': -10 - }), + metrics={'reward_arm': 10, 'reward_opp': -10}, + ), worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars - params_evaluated=np.array([10., 11., 12., 13.]), + params_evaluated=np.array([10.0, 11.0, 12.0, 13.0]), value=-10, - metrics={ - 'reward_arm': -10, - 'reward_opp': 5 - }), + metrics={'reward_arm': -10, 'reward_opp': 5}, + ), worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars - params_evaluated=np.array([10., 11., 14., 15.]), + params_evaluated=np.array([10.0, 11.0, 14.0, 15.0]), value=-10, - metrics={ - 'reward_arm': -10, - 'reward_opp': 10 - }), + metrics={'reward_arm': -10, 'reward_opp': 10}, + ), worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars - params_evaluated=np.array([5., 6., 7., 8.]), + params_evaluated=np.array([5.0, 6.0, 7.0, 8.0]), value=-10, - metrics={ - 'reward_arm': -10, - 'reward_opp': 10 - }), + metrics={'reward_arm': -10, 'reward_opp': 10}, + ), worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars params_evaluated=np.empty(0), value=0, - metrics={ - 'reward_arm': 0, - 'reward_opp': 0 - }), + metrics={'reward_arm': 0, 'reward_opp': 0}, + ), ] return eval_results @parameterized.parameters( - (['arm', 'opp'], [[10., 11.], [12., 13.]], 2), - (['1', '2', '3', '4'], [[10.], [11.], [12.], [13.]], 4), + (['arm', 'opp'], [[10.0, 11.0], [12.0, 13.0]], 2), + (['1', '2', '3', '4'], [[10.0], [11.0], [12.0], [13.0]], 4), ) def test_split_params(self, agent_keys, expected_split_params, num_agents): algo = self._init_algo(agent_keys=agent_keys) - params = np.array([10., 11., 12., 13.]) + params = np.array([10.0, 11.0, 12.0, 13.0]) split_params = algo._split_params(params) self.assertLen(split_params, num_agents) for p, exp_p in zip(split_params, expected_split_params): np.testing.assert_array_equal(p, np.asarray(exp_p)) @parameterized.parameters( - ([np.asarray([10., 11.]), np.asarray([12., 13.])], - np.asarray([10., 11., 12., 13.])), - ([np.asarray([10.]), np.asarray([11.]), - np.asarray([12.]), np.asarray([13.])], - np.asarray([10., 11., 12., 13.])), + ( + [np.asarray([10.0, 11.0]), np.asarray([12.0, 13.0])], + np.asarray([10.0, 11.0, 12.0, 13.0]), + ), + ( + [ + np.asarray([10.0]), + np.asarray([11.0]), + np.asarray([12.0]), + np.asarray([13.0]), + ], + np.asarray([10.0, 11.0, 12.0, 13.0]), + ), ) def test_combine_params(self, split_params, expected_combined_params): algo = self._init_algo() @@ -135,11 +128,9 @@ def test_combine_params(self, split_params, expected_combined_params): ('arm', np.asarray([10, 10]), np.asarray([-10, -10]), np.asarray([0, 1])), ('opp', np.asarray([-10, -5]), np.asarray([10, 5]), np.asarray([1, 0])), ) - def test_get_top_evaluation_results(self, - agent_key, - expected_pos_evals, - expected_neg_evals, - expected_idx): + def test_get_top_evaluation_results( + self, agent_key, expected_pos_evals, expected_neg_evals, expected_idx + ): algo = self._init_algo() eval_results = self._build_evaluation_results() filtered_pos_eval_results = eval_results[:2] @@ -147,22 +138,23 @@ def test_get_top_evaluation_results(self, pos_evals, neg_evals, idx = algo._get_top_evaluation_results( agent_key=agent_key, pos_eval_results=filtered_pos_eval_results, - neg_eval_results=filtered_neg_eval_results) + neg_eval_results=filtered_neg_eval_results, + ) np.testing.assert_array_equal(pos_evals, expected_pos_evals) np.testing.assert_array_equal(neg_evals, expected_neg_evals) np.testing.assert_array_equal(idx, expected_idx) def test_multi_agent_ars_gradient(self): algo = self._init_algo() - init_state = {'init_params': np.array([10., 10., 10., 10.])} + init_state = {'init_params': np.array([10.0, 10.0, 10.0, 10.0])} algo.initialize(init_state) suggestions = algo.get_param_suggestions() self.assertLen(suggestions, 8) eval_results = self._build_evaluation_results() algo.process_evaluations(eval_results) np.testing.assert_array_almost_equal( - algo._opt_params, - np.array([10., 11., 6.83772234, 5.88903904])) + algo._opt_params, np.array([10.0, 11.0, 6.83772234, 5.88903904]) + ) @parameterized.parameters( ({'params_to_eval': np.asarray([1, 2]), @@ -212,12 +204,10 @@ def test_multi_agent_ars_gradient(self): 'n': 5}}, True, 1), ) - def test_restore_state_from_checkpoint(self, - state, - expected_state, - restore_state_from_single_agent, - num_agents): - algo = ars_algorithm.MultiAgentAugmentedRandomSearch( + def test_restore_state_from_checkpoint( + self, state, expected_state, restore_state_from_single_agent, num_agents + ): + algo = multi_agent_ars_algorithm.MultiAgentAugmentedRandomSearch( num_suggestions=3, step_size=0.5, std=1.0, @@ -225,40 +215,44 @@ def test_restore_state_from_checkpoint(self, orthogonal_suggestions=True, quasirandom_suggestions=True, obs_norm_data_buffer=normalizer.MeanStdBuffer() - if state['obs_norm_state'] is not None else None, + if state['obs_norm_state'] is not None + else None, agent_keys=[str(i) for i in range(num_agents)], restore_state_from_single_agent=restore_state_from_single_agent, random_seed=7, ) self.assertEqual(algo._num_agents, num_agents) - init_state = {'init_params': np.array([10., 10.])} + init_state = {'init_params': np.array([10.0, 10.0])} if state['obs_norm_state'] is not None: - init_state['obs_norm_buffer_data'] = {'mean': np.asarray([0., 0.]), - 'std': np.asarray([1., 1.]), - 'n': 0} + init_state['obs_norm_buffer_data'] = { + 'mean': np.asarray([0.0, 0.0]), + 'std': np.asarray([1.0, 1.0]), + 'n': 0, + } algo.initialize(init_state) with self.subTest('init-mean'): - self.assertAllClose(np.array(algo._opt_params), - init_state['init_params']) + self.assertAllClose(np.array(algo._opt_params), init_state['init_params']) if state['obs_norm_state'] is not None: with self.subTest('init-obs-mean'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['mean']), - np.asarray(init_state['obs_norm_buffer_data']['mean'])) + np.asarray(init_state['obs_norm_buffer_data']['mean']), + ) with self.subTest('init-obs-n'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['n']), - np.asarray(init_state['obs_norm_buffer_data']['n'])) + np.asarray(init_state['obs_norm_buffer_data']['n']), + ) with self.subTest('init-obs-std'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['std']), - init_state['obs_norm_buffer_data']['std']) + init_state['obs_norm_buffer_data']['std'], + ) algo.restore_state_from_checkpoint(state) - self.assertAllClose(algo._opt_params, - expected_state['params_to_eval']) + self.assertAllClose(algo._opt_params, expected_state['params_to_eval']) if expected_state['obs_norm_state'] is not None: std = expected_state['obs_norm_state']['std'] var = np.square(std) @@ -266,15 +260,18 @@ def test_restore_state_from_checkpoint(self, with self.subTest('restore-obs-mean'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['mean']), - np.asarray(expected_state['obs_norm_state']['mean'])) + np.asarray(expected_state['obs_norm_state']['mean']), + ) with self.subTest('restore-obs-n'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['n']), - np.asarray(expected_state['obs_norm_state']['n'])) + np.asarray(expected_state['obs_norm_state']['n']), + ) with self.subTest('restore-obs-std'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['unnorm_var']), - expected_unnorm_var) + expected_unnorm_var, + ) @parameterized.parameters( ( @@ -299,50 +296,62 @@ def test_restore_state_from_checkpoint(self, ( { 'params_to_eval': np.asarray([1, 2, 3, 4, 5, 6]), - 'obs_norm_state': {'mean': np.asarray([6., 7., 8., 9.]), - 'std': np.asarray([10., 11., 12., 13.]), - 'n': 5 - }, + 'obs_norm_state': { + 'mean': np.asarray([6.0, 7.0, 8.0, 9.0]), + 'std': np.asarray([10.0, 11.0, 12.0, 13.0]), + 'n': 5, + }, }, [ - {'params_to_eval': np.asarray([1, 2, 3]), - 'obs_norm_state': {'mean': np.asarray([6., 7.]), - 'std': np.asarray([10., 11.]), - 'n': 5 - }, - }, - {'params_to_eval': np.asarray([4, 5, 6]), - 'obs_norm_state': {'mean': np.asarray([8., 9.]), - 'std': np.asarray([12., 13.]), - 'n': 5 - }, - }, + { + 'params_to_eval': np.asarray([1, 2, 3]), + 'obs_norm_state': { + 'mean': np.asarray([6.0, 7.0]), + 'std': np.asarray([10.0, 11.0]), + 'n': 5, + }, + }, + { + 'params_to_eval': np.asarray([4, 5, 6]), + 'obs_norm_state': { + 'mean': np.asarray([8.0, 9.0]), + 'std': np.asarray([12.0, 13.0]), + 'n': 5, + }, + }, ], 2, ), ( { 'params_to_eval': np.asarray([1, 2, 3, 4, 5, 6]), - 'obs_norm_state': {'mean': np.asarray([[6., 7., 8., 9.], - [10., 11., 12., 13.]]), - 'std': np.asarray([[14., 15., 16., 17.], - [18., 19., 20., 21.]]), - 'n': 5 - }, + 'obs_norm_state': { + 'mean': np.asarray( + [[6.0, 7.0, 8.0, 9.0], [10.0, 11.0, 12.0, 13.0]] + ), + 'std': np.asarray( + [[14.0, 15.0, 16.0, 17.0], [18.0, 19.0, 20.0, 21.0]] + ), + 'n': 5, + }, }, [ - {'params_to_eval': np.asarray([1, 2, 3]), - 'obs_norm_state': {'mean': np.asarray([[6., 7.,], [10., 11.,]]), - 'std': np.asarray([[14., 15.,], [18., 19.,]]), - 'n': 5 - }, - }, - {'params_to_eval': np.asarray([4, 5, 6]), - 'obs_norm_state': {'mean': np.asarray([[8., 9.], [12., 13.]]), - 'std': np.asarray([[16., 17.], [20., 21.]]), - 'n': 5 - }, - }, + { + 'params_to_eval': np.asarray([1, 2, 3]), + 'obs_norm_state': { + 'mean': np.asarray([[6.0, 7.0], [10.0, 11.0]]), + 'std': np.asarray([[14.0, 15.0], [18.0, 19.0]]), + 'n': 5, + }, + }, + { + 'params_to_eval': np.asarray([4, 5, 6]), + 'obs_norm_state': { + 'mean': np.asarray([[8.0, 9.0], [12.0, 13.0]]), + 'std': np.asarray([[16.0, 17.0], [20.0, 21.0]]), + 'n': 5, + }, + }, ], 2, ), @@ -361,46 +370,54 @@ def test_restore_state_from_checkpoint(self, ( { 'params_to_eval': np.asarray([1, 2, 3, 4, 5, 6]), - 'obs_norm_state': {'mean': np.asarray( - [[6., 7., 8., 9., 10., 11.], - [12., 13., 14., 15., 16., 17.]]), - 'std': np.asarray( - [[14., 15., 16., 17., 18., 19.], - [20., 21., 22., 23., 24., 25.]]), - 'n': 5 - }, + 'obs_norm_state': { + 'mean': np.asarray([ + [6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0, 16.0, 17.0], + ]), + 'std': np.asarray([ + [14.0, 15.0, 16.0, 17.0, 18.0, 19.0], + [20.0, 21.0, 22.0, 23.0, 24.0, 25.0], + ]), + 'n': 5, + }, }, [ - {'params_to_eval': np.asarray([1, 2]), - 'obs_norm_state': {'mean': np.asarray([[6., 7.,], [12., 13.,]]), - 'std': np.asarray([[14., 15.,], [20., 21.,]]), - 'n': 5 - }, - }, - {'params_to_eval': np.asarray([3, 4]), - 'obs_norm_state': {'mean': np.asarray([[8., 9.], [14., 15.]]), - 'std': np.asarray([[16., 17.], [22., 23.]]), - 'n': 5 - }, - }, - {'params_to_eval': np.asarray([5, 6]), - 'obs_norm_state': {'mean': np.asarray([[10., 11.], [16., 17.]]), - 'std': np.asarray([[18., 19.], [24., 25.]]), - 'n': 5 - }, - }, + { + 'params_to_eval': np.asarray([1, 2]), + 'obs_norm_state': { + 'mean': np.asarray([[6.0, 7.0], [12.0, 13.0]]), + 'std': np.asarray([[14.0, 15.0], [20.0, 21.0]]), + 'n': 5, + }, + }, + { + 'params_to_eval': np.asarray([3, 4]), + 'obs_norm_state': { + 'mean': np.asarray([[8.0, 9.0], [14.0, 15.0]]), + 'std': np.asarray([[16.0, 17.0], [22.0, 23.0]]), + 'n': 5, + }, + }, + { + 'params_to_eval': np.asarray([5, 6]), + 'obs_norm_state': { + 'mean': np.asarray([[10.0, 11.0], [16.0, 17.0]]), + 'std': np.asarray([[18.0, 19.0], [24.0, 25.0]]), + 'n': 5, + }, + }, ], 3, ), ) - def test_maybe_save_custom_checkpoint(self, - state, - expected_states, - num_agents): + def test_maybe_save_custom_checkpoint( + self, state, expected_states, num_agents + ): tempdir = self.create_tempdir() path = 'checkpoint_iteration_0' full_path = os.path.join(tempdir, path) - algo = ars_algorithm.MultiAgentAugmentedRandomSearch( + algo = multi_agent_ars_algorithm.MultiAgentAugmentedRandomSearch( num_suggestions=3, step_size=0.5, std=1.0, @@ -408,28 +425,37 @@ def test_maybe_save_custom_checkpoint(self, orthogonal_suggestions=True, quasirandom_suggestions=True, obs_norm_data_buffer=normalizer.MeanStdBuffer() - if state['obs_norm_state'] is not None else None, + if state['obs_norm_state'] is not None + else None, agent_keys=[str(i) for i in range(num_agents)], - random_seed=7) + random_seed=7, + ) algo.maybe_save_custom_checkpoint(state, full_path) for i in range(num_agents): agent_checkpoint_path = f'{full_path}_agent_{i}' agent_state = checkpoint_util.load_checkpoint_state(agent_checkpoint_path) - self.assertAllClose(agent_state['params_to_eval'], - expected_states[i]['params_to_eval']) + self.assertAllClose( + agent_state['params_to_eval'], expected_states[i]['params_to_eval'] + ) if expected_states[i]['obs_norm_state'] is not None: - self.assertAllClose(agent_state['obs_norm_state']['mean'], - expected_states[i]['obs_norm_state']['mean']) - self.assertAllClose(agent_state['obs_norm_state']['std'], - expected_states[i]['obs_norm_state']['std']) - self.assertAllClose(agent_state['obs_norm_state']['n'], - expected_states[i]['obs_norm_state']['n']) + self.assertAllClose( + agent_state['obs_norm_state']['mean'], + expected_states[i]['obs_norm_state']['mean'], + ) + self.assertAllClose( + agent_state['obs_norm_state']['std'], + expected_states[i]['obs_norm_state']['std'], + ) + self.assertAllClose( + agent_state['obs_norm_state']['n'], + expected_states[i]['obs_norm_state']['n'], + ) def test_split_checkpoint(self): tempdir = self.create_tempdir() path = 'checkpoint_iteration_0' full_path = os.path.join(tempdir, path) - algo = ars_algorithm.MultiAgentAugmentedRandomSearch( + algo = multi_agent_ars_algorithm.MultiAgentAugmentedRandomSearch( num_suggestions=3, step_size=0.5, std=1.0, @@ -438,7 +464,8 @@ def test_split_checkpoint(self): quasirandom_suggestions=True, obs_norm_data_buffer=normalizer.MeanStdBuffer(), agent_keys=[str(i) for i in range(3)], - random_seed=7) + random_seed=7, + ) state = { 'params_to_eval': np.asarray([1, 2, 3, 4, 5, 6]), 'obs_norm_state': { @@ -477,21 +504,29 @@ def test_split_checkpoint(self): 'std': np.asarray([[18.0, 19.0], [24.0, 25.0]]), 'n': 5, }, - }] + }, + ] checkpoint_util.save_checkpoint(full_path, state) algo.split_and_save_checkpoint(checkpoint_path=full_path) for i in range(3): agent_checkpoint_path = f'{full_path}_agent_{i}' agent_state = checkpoint_util.load_checkpoint_state(agent_checkpoint_path) - self.assertAllClose(agent_state['params_to_eval'], - expected_states[i]['params_to_eval']) + self.assertAllClose( + agent_state['params_to_eval'], expected_states[i]['params_to_eval'] + ) if expected_states[i]['obs_norm_state'] is not None: - self.assertAllClose(agent_state['obs_norm_state']['mean'], - expected_states[i]['obs_norm_state']['mean']) - self.assertAllClose(agent_state['obs_norm_state']['std'], - expected_states[i]['obs_norm_state']['std']) - self.assertAllClose(agent_state['obs_norm_state']['n'], - expected_states[i]['obs_norm_state']['n']) + self.assertAllClose( + agent_state['obs_norm_state']['mean'], + expected_states[i]['obs_norm_state']['mean'], + ) + self.assertAllClose( + agent_state['obs_norm_state']['std'], + expected_states[i]['obs_norm_state']['std'], + ) + self.assertAllClose( + agent_state['obs_norm_state']['n'], + expected_states[i]['obs_norm_state']['n'], + ) if __name__ == '__main__': diff --git a/iris/algorithms/run_split_checkpoint.py b/iris/algorithms/run_split_checkpoint.py index a7d0b1a..a2968a1 100644 --- a/iris/algorithms/run_split_checkpoint.py +++ b/iris/algorithms/run_split_checkpoint.py @@ -19,21 +19,24 @@ from absl import app from absl import flags from iris import normalizer -from iris.algorithms import ars_algorithm +from iris.algorithms import multi_agent_ars_algorithm _NUM_AGENTS = flags.DEFINE_integer('num_agents', 2, 'Number of agents.') _HAS_OBS_NORM = flags.DEFINE_boolean( - 'has_obs_norm', True, - 'Whether the checkpoint has observation normalization') + 'has_obs_norm', True, 'Whether the checkpoint has observation normalization' +) _CHECKPOINT_PATH = flags.DEFINE_string( - 'checkpoint_path', None, 'Path to checkpoint.', required=True) + 'checkpoint_path', None, 'Path to checkpoint.', required=True +) -def split_and_save_checkpoint(checkpoint_path: str, - num_agents: int = 2, - has_obs_norm_data_buffer: bool = False) -> None: +def split_and_save_checkpoint( + checkpoint_path: str, + num_agents: int = 2, + has_obs_norm_data_buffer: bool = False, +) -> None: """Splits the checkpoint at checkpoint_path into num_agents checkpoints.""" - algo = ars_algorithm.MultiAgentAugmentedRandomSearch( + algo = multi_agent_ars_algorithm.MultiAgentAugmentedRandomSearch( num_suggestions=3, step_size=0.5, std=1.0, @@ -41,18 +44,22 @@ def split_and_save_checkpoint(checkpoint_path: str, orthogonal_suggestions=True, quasirandom_suggestions=True, obs_norm_data_buffer=normalizer.MeanStdBuffer() - if has_obs_norm_data_buffer else None, + if has_obs_norm_data_buffer + else None, agent_keys=[str(i) for i in range(num_agents)], - random_seed=7) + random_seed=7, + ) algo.split_and_save_checkpoint(checkpoint_path=checkpoint_path) def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') - split_and_save_checkpoint(checkpoint_path=_CHECKPOINT_PATH.value, - num_agents=_NUM_AGENTS.value, - has_obs_norm_data_buffer=_HAS_OBS_NORM.value) + split_and_save_checkpoint( + checkpoint_path=_CHECKPOINT_PATH.value, + num_agents=_NUM_AGENTS.value, + has_obs_norm_data_buffer=_HAS_OBS_NORM.value, + ) if __name__ == '__main__':