From d5e389f24e35d93d18e4fb8f9b5a557e55999d0e Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Thu, 28 Nov 2024 16:34:03 -0800 Subject: [PATCH] Easy command line interface with MSA server flag (#190) Co-authored-by: Jack Dent --- README.md | 29 ++++++++++++- chai_lab/chai1.py | 39 ++++++++++++++--- chai_lab/data/dataset/inference_dataset.py | 9 ++-- chai_lab/data/dataset/msas/colabfold.py | 4 +- chai_lab/data/dataset/msas/load.py | 4 +- chai_lab/main.py | 42 +++++++++++++++++++ examples/msas/README.md | 21 ++++++++-- examples/msas/predict_with_msas.py | 21 +++------- examples/predict_structure.py | 3 +- .../restraints/predict_with_restraints.py | 4 +- pyproject.toml | 5 ++- 11 files changed, 143 insertions(+), 38 deletions(-) create mode 100644 chai_lab/main.py diff --git a/README.md b/README.md index 4c21383..8969ac5 100644 --- a/README.md +++ b/README.md @@ -22,9 +22,18 @@ This Python package requires Linux, and a GPU with CUDA and bfloat16 support. We ## Running the model -The model accepts inputs in the FASTA file format, and allows you to specify the number of trunk recycles and diffusion timesteps via the `chai_lab.chai1.run_inference` function. By default, the model generates five sample predictions, and uses embeddings without MSAs or templates. +### Command line inference -The following script demonstrates how to provide inputs to the model, and obtain a list of PDB files for downstream analysis: +You can fold a FASTA file containing all the sequences (including modified residues, nucleotides, and ligands as SMILES strings) in a complex of interest by calling: +```shell +chai fold input.fasta output_folder +``` + +By default, the model generates five sample predictions, and uses embeddings without MSAs or templates. For additional information about how to supply MSAs and restraints to the model, see the documentation below, or run `chai fold --help`. + +### Programmatic inference + +The main entrypoint into the Chai-1 folding code is through the `chai_lab.chai1.run_inference` function. The following script demonstrates how to programmatically provide inputs to the model, and obtain a list of PDB files for downstream analysis: ```shell python examples/predict_structure.py @@ -56,6 +65,8 @@ CHAI_DOWNLOADS_DIR=/tmp/downloads python ./examples/predict_structure.py Chai-1 supports MSAs provided as an `aligned.pqt` file. This file format is similar to an `a3m` file, but has additional columns that provide metadata like the source database and sequence pairing keys. We provide code to convert `a3m` files to `aligned.pqt` files. For more information on how to provide MSAs to Chai-1, see [this documentation](examples/msas/README.md). +For user convenience, we also support automatic MSA generation via the ColabFold [MMseqs2](https://github.com/soedinglab/MMseqs2) server via the `--msa-server` flag. As detailed in the ColabFold [repository](https://github.com/sokrypton/ColabFold), please keep in mind that this is a shared resource. Note that the results reported in our preprint and the webserver use a different MSA search strategy than MMseqs2, though we expect results to be broadly similar. +

@@ -121,6 +132,20 @@ If you find Chai-1 useful in your research or use any structures produced by the } ``` +You can also access this information by running `chai citation`. + +Additionally, if you use the automatic MMseqs2 MSA generation described above, please also cite: + +``` +@article{mirdita2022colabfold, + title={ColabFold: making protein folding accessible to all}, + author={Mirdita, Milot and Sch{\"u}tze, Konstantin and Moriwaki, Yoshitaka and Heo, Lim and Ovchinnikov, Sergey and Steinegger, Martin}, + journal={Nature methods}, + year={2022}, +} +``` + + ## Licence Chai-1 is released under an Apache 2.0 License (both code and model weights), which means it can be used for both academic and commerical purposes, including for drug discovery. diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index 73c7eb2..387af28 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -29,6 +29,7 @@ from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext from chai_lab.data.dataset.embeddings.esm import get_esm_embedding_context from chai_lab.data.dataset.inference_dataset import load_chains_from_raw, read_inputs +from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas from chai_lab.data.dataset.msas.load import get_msa_contexts from chai_lab.data.dataset.msas.msa_context import MSAContext from chai_lab.data.dataset.structure.all_atom_structure_context import ( @@ -83,6 +84,7 @@ ) from chai_lab.data.io.cif_utils import outputs_to_cif from chai_lab.data.parsing.restraints import parse_pairwise_table +from chai_lab.data.parsing.structure.entity_type import EntityType from chai_lab.model.diffusion_schedules import InferenceNoiseSchedule from chai_lab.model.utils import center_random_augmentation from chai_lab.ranking.frames import get_frames_and_mask @@ -268,14 +270,24 @@ def run_inference( *, output_dir: Path, use_esm_embeddings: bool = True, + msa_server: bool = False, msa_directory: Path | None = None, - constraint_path: Path | str | None = None, + constraint_path: Path | None = None, # expose some params for easy tweaking num_trunk_recycles: int = 3, num_diffn_timesteps: int = 200, seed: int | None = None, - device: torch.device | None = None, + device: str | None = None, ) -> StructureCandidates: + if output_dir.exists(): + assert not any( + output_dir.iterdir() + ), f"Output directory {output_dir} is not empty." + torch_device = torch.device(device if device is not None else "cuda:0") + assert not ( + msa_server and msa_directory + ), "Cannot specify both MSA server and directory" + # Prepare inputs assert fasta_file.exists(), fasta_file fasta_inputs = read_inputs(fasta_file, length_limit=None) @@ -290,14 +302,28 @@ def run_inference( # Load structure context chains = load_chains_from_raw(fasta_inputs) + del fasta_inputs # Do not reference inputs after creating chains from them + merged_context = AllAtomStructureContext.merge( [c.structure_context for c in chains] ) n_actual_tokens = merged_context.num_tokens raise_if_too_many_tokens(n_actual_tokens) - # Load MSAs - if msa_directory is not None: + # Generated and/or load MSAs + if msa_server: + protein_sequences = [ + chain.entity_data.sequence + for chain in chains + if chain.entity_data.entity_type == EntityType.PROTEIN + ] + msa_dir = output_dir / "msas" + msa_dir.mkdir(parents=True, exist_ok=False) + generate_colabfold_msas(protein_seqs=protein_sequences, msa_dir=msa_dir) + msa_context, msa_profile_context = get_msa_contexts( + chains, msa_directory=msa_dir + ) + elif msa_directory is not None: msa_context, msa_profile_context = get_msa_contexts( chains, msa_directory=msa_directory ) @@ -308,6 +334,7 @@ def run_inference( msa_profile_context = MSAContext.create_empty( n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH ) + assert ( msa_context.num_tokens == merged_context.num_tokens ), f"Discrepant tokens in input and MSA: {merged_context.num_tokens} != {msa_context.num_tokens}" @@ -320,7 +347,7 @@ def run_inference( # Load ESM embeddings if use_esm_embeddings: - embedding_context = get_esm_embedding_context(chains, device=device) + embedding_context = get_esm_embedding_context(chains, device=torch_device) else: embedding_context = EmbeddingContext.empty(n_tokens=n_actual_tokens) @@ -351,7 +378,7 @@ def run_inference( num_trunk_recycles=num_trunk_recycles, num_diffn_timesteps=num_diffn_timesteps, seed=seed, - device=device, + device=torch_device, ) diff --git a/chai_lab/data/dataset/inference_dataset.py b/chai_lab/data/dataset/inference_dataset.py index d589669..5caaa55 100644 --- a/chai_lab/data/dataset/inference_dataset.py +++ b/chai_lab/data/dataset/inference_dataset.py @@ -93,6 +93,7 @@ def _synth_subchain_id(idx: int) -> str: def raw_inputs_to_entitites_data( inputs: list[Input], identifier: str = "test" ) -> list[AllAtomEntityData]: + """Load an entity for each raw input.""" entities = [] # track unique entities @@ -157,7 +158,7 @@ def load_chains_from_raw( tokenizer: AllAtomResidueTokenizer | None = None, ) -> list[Chain]: """ - loads and tokenizes each input chain + Loads and tokenizes each input chain; skips over inputs that fail to tokenize. """ if tokenizer is None: @@ -186,12 +187,14 @@ def load_chains_from_raw( logger.exception(f"Failed to tokenize input {entity_data=} {sym_id=}") tok = None structure_contexts.append(tok) - assert len(structure_contexts) == len(entities) + # Join the untokenized entity data with the tokenized chain data, removing # chains we failed to tokenize chains = [ Chain(entity_data=entity_data, structure_context=structure_context) - for entity_data, structure_context in zip(entities, structure_contexts) + for entity_data, structure_context in zip( + entities, structure_contexts, strict=True + ) if structure_context is not None ] diff --git a/chai_lab/data/dataset/msas/colabfold.py b/chai_lab/data/dataset/msas/colabfold.py index 69d3a3a..87a85af 100644 --- a/chai_lab/data/dataset/msas/colabfold.py +++ b/chai_lab/data/dataset/msas/colabfold.py @@ -352,6 +352,8 @@ def generate_colabfold_msas(protein_seqs: list[str], msa_dir: Path): This implementation also relies on ColabFold's chain pairing algorithm rather than using Chai-1's own algorithm, which could also lead to differences in results. + + Places .aligned.pqt files in msa_dir; does not save intermediate a3m files. """ assert msa_dir.is_dir(), "MSA directory must be a dir" assert not any(msa_dir.iterdir()), "MSA directory must be empty" @@ -366,7 +368,7 @@ def generate_colabfold_msas(protein_seqs: list[str], msa_dir: Path): a3ms_dir.mkdir() # Generate MSAs for each protein chain - print(f"Running MSA generation for {len(protein_seqs)} protein sequences") + logger.info(f"Running MSA generation for {len(protein_seqs)} protein sequences") msas = _run_mmseqs2( protein_seqs, mmseqs_dir, diff --git a/chai_lab/data/dataset/msas/load.py b/chai_lab/data/dataset/msas/load.py index 18567ea..5e177ed 100644 --- a/chai_lab/data/dataset/msas/load.py +++ b/chai_lab/data/dataset/msas/load.py @@ -47,7 +47,9 @@ def get_msa_contexts( def get_msa_contexts_for_seq(seq) -> MSAContext: path = msa_directory / expected_basename(seq) if not path.is_file(): - logger.warning(f"No MSA found for sequence: {seq}") + if seq != "X": + # Don't warn for the special "X" sequence + logger.warning(f"No MSA found for sequence: {seq}") [tokenized_seq] = tokenize_sequences_to_arrays([seq])[0] return MSAContext.create_single_seq( MSADataSource.QUERY, tokens=torch.from_numpy(tokenized_seq) diff --git a/chai_lab/main.py b/chai_lab/main.py new file mode 100644 index 0000000..124a94d --- /dev/null +++ b/chai_lab/main.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024 Chai Discovery, Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for details. + +"""Command line interface.""" + +import logging + +import typer + +from chai_lab.chai1 import run_inference + +CITATION = """ +@article{Chai-1-Technical-Report, + title = {Chai-1: Decoding the molecular interactions of life}, + author = {{Chai Discovery}}, + year = 2024, + journal = {bioRxiv}, + publisher = {Cold Spring Harbor Laboratory}, + doi = {10.1101/2024.10.10.615955}, + url = {https://www.biorxiv.org/content/early/2024/10/11/2024.10.10.615955}, + elocation-id = {2024.10.10.615955}, + eprint = {https://www.biorxiv.org/content/early/2024/10/11/2024.10.10.615955.full.pdf} +} +""".strip() + + +def citation(): + """Print citation information""" + typer.echo(CITATION) + + +def cli(): + app = typer.Typer() + app.command("fold", help="Run Chai-1 to fold a complex.")(run_inference) + app.command("citation", help="Print citation information")(citation) + app() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + cli() diff --git a/examples/msas/README.md b/examples/msas/README.md index 2b396d5..dd47f4a 100644 --- a/examples/msas/README.md +++ b/examples/msas/README.md @@ -2,8 +2,6 @@ While Chai-1 performs very well in "single-sequence mode," it can also be given additional evolutionary information to further improve performance. As in other folding methods, this evolutionary information is provided in the form of a multiple sequence alignment (MSA). This information is given in the form of a `MSAContext` object (see `chai_lab/data/dataset/msas/msa_context.py`); we provide code for building these `MSAContext` objects through `aligned.pqt` files, though you can play with building out an `MSAContext` yourself as well. -Multiple strategies can be used for generating MSAs. In our [technical report](https://chaiassets.com/chai-1/paper/technical_report_v1.pdf), we generated MSAs using [jackhmmer](https://github.com/EddyRivasLab/hmmer). Other algorithms such as [MMseqs2](https://github.com/soedinglab/MMseqs2) can also be used. We provide an example of how to generate MSAs using [ColabFold](https://github.com/sokrypton/ColabFold) in `examples/msas/predict_with_msas.py`. Performance will vary depending on the input MSA databases and search algorithms used. - ## The `.aligned.pqt` file format The easiest way to provide MSA information to Chai-1 is through the `.aligned.pqt` file format that we have defined. This file can be thought of as an augmented `a3m` file, and is essentially a dataframe saved in parquet format with the following four (required) columns: @@ -58,4 +56,21 @@ import pandas as pd aligned_pqt = pd.read_parquet("examples/msas/703adc2c74b8d7e613549b6efcf37126da7963522dc33852ad3c691eef1da06f.aligned.pqt") aligned_pqt.head() -``` \ No newline at end of file +``` + + +## Additional MSA generation strategies + +Multiple strategies can be used for generating MSAs. In our [technical report](https://chaiassets.com/chai-1/paper/technical_report_v1.pdf), we generated MSAs using [jackhmmer](https://github.com/EddyRivasLab/hmmer). Other algorithms such as [MMseqs2](https://github.com/soedinglab/MMseqs2) can also be used. In this vein, we provide support for automatic MSA generation via the [ColabFold](https://github.com/sokrypton/ColabFold) server using `chai fold input.fasta output_directory --msa-server` or by invoking `run_inference` as follows: + +```python +candidates = run_inference( + ... + msa_sever=True, + ... +) +``` + +Please note that performance will vary depending on the input MSA databases and search algorithms used. + +In addition, people have found that tweaking MSA inputs can be a fruitful path to improving folding results -- we such exploration of this for Chai-1 as well! \ No newline at end of file diff --git a/examples/msas/predict_with_msas.py b/examples/msas/predict_with_msas.py index 3f260ba..4196f27 100644 --- a/examples/msas/predict_with_msas.py +++ b/examples/msas/predict_with_msas.py @@ -2,12 +2,8 @@ from pathlib import Path import numpy as np -import torch from chai_lab.chai1 import run_inference -from chai_lab.data.dataset.inference_dataset import read_inputs -from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas -from chai_lab.data.parsing.structure.entity_type import EntityType tmp_dir = Path(tempfile.mkdtemp()) @@ -25,16 +21,6 @@ fasta_path = tmp_dir / "example.fasta" fasta_path.write_text(example_fasta) -# Generate MSAs -msa_dir = tmp_dir / "msas" -msa_dir.mkdir() -protein_seqs = [ - input.sequence - for input in read_inputs(fasta_path) - if input.entity_type == EntityType.PROTEIN.value -] -generate_colabfold_msas(protein_seqs=protein_seqs, msa_dir=msa_dir) - # Generate structure output_dir = tmp_dir / "outputs" @@ -45,9 +31,12 @@ num_trunk_recycles=3, num_diffn_timesteps=200, seed=42, - device=torch.device("cuda:0"), + device="cuda:0", use_esm_embeddings=True, - msa_directory=msa_dir, + # See example .aligned.pqt files in this directory + msa_directory=Path(__file__).parent, + # Exclusive with msa_directory; can be used for MMseqs2 server MSA generation + msa_server=False, ) cif_paths = candidates.cif_paths scores = [rd.aggregate_score for rd in candidates.ranking_data] diff --git a/examples/predict_structure.py b/examples/predict_structure.py index 4e5a51d..2049304 100644 --- a/examples/predict_structure.py +++ b/examples/predict_structure.py @@ -1,7 +1,6 @@ from pathlib import Path import numpy as np -import torch from chai_lab.chai1 import run_inference @@ -36,7 +35,7 @@ num_trunk_recycles=3, num_diffn_timesteps=200, seed=42, - device=torch.device("cuda:0"), + device="cuda:0", use_esm_embeddings=True, ) diff --git a/examples/restraints/predict_with_restraints.py b/examples/restraints/predict_with_restraints.py index 74757d7..0adce06 100644 --- a/examples/restraints/predict_with_restraints.py +++ b/examples/restraints/predict_with_restraints.py @@ -1,8 +1,6 @@ import logging from pathlib import Path -import torch - from chai_lab.chai1 import run_inference logging.basicConfig(level=logging.INFO) @@ -32,6 +30,6 @@ num_trunk_recycles=3, num_diffn_timesteps=200, seed=42, - device=torch.device("cuda:0"), + device="cuda:0", use_esm_embeddings=True, ) diff --git a/pyproject.toml b/pyproject.toml index 6f1e10f..83823c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,4 +68,7 @@ exclude = [ ] [tool.hatch.build.targets.wheel] -# should use packages from sdist section \ No newline at end of file +# should use packages from sdist section + +[project.scripts] +chai = "chai_lab.main:cli" \ No newline at end of file