Skip to content

Commit

Permalink
Fix cif writing logic for multiple ligands (#247)
Browse files Browse the repository at this point in the history
* Entity names should be mapped by asym ID

* Typechecks, map entity names by asym id

* Fix naming of ligand components

* Minor fix

* Rename vars
  • Loading branch information
wukevin authored Dec 12, 2024
1 parent 696270c commit a54c6bd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
6 changes: 3 additions & 3 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,9 +902,9 @@ def avg_per_token_1d(x):
bfactors=scaled_plddt_scores_per_atom,
output_batch=inputs,
write_path=cif_out_path,
entity_names={
c.entity_data.entity_id: c.entity_data.entity_name
for c in feature_context.chains
asym_entity_names={
i: c.entity_data.entity_name
for i, c in enumerate(feature_context.chains, start=1)
},
)
cif_paths.append(cif_out_path)
Expand Down
30 changes: 18 additions & 12 deletions chai_lab/data/io/cif_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def token_centre_plddts(
return plddts[atom_idces].tolist(), residue_indices.tolist()


def get_chains_metadata(context: PDBContext, entity_names) -> dict[int, AsymUnit]:
def get_chains_metadata(
context: PDBContext, asymid2entity_name: dict[int, str]
) -> dict[int, AsymUnit]:
"""Return mapping from asym id to AsymUnit objects."""
assert context.asym_id2entity_type.keys() == asymid2entity_name.keys()
# for each chain, get chain id, entity id, full sequence
token_res_names = context.token_res_names_to_string

Expand Down Expand Up @@ -101,15 +105,15 @@ def get_chains_metadata(context: PDBContext, entity_names) -> dict[int, AsymUnit

sequence = [chain_token_res_names[i] for i in any_token_in_resi]

entity_id = context.token_entity_id[token_indices[0]]

chain_id_str = _get_chain_letter(asym_id)

asym_id2asym_unit[asym_id] = AsymUnit(
entity=Entity(
# sequence is a list of ChemComponents for aminoacids/bases
sequence=[_to_chem_component(resi, entity_type) for resi in sequence],
description=entity_names[int(entity_id)],
sequence=[
_to_chem_component(resi, entity_type, asym_id) for resi in sequence
],
description=asymid2entity_name[asym_id],
),
details=f"Chain {chain_id_str}",
id=chain_id_str,
Expand All @@ -118,11 +122,10 @@ def get_chains_metadata(context: PDBContext, entity_names) -> dict[int, AsymUnit
return asym_id2asym_unit


def _to_chem_component(res_name_3: str, entity_type: int):
def _to_chem_component(res_name_3: str, entity_type: int, asym_id: int):
match entity_type:
case EntityType.LIGAND.value:
code = res_name_3
return NonPolymerChemComp(res_name_3)
return NonPolymerChemComp(id=res_name_3 + str(asym_id))
case EntityType.MANUAL_GLYCAN.value:
return SaccharideChemComp(id=res_name_3, name=res_name_3)
case EntityType.PROTEIN.value:
Expand All @@ -140,11 +143,12 @@ def _to_chem_component(res_name_3: str, entity_type: int):
raise NotImplementedError(f"Cannot handle entity type: {entity_type}")


@typecheck
def save_to_cif(
coords: Float[Tensor, "1 n_atoms 3"],
output_batch: dict,
write_path: Path,
entity_names: dict[int, str],
asym_entity_names: dict[int, str],
bfactors: Float[Tensor, "1 n_atoms"] | None = None,
):
write_path.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -153,7 +157,7 @@ def save_to_cif(
coords=rearrange(coords, "1 n c -> n c", c=3).cpu(),
plddts=None if bfactors is None else rearrange(bfactors, "1 n -> n").cpu(),
context=pdb_context_from_batch(output_batch),
entity_names=entity_names,
asym_entity_names=asym_entity_names,
out_path=write_path,
)
logger.info(f"saved cif file to {write_path}")
Expand All @@ -164,10 +168,12 @@ def new_context_to_cif_atoms(
coords: Float[Tensor, "n_atoms 3"],
plddts: Float[Tensor, "n_atoms"] | None,
context: PDBContext,
entity_names: dict[int, str],
asym_entity_names: dict[int, str],
out_path: Path,
):
asym_id2asym_unit = get_chains_metadata(context, entity_names=entity_names)
asym_id2asym_unit = get_chains_metadata(
context, asymid2entity_name=asym_entity_names
)

atom_asym_id = context.token_asym_id[context.atom_token_index]
# atom level attributes
Expand Down

0 comments on commit a54c6bd

Please sign in to comment.