Skip to content

Commit

Permalink
bfloat16 support (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinwangx authored Aug 14, 2024
1 parent c6394f8 commit 9bf8050
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions repe/rep_reading_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def _get_hidden_states(
hidden_states_layers = {}
for layer in hidden_layers:
hidden_states = outputs['hidden_states'][layer]
hidden_states = hidden_states[:, rep_token, :]
# hidden_states_layers[layer] = hidden_states.cpu().to(dtype=torch.float32).detach().numpy()
hidden_states = hidden_states[:, rep_token, :].detach()
if hidden_states.dtype == torch.bfloat16:
hidden_states = hidden_states.float()
hidden_states_layers[layer] = hidden_states.detach()

return hidden_states_layers
Expand Down

0 comments on commit 9bf8050

Please sign in to comment.