Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 9, 2024
1 parent f3a589b commit df9c1f6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 16 deletions.
1 change: 0 additions & 1 deletion notebooks/dreamer_v3_imagination.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
"import torchvision\n",
"from lightning.fabric import Fabric\n",
"from omegaconf import OmegaConf\n",
"from PIL import Image\n",
"\n",
"from sheeprl.algos.dreamer_v3.agent import build_agent\n",
"from sheeprl.data.buffers import SequentialReplayBuffer\n",
Expand Down
4 changes: 1 addition & 3 deletions sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def __init__(
layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity},
activation=activation,
norm_layer=[layer_norm_cls] * stages,
norm_args=[
{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)
],
norm_args=[{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)],
),
nn.Flatten(-3, -1),
)
Expand Down
20 changes: 8 additions & 12 deletions sheeprl/data/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,10 @@ def to_tensor(
return buf

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...

@typing.overload
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
...
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...

def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None:
"""Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten.
Expand Down Expand Up @@ -617,12 +615,10 @@ def __len__(self) -> int:
return self.buffer_size

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...

@typing.overload
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
...
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...

def add(
self,
Expand Down Expand Up @@ -860,17 +856,17 @@ def __len__(self) -> int:
return self._cum_lengths[-1] if len(self._buf) > 0 else 0

@typing.overload
def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None:
...
def add(
self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False
) -> None: ...

@typing.overload
def add(
self,
data: Dict[str, np.ndarray],
env_idxes: Sequence[int] | None = None,
validate_args: bool = False,
) -> None:
...
) -> None: ...

def add(
self,
Expand Down

0 comments on commit df9c1f6

Please sign in to comment.