Skip to content

Commit

Permalink
Rename vars
Browse files Browse the repository at this point in the history
  • Loading branch information
wukevin committed Dec 12, 2024
1 parent b1a8b72 commit 58bfece
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def avg_per_token_1d(x):
bfactors=scaled_plddt_scores_per_atom,
output_batch=inputs,
write_path=cif_out_path,
entity_names={
asym_entity_names={
i: c.entity_data.entity_name
for i, c in enumerate(feature_context.chains, start=1)
},
Expand Down
16 changes: 9 additions & 7 deletions chai_lab/data/io/cif_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def token_centre_plddts(


def get_chains_metadata(
context: PDBContext, entity_names: dict[int, str]
context: PDBContext, asymid2entity_name: dict[int, str]
) -> dict[int, AsymUnit]:
"""Return mapping from asym id to AsymUnit objects."""
assert context.asym_id2entity_type.keys() == entity_names.keys()
assert context.asym_id2entity_type.keys() == asymid2entity_name.keys()
# for each chain, get chain id, entity id, full sequence
token_res_names = context.token_res_names_to_string

Expand Down Expand Up @@ -113,7 +113,7 @@ def get_chains_metadata(
sequence=[
_to_chem_component(resi, entity_type, asym_id) for resi in sequence
],
description=entity_names[asym_id],
description=asymid2entity_name[asym_id],
),
details=f"Chain {chain_id_str}",
id=chain_id_str,
Expand Down Expand Up @@ -148,7 +148,7 @@ def save_to_cif(
coords: Float[Tensor, "1 n_atoms 3"],
output_batch: dict,
write_path: Path,
entity_names: dict[int, str],
asym_entity_names: dict[int, str],
bfactors: Float[Tensor, "1 n_atoms"] | None = None,
):
write_path.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -157,7 +157,7 @@ def save_to_cif(
coords=rearrange(coords, "1 n c -> n c", c=3).cpu(),
plddts=None if bfactors is None else rearrange(bfactors, "1 n -> n").cpu(),
context=pdb_context_from_batch(output_batch),
entity_names=entity_names,
asym_entity_names=asym_entity_names,
out_path=write_path,
)
logger.info(f"saved cif file to {write_path}")
Expand All @@ -168,10 +168,12 @@ def new_context_to_cif_atoms(
coords: Float[Tensor, "n_atoms 3"],
plddts: Float[Tensor, "n_atoms"] | None,
context: PDBContext,
entity_names: dict[int, str],
asym_entity_names: dict[int, str],
out_path: Path,
):
asym_id2asym_unit = get_chains_metadata(context, entity_names=entity_names)
asym_id2asym_unit = get_chains_metadata(
context, asymid2entity_name=asym_entity_names
)

atom_asym_id = context.token_asym_id[context.atom_token_index]
# atom level attributes
Expand Down

0 comments on commit 58bfece

Please sign in to comment.