Skip to content

Commit

Permalink
add shared_codebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
a-kore committed Jun 22, 2023
1 parent 479ae5a commit bca4dee
Show file tree
Hide file tree
Showing 7 changed files with 1,517 additions and 399 deletions.
4 changes: 1 addition & 3 deletions rpq/models/rpqopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,6 @@ def forward(
hidden_states = self.self_attn_layer_norm(hidden_states)

# Fully Connected
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states

# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
Expand All @@ -355,7 +353,7 @@ def forward(
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

hidden_states = (residual + hidden_states).view(hidden_states_shape)
hidden_states = (residual + hidden_states)

# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
Expand Down
Loading

0 comments on commit bca4dee

Please sign in to comment.