From 3b43b2a2480127701bf437432c3868c152a63d08 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 4 Dec 2024 23:27:05 +0000 Subject: [PATCH] Add test for ion charge --- tests/test_inference_dataset.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/test_inference_dataset.py b/tests/test_inference_dataset.py index da7fd66..fb98616 100644 --- a/tests/test_inference_dataset.py +++ b/tests/test_inference_dataset.py @@ -6,6 +6,8 @@ Tests for inference dataset. """ +import pytest + from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw from chai_lab.data.dataset.structure.all_atom_residue_tokenizer import ( AllAtomResidueTokenizer, @@ -14,7 +16,12 @@ from chai_lab.data.sources.rdkit import RefConformerGenerator -def test_malformed_smiles(): +@pytest.fixture +def tokenizer() -> AllAtomResidueTokenizer: + return AllAtomResidueTokenizer(RefConformerGenerator()) + + +def test_malformed_smiles(tokenizer: AllAtomResidueTokenizer): """Malformed SMILES should be dropped.""" # Zn ligand is malformed (should be [Zn+2]) inputs = [ @@ -26,7 +33,7 @@ def test_malformed_smiles(): chains = load_chains_from_raw( inputs, identifier="test", - tokenizer=AllAtomResidueTokenizer(RefConformerGenerator()), + tokenizer=tokenizer, ) assert len(chains) == 3 for chain in chains: @@ -35,3 +42,13 @@ def test_malformed_smiles(): assert chain.structure_context.num_tokens == len( chain.entity_data.full_sequence ) + + +def test_ions_parsing(tokenizer: AllAtomResidueTokenizer): + """Ions as SMILES strings should carry the correct charge.""" + inputs = [Input("[Mg+2]", entity_type=EntityType.LIGAND.value, entity_name="foo")] + chains = load_chains_from_raw(inputs, identifier="foo", tokenizer=tokenizer) + assert len(chains) == 1 + chain = chains[0] + assert chain.structure_context.num_atoms == 1 + assert chain.structure_context.atom_ref_charge == 2