diff --git a/chai_lab/data/dataset/embeddings/esm.py b/chai_lab/data/dataset/embeddings/esm.py index 22c2632..5c565ed 100644 --- a/chai_lab/data/dataset/embeddings/esm.py +++ b/chai_lab/data/dataset/embeddings/esm.py @@ -7,6 +7,7 @@ from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext from chai_lab.data.dataset.structure.chain import Chain from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.utils.paths import downloads_path from chai_lab.utils.tensor_utils import move_data_to_device from chai_lab.utils.typing import typecheck @@ -19,6 +20,8 @@ # Did not find a way to filter specifically that logging message :/ tr_logging.set_verbosity_error() +esm_cache_folder = downloads_path.joinpath("esm") + @contextmanager def esm_model(model_name: str, device): @@ -27,7 +30,9 @@ def esm_model(model_name: str, device): if len(_esm_model) == 0: # lazy loading of the model - _esm_model.append(EsmModel.from_pretrained(model_name)) + _esm_model.append( + EsmModel.from_pretrained(model_name, cache_dir=esm_cache_folder) + ) [model] = _esm_model model.to(device) @@ -46,7 +51,7 @@ def _get_esm_contexts_for_sequences( from transformers import EsmTokenizer model_name = "facebook/esm2_t36_3B_UR50D" - tokenizer = EsmTokenizer.from_pretrained(model_name) + tokenizer = EsmTokenizer.from_pretrained(model_name, cache_dir=esm_cache_folder) seq2embedding_context = {} diff --git a/chai_lab/utils/paths.py b/chai_lab/utils/paths.py index d816cd6..72e75e0 100644 --- a/chai_lab/utils/paths.py +++ b/chai_lab/utils/paths.py @@ -1,4 +1,5 @@ import dataclasses +import os from pathlib import Path from typing import Final @@ -8,6 +9,12 @@ # of anything within repository repo_root: Final[Path] = Path(__file__).parents[2].absolute() +# weights and helper data is downloaded to CHAI_DOWNLOADS_DIR if provided. +# otherwise we use /downloads, which is gitignored by default +downloads_path = repo_root.joinpath("downloads") +downloads_path = Path(os.environ.get("CHAI_DOWNLOADS_DIR", downloads_path)) + + # minimal sanity check in case we start moving things around assert repo_root.exists() @@ -44,7 +51,7 @@ def get_path(self) -> Path: cached_conformers = Downloadable( url="https://chaiassets.com/chai1-inference-depencencies/conformers.apkl", - path=repo_root.joinpath("downloads", "conformers.apkl"), + path=downloads_path.joinpath("conformers.apkl"), ) @@ -55,7 +62,7 @@ def chai1_component(comp_key: str) -> Path: """ assert comp_key.endswith(".pt2") url = f"https://chaiassets.com/chai1-inference-depencencies/models/{comp_key}" - result = repo_root.joinpath("downloads", "models", comp_key) + result = downloads_path.joinpath("models", comp_key) if not result.exists(): download(url, result)