Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move ESM to/from GPU once per complex, not once per chain #5

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions chai_lab/data/dataset/embeddings/esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,44 +33,58 @@ def esm_model(model_name: str, device):
model.to(device)
model.eval()
yield model
# move model back to CPU
model.to("cpu")
model.to("cpu") # move model back to CPU when done


def embedding_context_from_sequence(seq: str, device) -> EmbeddingContext:
def _get_esm_contexts_for_sequences(
prot_sequences: set[str], device
) -> dict[str, EmbeddingContext]:
if len(prot_sequences) == 0:
return {} # skip loading ESM

# local import, requires huggingface transformers
from transformers import EsmTokenizer

model_name = "facebook/esm2_t36_3B_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)

inputs = tokenizer(seq, return_tensors="pt")
inputs = move_data_to_device(dict(**inputs), device=device)
seq2embedding_context = {}

with torch.no_grad():
with esm_model(model_name=model_name, device=device) as model:
outputs = model(**inputs)
for seq in prot_sequences:
inputs = tokenizer(seq, return_tensors="pt")
inputs = move_data_to_device(dict(**inputs), device=device)
outputs = model(**inputs)
# remove BOS/EOS, back to CPU
esm_embeddings = outputs.last_hidden_state[0, 1:-1].to("cpu")
seq_len, _emb_dim = esm_embeddings.shape
assert seq_len == len(seq)

seq2embedding_context[seq] = EmbeddingContext(
esm_embeddings=esm_embeddings
)

# remove BOS/EOS, back to CPU
esm_embeddings = outputs.last_hidden_state[0, 1:-1].to("cpu")
seq_len, _emb_dim = esm_embeddings.shape
assert seq_len == len(seq)
return EmbeddingContext(esm_embeddings=esm_embeddings)
return seq2embedding_context


@typecheck
def get_esm_embedding_context(chains: list[Chain], device) -> EmbeddingContext:
# device is used for computing, but result is still on CPU
chain_embs = []

protein_seq2emb_context = _get_esm_contexts_for_sequences(
prot_sequences=set(
chain.entity_data.sequence
for chain in chains
if chain.entity_data.entity_type == EntityType.PROTEIN
),
device=device,
)

chain_embs = []
for chain in chains:
if chain.entity_data.entity_type == EntityType.PROTEIN:
emb = embedding_context_from_sequence(
# modified residues represented as X
seq=chain.entity_data.sequence,
device=device,
)
chain_embs.append(emb)
chain_embs.append(protein_seq2emb_context[chain.entity_data.sequence])
else:
# embed non-proteins with zeros
chain_embs.append(
Expand Down
Loading