diff --git a/tests/test_model.py b/tests/test_model.py index 2f7957a..a1397f2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -75,7 +75,6 @@ def test_attention(fairscale_init): iattention.v_projector.weight.copy_(attention.wv.weight) iattention.output_projector.weight.copy_(attention.wo.weight) - iattention.cuda() iattention.eval() output2 = iattention(x, freqs_cis, 0) @@ -163,7 +162,6 @@ def test_decoder(): transformer_block.attention.cache_k.random_() transformer_block.attention.cache_v.random_() - transformer_block transformer_block.eval() itransformer_block.attention.k_cache.sequence_cache.copy_(transformer_block.attention.cache_k.transpose(1, 2)) @@ -177,8 +175,6 @@ def test_decoder(): itransformer_block.ffn.output_layer.weight.copy_(transformer_block.feed_forward.w2.weight) itransformer_block.ffn.gate_layer.weight.copy_(transformer_block.feed_forward.w3.weight) - itransformer_block - x = torch.randn(2, 2048, 4096) freqs_cis = precompute_freqs_cis(4096, 2048 * 2, 500000)