diff --git a/chai_lab/data/io/cif_utils.py b/chai_lab/data/io/cif_utils.py index 98d65b9..491b546 100644 --- a/chai_lab/data/io/cif_utils.py +++ b/chai_lab/data/io/cif_utils.py @@ -81,6 +81,7 @@ def get_chains_metadata( token_res_names = context.token_res_names_to_string asym_id2asym_unit = {} + entity_key2ihm_entity = {} for asym_id, entity_type in context.asym_id2entity_type.items(): assert asym_id != 0, "zero is padding for asym_id" @@ -96,27 +97,33 @@ def get_chains_metadata( any_token_in_resi = residue_indices.new_zeros([max_res + 1]) - 99 any_token_in_resi[residue_indices] = torch.arange( - len(residue_indices), - dtype=residue_indices.dtype, - device=residue_indices.device, + len(residue_indices), dtype=any_token_in_resi.dtype ) assert any_token_in_resi.min() >= 0 sequence = [chain_token_res_names[i] for i in any_token_in_resi] - chain_id_str = _get_chain_letter(asym_id) + if entity_type == EntityType.LIGAND.value: + entity_key = (entity_type, asym_id) # each ligand = separate entity. + else: + entity_key = (entity_type, *sequence) - asym_id2asym_unit[asym_id] = AsymUnit( - entity=Entity( + asym_entity_name = asymid2entity_name[asym_id] + if entity_key not in entity_key2ihm_entity: + entity_key2ihm_entity[entity_key] = Entity( # sequence is a list of ChemComponents for aminoacids/bases sequence=[ _to_chem_component(resi, entity_type, asym_id) for resi in sequence ], - description=asymid2entity_name[asym_id], - ), - details=f"Chain {chain_id_str}", - id=chain_id_str, + # will be named same as first among replicas + description=f"Entity {asym_entity_name}", + ) + + asym_id2asym_unit[asym_id] = AsymUnit( + entity=entity_key2ihm_entity[entity_key], + details=f"Chain {asym_entity_name}", + id=asym_entity_name, ) return asym_id2asym_unit