-
Notifications
You must be signed in to change notification settings - Fork 190
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix for cases where tokenization failed (#54)
* Append None to structure_contexts on exception * Assign chain_id based on non-null contexts * Add test case * Apply suggestions from code review Co-authored-by: Alex Rogozhnikov <[email protected]> * Add comment --------- Co-authored-by: Alex Rogozhnikov <[email protected]>
- Loading branch information
1 parent
1d3e499
commit bf90332
Showing
2 changed files
with
45 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
""" | ||
Tests for inference dataset. | ||
""" | ||
|
||
from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw | ||
from chai_lab.data.dataset.structure.all_atom_residue_tokenizer import ( | ||
AllAtomResidueTokenizer, | ||
) | ||
from chai_lab.data.parsing.structure.entity_type import EntityType | ||
from chai_lab.data.sources.rdkit import RefConformerGenerator | ||
|
||
|
||
def test_malformed_smiles(): | ||
"""Malformed SMILES should be dropped.""" | ||
# Zn ligand is malformed (should be [Zn+2]) | ||
inputs = [ | ||
Input("RKDESES", entity_type=EntityType.PROTEIN.value, entity_name="foo"), | ||
Input("Zn", entity_type=EntityType.LIGAND.value, entity_name="bar"), | ||
Input("RKEEE", entity_type=EntityType.PROTEIN.value, entity_name="baz"), | ||
Input("EEEEEEEEEEEE", entity_type=EntityType.PROTEIN.value, entity_name="boz"), | ||
] | ||
chains = load_chains_from_raw( | ||
inputs, | ||
identifier="test", | ||
tokenizer=AllAtomResidueTokenizer(RefConformerGenerator()), | ||
) | ||
assert len(chains) == 3 | ||
for chain in chains: | ||
# NOTE this check is only valid because there are no residues that are tokenized per-atom | ||
# Ensures that the entity data and the structure context in each chain are paired correctly | ||
assert chain.structure_context.num_tokens == len( | ||
chain.entity_data.full_sequence | ||
) |