diff --git a/chai_lab/data/dataset/msas/colabfold.py b/chai_lab/data/dataset/msas/colabfold.py new file mode 100644 index 0000000..cd0ae5d --- /dev/null +++ b/chai_lab/data/dataset/msas/colabfold.py @@ -0,0 +1,413 @@ +# Copyright (c) 2024 Chai Discovery, Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for details. + +""" +N.B. this code is copied from https://github.com/sokrypton/ColabFold +and follows the license in that repository +""" + +import logging +import os +import random +import tarfile +import tempfile +import time +from pathlib import Path + +import pandas as pd +import requests +from tqdm import tqdm + +from chai_lab.data.parsing.fasta import read_fasta +from chai_lab.data.parsing.msas.aligned_pqt import expected_basename, hash_sequence + +logger = logging.getLogger(__name__) + +TQDM_BAR_FORMAT = ( + "{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]" +) + + +def _run_mmseqs2( + x, + prefix, + use_env=True, + use_filter=True, + use_templates=False, + filter=None, + use_pairing=False, + pairing_strategy="greedy", + host_url="https://api.colabfold.com", + user_agent: str = "", +) -> list[str] | tuple[list[str], list[str]]: + submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa" + + headers = {} + if user_agent != "": + headers["User-Agent"] = user_agent + else: + logger.warning( + "No user agent specified. Please set a user agent (e.g., 'toolname/version contact@email') to help us debug in case of problems. This warning will become an error in the future." + ) + + def submit(seqs, mode, N=101): + n, query = N, "" + for seq in seqs: + query += f">{n}\n{seq}\n" + n += 1 + + while True: + error_count = 0 + try: + # https://requests.readthedocs.io/en/latest/user/advanced/#advanced + # "good practice to set connect timeouts to slightly larger than a multiple of 3" + res = requests.post( + f"{host_url}/{submission_endpoint}", + data={"q": query, "mode": mode}, + timeout=6.02, + headers=headers, + ) + except requests.exceptions.Timeout: + logger.warning("Timeout while submitting to MSA server. Retrying...") + continue + except Exception as e: + error_count += 1 + logger.warning( + f"Error while fetching result from MSA server. Retrying... ({error_count}/5)" + ) + logger.warning(f"Error: {e}") + time.sleep(5) + if error_count > 5: + raise + continue + break + + try: + out = res.json() + except ValueError: + logger.error(f"Server didn't reply with json: {res.text}") + out = {"status": "ERROR"} + return out + + def status(ID): + while True: + error_count = 0 + try: + res = requests.get( + f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers + ) + except requests.exceptions.Timeout: + logger.warning( + "Timeout while fetching status from MSA server. Retrying..." + ) + continue + except Exception as e: + error_count += 1 + logger.warning( + f"Error while fetching result from MSA server. Retrying... ({error_count}/5)" + ) + logger.warning(f"Error: {e}") + time.sleep(5) + if error_count > 5: + raise + continue + break + try: + out = res.json() + except ValueError: + logger.error(f"Server didn't reply with json: {res.text}") + out = {"status": "ERROR"} + return out + + def download(ID, path): + error_count = 0 + while True: + try: + res = requests.get( + f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers + ) + except requests.exceptions.Timeout: + logger.warning( + "Timeout while fetching result from MSA server. Retrying..." + ) + continue + except Exception as e: + error_count += 1 + logger.warning( + f"Error while fetching result from MSA server. Retrying... ({error_count}/5)" + ) + logger.warning(f"Error: {e}") + time.sleep(5) + if error_count > 5: + raise + continue + break + with open(path, "wb") as out: + out.write(res.content) + + # process input x + seqs = [x] if isinstance(x, str) else x + + # compatibility to old option + if filter is not None: + use_filter = filter + + # setup mode + if use_filter: + mode = "env" if use_env else "all" + else: + mode = "env-nofilter" if use_env else "nofilter" + + if use_pairing: + use_templates = False + mode = "" + # greedy is default, complete was the previous behavior + if pairing_strategy == "greedy": + mode = "pairgreedy" + elif pairing_strategy == "complete": + mode = "paircomplete" + if use_env: + mode = mode + "-env" + + # define path + path = f"{prefix}_{mode}" + if not os.path.isdir(path): + os.mkdir(path) + + # call mmseqs2 api + tar_gz_file = f"{path}/out.tar.gz" + N, REDO = 101, True + + # deduplicate and keep track of order + seqs_unique = [] + # TODO this might be slow for large sets + [seqs_unique.append(x) for x in seqs if x not in seqs_unique] + Ms = [N + seqs_unique.index(seq) for seq in seqs] + # lets do it! + if not os.path.isfile(tar_gz_file): + TIME_ESTIMATE = 150 * len(seqs_unique) + with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: + while REDO: + pbar.set_description("SUBMIT") + + # Resubmit job until it goes through + out = submit(seqs_unique, mode, N) + while out["status"] in ["UNKNOWN", "RATELIMIT"]: + sleep_time = 5 + random.randint(0, 5) + logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}") + # resubmit + time.sleep(sleep_time) + out = submit(seqs_unique, mode, N) + + if out["status"] == "ERROR": + raise Exception( + "MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later." + ) + + if out["status"] == "MAINTENANCE": + raise Exception( + "MMseqs2 API is undergoing maintenance. Please try again in a few minutes." + ) + + # wait for job to finish + ID, TIME = out["id"], 0 + pbar.set_description(out["status"]) + while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]: + t = 5 + random.randint(0, 5) + logger.error(f"Sleeping for {t}s. Reason: {out['status']}") + time.sleep(t) + out = status(ID) + pbar.set_description(out["status"]) + if out["status"] == "RUNNING": + TIME += t + pbar.update(n=t) + # if TIME > 900 and out["status"] != "COMPLETE": + # # something failed on the server side, need to resubmit + # N += 1 + # break + + if out["status"] == "COMPLETE": + if TIME < TIME_ESTIMATE: + pbar.update(n=(TIME_ESTIMATE - TIME)) + REDO = False + + if out["status"] == "ERROR": + REDO = False + raise Exception( + "MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later." + ) + + # Download results + download(ID, tar_gz_file) + + # prep list of a3m files + if use_pairing: + a3m_files = [f"{path}/pair.a3m"] + else: + a3m_files = [f"{path}/uniref.a3m"] + if use_env: + a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m") + + # extract a3m files + if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files): + with tarfile.open(tar_gz_file) as tar_gz: + tar_gz.extractall(path) + + # templates + if use_templates: + templates = {} + # print("seq\tpdb\tcid\tevalue") + for line in open(f"{path}/pdb70.m8", "r"): + p = line.rstrip().split() + M, pdb, _, _ = p[0], p[1], p[2], p[10] + M = int(M) + if M not in templates: + templates[M] = [] + templates[M].append(pdb) + # if len(templates[M]) <= 20: + # print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}") + + template_paths = {} + for k, TMPL in templates.items(): + TMPL_PATH = f"{prefix}_{mode}/templates_{k}" + if not os.path.isdir(TMPL_PATH): + os.mkdir(TMPL_PATH) + TMPL_LINE = ",".join(TMPL[:20]) + response = None + while True: + error_count = 0 + try: + # https://requests.readthedocs.io/en/latest/user/advanced/#advanced + # "good practice to set connect timeouts to slightly larger than a multiple of 3" + response = requests.get( + f"{host_url}/template/{TMPL_LINE}", + stream=True, + timeout=6.02, + headers=headers, + ) + except requests.exceptions.Timeout: + logger.warning( + "Timeout while submitting to template server. Retrying..." + ) + continue + except Exception as e: + error_count += 1 + logger.warning( + f"Error while fetching result from template server. Retrying... ({error_count}/5)" + ) + logger.warning(f"Error: {e}") + time.sleep(5) + if error_count > 5: + raise + continue + break + with tarfile.open(fileobj=response.raw, mode="r|gz") as tar: + tar.extractall(path=TMPL_PATH) + os.symlink("pdb70_a3m.ffindex", f"{TMPL_PATH}/pdb70_cs219.ffindex") + with open(f"{TMPL_PATH}/pdb70_cs219.ffdata", "w") as f: + f.write("") + template_paths[k] = TMPL_PATH + + # gather a3m lines + a3m_lines = {} + for a3m_file in a3m_files: + update_M, M = True, None + for line in open(a3m_file, "r"): + if len(line) > 0: + if "\x00" in line: + line = line.replace("\x00", "") + update_M = True + if line.startswith(">") and update_M: + M = int(line[1:].rstrip()) + update_M = False + if M not in a3m_lines: + a3m_lines[M] = [] + a3m_lines[M].append(line) + + # return results + + a3m_lines = ["".join(a3m_lines[n]) for n in Ms] + + if use_templates: + template_paths_ = [] + for n in Ms: + if n not in template_paths: + template_paths_.append(None) + # print(f"{n-N}\tno_templates_found") + else: + template_paths_.append(template_paths[n]) + template_paths = template_paths_ + + return (a3m_lines, template_paths) if use_templates else a3m_lines + + +def generate_colabfold_msas(protein_seqs: list[str], msa_dir: Path): + """ + Generate MSAs using the ColabFold (https://github.com/sokrypton/ColabFold) + server. + + N.B. the MSAs in our technical report were generated using jackhmmer, not + ColabFold, so we would expect some difference in results. + + This implementation also relies on ColabFold's chain pairing algorithm + rather than using Chai-1's own algorithm, which could also lead to + differences in results. + """ + assert msa_dir.is_dir(), "MSA directory must be a dir" + assert not any(msa_dir.iterdir()), "MSA directory must be empty" + + with tempfile.TemporaryDirectory() as tmp_dir_path: + tmp_dir = Path(tmp_dir_path) + + mmseqs_dir = tmp_dir / "mmseqs" + mmseqs_dir.mkdir() + + a3ms_dir = tmp_dir / "a3ms" + a3ms_dir.mkdir() + + # Generate MSAs for each protein chain + print(f"Running MSA generation for {len(protein_seqs)} protein sequences") + msas = _run_mmseqs2( + protein_seqs, + mmseqs_dir, + # N.B. we can set this to False to disable pairing + use_pairing=len(protein_seqs) > 1, + user_agent="chai-lab/0.4.0 feedback@chaidiscovery.com", + ) + assert isinstance(msas, list) + + # Process the MSAs into our internal format + for protein_seq, msa in zip(protein_seqs, msas, strict=True): + # Write out an A3M file + a3m_path = a3ms_dir / f"{hash_sequence(protein_seq.upper())}.a3m" + a3m_path.write_text(msa) + + # Convert the A3M file into aligned parquet files + msa_fasta = read_fasta(a3m_path) + headers, msa_seqs = zip(*msa_fasta) + + # This shouldn't have much of an effect on the model, but we make + # a best effort to synthesize a source database anyway + source_databases = ["query"] + [ + "uniref90" if h.startswith("UniRef") else "bfd_uniclust" + for h in headers[1:] + ] + + # Map the MSAs to our internal format + aligned_df = pd.DataFrame( + data=dict( + sequence=msa_seqs, + source_database=source_databases, + # ColabFold does not return taxonomies from its API, so we + # can't rely on our internal chain pairing logic. As an + # alternative, we could disable ColabFold pairing and rely + # on a mapping from sequence ~> taxonomy, which would allow + # us to use our internal pairing logic. + pairing_key="", + comment="", + ), + ) + msa_path = msa_dir / expected_basename(protein_seq) + assert not msa_path.exists() + aligned_df.to_parquet(msa_path) diff --git a/examples/msas/README.md b/examples/msas/README.md index 87977e6..2b396d5 100644 --- a/examples/msas/README.md +++ b/examples/msas/README.md @@ -2,6 +2,8 @@ While Chai-1 performs very well in "single-sequence mode," it can also be given additional evolutionary information to further improve performance. As in other folding methods, this evolutionary information is provided in the form of a multiple sequence alignment (MSA). This information is given in the form of a `MSAContext` object (see `chai_lab/data/dataset/msas/msa_context.py`); we provide code for building these `MSAContext` objects through `aligned.pqt` files, though you can play with building out an `MSAContext` yourself as well. +Multiple strategies can be used for generating MSAs. In our [technical report](https://chaiassets.com/chai-1/paper/technical_report_v1.pdf), we generated MSAs using [jackhmmer](https://github.com/EddyRivasLab/hmmer). Other algorithms such as [MMseqs2](https://github.com/soedinglab/MMseqs2) can also be used. We provide an example of how to generate MSAs using [ColabFold](https://github.com/sokrypton/ColabFold) in `examples/msas/predict_with_msas.py`. Performance will vary depending on the input MSA databases and search algorithms used. + ## The `.aligned.pqt` file format The easiest way to provide MSA information to Chai-1 is through the `.aligned.pqt` file format that we have defined. This file can be thought of as an augmented `a3m` file, and is essentially a dataframe saved in parquet format with the following four (required) columns: diff --git a/examples/msas/predict_with_msas.py b/examples/msas/predict_with_msas.py new file mode 100644 index 0000000..3f260ba --- /dev/null +++ b/examples/msas/predict_with_msas.py @@ -0,0 +1,56 @@ +import tempfile +from pathlib import Path + +import numpy as np +import torch + +from chai_lab.chai1 import run_inference +from chai_lab.data.dataset.inference_dataset import read_inputs +from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas +from chai_lab.data.parsing.structure.entity_type import EntityType + +tmp_dir = Path(tempfile.mkdtemp()) + +# Prepare input fasta +example_fasta = """ +>protein|name=example-of-long-protein +AGSHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASPRGEPRAPWVEQEGPEYWDRETQKYKRQAQTDRVSLRNLRGYYNQSEAGSHTLQWMFGCDLGPDGRLLRGYDQSAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAAREAEQRRAYLEGTCVEWLRRYLENGKETLQRAEHPKTHVTHHPVSDHEATLRCWALGFYPAEITLTWQWDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPEPLTLRWEP +>protein|name=example-of-short-protein +AIQRTPKIQVYSRHPAENGKSNFLNCYVSGFHPSDIEVDLLKNGERIEKVEHSDLSFSKDWSFYLLYYTEFTPTEKDEYACRVNHVTLSQPKIVKWDRDM +>protein|name=example-peptide +GAAL +>ligand|name=example-ligand-as-smiles +CCCCCCCCCCCCCC(=O)O +""".strip() +fasta_path = tmp_dir / "example.fasta" +fasta_path.write_text(example_fasta) + +# Generate MSAs +msa_dir = tmp_dir / "msas" +msa_dir.mkdir() +protein_seqs = [ + input.sequence + for input in read_inputs(fasta_path) + if input.entity_type == EntityType.PROTEIN.value +] +generate_colabfold_msas(protein_seqs=protein_seqs, msa_dir=msa_dir) + + +# Generate structure +output_dir = tmp_dir / "outputs" +candidates = run_inference( + fasta_file=fasta_path, + output_dir=output_dir, + # 'default' setup + num_trunk_recycles=3, + num_diffn_timesteps=200, + seed=42, + device=torch.device("cuda:0"), + use_esm_embeddings=True, + msa_directory=msa_dir, +) +cif_paths = candidates.cif_paths +scores = [rd.aggregate_score for rd in candidates.ranking_data] + +# Load pTM, ipTM, pLDDTs and clash scores for sample 2 +scores = np.load(output_dir.joinpath("scores.model_idx_2.npz"))