Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve colabfold MSAs to include unpaired MSA hits #213

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 90 additions & 25 deletions chai_lab/data/dataset/msas/colabfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from chai_lab import __version__
from chai_lab.data.parsing.fasta import read_fasta
from chai_lab.data.parsing.msas.aligned_pqt import expected_basename, hash_sequence
from chai_lab.data.parsing.msas.data_source import MSADataSource

logger = logging.getLogger(__name__)

Expand All @@ -26,7 +27,7 @@
)


# N.B. this code is copied from https://github.com/sokrypton/ColabFold
# N.B. this function (and this function only) is copied from https://github.com/sokrypton/ColabFold
# and follows the license in that repository
@typing.no_type_check # Original ColabFold code was not well typed
def _run_mmseqs2(
Expand All @@ -41,6 +42,7 @@ def _run_mmseqs2(
host_url="https://api.colabfold.com",
user_agent: str = "",
) -> list[str] | tuple[list[str], list[str]]:
"""Return a block of a3m lines for each of the input sequences in x."""
submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"

headers = {}
Expand Down Expand Up @@ -342,19 +344,26 @@ def download(ID, path):
return (a3m_lines, template_paths) if use_templates else a3m_lines


def _is_padding_msa_row(sequence: str) -> bool:
"""Check if the given MSA sequence is a a padding sequence."""
seq_chars = set(sequence)
return len(seq_chars) == 1 and seq_chars.pop() == "-"


def generate_colabfold_msas(
protein_seqs: list[str],
msa_dir: Path,
msa_server_url: str,
write_a3m_to_msa_dir: bool = False, # Useful for manual inspection + debugging
):
"""
Generate MSAs using the ColabFold (https://github.com/sokrypton/ColabFold)
server. No-op if no protein sequences are given.

N.B. the MSAs in our technical report were generated using jackhmmer, not
N.B.:
- the MSAs in our technical report were generated using jackhmmer, not
ColabFold, so we would expect some difference in results.

This implementation also relies on ColabFold's chain pairing algorithm
- this implementation relies on ColabFold's chain pairing algorithm
rather than using Chai-1's own algorithm, which could also lead to
differences in results.

Expand All @@ -369,52 +378,108 @@ def generate_colabfold_msas(
with tempfile.TemporaryDirectory() as tmp_dir_path:
tmp_dir = Path(tmp_dir_path)

mmseqs_paired_dir = tmp_dir / "mmseqs_paired"
mmseqs_paired_dir.mkdir()

mmseqs_dir = tmp_dir / "mmseqs"
mmseqs_dir.mkdir()

a3ms_dir = tmp_dir / "a3ms"
a3ms_dir = (tmp_dir if not write_a3m_to_msa_dir else msa_dir) / "a3ms"
a3ms_dir.mkdir()

# Generate MSAs for each protein chain
logger.info(f"Running MSA generation for {len(protein_seqs)} protein sequences")
msas = _run_mmseqs2(

# In paired mode, mmseqs2 returns paired a3ms where all a3ms have the same number of rows
# and each row is already paired to have the same species. As such, we insert pairing key
# as the i-th index of the sequence so long as it isn't a padding sequence (all -)
if len(protein_seqs) > 1:
paired_msas = _run_mmseqs2(
protein_seqs,
mmseqs_paired_dir,
use_pairing=True,
host_url=msa_server_url,
user_agent=f"chai-lab/{__version__} [email protected]",
)
else:
# If we only have a single protein chain, there are no paired MSAs by definition
paired_msas = ["" for _ in protein_seqs]
assert isinstance(paired_msas, list)
wukevin marked this conversation as resolved.
Show resolved Hide resolved

# MSAs without pairing logic attached; may include sequences not contained in the paired MSA
# Needs a second call as the colabfold server returns either paired or unpaired, not both
per_chain_msas = _run_mmseqs2(
protein_seqs,
mmseqs_dir,
# N.B. we can set this to False to disable pairing
use_pairing=len(protein_seqs) > 1,
use_pairing=False,
host_url=msa_server_url,
user_agent=f"chai-lab/{__version__} [email protected]",
)
assert isinstance(msas, list)

# Process the MSAs into our internal format
for protein_seq, msa in zip(protein_seqs, msas, strict=True):
# Write out an A3M file
a3m_path = a3ms_dir / f"{hash_sequence(protein_seq.upper())}.a3m"
a3m_path.write_text(msa)
for protein_seq, pair_msa, single_msa in zip(
protein_seqs, paired_msas, per_chain_msas, strict=True
):
# Write out an A3M file for both
hkey = hash_sequence(protein_seq.upper())
pair_a3m_path = a3ms_dir / f"{hkey}.pair.a3m"
pair_a3m_path.write_text(pair_msa)
single_a3m_path = a3ms_dir / f"{hkey}.single.a3m"
single_a3m_path.write_text(single_msa)

## Convert the A3M file into aligned parquet files
# Set the pairing key as the ith-index in the sequences, skip over sequences that have
# been inserted as padding as our internal pairing logic will match on pairing key.
paired_fasta: list[tuple[int, str, str]] = [
(pairkey, record.header, record.sequence)
for pairkey, record in enumerate(read_fasta(pair_a3m_path))
if not _is_padding_msa_row(record.sequence)
]
pairing_key, paired_headers, paired_msa_seqs = (
zip(*paired_fasta) if paired_fasta else ((), (), ())
)

# Convert the A3M file into aligned parquet files
msa_fasta = read_fasta(a3m_path)
headers, msa_seqs = zip(*msa_fasta)
# Non-paired MSA sequences that weren't already covered in the paired MSA; skip header
single_fasta: list[tuple[str, str]] = [
(record.header, record.sequence)
wukevin marked this conversation as resolved.
Show resolved Hide resolved
for i, record in enumerate(read_fasta(single_a3m_path))
if (
i > 0
and not _is_padding_msa_row(record.sequence)
and record.sequence not in set(paired_msa_seqs)
wukevin marked this conversation as resolved.
Show resolved Hide resolved
)
]
single_headers, single_msa_seqs = (
zip(*single_fasta) if single_fasta else ((), ())
)
# Create null pairing keys for each of the entries in the single MSA seq
single_null_pair_keys = ["" for _ in range(len(single_msa_seqs))]
wukevin marked this conversation as resolved.
Show resolved Hide resolved

# This shouldn't have much of an effect on the model, but we make
# a best effort to synthesize a source database anyway
# NOTE we already dropped the query row from the single MSAs so no need to slice
source_databases = ["query"] + [
"uniref90" if h.startswith("UniRef") else "bfd_uniclust"
for h in headers[1:]
(
MSADataSource.UNIREF90.value
if h.startswith("UniRef")
else MSADataSource.BFD_UNICLUST.value
)
for h in (paired_headers + single_headers)[1:]
]

# Combine information across paired and single hits
all_sequences = paired_msa_seqs + single_msa_seqs
all_pairing_keys = [str(k) for k in pairing_key] + single_null_pair_keys
wukevin marked this conversation as resolved.
Show resolved Hide resolved
assert (
len(all_sequences) == len(all_pairing_keys) == len(source_databases)
), f"Mismatched lengths: {len(all_sequences)=} {len(all_pairing_keys)=} {len(source_databases)=}"

# Map the MSAs to our internal format
aligned_df = pd.DataFrame(
data=dict(
sequence=msa_seqs,
sequence=all_sequences,
source_database=source_databases,
# ColabFold does not return taxonomies from its API, so we
# can't rely on our internal chain pairing logic. As an
# alternative, we could disable ColabFold pairing and rely
# on a mapping from sequence ~> taxonomy, which would allow
# us to use our internal pairing logic.
pairing_key="",
pairing_key=all_pairing_keys,
comment="",
),
)
Expand Down
3 changes: 2 additions & 1 deletion chai_lab/data/dataset/msas/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_UKEY_FOR_QUERY = (-999, -999)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def merge_main_msas_by_chain(msas: list[MSAContext]) -> MSAContext:
Expand Down Expand Up @@ -120,7 +121,7 @@ def pair_and_merge_msas(msas: list[MSAContext]) -> MSAContext:
selected_msa = msa.take_rows_with_padding(all_rowids)

logger.info(
f"Loaded (paired in includes query sequence): "
f"Loaded (paired includes query sequence): "
f"{n_paired_msa=} {n_unpaired_msa=} out of {msa.depth=} "
)

Expand Down
Loading