Skip to content

Commit

Permalink
Don't warn on missing MSAs for non-protein entities.
Browse files Browse the repository at this point in the history
  • Loading branch information
wukevin committed Dec 3, 2024
1 parent 2d2646b commit 6a78e06
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
19 changes: 10 additions & 9 deletions chai_lab/data/dataset/msas/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
parse_aligned_pqt_to_msa_context,
)
from chai_lab.data.parsing.msas.data_source import MSADataSource
from chai_lab.data.parsing.structure.entity_type import EntityType

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -43,14 +44,14 @@ def get_msa_contexts(

# MSAs are constructed based on sequence, so use the unique sequences present
# in input chains to determine the MSAs that need to be loaded

def get_msa_contexts_for_seq(seq) -> MSAContext:
def get_msa_contexts_for_seq(seq: str, etype: EntityType) -> MSAContext:
path = msa_directory / expected_basename(seq)
if not path.is_file():
if seq != "X":
# Don't warn for the special "X" sequence
# If the MSA is missing, or the query is not a protein, return an empty MSA
if not path.is_file() or etype != EntityType.PROTEIN:
if etype == EntityType.PROTEIN:
# Warn for proteins that have missing MSAs
logger.warning(f"No MSA found for sequence: {seq}")
[tokenized_seq] = tokenize_sequences_to_arrays([seq])[0]
[tokenized_seq], _ = tokenize_sequences_to_arrays([seq])
return MSAContext.create_single_seq(
MSADataSource.QUERY, tokens=torch.from_numpy(tokenized_seq)
)
Expand All @@ -61,9 +62,9 @@ def get_msa_contexts_for_seq(seq) -> MSAContext:
# For each chain, either fetch the corresponding MSA or create an empty MSA if it is missing
# + reindex to handle residues that are tokenized per-atom (this also crops if necessary)
msa_contexts = [
get_msa_contexts_for_seq(chain.entity_data.sequence)[
:, chain.structure_context.token_residue_index
]
get_msa_contexts_for_seq(
seq=chain.entity_data.sequence, etype=chain.entity_data.entity_type
)[:, chain.structure_context.token_residue_index]
for chain in chains
]

Expand Down
3 changes: 2 additions & 1 deletion chai_lab/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from chai_lab.chai1 import run_inference

logging.basicConfig(level=logging.INFO)

CITATION = """
@article{Chai-1-Technical-Report,
title = {Chai-1: Decoding the molecular interactions of life},
Expand Down Expand Up @@ -38,5 +40,4 @@ def cli():


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
cli()

0 comments on commit 6a78e06

Please sign in to comment.