Skip to content

Commit

Permalink
Fix/dv3 layer norm (#257)
Browse files Browse the repository at this point in the history
* Preserve input dtype after LayerNorm (pytorch/pytorch#66707 (comment))

* Fix imports
  • Loading branch information
belerico authored Apr 5, 2024
1 parent e25da73 commit 4441dbf
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 54 deletions.
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
)

from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state, init_weights
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormChannelLast, LayerNormGRUCell, MultiDecoder, MultiEncoder
from sheeprl.utils.distribution import TruncatedNormal
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward
from sheeprl.utils.model import ModuleType, cnn_forward


class CNNEncoder(nn.Module):
Expand Down
39 changes: 24 additions & 15 deletions sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@
from sheeprl.algos.dreamer_v2.agent import WorldModel
from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state
from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder
from sheeprl.models.models import (
CNN,
MLP,
DeCNN,
LayerNorm,
LayerNormChannelLast,
LayerNormGRUCell,
MultiDecoder,
MultiEncoder,
)
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.model import LayerNormChannelLastFP32, LayerNormFP32, ModuleType, cnn_forward
from sheeprl.utils.model import ModuleType, cnn_forward
from sheeprl.utils.utils import symlog


Expand All @@ -44,7 +53,7 @@ class CNNEncoder(nn.Module):
channels_multiplier (int): the multiplier for the output channels. Given the 4 stages, the 4 output channels
will be [1, 2, 4, 8] * `channels_multiplier`.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormChannelLastFP32.
Defaults to LayerNormChannelLast.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
activation (ModuleType, optional): the activation function.
Expand All @@ -58,7 +67,7 @@ def __init__(
input_channels: Sequence[int],
image_size: Tuple[int, int],
channels_multiplier: int,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLastFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLast,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
activation: ModuleType = nn.SiLU,
stages: int = 4,
Expand Down Expand Up @@ -102,7 +111,7 @@ class MLPEncoder(nn.Module):
dense_units (int, optional): the dimension of every mlp.
Defaults to 512.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
activation (ModuleType, optional): the activation function after every layer.
Expand All @@ -117,7 +126,7 @@ def __init__(
input_dims: Sequence[int],
mlp_layers: int = 4,
dense_units: int = 512,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
activation: ModuleType = nn.SiLU,
symlog_inputs: bool = True,
Expand Down Expand Up @@ -162,7 +171,7 @@ class CNNDecoder(nn.Module):
activation (nn.Module, optional): the activation function.
Defaults to nn.SiLU.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormChannelLastFP32.
Defaults to LayerNormChannelLast.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
stages (int): how many stages in the CNN decoder.
Expand All @@ -177,7 +186,7 @@ def __init__(
cnn_encoder_output_dim: int,
image_size: Tuple[int, int],
activation: nn.Module = nn.SiLU,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLastFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLast,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
stages: int = 4,
) -> None:
Expand Down Expand Up @@ -232,7 +241,7 @@ class MLPDecoder(nn.Module):
dense_units (int, optional): the dimension of every mlp.
Defaults to 512.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
activation (ModuleType, optional): the activation function after every layer.
Expand All @@ -247,7 +256,7 @@ def __init__(
mlp_layers: int = 4,
dense_units: int = 512,
activation: ModuleType = nn.SiLU,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
) -> None:
super().__init__()
Expand Down Expand Up @@ -282,7 +291,7 @@ class RecurrentModel(nn.Module):
activation_fn (nn.Module): the activation function.
Default to SiLU.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
"""
Expand All @@ -293,7 +302,7 @@ def __init__(
recurrent_state_size: int,
dense_units: int,
activation_fn: nn.Module = nn.SiLU,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
) -> None:
super().__init__()
Expand Down Expand Up @@ -710,7 +719,7 @@ class Actor(nn.Module):
mlp_layers (int): the number of dense layers.
Default to 5.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
unimix: (float, optional): the percentage of uniform distribution to inject into the categorical
Expand All @@ -734,7 +743,7 @@ def __init__(
dense_units: int = 1024,
activation: nn.Module = nn.SiLU,
mlp_layers: int = 5,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
unimix: float = 0.01,
action_clip: float = 1.0,
Expand Down Expand Up @@ -853,7 +862,7 @@ def __init__(
dense_units: int = 1024,
activation: nn.Module = nn.SiLU,
mlp_layers: int = 5,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
unimix: float = 0.01,
action_clip: float = 1.0,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/configs/algo/dreamer_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ mlp_keys:

# Model related parameters
cnn_layer_norm:
cls: sheeprl.utils.model.LayerNormChannelLastFP32
cls: sheeprl.models.models.LayerNormChannelLast
kw:
eps: 1e-3
mlp_layer_norm:
cls: sheeprl.utils.model.LayerNormFP32
cls: sheeprl.models.models.LayerNorm
kw:
eps: 1e-3
dense_units: 1024
Expand Down
21 changes: 21 additions & 0 deletions sheeprl/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,24 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
if self.mlp_decoder is not None:
reconstructed_obs.update(self.mlp_decoder(x))
return reconstructed_obs


class LayerNormChannelLast(nn.LayerNorm):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x: Tensor) -> Tensor:
if x.dim() != 4:
raise ValueError(f"Input tensor must be 4D (NCHW), received {len(x.shape)}D instead: {x.shape}")
input_dtype = x.dtype
x = x.permute(0, 2, 3, 1)
x = super().forward(x)
x = x.permute(0, 3, 1, 2)
return x.to(input_dtype)


class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
out = super().forward(x)
return out.to(input_dtype)
29 changes: 0 additions & 29 deletions sheeprl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,32 +221,3 @@ def cnn_forward(
flatten_input = input.reshape(-1, *input_dim)
model_out = model(flatten_input)
return model_out.reshape(*batch_shapes, *output_dim)


class LayerNormChannelLast(nn.LayerNorm):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x: Tensor) -> Tensor:
if x.dim() != 4:
raise ValueError(f"Input tensor must be 4D (NCHW), received {len(x.shape)}D instead: {x.shape}")
x = x.permute(0, 2, 3, 1)
x = super().forward(x)
x = x.permute(0, 3, 1, 2)
return x


class LayerNormChannelLastFP32(LayerNormChannelLast):
def forward(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
x = x.to(torch.float32)
out = super().forward(x)
return out.to(input_dtype)


class LayerNormFP32(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
x = x.to(torch.float32)
out = super().forward(x)
return out.to(input_dtype)
12 changes: 6 additions & 6 deletions tests/test_algos/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ def test_dreamer_v3(standard_args, env_id, start_time):
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
"algo.mlp_layer_norm.cls=sheeprl.models.models.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.models.models.LayerNormChannelLast",
]

with mock.patch.object(sys, "argv", args):
Expand Down Expand Up @@ -513,8 +513,8 @@ def test_p2e_dv3(standard_args, env_id, start_time):
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"checkpoint.save_last=True",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
"algo.mlp_layer_norm.cls=sheeprl.models.models.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.models.models.LayerNormChannelLast",
]

with mock.patch.object(sys, "argv", args):
Expand Down Expand Up @@ -557,8 +557,8 @@ def test_p2e_dv3(standard_args, env_id, start_time):
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
"algo.mlp_layer_norm.cls=sheeprl.models.models.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.models.models.LayerNormChannelLast",
]
with mock.patch.object(sys, "argv", args):
run()
Expand Down

0 comments on commit 4441dbf

Please sign in to comment.