Skip to content

Commit

Permalink
change default path for esm, provide envvar to control download locat…
Browse files Browse the repository at this point in the history
…ion (#61)
  • Loading branch information
arogozhnikov authored Sep 18, 2024
1 parent bf90332 commit 675cd94
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
9 changes: 7 additions & 2 deletions chai_lab/data/dataset/embeddings/esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext
from chai_lab.data.dataset.structure.chain import Chain
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.utils.paths import downloads_path
from chai_lab.utils.tensor_utils import move_data_to_device
from chai_lab.utils.typing import typecheck

Expand All @@ -19,6 +20,8 @@
# Did not find a way to filter specifically that logging message :/
tr_logging.set_verbosity_error()

esm_cache_folder = downloads_path.joinpath("esm")


@contextmanager
def esm_model(model_name: str, device):
Expand All @@ -27,7 +30,9 @@ def esm_model(model_name: str, device):

if len(_esm_model) == 0:
# lazy loading of the model
_esm_model.append(EsmModel.from_pretrained(model_name))
_esm_model.append(
EsmModel.from_pretrained(model_name, cache_dir=esm_cache_folder)
)

[model] = _esm_model
model.to(device)
Expand All @@ -46,7 +51,7 @@ def _get_esm_contexts_for_sequences(
from transformers import EsmTokenizer

model_name = "facebook/esm2_t36_3B_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)
tokenizer = EsmTokenizer.from_pretrained(model_name, cache_dir=esm_cache_folder)

seq2embedding_context = {}

Expand Down
11 changes: 9 additions & 2 deletions chai_lab/utils/paths.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import os
from pathlib import Path
from typing import Final

Expand All @@ -8,6 +9,12 @@
# of anything within repository
repo_root: Final[Path] = Path(__file__).parents[2].absolute()

# weights and helper data is downloaded to CHAI_DOWNLOADS_DIR if provided.
# otherwise we use <repo>/downloads, which is gitignored by default
downloads_path = repo_root.joinpath("downloads")
downloads_path = Path(os.environ.get("CHAI_DOWNLOADS_DIR", downloads_path))


# minimal sanity check in case we start moving things around
assert repo_root.exists()

Expand Down Expand Up @@ -44,7 +51,7 @@ def get_path(self) -> Path:

cached_conformers = Downloadable(
url="https://chaiassets.com/chai1-inference-depencencies/conformers.apkl",
path=repo_root.joinpath("downloads", "conformers.apkl"),
path=downloads_path.joinpath("conformers.apkl"),
)


Expand All @@ -55,7 +62,7 @@ def chai1_component(comp_key: str) -> Path:
"""
assert comp_key.endswith(".pt2")
url = f"https://chaiassets.com/chai1-inference-depencencies/models/{comp_key}"
result = repo_root.joinpath("downloads", "models", comp_key)
result = downloads_path.joinpath("models", comp_key)
if not result.exists():
download(url, result)

Expand Down

0 comments on commit 675cd94

Please sign in to comment.