diff --git a/chai_lab/data/dataset/inference_dataset.py b/chai_lab/data/dataset/inference_dataset.py index 4b04f8f..7ac2fd7 100644 --- a/chai_lab/data/dataset/inference_dataset.py +++ b/chai_lab/data/dataset/inference_dataset.py @@ -126,8 +126,19 @@ def raw_inputs_to_entitites_data( assert residues is not None # Determine the entity id (unique integer for each distinct sequence) - # NOTE very important for recognizing things like homo polymers - seq: tuple[str, ...] = tuple(res.name for res in residues) + # NOTE because ligand residues have a single "LIG" residue name, the name field + # cannot be used to distinguish them. Instead, we use the sequence field itself, + # which should contain the SMILES string. This is not ideal, as it fails to + # distinguish betwen different SMILES strings that represent the same molecule, + # but should capture most cases. + # We do not need to do special check on glycans because they are specified as a + # string of monosaccharides, which behaves similarly to a string of amino acid + # residues. + seq: tuple[str, ...] = ( + (input.sequence,) + if input.entity_type == EntityType.LIGAND.value + else tuple(res.name for res in residues) + ) entity_key: tuple[EntityType, tuple[str, ...]] = (entity_type, seq) if entity_key in entity_to_index: entity_id = entity_to_index[entity_key] diff --git a/tests/test_inference_dataset.py b/tests/test_inference_dataset.py index 6269351..c769eb8 100644 --- a/tests/test_inference_dataset.py +++ b/tests/test_inference_dataset.py @@ -7,11 +7,16 @@ """ import pytest +import torch from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw from chai_lab.data.dataset.structure.all_atom_residue_tokenizer import ( AllAtomResidueTokenizer, ) +from chai_lab.data.dataset.structure.all_atom_structure_context import ( + AllAtomStructureContext, +) +from chai_lab.data.dataset.structure.chain import Chain from chai_lab.data.parsing.structure.entity_type import EntityType from chai_lab.data.sources.rdkit import RefConformerGenerator @@ -53,3 +58,49 @@ def test_ions_parsing(tokenizer: AllAtomResidueTokenizer): assert chain.structure_context.num_atoms == 1 assert chain.structure_context.atom_ref_charge == 2 assert chain.structure_context.atom_ref_element.item() == 12 + + +def test_protein_with_smiles(tokenizer: AllAtomResidueTokenizer): + """Complex with multiple duplicated protein chains and SMILES ligands.""" + # Based on https://www.rcsb.org/structure/1AFS + seq = "MDSISLRVALNDGNFIPVLGFGTTVPEKVAKDEVIKATKIAIDNGFRHFDSAYLYEVEEEVGQAIRSKIEDGTVKREDIFYTSKLWSTFHRPELVRTCLEKTLKSTQLDYVDLYIIHFPMALQPGDIFFPRDEHGKLLFETVDICDTWEAMEKCKDAGLAKSIGVSNFNCRQLERILNKPGLKYKPVCNQVECHLYLNQSKMLDYCKSKDIILVSYCTLGSSRDKTWVDQKSPVLLDDPVLCAIAKKYKQTPALVALRYQLQRGVVPLIRSFNAKRIKELTQVFEFQLASEDMKALDGLNRNFRYNNAKYFDDHPNHPFTDEN" + nap = "NC(=O)c1ccc[n+](c1)[CH]2O[CH](CO[P]([O-])(=O)O[P](O)(=O)OC[CH]3O[CH]([CH](O[P](O)(O)=O)[CH]3O)n4cnc5c(N)ncnc45)[CH](O)[CH]2O" + tes = "O=C4C=C3C(C2CCC1(C(CCC1O)C2CC3)C)(C)CC4" + inputs = [ + Input(seq, EntityType.PROTEIN.value, entity_name="A"), + Input(seq, EntityType.PROTEIN.value, entity_name="B"), + Input(nap, EntityType.LIGAND.value, entity_name="C"), + Input(nap, EntityType.LIGAND.value, entity_name="D"), + Input(tes, EntityType.LIGAND.value, entity_name="E"), + Input(tes, EntityType.LIGAND.value, entity_name="F"), + ] + chains: list[Chain] = load_chains_from_raw(inputs, tokenizer=tokenizer) + assert len(chains) == len(inputs) + + example = AllAtomStructureContext.merge( + [chain.structure_context for chain in chains] + ) + + # Should be 1 protein chain, 2 ligand chains + assert example.token_entity_id.unique().numel() == 3 + assert example.token_asym_id.unique().numel() == 6 + + # Check protein chains + prot_entity_ids = example.token_entity_id[ + example.token_entity_type == EntityType.PROTEIN.value + ] + assert torch.unique(prot_entity_ids).numel() == 1 + prot_sym_ids = example.token_sym_id[ + example.token_entity_type == EntityType.PROTEIN.value + ] + assert torch.unique(prot_sym_ids).numel() == 2 # Two copies of this chain + + # Check ligand chains + lig_entity_ids = example.token_entity_id[ + example.token_entity_type == EntityType.LIGAND.value + ] + assert torch.unique(lig_entity_ids).numel() == 2 + lig_sym_ids = example.token_sym_id[ + example.token_entity_type == EntityType.LIGAND.value + ] + assert torch.unique(lig_sym_ids).numel() == 2 # Two copies of each ligand