Skip to content

Commit

Permalink
functionality to use pretrained pie model (#168)
Browse files Browse the repository at this point in the history
* functionality to use pretrained pie model

* fix NoneType error

* add documentation to src/train.py

---------

Co-authored-by: ArneBinder <[email protected]>
  • Loading branch information
Bhuvanesh-Verma and ArneBinder authored Jul 19, 2024
1 parent fc55427 commit c9ac6f4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
3 changes: 3 additions & 0 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ monitor_mode: "max"
# seed for random number generators in pytorch, numpy and python.random
seed: null

# path to pretrained pytorch-ie model that updates the weights of base model with pretrained pie model
pretrained_pie_model_path: null

# simply provide checkpoint path to resume training
ckpt_path: null

Expand Down
24 changes: 23 additions & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
import pytorch_lightning as pl
from omegaconf import DictConfig
from pie_datasets import DatasetDict
from pie_modules.models import * # noqa: F403
from pie_modules.models.interface import RequiresTaskmoduleConfig
from pie_modules.taskmodules import * # noqa: F403
from pytorch_ie import AutoModel
from pytorch_ie.core import PyTorchIEModel, TaskModule
from pytorch_ie.models import * # noqa: F403
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
Expand Down Expand Up @@ -140,9 +143,28 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:

# initialize the model
model: PyTorchIEModel = hydra.utils.instantiate(
cfg.model, _convert_="partial", **additional_model_kwargs
cfg.model,
_convert_="partial",
# In the case of loading weights from a pretrained PIE model, we do not need to download the base (transformer) model in the model constructors. We disable that by passing is_from_pretrained=True in these cases.
is_from_pretrained=cfg.get("pretrained_pie_model_path", None) is not None,
**additional_model_kwargs,
)

if cfg.get("pretrained_pie_model_path", None) is not None:
pie_model = AutoModel.from_pretrained(cfg["pretrained_pie_model_path"])
loaded_state_dict = pie_model.state_dict()
has_prefix_mapping = cfg.get("pretrained_pie_model_prefix_mapping", None) is not None
if has_prefix_mapping:
state_dict_to_load = {}
for prefix_from, prefix_to in cfg["pretrained_pie_model_prefix_mapping"].items():
for name, value in loaded_state_dict.items():
if name.startswith(prefix_from):
new_name = prefix_to + name[len(prefix_from) :]
state_dict_to_load[new_name] = value
else:
state_dict_to_load = loaded_state_dict
model.load_state_dict(state_dict_to_load, strict=not has_prefix_mapping)

log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks")

Expand Down

0 comments on commit c9ac6f4

Please sign in to comment.