Skip to content

Commit

Permalink
remove loading model on CPU by default
Browse files Browse the repository at this point in the history
  • Loading branch information
d-f authored Mar 18, 2024
1 parent 71e5d10 commit a3bc1ba
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def load_model(weight_path: Path, model: models.efficientnet.EfficientNet) -> mo
"""
loads all parameters of a model
"""
checkpoint = torch.load(weight_path, map_location='cpu')
checkpoint = torch.load(weight_path)
model.load_state_dict(checkpoint['state_dict'])
return model

Expand Down

0 comments on commit a3bc1ba

Please sign in to comment.