diff --git a/train_torchvision.py b/train_torchvision.py index afdbb9a..5362bda 100644 --- a/train_torchvision.py +++ b/train_torchvision.py @@ -104,7 +104,7 @@ def load_model(weight_path: Path, model: models.efficientnet.EfficientNet) -> mo """ loads all parameters of a model """ - checkpoint = torch.load(weight_path, map_location='cpu') + checkpoint = torch.load(weight_path) model.load_state_dict(checkpoint['state_dict']) return model