diff --git a/deepmatcher/models/core.py b/deepmatcher/models/core.py index c762166..c36e8a8 100644 --- a/deepmatcher/models/core.py +++ b/deepmatcher/models/core.py @@ -458,13 +458,13 @@ def save_state(self, path, include_meta=True): state[k] = getattr(self, k) torch.save(state, path, pickle_module=dill) - def load_state(self, path): + def load_state(self, path, map_location=None): r"""Load the model state from a file in a certain path. Args: path (string): The path to load the model state from. """ - state = torch.load(path, pickle_module=dill) + state = torch.load(path, pickle_module=dill, map_location=map_location) for k, v in six.iteritems(state): if k != 'model': self._train_buffers.add(k)