Skip to content

Commit

Permalink
contacts: Allow saving all atom-level contacts in neighbor-based cont…
Browse files Browse the repository at this point in the history
…act assignment
  • Loading branch information
avivrosenberg committed May 26, 2024
1 parent d69f767 commit 5a5d15a
Showing 1 changed file with 47 additions and 7 deletions.
54 changes: 47 additions & 7 deletions src/pp5/contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import subprocess
from abc import ABC, abstractmethod
from time import time
from typing import Set, Dict, List, Tuple, Union, Optional, Sequence
from typing import Any, Set, Dict, List, Tuple, Union, Optional, Sequence
from pathlib import Path
from functools import partial
from itertools import chain
Expand Down Expand Up @@ -165,6 +165,13 @@ def from_atoms(cls, src: Atom, tgt: Atom, with_altlocs: bool = True) -> AtomCont
def __str__(self):
return f"{self.src_key!s} -> {self.tgt_key!s} ({self.type}): {self.dist:.2f}"

def as_dict(self) -> Dict[str, Any]:
return {
**{f"src_{k}": v for k, v in attrs.asdict(self.src_key).items()},
**{f"tgt_{k}": v for k, v in attrs.asdict(self.tgt_key).items()},
**attrs.asdict(self, filter=attrs.filters.exclude("src_key", "tgt_key")),
}


@attrs.define(repr=True, eq=True, hash=True)
class ResidueContact:
Expand Down Expand Up @@ -267,8 +274,17 @@ def __init__(
pdb_source: str,
contact_radius: float = CONTACT_DEFAULT_RADIUS,
with_altlocs: bool = False,
with_atom_contacts: bool = False,
pdb_dict: Optional[dict] = None,
):
"""
:param pdb_id: The PDB ID to assign contacts for.
:param pdb_source: The source from which to obtain the PDB file.
:param contact_radius: The radius (in angstroms) to use for contact detection.
:param with_altlocs: Whether to include altloc atoms in the contact detection.
:param with_atom_contacts: Whether to include atom-level contacts in the output.
:param pdb_dict: If provided, the PDB structure will be loaded from this dict.
"""
super().__init__(
pdb_id=pdb_id,
pdb_source=pdb_source,
Expand All @@ -288,6 +304,7 @@ def __init__(
)
)
self._contacts_from = NeighborSearch(list(atoms))
self._with_atom_contacts = with_atom_contacts

def assign(self, res: Residue) -> Dict[str, Optional[ResidueContacts]]:

Expand All @@ -314,7 +331,10 @@ def assign(self, res: Residue) -> Dict[str, Optional[ResidueContacts]]:
*all_atoms, allow_disjoint=True, include_none=True
)

contacts = {NO_ALTLOC: None}
atom_contacts: List[AtomContact] = []
altloc_to_residue_contacts: Dict[str, Optional[ResidueContacts]] = {
NO_ALTLOC: None
}

# For each altloc, we want to move all the atoms to it (if it exists for a
# particular atom) and then calculate the contacts from all the moved atoms.
Expand Down Expand Up @@ -373,17 +393,20 @@ def assign(self, res: Residue) -> Dict[str, Optional[ResidueContacts]]:
alt_atom, a, with_altlocs=self.with_altlocs
)
curr_altloc_contacts.add(atom_contact)
atom_contacts.append(atom_contact)

# Convert the atom contacts to residue contacts
residue_contacts = self._aggregate_atom_contacts(
src_res_contact_key, tuple(curr_altloc_contacts)
)

contacts[altloc_id] = ResidueContacts.from_contacts(
res, residue_contacts
altloc_to_residue_contacts[altloc_id] = ResidueContacts.from_contacts(
res,
residue_contacts,
atom_contacts if self._with_atom_contacts else None,
)

return contacts
return altloc_to_residue_contacts

def _aggregate_atom_contacts(
self, src_key: ResidueContactKey, atom_contacts: Sequence[AtomContact]
Expand Down Expand Up @@ -457,6 +480,7 @@ def __init__(
contact_ooc: Union[Sequence[str], str],
contact_non_aa: Union[Sequence[str], str],
contact_aas: Union[Sequence[str], str],
atom_contacts: Optional[Sequence[AtomContact]] = None,
**kwargs_ignored, # ignore any other args (passed in from Arpeggio)
):
def _split(s: str):
Expand Down Expand Up @@ -484,6 +508,7 @@ def _split(s: str):
self.contact_ooc = tuple(contact_ooc)
self.contact_non_aa = tuple(contact_non_aa)
self.contact_aas = tuple(contact_aas)
self.atom_contacts = atom_contacts

def as_dict(self, key_postfix: str = "", join_lists: bool = True):
def _join(s):
Expand Down Expand Up @@ -512,7 +537,22 @@ def __hash__(self):
return hash(tuple(self.as_dict().values()))

@classmethod
def from_contacts(cls, src_res: Residue, contacts: Sequence[ResidueContact]):
def from_contacts(
cls,
src_res: Residue,
contacts: Sequence[ResidueContact],
atom_contacts: Sequence[AtomContact] = None,
) -> ResidueContacts:
"""
Creates a ResidueContacts from a sequence of ResidueContact objects.
:param src_res: The source residue.
:param contacts: The sequence of ResidueContact objects.
:param atom_contacts: Optional sequence of AtomContact objects. If provided,
they will be included as-is in the ResidueContacts object.
:return: A ResidueContacts object.
"""

def _format(_contacts: Sequence[ResidueContact]) -> Sequence[str]:
return tuple(
format_residue_contact(
Expand Down Expand Up @@ -543,7 +583,7 @@ def _format(_contacts: Sequence[ResidueContact]) -> Sequence[str]:
contact_ooc=contact_ooc,
contact_non_aa=contact_non_aa,
contact_aas=contact_aas,
# atom_contacts=atom_contacts,
atom_contacts=atom_contacts,
)


Expand Down

0 comments on commit 5a5d15a

Please sign in to comment.