Skip to content

Commit

Permalink
Combine perform shift functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ajfriedman22 committed Oct 7, 2024
1 parent c21bd0d commit 15c1fd2
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 105 deletions.
25 changes: 25 additions & 0 deletions ensemble_md/tests/data/coord_swap/broken_mol_2D.gro
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Generated with MDTraj, t= 1180.0
22
1C2D S1 1 0.081 2.020 0.959
1C2D C2 2 0.080 2.182 0.895
1C2D N3 3 2.728 -0.486 0.929
1C2D C4 4 2.639 -0.567 1.007
1C2D C5 5 2.684 -0.690 1.040
1C2D C6 6 0.192 2.229 0.800
1C2D C7 7 0.273 2.338 0.873
1C2D C8 8 0.364 2.420 0.778
1C2D H1 9 2.548 -0.517 1.037
1C2D H2 10 2.643 -0.765 1.107
1C2D H3 11 0.260 2.146 0.777
1C2D H4 12 0.151 2.262 0.704
1C2D H6 13 0.329 2.289 0.953
1C2D H7 14 0.202 2.408 0.919
1C2D H8 15 0.338 2.526 0.787
1C2D H9 16 0.469 2.408 0.808
1C2D H10 17 0.356 2.400 0.670
1C2D DC9 18 0.126 2.276 0.672
1C2D HV5 19 0.342 2.383 0.801
1C2D HV11 20 0.206 2.272 0.597
1C2D HV12 21 0.072 2.369 0.688
1C2D HV13 22 0.047 2.207 0.641
2.74964 2.74964 2.74964 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
25 changes: 25 additions & 0 deletions ensemble_md/tests/data/coord_swap/broken_mol_3D.gro
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Generated with MDTraj, t= 1180.0
22
1C2D S1 1 0.081 2.020 0.959
1C2D C2 2 0.080 2.182 0.895
1C2D N3 3 2.728 -0.486 -1.821
1C2D C4 4 2.639 -0.567 -1.742
1C2D C5 5 2.684 -0.690 -1.710
1C2D C6 6 0.192 2.229 0.800
1C2D C7 7 0.273 2.338 0.873
1C2D C8 8 0.364 2.420 0.778
1C2D H1 9 2.548 -0.517 -1.712
1C2D H2 10 2.643 -0.765 -1.642
1C2D H3 11 0.260 2.146 0.777
1C2D H4 12 0.151 2.262 0.704
1C2D H6 13 0.329 2.289 0.953
1C2D H7 14 0.202 2.408 0.919
1C2D H8 15 0.338 2.526 0.787
1C2D H9 16 0.469 2.408 0.808
1C2D H10 17 0.356 2.400 0.670
1C2D DC9 18 0.126 2.276 0.672
1C2D HV5 19 0.342 2.383 0.801
1C2D HV11 20 0.206 2.272 0.597
1C2D HV12 21 0.072 2.369 0.688
1C2D HV13 22 0.047 2.207 0.641
2.74964 2.74964 2.74964 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000
26 changes: 23 additions & 3 deletions ensemble_md/tests/test_coordinate_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,30 @@ def test_fix_break():
assert (test_fix.xyz == fixed_mol.xyz).all


def test_perform_shift_1D():
broken_mol = md.load(f'{input_path}/coord_swap/broken_mol.gro')
def test_perform_shift():
broken_mol = md.load(f'{input_path}/coord_swap/broken_mol_1D.gro')

partial_fix, was_it_fixed, prev_shifted_atoms = coordinate_swap.perform_shift_1D(broken_mol, [2.74964, 2.74964, 2.74964], [[0, 4]], [], 1) # noqa: E501

broken_pairs = coordinate_swap.check_break(partial_fix, [[0, 4]])

assert prev_shifted_atoms == [4]
assert was_it_fixed is True
assert len(broken_pairs) == 0

broken_mol_2D = md.load(f'broken_mol_2D.gro')

partial_fix, was_it_fixed, prev_shifted_atoms = coordinate_swap.perform_shift(broken_mol_2D, [2.74964, 2.74964, 2.74964], [[0, 4]], [], 2) # noqa: E501

broken_pairs = coordinate_swap.check_break(partial_fix, [[0, 4]])

