Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve hyperparameter logging #153

Merged
merged 1 commit into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# --------- pytorch-ie --------- #
pytorch-ie>=0.28.0,<0.30.0
pie-datasets>=0.8.1,<0.9.0
pie-modules>=0.8.0,<0.9.0
pie-modules>=0.9.0,<0.10.0

# --------- hydra --------- #
hydra-core>=1.3.0
Expand Down
2 changes: 1 addition & 1 deletion src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:

if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)

log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
Expand Down
16 changes: 13 additions & 3 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
# https://github.com/ashleve/pyrootutils
# ------------------------------------------------------------------------------------ #

import os
from typing import Any, Dict, List, Optional, Tuple

import hydra
import pytorch_lightning as pl
from hydra.utils import get_class
from omegaconf import DictConfig
from pie_datasets import DatasetDict
from pie_modules.models.interface import RequiresTaskmoduleConfig
from pytorch_ie.core import PyTorchIEModel, TaskModule
from pytorch_ie.models import * # noqa: F403
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
Expand Down Expand Up @@ -121,7 +122,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
log.info(f"Instantiating model <{cfg.model._target_}>")
# get additional model arguments
additional_model_kwargs: Dict[str, Any] = {}
model_cls = get_class(cfg.model["_target_"])
model_cls = hydra.utils.get_class(cfg.model["_target_"])
# NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE!
# SEE EXAMPLES BELOW.
if issubclass(model_cls, RequiresNumClasses):
Expand All @@ -134,6 +135,9 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
if isinstance(taskmodule, ChangesTokenizerVocabSize):
additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer)

if issubclass(model_cls, RequiresTaskmoduleConfig):
additional_model_kwargs["taskmodule_config"] = taskmodule.config

# initialize the model
model: PyTorchIEModel = hydra.utils.instantiate(
cfg.model, _convert_="partial", **additional_model_kwargs
Expand All @@ -160,7 +164,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:

if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)

if cfg.model_save_dir is not None:
log.info(f"Save taskmodule to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
Expand All @@ -177,6 +181,12 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
best_ckpt_path = trainer.checkpoint_callback.best_model_path
if best_ckpt_path != "":
log.info(f"Best ckpt path: {best_ckpt_path}")
best_checkpoint_file = os.path.basename(best_ckpt_path)
utils.log_hyperparameters(
logger=logger,
best_checkpoint=best_checkpoint_file,
checkpoint_dir=trainer.checkpoint_callback.dirpath,
)

if not cfg.trainer.get("fast_dev_run"):
if cfg.model_save_dir is not None:
Expand Down
82 changes: 44 additions & 38 deletions src/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import logging
from importlib.util import find_spec
from typing import List, Optional, Union

from omegaconf import DictConfig, OmegaConf
from pie_modules.models.interface import RequiresTaskmoduleConfig
from pytorch_ie import PyTorchIEModel, TaskModule
from pytorch_lightning.loggers import Logger
from pytorch_lightning.utilities import rank_zero_only


Expand All @@ -22,7 +27,14 @@ def get_pylogger(name=__name__) -> logging.Logger:


@rank_zero_only
def log_hyperparameters(object_dict: dict) -> None:
def log_hyperparameters(
logger: Optional[List[Logger]] = None,
config: Optional[Union[dict, DictConfig]] = None,
model: Optional[PyTorchIEModel] = None,
taskmodule: Optional[TaskModule] = None,
key_prefix: str = "_",
**kwargs,
) -> None:
"""Controls which config parts are saved by lightning loggers.

Additional saves:
Expand All @@ -31,48 +43,42 @@ def log_hyperparameters(object_dict: dict) -> None:

hparams = {}

cfg = object_dict["cfg"]
model = object_dict["model"]
taskmodule = object_dict["taskmodule"]
trainer = object_dict["trainer"]

if not trainer.logger:
if not logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return

# choose which parts of hydra config will be saved to loggers
# here we use the taskmodule/model config how it is after preparation/initialization
hparams["taskmodule"] = taskmodule._config()
hparams["model"] = model._config()

# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)

hparams["dataset"] = cfg["dataset"]
hparams["trainer"] = cfg["trainer"]

hparams["callbacks"] = cfg.get("callbacks")
hparams["extras"] = cfg.get("extras")

hparams["pipeline_type"] = cfg.get("pipeline_type")
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.get("ckpt_path")
hparams["seed"] = cfg.get("seed")

hparams["monitor_metric"] = cfg.get("monitor_metric")
hparams["monitor_mode"] = cfg.get("monitor_mode")

hparams["model_save_dir"] = cfg.get("model_save_dir")
# this is just for backwards compatibility: usually, the taskmodule_config should be passed to
# the model and, thus, be logged there automatically
if model is not None and not isinstance(model, RequiresTaskmoduleConfig):
if taskmodule is None:
raise ValueError(
"If model is not an instance of RequiresTaskmoduleConfig, taskmodule must be passed!"
)
# here we use the taskmodule/model config how it is after preparation/initialization
hparams["taskmodule_config"] = taskmodule.config

if model is not None:
# save number of model parameters
hparams[f"{key_prefix}num_params/total"] = sum(p.numel() for p in model.parameters())
hparams[f"{key_prefix}num_params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams[f"{key_prefix}num_params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)

if config is not None:
hparams[f"{key_prefix}config"] = (
OmegaConf.to_container(config, resolve=True) if OmegaConf.is_config(config) else config
)

# add additional hparams
for k, v in kwargs.items():
hparams[f"{key_prefix}{k}"] = v

# send hparams to all loggers
for logger in trainer.loggers:
logger.log_hyperparams(hparams)
for current_logger in logger:
current_logger.log_hyperparams(hparams)


def close_loggers() -> None:
Expand Down
Loading