Skip to content

Commit

Permalink
Add CLI option to set colabfold server url (#199)
Browse files Browse the repository at this point in the history
* Add CLI option to set colabfold server

* Update README.md

Co-authored-by: Jack Dent <[email protected]>

* Update chai_lab/chai1.py

Co-authored-by: Jack Dent <[email protected]>

* Update chai_lab/chai1.py

Co-authored-by: Jack Dent <[email protected]>

* Update chai_lab/data/dataset/msas/colabfold.py

Co-authored-by: Jack Dent <[email protected]>

* Update chai_lab/data/dataset/msas/colabfold.py

Co-authored-by: Jack Dent <[email protected]>

* Update README.md

Co-authored-by: Jack Dent <[email protected]>

---------

Co-authored-by: Jack Dent <[email protected]>
  • Loading branch information
danpf and jackdent authored Dec 2, 2024
1 parent e80bb3a commit 8f49f99
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ For example, to run the model with MSAs (which we recommend for improved perform
chai fold --use-msa-server input.fasta output_folder
```

If you are hosting your own ColabFold server, additionally pass the `--msa-server` flag with your server:

```shell
chai fold --use-msa-server --msa-server-url "https://api.internalcolabserver.com" input.fasta output_folder
```

### Programmatic inference

The main entrypoint into the Chai-1 folding code is through the `chai_lab.chai1.run_inference` function. The following script demonstrates how to programmatically provide inputs to the model, and obtain a list of PDB files for downstream analysis:
Expand Down Expand Up @@ -71,7 +77,7 @@ CHAI_DOWNLOADS_DIR=/tmp/downloads python ./examples/predict_structure.py

Chai-1 supports MSAs provided as an `aligned.pqt` file. This file format is similar to an `a3m` file, but has additional columns that provide metadata like the source database and sequence pairing keys. We provide code to convert `a3m` files to `aligned.pqt` files. For more information on how to provide MSAs to Chai-1, see [this documentation](examples/msas/README.md).

For user convenience, we also support automatic MSA generation via the ColabFold [MMseqs2](https://github.com/soedinglab/MMseqs2) server via the `--msa-server` flag. As detailed in the ColabFold [repository](https://github.com/sokrypton/ColabFold), please keep in mind that this is a shared resource. Note that the results reported in our preprint and the webserver use a different MSA search strategy than MMseqs2, though we expect results to be broadly similar.
For user convenience, we also support automatic MSA generation via the ColabFold [MMseqs2](https://github.com/soedinglab/MMseqs2) server via the `--use-msa-server` flag. As detailed in the ColabFold [repository](https://github.com/sokrypton/ColabFold), please keep in mind that this is a shared resource. Note that the results reported in our preprint and the webserver use a different MSA search strategy than MMseqs2, though we expect results to be broadly similar.

</p>
</details>
Expand Down
13 changes: 9 additions & 4 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ def run_inference(
*,
output_dir: Path,
use_esm_embeddings: bool = True,
msa_server: bool = False,
use_msa_server: bool = False,
msa_server_url: str = "https://api.colabfold.com",
msa_directory: Path | None = None,
constraint_path: Path | None = None,
# expose some params for easy tweaking
Expand All @@ -285,7 +286,7 @@ def run_inference(
), f"Output directory {output_dir} is not empty."
torch_device = torch.device(device if device is not None else "cuda:0")
assert not (
msa_server and msa_directory
use_msa_server and msa_directory
), "Cannot specify both MSA server and directory"

# Prepare inputs
Expand All @@ -311,15 +312,19 @@ def run_inference(
raise_if_too_many_tokens(n_actual_tokens)

# Generated and/or load MSAs
if msa_server:
if use_msa_server:
protein_sequences = [
chain.entity_data.sequence
for chain in chains
if chain.entity_data.entity_type == EntityType.PROTEIN
]
msa_dir = output_dir / "msas"
msa_dir.mkdir(parents=True, exist_ok=False)
generate_colabfold_msas(protein_seqs=protein_sequences, msa_dir=msa_dir)
generate_colabfold_msas(
protein_seqs=protein_sequences,
msa_dir=msa_dir,
msa_server_url=msa_server_url,
)
msa_context, msa_profile_context = get_msa_contexts(
chains, msa_directory=msa_dir
)
Expand Down
7 changes: 6 additions & 1 deletion chai_lab/data/dataset/msas/colabfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,11 @@ def download(ID, path):
return (a3m_lines, template_paths) if use_templates else a3m_lines


def generate_colabfold_msas(protein_seqs: list[str], msa_dir: Path):
def generate_colabfold_msas(
protein_seqs: list[str],
msa_dir: Path,
msa_server_url: str,
):
"""
Generate MSAs using the ColabFold (https://github.com/sokrypton/ColabFold)
server.
Expand Down Expand Up @@ -374,6 +378,7 @@ def generate_colabfold_msas(protein_seqs: list[str], msa_dir: Path):
mmseqs_dir,
# N.B. we can set this to False to disable pairing
use_pairing=len(protein_seqs) > 1,
host_url=msa_server_url,
user_agent="chai-lab/0.4.0 [email protected]",
)
assert isinstance(msas, list)
Expand Down
2 changes: 1 addition & 1 deletion examples/msas/predict_with_msas.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
# See example .aligned.pqt files in this directory
msa_directory=Path(__file__).parent,
# Exclusive with msa_directory; can be used for MMseqs2 server MSA generation
msa_server=False,
use_msa_server=False,
)
cif_paths = candidates.cif_paths
scores = [rd.aggregate_score for rd in candidates.ranking_data]
Expand Down

0 comments on commit 8f49f99

Please sign in to comment.