assert prev_shifted_atoms == [4]
assert was_it_fixed is True
assert len(broken_pairs) == 0

broken_mol_3D = md.load(f'broken_mol_3D.gro')

partial_fix, was_it_fixed, prev_shifted_atoms = coordinate_swap.perform_shift_1D(broken_mol, [2.74964, 2.74964, 2.74964], [[0, 4]], []) # noqa: E501
partial_fix, was_it_fixed, prev_shifted_atoms = coordinate_swap.perform_shift(broken_mol_3D, [2.74964, 2.74964, 2.74964], [[0, 4]], [], 3) # noqa: E501

broken_pairs = coordinate_swap.check_break(partial_fix, [[0, 4]])

Expand Down
128 changes: 26 additions & 102 deletions ensemble_md/utils/coordinate_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,17 @@ def fix_break(mol, resname, box_dimensions, atom_connect_all):
iter += 1

# Fix this break
mol, fixed, shift_atom = perform_shift_1D(mol, box_dimensions, broken_pairs, shift_atom)
mol, fixed, shift_atom = perform_shift(mol, box_dimensions, broken_pairs, shift_atom, 1)
if fixed:
broken_pairs = check_break(mol, atom_pairs)
continue
else:
mol, fixed = perform_shift_2D(mol, box_dimensions, broken_pairs, shift_atom)
mol, fixed = perform_shift(mol, box_dimensions, broken_pairs, shift_atom, 2)
if fixed:
broken_pairs = check_break(mol, atom_pairs)
continue
else:
mol, fixed = perform_shift_3D(mol, box_dimensions, broken_pairs, shift_atom)
mol, fixed = perform_shift(mol, box_dimensions, broken_pairs, shift_atom, 3)
if fixed:
broken_pairs = check_break(mol, atom_pairs)
continue
Expand All @@ -245,7 +245,7 @@ def fix_break(mol, resname, box_dimensions, atom_connect_all):
return mol


