Skip to content

Commit

Permalink
Change of input/output formats (#60)
Browse files Browse the repository at this point in the history
* minor: reformatting, dedup

* minor

* minor

* minor

* minor

* introduce a structure returned from folding methods

* return StructureCandidates from folding functions

* clarify inputs

* typing cleanup

* use kwarg-only interface, use check that names are unique

* use name=... in example + rewording

* raise Error on too long input, not return empty

* more generic format of fasta seq_ids

* change error class

* minor

* update doc as recommended

* typo

Co-authored-by: Jack Dent <[email protected]>

* Detalize error message

Co-authored-by: Jack Dent <[email protected]>

* typo

* reformat + __post_init__ check

---------

Co-authored-by: Jack Dent <[email protected]>
  • Loading branch information
arogozhnikov and jackdent authored Sep 20, 2024
1 parent edeb695 commit 128b691
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 99 deletions.
118 changes: 67 additions & 51 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# %%
import math
from collections import Counter
from dataclasses import dataclass
from pathlib import Path

Expand Down Expand Up @@ -203,7 +204,7 @@ def raise_if_too_many_templates(n_actual_templates: int):
def raise_if_msa_too_deep(msa_depth: int):
if msa_depth > MAX_MSA_DEPTH:
raise UnsupportedInputError(
f"MSA to deep: {msa_depth} > {MAX_MSA_DEPTH}. "
f"MSA too deep: {msa_depth} > {MAX_MSA_DEPTH}. "
"Please limit the MSA depth."
)

Expand All @@ -212,26 +213,61 @@ def raise_if_msa_too_deep(msa_depth: int):
# Inference logic


@typecheck
@dataclass(frozen=True)
class StructureCandidates:
# We provide candidates generated by a model,
# with confidence predictions and ranking scores.
# Predicted structure is a candidate with the highest score.

# locations of CIF files, one file per candidate
cif_paths: list[Path]
# scores for each of candidates + info that was used for scoring.
ranking_data: list[SampleRanking]
# iff MSA search was performed, we also save a plot as PDF
msa_coverage_plot_path: Path | None

# Predicted aligned error(PAE)
pae: Float[Tensor, "candidate num_tokens num_tokens"]
# Predicted distance error (PDE)
pde: Float[Tensor, "candidate num_tokens num_tokens"]
# Predicted local distance difference test (pLDDT)
plddt: Float[Tensor, "candidate num_tokens"]

def __post_init__(self):
assert len(self.cif_paths) == len(self.ranking_data)
assert len(self.cif_paths) == len(self.pae)


@torch.no_grad()
def run_inference(
fasta_file: Path,
*,
output_dir: Path,
use_esm_embeddings: bool = True,
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 2,
num_diffn_timesteps: int = 200,
seed: int | None = None,
device: torch.device | None = None,
) -> list[Path]:
) -> StructureCandidates:
# Prepare inputs
assert fasta_file.exists(), fasta_file
fasta_inputs = read_inputs(fasta_file, length_limit=None)

assert len(fasta_inputs) > 0, "No inputs found in fasta file"

for name, count in Counter([inp.entity_name for inp in fasta_inputs]).items():
if count > 1:
raise UnsupportedInputError(
f"{name=} used more than once in inputs. Each entity must have a unique name"
)

# Load structure context
chains = load_chains_from_raw(fasta_inputs)
contexts = [c.structure_context for c in chains]
merged_context = AllAtomStructureContext.merge(contexts)
merged_context = AllAtomStructureContext.merge(
[c.structure_context for c in chains]
)
n_actual_tokens = merged_context.num_tokens
raise_if_too_many_tokens(n_actual_tokens)

Expand Down Expand Up @@ -271,7 +307,7 @@ def run_inference(
constraint_context=constraint_context,
)

