Skip to content

Commit

Permalink
Allow bonds etc to be additively guessed when present (#4761)
Browse files Browse the repository at this point in the history
Allow bonds to be additively guessed (fixes #4759)

---------

Co-authored-by: Irfan Alibay <[email protected]>
  • Loading branch information
lilyminium and IAlibay authored Nov 11, 2024
1 parent c9a3778 commit e6bc096
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 29 deletions.
3 changes: 3 additions & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ The rules for this file:
* 2.8.0

Fixes
* Allows bond/angle/dihedral connectivity to be guessed additively with
`to_guess`, and as a replacement of existing values with `force_guess`.
Also updates cached bond attributes when updating bonds. (Issue #4759, PR #4761)
* Fixes bug where deleting connections by index would only delete
one of multiple, if multiple are present (Issue #4762, PR #4763)
* Changes error to warning on Universe creation if guessing fails
Expand Down
1 change: 0 additions & 1 deletion package/MDAnalysis/core/topologyattrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3117,7 +3117,6 @@ def _add_bonds(self, values, types=None, guessed=True, order=None):
guessed = itertools.cycle((guessed,))
if order is None:
order = itertools.cycle((None,))

existing = set(self.values)
for v, t, g, o in zip(values, types, guessed, order):
if v not in existing:
Expand Down
67 changes: 46 additions & 21 deletions package/MDAnalysis/core/universe.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@
Atom, Residue, Segment,
AtomGroup, ResidueGroup, SegmentGroup)
from .topology import Topology
from .topologyattrs import AtomAttr, ResidueAttr, SegmentAttr, BFACTOR_WARNING
from .topologyattrs import (
AtomAttr, ResidueAttr, SegmentAttr,
BFACTOR_WARNING, _Connection
)
from .topologyobjects import TopologyObject
from ..guesser.base import get_guesser

Expand Down Expand Up @@ -454,7 +457,10 @@ def __init__(self, topology=None, *coordinates, all_coordinates=False,
"the previous Context values.",
DeprecationWarning
)
force_guess = list(force_guess) + ['bonds', 'angles', 'dihedrals']
# Original behaviour is to add additionally guessed bond info
# this is achieved by adding to the `to_guess` list (unliked `force_guess`
# which replaces existing bonds).
to_guess = list(to_guess) + ['bonds', 'angles', 'dihedrals']

self.guess_TopologyAttrs(
context, to_guess, force_guess, error_if_missing=False
Expand Down Expand Up @@ -1180,7 +1186,6 @@ def _add_topology_objects(self, object_type, values, types=None, guessed=False,
self.add_TopologyAttr(object_type, [])
attr = getattr(self._topology, object_type)


attr._add_bonds(indices, types=types, guessed=guessed, order=order)

def add_bonds(self, values, types=None, guessed=False, order=None):
Expand Down Expand Up @@ -1231,6 +1236,16 @@ def add_bonds(self, values, types=None, guessed=False, order=None):
"""
self._add_topology_objects('bonds', values, types=types,
guessed=guessed, order=order)
self._invalidate_bond_related_caches()

def _invalidate_bond_related_caches(self):
"""
Invalidate caches related to bonds and fragments.
This should be called whenever the Universe's bonds are modified.
.. versionadded: 2.8.0
"""
# Invalidate bond-related caches
self._cache.pop('fragments', None)
self._cache['_valid'].pop('fragments', None)
Expand Down Expand Up @@ -1307,7 +1322,7 @@ def _delete_topology_objects(self, object_type, values):
Parameters
----------
object_type : {'bonds', 'angles', 'dihedrals', 'impropers'}
The type of TopologyObject to add.
The type of TopologyObject to delete.
values : iterable of tuples, AtomGroups, or TopologyObjects; or TopologyGroup
An iterable of: tuples of atom indices, or AtomGroups,
or TopologyObjects.
Expand All @@ -1330,7 +1345,6 @@ def _delete_topology_objects(self, object_type, values):
attr = getattr(self._topology, object_type)
except AttributeError:
raise ValueError('There are no {} to delete'.format(object_type))

attr._delete_bonds(indices)

def delete_bonds(self, values):
Expand Down Expand Up @@ -1371,10 +1385,7 @@ def delete_bonds(self, values):
.. versionadded:: 1.0.0
"""
self._delete_topology_objects('bonds', values)
# Invalidate bond-related caches
self._cache.pop('fragments', None)
self._cache['_valid'].pop('fragments', None)
self._cache['_valid'].pop('fragindices', None)
self._invalidate_bond_related_caches()

def delete_angles(self, values):
"""Delete Angles from this Universe.
Expand Down Expand Up @@ -1613,7 +1624,12 @@ def guess_TopologyAttrs(
# in the same order that the user provided
total_guess = list(dict.fromkeys(total_guess))

objects = ['bonds', 'angles', 'dihedrals', 'impropers']
# Set of all Connectivity related attribute names
# used to special case attribute replacement after calling the guesser
objects = set(
topattr.attrname for topattr in _TOPOLOGY_ATTRS.values()
if issubclass(topattr, _Connection)
)

# Checking if the universe is empty to avoid errors
# from guesser methods
Expand All @@ -1640,23 +1656,32 @@ def guess_TopologyAttrs(
fg = attr in force_guess
try:
values = guesser.guess_attr(attr, fg)
except ValueError as e:
except NoDataError as e:
if error_if_missing or fg:
raise e
else:
warnings.warn(str(e))
continue

if values is not None:
if attr in objects:
self._add_topology_objects(
attr, values, guessed=True)
else:
guessed_attr = _TOPOLOGY_ATTRS[attr](values, True)
self.add_TopologyAttr(guessed_attr)
logger.info(
f'attribute {attr} has been guessed'
' successfully.')
# None indicates no additional guessing was done
if values is None:
continue
if attr in objects:
# delete existing connections if they exist
if fg and hasattr(self.atoms, attr):
group = getattr(self.atoms, attr)
self._delete_topology_objects(attr, group)
# this method appends any new bonds in values to existing bonds
self._add_topology_objects(
attr, values, guessed=True)
if attr == "bonds":
self._invalidate_bond_related_caches()
else:
guessed_attr = _TOPOLOGY_ATTRS[attr](values, True)
self.add_TopologyAttr(guessed_attr)
logger.info(
f'attribute {attr} has been guessed'
' successfully.')

else:
raise ValueError(f'{context} guesser can not guess the'
Expand Down
34 changes: 27 additions & 7 deletions package/MDAnalysis/guesser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
.. autofunction:: get_guesser
"""
from .. import _GUESSERS
from .. import _GUESSERS, _TOPOLOGY_ATTRS
from ..core.topologyattrs import _Connection
import numpy as np
from .. import _TOPOLOGY_ATTRS
import logging
from typing import Dict
import copy
Expand Down Expand Up @@ -136,21 +136,41 @@ def guess_attr(self, attr_to_guess, force_guess=False):
NDArray of guessed values
"""
try:
top_attr = _TOPOLOGY_ATTRS[attr_to_guess]
except KeyError:
raise KeyError(
f"{attr_to_guess} is not a recognized MDAnalysis "
"topology attribute"
)
# make attribute to guess plural
attr_to_guess = top_attr.attrname

try:
guesser_method = self._guesser_methods[attr_to_guess]
except KeyError:
raise ValueError(f'{type(self).__name__} cannot guess this '
f'attribute: {attr_to_guess}')

# Connection attributes should be just returned as they are always
# appended to the Universe. ``force_guess`` handling should happen
# at Universe level.
if issubclass(top_attr, _Connection):
return guesser_method()

# check if the topology already has the attribute to partially guess it
if hasattr(self._universe.atoms, attr_to_guess) and not force_guess:
attr_values = np.array(
getattr(self._universe.atoms, attr_to_guess, None))

top_attr = _TOPOLOGY_ATTRS[attr_to_guess]

empty_values = top_attr.are_values_missing(attr_values)

if True in empty_values:
# pass to the guesser_method boolean mask to only guess the
# empty values
attr_values[empty_values] = self._guesser_methods[attr_to_guess](
indices_to_guess=empty_values)
attr_values[empty_values] = guesser_method(
indices_to_guess=empty_values
)
return attr_values

else:
Expand All @@ -159,7 +179,7 @@ def guess_attr(self, attr_to_guess, force_guess=False):
f'not guess any new values for {attr_to_guess} attribute')
return None
else:
return np.array(self._guesser_methods[attr_to_guess]())
return np.array(guesser_method())


def get_guesser(context, u=None, **kwargs):
Expand Down
138 changes: 138 additions & 0 deletions testsuite/MDAnalysisTests/guesser/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@
from MDAnalysis.core.topology import Topology
from MDAnalysis.core.topologyattrs import Masses, Atomnames, Atomtypes
import MDAnalysis.tests.datafiles as datafiles
from MDAnalysis.exceptions import NoDataError
from numpy.testing import assert_allclose, assert_equal

from MDAnalysis import _TOPOLOGY_ATTRS, _GUESSERS


class TestBaseGuesser():

Expand Down Expand Up @@ -101,6 +104,141 @@ def test_partial_guess_attr_with_unknown_no_value_label(self):
u = mda.Universe(top, to_guess=['types'])
assert_equal(u.atoms.types, ['', '', '', ''])

def test_guess_topology_objects_existing_read(self):
u = mda.Universe(datafiles.CONECT)
assert len(u.atoms.bonds) == 72
assert list(u.bonds[0].indices) == [623, 630]

# delete some bonds
u.delete_bonds(u.atoms.bonds[:10])
assert len(u.atoms.bonds) == 62
# first bond has changed
assert list(u.bonds[0].indices) == [1545, 1552]
# count number of (1545, 1552) bonds
ag = u.atoms[[1545, 1552]]
bonds = ag.bonds.atomgroup_intersection(ag, strict=True)
assert len(bonds) == 1
assert not bonds[0].is_guessed

all_indices = [tuple(x.indices) for x in u.bonds]
assert (623, 630) not in all_indices

# test guessing new bonds doesn't remove old ones
u.guess_TopologyAttrs("default", to_guess=["bonds"])
assert len(u.atoms.bonds) == 1922
old_bonds = ag.bonds.atomgroup_intersection(ag, strict=True)
assert len(old_bonds) == 1
# test guessing new bonds doesn't duplicate old ones
assert not old_bonds[0].is_guessed

new_ag = u.atoms[[623, 630]]
new_bonds = new_ag.bonds.atomgroup_intersection(new_ag, strict=True)
assert len(new_bonds) == 1
assert new_bonds[0].is_guessed

def test_guess_topology_objects_existing_in_universe(self):
u = mda.Universe(datafiles.CONECT, to_guess=["bonds"])
assert len(u.atoms.bonds) == 1922
assert list(u.bonds[0].indices) == [0, 1]

# delete some bonds
u.delete_bonds(u.atoms.bonds[:100])
assert len(u.atoms.bonds) == 1822
assert list(u.bonds[0].indices) == [94, 99]

all_indices = [tuple(x.indices) for x in u.bonds]
assert (0, 1) not in all_indices

# guess old bonds back
u.guess_TopologyAttrs("default", to_guess=["bonds"])
assert len(u.atoms.bonds) == 1922
# check TopologyGroup contains new (old) bonds
assert list(u.bonds[0].indices) == [0, 1]

def test_guess_topology_objects_force(self):
u = mda.Universe(datafiles.CONECT, force_guess=["bonds"])
assert len(u.atoms.bonds) == 1922

with pytest.raises(NoDataError):
u.atoms.angles

def test_guess_topology_objects_out_of_order_init(self):
u = mda.Universe(
datafiles.PDB_small,
to_guess=["dihedrals", "angles", "bonds"],
guess_bonds=False
)
assert len(u.atoms.angles) == 6123
assert len(u.atoms.dihedrals) == 8921

def test_guess_topology_objects_out_of_order_guess(self):
u = mda.Universe(datafiles.PDB_small)
with pytest.raises(NoDataError):
u.atoms.angles

u.guess_TopologyAttrs(
"default",
to_guess=["dihedrals", "angles", "bonds"]
)
assert len(u.atoms.angles) == 6123
assert len(u.atoms.dihedrals) == 8921

def test_force_guess_overwrites_existing_bonds(self):
u = mda.Universe(datafiles.CONECT)
assert len(u.atoms.bonds) == 72

# This low radius should find no bonds
vdw = dict.fromkeys(set(u.atoms.types), 0.1)
u.guess_TopologyAttrs("default", to_guess=["bonds"], vdwradii=vdw)
assert len(u.atoms.bonds) == 72

# Now force guess bonds
u.guess_TopologyAttrs("default", force_guess=["bonds"], vdwradii=vdw)
assert len(u.atoms.bonds) == 0

def test_guessing_angles_respects_bond_kwargs(self):
u = mda.Universe(datafiles.PDB)
assert not hasattr(u.atoms, "angles")

# This low radius should find no angles
vdw = dict.fromkeys(set(u.atoms.types), 0.01)

u.guess_TopologyAttrs("default", to_guess=["angles"], vdwradii=vdw)
assert len(u.atoms.angles) == 0

# set higher radii for lots of angles!
vdw = dict.fromkeys(set(u.atoms.types), 1)
u.guess_TopologyAttrs("default", force_guess=["angles"], vdwradii=vdw)
assert len(u.atoms.angles) == 89466

def test_guessing_dihedrals_respects_bond_kwargs(self):
u = mda.Universe(datafiles.CONECT)
assert len(u.atoms.bonds) == 72

u.guess_TopologyAttrs("default", to_guess=["dihedrals"])
assert len(u.atoms.dihedrals) == 3548
assert not hasattr(u.atoms, "angles")

def test_guess_invalid_attribute(self):
default_guesser = get_guesser("default")
err = "not a recognized MDAnalysis topology attribute"
with pytest.raises(KeyError, match=err):
default_guesser.guess_attr('not_an_attribute')

def test_guess_unsupported_attribute(self):
default_guesser = get_guesser("default")
err = "cannot guess this attribute"
with pytest.raises(ValueError, match=err):
default_guesser.guess_attr('tempfactors')

def test_guess_singular(self):
default_guesser = get_guesser("default")
u = mda.Universe(datafiles.PDB, to_guess=[])
assert not hasattr(u.atoms, "masses")

default_guesser._universe = u
masses = default_guesser.guess_attr('mass')


def test_Universe_guess_bonds_deprecated():
with pytest.warns(
Expand Down

0 comments on commit e6bc096

Please sign in to comment.