Skip to content

Commit

Permalink
Merge pull request #9 from chaidiscovery/alex/chailab
Browse files Browse the repository at this point in the history
Warn user about potentially wrong EntityType
  • Loading branch information
arogozhnikov authored Sep 11, 2024
2 parents b6b6b6e + f6a0fa9 commit aeb92d1
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 24 deletions.
33 changes: 27 additions & 6 deletions chai_lab/data/dataset/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
_make_sym_ids,
)
from chai_lab.data.dataset.structure.chain import Chain
from chai_lab.data.parsing.fasta import parse_modified_fasta_sequence, read_fasta
from chai_lab.data.parsing.fasta import get_residue_name, read_fasta
from chai_lab.data.parsing.input_validation import (
constituents_of_modified_fasta,
identify_potential_entity_types,
)
from chai_lab.data.parsing.structure.all_atom_entity_data import AllAtomEntityData
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.data.parsing.structure.residue import Residue, get_restype
Expand Down Expand Up @@ -50,7 +54,7 @@ def get_lig_residues(
def get_polymer_residues(
residue_names: list[str],
entity_type: EntityType,
):
) -> list[Residue]:
residues = []
for i, residue_name in enumerate(residue_names):
residues.append(
Expand Down Expand Up @@ -94,10 +98,17 @@ def raw_inputs_to_entitites_data(
residues = get_lig_residues(smiles=input.sequence)

case EntityType.PROTEIN | EntityType.RNA | EntityType.DNA:
parsed_sequence: list = parse_modified_fasta_sequence(
input.sequence, entity_type
parsed_sequence: list | None = constituents_of_modified_fasta(
input.sequence
)
residues = get_polymer_residues(parsed_sequence, entity_type)
assert (
parsed_sequence is not None
), f"incorrect FASTA: {parsed_sequence=} "
expanded_sequence = [
get_residue_name(r, entity_type=entity_type) if len(r) == 1 else r
for r in parsed_sequence
]
residues = get_polymer_residues(expanded_sequence, entity_type)
case _:
raise NotImplementedError
assert residues is not None
Expand Down Expand Up @@ -192,7 +203,7 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list
for desc, sequence in sequences:
logger.info(f"[fasta] [{fasta_file}] {desc} {len(sequence)}")
# get the type of the sequence
entity_str = desc.split("|")[0].strip()
entity_str = desc.split("|")[0].strip().lower()
match entity_str:
case "protein":
entity_type = EntityType.PROTEIN
Expand All @@ -204,6 +215,16 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list
entity_type = EntityType.DNA
case _:
raise ValueError(f"{entity_str} is not a valid entity type")

possible_types = identify_potential_entity_types(sequence)
if len(possible_types) == 0:
logger.error(f"Provided {sequence=} is invalid")
elif entity_type not in possible_types:
types_fmt = "/".join(str(et.name) for et in possible_types)
logger.warning(
f"Provided {sequence=} is likely {types_fmt}, not {entity_type.name}"
)

retval.append(Input(sequence, entity_type.value))
total_length += len(sequence)

Expand Down
1 change: 1 addition & 0 deletions chai_lab/data/dataset/structure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_centre_atom_name(residue_name: str) -> str:
}:
return "C1'"
else:
assert len(residue_name) == 3, "residue expected"
return "CA"


Expand Down
20 changes: 2 additions & 18 deletions chai_lab/data/parsing/fasta.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import re
from pathlib import Path
from typing import Iterable

Expand Down Expand Up @@ -35,6 +34,8 @@ def get_residue_name(
fasta_code: str,
entity_type: EntityType,
) -> str:
if len(fasta_code) != 1:
raise ValueError("Cannot handle non-single chars: {}".format(fasta_code))
match entity_type:
case EntityType.PROTEIN:
return restype_1to3_with_x.get(fasta_code, "UNK")
Expand All @@ -44,20 +45,3 @@ def get_residue_name(
return nucleic_acid_1_to_name.get((fasta_code, entity_type), unk)
case _:
raise ValueError(f"Invalid polymer entity type {entity_type}")


def parse_modified_fasta_sequence(sequence: str, entity_type: EntityType) -> list[str]:
"""
Parses a fasta-like string containing modified residues in
brackets, returns a list of residue codes
"""
pattern = r"[A-Z]|\[[A-Z0-9]+\]"
residues = re.findall(pattern, sequence)

# get full residue name if regular fasta code (not in brackets),
# otherwise return what user passed in brackets
parsed_residues = [
get_residue_name(x, entity_type) if not x.startswith("[") else x.strip("[]")
for x in residues
]
return parsed_residues
74 changes: 74 additions & 0 deletions chai_lab/data/parsing/input_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Simple heuristics that can help with identification of EntityType
"""

import string
from string import ascii_letters

from chai_lab.data.parsing.structure.entity_type import EntityType


def constituents_of_modified_fasta(x: str) -> list[str] | None:
"""
Accepts RNA/DNA inputs: 'agtc', 'AGT(ASP)TG', etc. Does not accept SMILES strings.
Returns constituents, e.g, [A, G, T, ASP, T, G] or None if string is incorrect.
Everything in returned list is single character, except for blocks specified in brackets.
"""
x = x.strip().upper()
# it is a bit strange that digits are here, but [NH2] was in one protein
allowed_chars = ascii_letters + "()" + string.digits
if not all(letter in allowed_chars for letter in x):
return None

current_modified: str | None = None

constituents = []
for letter in x:
if letter == "(":
if current_modified is not None:
return None # double open bracket
current_modified = ""
elif letter == ")":
if current_modified is None:
return None # closed without opening
if len(current_modified) <= 1:
return None # empty modification: () or single (K)
constituents.append(current_modified)
current_modified = None
else:
if current_modified is not None:
current_modified += letter
else:
if letter not in ascii_letters:
return None # strange single-letter residue
constituents.append(letter)
if current_modified is not None:
return None # did not close bracket
return constituents


def identify_potential_entity_types(sequence: str) -> list[EntityType]:
"""
Provided FASTA sequence or smiles, lists which entities those could be.
Returns an empty list if sequence is invalid for all entity types.
"""
sequence = sequence.strip()
if len(sequence) == 0:
return []
possible_entity_types = []

constituents = constituents_of_modified_fasta(sequence)
if constituents is not None:
# this can be RNA/DNA/protein.
one_letter_constituents = set(x for x in constituents if len(x) == 1)
if set.issubset(one_letter_constituents, set("AGTC")):
possible_entity_types.append(EntityType.DNA)
if set.issubset(one_letter_constituents, set("AGUC")):
possible_entity_types.append(EntityType.RNA)
if "U" not in one_letter_constituents:
possible_entity_types.append(EntityType.PROTEIN)

ascii_symbols = string.ascii_letters + string.digits + ".-+=#$%:/\\[]()<>@"
if set.issubset(set(sequence.upper()), set(ascii_symbols)):
possible_entity_types.append(EntityType.LIGAND)
return possible_entity_types
Empty file added tests/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions tests/example_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
example_ligands = [
"C",
"O",
"C(C1C(C(C(C(O1)O)O)O)O)O",
"[O-]S(=O)(=O)[O-]",
"CC1=C(C(CCC1)(C)C)/C=C/C(=C/C=C/C(=C/C=O)/C)/C",
"CCC1=C(c2cc3c(c(c4n3[Mg]56[n+]2c1cc7n5c8c(c9[n+]6c(c4)C(C9CCC(=O)OC/C=C(\C)/CCC[C@H](C)CCC[C@H](C)CCCC(C)C)C)[C@H](C(=O)c8c7C)C(=O)OC)C)C=C)C=O",
r"C=CC1=C(C)/C2=C/c3c(C)c(CCC(=O)O)c4n3[Fe@TB16]35<-N2=C1/C=c1/c(C)c(C=C)/c(n13)=C/C1=N->5/C(=C\4)C(CCC(=O)O)=C1C",
# different ions
"[Mg+2]",
"[Na+]",
"[Cl-]",
]

example_proteins = [
"AGSHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVR",
"(KCJ)(SEP)(PPN)(B3S)(BAL)(PPN)K(NH2)",
"XDHPX",
]


example_rna = [
"AGUGGCUA",
"AAAAAA",
"AGUC",
]

example_dna = [
"AGTGGCTA",
"AAAAAA",
"AGTC",
]
48 changes: 48 additions & 0 deletions tests/test_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from chai_lab.data.parsing.input_validation import (
constituents_of_modified_fasta,
identify_potential_entity_types,
)
from chai_lab.data.parsing.structure.entity_type import EntityType

from .example_inputs import example_dna, example_ligands, example_proteins, example_rna


def test_simple_protein_fasta():
parts = constituents_of_modified_fasta("RKDES")
assert parts is not None
assert all(x == y for x, y in zip(parts, ["R", "K", "D", "E", "S"]))


def test_modified_protein_fasta():
parts = constituents_of_modified_fasta("(KCJ)(SEP)(PPN)(B3S)(BAL)(PPN)KX(NH2)")
assert parts is not None
expected = ["KCJ", "SEP", "PPN", "B3S", "BAL", "PPN", "K", "X", "NH2"]
assert all(x == y for x, y in zip(parts, expected))


def test_rna_fasta():
seq = "ACUGACG"
parts = constituents_of_modified_fasta(seq)
assert parts is not None
assert all(x == y for x, y in zip(parts, seq))


def test_dna_fasta():
seq = "ACGACTAGCAT"
parts = constituents_of_modified_fasta(seq)
assert parts is not None
assert all(x == y for x, y in zip(parts, seq))


def test_parsing():
for ligand in example_ligands:
assert EntityType.LIGAND in identify_potential_entity_types(ligand)

for protein in example_proteins:
assert EntityType.PROTEIN in identify_potential_entity_types(protein)

for dna in example_dna:
assert EntityType.DNA in identify_potential_entity_types(dna)

for rna in example_rna:
assert EntityType.RNA in identify_potential_entity_types(rna)

0 comments on commit aeb92d1

Please sign in to comment.