Skip to content

Commit

Permalink
Fix for cases where tokenization failed (#54)
Browse files Browse the repository at this point in the history
* Append None to structure_contexts on exception

* Assign chain_id based on non-null contexts

* Add test case

* Apply suggestions from code review

Co-authored-by: Alex Rogozhnikov <[email protected]>

* Add comment

---------

Co-authored-by: Alex Rogozhnikov <[email protected]>
  • Loading branch information
wukevin and arogozhnikov authored Sep 17, 2024
1 parent 1d3e499 commit bf90332
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
18 changes: 12 additions & 6 deletions chai_lab/data/dataset/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
AllAtomResidueTokenizer,
_make_sym_ids,
)
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
from chai_lab.data.dataset.structure.chain import Chain
from chai_lab.data.parsing.fasta import get_residue_name, read_fasta
from chai_lab.data.parsing.input_validation import (
Expand Down Expand Up @@ -164,19 +167,22 @@ def load_chains_from_raw(
)

# Tokenize the entity data
structure_contexts = []
structure_contexts: list[AllAtomStructureContext | None] = []
sym_ids = _make_sym_ids([x.entity_id for x in entities])
for idx, (entity_data, sym_id) in enumerate(zip(entities, sym_ids)):
for entity_data, sym_id in zip(entities, sym_ids):
# chain index should not count null contexts that result from failed tokenization
chain_index = sum(ctx is not None for ctx in structure_contexts) + 1
try:
tok = tokenizer._tokenize_entity(
entity_data,
chain_id=idx + 1,
chain_id=chain_index,
sym_id=sym_id,
)
structure_contexts.append(tok)
except Exception:
logger.exception(f"Failed to tokenize input {inputs[idx]}")

logger.exception(f"Failed to tokenize input {entity_data=} {sym_id=}")
tok = None
structure_contexts.append(tok)
assert len(structure_contexts) == len(entities)
# Join the untokenized entity data with the tokenized chain data, removing
# chains we failed to tokenize
chains = [
Expand Down
33 changes: 33 additions & 0 deletions tests/test_inference_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Tests for inference dataset.
"""

from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw
from chai_lab.data.dataset.structure.all_atom_residue_tokenizer import (
AllAtomResidueTokenizer,
)
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.data.sources.rdkit import RefConformerGenerator


def test_malformed_smiles():
"""Malformed SMILES should be dropped."""
# Zn ligand is malformed (should be [Zn+2])
inputs = [
Input("RKDESES", entity_type=EntityType.PROTEIN.value, entity_name="foo"),
Input("Zn", entity_type=EntityType.LIGAND.value, entity_name="bar"),
Input("RKEEE", entity_type=EntityType.PROTEIN.value, entity_name="baz"),
Input("EEEEEEEEEEEE", entity_type=EntityType.PROTEIN.value, entity_name="boz"),
]
chains = load_chains_from_raw(
inputs,
identifier="test",
tokenizer=AllAtomResidueTokenizer(RefConformerGenerator()),
)
assert len(chains) == 3
for chain in chains:
# NOTE this check is only valid because there are no residues that are tokenized per-atom
# Ensures that the entity data and the structure context in each chain are paired correctly
assert chain.structure_context.num_tokens == len(
chain.entity_data.full_sequence
)

0 comments on commit bf90332

Please sign in to comment.