Skip to content

Commit

Permalink
Added adaptive learning rate feature.
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelalonsojr committed Dec 9, 2024
1 parent cfb26e3 commit 1996bcd
Show file tree
Hide file tree
Showing 10 changed files with 362 additions and 16 deletions.
26 changes: 25 additions & 1 deletion ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
GlobalAgentId,
GlobalGroupId,
)
from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple
from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple, MusTuple, SigmasTuple
from mlagents.trainers.torch_entities.utils import ModelUtils

T = TypeVar("T")
Expand Down Expand Up @@ -251,6 +251,28 @@ def _process_step(
except KeyError:
log_probs_tuple = LogProbsTuple.empty_log_probs()

try:
stored_action_mus = stored_take_action_outputs["mus"]
if not isinstance(stored_action_mus, MusTuple):
stored_action_mus = stored_action_mus.to_mus_tuple()
mus_tuple = MusTuple(
continuous=stored_action_mus.continuous[idx],
discrete=stored_action_mus.discrete[idx],
)
except KeyError:
mus_tuple = MusTuple.empty_mus()

try:
stored_action_sigmas = stored_take_action_outputs["sigmas"]
if not isinstance(stored_action_sigmas, SigmasTuple):
stored_action_sigmas = stored_action_sigmas.to_sigmas_tuple()
sigmas_tuple = SigmasTuple(
continuous=stored_action_sigmas.continuous[idx],
discrete=stored_action_sigmas.discrete[idx],
)
except KeyError:
sigmas_tuple = MusTuple.empty_mus()

action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :]

