diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index 4cdf202..672462b 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -296,8 +296,7 @@ def sorted(self) -> "StructureCandidates": ) -@torch.no_grad() -def run_inference( +def make_all_atom_feature_context( fasta_file: Path, *, output_dir: Path, @@ -306,18 +305,8 @@ def run_inference( msa_server_url: str = "https://api.colabfold.com", msa_directory: Path | None = None, constraint_path: Path | None = None, - # expose some params for easy tweaking - num_trunk_recycles: int = 3, - num_diffn_timesteps: int = 200, - seed: int | None = None, - device: str | None = None, - low_memory: bool = True, -) -> StructureCandidates: - if output_dir.exists(): - assert not any( - output_dir.iterdir() - ), f"Output directory {output_dir} is not empty." - torch_device = torch.device(device if device is not None else "cuda:0") + esm_device: torch.device = torch.device("cpu"), +): assert not ( use_msa_server and msa_directory ), "Cannot specify both MSA server and directory" @@ -385,7 +374,7 @@ def run_inference( # Load ESM embeddings if use_esm_embeddings: - embedding_context = get_esm_embedding_context(chains, device=torch_device) + embedding_context = get_esm_embedding_context(chains, device=esm_device) else: embedding_context = EmbeddingContext.empty(n_tokens=n_actual_tokens) @@ -420,6 +409,9 @@ def run_inference( else: restraint_context = RestraintContext.empty() + # Handles leaving atoms for glycan bonds in-place + merged_context.drop_glycan_leaving_atoms_inplace() + # Build final feature context feature_context = AllAtomFeatureContext( chains=chains, @@ -430,6 +422,43 @@ def run_inference( embedding_context=embedding_context, restraint_context=restraint_context, ) + return feature_context + + +@torch.no_grad() +def run_inference( + fasta_file: Path, + *, + output_dir: Path, + use_esm_embeddings: bool = True, + use_msa_server: bool = False, + msa_server_url: str = "https://api.colabfold.com", + msa_directory: Path | None = None, + constraint_path: Path | None = None, + # expose some params for easy tweaking + num_trunk_recycles: int = 3, + num_diffn_timesteps: int = 200, + seed: int | None = None, + device: str | None = None, + low_memory: bool = True, +) -> StructureCandidates: + if output_dir.exists(): + assert not any( + output_dir.iterdir() + ), f"Output directory {output_dir} is not empty." + + torch_device = torch.device(device if device is not None else "cuda:0") + + feature_context = make_all_atom_feature_context( + fasta_file=fasta_file, + output_dir=output_dir, + use_esm_embeddings=use_esm_embeddings, + use_msa_server=use_msa_server, + msa_server_url=msa_server_url, + msa_directory=msa_directory, + constraint_path=constraint_path, + esm_device=torch_device, + ) return run_folding_on_context( feature_context, diff --git a/chai_lab/data/dataset/structure/all_atom_structure_context.py b/chai_lab/data/dataset/structure/all_atom_structure_context.py index ab02276..1a2d432 100644 --- a/chai_lab/data/dataset/structure/all_atom_structure_context.py +++ b/chai_lab/data/dataset/structure/all_atom_structure_context.py @@ -9,8 +9,10 @@ import torch from torch import Tensor +from chai_lab.data.parsing.structure.entity_type import EntityType from chai_lab.utils.tensor_utils import ( batch_tensorcode_to_string, + cdist, tensorcode_to_string, ) from chai_lab.utils.typing import Bool, Float, Int, UInt8, typecheck @@ -107,6 +109,95 @@ def report_bonds(self) -> None: f"Bond {i} (asym res_idx resname): {asym_a} {res_idx_a} {resname_a} <> {asym_b} {res_idx_b} {resname_b}" ) + @typecheck + def _infer_CO_bonds_within_glycan( + self, + atom_idx: int, + allowed_elements: list[int] | None = None, + ) -> Bool[Tensor, "{self.num_atoms}"]: + """Return mask for atoms that atom_idx might bond to based on distances. + + If exclude_polymers is True, then always return no bonds for polymer entities + """ + tok = self.atom_token_index[atom_idx] + res = self.token_residue_index[tok] + asym = self.token_asym_id[tok] + + if self.token_entity_type[tok].item() != EntityType.MANUAL_GLYCAN.value: + return torch.zeros(self.num_atoms, dtype=torch.bool) + + mask = ( + (self.atom_residue_index == res) + & (self.atom_asym_id == asym) + & self.atom_exists_mask + ) + + # This field contains reference conformers for each residue + # Pairwise distances are therefore valid within each residue + distances = cdist(self.atom_gt_coords) + assert distances.shape == (self.num_atoms, self.num_atoms) + distances[torch.arange(self.num_atoms), torch.arange(self.num_atoms)] = ( + torch.inf + ) + + is_allowed_element = ( + torch.isin( + self.atom_ref_element, test_elements=torch.tensor(allowed_elements) + ) + if allowed_elements is not None + else torch.ones_like(mask) + ) + # Canonical bond length for C-O is 1.43 angstroms; add a bit of headroom + bond_candidates = (distances[atom_idx] < 1.5) & mask & is_allowed_element + return bond_candidates + + def drop_glycan_leaving_atoms_inplace(self) -> None: + """Drop OH groups that leave upon bond formation by setting atom_exists_mask.""" + # For each of the bonds, identify the atoms within bond radius and guess which are leaving + oxygen = 8 + for i, (atom_a, atom_b) in enumerate(zip(*self.atom_covalent_bond_indices)): + # Find the C-O bonds + [bond_candidates_b] = torch.where( + self._infer_CO_bonds_within_glycan( + atom_b.item(), allowed_elements=[oxygen] + ) + ) + # Filter to bonds that link to terminal atoms + # NOTE do not specify element here + bonds_b = [ + candidate + for candidate in bond_candidates_b.tolist() + if (self._infer_CO_bonds_within_glycan(candidate).sum() == 1) + ] + # If there are multiple such bonds, we can't infer which to drop + if len(bonds_b) == 1: + [b_bond] = bonds_b + self.atom_exists_mask[b_bond] = False + logger.info( + f"Bond {i} right: Dropping latter atom in bond {self.atom_residue_index[atom_b]} {self.atom_ref_name[atom_b]} -> {self.atom_residue_index[b_bond]} {self.atom_ref_name[b_bond]}" + ) + continue # Only identify one leaving atom per bond + + # Repeat the above for atom_a if we didn't find anything for atom B + [bond_candidates_a] = torch.where( + self._infer_CO_bonds_within_glycan( + atom_a.item(), allowed_elements=[oxygen] + ) + ) + # Filter to bonds that link to terminal atoms + bonds_a = [ + candidate + for candidate in bond_candidates_a.tolist() + if (self._infer_CO_bonds_within_glycan(candidate).sum() == 1) + ] + # If there are multiple such bonds, we can't infer which to drop + if len(bonds_a) == 1: + [a_bond] = bonds_a + self.atom_exists_mask[a_bond] = False + logger.info( + f"Bond {i} left: Dropping latter atom in bond {self.atom_residue_index[atom_a]} {self.atom_ref_element[atom_a]} -> {self.atom_residue_index[a_bond]} {self.atom_ref_element[a_bond]}" + ) + def pad( self, n_tokens: int, @@ -321,6 +412,14 @@ def num_atoms(self) -> int: (n_atoms,) = self.atom_token_index.shape return n_atoms + @property + def atom_residue_index(self) -> Int[Tensor, "n_atoms"]: + return self.token_residue_index[self.atom_token_index] + + @property + def atom_asym_id(self) -> Int[Tensor, "n_atoms"]: + return self.token_asym_id[self.atom_token_index] + def to_dict(self) -> dict[str, torch.Tensor]: return asdict(self) diff --git a/chai_lab/data/parsing/glycans.py b/chai_lab/data/parsing/glycans.py index f29fd36..d3e0a8f 100644 --- a/chai_lab/data/parsing/glycans.py +++ b/chai_lab/data/parsing/glycans.py @@ -33,10 +33,12 @@ def __post_init__(self): @property def src_atom_name(self) -> str: - return f"C{self.src_atom}" + """Links between sugars are O-glycosidic bonds; we use src O dst C.""" + return f"O{self.src_atom}" @property def dst_atom_name(self) -> str: + """Links between sugars are O-glycosidic bonds; we use src O dst C.""" return f"C{self.dst_atom}" diff --git a/examples/glycosylation/1ac5.fasta b/examples/glycosylation/1ac5.fasta index 3e952e0..1e3087f 100644 --- a/examples/glycosylation/1ac5.fasta +++ b/examples/glycosylation/1ac5.fasta @@ -1,6 +1,6 @@ >protein|1AC5 LPSSEEYKVAYELLPGLSEVPDPSNIPQMHAGHIPLRSEDADEQDSSDLEYFFWKFTNNDSNGNVDRPLIIWLNGGPGCSSMDGALVESGPFRVNSDGKLYLNEGSWISKGDLLFIDQPTGTGFSVEQNKDEGKIDKNKFDEDLEDVTKHFMDFLENYFKIFPEDLTRKIILSGESYAGQYIPFFANAILNHNKFSKIDGDTYDLKALLIGNGWIDPNTQSLSYLPFAMEKKLIDESNPNFKHLTNAHENCQNLINSASTDEAAHFSYQECENILNLLLSYTRESSQKGTADCLNMYNFNLKDSYPSCGMNWPKDISFVSKFFSTPGVIDSLHLDSDKIDHWKECTNSVGTKLSNPISKPSIHLLPGLLESGIEIVLFNGDKDLICNNKGVLDTIDNLKWGGIKGFSDDAVSFDWIHKSKSTDDSEEFSGYVKYDRNLTFVSVYNASHMVPFDKSLVSRGIVDIYSNDVMIIDNNGKNVMITT >glycan|two-sugar -NAG(1-4 NAG) +NAG(4-1 NAG) >glycan|one-sugar NAG \ No newline at end of file diff --git a/examples/glycosylation/README.md b/examples/glycosylation/README.md index f9ca907..3267c84 100644 --- a/examples/glycosylation/README.md +++ b/examples/glycosylation/README.md @@ -24,9 +24,9 @@ Now, a glycan is also covalently bound to a residue; to specify this, we include chainA|res_idxA|chainB|res_idxB|connection_type|confidence|min_distance_angstrom|max_distance_angstrom|comment|restraint_id |---|---|---|---|---|---|---|---|---|---| -A|N436@N|B|@C4|covalent|1.0|0.0|0.0|protein-glycan|bond1 +A|N436@N|B|@C1|covalent|1.0|0.0|0.0|protein-glycan|bond1 -Breaking this down, this specifies that the within chain A (the first entry in the fasta), the "N" residue at the 436-th position (1-indexed) as indicated by the "N436" prefix is bound, via its nitrogen "N" atom as indicated by the "@N" suffix, to the C4 atom in the first glycan ("@C4"). Ring numbering follows standard glycan ring number schemas. For other ligands, you will need check how the specific version of `rdkit` that we use in `chai-lab` (run `uv pip list | grep rdkit` for version) assigns atom names and use the same atom names to specify your bonds. In addition, note that the min and max distance fields are ignored, as is the confidence field. +Breaking this down, this specifies that the within chain A (the first entry in the fasta), the "N" residue at the 436-th position (1-indexed) as indicated by the "N436" prefix is bound, via its nitrogen "N" atom as indicated by the "@N" suffix, to the C1 atom in the first glycan ("@C1"). Ring numbering follows standard glycan ring number schemas. For other ligands, you will need check how the specific version of `rdkit` that we use in `chai-lab` (run `uv pip list | grep rdkit` for version) assigns atom names and use the same atom names to specify your bonds. In addition, note that the min and max distance fields are ignored, as is the confidence field. ### Multi-ring glycan @@ -37,30 +37,36 @@ Working through a more complex example, let's say we have a two-ring ligand such >protein|example-protein ...N... >glycan|example-dual-sugar -NAG(1-4 NAG) +NAG(4-1 NAG) ``` -This syntax specifies that the root of the glycan is the leading `NAG` ring. The parentheses indicate that we are attaching another molecule to the ring directly preceding the parentheses. The `1-4` syntax "draws" a bond between the C1 atom of the previous "root" `NAG` and the C4 atom of the subsequent `NAG` ring. To specify how this glycan ought to be connected to the protein, we again use the restraints file to specify a residue and atom to which the glycan is bound, and the carbon atom within the root glycan ring that is bound. +This syntax specifies that the root of the glycan is the leading `NAG` ring. The parentheses indicate that we are attaching another molecule to the ring directly preceding the parentheses. The `4-1` syntax "draws" a bond between the O4 atom of the previous "root" `NAG` and the C1 atom of the subsequent `NAG` ring. Note that this syntax, when read left-to-right, is "building out" the glycan from the root sugar outwards. + +To specify how this glycan ought to be connected to the protein, we again use the restraints file to specify a residue and atom to which the glycan is bound, and the carbon atom within the root glycan ring that is bound. chainA|res_idxA|chainB|res_idxB|connection_type|confidence|min_distance_angstrom|max_distance_angstrom|comment|restraint_id |---|---|---|---|---|---|---|---|---|---| -A|N436@N|B|@C4|covalent|1.0|0.0|0.0|protein-glycan|bond1 +A|N436@N|B|@C1|covalent|1.0|0.0|0.0|protein-glycan|bond1 You can chain this syntax to create longer ligands: ``` >glycan|4-NAG-in-a-linear-chain -NAG(1-4 NAG(1-4 NAG(1-4 NAG))) +NAG(4-1 NAG(4-1 NAG(4-1 NAG))) ``` ...and to create branched ligands ``` >glycan|branched-glycan -NAG(1-4 NAG(1-4 NAG))(3-4 MAN) +NAG(4-1 NAG(4-1 BMA(3-1 MAN)(6-1 MAN))) ``` -This branched example has a root `NAG` ring with a branch with two more `NAG` rings and a branch with a single `MAN` ring. For additional examples, please refer to the examples tested in the `tests/test_glycans.py` test file. +This branched example has a root `NAG` ring followed by a `NAG` and a `BMA`, which then branches to two `MAN` rings. For additional examples of this syntax, please refer to the examples in `tests/test_glycans.py`. ### Example We have included an example of how glycans can be specified under `predict_glycosylated.py` in this directory, along with its corresponding `bonds.restraints` csv file. This example is based on the PDB structure [1AC5](https://www.rcsb.org/structure/1ac5). The predicted structrue (colored, glycans in purple and orange, protein in green) from this script should look like the following when aligned with the ground truth 1AC5 structure (gray): ![glycan example prediction](./output.png) + +### A note on leaving atoms + +One might notice that in the above example, we are specifying CCD codes for sugar rings and connecting them to each other and an amino acid residue via various bonds. A subtle point is that the reference conformer for these sugar rings include OH hydroxyl groups that leave when bonds are formed. Under the hood, Chai-1 tries to automatically find and remove these atoms (see `AllAtomStructureContext.drop_glycan_leaving_atoms_inplace` for implementation), but this logic only drops leaving hydroxyl groups within glycan sugar rings. For other, non-sugar covalently attached ligands, please specify a SMILES string without the leaving atoms. If this does not work for your use case, please open a GitHub issue. diff --git a/examples/glycosylation/bonds.restraints b/examples/glycosylation/bonds.restraints index d48b6ac..5cd8f5b 100644 --- a/examples/glycosylation/bonds.restraints +++ b/examples/glycosylation/bonds.restraints @@ -1,3 +1,3 @@ chainA,res_idxA,chainB,res_idxB,connection_type,confidence,min_distance_angstrom,max_distance_angstrom,comment,restraint_id -A,N437@N,B,@C4,covalent,1.0,0.0,0.0,protein-glycan,bond1 -A,N445@N,C,@C4,covalent,1.0,0.0,0.0,protein-glycan,bond2 \ No newline at end of file +A,N437@N,B,@C1,covalent,1.0,0.0,0.0,protein-glycan,bond1 +A,N445@N,C,@C1,covalent,1.0,0.0,0.0,protein-glycan,bond2 \ No newline at end of file diff --git a/examples/glycosylation/output.png b/examples/glycosylation/output.png index aa076df..a803d7e 100644 Binary files a/examples/glycosylation/output.png and b/examples/glycosylation/output.png differ diff --git a/tests/test_glycans.py b/tests/test_glycans.py index 4895a74..6bf3405 100644 --- a/tests/test_glycans.py +++ b/tests/test_glycans.py @@ -1,8 +1,13 @@ # Copyright (c) 2024 Chai Discovery, Inc. # Licensed under the Apache License, Version 2.0. # See the LICENSE file for details. +from collections import Counter +from pathlib import Path +from tempfile import TemporaryDirectory + import pytest +from chai_lab.chai1 import make_all_atom_feature_context from chai_lab.data.parsing.glycans import _glycan_string_to_sugars_and_bonds @@ -22,13 +27,20 @@ def test_complex_parsing(): assert bond1.src_sugar_index == 0 assert bond1.dst_sugar_index == 1 + assert bond1.src_atom == 6 + assert bond1.dst_atom == 1 assert bond2.src_sugar_index == 0 assert bond2.dst_sugar_index == 2 - assert bond3.src_sugar_index == 2 + assert bond2.src_atom == 4 + assert bond2.dst_atom == 1 assert bond3.src_sugar_index == 2 assert bond3.dst_sugar_index == 3 + assert bond3.src_atom == 6 + assert bond3.dst_atom == 1 assert bond4.src_sugar_index == 3 assert bond4.dst_sugar_index == 4 + assert bond4.src_atom == 6 + assert bond4.dst_atom == 1 def test_complex_parsing_2(): @@ -51,3 +63,46 @@ def test_complex_parsing_2(): for (expected_src, expected_dst), bond in zip(expected_bonds, bonds, strict=True): assert bond.src_sugar_index == expected_src assert bond.dst_sugar_index == expected_dst + + +def test_glycan_tokenization_with_bond(): + """Test that tokenization works, and that atoms are dropped as expected.""" + glycan = ">glycan|foo\nNAG(4-1 NAG)\n" + with TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) + + fasta_file = tmp_path / "input.fasta" + fasta_file.write_text(glycan) + + output_dir = tmp_path / "out" + + feature_context = make_all_atom_feature_context( + fasta_file, + output_dir=output_dir, + use_esm_embeddings=False, # Just a test; no need + ) + + # Each NAG component is C8 H15 N O6 -> 8 + 1 + 6 = 15 heavy atoms + # The bond between them displaces one oxygen, leaving 2 * 15 - 1 = 29 atoms + assert feature_context.structure_context.atom_exists_mask.sum() == 29 + # We originally constructed all atoms in dropped the atoms that leave + assert feature_context.structure_context.atom_exists_mask.numel() == 30 + elements = Counter( + feature_context.structure_context.atom_ref_element[ + feature_context.structure_context.atom_exists_mask + ].tolist() + ) + assert elements[6] == 16 # 6 = Carbon + assert elements[7] == 2 # 7 = Nitrogen + assert elements[8] == 11 # 8 = Oxygen + + # Single bond feature between O and C + left, right = feature_context.structure_context.atom_covalent_bond_indices + assert left.numel() == right.numel() == 1 + bond_elements = set( + [ + feature_context.structure_context.atom_ref_element[left].item(), + feature_context.structure_context.atom_ref_element[right].item(), + ] + ) + assert bond_elements == {8, 6}