output_cif_paths, _, _, _ = run_folding_on_context(
return run_folding_on_context(
feature_context,
output_dir=output_dir,
num_trunk_recycles=num_trunk_recycles,
Expand All @@ -280,45 +316,25 @@ def run_inference(
device=device,
)

return output_cif_paths


def _bin_centers(min_bin: float, max_bin: float, no_bins: int) -> Tensor:
return torch.linspace(min_bin, max_bin, 2 * no_bins + 1)[1::2]


@typecheck
@dataclass(frozen=True)
class ConfidenceScores:
# Predicted aligned error(PAE)
pae: Float[Tensor, "bs num_tokens num_tokens"]

# Predicted distance error (PDE)
pde: Float[Tensor, "bs num_tokens num_tokens"]

# Predicted local distance difference test (pLDDT)
plddt: Float[Tensor, "bs num_tokens"]


@torch.no_grad()
def run_folding_on_context(
feature_context: AllAtomFeatureContext,
*,
output_dir: Path,
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
seed: int | None = None,
device: torch.device | None = None,
) -> tuple[list[Path], ConfidenceScores, list[SampleRanking], Path]:
) -> StructureCandidates:
"""
Function for in-depth explorations.
User completely controls folding inputs.
Returns:
- list of Path corresponding to folding outputs
- ConfidenceScores object
- SampleRanking data
- Path to plot of MSA coverage
"""
# Set seed
if seed is not None:
Expand Down Expand Up @@ -611,12 +627,6 @@ def avg_per_token_1d(x):

plddt_scores = torch.stack([avg_per_token_1d(x) for x in plddt_scores_atom])

confidence_scores = ConfidenceScores(
pae=pae_scores,
pde=pde_scores,
plddt=plddt_scores,
)

##
## Write the outputs
##
Expand All @@ -628,13 +638,17 @@ def avg_per_token_1d(x):

# Plot coverage of tokens by MSA, save plot
output_dir.mkdir(parents=True, exist_ok=True)
msa_plot_path = plot_msa(
input_tokens=feature_context.structure_context.token_residue_type,
msa_tokens=feature_context.msa_context.tokens,
out_fname=output_dir / "msa_depth.pdf",
)

output_paths: list[Path] = []
if feature_context.msa_context.mask.any():
msa_plot_path = plot_msa(
input_tokens=feature_context.structure_context.token_residue_type,
msa_tokens=feature_context.msa_context.tokens,
out_fname=output_dir / "msa_depth.pdf",
)
else:
msa_plot_path = None

cif_paths: list[Path] = []
ranking_data: list[SampleRanking] = []

for idx in range(num_diffn_samples):
Expand Down Expand Up @@ -677,8 +691,8 @@ def avg_per_token_1d(x):
##

cif_out_path = output_dir.joinpath(f"pred.model_idx_{idx}.cif")

print(f"Writing output to {cif_out_path}")
aggregate_score = ranking_outputs.aggregate_score.item()
print(f"Score={aggregate_score:.3f}, writing output to {cif_out_path} ")

# use 0-100 scale for pLDDT in pdb outputs
scaled_plddt_scores_per_atom = 100 * plddt_scores_atom[idx : idx + 1]
Expand All @@ -693,15 +707,17 @@ def avg_per_token_1d(x):
for c in feature_context.chains
},
)
output_paths.append(cif_out_path)
cif_paths.append(cif_out_path)

scores_basename = f"scores.model_idx_{idx}.npz"
scores_out_path = output_dir / scores_basename
scores_out_path = output_dir.joinpath(f"scores.model_idx_{idx}.npz")

scores = get_scores(ranking_outputs)
np.savez(
scores_out_path,
**scores,
)
np.savez(scores_out_path, **get_scores(ranking_outputs))

return output_paths, confidence_scores, ranking_data, msa_plot_path
return StructureCandidates(
cif_paths=cif_paths,
ranking_data=ranking_data,
msa_coverage_plot_path=msa_plot_path,
pae=pae_scores,
pde=pde_scores,
plddt=plddt_scores,
)
27 changes: 21 additions & 6 deletions chai_lab/data/dataset/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,14 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list
total_length: int = 0
for desc, sequence in sequences:
logger.info(f"[fasta] [{fasta_file}] {desc} {len(sequence)}")
# get the type of the sequence
entity_str = desc.split("|")[0].strip().lower()
entity_name = desc.split("|")[1].strip().lower()
# examples of inputs
# 'protein|example-of-protein'
# 'protein|name=example-of-protein'
# 'protein|name=example-of-protein|use_esm=true' # example how it can be in the future

match entity_str:
entity_str, *desc_parts = desc.split("|")

match entity_str.lower().strip():
case "protein":
entity_type = EntityType.PROTEIN
case "ligand":
Expand All @@ -225,6 +228,19 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list
case _:
raise ValueError(f"{entity_str} is not a valid entity type")

match desc_parts:
case []:
raise ValueError(f"label is not provided in {desc=}")
case [label_part]:
label_part = label_part.strip()
if "=" in label_part:
field_name, entity_name = label_part.split("=")
assert field_name == "name"
else:
entity_name = label_part
case _:
raise ValueError(f"Unsupported inputs: {desc=}")

possible_types = identify_potential_entity_types(sequence)
if len(possible_types) == 0:
logger.error(f"Provided {sequence=} is invalid")
Expand All @@ -238,9 +254,8 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list
total_length += len(sequence)

