Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 18, 2024
1 parent 5ba6cb1 commit 2115104
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 60 deletions.
86 changes: 50 additions & 36 deletions sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def __init__(
layer_args={"kernel_size": 4, "stride": 2},
activation=activation,
norm_layer=[LayerNormChannelLast for _ in range(4)] if layer_norm else None,
norm_args=[{"normalized_shape": (2**i) * channels_multiplier} for i in range(4)]
if layer_norm
else None,
norm_args=(
[{"normalized_shape": (2**i) * channels_multiplier} for i in range(4)] if layer_norm else None
),
),
nn.Flatten(-3, -1),
)
Expand Down Expand Up @@ -172,12 +172,12 @@ def __init__(
],
activation=[activation, activation, activation, None],
norm_layer=[LayerNormChannelLast for _ in range(3)] + [None] if layer_norm else None,
norm_args=[
{"normalized_shape": (2 ** (4 - i - 2)) * channels_multiplier} for i in range(self.output_dim[0])
]
+ [None]
if layer_norm
else None,
norm_args=(
[{"normalized_shape": (2 ** (4 - i - 2)) * channels_multiplier} for i in range(self.output_dim[0])]
+ [None]
if layer_norm
else None
),
),
)

Expand Down Expand Up @@ -943,9 +943,11 @@ def build_agent(
activation=eval(world_model_cfg.representation_model.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm] if world_model_cfg.representation_model.layer_norm else None,
norm_args=[{"normalized_shape": world_model_cfg.representation_model.hidden_size}]
if world_model_cfg.representation_model.layer_norm
else None,
norm_args=(
[{"normalized_shape": world_model_cfg.representation_model.hidden_size}]
if world_model_cfg.representation_model.layer_norm
else None
),
)
transition_model = MLP(
input_dims=world_model_cfg.recurrent_model.recurrent_state_size,
Expand All @@ -954,9 +956,11 @@ def build_agent(
activation=eval(world_model_cfg.transition_model.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm] if world_model_cfg.transition_model.layer_norm else None,
norm_args=[{"normalized_shape": world_model_cfg.transition_model.hidden_size}]
if world_model_cfg.transition_model.layer_norm
else None,
norm_args=(
[{"normalized_shape": world_model_cfg.transition_model.hidden_size}]
if world_model_cfg.transition_model.layer_norm
else None
),
)
rssm = RSSM(
recurrent_model=recurrent_model.apply(init_weights),
Expand Down Expand Up @@ -999,15 +1003,19 @@ def build_agent(
hidden_sizes=[world_model_cfg.reward_model.dense_units] * world_model_cfg.reward_model.mlp_layers,
activation=eval(world_model_cfg.reward_model.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)]
if world_model_cfg.reward_model.layer_norm
else None,
norm_args=[
{"normalized_shape": world_model_cfg.reward_model.dense_units}
for _ in range(world_model_cfg.reward_model.mlp_layers)
]
if world_model_cfg.reward_model.layer_norm
else None,
norm_layer=(
[nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)]
if world_model_cfg.reward_model.layer_norm
else None
),
norm_args=(
[
{"normalized_shape": world_model_cfg.reward_model.dense_units}
for _ in range(world_model_cfg.reward_model.mlp_layers)
]
if world_model_cfg.reward_model.layer_norm
else None
),
)
if world_model_cfg.use_continues:
continue_model = MLP(
Expand All @@ -1016,15 +1024,19 @@ def build_agent(
hidden_sizes=[world_model_cfg.discount_model.dense_units] * world_model_cfg.discount_model.mlp_layers,
activation=eval(world_model_cfg.discount_model.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)]
if world_model_cfg.discount_model.layer_norm
else None,
norm_args=[
{"normalized_shape": world_model_cfg.discount_model.dense_units}
for _ in range(world_model_cfg.discount_model.mlp_layers)
]
if world_model_cfg.discount_model.layer_norm
else None,
norm_layer=(
[nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)]
if world_model_cfg.discount_model.layer_norm
else None
),
norm_args=(
[
{"normalized_shape": world_model_cfg.discount_model.dense_units}
for _ in range(world_model_cfg.discount_model.mlp_layers)
]
if world_model_cfg.discount_model.layer_norm
else None
),
)
world_model = WorldModel(
encoder.apply(init_weights),
Expand Down Expand Up @@ -1053,9 +1065,11 @@ def build_agent(
activation=eval(critic_cfg.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None,
norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None,
norm_args=(
[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None
),
)
actor.apply(init_weights)
critic.apply(init_weights)
Expand Down
8 changes: 5 additions & 3 deletions sheeprl/algos/p2e_dv2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@ def build_agent(
activation=eval(critic_cfg.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None,
norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None,
norm_args=(
[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None
),
)
actor_task.apply(init_weights)
critic_task.apply(init_weights)
Expand Down
8 changes: 5 additions & 3 deletions sheeprl/algos/p2e_dv3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ def build_agent(
flatten_dim=None,
layer_args={"bias": not critic_cfg.layer_norm},
norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None,
norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None,
norm_args=(
[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None
),
),
}
critics_exploration[k]["module"].apply(init_weights)
Expand Down
16 changes: 10 additions & 6 deletions sheeprl/algos/ppo_recurrent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ def __init__(
activation=eval(pre_rnn_mlp_cfg.activation),
layer_args={"bias": pre_rnn_mlp_cfg.bias},
norm_layer=[nn.LayerNorm] if pre_rnn_mlp_cfg.layer_norm else None,
norm_args=[{"normalized_shape": pre_rnn_mlp_cfg.dense_units, "eps": 1e-3}]
if pre_rnn_mlp_cfg.layer_norm
else None,
norm_args=(
[{"normalized_shape": pre_rnn_mlp_cfg.dense_units, "eps": 1e-3}]
if pre_rnn_mlp_cfg.layer_norm
else None
),
)
else:
self._pre_mlp = nn.Identity()
Expand All @@ -45,9 +47,11 @@ def __init__(
activation=eval(post_rnn_mlp_cfg.activation),
layer_args={"bias": post_rnn_mlp_cfg.bias},
norm_layer=[nn.LayerNorm] if post_rnn_mlp_cfg.layer_norm else None,
norm_args=[{"normalized_shape": post_rnn_mlp_cfg.dense_units, "eps": 1e-3}]
if post_rnn_mlp_cfg.layer_norm
else None,
norm_args=(
[{"normalized_shape": post_rnn_mlp_cfg.dense_units, "eps": 1e-3}]
if post_rnn_mlp_cfg.layer_norm
else None
),
)
self._output_dim = post_rnn_mlp_cfg.dense_units
else:
Expand Down
20 changes: 8 additions & 12 deletions sheeprl/data/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,10 @@ def to_tensor(
return buf

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...

@typing.overload
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
...
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...

def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None:
"""Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten.
Expand Down Expand Up @@ -614,12 +612,10 @@ def __len__(self) -> int:
return self.buffer_size

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...

@typing.overload
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
...
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...

def add(
self,
Expand Down Expand Up @@ -857,17 +853,17 @@ def __len__(self) -> int:
return self._cum_lengths[-1] if len(self._buf) > 0 else 0

@typing.overload
def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None:
...
def add(
self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False
) -> None: ...

@typing.overload
def add(
self,
data: Dict[str, np.ndarray],
env_idxes: Sequence[int] | None = None,
validate_args: bool = False,
) -> None:
...
) -> None: ...

def add(
self,
Expand Down
1 change: 1 addition & 0 deletions sheeprl/models/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Adapted from: https://github.com/thu-ml/tianshou/blob/master/tianshou/utils/net/common.py
"""

import warnings
from math import prod
from typing import Dict, Optional, Sequence, Union, no_type_check
Expand Down
2 changes: 2 additions & 0 deletions sheeprl/utils/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class OneHotCategoricalValidateArgs(Distribution):
probs (Tensor): event probabilities
logits (Tensor): event log probabilities (unnormalized)
"""

arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.one_hot
has_enumerate_support = True
Expand Down Expand Up @@ -391,6 +392,7 @@ class OneHotCategoricalStraightThroughValidateArgs(OneHotCategoricalValidateArgs
[1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
(Bengio et al, 2013)
"""

has_rsample = True

def rsample(self, sample_shape=torch.Size()):
Expand Down
1 change: 1 addition & 0 deletions sheeprl/utils/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Adapted from: https://github.com/thu-ml/tianshou/blob/master/tianshou/utils/net/common.py
"""

from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
Expand Down

0 comments on commit 2115104

Please sign in to comment.