diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index e0f9b13..a003b22 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -1,5 +1,6 @@ # %% import math +from collections import Counter from dataclasses import dataclass from pathlib import Path @@ -203,7 +204,7 @@ def raise_if_too_many_templates(n_actual_templates: int): def raise_if_msa_too_deep(msa_depth: int): if msa_depth > MAX_MSA_DEPTH: raise UnsupportedInputError( - f"MSA to deep: {msa_depth} > {MAX_MSA_DEPTH}. " + f"MSA too deep: {msa_depth} > {MAX_MSA_DEPTH}. " "Please limit the MSA depth." ) @@ -212,26 +213,61 @@ def raise_if_msa_too_deep(msa_depth: int): # Inference logic +@typecheck +@dataclass(frozen=True) +class StructureCandidates: + # We provide candidates generated by a model, + # with confidence predictions and ranking scores. + # Predicted structure is a candidate with the highest score. + + # locations of CIF files, one file per candidate + cif_paths: list[Path] + # scores for each of candidates + info that was used for scoring. + ranking_data: list[SampleRanking] + # iff MSA search was performed, we also save a plot as PDF + msa_coverage_plot_path: Path | None + + # Predicted aligned error(PAE) + pae: Float[Tensor, "candidate num_tokens num_tokens"] + # Predicted distance error (PDE) + pde: Float[Tensor, "candidate num_tokens num_tokens"] + # Predicted local distance difference test (pLDDT) + plddt: Float[Tensor, "candidate num_tokens"] + + def __post_init__(self): + assert len(self.cif_paths) == len(self.ranking_data) + assert len(self.cif_paths) == len(self.pae) + + @torch.no_grad() def run_inference( fasta_file: Path, + *, output_dir: Path, use_esm_embeddings: bool = True, # expose some params for easy tweaking num_trunk_recycles: int = 3, - num_diffn_timesteps: int = 2, + num_diffn_timesteps: int = 200, seed: int | None = None, device: torch.device | None = None, -) -> list[Path]: +) -> StructureCandidates: # Prepare inputs assert fasta_file.exists(), fasta_file fasta_inputs = read_inputs(fasta_file, length_limit=None) + assert len(fasta_inputs) > 0, "No inputs found in fasta file" + for name, count in Counter([inp.entity_name for inp in fasta_inputs]).items(): + if count > 1: + raise UnsupportedInputError( + f"{name=} used more than once in inputs. Each entity must have a unique name" + ) + # Load structure context chains = load_chains_from_raw(fasta_inputs) - contexts = [c.structure_context for c in chains] - merged_context = AllAtomStructureContext.merge(contexts) + merged_context = AllAtomStructureContext.merge( + [c.structure_context for c in chains] + ) n_actual_tokens = merged_context.num_tokens raise_if_too_many_tokens(n_actual_tokens) @@ -271,7 +307,7 @@ def run_inference( constraint_context=constraint_context, ) - output_cif_paths, _, _, _ = run_folding_on_context( + return run_folding_on_context( feature_context, output_dir=output_dir, num_trunk_recycles=num_trunk_recycles, @@ -280,45 +316,25 @@ def run_inference( device=device, ) - return output_cif_paths - def _bin_centers(min_bin: float, max_bin: float, no_bins: int) -> Tensor: return torch.linspace(min_bin, max_bin, 2 * no_bins + 1)[1::2] -@typecheck -@dataclass(frozen=True) -class ConfidenceScores: - # Predicted aligned error(PAE) - pae: Float[Tensor, "bs num_tokens num_tokens"] - - # Predicted distance error (PDE) - pde: Float[Tensor, "bs num_tokens num_tokens"] - - # Predicted local distance difference test (pLDDT) - plddt: Float[Tensor, "bs num_tokens"] - - @torch.no_grad() def run_folding_on_context( feature_context: AllAtomFeatureContext, + *, output_dir: Path, # expose some params for easy tweaking num_trunk_recycles: int = 3, num_diffn_timesteps: int = 200, seed: int | None = None, device: torch.device | None = None, -) -> tuple[list[Path], ConfidenceScores, list[SampleRanking], Path]: +) -> StructureCandidates: """ Function for in-depth explorations. User completely controls folding inputs. - - Returns: - - list of Path corresponding to folding outputs - - ConfidenceScores object - - SampleRanking data - - Path to plot of MSA coverage """ # Set seed if seed is not None: @@ -611,12 +627,6 @@ def avg_per_token_1d(x): plddt_scores = torch.stack([avg_per_token_1d(x) for x in plddt_scores_atom]) - confidence_scores = ConfidenceScores( - pae=pae_scores, - pde=pde_scores, - plddt=plddt_scores, - ) - ## ## Write the outputs ## @@ -628,13 +638,17 @@ def avg_per_token_1d(x): # Plot coverage of tokens by MSA, save plot output_dir.mkdir(parents=True, exist_ok=True) - msa_plot_path = plot_msa( - input_tokens=feature_context.structure_context.token_residue_type, - msa_tokens=feature_context.msa_context.tokens, - out_fname=output_dir / "msa_depth.pdf", - ) - output_paths: list[Path] = [] + if feature_context.msa_context.mask.any(): + msa_plot_path = plot_msa( + input_tokens=feature_context.structure_context.token_residue_type, + msa_tokens=feature_context.msa_context.tokens, + out_fname=output_dir / "msa_depth.pdf", + ) + else: + msa_plot_path = None + + cif_paths: list[Path] = [] ranking_data: list[SampleRanking] = [] for idx in range(num_diffn_samples): @@ -677,8 +691,8 @@ def avg_per_token_1d(x): ## cif_out_path = output_dir.joinpath(f"pred.model_idx_{idx}.cif") - - print(f"Writing output to {cif_out_path}") + aggregate_score = ranking_outputs.aggregate_score.item() + print(f"Score={aggregate_score:.3f}, writing output to {cif_out_path} ") # use 0-100 scale for pLDDT in pdb outputs scaled_plddt_scores_per_atom = 100 * plddt_scores_atom[idx : idx + 1] @@ -693,15 +707,17 @@ def avg_per_token_1d(x): for c in feature_context.chains }, ) - output_paths.append(cif_out_path) + cif_paths.append(cif_out_path) - scores_basename = f"scores.model_idx_{idx}.npz" - scores_out_path = output_dir / scores_basename + scores_out_path = output_dir.joinpath(f"scores.model_idx_{idx}.npz") - scores = get_scores(ranking_outputs) - np.savez( - scores_out_path, - **scores, - ) + np.savez(scores_out_path, **get_scores(ranking_outputs)) - return output_paths, confidence_scores, ranking_data, msa_plot_path + return StructureCandidates( + cif_paths=cif_paths, + ranking_data=ranking_data, + msa_coverage_plot_path=msa_plot_path, + pae=pae_scores, + pde=pde_scores, + plddt=plddt_scores, + ) diff --git a/chai_lab/data/dataset/inference_dataset.py b/chai_lab/data/dataset/inference_dataset.py index e2189ef..01554cf 100644 --- a/chai_lab/data/dataset/inference_dataset.py +++ b/chai_lab/data/dataset/inference_dataset.py @@ -209,11 +209,14 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list total_length: int = 0 for desc, sequence in sequences: logger.info(f"[fasta] [{fasta_file}] {desc} {len(sequence)}") - # get the type of the sequence - entity_str = desc.split("|")[0].strip().lower() - entity_name = desc.split("|")[1].strip().lower() + # examples of inputs + # 'protein|example-of-protein' + # 'protein|name=example-of-protein' + # 'protein|name=example-of-protein|use_esm=true' # example how it can be in the future - match entity_str: + entity_str, *desc_parts = desc.split("|") + + match entity_str.lower().strip(): case "protein": entity_type = EntityType.PROTEIN case "ligand": @@ -225,6 +228,19 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list case _: raise ValueError(f"{entity_str} is not a valid entity type") + match desc_parts: + case []: + raise ValueError(f"label is not provided in {desc=}") + case [label_part]: + label_part = label_part.strip() + if "=" in label_part: + field_name, entity_name = label_part.split("=") + assert field_name == "name" + else: + entity_name = label_part + case _: + raise ValueError(f"Unsupported inputs: {desc=}") + possible_types = identify_potential_entity_types(sequence) if len(possible_types) == 0: logger.error(f"Provided {sequence=} is invalid") @@ -238,9 +254,8 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list total_length += len(sequence) if length_limit is not None and total_length > length_limit: - logger.warning( + raise ValueError( f"[fasta] [{fasta_file}] too many chars ({total_length} > {length_limit}); skipping" ) - return [] return retval diff --git a/chai_lab/data/parsing/fasta.py b/chai_lab/data/parsing/fasta.py index b15631e..6e5867a 100644 --- a/chai_lab/data/parsing/fasta.py +++ b/chai_lab/data/parsing/fasta.py @@ -1,6 +1,5 @@ import logging from pathlib import Path -from typing import Iterable from chai_lab.data.parsing.structure.entity_type import EntityType from chai_lab.data.residue_constants import restype_1to3_with_x @@ -8,7 +7,6 @@ logger = logging.getLogger(__name__) Fasta = tuple[str, str] -Fastas = list[Fasta] nucleic_acid_1_to_name: dict[tuple[str, EntityType], str] = { @@ -23,7 +21,7 @@ } -def read_fasta(file_path: str | Path) -> Iterable[Fasta]: +def read_fasta(file_path: str | Path) -> list[Fasta]: from Bio import SeqIO fasta_sequences = SeqIO.parse(open(file_path), "fasta") diff --git a/chai_lab/ranking/clashes.py b/chai_lab/ranking/clashes.py index aacf222..2fb9982 100644 --- a/chai_lab/ranking/clashes.py +++ b/chai_lab/ranking/clashes.py @@ -4,7 +4,7 @@ from einops import rearrange, reduce, repeat from torch import Tensor -import chai_lab.ranking.utils as rutils +import chai_lab.ranking.utils as rank_utils from chai_lab.utils.tensor_utils import cdist, und_self from chai_lab.utils.typing import Bool, Float, Int, typecheck @@ -22,8 +22,7 @@ class ClashScores: total_clashes: Int[Tensor, "..."] total_inter_chain_clashes: Int[Tensor, "..."] - chain_intra_clashes: Int[Tensor, "... n_chains"] - chain_chain_inter_clashes: Int[Tensor, "... n_chains n_chains"] + chain_chain_clashes: Int[Tensor, "... n_chains n_chains"] has_inter_chain_clashes: Bool[Tensor, "..."] @@ -62,7 +61,7 @@ def has_inter_chain_clashes( """ has_clashes = per_chain_pair_clashes >= max_clashes - atoms_per_chain = rutils.num_atoms_per_chain( + atoms_per_chain = rank_utils.num_atoms_per_chain( atom_mask=atom_mask, asym_id=atom_asym_id, ) @@ -80,7 +79,7 @@ def has_inter_chain_clashes( ).ge(max_clash_ratio) # only consider clashes between pairs of polymer chains - polymer_chains = rutils.chain_is_polymer( + polymer_chains = rank_utils.chain_is_polymer( asym_id=atom_asym_id, mask=atom_mask, entity_type=atom_entity_type, @@ -129,8 +128,14 @@ def get_scores( ) # i, j enumerate chains total_clashes = reduce(clashes_chain_chain, "... i j -> ...", "sum") // 2 - # NB: diagonal term (self-interaction of chain), contains doubled self-interaction - per_chain_intra_clashes = torch.einsum("... i i -> ... i", clashes_chain_chain) // 2 + + # NB: self-interaction of chain contains doubled self-interaction, + # we compensate for this. + clashes_chain_chain = clashes_chain_chain // ( + 1 + torch.diag(clashes_a_a.new_ones(n_chains)) + ) + # in case anyone needs + # per_chain_intra_clashes = torch.einsum("... i i -> ... i", clashes_chain_chain) # delete self-interaction for simplicity non_diag = 1 - torch.diag(clashes_a_a.new_ones(n_chains)) inter_chain_chain = non_diag * clashes_chain_chain @@ -142,8 +147,7 @@ def get_scores( return ClashScores( total_clashes=total_clashes, total_inter_chain_clashes=inter_chain_clashes, - chain_intra_clashes=per_chain_intra_clashes, - chain_chain_inter_clashes=inter_chain_chain, + chain_chain_clashes=clashes_chain_chain, has_inter_chain_clashes=has_inter_chain_clashes( atom_mask=atom_mask, atom_asym_id=atom_asym_id, diff --git a/chai_lab/ranking/plddt.py b/chai_lab/ranking/plddt.py index 525c0dd..93108d4 100644 --- a/chai_lab/ranking/plddt.py +++ b/chai_lab/ranking/plddt.py @@ -3,7 +3,7 @@ from einops import repeat from torch import Tensor -import chai_lab.ranking.utils as rutils +import chai_lab.ranking.utils as rank_utils from chai_lab.utils.tensor_utils import masked_mean from chai_lab.utils.typing import Bool, Float, Int, typecheck @@ -29,7 +29,7 @@ def plddt( bin_centers: Float[Tensor, "bins"], per_residue: bool = False, ) -> Float[Tensor, "..."] | Float[Tensor, "... a"]: - expectations = rutils.expectation(logits, bin_centers) + expectations = rank_utils.expectation(logits, bin_centers) if per_residue: return expectations else: @@ -43,7 +43,7 @@ def per_chain_plddt( asym_id: Int[Tensor, "... a"], bin_centers: Float[Tensor, "bins"], ) -> Float[Tensor, "... c"]: - chain_masks, _ = rutils.get_chain_masks_and_asyms(asym_id, atom_mask) + chain_masks, _ = rank_utils.get_chain_masks_and_asyms(asym_id, atom_mask) logits = repeat(logits, "... a b -> ... c a b", c=chain_masks.shape[-2]) return plddt(logits, chain_masks, bin_centers, per_residue=False) diff --git a/chai_lab/ranking/rank.py b/chai_lab/ranking/rank.py index 6c7853e..4d0e21a 100644 --- a/chai_lab/ranking/rank.py +++ b/chai_lab/ranking/rank.py @@ -7,7 +7,7 @@ import chai_lab.ranking.clashes as clashes import chai_lab.ranking.plddt as plddt import chai_lab.ranking.ptm as ptm -import chai_lab.ranking.utils as rutils +import chai_lab.ranking.utils as rank_utils from chai_lab.utils.typing import Bool, Float, Int, typecheck @@ -15,16 +15,15 @@ @dataclass class SampleRanking: """Sample Ranking Data - asym ids: a tensor of shape (c,) containing the unique asym ids for - each chain in the sample. The asym ids are sorted numerically. - aggregate_score: a tensor of shape (...) containing the aggregate ranking - score for the sample + asym ids: tensor with unique asym ids for each chain in the sample. + The asym ids are sorted numerically, starting from 1. + aggregate_score: aggregate ranking score for the sample ptm_scores: see ptm.get_scores for a description of the ptm scores clash_scores: a dictionary of clash scores plddt_scores: see plddt.PLDDTScores for a description of the plddt scores """ - asym_ids: Int[Tensor, "c"] + asym_ids: Int[Tensor, "chain"] aggregate_score: Float[Tensor, "..."] ptm_scores: ptm.PTMScores clash_scores: clashes.ClashScores @@ -65,14 +64,12 @@ def rank( bin_centers=pae_bin_centers, token_asym_id=token_asym_id, ) + atom_asym_id = torch.gather(token_asym_id, dim=-1, index=atom_token_index.long()) + clash_scores = clashes.get_scores( atom_coords=atom_coords, atom_mask=atom_mask, - atom_asym_id=torch.gather( - token_asym_id, - dim=-1, - index=atom_token_index.long(), - ), + atom_asym_id=atom_asym_id, atom_entity_type=torch.gather( token_entity_type, dim=-1, @@ -87,11 +84,7 @@ def rank( lddt_logits=lddt_logits, atom_mask=atom_mask, bin_centers=lddt_bin_centers, - atom_asym_id=torch.gather( - token_asym_id, - dim=-1, - index=atom_token_index.long(), - ), + atom_asym_id=atom_asym_id, ) # aggregate score @@ -101,7 +94,7 @@ def rank( - 100 * clash_scores.has_inter_chain_clashes.float() ) - _, asyms = rutils.get_chain_masks_and_asyms( + _, asyms = rank_utils.get_chain_masks_and_asyms( asym_id=token_asym_id, mask=token_exists_mask, ) @@ -123,8 +116,6 @@ def get_scores(ranking_data: SampleRanking) -> dict[str, np.ndarray]: "per_chain_ptm": ranking_data.ptm_scores.per_chain_ptm, "per_chain_pair_iptm": ranking_data.ptm_scores.per_chain_pair_iptm, "has_inter_chain_clashes": ranking_data.clash_scores.has_inter_chain_clashes, - # TODO replace with just one tensor that contains both - "chain_intra_clashes": ranking_data.clash_scores.chain_intra_clashes, - "chain_chain_inter_clashes": ranking_data.clash_scores.chain_chain_inter_clashes, + "chain_chain_clashes": ranking_data.clash_scores.chain_chain_clashes, } return {k: v.cpu().numpy() for k, v in scores.items()} diff --git a/examples/predict_structure.py b/examples/predict_structure.py index a55ca39..02c797c 100644 --- a/examples/predict_structure.py +++ b/examples/predict_structure.py @@ -6,17 +6,20 @@ from chai_lab.chai1 import run_inference # We use fasta-like format for inputs. -# Every record may encode protein, ligand, RNA or DNA -# see example below +# - each entity encodes protein, ligand, RNA or DNA +# - each entity is labeled with unique name; +# - ligands are encoded with SMILES; modified residues encoded like AAA(SEP)AAA + +# Example given below, just modify it example_fasta = """ ->protein|example-of-long-protein +>protein|name=example-of-long-protein AGSHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASPRGEPRAPWVEQEGPEYWDRETQKYKRQAQTDRVSLRNLRGYYNQSEAGSHTLQWMFGCDLGPDGRLLRGYDQSAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAAREAEQRRAYLEGTCVEWLRRYLENGKETLQRAEHPKTHVTHHPVSDHEATLRCWALGFYPAEITLTWQWDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPEPLTLRWEP ->protein|example-of-short-protein +>protein|name=example-of-short-protein AIQRTPKIQVYSRHPAENGKSNFLNCYVSGFHPSDIEVDLLKNGERIEKVEHSDLSFSKDWSFYLLYYTEFTPTEKDEYACRVNHVTLSQPKIVKWDRDM ->protein|example-of-peptide +>protein|name=example-peptide GAAL ->ligand|and-example-for-ligand-encoded-as-smiles +>ligand|name=example-ligand-as-smiles CCCCCCCCCCCCCC(=O)O """.strip() @@ -24,7 +27,8 @@ fasta_path.write_text(example_fasta) output_dir = Path("/tmp/outputs") -output_cif_paths = run_inference( + +candidates = run_inference( fasta_file=fasta_path, output_dir=output_dir, # 'default' setup @@ -35,5 +39,9 @@ use_esm_embeddings=True, ) +cif_paths = candidates.cif_paths +scores = [rd.aggregate_score for rd in candidates.ranking_data] + + # Load pTM, ipTM, pLDDTs and clash scores for sample 2 scores = np.load(output_dir.joinpath("scores.model_idx_2.npz"))