Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cif writing logic for multiple ligands #247

Merged
merged 5 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,9 +902,9 @@ def avg_per_token_1d(x):
bfactors=scaled_plddt_scores_per_atom,
output_batch=inputs,
write_path=cif_out_path,
entity_names={
c.entity_data.entity_id: c.entity_data.entity_name
for c in feature_context.chains
asym_entity_names={
i: c.entity_data.entity_name
for i, c in enumerate(feature_context.chains, start=1)
},
)
cif_paths.append(cif_out_path)
Expand Down
30 changes: 18 additions & 12 deletions chai_lab/data/io/cif_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def token_centre_plddts(
return plddts[atom_idces].tolist(), residue_indices.tolist()


def get_chains_metadata(context: PDBContext, entity_names) -> dict[int, AsymUnit]:
def get_chains_metadata(
context: PDBContext, asymid2entity_name: dict[int, str]
) -> dict[int, AsymUnit]:
"""Return mapping from asym id to AsymUnit objects."""
assert context.asym_id2entity_type.keys() == asymid2entity_name.keys()
# for each chain, get chain id, entity id, full sequence
token_res_names = context.token_res_names_to_string

Expand Down Expand Up @@ -101,15 +105,15 @@ def get_chains_metadata(context: PDBContext, entity_names) -> dict[int, AsymUnit

sequence = [chain_token_res_names[i] for i in any_token_in_resi]

entity_id = context.token_entity_id[token_indices[0]]

chain_id_str = _get_chain_letter(asym_id)

asym_id2asym_unit[asym_id] = AsymUnit(
entity=Entity(
# sequence is a list of ChemComponents for aminoacids/bases
sequence=[_to_chem_component(resi, entity_type) for resi in sequence],
description=entity_names[int(entity_id)],
sequence=[
_to_chem_component(resi, entity_type, asym_id) for resi in sequence
],
description=asymid2entity_name[asym_id],
),
details=f"Chain {chain_id_str}",
id=chain_id_str,
Expand All @@ -118,11 +122,10 @@ def get_chains_metadata(context: PDBContext, entity_names) -> dict[int, AsymUnit
return asym_id2asym_unit


def _to_chem_component(res_name_3: str, entity_type: int):
def _to_chem_component(res_name_3: str, entity_type: int, asym_id: int):
match entity_type:
case EntityType.LIGAND.value:
code = res_name_3
return NonPolymerChemComp(res_name_3)
return NonPolymerChemComp(id=res_name_3 + str(asym_id))
case EntityType.MANUAL_GLYCAN.value:
return SaccharideChemComp(id=res_name_3, name=res_name_3)
case EntityType.PROTEIN.value:
Expand All @@ -140,11 +143,12 @@ def _to_chem_component(res_name_3: str, entity_type: int):
raise NotImplementedError(f"Cannot handle entity type: {entity_type}")


@typecheck
def save_to_cif(
coords: Float[Tensor, "1 n_atoms 3"],
output_batch: dict,
write_path: Path,
entity_names: dict[int, str],
asym_entity_names: dict[int, str],
bfactors: Float[Tensor, "1 n_atoms"] | None = None,
):
write_path.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -153,7 +157,7 @@ def save_to_cif(
coords=rearrange(coords, "1 n c -> n c", c=3).cpu(),
plddts=None if bfactors is None else rearrange(bfactors, "1 n -> n").cpu(),
context=pdb_context_from_batch(output_batch),
entity_names=entity_names,
asym_entity_names=asym_entity_names,
out_path=write_path,
)
logger.info(f"saved cif file to {write_path}")
Expand All @@ -164,10 +168,12 @@ def new_context_to_cif_atoms(
coords: Float[Tensor, "n_atoms 3"],
plddts: Float[Tensor, "n_atoms"] | None,
context: PDBContext,
entity_names: dict[int, str],
asym_entity_names: dict[int, str],
out_path: Path,
):
asym_id2asym_unit = get_chains_metadata(context, entity_names=entity_names)
asym_id2asym_unit = get_chains_metadata(
context, asymid2entity_name=asym_entity_names
)

atom_asym_id = context.token_asym_id[context.atom_token_index]
# atom level attributes
Expand Down
Loading