Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproduce PoseBusters score #210

Open
Bae-SungHan opened this issue Dec 4, 2024 · 2 comments
Open

Reproduce PoseBusters score #210

Bae-SungHan opened this issue Dec 4, 2024 · 2 comments

Comments

@Bae-SungHan
Copy link

Bae-SungHan commented Dec 4, 2024

Hi.
I have faced some problems while trying to reproduce the PoseBusters score you reported (success rate under 2.0 Å is about 77%) with Chai-1.
First of all, I measured the RMSD between the Chai-1 prediction provided in the paper and the GT crystal structure, and the success rate was 71%, which is lower than the 77% you reported.
I ran the measurement with below code while considering atom order of ligand molecule in the GT sdf file and prediction file is same.
Is there any problem in my work?
Thx.

from pymol import cmd, stored
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.rdMolAlign import CalcRMS
import tempfile
import os
from tqdm import tqdm
from rdkit.Geometry import Point3D
import copy


def assign_bond_order(template: Chem.Mol, mol: Chem.Mol) -> Chem.Mol:
    mol = AllChem.AssignBondOrdersFromTemplate(template, mol)
    mol = Chem.RemoveHs(mol)
    return mol

data_root ="./data/posebusters/posebusters_benchmark_set"
pred_root = "./data/posebusters/chai_1_posebusters_predictions.personal"
data_ids = [i for i in os.listdir(data_root) if i !="7D6O_MTE"]
rmsd_dict = dict()
for data_id in tqdm(data_ids):
    ref_protein = os.path.join(data_root,data_id,f"{data_id}_protein.pdb")
    ref_ligand = os.path.join(data_root,data_id,f"{data_id}_ligand.sdf")
    cmd.reinitialize()
    cmd.set("retain_order", 1)
    cmd.load(ref_protein, "crystal_prot")
    cmd.load(ref_ligand, "crystal_lig")
    cmd.select("receptor", "br. (%crystal_prot and not h.) w. 10.0 for (%crystal_lig and not h.)")
    pdb_code, lig_ccd = data_id.split("_")
    cmd.load(os.path.join(pred_root,f"{pdb_code}__{lig_ccd}_pred.pdb"), "pred")
    protein_rmsd = cmd.align("pred", "receptor")[0]
    cmd.select("pred_lig", f"pred and resn {lig_ccd}")
    if cmd.count_atoms("pred_lig") == 0:
        print(f"{data_id} has no ligand")
        continue
    chains = cmd.get_chains("pred_lig")

    ref_lig = Chem.MolFromMolFile(ref_ligand,removeHs=True)
    rmsds = []
    with tempfile.TemporaryDirectory() as tmp_dir:
        for chain in chains:
            cmd.select("sub_pred_lig",f"pred_lig and chain {chain}")
            stored.pred_coords = []
            cmd.iterate_state(1, "sub_pred_lig", "stored.pred_coords.append((x, y, z))")  

            pred_lig = copy.deepcopy(ref_lig)
            pred_lig.RemoveAllConformers()
            conf = Chem.Conformer(pred_lig.GetNumAtoms())
            for pred_coord, rd_atom in zip(
                stored.pred_coords, pred_lig.GetAtoms()
            ):
                x, y, z = pred_coord
                conf.SetAtomPosition(rd_atom.GetIdx(), Point3D(x, y, z))
            pred_lig.AddConformer(conf)
            rmsds.append(CalcRMS(pred_lig,ref_lig))
    rmsd_dict[data_id]=min(rmsds)

print(len([k for k,v in rmsd_dict.items() if v<=2.0])/len(rmsd_dict))
@jacquesboitreaud
Copy link
Contributor

jacquesboitreaud commented Dec 4, 2024

Hi, thanks for sharing the code, we found a few differences with what we ran, that should explain observed difference in metrics:

  • As ground truth PDB, we use the first biological assembly, not the asymmetric unit. We found this makes a difference in complexes with symmetries.

  • In the case where there are multiple copies of the ligand of interest, we run the pocket alignment and RMSD calculation independently for each of the copies, then report the best value. In your code, the pocket alignment is done only once, on the union of the pockets.

  • For the pocket alignment process, we followed Alphafold3 section Methods, Posebusters :

