From 9bf8050c51d6095d3f7eb394de4495ef697af32e Mon Sep 17 00:00:00 2001 From: Justin Wang <73374902+justinwangx@users.noreply.github.com> Date: Tue, 13 Aug 2024 21:56:42 -0400 Subject: [PATCH] bfloat16 support (#51) --- repe/rep_reading_pipeline.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/repe/rep_reading_pipeline.py b/repe/rep_reading_pipeline.py index 437ff0f..16d5f47 100644 --- a/repe/rep_reading_pipeline.py +++ b/repe/rep_reading_pipeline.py @@ -22,8 +22,9 @@ def _get_hidden_states( hidden_states_layers = {} for layer in hidden_layers: hidden_states = outputs['hidden_states'][layer] - hidden_states = hidden_states[:, rep_token, :] - # hidden_states_layers[layer] = hidden_states.cpu().to(dtype=torch.float32).detach().numpy() + hidden_states = hidden_states[:, rep_token, :].detach() + if hidden_states.dtype == torch.bfloat16: + hidden_states = hidden_states.float() hidden_states_layers[layer] = hidden_states.detach() return hidden_states_layers