def perform_shift_1D(mol, box_dimensions, broken_pairs_init, prev_shift_atom):
def perform_shift(mol, box_dimensions, broken_pairs_init, prev_shift_atom, num_shift_dimensions):
"""
Shifts the input trajectory across the periodic boundaries in 1D.
Expand All @@ -259,6 +259,7 @@ def perform_shift_1D(mol, box_dimensions, broken_pairs_init, prev_shift_atom):
Which pairs of atoms were found to be broken.
prev_shift_atom : int
Which atoms have already been shifted so we don't undo what we've done.
num_shift_dimensions : int
Returns
-------
Expand All @@ -274,101 +275,25 @@ def perform_shift_1D(mol, box_dimensions, broken_pairs_init, prev_shift_atom):
if broken_atom in prev_shift_atom:
broken_atom = atom_pair[0]
fixed = False
for x in range(3): # Loop through x, y, and z
for shift_dir in [1, -1]:
mol.xyz[0, broken_atom, x] = mol.xyz[0, broken_atom, x] + (shift_dir * box_dimensions[x]) # positive shift # noqa: E501
dist_check = md.compute_distances(mol, atom_pairs=[atom_pair], periodic=False)
if dist_check > 0.2: # Didn't work so reverse and try again
mol.xyz[0, broken_atom, x] = mol.xyz[0, broken_atom, x] - (shift_dir * box_dimensions[x])
else: # Yay fixed break
fixed = True
break
if fixed:
break
if fixed:
prev_shift_atom.append(broken_atom)
return mol, fixed, prev_shift_atom


def perform_shift_2D(mol, box_dimensions, broken_pairs_init, prev_shift_atom):
"""
Shifts the input trajectory across the periodic boundaries in 2D.
Parameters
----------
mol : :func:`mdtraj.Trajectory` object
Trajecotry with the original coordinates prior to the shift.
box_dimensions : list
Dimensions of the periodic boundary box.
broken_pairs_init : int
Which pairs of atoms were found to be broken.
prev_shift_atom : int
Which atoms have already been shifted so we don't undo what we've done.
Returns
-------
mol : :func:`mdtraj.Trajectory` object
Trajectory with the new coordintes.
fixed : bool
A boolean indicating whether the break was actually fixed.
prev_shift_atom : int
Which atoms have already been shifted so we don't undo what we've done.
"""
atom_pair = broken_pairs_init[0]
broken_atom = atom_pair[1]
if broken_atom in prev_shift_atom:
broken_atom = atom_pair[0]
fixed = False
shift_combos = product([0, 1, 2], [0, 1, 2])
for pair in shift_combos: # Loop through x, y, and z
for shift_dir in product([1, -1], [1, -1]): # Try all combos of shift direction
x, y = pair
mol.xyz[0, broken_atom, x] = mol.xyz[0, broken_atom, x] + (shift_dir[0] * box_dimensions[x]) # positive shift # noqa: E501
mol.xyz[0, broken_atom, y] = mol.xyz[0, broken_atom, y] + (shift_dir[1] * box_dimensions[y])
dist_check = md.compute_distances(mol, atom_pairs=[atom_pair], periodic=False)
if dist_check > 0.2: # Didn't work so reverse and try again
mol.xyz[0, broken_atom, x] = mol.xyz[0, broken_atom, x] - (shift_dir[0] * box_dimensions[x])
mol.xyz[0, broken_atom, y] = mol.xyz[0, broken_atom, y] - (shift_dir[1] * box_dimensions[y])
else: # Yay fixed break
fixed = True
break
if fixed:
break
if fixed:
prev_shift_atom.append(broken_atom)
return mol, fixed


def perform_shift_3D(mol, box_dimensions, broken_pairs_init, prev_shift_atom):
"""
Shifts the input trajectory across the periodic boundaries in 3D
Parameters
----------
mol : :func:`mdtraj.Trajectory` object
Trajecotry with the original coordinates prior to the shift.
box_dimensions : list
Dimensions of the periodic boundary box.
broken_pairs_init : int
Which pairs of atoms were found to be broken.
prev_shift_atom : int
Which atoms have already been shifted so we don't undo what we've done.
Returns
-------
mol : :func:`mdtraj.Trajectory` object
Trajectory with the new coordintes.
fixed : bool
A boolean indicating whether the break was actually fixed.
prev_shift_atom : int
Which atoms have already been shifted so we don't undo what we've done.
"""
atom_pair = broken_pairs_init[0]
broken_atom = atom_pair[1]
if broken_atom in prev_shift_atom:
broken_atom = atom_pair[0]
fixed = False
for shift_dir in product([1, -1], [1, -1], [1, -1]): # Try all combos of shift direction
if num_shift_dimensions == 1:
shift_combos = np.concatenate((np.identity(3), -1*np.identity(3)), axis=0)
elif num_shift_dimensions == 2:
shift_combos = [[1, 1, 0],
[1, -1, 0],
[-1, 1, 0],
[-1, -1, 0],
[0, 1, 1],
[0, 1, -1],
[0, -1, 1],
[0, -1, -1],
[1, 0, 1],
[1, 0, -1],
[-1, 0, 1],
[-1, 0, -1]]
else:
shift_combos = product([1, -1], [1, -1], [1, -1])

for shift_dir in shift_combos: # Try all combos of shift direction
mol.xyz[0, broken_atom, 0] = mol.xyz[0, broken_atom, 0] + (shift_dir[0] * box_dimensions[0])
mol.xyz[0, broken_atom, 1] = mol.xyz[0, broken_atom, 1] + (shift_dir[1] * box_dimensions[1])
mol.xyz[0, broken_atom, 2] = mol.xyz[0, broken_atom, 2] + (shift_dir[2] * box_dimensions[2])
Expand All @@ -382,8 +307,7 @@ def perform_shift_3D(mol, box_dimensions, broken_pairs_init, prev_shift_atom):
break
if fixed:
prev_shift_atom.append(broken_atom)
return mol, fixed

return mol, fixed, prev_shift_atom

def check_break(mol, atom_pairs):
"""
Expand Down Expand Up @@ -464,7 +388,7 @@ def compute_angle(coords):

return angle


# Create a new column for coordinates if one does not exist
if 'X Coordinates' not in df_atom_swap.columns:
df_atom_swap['X Coordinates'] = np.NaN
Expand Down

0 comments on commit 15c1fd2

Please sign in to comment.