Skip to content

Commit

Permalink
Easy command line interface with MSA server flag (#190)
Browse files Browse the repository at this point in the history
Co-authored-by: Jack Dent <[email protected]>
  • Loading branch information
wukevin and jackdent authored Nov 29, 2024
1 parent c086906 commit d5e389f
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 38 deletions.
29 changes: 27 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,18 @@ This Python package requires Linux, and a GPU with CUDA and bfloat16 support. We

## Running the model

The model accepts inputs in the FASTA file format, and allows you to specify the number of trunk recycles and diffusion timesteps via the `chai_lab.chai1.run_inference` function. By default, the model generates five sample predictions, and uses embeddings without MSAs or templates.
### Command line inference

The following script demonstrates how to provide inputs to the model, and obtain a list of PDB files for downstream analysis:
You can fold a FASTA file containing all the sequences (including modified residues, nucleotides, and ligands as SMILES strings) in a complex of interest by calling:
```shell
chai fold input.fasta output_folder
```

By default, the model generates five sample predictions, and uses embeddings without MSAs or templates. For additional information about how to supply MSAs and restraints to the model, see the documentation below, or run `chai fold --help`.

### Programmatic inference

The main entrypoint into the Chai-1 folding code is through the `chai_lab.chai1.run_inference` function. The following script demonstrates how to programmatically provide inputs to the model, and obtain a list of PDB files for downstream analysis:

```shell
python examples/predict_structure.py
Expand Down Expand Up @@ -56,6 +65,8 @@ CHAI_DOWNLOADS_DIR=/tmp/downloads python ./examples/predict_structure.py

Chai-1 supports MSAs provided as an `aligned.pqt` file. This file format is similar to an `a3m` file, but has additional columns that provide metadata like the source database and sequence pairing keys. We provide code to convert `a3m` files to `aligned.pqt` files. For more information on how to provide MSAs to Chai-1, see [this documentation](examples/msas/README.md).

For user convenience, we also support automatic MSA generation via the ColabFold [MMseqs2](https://github.com/soedinglab/MMseqs2) server via the `--msa-server` flag. As detailed in the ColabFold [repository](https://github.com/sokrypton/ColabFold), please keep in mind that this is a shared resource. Note that the results reported in our preprint and the webserver use a different MSA search strategy than MMseqs2, though we expect results to be broadly similar.

</p>
</details>

Expand Down Expand Up @@ -121,6 +132,20 @@ If you find Chai-1 useful in your research or use any structures produced by the
}
```

You can also access this information by running `chai citation`.

Additionally, if you use the automatic MMseqs2 MSA generation described above, please also cite:

```
@article{mirdita2022colabfold,
title={ColabFold: making protein folding accessible to all},
author={Mirdita, Milot and Sch{\"u}tze, Konstantin and Moriwaki, Yoshitaka and Heo, Lim and Ovchinnikov, Sergey and Steinegger, Martin},
journal={Nature methods},
year={2022},
}
```


## Licence

Chai-1 is released under an Apache 2.0 License (both code and model weights), which means it can be used for both academic and commerical purposes, including for drug discovery.
Expand Down
39 changes: 33 additions & 6 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext
from chai_lab.data.dataset.embeddings.esm import get_esm_embedding_context
from chai_lab.data.dataset.inference_dataset import load_chains_from_raw, read_inputs
from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas
from chai_lab.data.dataset.msas.load import get_msa_contexts
from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.structure.all_atom_structure_context import (
Expand Down Expand Up @@ -83,6 +84,7 @@
)
from chai_lab.data.io.cif_utils import outputs_to_cif
from chai_lab.data.parsing.restraints import parse_pairwise_table
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.model.diffusion_schedules import InferenceNoiseSchedule
from chai_lab.model.utils import center_random_augmentation
from chai_lab.ranking.frames import get_frames_and_mask
Expand Down Expand Up @@ -268,14 +270,24 @@ def run_inference(
*,
output_dir: Path,
use_esm_embeddings: bool = True,
msa_server: bool = False,
msa_directory: Path | None = None,
constraint_path: Path | str | None = None,
constraint_path: Path | None = None,
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
seed: int | None = None,
device: torch.device | None = None,
device: str | None = None,
) -> StructureCandidates:
if output_dir.exists():
assert not any(
output_dir.iterdir()
), f"Output directory {output_dir} is not empty."
torch_device = torch.device(device if device is not None else "cuda:0")
assert not (
msa_server and msa_directory
), "Cannot specify both MSA server and directory"

# Prepare inputs
assert fasta_file.exists(), fasta_file
fasta_inputs = read_inputs(fasta_file, length_limit=None)
Expand All @@ -290,14 +302,28 @@ def run_inference(

# Load structure context
chains = load_chains_from_raw(fasta_inputs)
del fasta_inputs # Do not reference inputs after creating chains from them

merged_context = AllAtomStructureContext.merge(
[c.structure_context for c in chains]
)
n_actual_tokens = merged_context.num_tokens
raise_if_too_many_tokens(n_actual_tokens)

# Load MSAs
if msa_directory is not None:
# Generated and/or load MSAs
if msa_server:
protein_sequences = [
chain.entity_data.sequence
for chain in chains
if chain.entity_data.entity_type == EntityType.PROTEIN
]
msa_dir = output_dir / "msas"
msa_dir.mkdir(parents=True, exist_ok=False)
generate_colabfold_msas(protein_seqs=protein_sequences, msa_dir=msa_dir)
msa_context, msa_profile_context = get_msa_contexts(
chains, msa_directory=msa_dir
)
elif msa_directory is not None:
msa_context, msa_profile_context = get_msa_contexts(
chains, msa_directory=msa_directory
)
Expand All @@ -308,6 +334,7 @@ def run_inference(
msa_profile_context = MSAContext.create_empty(
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
)

assert (
msa_context.num_tokens == merged_context.num_tokens
), f"Discrepant tokens in input and MSA: {merged_context.num_tokens} != {msa_context.num_tokens}"
Expand All @@ -320,7 +347,7 @@ def run_inference(

# Load ESM embeddings
if use_esm_embeddings:
embedding_context = get_esm_embedding_context(chains, device=device)
embedding_context = get_esm_embedding_context(chains, device=torch_device)
else:
embedding_context = EmbeddingContext.empty(n_tokens=n_actual_tokens)

Expand Down Expand Up @@ -351,7 +378,7 @@ def run_inference(
num_trunk_recycles=num_trunk_recycles,
num_diffn_timesteps=num_diffn_timesteps,
seed=seed,
device=device,
device=torch_device,
)


Expand Down
9 changes: 6 additions & 3 deletions chai_lab/data/dataset/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _synth_subchain_id(idx: int) -> str:
def raw_inputs_to_entitites_data(
inputs: list[Input], identifier: str = "test"
) -> list[AllAtomEntityData]:
"""Load an entity for each raw input."""
entities = []

# track unique entities
Expand Down Expand Up @@ -157,7 +158,7 @@ def load_chains_from_raw(
tokenizer: AllAtomResidueTokenizer | None = None,
) -> list[Chain]:
"""
loads and tokenizes each input chain
Loads and tokenizes each input chain; skips over inputs that fail to tokenize.
"""

if tokenizer is None:
Expand Down Expand Up @@ -186,12 +187,14 @@ def load_chains_from_raw(
logger.exception(f"Failed to tokenize input {entity_data=} {sym_id=}")
tok = None
structure_contexts.append(tok)
assert len(structure_contexts) == len(entities)

# Join the untokenized entity data with the tokenized chain data, removing
# chains we failed to tokenize
chains = [
Chain(entity_data=entity_data, structure_context=structure_context)
for entity_data, structure_context in zip(entities, structure_contexts)
for entity_data, structure_context in zip(
entities, structure_contexts, strict=True
)
if structure_context is not None
]

Expand Down
4 changes: 3 additions & 1 deletion chai_lab/data/dataset/msas/colabfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ def generate_colabfold_msas(protein_seqs: list[str], msa_dir: Path):
This implementation also relies on ColabFold's chain pairing algorithm
rather than using Chai-1's own algorithm, which could also lead to
differences in results.
Places .aligned.pqt files in msa_dir; does not save intermediate a3m files.
"""
assert msa_dir.is_dir(), "MSA directory must be a dir"
assert not any(msa_dir.iterdir()), "MSA directory must be empty"
Expand All @@ -366,7 +368,7 @@ def generate_colabfold_msas(protein_seqs: list[str], msa_dir: Path):
a3ms_dir.mkdir()

# Generate MSAs for each protein chain
print(f"Running MSA generation for {len(protein_seqs)} protein sequences")
logger.info(f"Running MSA generation for {len(protein_seqs)} protein sequences")
msas = _run_mmseqs2(
protein_seqs,
mmseqs_dir,
Expand Down
4 changes: 3 additions & 1 deletion chai_lab/data/dataset/msas/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def get_msa_contexts(
def get_msa_contexts_for_seq(seq) -> MSAContext:
path = msa_directory / expected_basename(seq)
if not path.is_file():
logger.warning(f"No MSA found for sequence: {seq}")
if seq != "X":
# Don't warn for the special "X" sequence
logger.warning(f"No MSA found for sequence: {seq}")
[tokenized_seq] = tokenize_sequences_to_arrays([seq])[0]
return MSAContext.create_single_seq(
MSADataSource.QUERY, tokens=torch.from_numpy(tokenized_seq)
Expand Down
42 changes: 42 additions & 0 deletions chai_lab/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2024 Chai Discovery, Inc.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.

"""Command line interface."""

import logging

import typer

from chai_lab.chai1 import run_inference

CITATION = """
@article{Chai-1-Technical-Report,
title = {Chai-1: Decoding the molecular interactions of life},
author = {{Chai Discovery}},
year = 2024,
journal = {bioRxiv},
publisher = {Cold Spring Harbor Laboratory},
doi = {10.1101/2024.10.10.615955},
url = {https://www.biorxiv.org/content/early/2024/10/11/2024.10.10.615955},
elocation-id = {2024.10.10.615955},
eprint = {https://www.biorxiv.org/content/early/2024/10/11/2024.10.10.615955.full.pdf}
}
""".strip()


def citation():
"""Print citation information"""
typer.echo(CITATION)


def cli():
app = typer.Typer()
app.command("fold", help="Run Chai-1 to fold a complex.")(run_inference)
app.command("citation", help="Print citation information")(citation)
app()


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
cli()
21 changes: 18 additions & 3 deletions examples/msas/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

While Chai-1 performs very well in "single-sequence mode," it can also be given additional evolutionary information to further improve performance. As in other folding methods, this evolutionary information is provided in the form of a multiple sequence alignment (MSA). This information is given in the form of a `MSAContext` object (see `chai_lab/data/dataset/msas/msa_context.py`); we provide code for building these `MSAContext` objects through `aligned.pqt` files, though you can play with building out an `MSAContext` yourself as well.

Multiple strategies can be used for generating MSAs. In our [technical report](https://chaiassets.com/chai-1/paper/technical_report_v1.pdf), we generated MSAs using [jackhmmer](https://github.com/EddyRivasLab/hmmer). Other algorithms such as [MMseqs2](https://github.com/soedinglab/MMseqs2) can also be used. We provide an example of how to generate MSAs using [ColabFold](https://github.com/sokrypton/ColabFold) in `examples/msas/predict_with_msas.py`. Performance will vary depending on the input MSA databases and search algorithms used.

## The `.aligned.pqt` file format

The easiest way to provide MSA information to Chai-1 is through the `.aligned.pqt` file format that we have defined. This file can be thought of as an augmented `a3m` file, and is essentially a dataframe saved in parquet format with the following four (required) columns:
Expand Down Expand Up @@ -58,4 +56,21 @@ import pandas as pd

aligned_pqt = pd.read_parquet("examples/msas/703adc2c74b8d7e613549b6efcf37126da7963522dc33852ad3c691eef1da06f.aligned.pqt")
aligned_pqt.head()
```
```


## Additional MSA generation strategies

Multiple strategies can be used for generating MSAs. In our [technical report](https://chaiassets.com/chai-1/paper/technical_report_v1.pdf), we generated MSAs using [jackhmmer](https://github.com/EddyRivasLab/hmmer). Other algorithms such as [MMseqs2](https://github.com/soedinglab/MMseqs2) can also be used. In this vein, we provide support for automatic MSA generation via the [ColabFold](https://github.com/sokrypton/ColabFold) server using `chai fold input.fasta output_directory --msa-server` or by invoking `run_inference` as follows:

```python
candidates = run_inference(
...
msa_sever=True,
...
)
```

Please note that performance will vary depending on the input MSA databases and search algorithms used.

In addition, people have found that tweaking MSA inputs can be a fruitful path to improving folding results -- we such exploration of this for Chai-1 as well!
21 changes: 5 additions & 16 deletions examples/msas/predict_with_msas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@
from pathlib import Path

import numpy as np
import torch

from chai_lab.chai1 import run_inference
from chai_lab.data.dataset.inference_dataset import read_inputs
from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas
from chai_lab.data.parsing.structure.entity_type import EntityType

tmp_dir = Path(tempfile.mkdtemp())

Expand All @@ -25,16 +21,6 @@
fasta_path = tmp_dir / "example.fasta"
fasta_path.write_text(example_fasta)

# Generate MSAs
msa_dir = tmp_dir / "msas"
msa_dir.mkdir()
protein_seqs = [
input.sequence
for input in read_inputs(fasta_path)
if input.entity_type == EntityType.PROTEIN.value
]
generate_colabfold_msas(protein_seqs=protein_seqs, msa_dir=msa_dir)


# Generate structure
output_dir = tmp_dir / "outputs"
Expand All @@ -45,9 +31,12 @@
num_trunk_recycles=3,
num_diffn_timesteps=200,
seed=42,
device=torch.device("cuda:0"),
device="cuda:0",
use_esm_embeddings=True,
msa_directory=msa_dir,
# See example .aligned.pqt files in this directory
msa_directory=Path(__file__).parent,
# Exclusive with msa_directory; can be used for MMseqs2 server MSA generation
msa_server=False,
)
cif_paths = candidates.cif_paths
scores = [rd.aggregate_score for rd in candidates.ranking_data]
Expand Down
3 changes: 1 addition & 2 deletions examples/predict_structure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pathlib import Path

import numpy as np
import torch

from chai_lab.chai1 import run_inference

Expand Down Expand Up @@ -36,7 +35,7 @@
num_trunk_recycles=3,
num_diffn_timesteps=200,
seed=42,
device=torch.device("cuda:0"),
device="cuda:0",
use_esm_embeddings=True,
)

Expand Down
4 changes: 1 addition & 3 deletions examples/restraints/predict_with_restraints.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
from pathlib import Path

import torch

from chai_lab.chai1 import run_inference

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -32,6 +30,6 @@
num_trunk_recycles=3,
num_diffn_timesteps=200,
seed=42,
device=torch.device("cuda:0"),
device="cuda:0",
use_esm_embeddings=True,
)
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,7 @@ exclude = [
]

[tool.hatch.build.targets.wheel]
# should use packages from sdist section
# should use packages from sdist section

[project.scripts]
chai = "chai_lab.main:cli"

0 comments on commit d5e389f

Please sign in to comment.