From 9123fa30eb3fb4b32239826fd84eed721c65e92a Mon Sep 17 00:00:00 2001 From: "jorge.decorte@pttrns.ai" Date: Wed, 15 Apr 2020 09:38:46 +0200 Subject: [PATCH] Add map_location to load_state function --- deepmatcher/models/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmatcher/models/core.py b/deepmatcher/models/core.py index c762166..c36e8a8 100644 --- a/deepmatcher/models/core.py +++ b/deepmatcher/models/core.py @@ -458,13 +458,13 @@ def save_state(self, path, include_meta=True): state[k] = getattr(self, k) torch.save(state, path, pickle_module=dill) - def load_state(self, path): + def load_state(self, path, map_location=None): r"""Load the model state from a file in a certain path. Args: path (string): The path to load the model state from. """ - state = torch.load(path, pickle_module=dill) + state = torch.load(path, pickle_module=dill, map_location=map_location) for k, v in six.iteritems(state): if k != 'model': self._train_buffers.add(k)