diff --git a/ml-agents/mlagents/trainers/agent_processor.py b/ml-agents/mlagents/trainers/agent_processor.py index 720f3d14bd..23768cf109 100644 --- a/ml-agents/mlagents/trainers/agent_processor.py +++ b/ml-agents/mlagents/trainers/agent_processor.py @@ -27,7 +27,7 @@ GlobalAgentId, GlobalGroupId, ) -from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple +from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple, MusTuple, SigmasTuple from mlagents.trainers.torch_entities.utils import ModelUtils T = TypeVar("T") @@ -251,6 +251,28 @@ def _process_step( except KeyError: log_probs_tuple = LogProbsTuple.empty_log_probs() + try: + stored_action_mus = stored_take_action_outputs["mus"] + if not isinstance(stored_action_mus, MusTuple): + stored_action_mus = stored_action_mus.to_mus_tuple() + mus_tuple = MusTuple( + continuous=stored_action_mus.continuous[idx], + discrete=stored_action_mus.discrete[idx], + ) + except KeyError: + mus_tuple = MusTuple.empty_mus() + + try: + stored_action_sigmas = stored_take_action_outputs["sigmas"] + if not isinstance(stored_action_sigmas, SigmasTuple): + stored_action_sigmas = stored_action_sigmas.to_sigmas_tuple() + sigmas_tuple = SigmasTuple( + continuous=stored_action_sigmas.continuous[idx], + discrete=stored_action_sigmas.discrete[idx], + ) + except KeyError: + sigmas_tuple = MusTuple.empty_mus() + action_mask = stored_decision_step.action_mask prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :] @@ -266,6 +288,8 @@ def _process_step( done=done, action=action_tuple, action_probs=log_probs_tuple, + action_mus=mus_tuple, + action_sigmas=sigmas_tuple, action_mask=action_mask, prev_action=prev_action, interrupted=interrupted, diff --git a/ml-agents/mlagents/trainers/buffer.py b/ml-agents/mlagents/trainers/buffer.py index ea6a2d5111..dd60a14203 100644 --- a/ml-agents/mlagents/trainers/buffer.py +++ b/ml-agents/mlagents/trainers/buffer.py @@ -27,6 +27,10 @@ class BufferKey(enum.Enum): CONTINUOUS_ACTION = "continuous_action" NEXT_CONT_ACTION = "next_continuous_action" CONTINUOUS_LOG_PROBS = "continuous_log_probs" + CONTINUOUS_MUS = "continuous_mus" + DISCRETE_MUS = "discrete_mus" + CONTINUOUS_SIGMAS = "continuous_sigmas" + DISCRETE_SIGMAS = "discrete_sigmas" DISCRETE_ACTION = "discrete_action" NEXT_DISC_ACTION = "next_discrete_action" DISCRETE_LOG_PROBS = "discrete_log_probs" diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 41a452c65c..6e465cf563 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -15,7 +15,7 @@ ) from mlagents.trainers.torch_entities.networks import ValueNetwork from mlagents.trainers.torch_entities.agent_action import AgentAction -from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs +from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs, ActionMus, ActionSigmas from mlagents.trainers.torch_entities.utils import ModelUtils from mlagents.trainers.trajectory import ObsUtil @@ -66,8 +66,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): self.decay_learning_rate = ModelUtils.DecayedValue( self.hyperparameters.learning_rate_schedule, self.hyperparameters.learning_rate, - 1e-10, + self.hyperparameters.lr_min, self.trainer_settings.max_steps, + self.hyperparameters.desired_lr_kl, + self.hyperparameters.lr_max ) self.decay_epsilon = ModelUtils.DecayedValue( self.hyperparameters.epsilon_schedule, @@ -92,6 +94,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): self.stream_names = list(self.reward_signals.keys()) + self.loss = torch.zeros(1, device=default_device()) + + self.last_actions = None + @property def critic(self): return self._critic @@ -153,6 +159,8 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: log_probs = run_out["log_probs"] entropy = run_out["entropy"] + mus = run_out["mus"] + sigmas = run_out["sigmas"] values, _ = self.critic.critic_pass( current_obs, @@ -160,6 +168,8 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: sequence_length=self.policy.sequence_length, ) old_log_probs = ActionLogProbs.from_buffer(batch).flatten() + old_mus = ActionMus.from_buffer(batch).flatten() + old_sigmas = ActionSigmas.from_buffer(batch).flatten() log_probs = log_probs.flatten() loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) value_loss = ModelUtils.trust_region_value_loss( @@ -172,16 +182,22 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: loss_masks, decay_eps, ) - loss = ( + self.loss = ( policy_loss + 0.5 * value_loss - decay_bet * ModelUtils.masked_mean(entropy, loss_masks) ) + # adaptive learning rate + if self.hyperparameters.learning_rate_schedule == ScheduleType.ADAPTIVE: + decay_lr = self.decay_learning_rate.get_value( + self.policy.get_current_step(), mus, old_mus, sigmas, old_sigmas + ) + # Set optimizer learning rate ModelUtils.update_learning_rate(self.optimizer, decay_lr) self.optimizer.zero_grad() - loss.backward() + self.loss.backward() self.optimizer.step() update_stats = { @@ -194,6 +210,8 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: "Policy/Beta": decay_bet, } + self.loss = torch.zeros(1, device=default_device()) + return update_stats # TODO move module update into TorchOptimizer for reward_provider diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py index 9cb9a1f291..4c97ba9fcf 100644 --- a/ml-agents/mlagents/trainers/settings.py +++ b/ml-agents/mlagents/trainers/settings.py @@ -105,6 +105,7 @@ class EncoderType(Enum): class ScheduleType(Enum): CONSTANT = "constant" LINEAR = "linear" + ADAPTIVE = "adaptive" # TODO add support for lesson based scheduling # LESSON = "lesson" @@ -158,6 +159,9 @@ class HyperparamSettings: batch_size: int = 1024 buffer_size: int = 10240 learning_rate: float = 3.0e-4 + desired_lr_kl: float = 0.008 + lr_min: float = 1.0e-10 + lr_max: float = 1.0e-2 learning_rate_schedule: ScheduleType = ScheduleType.CONSTANT diff --git a/ml-agents/mlagents/trainers/torch_entities/action_log_probs.py b/ml-agents/mlagents/trainers/torch_entities/action_log_probs.py index b72e7bb223..788365d14c 100644 --- a/ml-agents/mlagents/trainers/torch_entities/action_log_probs.py +++ b/ml-agents/mlagents/trainers/torch_entities/action_log_probs.py @@ -7,6 +7,26 @@ from mlagents_envs.base_env import _ActionTupleBase +class MusTuple(_ActionTupleBase): + @property + def discrete_dtype(self) -> np.dtype: + return np.float32 + + @staticmethod + def empty_mus() -> "MusTuple": + return MusTuple() + + +class SigmasTuple(_ActionTupleBase): + @property + def discrete_dtype(self) -> np.dtype: + return np.float32 + + @staticmethod + def empty_sigmas() -> "SigmasTuple": + return SigmasTuple() + + class LogProbsTuple(_ActionTupleBase): """ An object whose fields correspond to the log probs of actions of different types. @@ -116,3 +136,141 @@ def from_buffer(buff: AgentBuffer) -> "ActionLogProbs": discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) ] return ActionLogProbs(continuous, discrete, None) + + +class ActionMus(NamedTuple): + continuous_tensor: torch.Tensor + discrete_list: Optional[List[torch.Tensor]] + all_discrete_list: Optional[List[torch.Tensor]] + + @property + def discrete_tensor(self): + """ + Returns the discrete log probs list as a stacked tensor + """ + return torch.stack(self.discrete_list, dim=-1) + + @property + def all_discrete_tensor(self): + """ + Returns the discrete log probs of each branch as a tensor + """ + return torch.cat(self.all_discrete_list, dim=1) + + def to_mus_tuple(self) -> MusTuple: + mus_tuple = MusTuple() + if self.continuous_tensor is not None: + continuous = ModelUtils.to_numpy(self.continuous_tensor) + mus_tuple.add_continuous(continuous) + if self.discrete_list is not None: + discrete = ModelUtils.to_numpy(self.discrete_tensor) + mus_tuple.add_discrete(discrete) + return mus_tuple + + def _to_tensor_list(self) -> List[torch.Tensor]: + """ + Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This + is private and serves as a utility for self.flatten() + """ + tensor_list: List[torch.Tensor] = [] + if self.continuous_tensor is not None: + tensor_list.append(self.continuous_tensor) + if self.discrete_list is not None: + tensor_list.append(self.discrete_tensor) + return tensor_list + + def flatten(self) -> torch.Tensor: + """ + A utility method that returns all log probs in ActionLogProbs as a flattened tensor. + This is useful for algorithms like PPO which can treat all log probs in the same way. + """ + return torch.cat(self._to_tensor_list(), dim=1) + + @staticmethod + def from_buffer(buff: AgentBuffer) -> "ActionMus": + """ + A static method that accesses continuous and discrete log probs fields in an AgentBuffer + and constructs the corresponding ActionLogProbs from the retrieved np arrays. + """ + continuous: torch.Tensor = None + discrete: List[torch.Tensor] = None # type: ignore + + if BufferKey.CONTINUOUS_MUS in buff: + continuous = ModelUtils.list_to_tensor(buff[BufferKey.CONTINUOUS_MUS]) + if BufferKey.DISCRETE_MUS in buff: + discrete_tensor = ModelUtils.list_to_tensor(buff[BufferKey.DISCRETE_MUS]) + # This will keep discrete_list = None which enables flatten() + if discrete_tensor.shape[1] > 0: + discrete = [ + discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) + ] + return ActionMus(continuous, discrete, None) + + +class ActionSigmas(NamedTuple): + continuous_tensor: torch.Tensor + discrete_list: Optional[List[torch.Tensor]] + all_discrete_list: Optional[List[torch.Tensor]] + + @property + def discrete_tensor(self): + """ + Returns the discrete log probs list as a stacked tensor + """ + return torch.stack(self.discrete_list, dim=-1) + + @property + def all_discrete_tensor(self): + """ + Returns the discrete log probs of each branch as a tensor + """ + return torch.cat(self.all_discrete_list, dim=1) + + def to_sigmas_tuple(self) -> SigmasTuple: + sigmas_tuple = SigmasTuple() + if self.continuous_tensor is not None: + continuous = ModelUtils.to_numpy(self.continuous_tensor) + sigmas_tuple.add_continuous(continuous) + if self.discrete_list is not None: + discrete = ModelUtils.to_numpy(self.discrete_tensor) + sigmas_tuple.add_discrete(discrete) + return sigmas_tuple + + def _to_tensor_list(self) -> List[torch.Tensor]: + """ + Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This + is private and serves as a utility for self.flatten() + """ + tensor_list: List[torch.Tensor] = [] + if self.continuous_tensor is not None: + tensor_list.append(self.continuous_tensor) + if self.discrete_list is not None: + tensor_list.append(self.discrete_tensor) + return tensor_list + + def flatten(self) -> torch.Tensor: + """ + A utility method that returns all log probs in ActionLogProbs as a flattened tensor. + This is useful for algorithms like PPO which can treat all log probs in the same way. + """ + return torch.cat(self._to_tensor_list(), dim=1) + + @staticmethod + def from_buffer(buff: AgentBuffer) -> "ActionSigmas": + """ + A static method that accesses continuous and discrete log probs fields in an AgentBuffer + and constructs the corresponding ActionLogProbs from the retrieved np arrays. + """ + continuous: torch.Tensor = None + discrete: List[torch.Tensor] = None # type: ignore + + if BufferKey.CONTINUOUS_SIGMAS in buff: + continuous = ModelUtils.list_to_tensor(buff[BufferKey.CONTINUOUS_SIGMAS]) + if BufferKey.DISCRETE_SIGMAS in buff: + discrete_tensor = ModelUtils.list_to_tensor(buff[BufferKey.DISCRETE_SIGMAS]) + # This will keep discrete_list = None which enables flatten() + if discrete_tensor.shape[1] > 0: + discrete = [ + discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) + ] + return ActionSigmas(continuous, discrete, None) diff --git a/ml-agents/mlagents/trainers/torch_entities/action_model.py b/ml-agents/mlagents/trainers/torch_entities/action_model.py index 7b88c0262d..15551ba3ca 100644 --- a/ml-agents/mlagents/trainers/torch_entities/action_model.py +++ b/ml-agents/mlagents/trainers/torch_entities/action_model.py @@ -7,7 +7,7 @@ MultiCategoricalDistribution, ) from mlagents.trainers.torch_entities.agent_action import AgentAction -from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs +from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs, ActionMus, ActionSigmas from mlagents_envs.base_env import ActionSpec @@ -146,9 +146,25 @@ def _get_probs_and_entropy( entropies = torch.cat(entropies_list, dim=1) return action_log_probs, entropies + def _get_mus_and_sigmas(self, actions, dists): + continuous_mus: Optional[torch.Tensor] = None + continuous_sigmas: Optional[torch.Tensor] = None + discrete_mus: Optional[torch.Tensor] = None + discrete_sigmas: Optional[torch.Tensor] = None + all_discrete_mus: Optional[List[torch.Tensor]] = None + all_discrete_sigmas: Optional[List[torch.Tensor]] = None + if dists.continuous is not None: + continuous_mus = dists.continuous.mu() + continuous_sigmas = dists.continuous.sigma() + action_mus = ActionMus(continuous_mus, discrete_mus, all_discrete_mus) + action_sigmas = ActionSigmas( + continuous_sigmas, discrete_sigmas, all_discrete_sigmas + ) + return action_mus, action_sigmas + def evaluate( self, inputs: torch.Tensor, masks: torch.Tensor, actions: AgentAction - ) -> Tuple[ActionLogProbs, torch.Tensor]: + ) -> Tuple[ActionLogProbs, torch.Tensor, torch.Tensor, torch.Tensor]: """ Given actions and encoding from the network body, gets the distributions and computes the log probabilites and entropies. @@ -159,9 +175,12 @@ def evaluate( """ dists = self._get_dists(inputs, masks) log_probs, entropies = self._get_probs_and_entropy(actions, dists) + # mus = dists.continuous.deterministic_sample() + mus = dists.continuous.mu() + sigmas = dists.continuous.sigma() # Use the sum of entropy across actions, not the mean entropy_sum = torch.sum(entropies, dim=1) - return log_probs, entropy_sum + return log_probs, entropy_sum, mus, sigmas def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: """ @@ -228,4 +247,5 @@ def forward( log_probs, entropies = self._get_probs_and_entropy(actions, dists) # Use the sum of entropy across actions, not the mean entropy_sum = torch.sum(entropies, dim=1) - return (actions, log_probs, entropy_sum) + mus, sigmas = self._get_mus_and_sigmas(actions, dists) + return (actions, log_probs, entropy_sum, mus, sigmas) diff --git a/ml-agents/mlagents/trainers/torch_entities/distributions.py b/ml-agents/mlagents/trainers/torch_entities/distributions.py index c324e172f4..9b7f87e292 100644 --- a/ml-agents/mlagents/trainers/torch_entities/distributions.py +++ b/ml-agents/mlagents/trainers/torch_entities/distributions.py @@ -23,6 +23,14 @@ def deterministic_sample(self) -> torch.Tensor: """ pass + @abc.abstractmethod + def mu(self): + pass + + @abc.abstractmethod + def sigma(self): + pass + @abc.abstractmethod def log_prob(self, value: torch.Tensor) -> torch.Tensor: """ @@ -69,6 +77,12 @@ def sample(self): def deterministic_sample(self): return self.mean + def mu(self): + return self.mean + + def sigma(self): + return self.std + def log_prob(self, value): var = self.std**2 log_scale = torch.log(self.std + EPSILON) diff --git a/ml-agents/mlagents/trainers/torch_entities/networks.py b/ml-agents/mlagents/trainers/torch_entities/networks.py index 555268075c..e5168abd75 100644 --- a/ml-agents/mlagents/trainers/torch_entities/networks.py +++ b/ml-agents/mlagents/trainers/torch_entities/networks.py @@ -531,6 +531,15 @@ def get_action_and_stats( """ pass + def get_mus( + self, + inputs: List[torch.Tensor], + masks: Optional[torch.Tensor] = None, + memories: Optional[torch.Tensor] = None, + sequence_length: int = 1, + ) -> Dict[str, Any]: + pass + def get_stats( self, inputs: List[torch.Tensor], @@ -637,7 +646,7 @@ def get_action_and_stats( encoding, memories = self.network_body( inputs, memories=memories, sequence_length=sequence_length ) - action, log_probs, entropies = self.action_model(encoding, masks) + action, log_probs, entropies, mus, sigmas = self.action_model(encoding, masks) run_out = {} # This is the clipped action which is not saved to the buffer # but is exclusively sent to the environment. @@ -646,9 +655,32 @@ def get_action_and_stats( ) run_out["log_probs"] = log_probs run_out["entropy"] = entropies + run_out["mus"] = mus + run_out["sigmas"] = sigmas return action, run_out, memories + def get_mus( + self, + inputs: List[torch.Tensor], + masks: Optional[torch.Tensor] = None, + memories: Optional[torch.Tensor] = None, + sequence_length: int = 1, + ) -> Dict[str, Any]: + encoding, actor_mem_outs = self.network_body( + inputs, memories=memories, sequence_length=sequence_length + ) + + ( + continuous_out, + discrete_out, + action_out_deprecated, + deterministic_continuous_out, + deterministic_discrete_out, + ) = self.action_model.get_action_out(encoding, masks) + run_out = {"mus": deterministic_continuous_out} + return run_out + def get_stats( self, inputs: List[torch.Tensor], @@ -661,10 +693,13 @@ def get_stats( inputs, memories=memories, sequence_length=sequence_length ) - log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) + log_probs, entropies, mus, sigmas = self.action_model.evaluate(encoding, masks, actions) run_out = {} run_out["log_probs"] = log_probs run_out["entropy"] = entropies + run_out["mus"] = mus + run_out["sigmas"] = sigmas + return run_out def forward( diff --git a/ml-agents/mlagents/trainers/torch_entities/utils.py b/ml-agents/mlagents/trainers/torch_entities/utils.py index d5381cbecb..200f9cbf81 100644 --- a/ml-agents/mlagents/trainers/torch_entities/utils.py +++ b/ml-agents/mlagents/trainers/torch_entities/utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Dict +from typing import List, Optional, Tuple, Dict, Any from mlagents.torch_utils import torch, nn from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization import numpy as np @@ -67,26 +67,43 @@ def __init__( initial_value: float, min_value: float, max_step: int, + desired_kl: float = None, + max_value: float = None, ): """ - Object that represnets value of a parameter that should be decayed, assuming it is a function of + Object that represents value of a parameter that should be decayed, assuming it is a function of global_step. :param schedule: Type of learning rate schedule. :param initial_value: Initial value before decay. :param min_value: Decay value to this value by max_step. :param max_step: The final step count where the return value should equal min_value. - :param global_step: The current step count. + :param desired_kl: Target KL. + :param max_value: Max value. :return: The value. """ self.schedule = schedule self.initial_value = initial_value + self.current_value = initial_value self.min_value = min_value self.max_step = max_step + self.desired_kl = desired_kl + self.max_value = max_value - def get_value(self, global_step: int) -> float: + def get_value( + self, + global_step: int, + mus: Any = None, + old_mus: Any = None, + sigmas: Any = None, + old_sigmas: Any = None, + ) -> float: """ Get the value at a given global step. :param global_step: Step count. + :param mus: Mean value. + :param old_mus: Old mean value. + :param sigmas: Sigma values. + :param old_sigmas: Old sigma values. :returns: Decayed value at this global step. """ if self.schedule == ScheduleType.CONSTANT: @@ -95,6 +112,18 @@ def get_value(self, global_step: int) -> float: return ModelUtils.polynomial_decay( self.initial_value, self.min_value, self.max_step, global_step ) + elif self.schedule == ScheduleType.ADAPTIVE: + self.current_value = ModelUtils.adaptive_decay( + self.current_value, + self.desired_kl, + self.max_value, + self.min_value, + mus, + old_mus, + sigmas, + old_sigmas, + ) + return self.current_value else: raise UnityTrainerException(f"The schedule {self.schedule} is invalid.") @@ -121,6 +150,37 @@ def polynomial_decay( ) ** (power) + min_value return decayed_value + @staticmethod + def adaptive_decay( + current_value: float, + desired_kl: float, + max_value: float, + min_value: float, + mus: Any = None, + old_mus: Any = None, + sigmas: Any = None, + old_sigmas: Any = None, + ) -> float: + if mus is None or old_mus is None or sigmas is None or old_sigmas is None: + return current_value + decayed_value = current_value + kl_star = desired_kl + with torch.no_grad(): + kl = torch.sum( + torch.log(sigmas / old_sigmas + 1.0e-5) + + (torch.square(old_sigmas) + torch.square(old_mus - mus)) + / (2.0 * torch.square(sigmas)) + - 0.5, + dim=-1, + ) + kl_mean = kl.mean() + # print(f"KL: {kl_mean}") + if kl_mean > kl_star * 2.0: + decayed_value = max(min_value, decayed_value / 1.5) + elif kl_star / 2.0 > kl_mean > 0.0: + decayed_value = min(max_value, 1.5 * decayed_value) + return decayed_value + @staticmethod def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module: ENCODER_FUNCTION_BY_TYPE = { diff --git a/ml-agents/mlagents/trainers/trajectory.py b/ml-agents/mlagents/trainers/trajectory.py index 0a08bc24b4..6d1fc7a68c 100644 --- a/ml-agents/mlagents/trainers/trajectory.py +++ b/ml-agents/mlagents/trainers/trajectory.py @@ -8,7 +8,7 @@ BufferKey, ) from mlagents_envs.base_env import ActionTuple -from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple +from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple, MusTuple, SigmasTuple class AgentStatus(NamedTuple): @@ -35,6 +35,8 @@ class AgentExperience(NamedTuple): done: bool action: ActionTuple action_probs: LogProbsTuple + action_mus: MusTuple + action_sigmas: SigmasTuple action_mask: np.ndarray prev_action: np.ndarray interrupted: bool @@ -267,6 +269,13 @@ def to_agentbuffer(self) -> AgentBuffer: agent_buffer_trajectory[BufferKey.DISCRETE_LOG_PROBS].append( exp.action_probs.discrete ) + agent_buffer_trajectory[BufferKey.CONTINUOUS_MUS].append( + exp.action_mus.continuous + ) + + agent_buffer_trajectory[BufferKey.CONTINUOUS_SIGMAS].append( + exp.action_sigmas.continuous + ) # Store action masks if necessary. Note that 1 means active, while # in AgentExperience False means active.