From dca27f52c47091fa31bb46671a7f3961300bff6a Mon Sep 17 00:00:00 2001 From: Benjamin-Walker Date: Sat, 9 Nov 2024 09:43:16 +0000 Subject: [PATCH] Updated toy model configuration --- torch_sequence_models.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torch_sequence_models.py b/torch_sequence_models.py index 8ad652f..5e1ea71 100644 --- a/torch_sequence_models.py +++ b/torch_sequence_models.py @@ -427,10 +427,11 @@ def __init__( self.nonlinear = nonlinear if activation == "GLU": self.activation = nn.GLU() - elif activation == "ReLU": - self.activation = nn.ReLU() + elif activation == "GELU": + self.activation = nn.GELU() else: raise ValueError("Invalid activation function") + self.act_name = activation self.ssms = nn.ModuleList() self.linear_mixes = nn.ModuleList() @@ -480,9 +481,14 @@ def forward(self, x): x = ssm(x) if self.dropout is not None: x = self.dropout(x) - x = linear_mix(x) - if self.nonlinear: - x = self.activation(x) + if self.act_name == "GELU": + if self.nonlinear: + x = self.activation(x) + x = linear_mix(x) + if self.act_name == "GLU": + x = linear_mix(x) + if self.nonlinear: + x = self.activation(x) if self.dropout is not None: x = self.dropout(x) x = x + residual @@ -490,9 +496,9 @@ def forward(self, x): x = self.layernorms[i](x) residual = x - x = self.linear_out(x) if not self.continuous_output: x = x.mean(dim=1) + x = self.linear_out(x) return x @@ -662,7 +668,7 @@ def run_sm_toy_experiment( depth, model_name, nonlinear, - activation="ReLU", + activation="GELU", ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)