diff --git a/lightly/cli/_helpers.py b/lightly/cli/_helpers.py index c2bc6763d..b4b182946 100644 --- a/lightly/cli/_helpers.py +++ b/lightly/cli/_helpers.py @@ -2,7 +2,10 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from __future__ import annotations + import os +from typing import Any import hydra import torch @@ -17,7 +20,7 @@ from lightly.utils.version_compare import version_compare -def cpu_count(): +def cpu_count() -> int | None: """Returns the number of CPUs which are present in the system. This number is not equivalent to the number of available CPUs to the process. @@ -26,21 +29,26 @@ def cpu_count(): return os.cpu_count() -def fix_input_path(path): +def fix_input_path(path: str) -> str: """Fix broken relative paths.""" if not os.path.isabs(path): path = utils.to_absolute_path(path) return path -def fix_hydra_arguments(config_path: str = "config", config_name: str = "config"): +def fix_hydra_arguments( + config_path: str = "config", config_name: str = "config" +) -> dict[str, str | None]: """Helper to make hydra arugments adaptive to installed hydra version Hydra introduced the `version_base` argument in version 1.2.0 We use this helper to provide backwards compatibility to older hydra verisons. """ - hydra_args = {"config_path": config_path, "config_name": config_name} + hydra_args: dict[str, str | None] = { + "config_path": config_path, + "config_name": config_name, + } try: if version_compare(hydra.__version__, "1.2.0") >= 0: @@ -53,13 +61,13 @@ def fix_hydra_arguments(config_path: str = "config", config_name: str = "config" return hydra_args -def is_url(checkpoint): +def is_url(checkpoint: str) -> bool: """Check whether the checkpoint is a url or not.""" is_url = "https://storage.googleapis.com" in checkpoint return is_url -def get_ptmodel_from_config(model): +def get_ptmodel_from_config(model: dict[str, Any]) -> tuple[str, str]: """Get a pre-trained model from the lightly model zoo.""" key = model["name"] key += "/simclr" @@ -72,10 +80,14 @@ def get_ptmodel_from_config(model): return "", key -def load_state_dict_from_url(url, map_location=None): +def load_state_dict_from_url( + url: str, map_location: torch.device | None = None +) -> dict[str, torch.Tensor | None]: """Try to load the checkopint from the given url.""" try: - state_dict = torch.hub.load_state_dict_from_url(url, map_location=map_location) + state_dict: dict[str, torch.Tensor] = torch.hub.load_state_dict_from_url( + url, map_location=map_location + ) return state_dict except Exception: print("Not able to load state dict from %s" % (url)) @@ -89,10 +101,15 @@ def load_state_dict_from_url(url, map_location=None): # in this case downloading the pre-trained model was not possible # notify the user and return + return {"state_dict": None} -def _maybe_expand_batchnorm_weights(model_dict, state_dict, num_splits): +def _maybe_expand_batchnorm_weights( + model_dict: dict[str, torch.Tensor], + state_dict: dict[str, torch.Tensor], + num_splits: int, +) -> dict[str, torch.Tensor]: """Expands the weights of the BatchNorm2d to the size of SplitBatchNorm.""" running_mean = "running_mean" running_var = "running_var" @@ -116,7 +133,9 @@ def _maybe_expand_batchnorm_weights(model_dict, state_dict, num_splits): return state_dict -def _filter_state_dict(state_dict, remove_model_prefix_offset: int = 1): +def _filter_state_dict( + state_dict: dict[str, torch.Tensor], remove_model_prefix_offset: int = 1 +) -> dict[str, torch.Tensor]: """Makes the state_dict compatible with the model. Prevents unexpected key error when loading PyTorch-Lightning checkpoints. @@ -141,7 +160,9 @@ def _filter_state_dict(state_dict, remove_model_prefix_offset: int = 1): return new_state_dict -def _fix_projection_head_keys(state_dict): +def _fix_projection_head_keys( + state_dict: dict[str, torch.Tensor] +) -> dict[str, torch.Tensor]: """Makes the state_dict compatible with the refactored projection heads. TODO: Remove once the models are refactored and the old checkpoints were @@ -173,12 +194,12 @@ def _fix_projection_head_keys(state_dict): def load_from_state_dict( - model, - state_dict, + model: nn.Module, + state_dict: dict[str, torch.Tensor], strict: bool = True, apply_filter: bool = True, num_splits: int = 0, -): +) -> None: """Loads the model weights from the state dictionary.""" # step 1: filter state dict @@ -196,7 +217,9 @@ def load_from_state_dict( model.load_state_dict(state_dict, strict=strict) -def get_model_from_config(cfg, is_cli_call: bool = False) -> SelfSupervisedEmbedding: +def get_model_from_config( + cfg: dict[str, Any], is_cli_call: bool = False +) -> SelfSupervisedEmbedding: checkpoint = cfg["checkpoint"] if torch.cuda.is_available(): device = torch.device("cuda") @@ -233,5 +256,5 @@ def get_model_from_config(cfg, is_cli_call: bool = False) -> SelfSupervisedEmbed if state_dict is not None: load_from_state_dict(model, state_dict) - encoder = SelfSupervisedEmbedding(model, None, None, None) + encoder = SelfSupervisedEmbedding(model, None, None, None) # type: ignore return encoder diff --git a/pyproject.toml b/pyproject.toml index 405d16f07..1b6d634fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,7 +191,6 @@ exclude = '''(?x)( lightly/cli/config/get_config.py | lightly/cli/train_cli.py | lightly/cli/_cli_simclr.py | - lightly/cli/_helpers.py | lightly/loss/ntx_ent_loss.py | lightly/loss/vicreg_loss.py | lightly/loss/tico_loss.py |