diff --git a/ensemble_md/tests/data/coord_swap/broken_mol.gro b/ensemble_md/tests/data/coord_swap/broken_mol_1D.gro similarity index 100% rename from ensemble_md/tests/data/coord_swap/broken_mol.gro rename to ensemble_md/tests/data/coord_swap/broken_mol_1D.gro diff --git a/ensemble_md/tests/data/coord_swap/broken_mol_2D.gro b/ensemble_md/tests/data/coord_swap/broken_mol_2D.gro new file mode 100644 index 0000000..93f3c0c --- /dev/null +++ b/ensemble_md/tests/data/coord_swap/broken_mol_2D.gro @@ -0,0 +1,25 @@ +Generated with MDTraj, t= 1180.0 + 22 + 1C2D S1 1 0.081 2.020 0.959 + 1C2D C2 2 0.080 2.182 0.895 + 1C2D N3 3 2.728 -0.486 0.929 + 1C2D C4 4 2.639 -0.567 1.007 + 1C2D C5 5 2.684 -0.690 1.040 + 1C2D C6 6 0.192 2.229 0.800 + 1C2D C7 7 0.273 2.338 0.873 + 1C2D C8 8 0.364 2.420 0.778 + 1C2D H1 9 2.548 -0.517 1.037 + 1C2D H2 10 2.643 -0.765 1.107 + 1C2D H3 11 0.260 2.146 0.777 + 1C2D H4 12 0.151 2.262 0.704 + 1C2D H6 13 0.329 2.289 0.953 + 1C2D H7 14 0.202 2.408 0.919 + 1C2D H8 15 0.338 2.526 0.787 + 1C2D H9 16 0.469 2.408 0.808 + 1C2D H10 17 0.356 2.400 0.670 + 1C2D DC9 18 0.126 2.276 0.672 + 1C2D HV5 19 0.342 2.383 0.801 + 1C2D HV11 20 0.206 2.272 0.597 + 1C2D HV12 21 0.072 2.369 0.688 + 1C2D HV13 22 0.047 2.207 0.641 + 2.74964 2.74964 2.74964 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000 diff --git a/ensemble_md/tests/data/coord_swap/broken_mol_3D.gro b/ensemble_md/tests/data/coord_swap/broken_mol_3D.gro new file mode 100644 index 0000000..b788551 --- /dev/null +++ b/ensemble_md/tests/data/coord_swap/broken_mol_3D.gro @@ -0,0 +1,25 @@ +Generated with MDTraj, t= 1180.0 + 22 + 1C2D S1 1 0.081 2.020 0.959 + 1C2D C2 2 0.080 2.182 0.895 + 1C2D N3 3 2.728 -0.486 -1.821 + 1C2D C4 4 2.639 -0.567 -1.742 + 1C2D C5 5 2.684 -0.690 -1.710 + 1C2D C6 6 0.192 2.229 0.800 + 1C2D C7 7 0.273 2.338 0.873 + 1C2D C8 8 0.364 2.420 0.778 + 1C2D H1 9 2.548 -0.517 -1.712 + 1C2D H2 10 2.643 -0.765 -1.642 + 1C2D H3 11 0.260 2.146 0.777 + 1C2D H4 12 0.151 2.262 0.704 + 1C2D H6 13 0.329 2.289 0.953 + 1C2D H7 14 0.202 2.408 0.919 + 1C2D H8 15 0.338 2.526 0.787 + 1C2D H9 16 0.469 2.408 0.808 + 1C2D H10 17 0.356 2.400 0.670 + 1C2D DC9 18 0.126 2.276 0.672 + 1C2D HV5 19 0.342 2.383 0.801 + 1C2D HV11 20 0.206 2.272 0.597 + 1C2D HV12 21 0.072 2.369 0.688 + 1C2D HV13 22 0.047 2.207 0.641 + 2.74964 2.74964 2.74964 0.00000 0.00000 0.00000 0.00000 0.00000 0.00000 diff --git a/ensemble_md/tests/test_coordinate_swap.py b/ensemble_md/tests/test_coordinate_swap.py index 6b1bf6b..4b71189 100644 --- a/ensemble_md/tests/test_coordinate_swap.py +++ b/ensemble_md/tests/test_coordinate_swap.py @@ -94,10 +94,30 @@ def test_fix_break(): assert (test_fix.xyz == fixed_mol.xyz).all -def test_perform_shift_1D(): - broken_mol = md.load(f'{input_path}/coord_swap/broken_mol.gro') +def test_perform_shift(): + broken_mol = md.load(f'{input_path}/coord_swap/broken_mol_1D.gro') + + partial_fix, was_it_fixed, prev_shifted_atoms = coordinate_swap.perform_shift_1D(broken_mol, [2.74964, 2.74964, 2.74964], [[0, 4]], [], 1) # noqa: E501 + + broken_pairs = coordinate_swap.check_break(partial_fix, [[0, 4]]) + + assert prev_shifted_atoms == [4] + assert was_it_fixed is True + assert len(broken_pairs) == 0 + + broken_mol_2D = md.load(f'broken_mol_2D.gro') + + partial_fix, was_it_fixed, prev_shifted_atoms = coordinate_swap.perform_shift(broken_mol_2D, [2.74964, 2.74964, 2.74964], [[0, 4]], [], 2) # noqa: E501 + + broken_pairs = coordinate_swap.check_break(partial_fix, [[0, 4]]) + + assert prev_shifted_atoms == [4] + assert was_it_fixed is True + assert len(broken_pairs) == 0 + + broken_mol_3D = md.load(f'broken_mol_3D.gro') - partial_fix, was_it_fixed, prev_shifted_atoms = coordinate_swap.perform_shift_1D(broken_mol, [2.74964, 2.74964, 2.74964], [[0, 4]], []) # noqa: E501 + partial_fix, was_it_fixed, prev_shifted_atoms = coordinate_swap.perform_shift(broken_mol_3D, [2.74964, 2.74964, 2.74964], [[0, 4]], [], 3) # noqa: E501 broken_pairs = coordinate_swap.check_break(partial_fix, [[0, 4]]) diff --git a/ensemble_md/utils/coordinate_swap.py b/ensemble_md/utils/coordinate_swap.py index adc4c92..73f430e 100644 --- a/ensemble_md/utils/coordinate_swap.py +++ b/ensemble_md/utils/coordinate_swap.py @@ -225,17 +225,17 @@ def fix_break(mol, resname, box_dimensions, atom_connect_all): iter += 1 # Fix this break - mol, fixed, shift_atom = perform_shift_1D(mol, box_dimensions, broken_pairs, shift_atom) + mol, fixed, shift_atom = perform_shift(mol, box_dimensions, broken_pairs, shift_atom, 1) if fixed: broken_pairs = check_break(mol, atom_pairs) continue else: - mol, fixed = perform_shift_2D(mol, box_dimensions, broken_pairs, shift_atom) + mol, fixed = perform_shift(mol, box_dimensions, broken_pairs, shift_atom, 2) if fixed: broken_pairs = check_break(mol, atom_pairs) continue else: - mol, fixed = perform_shift_3D(mol, box_dimensions, broken_pairs, shift_atom) + mol, fixed = perform_shift(mol, box_dimensions, broken_pairs, shift_atom, 3) if fixed: broken_pairs = check_break(mol, atom_pairs) continue @@ -245,7 +245,7 @@ def fix_break(mol, resname, box_dimensions, atom_connect_all): return mol -def perform_shift_1D(mol, box_dimensions, broken_pairs_init, prev_shift_atom): +def perform_shift(mol, box_dimensions, broken_pairs_init, prev_shift_atom, num_shift_dimensions): """ Shifts the input trajectory across the periodic boundaries in 1D. @@ -259,6 +259,7 @@ def perform_shift_1D(mol, box_dimensions, broken_pairs_init, prev_shift_atom): Which pairs of atoms were found to be broken. prev_shift_atom : int Which atoms have already been shifted so we don't undo what we've done. + num_shift_dimensions : int Returns ------- @@ -274,101 +275,25 @@ def perform_shift_1D(mol, box_dimensions, broken_pairs_init, prev_shift_atom): if broken_atom in prev_shift_atom: broken_atom = atom_pair[0] fixed = False - for x in range(3): # Loop through x, y, and z - for shift_dir in [1, -1]: - mol.xyz[0, broken_atom, x] = mol.xyz[0, broken_atom, x] + (shift_dir * box_dimensions[x]) # positive shift # noqa: E501 - dist_check = md.compute_distances(mol, atom_pairs=[atom_pair], periodic=False) - if dist_check > 0.2: # Didn't work so reverse and try again - mol.xyz[0, broken_atom, x] = mol.xyz[0, broken_atom, x] - (shift_dir * box_dimensions[x]) - else: # Yay fixed break - fixed = True - break - if fixed: - break - if fixed: - prev_shift_atom.append(broken_atom) - return mol, fixed, prev_shift_atom - - -def perform_shift_2D(mol, box_dimensions, broken_pairs_init, prev_shift_atom): - """ - Shifts the input trajectory across the periodic boundaries in 2D. - - Parameters - ---------- - mol : :func:`mdtraj.Trajectory` object - Trajecotry with the original coordinates prior to the shift. - box_dimensions : list - Dimensions of the periodic boundary box. - broken_pairs_init : int - Which pairs of atoms were found to be broken. - prev_shift_atom : int - Which atoms have already been shifted so we don't undo what we've done. - - Returns - ------- - mol : :func:`mdtraj.Trajectory` object - Trajectory with the new coordintes. - fixed : bool - A boolean indicating whether the break was actually fixed. - prev_shift_atom : int - Which atoms have already been shifted so we don't undo what we've done. - """ - atom_pair = broken_pairs_init[0] - broken_atom = atom_pair[1] - if broken_atom in prev_shift_atom: - broken_atom = atom_pair[0] - fixed = False - shift_combos = product([0, 1, 2], [0, 1, 2]) - for pair in shift_combos: # Loop through x, y, and z - for shift_dir in product([1, -1], [1, -1]): # Try all combos of shift direction - x, y = pair - mol.xyz[0, broken_atom, x] = mol.xyz[0, broken_atom, x] + (shift_dir[0] * box_dimensions[x]) # positive shift # noqa: E501 - mol.xyz[0, broken_atom, y] = mol.xyz[0, broken_atom, y] + (shift_dir[1] * box_dimensions[y]) - dist_check = md.compute_distances(mol, atom_pairs=[atom_pair], periodic=False) - if dist_check > 0.2: # Didn't work so reverse and try again - mol.xyz[0, broken_atom, x] = mol.xyz[0, broken_atom, x] - (shift_dir[0] * box_dimensions[x]) - mol.xyz[0, broken_atom, y] = mol.xyz[0, broken_atom, y] - (shift_dir[1] * box_dimensions[y]) - else: # Yay fixed break - fixed = True - break - if fixed: - break - if fixed: - prev_shift_atom.append(broken_atom) - return mol, fixed - - -def perform_shift_3D(mol, box_dimensions, broken_pairs_init, prev_shift_atom): - """ - Shifts the input trajectory across the periodic boundaries in 3D - - Parameters - ---------- - mol : :func:`mdtraj.Trajectory` object - Trajecotry with the original coordinates prior to the shift. - box_dimensions : list - Dimensions of the periodic boundary box. - broken_pairs_init : int - Which pairs of atoms were found to be broken. - prev_shift_atom : int - Which atoms have already been shifted so we don't undo what we've done. - - Returns - ------- - mol : :func:`mdtraj.Trajectory` object - Trajectory with the new coordintes. - fixed : bool - A boolean indicating whether the break was actually fixed. - prev_shift_atom : int - Which atoms have already been shifted so we don't undo what we've done. - """ - atom_pair = broken_pairs_init[0] - broken_atom = atom_pair[1] - if broken_atom in prev_shift_atom: - broken_atom = atom_pair[0] - fixed = False - for shift_dir in product([1, -1], [1, -1], [1, -1]): # Try all combos of shift direction + if num_shift_dimensions == 1: + shift_combos = np.concatenate((np.identity(3), -1*np.identity(3)), axis=0) + elif num_shift_dimensions == 2: + shift_combos = [[1, 1, 0], + [1, -1, 0], + [-1, 1, 0], + [-1, -1, 0], + [0, 1, 1], + [0, 1, -1], + [0, -1, 1], + [0, -1, -1], + [1, 0, 1], + [1, 0, -1], + [-1, 0, 1], + [-1, 0, -1]] + else: + shift_combos = product([1, -1], [1, -1], [1, -1]) + + for shift_dir in shift_combos: # Try all combos of shift direction mol.xyz[0, broken_atom, 0] = mol.xyz[0, broken_atom, 0] + (shift_dir[0] * box_dimensions[0]) mol.xyz[0, broken_atom, 1] = mol.xyz[0, broken_atom, 1] + (shift_dir[1] * box_dimensions[1]) mol.xyz[0, broken_atom, 2] = mol.xyz[0, broken_atom, 2] + (shift_dir[2] * box_dimensions[2]) @@ -382,8 +307,7 @@ def perform_shift_3D(mol, box_dimensions, broken_pairs_init, prev_shift_atom): break if fixed: prev_shift_atom.append(broken_atom) - return mol, fixed - + return mol, fixed, prev_shift_atom def check_break(mol, atom_pairs): """ @@ -464,7 +388,7 @@ def compute_angle(coords): return angle - + # Create a new column for coordinates if one does not exist if 'X Coordinates' not in df_atom_swap.columns: df_atom_swap['X Coordinates'] = np.NaN