if length_limit is not None and total_length > length_limit:
logger.warning(
raise ValueError(
f"[fasta] [{fasta_file}] too many chars ({total_length} > {length_limit}); skipping"
)
return []

return retval
4 changes: 1 addition & 3 deletions chai_lab/data/parsing/fasta.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import logging
from pathlib import Path
from typing import Iterable

from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.data.residue_constants import restype_1to3_with_x

logger = logging.getLogger(__name__)

Fasta = tuple[str, str]
Fastas = list[Fasta]


nucleic_acid_1_to_name: dict[tuple[str, EntityType], str] = {
Expand All @@ -23,7 +21,7 @@
}


def read_fasta(file_path: str | Path) -> Iterable[Fasta]:
def read_fasta(file_path: str | Path) -> list[Fasta]:
from Bio import SeqIO

fasta_sequences = SeqIO.parse(open(file_path), "fasta")
Expand Down
22 changes: 13 additions & 9 deletions chai_lab/ranking/clashes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from einops import rearrange, reduce, repeat
from torch import Tensor

import chai_lab.ranking.utils as rutils
import chai_lab.ranking.utils as rank_utils
from chai_lab.utils.tensor_utils import cdist, und_self
from chai_lab.utils.typing import Bool, Float, Int, typecheck

Expand All @@ -22,8 +22,7 @@ class ClashScores:

total_clashes: Int[Tensor, "..."]
total_inter_chain_clashes: Int[Tensor, "..."]
chain_intra_clashes: Int[Tensor, "... n_chains"]
chain_chain_inter_clashes: Int[Tensor, "... n_chains n_chains"]
chain_chain_clashes: Int[Tensor, "... n_chains n_chains"]
has_inter_chain_clashes: Bool[Tensor, "..."]


Expand Down Expand Up @@ -62,7 +61,7 @@ def has_inter_chain_clashes(
"""
has_clashes = per_chain_pair_clashes >= max_clashes

atoms_per_chain = rutils.num_atoms_per_chain(
atoms_per_chain = rank_utils.num_atoms_per_chain(
atom_mask=atom_mask,
asym_id=atom_asym_id,
)
Expand All @@ -80,7 +79,7 @@ def has_inter_chain_clashes(
).ge(max_clash_ratio)

# only consider clashes between pairs of polymer chains
polymer_chains = rutils.chain_is_polymer(
polymer_chains = rank_utils.chain_is_polymer(
asym_id=atom_asym_id,
mask=atom_mask,
entity_type=atom_entity_type,
Expand Down Expand Up @@ -129,8 +128,14 @@ def get_scores(
)
# i, j enumerate chains
total_clashes = reduce(clashes_chain_chain, "... i j -> ...", "sum") // 2
# NB: diagonal term (self-interaction of chain), contains doubled self-interaction
per_chain_intra_clashes = torch.einsum("... i i -> ... i", clashes_chain_chain) // 2

# NB: self-interaction of chain contains doubled self-interaction,
# we compensate for this.
clashes_chain_chain = clashes_chain_chain // (
1 + torch.diag(clashes_a_a.new_ones(n_chains))
)
# in case anyone needs
# per_chain_intra_clashes = torch.einsum("... i i -> ... i", clashes_chain_chain)
# delete self-interaction for simplicity
non_diag = 1 - torch.diag(clashes_a_a.new_ones(n_chains))
inter_chain_chain = non_diag * clashes_chain_chain
Expand All @@ -142,8 +147,7 @@ def get_scores(
return ClashScores(
total_clashes=total_clashes,
total_inter_chain_clashes=inter_chain_clashes,
chain_intra_clashes=per_chain_intra_clashes,
chain_chain_inter_clashes=inter_chain_chain,
chain_chain_clashes=clashes_chain_chain,
has_inter_chain_clashes=has_inter_chain_clashes(
atom_mask=atom_mask,
atom_asym_id=atom_asym_id,
Expand Down
6 changes: 3 additions & 3 deletions chai_lab/ranking/plddt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from einops import repeat
from torch import Tensor

import chai_lab.ranking.utils as rutils
import chai_lab.ranking.utils as rank_utils
from chai_lab.utils.tensor_utils import masked_mean
from chai_lab.utils.typing import Bool, Float, Int, typecheck

Expand All @@ -29,7 +29,7 @@ def plddt(
bin_centers: Float[Tensor, "bins"],
per_residue: bool = False,
) -> Float[Tensor, "..."] | Float[Tensor, "... a"]:
expectations = rutils.expectation(logits, bin_centers)
expectations = rank_utils.expectation(logits, bin_centers)
if per_residue:
return expectations
else:
Expand All @@ -43,7 +43,7 @@ def per_chain_plddt(
asym_id: Int[Tensor, "... a"],
bin_centers: Float[Tensor, "bins"],
) -> Float[Tensor, "... c"]:
chain_masks, _ = rutils.get_chain_masks_and_asyms(asym_id, atom_mask)
chain_masks, _ = rank_utils.get_chain_masks_and_asyms(asym_id, atom_mask)
logits = repeat(logits, "... a b -> ... c a b", c=chain_masks.shape[-2])
return plddt(logits, chain_masks, bin_centers, per_residue=False)

Expand Down
Loading

0 comments on commit 128b691

Please sign in to comment.