Expand All @@ -266,6 +288,8 @@ def _process_step(
done=done,
action=action_tuple,
action_probs=log_probs_tuple,
action_mus=mus_tuple,
action_sigmas=sigmas_tuple,
action_mask=action_mask,
prev_action=prev_action,
interrupted=interrupted,
Expand Down
4 changes: 4 additions & 0 deletions ml-agents/mlagents/trainers/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class BufferKey(enum.Enum):
CONTINUOUS_ACTION = "continuous_action"
NEXT_CONT_ACTION = "next_continuous_action"
CONTINUOUS_LOG_PROBS = "continuous_log_probs"
CONTINUOUS_MUS = "continuous_mus"
DISCRETE_MUS = "discrete_mus"
CONTINUOUS_SIGMAS = "continuous_sigmas"
DISCRETE_SIGMAS = "discrete_sigmas"
DISCRETE_ACTION = "discrete_action"
NEXT_DISC_ACTION = "next_discrete_action"
DISCRETE_LOG_PROBS = "discrete_log_probs"
Expand Down
26 changes: 22 additions & 4 deletions ml-agents/mlagents/trainers/ppo/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from mlagents.trainers.torch_entities.networks import ValueNetwork
from mlagents.trainers.torch_entities.agent_action import AgentAction
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs, ActionMus, ActionSigmas
from mlagents.trainers.torch_entities.utils import ModelUtils
from mlagents.trainers.trajectory import ObsUtil

Expand Down Expand Up @@ -66,8 +66,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
self.decay_learning_rate = ModelUtils.DecayedValue(
self.hyperparameters.learning_rate_schedule,
self.hyperparameters.learning_rate,
1e-10,
self.hyperparameters.lr_min,
self.trainer_settings.max_steps,
self.hyperparameters.desired_lr_kl,
self.hyperparameters.lr_max
)
self.decay_epsilon = ModelUtils.DecayedValue(
self.hyperparameters.epsilon_schedule,
Expand All @@ -92,6 +94,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):

self.stream_names = list(self.reward_signals.keys())

self.loss = torch.zeros(1, device=default_device())

self.last_actions = None

@property
def critic(self):
return self._critic
Expand Down Expand Up @@ -153,13 +159,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:

log_probs = run_out["log_probs"]
entropy = run_out["entropy"]
mus = run_out["mus"]
sigmas = run_out["sigmas"]

values, _ = self.critic.critic_pass(
current_obs,
memories=value_memories,
sequence_length=self.policy.sequence_length,
)
old_log_probs = ActionLogProbs.from_buffer(batch).flatten()
old_mus = ActionMus.from_buffer(batch).flatten()
old_sigmas = ActionSigmas.from_buffer(batch).flatten()
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool)
value_loss = ModelUtils.trust_region_value_loss(
Expand All @@ -172,16 +182,22 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
loss_masks,
decay_eps,
)
loss = (
self.loss = (
policy_loss
+ 0.5 * value_loss
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
)

# adaptive learning rate
if self.hyperparameters.learning_rate_schedule == ScheduleType.ADAPTIVE:
decay_lr = self.decay_learning_rate.get_value(
self.policy.get_current_step(), mus, old_mus, sigmas, old_sigmas
)

# Set optimizer learning rate
ModelUtils.update_learning_rate(self.optimizer, decay_lr)
self.optimizer.zero_grad()
loss.backward()
self.loss.backward()

self.optimizer.step()
update_stats = {
Expand All @@ -194,6 +210,8 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"Policy/Beta": decay_bet,
}

self.loss = torch.zeros(1, device=default_device())

return update_stats

# TODO move module update into TorchOptimizer for reward_provider
Expand Down
4 changes: 4 additions & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class EncoderType(Enum):
class ScheduleType(Enum):
CONSTANT = "constant"
LINEAR = "linear"
ADAPTIVE = "adaptive"
# TODO add support for lesson based scheduling
# LESSON = "lesson"

Expand Down Expand Up @@ -158,6 +159,9 @@ class HyperparamSettings:
batch_size: int = 1024
buffer_size: int = 10240
learning_rate: float = 3.0e-4
desired_lr_kl: float = 0.008
lr_min: float = 1.0e-10
lr_max: float = 1.0e-2
learning_rate_schedule: ScheduleType = ScheduleType.CONSTANT


Expand Down
158 changes: 158 additions & 0 deletions ml-agents/mlagents/trainers/torch_entities/action_log_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@
from mlagents_envs.base_env import _ActionTupleBase


class MusTuple(_ActionTupleBase):
@property
def discrete_dtype(self) -> np.dtype:
return np.float32

@staticmethod
def empty_mus() -> "MusTuple":
return MusTuple()


class SigmasTuple(_ActionTupleBase):
@property
def discrete_dtype(self) -> np.dtype:
return np.float32

@staticmethod
def empty_sigmas() -> "SigmasTuple":
return SigmasTuple()


class LogProbsTuple(_ActionTupleBase):
"""
An object whose fields correspond to the log probs of actions of different types.
Expand Down Expand Up @@ -116,3 +136,141 @@ def from_buffer(buff: AgentBuffer) -> "ActionLogProbs":
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return ActionLogProbs(continuous, discrete, None)


class ActionMus(NamedTuple):
continuous_tensor: torch.Tensor
discrete_list: Optional[List[torch.Tensor]]
all_discrete_list: Optional[List[torch.Tensor]]

@property
def discrete_tensor(self):
"""
Returns the discrete log probs list as a stacked tensor
"""
return torch.stack(self.discrete_list, dim=-1)

@property
def all_discrete_tensor(self):
"""
Returns the discrete log probs of each branch as a tensor
"""
return torch.cat(self.all_discrete_list, dim=1)

def to_mus_tuple(self) -> MusTuple:
mus_tuple = MusTuple()
if self.continuous_tensor is not None:
continuous = ModelUtils.to_numpy(self.continuous_tensor)
mus_tuple.add_continuous(continuous)
if self.discrete_list is not None:
discrete = ModelUtils.to_numpy(self.discrete_tensor)
mus_tuple.add_discrete(discrete)
return mus_tuple

def _to_tensor_list(self) -> List[torch.Tensor]:
"""
Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This
is private and serves as a utility for self.flatten()
"""
tensor_list: List[torch.Tensor] = []
if self.continuous_tensor is not None:
tensor_list.append(self.continuous_tensor)
if self.discrete_list is not None:
tensor_list.append(self.discrete_tensor)
return tensor_list

def flatten(self) -> torch.Tensor:
"""
A utility method that returns all log probs in ActionLogProbs as a flattened tensor.
This is useful for algorithms like PPO which can treat all log probs in the same way.
"""
return torch.cat(self._to_tensor_list(), dim=1)

@staticmethod
def from_buffer(buff: AgentBuffer) -> "ActionMus":
"""
A static method that accesses continuous and discrete log probs fields in an AgentBuffer
and constructs the corresponding ActionLogProbs from the retrieved np arrays.
"""
continuous: torch.Tensor = None
discrete: List[torch.Tensor] = None # type: ignore

if BufferKey.CONTINUOUS_MUS in buff:
continuous = ModelUtils.list_to_tensor(buff[BufferKey.CONTINUOUS_MUS])
if BufferKey.DISCRETE_MUS in buff:
discrete_tensor = ModelUtils.list_to_tensor(buff[BufferKey.DISCRETE_MUS])
# This will keep discrete_list = None which enables flatten()
if discrete_tensor.shape[1] > 0:
discrete = [
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return ActionMus(continuous, discrete, None)


class ActionSigmas(NamedTuple):
continuous_tensor: torch.Tensor
discrete_list: Optional[List[torch.Tensor]]
all_discrete_list: Optional[List[torch.Tensor]]

@property
def discrete_tensor(self):
"""
Returns the discrete log probs list as a stacked tensor
"""
return torch.stack(self.discrete_list, dim=-1)

@property
def all_discrete_tensor(self):
"""
Returns the discrete log probs of each branch as a tensor
"""
return torch.cat(self.all_discrete_list, dim=1)

def to_sigmas_tuple(self) -> SigmasTuple:
sigmas_tuple = SigmasTuple()
if self.continuous_tensor is not None:
continuous = ModelUtils.to_numpy(self.continuous_tensor)
sigmas_tuple.add_continuous(continuous)
if self.discrete_list is not None:
discrete = ModelUtils.to_numpy(self.discrete_tensor)
sigmas_tuple.add_discrete(discrete)
return sigmas_tuple

def _to_tensor_list(self) -> List[torch.Tensor]:
"""
Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This
is private and serves as a utility for self.flatten()
"""
tensor_list: List[torch.Tensor] = []
if self.continuous_tensor is not None:
tensor_list.append(self.continuous_tensor)
if self.discrete_list is not None:
tensor_list.append(self.discrete_tensor)
return tensor_list

def flatten(self) -> torch.Tensor:
"""
A utility method that returns all log probs in ActionLogProbs as a flattened tensor.
This is useful for algorithms like PPO which can treat all log probs in the same way.
"""
return torch.cat(self._to_tensor_list(), dim=1)

@staticmethod
def from_buffer(buff: AgentBuffer) -> "ActionSigmas":
"""
A static method that accesses continuous and discrete log probs fields in an AgentBuffer
and constructs the corresponding ActionLogProbs from the retrieved np arrays.
"""
continuous: torch.Tensor = None
discrete: List[torch.Tensor] = None # type: ignore

if BufferKey.CONTINUOUS_SIGMAS in buff:
continuous = ModelUtils.list_to_tensor(buff[BufferKey.CONTINUOUS_SIGMAS])
if BufferKey.DISCRETE_SIGMAS in buff:
discrete_tensor = ModelUtils.list_to_tensor(buff[BufferKey.DISCRETE_SIGMAS])
# This will keep discrete_list = None which enables flatten()
if discrete_tensor.shape[1] > 0:
discrete = [
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return ActionSigmas(continuous, discrete, None)
28 changes: 24 additions & 4 deletions ml-agents/mlagents/trainers/torch_entities/action_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
MultiCategoricalDistribution,
)
from mlagents.trainers.torch_entities.agent_action import AgentAction
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs, ActionMus, ActionSigmas
from mlagents_envs.base_env import ActionSpec


Expand Down Expand Up @@ -146,9 +146,25 @@ def _get_probs_and_entropy(
entropies = torch.cat(entropies_list, dim=1)
return action_log_probs, entropies

def _get_mus_and_sigmas(self, actions, dists):
continuous_mus: Optional[torch.Tensor] = None
continuous_sigmas: Optional[torch.Tensor] = None
discrete_mus: Optional[torch.Tensor] = None
discrete_sigmas: Optional[torch.Tensor] = None
all_discrete_mus: Optional[List[torch.Tensor]] = None
all_discrete_sigmas: Optional[List[torch.Tensor]] = None
if dists.continuous is not None:
continuous_mus = dists.continuous.mu()
continuous_sigmas = dists.continuous.sigma()
action_mus = ActionMus(continuous_mus, discrete_mus, all_discrete_mus)
action_sigmas = ActionSigmas(
continuous_sigmas, discrete_sigmas, all_discrete_sigmas
)
return action_mus, action_sigmas

def evaluate(
self, inputs: torch.Tensor, masks: torch.Tensor, actions: AgentAction
) -> Tuple[ActionLogProbs, torch.Tensor]:
) -> Tuple[ActionLogProbs, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given actions and encoding from the network body, gets the distributions and
computes the log probabilites and entropies.
Expand All @@ -159,9 +175,12 @@ def evaluate(
"""
dists = self._get_dists(inputs, masks)
log_probs, entropies = self._get_probs_and_entropy(actions, dists)
# mus = dists.continuous.deterministic_sample()
mus = dists.continuous.mu()
sigmas = dists.continuous.sigma()
# Use the sum of entropy across actions, not the mean
entropy_sum = torch.sum(entropies, dim=1)
return log_probs, entropy_sum
return log_probs, entropy_sum, mus, sigmas

def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -228,4 +247,5 @@ def forward(
log_probs, entropies = self._get_probs_and_entropy(actions, dists)
# Use the sum of entropy across actions, not the mean
entropy_sum = torch.sum(entropies, dim=1)
return (actions, log_probs, entropy_sum)
mus, sigmas = self._get_mus_and_sigmas(actions, dists)
return (actions, log_probs, entropy_sum, mus, sigmas)
Loading

0 comments on commit 1996bcd

Please sign in to comment.