Skip to content

Commit

Permalink
fix cuda bug in unit test part 3 oops
Browse files Browse the repository at this point in the history
  • Loading branch information
kristian-georgiev committed Nov 2, 2023
1 parent be57466 commit d9b292c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def test_custom_model_output(tmp_path, cpu_proj):


def test_grad_wrt_last_layer(tmp_path):
model = resnet18().cuda().eval()
model = resnet18().eval()
N = 5
batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,))
traker = TRAKer(
Expand All @@ -400,7 +400,7 @@ def test_grad_wrt_last_layer(tmp_path):

@pytest.mark.cuda
def test_grad_wrt_last_layer_cuda(tmp_path):
model = resnet18().eval()
model = resnet18().cuda().eval()
N = 5
batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda()
traker = TRAKer(
Expand Down

0 comments on commit d9b292c

Please sign in to comment.