diff --git a/benchmark.py b/benchmark.py index c991bb5..83a8ec6 100644 --- a/benchmark.py +++ b/benchmark.py @@ -125,8 +125,8 @@ def meta_forward( # Define parameters wq, wk, wv, wo = attention.q_projector, attention.k_projector, attention.v_projector, attention.output_projector - cache_k = attention.k_cache.sequence_cache.transpose(1, 2).detach().clone() - cache_v = attention.v_cache.sequence_cache.transpose(1, 2).detach().clone() + cache_k = attention.k_cache.sequence_cache.detach().clone() + cache_v = attention.v_cache.sequence_cache.detach().clone() n_local_heads = attention.number_of_heads n_local_kv_heads = attention.number_of_kv_heads