From 4441dbf4bcd7ae0daee47d35fb0660bc1fe8bd4b Mon Sep 17 00:00:00 2001 From: Federico Belotti Date: Fri, 5 Apr 2024 10:17:41 +0200 Subject: [PATCH] Fix/dv3 layer norm (#257) * Preserve input dtype after LayerNorm (https://github.com/pytorch/pytorch/issues/66707#issuecomment-2028904230) * Fix imports --- sheeprl/algos/dreamer_v2/agent.py | 4 +-- sheeprl/algos/dreamer_v3/agent.py | 39 +++++++++++++++++----------- sheeprl/configs/algo/dreamer_v3.yaml | 4 +-- sheeprl/models/models.py | 21 +++++++++++++++ sheeprl/utils/model.py | 29 --------------------- tests/test_algos/test_algos.py | 12 ++++----- 6 files changed, 55 insertions(+), 54 deletions(-) diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index 4b693395..120b9d7f 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -22,10 +22,10 @@ ) from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state, init_weights -from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder +from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormChannelLast, LayerNormGRUCell, MultiDecoder, MultiEncoder from sheeprl.utils.distribution import TruncatedNormal from sheeprl.utils.fabric import get_single_device_fabric -from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward +from sheeprl.utils.model import ModuleType, cnn_forward class CNNEncoder(nn.Module): diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index adcbf195..e61d1472 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -24,9 +24,18 @@ from sheeprl.algos.dreamer_v2.agent import WorldModel from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights -from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder +from sheeprl.models.models import ( + CNN, + MLP, + DeCNN, + LayerNorm, + LayerNormChannelLast, + LayerNormGRUCell, + MultiDecoder, + MultiEncoder, +) from sheeprl.utils.fabric import get_single_device_fabric -from sheeprl.utils.model import LayerNormChannelLastFP32, LayerNormFP32, ModuleType, cnn_forward +from sheeprl.utils.model import ModuleType, cnn_forward from sheeprl.utils.utils import symlog @@ -44,7 +53,7 @@ class CNNEncoder(nn.Module): channels_multiplier (int): the multiplier for the output channels. Given the 4 stages, the 4 output channels will be [1, 2, 4, 8] * `channels_multiplier`. layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. - Defaults to LayerNormChannelLastFP32. + Defaults to LayerNormChannelLast. layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. Default to {"eps": 1e-3}. activation (ModuleType, optional): the activation function. @@ -58,7 +67,7 @@ def __init__( input_channels: Sequence[int], image_size: Tuple[int, int], channels_multiplier: int, - layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLastFP32, + layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLast, layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, activation: ModuleType = nn.SiLU, stages: int = 4, @@ -102,7 +111,7 @@ class MLPEncoder(nn.Module): dense_units (int, optional): the dimension of every mlp. Defaults to 512. layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. - Defaults to LayerNormFP32. + Defaults to LayerNorm. layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. Default to {"eps": 1e-3}. activation (ModuleType, optional): the activation function after every layer. @@ -117,7 +126,7 @@ def __init__( input_dims: Sequence[int], mlp_layers: int = 4, dense_units: int = 512, - layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32, + layer_norm_cls: Callable[..., nn.Module] = LayerNorm, layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, activation: ModuleType = nn.SiLU, symlog_inputs: bool = True, @@ -162,7 +171,7 @@ class CNNDecoder(nn.Module): activation (nn.Module, optional): the activation function. Defaults to nn.SiLU. layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. - Defaults to LayerNormChannelLastFP32. + Defaults to LayerNormChannelLast. layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. Default to {"eps": 1e-3}. stages (int): how many stages in the CNN decoder. @@ -177,7 +186,7 @@ def __init__( cnn_encoder_output_dim: int, image_size: Tuple[int, int], activation: nn.Module = nn.SiLU, - layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLastFP32, + layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLast, layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, stages: int = 4, ) -> None: @@ -232,7 +241,7 @@ class MLPDecoder(nn.Module): dense_units (int, optional): the dimension of every mlp. Defaults to 512. layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. - Defaults to LayerNormFP32. + Defaults to LayerNorm. layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. Default to {"eps": 1e-3}. activation (ModuleType, optional): the activation function after every layer. @@ -247,7 +256,7 @@ def __init__( mlp_layers: int = 4, dense_units: int = 512, activation: ModuleType = nn.SiLU, - layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32, + layer_norm_cls: Callable[..., nn.Module] = LayerNorm, layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, ) -> None: super().__init__() @@ -282,7 +291,7 @@ class RecurrentModel(nn.Module): activation_fn (nn.Module): the activation function. Default to SiLU. layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. - Defaults to LayerNormFP32. + Defaults to LayerNorm. layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. Default to {"eps": 1e-3}. """ @@ -293,7 +302,7 @@ def __init__( recurrent_state_size: int, dense_units: int, activation_fn: nn.Module = nn.SiLU, - layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32, + layer_norm_cls: Callable[..., nn.Module] = LayerNorm, layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, ) -> None: super().__init__() @@ -710,7 +719,7 @@ class Actor(nn.Module): mlp_layers (int): the number of dense layers. Default to 5. layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection. - Defaults to LayerNormFP32. + Defaults to LayerNorm. layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm. Default to {"eps": 1e-3}. unimix: (float, optional): the percentage of uniform distribution to inject into the categorical @@ -734,7 +743,7 @@ def __init__( dense_units: int = 1024, activation: nn.Module = nn.SiLU, mlp_layers: int = 5, - layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32, + layer_norm_cls: Callable[..., nn.Module] = LayerNorm, layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, unimix: float = 0.01, action_clip: float = 1.0, @@ -853,7 +862,7 @@ def __init__( dense_units: int = 1024, activation: nn.Module = nn.SiLU, mlp_layers: int = 5, - layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32, + layer_norm_cls: Callable[..., nn.Module] = LayerNorm, layer_norm_kw: Dict[str, Any] = {"eps": 1e-3}, unimix: float = 0.01, action_clip: float = 1.0, diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index 9b6e85fd..11955d84 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -26,11 +26,11 @@ mlp_keys: # Model related parameters cnn_layer_norm: - cls: sheeprl.utils.model.LayerNormChannelLastFP32 + cls: sheeprl.models.models.LayerNormChannelLast kw: eps: 1e-3 mlp_layer_norm: - cls: sheeprl.utils.model.LayerNormFP32 + cls: sheeprl.models.models.LayerNorm kw: eps: 1e-3 dense_units: 1024 diff --git a/sheeprl/models/models.py b/sheeprl/models/models.py index dbc810ad..a65dc381 100644 --- a/sheeprl/models/models.py +++ b/sheeprl/models/models.py @@ -502,3 +502,24 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: if self.mlp_decoder is not None: reconstructed_obs.update(self.mlp_decoder(x)) return reconstructed_obs + + +class LayerNormChannelLast(nn.LayerNorm): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor) -> Tensor: + if x.dim() != 4: + raise ValueError(f"Input tensor must be 4D (NCHW), received {len(x.shape)}D instead: {x.shape}") + input_dtype = x.dtype + x = x.permute(0, 2, 3, 1) + x = super().forward(x) + x = x.permute(0, 3, 1, 2) + return x.to(input_dtype) + + +class LayerNorm(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + input_dtype = x.dtype + out = super().forward(x) + return out.to(input_dtype) diff --git a/sheeprl/utils/model.py b/sheeprl/utils/model.py index 1552020e..cc25897c 100644 --- a/sheeprl/utils/model.py +++ b/sheeprl/utils/model.py @@ -221,32 +221,3 @@ def cnn_forward( flatten_input = input.reshape(-1, *input_dim) model_out = model(flatten_input) return model_out.reshape(*batch_shapes, *output_dim) - - -class LayerNormChannelLast(nn.LayerNorm): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def forward(self, x: Tensor) -> Tensor: - if x.dim() != 4: - raise ValueError(f"Input tensor must be 4D (NCHW), received {len(x.shape)}D instead: {x.shape}") - x = x.permute(0, 2, 3, 1) - x = super().forward(x) - x = x.permute(0, 3, 1, 2) - return x - - -class LayerNormChannelLastFP32(LayerNormChannelLast): - def forward(self, x: Tensor) -> Tensor: - input_dtype = x.dtype - x = x.to(torch.float32) - out = super().forward(x) - return out.to(input_dtype) - - -class LayerNormFP32(nn.LayerNorm): - def forward(self, x: Tensor) -> Tensor: - input_dtype = x.dtype - x = x.to(torch.float32) - out = super().forward(x) - return out.to(input_dtype) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index a15b6834..511a4691 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -474,8 +474,8 @@ def test_dreamer_v3(standard_args, env_id, start_time): "algo.cnn_keys.decoder=[rgb]", "algo.mlp_keys.encoder=[state]", "algo.mlp_keys.decoder=[state]", - "algo.mlp_layer_norm.cls=torch.nn.LayerNorm", - "algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast", + "algo.mlp_layer_norm.cls=sheeprl.models.models.LayerNorm", + "algo.cnn_layer_norm.cls=sheeprl.models.models.LayerNormChannelLast", ] with mock.patch.object(sys, "argv", args): @@ -513,8 +513,8 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.mlp_keys.encoder=[state]", "algo.mlp_keys.decoder=[state]", "checkpoint.save_last=True", - "algo.mlp_layer_norm.cls=torch.nn.LayerNorm", - "algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast", + "algo.mlp_layer_norm.cls=sheeprl.models.models.LayerNorm", + "algo.cnn_layer_norm.cls=sheeprl.models.models.LayerNormChannelLast", ] with mock.patch.object(sys, "argv", args): @@ -557,8 +557,8 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.cnn_keys.decoder=[rgb]", "algo.mlp_keys.encoder=[state]", "algo.mlp_keys.decoder=[state]", - "algo.mlp_layer_norm.cls=torch.nn.LayerNorm", - "algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast", + "algo.mlp_layer_norm.cls=sheeprl.models.models.LayerNorm", + "algo.cnn_layer_norm.cls=sheeprl.models.models.LayerNormChannelLast", ] with mock.patch.object(sys, "argv", args): run()