For pocket-aligned r.m.s.d., first alignment between the predicted and ground-truth structures was conducted by aligning to the ground-truth pocket backbone atoms (CA, C or N atoms within 10 Å of the ligand of interest) from the primary protein chain (the chain with the greatest number of contacts within 10 Å of the ligand)

with the exception that the alignment was done on CA only in our case, because we got confused by contradictory statements in the AF3 supplemental and AF3 main text. I doubt this changes the results in any case.

I m attaching a csv file with the rmsd values we obtained, as it may be helpful to debug some examples.

posebusters_summary.csv

@Bae-SungHan
Copy link
Author

Bae-SungHan commented Dec 5, 2024

Thank you for your prompt and kind response!

I would like to confirm if my understanding of your explanation is correct.

First, you mentioned that you considered the biological assembly instead of the asymmetric unit. In cases like 7A9E_R4W among the predicted structures you provided, the ground truth PDB is the asymmetric unit, but the predicted structure is the biological assembly. Was this intentional?

Also, in cases where there are multiple copies of the ligand of interest, I understood that for each ligand copy in the ground truth, you designate a pocket, align the predicted structure to it, measure the RMSD for each ground truth and predicted ligand pair, and then return the lowest value. Is this mechanism correct?
I am attaching the modified code based on my understanding.

Thank you!

from pymol import cmd, stored
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.rdMolAlign import CalcRMS
import tempfile
import os
from tqdm import tqdm
from rdkit.Geometry import Point3D
import copy

data_root ="./data/posebusters/posebusters_benchmark_set"
pred_root = "./data/posebusters/chai_1_posebusters_predictions.personal"
data_ids = [i for i in os.listdir(data_root) if i !="7D6O_MTE"]
rmsd_dict = dict()

for data_id in tqdm(data_ids):
    ref_protein = os.path.join(data_root,data_id,f"{data_id}_protein.pdb")
    ref_ligand_mols = Chem.SDMolSupplier(os.path.join(data_root,data_id,f"{data_id}_ligands.sdf"),removeHs=True)
    rmsds = []
    for ref_ligand_mol in ref_ligand_mols:
        with tempfile.TemporaryDirectory() as tmp_dir:
            ref_ligand = os.path.join(tmp_dir,"ref_ligand.sdf")
            with Chem.SDWriter(ref_ligand) as w:
                w.write(ref_ligand_mol)
            cmd.reinitialize()
            cmd.set("retain_order", 1)
            cmd.load(ref_protein, "crystal_prot")
            prot_chains = chains = cmd.get_chains("crystal_prot")
            cmd.load(ref_ligand, "crystal_lig")
            cmd.select("receptor", "br. (%crystal_prot and name CA) w. 10.0 for (%crystal_lig and not h.)")
            pdb_code, lig_ccd = data_id.split("_")
            cmd.load(os.path.join(pred_root,f"{pdb_code}__{lig_ccd}_pred.pdb"), "pred")
            protein_rmsd = cmd.align("pred", "receptor")[0]
            cmd.select("pred_lig", f"pred and resn {lig_ccd}")
            if cmd.count_atoms("pred_lig") == 0:
                print(f"{data_id} has no ligand")
                continue
            chains = cmd.get_chains("pred_lig")

            ref_lig = Chem.MolFromMolFile(ref_ligand,removeHs=True)
            for chain in chains:
                cmd.select("sub_pred_lig",f"pred_lig and chain {chain}")
                stored.pred_coords = []
                cmd.iterate_state(1, "sub_pred_lig", "stored.pred_coords.append((x, y, z))")  

                pred_lig = copy.deepcopy(ref_lig)
                pred_lig.RemoveAllConformers()
                conf = Chem.Conformer(pred_lig.GetNumAtoms())
                for pred_coord, rd_atom in zip(
                    stored.pred_coords, pred_lig.GetAtoms()
                ):
                    x, y, z = pred_coord
                    conf.SetAtomPosition(rd_atom.GetIdx(), Point3D(x, y, z))
                pred_lig.AddConformer(conf)
                rmsds.append(CalcRMS(pred_lig,ref_lig))
    if len(rmsds):
        rmsd_dict[data_id]=min(rmsds)
print(len([k for k,v in rmsd_dict.items() if v<=2.0])/len(rmsd_dict))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants