Skip to content

Commit

Permalink
refactor external model loading for better compatibility w/ DDP
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Jan 4, 2024
1 parent 080428d commit ddac685
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Literal, Optional, Sequence, Tuple, Union)
import os
from os.path import join, isdir
from enum import Enum
import random
Expand Down Expand Up @@ -164,7 +165,10 @@ def check_either_uri_or_repo(cls, values: dict) -> dict:
'Must specify one (and only one) of github_repo and uri.')
return values

def build(self, save_dir: str, hubconf_dir: Optional[str] = None) -> Any:
def build(self,
save_dir: str,
hubconf_dir: Optional[str] = None,
ddp_rank: Optional[int] = None) -> Any:
"""Load an external module via torch.hub.
Note: Loading a PyTorch module is the typical use case, but there are
Expand All @@ -188,22 +192,28 @@ def build(self, save_dir: str, hubconf_dir: Optional[str] = None) -> Any:
**self.entrypoint_kwargs)
return module

hubconf_dir = get_hubconf_dir_from_cfg(self, parent=save_dir)
dst_dir = get_hubconf_dir_from_cfg(self, parent=save_dir)
if ddp_rank is not None:
# avoid conflicts when downloading
os.environ['TORCH_HOME'] = f'~/.cache/torch/{ddp_rank}'
if ddp_rank != 0:
dst_dir = None

if self.github_repo is not None:
log.info(f'Fetching module definition from: {self.github_repo}')
module = torch_hub_load_github(
repo=self.github_repo,
hubconf_dir=hubconf_dir,
entrypoint=self.entrypoint,
*self.entrypoint_args,
dst_dir=dst_dir,
**self.entrypoint_kwargs)
else:
log.info(f'Fetching module definition from: {self.uri}')
module = torch_hub_load_uri(
uri=self.uri,
hubconf_dir=hubconf_dir,
entrypoint=self.entrypoint,
*self.entrypoint_args,
dst_dir=dst_dir,
**self.entrypoint_kwargs)
return module

Expand Down Expand Up @@ -253,6 +263,7 @@ def build(self,
in_channels: int,
save_dir: Optional[str] = None,
hubconf_dir: Optional[str] = None,
ddp_rank: Optional[int] = None,
**kwargs) -> nn.Module:
"""Build and return a model based on the config.
Expand All @@ -271,7 +282,7 @@ def build(self,
"""
if self.external_def is not None:
return self.build_external_model(
save_dir=save_dir, hubconf_dir=hubconf_dir)
save_dir=save_dir, hubconf_dir=hubconf_dir, ddp_rank=ddp_rank)
return self.build_default_model(num_classes, in_channels, **kwargs)

def build_default_model(self, num_classes: int, in_channels: int,
Expand All @@ -290,7 +301,8 @@ def build_default_model(self, num_classes: int, in_channels: int,

def build_external_model(self,
save_dir: str,
hubconf_dir: Optional[str] = None) -> nn.Module:
hubconf_dir: Optional[str] = None,
ddp_rank: Optional[int] = None) -> nn.Module:
"""Build and return an external model.
Args:
Expand All @@ -301,7 +313,8 @@ def build_external_model(self,
Returns:
A PyTorch nn.Module.
"""
return self.external_def.build(save_dir, hubconf_dir=hubconf_dir)
return self.external_def.build(
save_dir, hubconf_dir=hubconf_dir, ddp_rank=ddp_rank)


def solver_config_upgrader(cfg_dict: dict, version: int) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def build_model(self, model_def_path: Optional[str] = None) -> 'nn.Module':
in_channels=cfg.data.img_channels,
save_dir=self.modules_dir,
hubconf_dir=model_def_path,
img_sz=cfg.data.img_sz)
img_sz=cfg.data.img_sz,
ddp_rank=self.ddp_local_rank)
return model

def setup_model(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def build_model(self, model_def_path: Optional[str] = None) -> 'nn.Module':
hubconf_dir=model_def_path,
class_names=class_names,
pos_class_names=pos_class_names,
prob_class_names=prob_class_names)
prob_class_names=prob_class_names,
ddp_rank=self.ddp_local_rank)
return model

def on_overfit_start(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,19 @@ def get_hubconf_dir_from_cfg(cfg, parent: Optional[str] = '') -> str:
return path


def torch_hub_load_github(repo: str, hubconf_dir: str, entrypoint: str, *args,
def torch_hub_load_github(repo: str,
entrypoint: str,
*args,
dst_dir: Optional[str] = None,
**kwargs) -> Any:
"""Load an entrypoint from a github repo using :func:`torch.hub.load`.
Args:
repo (str): <repo-owner>/<erpo-name>[:tag]
hubconf_dir (str): Where the contents from the uri will finally
be saved to.
entrypoint (str): Name of a callable present in hubconf.py.
*args: Args to be passed to the entrypoint.
dst_dir: If provided, the contents of the repo are copied there.
Defaults to None.
**kwargs: Keyword args to be passed to the entrypoint.
Returns:
Expand All @@ -85,14 +88,18 @@ def torch_hub_load_github(repo: str, hubconf_dir: str, entrypoint: str, *args,
skip_validation=True,
**kwargs)

orig_dir = join(torch.hub.get_dir(), _repo_name_to_dir_name(repo))
_remove_dir(hubconf_dir)
shutil.move(orig_dir, hubconf_dir)
if dst_dir is not None:
orig_dir = join(torch.hub.get_dir(), _repo_name_to_dir_name(repo))
_remove_dir(dst_dir)
shutil.move(orig_dir, dst_dir)

return out


def torch_hub_load_uri(uri: str, hubconf_dir: str, entrypoint: str, *args,
def torch_hub_load_uri(uri: str,
entrypoint: str,
*args,
dst_dir: Optional[str] = None,
**kwargs) -> Any:
"""Load an entrypoint from a uri.
Expand All @@ -103,47 +110,48 @@ def torch_hub_load_uri(uri: str, hubconf_dir: str, entrypoint: str, *args,
The zip file should either have hubconf.py at the top level or contain
a single sub-directory that contains hubconf.py at its top level. In the
latter case, the sub-directory will be copied to hubconf_dir.
latter case, the sub-directory will be copied to dst_dir.
Args:
uri (str): A URI.
hubconf_dir (str): The target directory where the contents from the uri
will finally be saved to.
entrypoint (str): Name of a callable present in hubconf.py.
*args: Args to be passed to the entrypoint.
dst_dir: If provided, the contents from the uri are copied there.
Defaults to None.
**kwargs: Keyword args to be passed to the entrypoint.
Returns:
Any: The output from calling the entrypoint.
"""

uri_path = Path(uri)
is_zip = uri_path.suffix.lower() == '.zip'
if is_zip:
# unzip
zip_path = download_if_needed(uri)
with get_tmp_dir() as tmp_dir:
unzip_dir = join(tmp_dir, uri_path.stem)
_remove_dir(unzip_dir)
unzip(zip_path, target_dir=unzip_dir)
unzipped_contents = list(glob(f'{unzip_dir}/*', recursive=False))

_remove_dir(hubconf_dir)

# if the top level only contains a directory
if (len(unzipped_contents) == 1) and isdir(unzipped_contents[0]):
sub_dir = unzipped_contents[0]
shutil.move(sub_dir, hubconf_dir)
scr_dir = sub_dir
else:
shutil.move(unzip_dir, hubconf_dir)
# assume uri is local and attempt copying
scr_dir = unzip_dir

out = torch_hub_load_local(scr_dir, entrypoint, *args, **kwargs)

if dst_dir is not None:
_remove_dir(dst_dir)
shutil.move(scr_dir, dst_dir)
else:
# only copy if needed
if realpath(uri) != realpath(hubconf_dir):
_remove_dir(hubconf_dir)
shutil.copytree(uri, hubconf_dir)
# assume uri is local
out = torch_hub_load_local(uri, entrypoint, *args, **kwargs)
if dst_dir is not None and realpath(uri) != realpath(dst_dir):
_remove_dir(dst_dir)
shutil.copytree(uri, dst_dir)

out = torch_hub_load_local(hubconf_dir, entrypoint, *args, **kwargs)
return out


Expand Down
6 changes: 3 additions & 3 deletions tests/pytorch_learner/utils/test_torch_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def test_torch_hub_load(self):
hubconf_dir = join(tmp_dir, 'focal_loss')
loss = torch_hub_load_github(
repo='AdeelH/pytorch-multi-class-focal-loss:1.1',
hubconf_dir=hubconf_dir,
entrypoint='focal_loss',
dst_dir=hubconf_dir,
alpha=[.75, .25],
gamma=2)
self.assertIsInstance(loss, nn.Module)
Expand All @@ -95,8 +95,8 @@ def test_torch_hub_load(self):
# local, via torch_hub_load_uri
loss = torch_hub_load_uri(
uri=hubconf_dir,
hubconf_dir=hubconf_dir,
entrypoint='focal_loss',
dst_dir=hubconf_dir,
alpha=[.75, .25],
gamma=2)
self.assertIsInstance(loss, nn.Module)
Expand All @@ -110,8 +110,8 @@ def test_torch_hub_load(self):
loss = torch_hub_load_uri(
uri=
'https://github.com/AdeelH/pytorch-multi-class-focal-loss/archive/refs/tags/1.1.zip', # noqa
hubconf_dir=hubconf_dir,
entrypoint='focal_loss',
dst_dir=hubconf_dir,
alpha=[.75, .25],
gamma=2)
self.assertIsInstance(loss, nn.Module)
Expand Down

0 comments on commit ddac685

Please sign in to comment.