Skip to content

Commit

Permalink
Updated toy model configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin-Walker committed Nov 9, 2024
1 parent bfec069 commit dca27f5
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions torch_sequence_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,11 @@ def __init__(
self.nonlinear = nonlinear
if activation == "GLU":
self.activation = nn.GLU()
elif activation == "ReLU":
self.activation = nn.ReLU()
elif activation == "GELU":
self.activation = nn.GELU()
else:
raise ValueError("Invalid activation function")
self.act_name = activation

self.ssms = nn.ModuleList()
self.linear_mixes = nn.ModuleList()
Expand Down Expand Up @@ -480,19 +481,24 @@ def forward(self, x):
x = ssm(x)
if self.dropout is not None:
x = self.dropout(x)
x = linear_mix(x)
if self.nonlinear:
x = self.activation(x)
if self.act_name == "GELU":
if self.nonlinear:
x = self.activation(x)
x = linear_mix(x)
if self.act_name == "GLU":
x = linear_mix(x)
if self.nonlinear:
x = self.activation(x)
if self.dropout is not None:
x = self.dropout(x)
x = x + residual
if self.use_layernorm:
x = self.layernorms[i](x)
residual = x

x = self.linear_out(x)
if not self.continuous_output:
x = x.mean(dim=1)
x = self.linear_out(x)
return x


Expand Down Expand Up @@ -662,7 +668,7 @@ def run_sm_toy_experiment(
depth,
model_name,
nonlinear,
activation="ReLU",
activation="GELU",
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Expand Down

0 comments on commit dca27f5

Please sign in to comment.