Skip to content

Commit

Permalink
Fix ligand entity IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
wukevin committed Dec 13, 2024
1 parent a54c6bd commit f8acdeb
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
15 changes: 13 additions & 2 deletions chai_lab/data/dataset/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,19 @@ def raw_inputs_to_entitites_data(
assert residues is not None

# Determine the entity id (unique integer for each distinct sequence)
# NOTE very important for recognizing things like homo polymers
seq: tuple[str, ...] = tuple(res.name for res in residues)
# NOTE because ligand residues have a single "LIG" residue name, the name field
# cannot be used to distinguish them. Instead, we use the sequence field itself,
# which should contain the SMILES string. This is not ideal, as it fails to
# distinguish betwen different SMILES strings that represent the same molecule,
# but should capture most cases.
# We do not need to do special check on glycans because they are specified as a
# string of monosaccharides, which behaves similarly to a string of amino acid
# residues.
seq: tuple[str, ...] = (
(input.sequence,)
if input.entity_type == EntityType.LIGAND.value
else tuple(res.name for res in residues)
)
entity_key: tuple[EntityType, tuple[str, ...]] = (entity_type, seq)
if entity_key in entity_to_index:
entity_id = entity_to_index[entity_key]
Expand Down
51 changes: 51 additions & 0 deletions tests/test_inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@
"""

import pytest
import torch

from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw
from chai_lab.data.dataset.structure.all_atom_residue_tokenizer import (
AllAtomResidueTokenizer,
)
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
from chai_lab.data.dataset.structure.chain import Chain
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.data.sources.rdkit import RefConformerGenerator

Expand Down Expand Up @@ -53,3 +58,49 @@ def test_ions_parsing(tokenizer: AllAtomResidueTokenizer):
assert chain.structure_context.num_atoms == 1
assert chain.structure_context.atom_ref_charge == 2
assert chain.structure_context.atom_ref_element.item() == 12


def test_protein_with_smiles(tokenizer: AllAtomResidueTokenizer):
"""Complex with multiple duplicated protein chains and SMILES ligands."""
# Based on https://www.rcsb.org/structure/1AFS
seq = "MDSISLRVALNDGNFIPVLGFGTTVPEKVAKDEVIKATKIAIDNGFRHFDSAYLYEVEEEVGQAIRSKIEDGTVKREDIFYTSKLWSTFHRPELVRTCLEKTLKSTQLDYVDLYIIHFPMALQPGDIFFPRDEHGKLLFETVDICDTWEAMEKCKDAGLAKSIGVSNFNCRQLERILNKPGLKYKPVCNQVECHLYLNQSKMLDYCKSKDIILVSYCTLGSSRDKTWVDQKSPVLLDDPVLCAIAKKYKQTPALVALRYQLQRGVVPLIRSFNAKRIKELTQVFEFQLASEDMKALDGLNRNFRYNNAKYFDDHPNHPFTDEN"
nap = "NC(=O)c1ccc[n+](c1)[CH]2O[CH](CO[P]([O-])(=O)O[P](O)(=O)OC[CH]3O[CH]([CH](O[P](O)(O)=O)[CH]3O)n4cnc5c(N)ncnc45)[CH](O)[CH]2O"
tes = "O=C4C=C3C(C2CCC1(C(CCC1O)C2CC3)C)(C)CC4"
inputs = [
Input(seq, EntityType.PROTEIN.value, entity_name="A"),
Input(seq, EntityType.PROTEIN.value, entity_name="B"),
Input(nap, EntityType.LIGAND.value, entity_name="C"),
Input(nap, EntityType.LIGAND.value, entity_name="D"),
Input(tes, EntityType.LIGAND.value, entity_name="E"),
Input(tes, EntityType.LIGAND.value, entity_name="F"),
]
chains: list[Chain] = load_chains_from_raw(inputs, tokenizer=tokenizer)
assert len(chains) == len(inputs)

example = AllAtomStructureContext.merge(
[chain.structure_context for chain in chains]
)

# Should be 1 protein chain, 2 ligand chains
assert example.token_entity_id.unique().numel() == 3
assert example.token_asym_id.unique().numel() == 6

# Check protein chains
prot_entity_ids = example.token_entity_id[
example.token_entity_type == EntityType.PROTEIN.value
]
assert torch.unique(prot_entity_ids).numel() == 1
prot_sym_ids = example.token_sym_id[
example.token_entity_type == EntityType.PROTEIN.value
]
assert torch.unique(prot_sym_ids).numel() == 2 # Two copies of this chain

# Check ligand chains
lig_entity_ids = example.token_entity_id[
example.token_entity_type == EntityType.LIGAND.value
]
assert torch.unique(lig_entity_ids).numel() == 2
lig_sym_ids = example.token_sym_id[
example.token_entity_type == EntityType.LIGAND.value
]
assert torch.unique(lig_sym_ids).numel() == 2 # Two copies of each ligand

0 comments on commit f8acdeb

Please sign in to comment.