Skip to content

Commit

Permalink
Add test for ion charge
Browse files Browse the repository at this point in the history
  • Loading branch information
wukevin committed Dec 4, 2024
1 parent f896eff commit 3b43b2a
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions tests/test_inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Tests for inference dataset.
"""

import pytest

from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw
from chai_lab.data.dataset.structure.all_atom_residue_tokenizer import (
AllAtomResidueTokenizer,
Expand All @@ -14,7 +16,12 @@
from chai_lab.data.sources.rdkit import RefConformerGenerator


def test_malformed_smiles():
@pytest.fixture
def tokenizer() -> AllAtomResidueTokenizer:
return AllAtomResidueTokenizer(RefConformerGenerator())


def test_malformed_smiles(tokenizer: AllAtomResidueTokenizer):
"""Malformed SMILES should be dropped."""
# Zn ligand is malformed (should be [Zn+2])
inputs = [
Expand All @@ -26,7 +33,7 @@ def test_malformed_smiles():
chains = load_chains_from_raw(
inputs,
identifier="test",
tokenizer=AllAtomResidueTokenizer(RefConformerGenerator()),
tokenizer=tokenizer,
)
assert len(chains) == 3
for chain in chains:
Expand All @@ -35,3 +42,13 @@ def test_malformed_smiles():
assert chain.structure_context.num_tokens == len(
chain.entity_data.full_sequence
)


def test_ions_parsing(tokenizer: AllAtomResidueTokenizer):
"""Ions as SMILES strings should carry the correct charge."""
inputs = [Input("[Mg+2]", entity_type=EntityType.LIGAND.value, entity_name="foo")]
chains = load_chains_from_raw(inputs, identifier="foo", tokenizer=tokenizer)
assert len(chains) == 1
chain = chains[0]
assert chain.structure_context.num_atoms == 1
assert chain.structure_context.atom_ref_charge == 2

0 comments on commit 3b43b2a

Please sign in to comment.