-
Notifications
You must be signed in to change notification settings - Fork 391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix warning from torch.load starting in torch 2.4 #1064
Conversation
See discussion in #1063 Starting from PyTorch 2.4, there is a warning when torch.load is called without setting the weights_only argument. This is because in the future, the default will switch from False to True, which can result in a lot of errors when trying to load torch files (which are pickle files and thus insecure). In this PR, we add a possibility for the user to influence the kwargs passed to torch.load so that they can control that behavior. If not further indicated by the user, we will use the same defaults as the installed torch version. Therefore, users will only encounter this issue via skorch if they would have encountered it via torch anyway. Since it's not 100% certain if the default will switch in torch 2.6.0, we may have to adjust the version check in the future. Besides directly testing the kwargs being passed on, a test was also added that net.load_params does not give any warnings. This is already indirectly tested through some accelerate tests that are currently failing with torch 2.4, but it's better to have an explicit test. After this is merged, the CI should pass when using torch 2.4.0.
skorch/utils.py
Outdated
@@ -768,3 +769,16 @@ def _check_f_arguments(caller_name, **kwargs): | |||
key = 'module_' if key == 'f_params' else key[2:] + '_' | |||
kwargs_module[key] = val | |||
return kwargs_module, kwargs_other | |||
|
|||
|
|||
def check_torch_weights_only_default_true(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given how specific this function is to torch.load
, can this return torch_load_kwargs
itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I made the suggested change.
skorch/utils.py
Outdated
|
||
|
||
def get_torch_load_kwargs(): | ||
"""Returns the kwargs passed to torch.load the correspond to the current |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Returns the kwargs passed to torch.load the correspond to the current | |
"""Returns the kwargs passed to torch.load that correspond to the current |
skorch/utils.py
Outdated
@@ -768,3 +769,18 @@ def _check_f_arguments(caller_name, **kwargs): | |||
key = 'module_' if key == 'f_params' else key[2:] + '_' | |||
kwargs_module[key] = val | |||
return kwargs_module, kwargs_other | |||
|
|||
|
|||
def get_torch_load_kwargs(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_torch_load_kwargs(): | |
def get_default_torch_load_kwargs(): |
skorch/net.py
Outdated
@@ -2620,10 +2650,14 @@ def _get_state_dict(f_name): | |||
|
|||
return state_dict | |||
else: | |||
torch_load_kwargs = self.torch_load_kwargs | |||
if torch_load_kwargs is None: | |||
torch_load_kwargs = get_torch_load_kwargs() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch_load_kwargs = get_torch_load_kwargs() | |
torch_load_kwargs = get_default_torch_load_kwargs() |
Instead, rely on the installed torch version and skip if it doesn't fit.
CI is failing for unrelated reasons since the latest accelerate release, I opened an issue about it: |
Quick question about the (unrelated) failing CI, are the CI and integration tests run on multigpu environments at all? |
No, we're only using the free runners from GitHub on this repo. Is there anything that we should check specifically on GPU? |
Not sure. I think the only way GPU training would affect pickling is on distributed setups. I'm actually not sure how reliable pickling a running distributed accelerator is (e.g. there are a LOT of stackoverflow or forum posts about running into issues with pickling generators or in a multiprocessing context) |
If such a setting causes trouble, it's probably not just because of |
@ottonemo have your points been addressed? |
See discussion in #1063
Starting from PyTorch 2.4, there is a warning when
torch.load
is called without setting the weights_only argument. This is because in the future, the default will switch fromFalse
toTrue
, which can result in a lot of errors when trying to load torch files (which are pickle files and thus insecure).In this PR, we add a possibility for the user to influence the
kwargs
passed totorch.load
so that they can control that behavior. If not further indicated by the user, we will use the same defaults as the installed torch version. Therefore, users will only encounter this issue via skorch if they would have encountered it via torch anyway.Since it's not 100% certain if the default will switch in torch 2.6.0, we may have to adjust the version check in the future.
Besides directly testing the
kwargs
being passed on, a test was also added thatnet.load_params
does not give any warnings. This is already indirectly tested through some accelerate tests that are currently failing with torch 2.4, but it's better to have an explicit test.After this is merged, the CI should pass when using torch 2.4.0.