Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for specifying covalent bonds and glycans #205

Merged
merged 28 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ For user convenience, we also support automatic MSA generation via the ColabFold
<summary>How can I customize the inputs to the model further?</summary>
<p markdown="1">

For more advanced use cases, we also expose the `chai_lab.chai1.run_folding_on_context`, which allows users to construct an `AllAtomFeatureContext` manually. This allows users to specify their own templates, MSAs, embeddings, and constraints. We currently provide an example of how to construct an embeddings context as well as an MSA context, and will be releasing helper methods to build template contexts soon.
For more advanced use cases, we also expose the `chai_lab.chai1.run_folding_on_context`, which allows users to construct an `AllAtomFeatureContext` manually. This allows users to specify their own templates, MSAs, embeddings, and constraints, including support for specifying covalent bonds (for example, for specifying branched ligands). We currently provide examples of how to construct an embeddings context, an MSA context, restraint contexts, and covalent bonds. We will be releasing helper methods to build template contexts soon.

</p>
</details>
Expand Down
41 changes: 40 additions & 1 deletion chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
from chai_lab.data.dataset.structure.bond_utils import (
get_atom_covalent_bond_pairs_from_constraints,
)
from chai_lab.data.dataset.templates.context import TemplateContext
from chai_lab.data.features.feature_factory import FeatureFactory
from chai_lab.data.features.feature_type import FeatureType
Expand Down Expand Up @@ -76,6 +79,7 @@
TemplateResTypeGenerator,
TemplateUnitVectorGenerator,
)
from chai_lab.data.features.generators.token_bond import TokenBondRestraint
from chai_lab.data.features.generators.token_dist_restraint import (
TokenDistanceRestraint,
)
Expand Down Expand Up @@ -358,11 +362,32 @@ def run_inference(

# Constraints
if constraint_path is not None:
# Handles contact and pocket restraints
pairs = parse_pairwise_table(constraint_path)
restraint_context = load_manual_restraints_for_chai1(
chains,
crop_idces=None,
provided_constraints=parse_pairwise_table(constraint_path),
provided_constraints=pairs,
)
# Handle covalent bond restraints
cov_a, cov_b = get_atom_covalent_bond_pairs_from_constraints(
provided_constraints=pairs,
token_residue_index=merged_context.token_residue_index,
token_residue_name=merged_context.token_residue_name,
token_subchain_id=merged_context.subchain_id,
token_asym_id=merged_context.token_asym_id,
atom_token_index=merged_context.atom_token_index,
atom_ref_name=merged_context.atom_ref_name,
)
if cov_a.numel() > 0 and cov_b.numel() > 0:
orig_a, orig_b = merged_context.atom_covalent_bond_indices
if orig_a.numel() == orig_b.numel() == 0:
merged_context.atom_covalent_bond_indices = (orig_a, orig_b)
else:
merged_context.atom_covalent_bond_indices = (
torch.concatenate([orig_a, cov_a]),
torch.concatenate([orig_b, cov_b]),
)
else:
restraint_context = RestraintContext.empty()

Expand Down Expand Up @@ -425,6 +450,7 @@ def run_folding_on_context(
raise_if_too_many_templates(feature_context.template_context.num_templates)
raise_if_msa_too_deep(feature_context.msa_context.depth)
# NOTE profile MSA used only for statistics; no depth check
feature_context.structure_context.report_bonds()

##
## Prepare batch
Expand Down Expand Up @@ -468,6 +494,7 @@ def run_folding_on_context(
assert model_size in AVAILABLE_MODEL_SIZES

feature_embedding = load_exported("feature_embedding.pt", device)
bond_loss_input_proj = load_exported("bond_loss_input_proj.pt", device)
token_input_embedder = load_exported("token_embedder.pt", device)
trunk = load_exported("trunk.pt", device)
diffusion_module = load_exported("diffusion_module.pt", device)
Expand All @@ -491,6 +518,18 @@ def run_folding_on_context(
template_input_feats = embedded_features["TEMPLATES"]
msa_input_feats = embedded_features["MSA"]

##
## Bond feature generator
## Separate from other feature embeddings due to export limitations
##
bond_ft_gen = TokenBondRestraint()
jackdent marked this conversation as resolved.
Show resolved Hide resolved
bond_ft = bond_ft_gen.generate(batch=batch).data
trunk_bond_feat, structure_bond_feat = bond_loss_input_proj.forward(
crop_size=model_size, input=bond_ft
).chunk(2, dim=-1)
token_pair_input_feats += trunk_bond_feat
token_pair_structure_input_feats += structure_bond_feat

##
## Run the inputs through the token input embedder
##
Expand Down
2 changes: 1 addition & 1 deletion chai_lab/data/dataset/constraints/restraint_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def load_manual_restraints_for_chai1(
contact_constraints: list[ContactRestraint] = []
pocket_constraints: list[PocketRestraint] = []

logger.info(f"Loading {len(provided_constraints)} constraints...")
logger.info(f"Loading {len(provided_constraints)} restraints...")
for constraint in provided_constraints:
match ctype := constraint.connection_type:
case PairwiseInteractionType.COVALENT:
Expand Down
6 changes: 6 additions & 0 deletions chai_lab/data/dataset/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from chai_lab.data.dataset.structure.chain import Chain
from chai_lab.data.parsing.fasta import get_residue_name, read_fasta
from chai_lab.data.parsing.glycans import glycan_string_residues
from chai_lab.data.parsing.input_validation import (
constituents_of_modified_fasta,
identify_potential_entity_types,
Expand Down Expand Up @@ -118,6 +119,8 @@ def raw_inputs_to_entitites_data(
for r in parsed_sequence
]
residues = get_polymer_residues(expanded_sequence, entity_type)
case EntityType.MANUAL_GLYCAN:
residues = glycan_string_residues(glycan_string=input.sequence)
case _:
raise NotImplementedError
assert residues is not None
Expand Down Expand Up @@ -145,6 +148,7 @@ def raw_inputs_to_entitites_data(
method="none",
entity_type=entity_type,
subchain_id=_synth_subchain_id(i),
original_record=input.sequence,
)
)

Expand Down Expand Up @@ -232,6 +236,8 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list
entity_type = EntityType.RNA
case "dna":
entity_type = EntityType.DNA
case "glycan":
entity_type = EntityType.MANUAL_GLYCAN
case _:
raise ValueError(f"{entity_str} is not a valid entity type")

Expand Down
12 changes: 12 additions & 0 deletions chai_lab/data/dataset/structure/all_atom_residue_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
from chai_lab.data.dataset.structure.bond_utils import (
get_atom_covalent_bond_pairs_from_glycan_string,
)
from chai_lab.data.dataset.structure.utils import (
backbone_atoms_all_present,
backbone_atoms_indices,
Expand Down Expand Up @@ -510,6 +513,15 @@ def _tokenize_entity(
dtype=torch.bool,
),
symmetries=tokens.symmetries,
atom_covalent_bond_indices=get_atom_covalent_bond_pairs_from_glycan_string(
glycan_string=(
entity_data.original_record
if entity_data.entity_type == EntityType.MANUAL_GLYCAN
else ""
),
token_residue_index=tokens.residue_index,
atom_ref_name=tokens.atom_names,
),
)

def _get_ref_conformer_data(self, residue: Residue) -> ConformerData:
Expand Down
47 changes: 47 additions & 0 deletions chai_lab/data/dataset/structure/all_atom_structure_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class AllAtomStructureContext:
is_distillation: Bool[Tensor, "1"]
# symmetric atom swap indices
symmetries: Int[Tensor, "n_atoms n_symmetries"]
# atom-wise bond feature; corresponding lists of atoms that are covalently bound
atom_covalent_bond_indices: tuple[Int[Tensor, "n_bonds"], Int[Tensor, "n_bonds"]]

def __post_init__(self):
# Resolved residues filter should eliminate PDBs with missing residues, but that
Expand All @@ -82,10 +84,29 @@ def __post_init__(self):
pdb_id = tensorcode_to_string(self.pdb_id[0])
logger.error(f"Incompatible masks for {pdb_id}")

# Check that bonds are specified in atom space
assert torch.all(self.atom_covalent_bond_indices[0] < self.num_atoms)
assert torch.all(self.atom_covalent_bond_indices[1] < self.num_atoms)

@cached_property
def residue_names(self) -> list[str]:
return batch_tensorcode_to_string(self.token_residue_name)

def report_bonds(self) -> None:
"""Log information about covalent bonds."""
for i, (atom_a, atom_b) in enumerate(zip(*self.atom_covalent_bond_indices)):
tok_a = self.atom_token_index[atom_a]
tok_b = self.atom_token_index[atom_b]
asym_a = self.token_asym_id[tok_a]
asym_b = self.token_asym_id[tok_b]
res_idx_a = self.token_residue_index[tok_a]
res_idx_b = self.token_residue_index[tok_b]
resname_a = tensorcode_to_string(self.token_residue_name[tok_a])
resname_b = tensorcode_to_string(self.token_residue_name[tok_b])
logging.info(
f"Bond {i} (asym res_idx resname): {asym_a} {res_idx_a} {resname_a} <> {asym_b} {res_idx_b} {resname_b}"
)

def pad(
self,
n_tokens: int,
Expand Down Expand Up @@ -142,6 +163,7 @@ def pad(
resolution=self.resolution,
is_distillation=self.is_distillation,
symmetries=pad_atoms_func(self.symmetries, pad_value=-1),
atom_covalent_bond_indices=self.atom_covalent_bond_indices,
)

@typecheck
Expand Down Expand Up @@ -177,6 +199,30 @@ def merge(
n_tokens = sum(x.num_tokens for x in contexts)
token_index = torch.arange(n_tokens, dtype=torch.int)

# Merge and offset bond indices, which are indexed by *token*
atom_covalent_bond_indices_manual_a = []
atom_covalent_bond_indices_manual_b = []
for ctx, count in zip(contexts, atom_offsets):
if ctx.atom_covalent_bond_indices is None:
continue
a, b = ctx.atom_covalent_bond_indices
atom_covalent_bond_indices_manual_a.append(a + count)
atom_covalent_bond_indices_manual_b.append(b + count)
assert len(atom_covalent_bond_indices_manual_a) == len(
atom_covalent_bond_indices_manual_b
)
atom_covalent_bond_indices = (
(
torch.concatenate(atom_covalent_bond_indices_manual_a),
torch.concatenate(atom_covalent_bond_indices_manual_b),
)
if atom_covalent_bond_indices_manual_a
else (
torch.zeros(0, dtype=torch.long),
torch.zeros(0, dtype=torch.long),
)
)

# re-index the reference space from 0..n_tokens-1.
zero_indexed_ref_uids = [
torch.unique_consecutive(x.atom_ref_space_uid, return_inverse=True)[1]
Expand Down Expand Up @@ -255,6 +301,7 @@ def merge(
torch.stack([x.is_distillation for x in contexts]), 0
).values,
symmetries=symmetries,
atom_covalent_bond_indices=atom_covalent_bond_indices,
)

def to(self, device: torch.device | str) -> "AllAtomStructureContext":
Expand Down
Loading
Loading