From fafeab7ad468ea89b927d8b94fd7625a0d7e6e6a Mon Sep 17 00:00:00 2001 From: Matteo Degiacomi Date: Thu, 20 Jul 2023 22:30:31 +0100 Subject: [PATCH] formatting and small cleanup - first pass fixing indentations, spaces, and syntax details (guided by flake8) - using os.cpu_count() instead of our custom function --- src/molearn/analysis/GUI.py | 67 +- src/molearn/analysis/analyser.py | 78 +- src/molearn/analysis/path.py | 14 +- src/molearn/data/pdb_data.py | 21 +- src/molearn/loss_functions/openmm_thread.py | 56 +- .../loss_functions/torch_protein_energy.py | 174 ++--- .../torch_protein_energy_utils.py | 739 +++++++++--------- src/molearn/models/CNN_autoencoder.py | 20 +- src/molearn/models/foldingnet.py | 31 +- src/molearn/models/small_foldingnet.py | 17 +- src/molearn/scoring/dope_score.py | 44 +- src/molearn/scoring/ramachandran_score.py | 49 +- .../trainers/openmm_physics_trainer.py | 16 +- src/molearn/trainers/sinkhorn_trainer.py | 67 +- src/molearn/trainers/torch_physics_trainer.py | 14 +- src/molearn/trainers/trainer.py | 60 +- src/molearn/utils.py | 33 +- 17 files changed, 724 insertions(+), 776 deletions(-) diff --git a/src/molearn/analysis/GUI.py b/src/molearn/analysis/GUI.py index cef5594..a2fa6a6 100644 --- a/src/molearn/analysis/GUI.py +++ b/src/molearn/analysis/GUI.py @@ -32,7 +32,7 @@ from ..utils import as_numpy -class MolearnGUI(object): +class MolearnGUI: ''' This class produces an interactive visualisation for data stored in a :func:`MolearnAnalysis ` object, @@ -49,11 +49,10 @@ def __init__(self, MA=None): else: self.MA = MA - self.waypoints = [] # collection of all saved waypoints - self.samples = [] # collection of all calculated sampling points + self.waypoints = [] # collection of all saved waypoints + self.samples = [] # collection of all calculated sampling points self.run() - def update_trails(self): ''' @@ -76,7 +75,6 @@ def update_trails(self): self.latent.data[2].y = self.samples[:, 1] self.latent.update() - def on_click(self, trace, points, selector): ''' @@ -97,14 +95,13 @@ def on_click(self, trace, points, selector): # update textbox (triggering update of 3D representation) try: pt = self.waypoints.flatten().round(decimals=4).astype(str) - #pt = np.array([self.latent.data[3].x, self.latent.data[3].y]).T.flatten().round(decimals=4).astype(str) + # pt = np.array([self.latent.data[3].x, self.latent.data[3].y]).T.flatten().round(decimals=4).astype(str) self.mybox.value = " ".join(pt) except Exception: return self.update_trails() - def get_samples(self, mybox, samplebox, path): ''' provide a trail of point between list of waypoints, either connected @@ -120,8 +117,8 @@ def get_samples(self, mybox, samplebox, path): crd = np.array(mybox.split()).astype(float) crd = crd.reshape((int(len(crd)/2), 2)) except Exception: - raise Exception("Cannot define sampling points") - return + raise Exception("Cannot define sampling points") + return if use_path: # connect points via A* @@ -129,8 +126,8 @@ def get_samples(self, mybox, samplebox, path): landscape = self.latent.data[0].z crd = get_path_aggregate(crd, landscape.T, self.MA.xvals, self.MA.yvals) except Exception as e: - raise Exception(f"Cannot define sampling points: path finding failed. {e})") - return + raise Exception(f"Cannot define sampling points: path finding failed. {e})") + return else: # connect points via straight line @@ -141,7 +138,6 @@ def get_samples(self, mybox, samplebox, path): return return crd - def interact_3D(self, mybox, samplebox, path): ''' @@ -152,7 +148,7 @@ def interact_3D(self, mybox, samplebox, path): crd = self.get_samples(mybox, samplebox, path) self.samples = crd.copy() crd = crd.reshape((1, len(crd), 2)) - except: + except Exception: self.button_pdb.disabled = True return @@ -169,12 +165,11 @@ def interact_3D(self, mybox, samplebox, path): self.mymol.load_new(gen) view = nv.show_mdanalysis(self.mymol) view.add_representation("spacefill") - #view.add_representation("cartoon") + # view.add_representation("cartoon") display.display(view) self.button_pdb.disabled = False - def drop_background_event(self, change): ''' control colouring style of latent space surface @@ -186,7 +181,7 @@ def drop_background_event(self, change): mykey = change.new try: - data = self.MA.surfaces[mykey] + data = self.MA.surfaces[mykey] except Exception as e: print(f"{e}") return @@ -204,7 +199,7 @@ def drop_background_event(self, change): self.latent.data[0].zmax = np.max(data) self.block0.children[1].min = np.min(data) self.block0.children[1].max = np.max(data) - except: + except Exception: self.latent.data[0].zmax = np.max(data) self.latent.data[0].zmin = np.min(data) self.block0.children[1].max = np.max(data) @@ -214,7 +209,6 @@ def drop_background_event(self, change): self.update_trails() - def drop_dataset_event(self, change): ''' control which dataset is displayed @@ -226,7 +220,7 @@ def drop_dataset_event(self, change): else: try: - data = as_numpy(self.MA.get_encoded(change.new).squeeze(2)) + data = as_numpy(self.MA.get_encoded(change.new).squeeze(2)) except Exception as e: print(f"{e}") return @@ -238,7 +232,6 @@ def drop_dataset_event(self, change): self.latent.update() - def drop_path_event(self, change): ''' control way paths are looked for @@ -251,7 +244,6 @@ def drop_path_event(self, change): self.update_trails() - def range_slider_event(self, change): ''' update surface colouring upon manipulation of range slider @@ -261,7 +253,6 @@ def range_slider_event(self, change): self.latent.data[0].zmax = change.new[1] self.latent.update() - def trail_update_event(self, change): ''' update trails (waypoints and way they are connected) @@ -270,7 +261,7 @@ def trail_update_event(self, change): try: crd = np.array(self.mybox.value.split()).astype(float) crd = crd.reshape((int(len(crd)/2), 2)) - except: + except Exception: self.button_pdb.disabled = False return @@ -278,7 +269,6 @@ def trail_update_event(self, change): self.update_trails() - def button_pdb_event(self, check): ''' save PDB file corresponding to the interpolation shown in the 3D view @@ -307,7 +297,6 @@ def button_pdb_event(self, check): for ts in self.mymol.trajectory: W.write(protein) - def button_save_state_event(self, check): ''' save class state @@ -321,8 +310,7 @@ def button_save_state_event(self, check): if fname == "": return - pickle.dump([self.MA, self.waypoints], open( fname, "wb" ) ) - + pickle.dump([self.MA, self.waypoints], open(fname, "wb")) def button_load_state_event(self, check): ''' @@ -338,7 +326,7 @@ def button_load_state_event(self, check): return try: - self.MA, self.waypoints = pickle.load( open( fname, "rb" ) ) + self.MA, self.waypoints = pickle.load(open(fname, "rb")) self.run() except Exception as e: raise Exception(f"Cannot load state file. {e}") @@ -349,7 +337,7 @@ def run(self): # create an MDAnalysis instance of input protein (for viewing purposes) if hasattr(self.MA, "mol"): - self.MA.mol.write_pdb("tmp.pdb", conformations=[0], split_struc = False) + self.MA.mol.write_pdb("tmp.pdb", conformations=[0], split_struc=False) self.mymol = mda.Universe('tmp.pdb') ### MENU ITEMS ### @@ -376,7 +364,6 @@ def run(self): self.drop_background.observe(self.drop_background_event, names='value') - # dataset selector dropdown menu options2 = ["none"] if self.MA is not None: @@ -406,7 +393,6 @@ def run(self): self.drop_path.observe(self.drop_path_event, names='value') - # text box holding current coordinates self.mybox = widgets.Textarea(placeholder='coordinates', description='crds:', @@ -421,7 +407,6 @@ def run(self): self.samplebox.observe(self.trail_update_event, names='value') - # button to save PDB file self.button_pdb = widgets.Button( description='Save PDB', @@ -429,23 +414,20 @@ def run(self): self.button_pdb.on_click(self.button_pdb_event) - # button to save state file self.button_save_state = widgets.Button( - description= 'Save state', + description='Save state', disabled=False, layout=Layout(flex='1 1 0%', width='auto')) self.button_save_state.on_click(self.button_save_state_event) - # button to load state file self.button_load_state = widgets.Button( - description= 'Load state', + description='Load state', disabled=False, layout=Layout(flex='1 1 0%', width='auto')) self.button_load_state.on_click(self.button_load_state_event) - # latent space range slider self.range_slider = widgets.FloatRangeSlider( description='cmap range:', @@ -463,8 +445,7 @@ def run(self): if self.waypoints == []: self.button_pdb.disabled = True - - + ### LATENT SPACE REPRESENTATION ### # surface @@ -502,7 +483,7 @@ def run(self): # path plot3 = go.Scatter(x=np.array([]), y=np.array([]), - showlegend=False, opacity=0.9, mode = 'lines+markers', + showlegend=False, opacity=0.9, mode='lines+markers', marker=dict(color='red', size=4)) self.latent = go.FigureWidget([plot1, plot2, plot3]) @@ -521,7 +502,7 @@ def run(self): try: self.range_slider.min = scmin self.range_slider.max = scmax - except: + except Exception: self.range_slider.max = scmax self.range_slider.min = scmin @@ -530,8 +511,7 @@ def run(self): # 3D protein representation (triggered by update of textbox, sampling box, or pathfinding method) self.protein = widgets.interactive_output(self.interact_3D, {'mybox': self.mybox, 'samplebox': self.samplebox, 'path': self.drop_path}) - - + ### WIDGETS ARRANGEMENT ### self.block0 = widgets.VBox([self.drop_dataset, self.range_slider, @@ -555,4 +535,3 @@ def run(self): display.clear_output(wait=True) display.display(self.scene) - diff --git a/src/molearn/analysis/analyser.py b/src/molearn/analysis/analyser.py index aacce1e..83bfe8d 100644 --- a/src/molearn/analysis/analyser.py +++ b/src/molearn/analysis/analyser.py @@ -30,7 +30,7 @@ warnings.filterwarnings("ignore") -class MolearnAnalysis(object): +class MolearnAnalysis: ''' This class provides methods dedicated to the quality analysis of a trained model. @@ -125,7 +125,7 @@ def get_decoded(self, key): encoded = self.get_encoded(key) decoded = torch.empty(encoded.shape[0], *self.shape).float() for i in tqdm(range(0, encoded.shape[0], batch_size), desc=f'Decoding {key}'): - decoded[i:i+batch_size] = self.network.decode(encoded[i:i+batch_size].to(self.device))[:,:,:self.shape[1]].cpu() + decoded[i:i+batch_size] = self.network.decode(encoded[i:i+batch_size].to(self.device))[:, :, :self.shape[1]].cpu() self._decoded[key] = decoded return self._decoded[key] @@ -157,20 +157,20 @@ def get_error(self, key, align=True): m = deepcopy(self.mol) for i in range(dataset.shape[0]): crd_ref = as_numpy(dataset[i].permute(1,0).unsqueeze(0))*self.stdval + self.meanval - crd_mdl = as_numpy(decoded[i].permute(1,0).unsqueeze(0))[:, :dataset.shape[2]]*self.stdval + self.meanval #clip the padding of models - if align: # use Molecule Biobox class to calculate RMSD + crd_mdl = as_numpy(decoded[i].permute(1,0).unsqueeze(0))[:, :dataset.shape[2]]*self.stdval + self.meanval # clip the padding of models + # use Molecule Biobox class to calculate RMSD + if align: m.coordinates = deepcopy(crd_ref) m.set_current(0) m.add_xyz(crd_mdl[0]) rmsd = m.rmsd(0, 1) else: - rmsd = np.sqrt(np.sum((crd_ref.flatten()-crd_mdl.flatten())**2)/crd_mdl.shape[1]) # Cartesian L2 norm + rmsd = np.sqrt(np.sum((crd_ref.flatten()-crd_mdl.flatten())**2)/crd_mdl.shape[1]) # Cartesian L2 norm err.append(rmsd) return np.array(err) - def get_dope(self, key, refine=True, **kwargs): ''' :param str key: key pointing to a dataset previously loaded with :func:`set_dataset ` @@ -180,11 +180,11 @@ def get_dope(self, key, refine=True, **kwargs): dataset = self.get_dataset(key) decoded = self.get_decoded(key) - dope_dataset = self.get_all_dope_score(dataset, refine=refine,**kwargs) - dope_decoded = self.get_all_dope_score(decoded, refine=refine,**kwargs) + dope_dataset = self.get_all_dope_score(dataset, refine=refine, **kwargs) + dope_decoded = self.get_all_dope_score(decoded, refine=refine, **kwargs) - return dict(dataset_dope = dope_dataset, - decoded_dope = dope_decoded) + return dict(dataset_dope=dope_dataset, + decoded_dope=dope_decoded) def get_ramachandran(self, key): ''' @@ -213,7 +213,7 @@ def setup_grid(self, samples=64, bounds_from=None, bounds=None, padding=0.1): if bounds_from is None: bounds_from = "all" - bounds = self._get_bounds(bounds_from, exclude = key) + bounds = self._get_bounds(bounds_from, exclude=key) bx = (bounds[1]-bounds[0])*padding by = (bounds[3]-bounds[2])*padding @@ -221,12 +221,12 @@ def setup_grid(self, samples=64, bounds_from=None, bounds=None, padding=0.1): self.yvals = np.linspace(bounds[2]-by, bounds[3]+by, samples) self.n_samples = samples meshgrid = np.meshgrid(self.xvals, self.yvals) - stack = np.stack(meshgrid, axis=2).reshape(-1,1,2) + stack = np.stack(meshgrid, axis=2).reshape(-1, 1, 2) self.set_encoded(key, stack) return key - def _get_bounds(self, bounds_from, exclude = ['grid', 'grid_decoded']): + def _get_bounds(self, bounds_from, exclude=['grid', 'grid_decoded']): ''' :param bounds_from: keys of datasets to be considered for identification of boundaries in latent space :param exclude: keys of dataset not to consider @@ -243,10 +243,10 @@ def _get_bounds(self, bounds_from, exclude = ['grid', 'grid_decoded']): xmin, ymin, xmax, ymax = [], [], [], [] for key in bounds_from: z = self.get_encoded(key) - xmin.append(z[:,0].min()) - ymin.append(z[:,1].min()) - xmax.append(z[:,0].max()) - ymax.append(z[:,1].max()) + xmin.append(z[:, 0].min()) + ymin.append(z[:, 1].min()) + xmax.append(z[:, 0].max()) + ymax.append(z[:, 1].max()) xmin, ymin = min(xmin), min(ymin) xmax, ymax = max(xmax), max(ymax) @@ -276,12 +276,12 @@ def scan_error_from_target(self, key, index=None, align=True): decoded = self.get_decoded('grid') if align: - crd_ref = as_numpy(target.permute(0,2,1))*self.stdval - crd_mdl = as_numpy(decoded.permute(0,2,1))*self.stdval + crd_ref = as_numpy(target.permute(0, 2, 1))*self.stdval + crd_mdl = as_numpy(decoded.permute(0, 2, 1))*self.stdval m = deepcopy(self.mol) m.coordinates = np.concatenate([crd_ref, crd_mdl]) m.set_current(0) - rmsd = np.array([m.rmsd(0,i) for i in range(1, len(m.coordinates))]) + rmsd = np.array([m.rmsd(0, i) for i in range(1, len(m.coordinates))]) else: rmsd = (((decoded-target)*self.stdval)**2).sum(axis=1).mean(axis=-1).sqrt() self.surfaces[s_key] = as_numpy(rmsd.reshape(self.n_samples, self.n_samples)) @@ -304,18 +304,19 @@ def scan_error(self, s_key='Network_RMSD', z_key='Network_z_drift'): z_key = 'Network_z_drift' if s_key not in self.surfaces: assert 'grid' in self._encoded, 'make sure to call MolearnAnalysis.setup_grid first' - decoded = self.get_decoded('grid') # decode grid - #self.set_dataset('grid_decoded', decoded) # add back as dataset w. different name + decoded = self.get_decoded('grid') # decode grid + # self.set_dataset('grid_decoded', decoded) # add back as dataset w. different name self._datasets['grid_decoded'] = decoded - decoded_2 = self.get_decoded('grid_decoded') # encode, and decode a second time - grid = self.get_encoded('grid') # retrieve original grid - grid_2 = self.get_encoded('grid_decoded') # retrieve decoded encoded grid + decoded_2 = self.get_decoded('grid_decoded') # encode, and decode a second time + grid = self.get_encoded('grid') # retrieve original grid + grid_2 = self.get_encoded('grid_decoded') # retrieve decoded encoded grid rmsd = (((decoded-decoded_2)*self.stdval)**2).sum(axis=1).mean(axis=-1).sqrt() z_drift = ((grid-grid_2)**2).mean(axis=2).mean(axis=1).sqrt() self.surfaces[s_key] = rmsd.reshape(self.n_samples, self.n_samples).numpy() self.surfaces[z_key] = z_drift.reshape(self.n_samples, self.n_samples).numpy() + return self.surfaces[s_key], self.surfaces[z_key], self.xvals, self.yvals def _ramachandran_score(self, frame): @@ -324,10 +325,10 @@ def _ramachandran_score(self, frame): AsyncResult.get() will return the result ''' if not hasattr(self, 'ramachandran_score_class'): - self.ramachandran_score_class = Parallel_Ramachandran_Score(self.mol, self.processes) #Parallel_Ramachandran_Score(self.mol) + self.ramachandran_score_class = Parallel_Ramachandran_Score(self.mol, self.processes) assert len(frame.shape) == 2, f'We wanted 2D data but got {len(frame.shape)} dimensions' if frame.shape[0] == 3: - f = frame.permute(1,0) + f = frame.permute(1, 0) else: assert frame.shape[1] == 3 f = frame @@ -335,9 +336,8 @@ def _ramachandran_score(self, frame): f = f.data.cpu().numpy() return self.ramachandran_score_class.get_score(f*self.stdval) - #nf, na, no, nt = self.ramachandran_score_class.get_score(f*self.stdval) - #return {'favored':nf, 'allowed':na, 'outliers':no, 'total':nt} - + # nf, na, no, nt = self.ramachandran_score_class.get_score(f*self.stdval) + # return {'favored':nf, 'allowed':na, 'outliers':no, 'total':nt} def _dope_score(self, frame, refine=True, **kwargs): ''' @@ -349,9 +349,9 @@ def _dope_score(self, frame, refine=True, **kwargs): assert len(frame.shape) == 2, f'We wanted 2D data but got {len(frame.shape)} dimensions' if frame.shape[0] == 3: - f = frame.permute(1,0) + f = frame.permute(1, 0) else: - assert frame.shape[1] ==3 + assert frame.shape[1] == 3 f = frame if isinstance(f,torch.Tensor): f = f.data.cpu().numpy() @@ -368,7 +368,7 @@ def get_all_ramachandran_score(self, tensor): results = [] for f in tensor: results.append(self._ramachandran_score(f)) - for r in tqdm(results,desc=f'Calc rama'): + for r in tqdm(results,desc='Calc rama'): favored, allowed, outliers, total = r.get() rama['favored'].append(favored) rama['allowed'].append(allowed) @@ -386,7 +386,7 @@ def get_all_dope_score(self, tensor, refine=True): results = [] for f in tensor: results.append(self._dope_score(f, refine=refine)) - results = np.array([r.get() for r in tqdm(results, desc=f'Calc Dope')]) + results = np.array([r.get() for r in tqdm(results, desc='Calc Dope')]) return results def reference_dope_score(self, frame): @@ -395,7 +395,7 @@ def reference_dope_score(self, frame): :return: DOPE score ''' self.mol.coordinates = deepcopy(frame) - self.mol.write_pdb('tmp.pdb', split_struc = False) + self.mol.write_pdb('tmp.pdb', split_struc=False) env = Environ() env.libs.topology.read(file='$(LIB)/top_heav.lib') env.libs.parameters.read(file='$(LIB)/par.lib') @@ -429,7 +429,7 @@ def scan_dope(self, key=None, refine=True, **kwargs): decoded = self.get_decoded('grid') result = self.get_all_dope_score(decoded, refine=refine, **kwargs) if refine=='both': - self.surfaces[key] = as_numpy(result.reshape(self.n_samples, self.n_samples,2)) + self.surfaces[key] = as_numpy(result.reshape(self.n_samples, self.n_samples, 2)) else: self.surfaces[key] = as_numpy(result.reshape(self.n_samples, self.n_samples)) @@ -468,8 +468,8 @@ def scan_custom(self, fct, params, key): ''' decoded = self.get_decoded('grid') results = [] - for i,j in enumerate(decoded): - s = (j.view(1,3,-1).permute(0,2,1)*self.stdval).numpy() + for i, j in enumerate(decoded): + s = (j.view(1, 3, -1).permute(0, 2, 1)*self.stdval).numpy() results.append(fct(s, *params)) self.surfaces[key] = np.array(results).reshape(self.n_samples, self.n_samples) @@ -489,7 +489,5 @@ def generate(self, crd): return s*self.stdval + self.meanval - def __getstate__(self): return {key:value for key, value in dict(self.__dict__).items() if key not in ['dope_score_class', 'ramachandran_score_class']} - diff --git a/src/molearn/analysis/path.py b/src/molearn/analysis/path.py index 62203c5..acba0f2 100644 --- a/src/molearn/analysis/path.py +++ b/src/molearn/analysis/path.py @@ -6,7 +6,8 @@ :synopsis: Tools for linking waypoints with paths in latent space """ -class PriorityQueue(object): + +class PriorityQueue: ''' Queue for shortest path algorithms. @@ -73,7 +74,7 @@ def _neighbors(idx, gridshape, flattened=True): idx = np.unravel_index(idx, gridshape) elif len(idx) != 2: raise Exception("Expecting 2D coordinates") - except: + except Exception: raise Exception("idx should be either integer or an iterable") # generate neighbour list @@ -103,6 +104,7 @@ def _cost(pt, graph): ''' :return: scalar value, reporting on the cost of moving onto a grid cell ''' + # separate function for clarity, and in case in the future we want to alter this return graph[pt] @@ -177,7 +179,9 @@ def get_path(idx_start, idx_end, landscape, xvals, yvals, smooth=3): coords = [] score = [] idx_flat = np.ravel_multi_index(idx_end, landscape.shape) - while cnt<1000: #safeguad for (unlikely) unfinished paths + + # safeguard for (unlikely) unfinished paths + while cnt<1000: if idx_flat == mypath[idx_flat]: break @@ -203,7 +207,6 @@ def get_path(idx_start, idx_end, landscape, xvals, yvals, smooth=3): return traj_smooth, np.array(score)[::-1] - def _get_point_index(crd, xvals, yvals): ''' Extract index (of 2D surface) closest to a given real value coordinate @@ -259,6 +262,7 @@ def oversample(crd, pts=10): :param int pts: number of extra points to add in each interval :return: Mx2 numpy array, with M>=N. ''' + pts += 1 steps = np.linspace(1./pts, 1, pts) pts = [crd[0]] @@ -267,4 +271,4 @@ def oversample(crd, pts=10): newpt = crd[i-1] + (crd[i]-crd[i-1])*j pts.append(newpt) - return np.array(pts) \ No newline at end of file + return np.array(pts) diff --git a/src/molearn/data/pdb_data.py b/src/molearn/data/pdb_data.py index 376a060..36f4320 100644 --- a/src/molearn/data/pdb_data.py +++ b/src/molearn/data/pdb_data.py @@ -3,9 +3,10 @@ from copy import deepcopy import biobox as bb + class PDBData: - def __init__(self, filename = None, fix_terminal = False, atoms = None, ): + def __init__(self, filename=None, fix_terminal=False, atoms=None): ''' Create object enabling the manipulation of multi-PDB files into a dataset suitable for training. @@ -22,7 +23,7 @@ def __init__(self, filename = None, fix_terminal = False, atoms = None, ): if fix_terminal: self.fix_terminal() if atoms is not None: - self.atomselect(atoms = atoms) + self.atomselect(atoms=atoms) def import_pdb(self, filename): ''' @@ -60,7 +61,7 @@ def atomselect(self, atoms, ignore_atoms=[]): if to_remove in _atoms: _atoms.remove(to_remove) elif atoms == "no_hydrogen": - _atoms = self.atoms #list(np.unique(self._mol.data["name"].values)) #all the atoms + _atoms = self.atoms # list(np.unique(self._mol.data["name"].values)) #all the atoms _plain_atoms = [] for a in _atoms: if a in self._mol.knowledge['atomtype']: @@ -68,7 +69,7 @@ def atomselect(self, atoms, ignore_atoms=[]): elif a[:-1] in self._mol.knowledge['atomtype']: _plain_atoms.append(self._mol.knowledge['atomtype'][a[:-1]]) else: - _plain_atoms.append(self._mol.knowledge['atomtype'][a]) # if above failed just raise the keyerror + _plain_atoms.append(self._mol.knowledge['atomtype'][a]) # if above failed just raise the keyerror _atoms = [atom for atom, element in zip(_atoms, _plain_atoms) if element != 'H'] else: _atoms = [_a for _a in atoms if _a not in ignore_atoms] @@ -155,7 +156,7 @@ def split(self, *args, **kwargs): :return: :func:`PDBData ` object corresponding to train set :return: :func:`PDBData ` object corresponding to validation set ''' - #validation_split=0.1, valid_size=None, train_size=None, manual_seed = None): + # validation_split=0.1, valid_size=None, train_size=None, manual_seed = None): train_dataset, valid_dataset = self.get_datasets(*args, **kwargs) train = PDBData() valid = PDBData() @@ -166,7 +167,7 @@ def split(self, *args, **kwargs): valid.dataset = valid_dataset return train, valid - def get_datasets(self, validation_split=0.1, valid_size=None, train_size=None, manual_seed = None): + def get_datasets(self, validation_split=0.1, valid_size=None, train_size=None, manual_seed=None): ''' Create a training and validation set from the imported data @@ -190,7 +191,7 @@ def get_datasets(self, validation_split=0.1, valid_size=None, train_size=None, m _valid_size = valid_size from torch import randperm if manual_seed is not None: - indices = randperm(len(self.dataset), generator = torch.Generator().manual_seed(manual_seed)) + indices = randperm(len(self.dataset), generator=torch.Generator().manual_seed(manual_seed)) else: indices = randperm(len(self.dataset)) @@ -201,12 +202,8 @@ def get_datasets(self, validation_split=0.1, valid_size=None, train_size=None, m @property def atoms(self): - return list(np.unique(self._mol.data["name"].values)) #all the atoms + return list(np.unique(self._mol.data["name"].values)) # all the atoms @property def mol(self): return self.frame() - - - - diff --git a/src/molearn/loss_functions/openmm_thread.py b/src/molearn/loss_functions/openmm_thread.py index 6ed4c66..2f20c11 100644 --- a/src/molearn/loss_functions/openmm_thread.py +++ b/src/molearn/loss_functions/openmm_thread.py @@ -18,7 +18,7 @@ class ModifiedForceField(ForceField): - def __init__(self, *args, alternative_residue_names = None, **kwargs): + def __init__(self, *args, alternative_residue_names=None, **kwargs): ''' Takes all `*args` and `**kwargs` of `openmm.app.ForceField`, plus an optional parameter described here. @@ -80,7 +80,7 @@ def _getResidueTemplateMatches(self, res, bondedToAtom, templateSignatures=None, matches = m return [template, matches] print(f'multiple for {t.name}') - # We found multiple matches. This is OK if and only if they assign identical types and parameters to all atoms. + # We found multiple matches. This is OK if and only if they assign identical types and parameters to all atoms. t1, m1 = allMatches[0] for t2, m2 in allMatches[1:]: @@ -90,15 +90,16 @@ def _getResidueTemplateMatches(self, res, bondedToAtom, templateSignatures=None, matches = allMatches[0][1] return [template, matches] + class OpenmmPluginScore(): ''' This will use the new OpenMM Plugin to calculate forces and energy. The intention is that this will be fast enough to be able to calculate forces and energy during training. N.B.: The current torchintegratorplugin only supports float on GPU and double on CPU. ''' - def __init__(self, mol=None, xml_file = ['amber14-all.xml'], platform = 'CUDA', remove_NB=False, - alternative_residue_names = dict(HIS='HIE', HSE='HIE'), atoms=['CA', 'C', 'N', 'CB','O'], - soft=False): + def __init__(self, mol=None, xml_file=['amber14-all.xml'], platform='CUDA', remove_NB=False, + alternative_residue_names=dict(HIS='HIE', HSE='HIE'), atoms=['CA', 'C', 'N', 'CB','O'], + soft=False): ''' :param `biobox.Molecule` mol: if pldataloader is not given, then a biobox object will be taken from this parameter. If neither are given then an error will be thrown. :param str xml_file: xml parameter file @@ -110,12 +111,12 @@ def __init__(self, mol=None, xml_file = ['amber14-all.xml'], platform = 'CUDA', ''' self.mol = mol for key, value in alternative_residue_names.items(): - #self.mol.data.loc[:,'resname'][self.mol.data['resname']==key]=value + # self.mol.data.loc[:,'resname'][self.mol.data['resname']==key]=value self.mol.data.loc[self.mol.data['resname']==key,'resname']=value - #self.mol.data.loc[lambda df: df['resname']==key, key]=value + # self.mol.data.loc[lambda df: df['resname']==key, key]=value tmp_file = f'tmp{np.random.randint(1e10)}.pdb' self.atoms = atoms - self.mol.write_pdb(tmp_file, split_struc = False) + self.mol.write_pdb(tmp_file, split_struc=False) self.pdb = PDBFile(tmp_file) if soft: print('attempting soft forcefield') @@ -125,9 +126,9 @@ def __init__(self, mol=None, xml_file = ['amber14-all.xml'], platform = 'CUDA', self.system = self.forcefield.createSystem(self.pdb.topology) else: if isinstance(xml_file,str): - self.forcefield = ModifiedForceField(xml_file, alternative_residue_names = alternative_residue_names) + self.forcefield = ModifiedForceField(xml_file, alternative_residue_names=alternative_residue_names) elif len(xml_file)>0: - self.forcefield = ModifiedForceField(*xml_file, alternative_residue_names = alternative_residue_names) + self.forcefield = ModifiedForceField(*xml_file, alternative_residue_names=alternative_residue_names) else: raise ValueError(f'xml_file: {xml_file} needs to be a str or a list of str') @@ -135,14 +136,14 @@ def __init__(self, mol=None, xml_file = ['amber14-all.xml'], platform = 'CUDA', self.ignore_hydrogen() else: self.atomselect(atoms) - #save pdb and reload in modeller + # save pdb and reload in modeller templates, unique_unmatched_residues = self.forcefield.generateTemplatesForUnmatchedResidues(self.pdb.topology) self.system = self.forcefield.createSystem(self.pdb.topology) if remove_NB: forces = self.system.getForces() for idx in reversed(range(len(forces))): force = forces[idx] - if isinstance(force, (#openmm.PeriodicTorsionForce, + if isinstance(force, ( # openmm.PeriodicTorsionForce, openmm.CustomGBForce, openmm.NonbondedForce, openmm.CMMotionRemover)): @@ -154,7 +155,6 @@ def __init__(self, mol=None, xml_file = ['amber14-all.xml'], platform = 'CUDA', if isinstance(force, openmm.CustomGBForce): self.system.removeForce(idx) - self.integrator = TorchExposedIntegrator() self.platform = Platform.getPlatformByName(platform) self.simulation = Simulation(self.pdb.topology, self.system, self.integrator, self.platform) @@ -167,7 +167,7 @@ def __init__(self, mol=None, xml_file = ['amber14-all.xml'], platform = 'CUDA', os.remove(tmp_file) def ignore_hydrogen(self): - #ignore = ['ASH', 'LYN', 'GLH', 'HID', 'HIP', 'CYM', ] + # ignore = ['ASH', 'LYN', 'GLH', 'HID', 'HIP', 'CYM', ] ignore = [] for name, template in self.forcefield._templates.items(): if name in ignore: @@ -215,7 +215,6 @@ def atomselect(self, atoms): self.forcefield.registerTemplatePatch(name, name+'_leave_only_'+'_'.join(atoms), 0) self.forcefield.registerPatch(patchData) - def get_energy(self, pos_ptr, force_ptr, energy_ptr, n_particles, batch_size): ''' :param pos_ptr: tensor.data_ptr() @@ -234,12 +233,13 @@ def execute(self, x): :param `torch.Tensor` x: shape [b, N, 3]. dtype=float. device = gpu ''' force = torch.zeros_like(x) - energy = torch.zeros(x.shape[0], device = torch.device('cpu'), dtype=torch.double) + energy = torch.zeros(x.shape[0], device=torch.device('cpu'), dtype=torch.double) self.get_energy(x.data_ptr(), force.data_ptr(), energy.data_ptr(), x.shape[1], x.shape[0]) return force, energy class OpenmmTorchEnergyMinimizer(OpenmmPluginScore): + def minimize(self, x, maxIterations=10, threshold=10000): minimized_x = torch.empty_like(x) for i,s in enumerate(x.unsqueeze(1)): @@ -261,13 +261,13 @@ def minimize(self, x, maxIterations=10, threshold=10000): return minimized_x - class OpenMMPluginScoreSoftForceField(OpenmmPluginScore): + def __init__(self, mol=None, platform='CUDA', atoms=['CA','C','N','CB','O']): self.mol = mol tmp_file = 'tmp.pdb' self.atoms = atoms - self.mol.write_pdb(tmp_file, split_struc = False) + self.mol.write_pdb(tmp_file, split_struc=False) self.pdb = PDBFile(tmp_file) from pdbfixer import PDBFixer f = PDBFixer(tmp_file) @@ -304,19 +304,20 @@ def forward(ctx, plugin, x): force = torch.tensor(force).float() energy = torch.tensor(energy).float() else: - #torch.cuda.synchronize(x.device) + # torch.cuda.synchronize(x.device) force, energy = plugin.execute(x) - #torch.cuda.synchronize(x.device) + # torch.cuda.synchronize(x.device) ctx.save_for_backward(force) energy = energy.float().to(x.device) return energy @staticmethod def backward(ctx, grad_output): - force = ctx.saved_tensors[0] # force shape [B, N, 3] - #embed(header='23 openmm_loss_function') + force = ctx.saved_tensors[0] # force shape [B, N, 3] + # embed(header='23 openmm_loss_function') return None, -force*grad_output.view(-1,1,1) + class openmm_clamped_energy_function(torch.autograd.Function): @staticmethod @@ -347,10 +348,12 @@ def forward(ctx, plugin, x, clamp): @staticmethod def backward(ctx, grad_output): force = ctx.saved_tensors[0] - return None, -force*grad_output.view(-1,1,1), None + return None, -force*grad_output.view(-1, 1, 1), None + class openmm_energy(torch.nn.Module): - def __init__(self, mol, std, clamp = None, **kwargs): + + def __init__(self, mol, std, clamp=None, **kwargs): super().__init__() self.openmmplugin = OpenmmPluginScore(mol, **kwargs) self.std = std/10 @@ -365,7 +368,7 @@ def _forward(self, x): :param `torch.Tensor` x: dtype=torch.float, device=CUDA, shape B, 3, N :returns: torch energy tensor dtype should be float and on same device as x ''' - _x = (x*self.std).permute(0,2,1).contiguous() + _x = (x*self.std).permute(0, 2, 1).contiguous() energy = openmm_energy_function.apply(self.openmmplugin, _x) return energy @@ -374,7 +377,6 @@ def _clamp_forward(self, x): :param `torch.Tensor` x: dtype=torch.float, device=CUDA, shape B, 3, N :returns: torch energy tensor dtype should be float and on same device as x ''' - _x = (x*self.std).permute(0,2,1).contiguous() + _x = (x*self.std).permute(0, 2, 1).contiguous() energy = openmm_clamped_energy_function.apply(self.openmmplugin, _x, self.clamp) return energy - diff --git a/src/molearn/loss_functions/torch_protein_energy.py b/src/molearn/loss_functions/torch_protein_energy.py index 7d55699..078f43e 100644 --- a/src/molearn/loss_functions/torch_protein_energy.py +++ b/src/molearn/loss_functions/torch_protein_energy.py @@ -9,18 +9,16 @@ # You should have received a copy of the GNU General Public License along with molearn ; # if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. - -import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F -from copy import deepcopy +# import torch.nn as nn +# import torch.nn.functional as F from molearn.loss_functions.torch_protein_energy_utils import get_convolutions + class TorchProteinEnergy(): def __init__(self, frame, pdb_atom_names, padded_residues=False, - method =('indexed', 'convolutional', 'roll')[2], + method=('indexed', 'convolutional', 'roll')[2], device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'), fix_h=False,alt_vdw=[], NB='repulsive'): ''' @@ -45,7 +43,7 @@ def __init__(self, frame, pdb_atom_names, ''' self.device = device self.method = method - if padded_residues == True: + if padded_residues: if method == 'indexed': self._padded_indexed_init(frame, pdb_atom_names) else: @@ -69,61 +67,60 @@ def get_energy(self, x, nonbonded=False): else: return self.get_bonded_energy(x) - - def _roll_init(self, frame, pdb_atom_names, NB='full', fix_h=False,alt_vdw=[]): + def _roll_init(self, frame, pdb_atom_names, NB='full', fix_h=False, alt_vdw=[]): (b_masks, b_equil, b_force, b_weights, a_masks, a_equil, a_force, a_weights, t_masks, t_para, t_weights, vdw_R, vdw_e, vdw_14R, vdw_14e, - q1q2, q1q2_14 )=get_convolutions(frame, pdb_atom_names, fix_slice_method=True, fix_h=fix_h,alt_vdw=alt_vdw) + q1q2, q1q2_14) = get_convolutions(frame, pdb_atom_names, fix_slice_method=True, fix_h=fix_h, alt_vdw=alt_vdw) - self.brdiff=[] - self.br_equil=[] - self.br_force=[] + self.brdiff = [] + self.br_equil = [] + self.br_force = [] for i,j in enumerate(b_weights): - atom1=j.index(1) - atom2=j.index(-1) + atom1 = j.index(1) + atom2 = j.index(-1) d = j.index(-1)-j.index(1) - padding=len(j)-2 + padding = len(j)-2 self.brdiff.append(d) - #b_equil[:,0] is just padding so can roll(-1,1) to get correct padding - self.br_equil.append(torch.tensor(b_equil[i,padding-1:]).roll(-1).to(self.device).float()) - self.br_force.append(torch.tensor(b_force[i,padding-1:]).roll(-1).to(self.device).float()) - self.ardiff=[] - self.arsign=[] - self.arroll=[] - self.ar_equil=[] - self.ar_force=[] - self.ar_masks=[] + # b_equil[:,0] is just padding so can roll(-1,1) to get correct padding + self.br_equil.append(torch.tensor(b_equil[i, padding-1:]).roll(-1).to(self.device).float()) + self.br_force.append(torch.tensor(b_force[i, padding-1:]).roll(-1).to(self.device).float()) + self.ardiff = [] + self.arsign = [] + self.arroll = [] + self.ar_equil = [] + self.ar_force = [] + self.ar_masks = [] for i, j in enumerate(a_weights): - atom1=j[0].index(1) - atom2=j[0].index(-1) - atom3=j[1].index(1) - diff1=atom2-atom1 - diff2=atom2-atom3 - padding=len(j[0])-3 + atom1 = j[0].index(1) + atom2 = j[0].index(-1) + atom3 = j[1].index(1) + diff1 = atom2-atom1 + diff2 = atom2-atom3 + padding = len(j[0])-3 self.arroll.append([min(atom1,atom2), min(atom2,atom3)]) self.ardiff.append([abs(diff1)-1, abs(diff2)-1]) self.arsign.append([diff1/abs(diff1), diff2/abs(diff2)]) self.ar_equil.append(torch.tensor(a_equil[i,padding-2:]).roll(-2).to(self.device).float()) self.ar_force.append(torch.tensor(a_force[i,padding-2:]).roll(-2).to(self.device).float()) - self.trdiff=[] - self.trsign=[] - self.trroll=[] - self.tr_para=[] + self.trdiff = [] + self.trsign = [] + self.trroll = [] + self.tr_para = [] for i, j in enumerate(t_weights): - atom1=j[0].index(1) #i-j 0 - atom2=j[0].index(-1) #i-j 2 - atom3=j[1].index(-1) #j-k 3 - atom4=j[2].index(1) #l-k 4 - diff1=atom2-atom1 #ij 2 - diff2=atom3-atom2 #jk 1 - diff3=(atom4-atom3)*-1 #lk 1 - padding=len(j[0])-4 - self.trroll.append([min(atom1,atom2),min(atom2,atom3),min(atom3,atom4)]) + atom1 = j[0].index(1) # i-j 0 + atom2 = j[0].index(-1) # i-j 2 + atom3 = j[1].index(-1) # j-k 3 + atom4 = j[2].index(1) # l-k 4 + diff1 = atom2-atom1 # ij 2 + diff2 = atom3-atom2 # jk 1 + diff3 = (atom4-atom3)*-1 # lk 1 + padding = len(j[0])-4 + self.trroll.append([min(atom1,atom2), min(atom2,atom3), min(atom3,atom4)]) self.trsign.append([diff1/abs(diff1), diff2/abs(diff2), diff3/abs(diff3)]) self.trdiff.append([abs(diff1)-1, abs(diff2)-1, abs(diff3)-1]) - self.tr_para.append(torch.tensor(t_para[i,padding-3:]).roll(-3,0).to(self.device).float()) + self.tr_para.append(torch.tensor(t_para[i, padding-3:]).roll(-3, 0).to(self.device).float()) self.vdw_A = (vdw_e*(vdw_R**12)).to(self.device) self.vdw_B = (2*vdw_e*(vdw_R**6)).to(self.device) @@ -135,24 +132,24 @@ def _roll_init(self, frame, pdb_atom_names, NB='full', fix_h=False,alt_vdw=[]): elif NB == 'repulsive': self._nb_loss = self._cdist_nb - def _convolutional_init(self, frame, pdb_atom_names, NB='full', fix_h=False,alt_vdw=[]): + def _convolutional_init(self, frame, pdb_atom_names, NB='full', fix_h=False, alt_vdw=[]): (b_masks, b_equil, b_force, b_weights, a_masks, a_equil, a_force, a_weights, t_masks, t_para, t_weights, vdw_R, vdw_e, vdw_14R, vdw_14e, - q1q2, q1q2_14 )=get_convolutions(frame, pdb_atom_names, fix_slice_method=False, fix_h=fix_h, alt_vdw=alt_vdw) + q1q2, q1q2_14) = get_convolutions(frame, pdb_atom_names, fix_slice_method=False, fix_h=fix_h, alt_vdw=alt_vdw) - self.b_equil =torch.tensor(b_equil ).to(self.device) - self.b_force =torch.tensor(b_force ).to(self.device) - self.b_weights=torch.tensor(b_weights).to(self.device) + self.b_equil = torch.tensor(b_equil ).to(self.device) + self.b_force = torch.tensor(b_force ).to(self.device) + self.b_weights = torch.tensor(b_weights).to(self.device) - self.a_equil =torch.tensor(a_equil ).to(self.device).float() - self.a_force =torch.tensor(a_force ).to(self.device).float() - self.a_weights=torch.tensor(a_weights).to(self.device) - self.a_masks =torch.tensor(a_masks ).to(self.device) + self.a_equil = torch.tensor(a_equil ).to(self.device).float() + self.a_force = torch.tensor(a_force ).to(self.device).float() + self.a_weights = torch.tensor(a_weights).to(self.device) + self.a_masks = torch.tensor(a_masks ).to(self.device) - self.t_para =torch.tensor(t_para ).to(self.device) - self.t_weights=torch.tensor(t_weights).to(self.device) + self.t_para = torch.tensor(t_para ).to(self.device) + self.t_weights = torch.tensor(t_weights).to(self.device) self.vdw_A = (vdw_e*(vdw_R**12)).to(self.device) self.vdw_B = (2*vdw_e*(vdw_R**6)).to(self.device) @@ -164,7 +161,7 @@ def _convolutional_init(self, frame, pdb_atom_names, NB='full', fix_h=False,alt_ elif NB == 'repulsive': self._nb_loss = self._cdist_nb - def _padded_indexed_init(self, frame, pdb_atom_names, NB = 'full'): + def _padded_indexed_init(self, frame, pdb_atom_names, NB='full'): from molearn import get_conv_pad_res (bond_idxs, bond_para, angle_idxs, angle_para, angle_mask, ij_jk, @@ -203,12 +200,12 @@ def _bonded_roll_loss(self, x): return bloss/bs, aloss/bs, tloss/bs def _bonded_padded_residues_loss(self, x): - #x.shape [B, R, M, 3] + # x.shape [B, R, M, 3] x = x.view(x.shape[0], -1, 3)[:,] - v = x[:,self.bond_idxs[:,1]]-x[:,self.bond_idxs[:,0]] #j-i == i->j + v = x[:,self.bond_idxs[:,1]]-x[:,self.bond_idxs[:,0]] # j-i == i->j bloss = (((v.norm(dim=2)-self.bond_para[:,0])**2)*self.bond_para[:,1]).sum() - v1 = v[:,self.ij_jk[0]]*self.angle_mask[0].view(1,-1,1) - v2 = v[:,self.ij_jk[1]]*self.angle_mask[1].view(1,-1,1) + v1 = v[:,self.ij_jk[0]]*self.angle_mask[0].view(1, -1, 1) + v2 = v[:,self.ij_jk[1]]*self.angle_mask[1].view(1, -1, 1) xyz=torch.sum(v1*v2, dim=2) / (torch.norm(v1, dim=2) * torch.norm(v2, dim=2)) theta = torch.acos(torch.clamp(xyz, min=-0.999999, max=0.999999)) aloss = (((theta-self.angle_para[:,0])**2)*self.angle_para[:,1]).sum() @@ -233,51 +230,51 @@ def _cdist_nb_full(self, x, cutoff=9.0, mask=False): return torch.nansum(LJpA-LJpB+Cp) def _cdist_nb(self, x, cutoff=9.0, mask=False): - dmat = torch.cdist(x.permute(0,2,1),x.permute(0,2,1)) + dmat = torch.cdist(x.permute(0, 2, 1),x.permute(0, 2, 1)) LJp = self.vdw_A/(self._warp_domain(dmat, 1.9)**12) Cp = (self.q1q2/self._warp_domain(dmat, 0.4)) return torch.nansum(LJp+Cp) - def _warp_domain(self,x,k): + def _warp_domain(self, x, k): return torch.nn.functional.elu(x-k, 1.0)+k def _conv_bond_loss(self, x): - #x shape[B, 3, N] + # x shape[B, 3, N] loss=torch.tensor(0.0).float().to(self.device) for i, weight in enumerate(self.b_weights): - y = torch.nn.functional.conv1d(x, weight.view(1,1,-1).repeat(3,1,1).to(self.device), groups=3, padding=(len(weight)-2)) + y = torch.nn.functional.conv1d(x, weight.view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight)-2)) loss+=(self.b_force[i]*((y.norm(dim=1)-self.b_equil[i])**2)).sum() return loss def _conv_angle_loss(self, x): - #x shape[X, 3, N] + # x shape[X, 3, N] loss=torch.tensor(0.0).float().to(self.device) for i, weight in enumerate(self.a_weights): - v1 = torch.nn.functional.conv1d(x, weight[0].view(1,1,-1).repeat(3,1,1).to(self.device), groups=3, padding=(len(weight[0])-3)) - v2 = torch.nn.functional.conv1d(x, weight[1].view(1,1,-1).repeat(3,1,1).to(self.device), groups=3, padding=(len(weight[1])-3)) - xyz=torch.sum(v1*v2, dim=1) / (torch.norm(v1, dim=1) * torch.norm(v2, dim=1)) + v1 = torch.nn.functional.conv1d(x, weight[0].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[0])-3)) + v2 = torch.nn.functional.conv1d(x, weight[1].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[1])-3)) + xyz = torch.sum(v1*v2, dim=1) / (torch.norm(v1, dim=1) * torch.norm(v2, dim=1)) theta = torch.acos(torch.clamp(xyz, min=-0.999999, max=0.999999)) energy = (self.a_force[i]*((theta-self.a_equil[i])**2)).sum(dim=0)[self.a_masks[i]].sum() loss+=energy return loss def _conv_torsion_loss(self, x): - #x shape[X, 3, N] + # x shape[X, 3, N] loss=torch.tensor(0.0).float().to(self.device) for i, weight in enumerate(self.t_weights): - b1 = torch.nn.functional.conv1d(x, weight[0].view(1,1,-1).repeat(3,1,1).to(self.device), groups=3, padding=(len(weight[0])-4))#i-j - b2 = torch.nn.functional.conv1d(x, weight[1].view(1,1,-1).repeat(3,1,1).to(self.device), groups=3, padding=(len(weight[1])-4))#j-k - b3 = torch.nn.functional.conv1d(x, weight[2].view(1,1,-1).repeat(3,1,1).to(self.device), groups=3, padding=(len(weight[2])-4))#l-k - c32=torch.cross(b3,b2) - c12=torch.cross(b1,b2) - torsion=torch.atan2((b2*torch.cross(c32,c12)).sum(dim=1), + b1 = torch.nn.functional.conv1d(x, weight[0].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[0])-4)) # i-j + b2 = torch.nn.functional.conv1d(x, weight[1].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[1])-4)) # j-k + b3 = torch.nn.functional.conv1d(x, weight[2].view(1, 1, -1).repeat(3, 1, 1).to(self.device), groups=3, padding=(len(weight[2])-4)) # l-k + c32 = torch.cross(b3,b2) + c12 = torch.cross(b1,b2) + torsion = torch.atan2((b2*torch.cross(c32,c12)).sum(dim=1), b2.norm(dim=1)*((c12*c32).sum(dim=1))) p = self.t_para[i,:,:,:].unsqueeze(0) - loss+=((p[:,:,1]/p[:,:,0])*(1+torch.cos((p[:,:,3]*torsion.unsqueeze(2))-p[:,:,2]))).sum() + loss+=((p[:,:,1]/p[:, :, 0])*(1+torch.cos((p[:, :, 3]*torsion.unsqueeze(2))-p[:, :, 2]))).sum() return loss def _roll_bond_angle_torsion_loss(self, x): - #x.shape [5,3,2145] + # x.shape [5,3,2145] bloss = torch.tensor(0.0).float().to(self.device) aloss = torch.tensor(0.0).float().to(self.device) tloss = torch.tensor(0.0).float().to(self.device) @@ -289,21 +286,20 @@ def _roll_bond_angle_torsion_loss(self, x): for i, diff in enumerate(self.ardiff): v1 = self.arsign[i][0]*(v[diff[0]].roll(-self.arroll[i][0],2)) v2 = self.arsign[i][1]*(v[diff[1]].roll(-self.arroll[i][1],2)) - xyz=torch.sum(v1*v2, dim=1) / (torch.norm(v1, dim=1) * torch.norm(v2, dim=1)) + xyz = torch.sum(v1*v2, dim=1)/(torch.norm(v1, dim=1)*torch.norm(v2, dim=1)) theta = torch.acos(torch.clamp(xyz, min=-0.999999, max=0.999999)) energy=(self.ar_force[i]*((theta-self.ar_equil[i])**2)) sum_e = energy.sum() - aloss+=(sum_e) + aloss += (sum_e) for i, diff in enumerate(self.trdiff): - b1 = self.trsign[i][0]*(v[diff[0]].roll(-self.trroll[i][0],2)) - b2 = self.trsign[i][1]*(v[diff[1]].roll(-self.trroll[i][1],2)) - b3 = self.trsign[i][2]*(v[diff[2]].roll(-self.trroll[i][2],2)) - c32=torch.cross(b3,b2) - c12=torch.cross(b1,b2) - torsion=torch.atan2((b2*torch.cross(c32,c12)).sum(dim=1), + b1 = self.trsign[i][0]*(v[diff[0]].roll(-self.trroll[i][0], 2)) + b2 = self.trsign[i][1]*(v[diff[1]].roll(-self.trroll[i][1], 2)) + b3 = self.trsign[i][2]*(v[diff[2]].roll(-self.trroll[i][2], 2)) + c32 = torch.cross(b3,b2) + c12 = torch.cross(b1,b2) + torsion = torch.atan2((b2*torch.cross(c32,c12)).sum(dim=1), b2.norm(dim=1)*((c12*c32).sum(dim=1))) p = self.tr_para[i].unsqueeze(0) - tloss+=( ((p[:,:,1]/p[:,:,0])*(1+torch.cos((p[:,:,3]*torsion.unsqueeze(2))-p[:,:,2]))).sum()) - return bloss,aloss,tloss - + tloss += (((p[:, :, 1]/p[:, :, 0])*(1+torch.cos((p[:, :, 3]*torsion.unsqueeze(2))-p[:, :, 2]))).sum()) + return bloss, aloss, tloss diff --git a/src/molearn/loss_functions/torch_protein_energy_utils.py b/src/molearn/loss_functions/torch_protein_energy_utils.py index 393fc05..4def9e5 100644 --- a/src/molearn/loss_functions/torch_protein_energy_utils.py +++ b/src/molearn/loss_functions/torch_protein_energy_utils.py @@ -9,12 +9,9 @@ # You should have received a copy of the GNU General Public License along with molearn ; # if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. - import numpy as np import torch -import time from copy import deepcopy -import biobox import os import pkg_resources @@ -23,37 +20,37 @@ def read_lib_file(file_name, amber_atoms, atom_charge, connectivity): try: f_location = pkg_resources.resource_filename('molearn', 'parameters') f_in = open(f'{f_location}/{file_name}') - print('File %s opened' % file_name) - except Exception as ex: - raise Exception('ERROR: file %s not found!' % file_name) - + print(f'File {file_name} opened') + except Exception: + raise Exception(f'ERROR: file {file_name} not found!') lines = f_in.readlines() depth = 0 indexs = {} for tline in lines: - if tline.split()==['!!index', 'array', 'str']: + if tline.split() == ['!!index', 'array', 'str']: depth+=1 for line in lines[depth:]: - if line[0]!=' ': + if line[0] != ' ': break contents = line.split() if len(contents)!=1 and len(contents[0])!=5: break res=contents[0] if res[0]=='"' and res[-1] =='"': - amber_atoms[res[1:-1]]={} - atom_charge[res[1:-1]]={} - indexs[res[1:-1]]={} - connectivity[res[1:-1]]={} + amber_atoms[res[1:-1]] = {} + atom_charge[res[1:-1]] = {} + indexs[res[1:-1]] = {} + connectivity[res[1:-1]] = {} else: - raise Exception(('I was expecting something of the form' + raise Exception(('I was expecting something of the form' +'"XXX" but got %s instead' % res)) - depth+=1 + depth += 1 break - depth+=1 + depth += 1 + for i, tline in enumerate(lines): - entry, res, unit_atoms, unit_connectivity = tline[0:7],tline[7:10], tline[10:22], tline[10:29] + entry, res, unit_atoms, unit_connectivity = tline[0:7], tline[7:10], tline[10:22], tline[10:29] if entry=='!entry.' and unit_atoms=='.unit.atoms ': depth=i+1 for line in lines[depth:]: @@ -63,37 +60,38 @@ def read_lib_file(file_name, amber_atoms, atom_charge, connectivity): if len(contents)<3 and len(contents[0])>4 and len(contents[1])>4: break pdb_name, amber_name,_,_,_,index,element_number,charge = contents - if (pdb_name[0]=='"' and pdb_name[-1]=='"' - and amber_name[0]=='"' and amber_name[-1]=='"'): + if (pdb_name[0] == '"' and pdb_name[-1] == '"' + and amber_name[0] == '"' and amber_name[-1] == '"'): amber_atoms[res][contents[0][1:-1]] = contents[1][1:-1] atom_charge[res][amber_name[1:-1]] = float(charge) - #indexs[res][amber_name[1:-1]] = int(index) + # indexs[res][amber_name[1:-1]] = int(index) indexs[res][int(index)] = pdb_name[1:-1] connectivity[res][pdb_name[1:-1]] = [] else: - raise Exception(('I was expecting something of the form' + raise Exception(('I was expecting something of the form' +'"XXX" but got %s instead' % res)) - elif entry=='!entry.' and unit_connectivity=='.unit.connectivity ': - depth=i+1 + elif entry == '!entry.' and unit_connectivity == '.unit.connectivity ': + depth = i+1 for line in lines[depth:]: - if line[0]!=' ': + if line[0] != ' ': break contents = line.split() - if len(contents)!=3: + if len(contents) != 3: break - a1,a2,flag = contents + a1, a2, flag = contents connectivity[res][indexs[res][int(a1)]].append(indexs[res][int(a2)]) connectivity[res][indexs[res][int(a2)]].append(indexs[res][int(a1)]) + def get_amber_parameters(order=False, radians=True): file_names =('amino12.lib', 'parm10.dat', 'frcmod.ff14SB') - #amber19 is dangerous because they've replaced parameters with cmap - ###### pdb atom names to amber atom names using amino19.lib ###### - amber_atoms = {} # knowledge[res][pdb_atom] = amber_atom + # amber19 is dangerous because they've replaced parameters with cmap + # pdb atom names to amber atom names using amino19.lib + amber_atoms = {} # knowledge[res][pdb_atom] = amber_atom atom_mass = {} atom_polarizability = {} bond_force = {} @@ -121,10 +119,10 @@ def get_amber_parameters(order=False, radians=True): f_location = pkg_resources.resource_filename('molearn', 'parameters') f_in = open(f'{f_location}/{file_names[1]}') print('File %s opened' % file_names[1]) - except Exception as ex: + except Exception: raise Exception('ERROR: file %s not found!' % file_names[1]) - #section 1 title + # section 1 title line = f_in.readline() print(line) @@ -138,103 +136,106 @@ def get_amber_parameters(order=False, radians=True): amber_card_type_8(f_in, other_parameters) amber_card_type_9(f_in, other_parameters) for line in f_in: - if len(line.split())>1: - if line.split()[1]=='RE': + if len(line.split()) > 1: + if line.split()[1] == 'RE': amber_card_type_10B(f_in, other_parameters) - elif line[0:3]=='END': + elif line[0:3] == 'END': print('parameters loaded') f_in.close() - #open frcmod file, should be identifcal format but missing any or all cards + # pen frcmod file, should be identifcal format but missing any or all cards try: f_location = pkg_resources.resource_filename('molearn', 'parameters') f_in = open(f'{f_location}/{file_names[2]}') print('File %s opened' % file_names[2]) - except Exception as ex: + except Exception: raise Exception('ERROR: file %s not found!' % file_names[2]) - - #section 1 title + # section 1 title line = f_in.readline() print(line) for line in f_in: - if line[:4]=='MASS': + if line[:4] == 'MASS': amber_card_type_2(f_in, atom_mass, atom_polarizability) - if line[:4]=='BOND': + if line[:4] == 'BOND': amber_card_type_4(f_in, bond_force, bond_equil) - if line[:4]=='ANGL': + if line[:4] == 'ANGL': amber_card_type_5(f_in, angle_force, angle_equil) - if line[:4]=='DIHE': + if line[:4] == 'DIHE': amber_card_type_6(f_in, torsion_factor, torsion_barrier, torsion_phase, torsion_period) - if line[:4]=='IMPR': + if line[:4] == 'IMPR': amber_card_type_7(f_in, improper_factor, improper_barrier, improper_phase, improper_period) - if line[:4]=='HBON': + if line[:4] == 'HBON': amber_card_type_8(f_in, other_parameters) - if line[:4]=='NONB': + if line[:4] == 'NONB': amber_card_type_10B(f_in, other_parameters) - if line[:4]=='CMAP': + if line[:4] == 'CMAP': print('Yeah, Im not bothering to implement cmap') - elif line[0:3]=='END': + elif line[0:3] == 'END': print('parameters loaded') f_in.close() if radians: for angle in angle_equil: - angle_equil[angle]=np.deg2rad(angle_equil[angle]) + angle_equil[angle] = np.deg2rad(angle_equil[angle]) for torsion in torsion_phase: - torsion_phase[torsion]=list(np.deg2rad(torsion_phase[torsion])) + torsion_phase[torsion] = list(np.deg2rad(torsion_phase[torsion])) return (amber_atoms, atom_mass, atom_polarizability, bond_force, bond_equil, angle_force, angle_equil, torsion_factor, torsion_barrier, torsion_phase, torsion_period, improper_factor, improper_barrier, improper_phase, improper_period, other_parameters) + def amber_card_type_2(f_in, atom_mass, atom_polarizability): - #section 2 input for atom symbols and masses + # section 2 input for atom symbols and masses for line in f_in: - if line=='\n' or line.strip()=='': + if line == '\n' or line.strip() == '': break atom = line[0:2].strip() contents = line[2:24].split() if len(contents)==2: mass, polarizability = float(contents[0]),float(contents[1]) - atom_mass[atom]=mass - atom_polarizability[atom]=polarizability - elif len(contents)==1: # sometimes a polarizability is not listed + atom_mass[atom] = mass + atom_polarizability[atom] = polarizability + elif len(contents) == 1: # sometimes a polarizability is not listed mass = float(contents[0]) - atom_mass[atom]=mass - atom_polarizability[atom]=polarizability + atom_mass[atom] = mass + atom_polarizability[atom] = polarizability else: raise Exception('Should be 2A, X, F10.2, F10.2, comments but got %s' % line) + def amber_card_type_3(f_in): - #section 3 input for atom symbols that are hydrophilic + # section 3 input for atom symbols that are hydrophilic line = f_in.readline() + def amber_card_type_4(f_in, bond_force, bond_equil, order=False): - #section 4 bond length paramters + # section 4 bond length paramters for line in f_in: - if line=='\n' or line.strip()=='': + if line == '\n' or line.strip() == '': break atom1 = line[0:2]. strip() atom2 = line[3:5].strip() if order: - bond = tuple(sorted((atom1, atom2))) # put in alphabetical order + bond = tuple(sorted((atom1, atom2))) # put in alphabetical order else: bond = (atom1, atom2) contents = line[5:25].split() - if len(contents)!=2: + if len(contents) != 2: raise Exception('Expected 2 floats but got %s' % line[6:26]) force_constant, equil_length = float(contents[0]), float(contents[1]) - bond_force[bond]=force_constant - bond_equil[bond]=equil_length - #this should throw an error if there are not + bond_force[bond] = force_constant + bond_equil[bond] = equil_length + # this should throw an error if there are not + def amber_card_type_5(f_in, angle_force, angle_equil, order=False): - #section 5 + # section 5 for line in f_in: if line=='\n' or line.strip()=='': break @@ -254,11 +255,12 @@ def amber_card_type_5(f_in, angle_force, angle_equil, order=False): angle_force[angle]=force_constant angle_equil[angle]=equil_angle + def amber_card_type_6(f_in, torsion_factor, torsion_barrier, torsion_phase, torsion_period, order=False): - #secti on 6 torsion / proper dihedral + # section 6 torsion / proper dihedral for line in f_in: - if line=='\n' or line.strip()=='': + if line == '\n' or line.strip() == '': break atom1 = line[0:2].strip() atom2 = line[3:5].strip() @@ -266,13 +268,13 @@ def amber_card_type_6(f_in, torsion_factor, torsion_barrier, torsion_phase, atom4 = line[9:11].strip() if order: sort23 = sorted([(atom2, atom1), (atom3, atom4)], key=lambda x: x[0]) - torsion = tuple( (sort23[0][1], sort23[0][0], sort23[1][0], sort23[1][1]) ) + torsion = tuple((sort23[0][1], sort23[0][0], sort23[1][0], sort23[1][1])) else: torsion = (atom1, atom2, atom3, atom4) contents = line[11:55].split() - if len(contents)!=4: + if len(contents) != 4: raise Exception('I wanted four values here?') - #the actual torsion potential is (barrier/factor)*(1+cos(period*phi-phase)) + # the actual torsion potential is (barrier/factor)*(1+cos(period*phi-phase)) if torsion in torsion_period: if torsion_period[torsion][-1]>0: torsion_factor[torsion] = [int(contents[0]) ] @@ -290,9 +292,10 @@ def amber_card_type_6(f_in, torsion_factor, torsion_barrier, torsion_phase, torsion_phase[torsion] = [float(contents[2])] torsion_period[torsion] = [float(contents[3])] + def amber_card_type_7(f_in, improper_factor, improper_barrier, improper_phase, improper_period, order=False): - #section 7 improper dihedrals + # section 7 improper dihedrals for line in f_in: if line=='\n' or line.strip()=='': break @@ -302,25 +305,26 @@ def amber_card_type_7(f_in, improper_factor, improper_barrier, atom4 = line[9:11].strip() if order: sort23 = sorted([(atom2, atom1), (atom3, atom4)], key=lambda x: x[0]) - torsion = tuple( (sort23[0][1], sort23[0][0], sort23[1][0], sort23[1][1]) ) + torsion = tuple((sort23[0][1], sort23[0][0], sort23[1][0], sort23[1][1])) else: torsion = (atom1, atom2, atom3, atom4) contents = line[11:55].split() - if len(contents)==3: + if len(contents) == 3: improper_barrier[torsion] = float(contents[0]) improper_phase[torsion] = float(contents[1]) improper_period[torsion] = float(contents[2]) - elif len(contents)==4: + elif len(contents) == 4: raise Exception('This seems allowed in the doc but doesnt appear in reality') improper_factor[torsion] = int(contents[0]) improper_barrier[torsion] = float(contents[1]) improper_phase[torsion] = float(contents[2]) improper_period[torsion] = float(contents[3]) - #the actual torsion potential is (barrier/factor)*(1+cos(period*phi-phase)) - #it seems improper potential don't divide by the factor + # the actual torsion potential is (barrier/factor)*(1+cos(period*phi-phase)) + # it seems improper potential don't divide by the factor + def amber_card_type_8(f_in, other_parameters, order=False): - #section 8 H-bond 10-12 potential parameters + # section 8 H-bond 10-12 potential parameters for line in f_in: if line=='\n' or line.strip()=='': break @@ -333,22 +337,25 @@ def amber_card_type_8(f_in, other_parameters, order=False): contents = line[8:].split() other_parameters['H_bond_10_12_parameters'][pair]=contents + def amber_card_type_9(f_in, other_parameters): - #section 9 equi valencing atom symbols for non-bonded 6-12 potential parameters + # section 9 equivalencing atom symbols for non-bonded 6-12 potential parameters for line in f_in: if line=='\n' or line.strip()=='': break contents = line.split() - other_parameters['equivalences'][contents[0]]=contents + other_parameters['equivalences'][contents[0]] = contents + def amber_card_type_10B(f_in, other_parameters): - #section 10 6-12 potential parameters + # section 10 6-12 potential parameters for line in f_in: if line== '\n' or line.strip()=='': break contents = line.split() other_parameters['vdw_potential_well_depth'][contents[0]] = [float(i) for i in contents[1:3]] + def get_convolutions(dataset, pdb_atom_names, atom_label=('set','string')[0], perform_checks=True, @@ -361,7 +368,7 @@ def get_convolutions(dataset, pdb_atom_names, fix_charmm_residues=True, fix_slice_method=False, fix_h=False, - alt_vdw = [], + alt_vdw=[], permitivity=1.0 ): ''' @@ -405,138 +412,137 @@ def get_convolutions(dataset, pdb_atom_names, ''' - - #get amber parameters + # get amber parameters (amber_atoms, atom_mass, atom_polarizability, bond_force, bond_equil, angle_force, angle_equil, torsion_factor, torsion_barrier, torsion_phase, torsion_period, improper_factor, improper_barrier, improper_phase, - improper_period, other_parameters) = get_amber_parameters() + improper_period, other_parameters) = get_amber_parameters() if fix_terminal: - pdb_atom_names[pdb_atom_names[:,0]=='OXT',0]='O' + pdb_atom_names[pdb_atom_names[:, 0] == 'OXT', 0] = 'O' if fix_charmm_residues: - pdb_atom_names[pdb_atom_names[:,1]=='HSD',1]='HID' - pdb_atom_names[pdb_atom_names[:,1]=='HSE',1]='HIE' - for i in np.unique(pdb_atom_names[:,2]): - res_mask = pdb_atom_names[:,2]==i - if (pdb_atom_names[res_mask, 1]=='HIS').all(): # if a HIS residue - if (pdb_atom_names[res_mask, 0]=='HD1').any() and (pdb_atom_names[res_mask, 0]=='HE2').any(): - pdb_atom_names[res_mask, 1]='HIP' - elif (pdb_atom_names[res_mask, 0]=='HD1').any(): - pdb_atom_names[res_mask, 1]='HID' - elif (pdb_atom_names[res_mask, 0]=='HE2').any(): - pdb_atom_names[res_mask, 1]='HIE' - #if any HIS are remaining it doesn't matter which because the H is dealt with above - pdb_atom_names[pdb_atom_names[:,1]=='HIS',1]='HIE' + pdb_atom_names[pdb_atom_names[:, 1] == 'HSD', 1] = 'HID' + pdb_atom_names[pdb_atom_names[:, 1] == 'HSE', 1] = 'HIE' + for i in np.unique(pdb_atom_names[:, 2]): + res_mask = pdb_atom_names[:, 2] == i + if (pdb_atom_names[res_mask, 1]=='HIS').all(): # if a HIS residue + if (pdb_atom_names[res_mask, 0] == 'HD1').any() and (pdb_atom_names[res_mask, 0] == 'HE2').any(): + pdb_atom_names[res_mask, 1] = 'HIP' + elif (pdb_atom_names[res_mask, 0] == 'HD1').any(): + pdb_atom_names[res_mask, 1] ='HID' + elif (pdb_atom_names[res_mask, 0] == 'HE2').any(): + pdb_atom_names[res_mask, 1] = 'HIE' + # if any HIS are remaining it does not matter which because the H is dealt with above + pdb_atom_names[pdb_atom_names[:, 1]=='HIS', 1]='HIE' if fix_h: - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='MET'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HG1', pdb_atom_names[:,1]=='MET'),0]='HG3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='ASN'),0]='HB3' - pdb_atom_names[pdb_atom_names[:,0]=='HN',0]='H' - pdb_atom_names[pdb_atom_names[:,0]=='1HD2',0]='HD21' - pdb_atom_names[pdb_atom_names[:,0]=='2HD2',0]='HD22' - pdb_atom_names[pdb_atom_names[:,0]=='1HG2',0]='HG21' - pdb_atom_names[pdb_atom_names[:,0]=='2HG2',0]='HG22' - pdb_atom_names[pdb_atom_names[:,0]=='3HG2',0]='HG23' - pdb_atom_names[pdb_atom_names[:,0]=='3HG1',0]='HG13' - pdb_atom_names[pdb_atom_names[:,0]=='1HG1',0]='HG11' - pdb_atom_names[pdb_atom_names[:,0]=='2HG1',0]='HG12' - pdb_atom_names[pdb_atom_names[:,0]=='1HD1',0]='HD11' - pdb_atom_names[pdb_atom_names[:,0]=='2HD1',0]='HD12' - pdb_atom_names[pdb_atom_names[:,0]=='3HD1',0]='HD13' - pdb_atom_names[pdb_atom_names[:,0]=='3HD2',0]='HD23' - pdb_atom_names[pdb_atom_names[:,0]=='1HH1',0]='HH11' - pdb_atom_names[pdb_atom_names[:,0]=='2HH1',0]='HH12' - pdb_atom_names[pdb_atom_names[:,0]=='1HH2',0]='HH21' - pdb_atom_names[pdb_atom_names[:,0]=='2HH2',0]='HH22' - pdb_atom_names[pdb_atom_names[:,0]=='1HE2',0]='HE21' - pdb_atom_names[pdb_atom_names[:,0]=='2HE2',0]='HE22' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HG11', pdb_atom_names[:,1]=='ILE'),0]='HG13' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='CD', pdb_atom_names[:,1]=='ILE'),0]='CD1' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HD1', pdb_atom_names[:,1]=='ILE'),0]='HD11' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HD2', pdb_atom_names[:,1]=='ILE'),0]='HD12' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HD3', pdb_atom_names[:,1]=='ILE'),0]='HD13' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='PHE'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='GLU'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HG1', pdb_atom_names[:,1]=='GLU'),0]='HG3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='LEU'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='ARG'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HG1', pdb_atom_names[:,1]=='ARG'),0]='HG3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HD1', pdb_atom_names[:,1]=='ARG'),0]='HD3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='ASP'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HA1', pdb_atom_names[:,1]=='GLY'),0]='HA3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='LYS'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HG1', pdb_atom_names[:,1]=='LYS'),0]='HG3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HD1', pdb_atom_names[:,1]=='LYS'),0]='HD3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HE1', pdb_atom_names[:,1]=='LYS'),0]='HE3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='TYR'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='HIP'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='SER'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HG1', pdb_atom_names[:,1]=='SER'),0]='HG' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='PRO'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HG1', pdb_atom_names[:,1]=='PRO'),0]='HG3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HD1', pdb_atom_names[:,1]=='PRO'),0]='HD3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='LEU'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='GLN'),0]='HB3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HG1', pdb_atom_names[:,1]=='GLN'),0]='HG3' - pdb_atom_names[np.logical_and(pdb_atom_names[:,0]=='HB1', pdb_atom_names[:,1]=='TRP'),0]='HB3' - #writes termini as H because we haven't loaded in termini parameters - atom_names = [[amber_atoms[res][atom],res, resid] if atom not in ['H2', 'H3'] else [amber_atoms[res]['H'], res, resid] for atom, res, resid in pdb_atom_names ] - p_atom_names = [[atom,res, resid] if atom not in ['H2', 'H3'] else ['H', res, resid] for atom, res, resid in pdb_atom_names ] + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0]=='HB1', pdb_atom_names[:, 1]=='MET'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0]=='HG1', pdb_atom_names[:, 1]=='MET'), 0] = 'HG3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0]=='HB1', pdb_atom_names[:, 1]=='ASN'), 0] = 'HB3' + pdb_atom_names[pdb_atom_names[:, 0] == 'HN', 0] = 'H' + pdb_atom_names[pdb_atom_names[:, 0] == '1HD2', 0] = 'HD21' + pdb_atom_names[pdb_atom_names[:, 0] == '2HD2', 0] = 'HD22' + pdb_atom_names[pdb_atom_names[:, 0] == '1HG2', 0] = 'HG21' + pdb_atom_names[pdb_atom_names[:, 0] == '2HG2', 0] = 'HG22' + pdb_atom_names[pdb_atom_names[:, 0] == '3HG2', 0] = 'HG23' + pdb_atom_names[pdb_atom_names[:, 0] == '3HG1', 0] = 'HG13' + pdb_atom_names[pdb_atom_names[:, 0] == '1HG1', 0] = 'HG11' + pdb_atom_names[pdb_atom_names[:, 0] == '2HG1', 0] = 'HG12' + pdb_atom_names[pdb_atom_names[:, 0] == '1HD1', 0] = 'HD11' + pdb_atom_names[pdb_atom_names[:, 0] == '2HD1', 0] = 'HD12' + pdb_atom_names[pdb_atom_names[:, 0] == '3HD1', 0] = 'HD13' + pdb_atom_names[pdb_atom_names[:, 0] == '3HD2', 0] = 'HD23' + pdb_atom_names[pdb_atom_names[:, 0] == '1HH1', 0] = 'HH11' + pdb_atom_names[pdb_atom_names[:, 0] == '2HH1', 0] = 'HH12' + pdb_atom_names[pdb_atom_names[:, 0] == '1HH2', 0] = 'HH21' + pdb_atom_names[pdb_atom_names[:, 0] == '2HH2', 0] = 'HH22' + pdb_atom_names[pdb_atom_names[:, 0] == '1HE2', 0] = 'HE21' + pdb_atom_names[pdb_atom_names[:, 0] == '2HE2', 0] = 'HE22' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG11', pdb_atom_names[:,1] == 'ILE'), 0] = 'HG13' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'CD', pdb_atom_names[:, 1] == 'ILE'), 0] = 'CD1' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1]=='ILE'), 0] = 'HD11' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD2', pdb_atom_names[:, 1]=='ILE'), 0] = 'HD12' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD3', pdb_atom_names[:, 1]=='ILE'), 0] = 'HD13' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='PHE'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='GLU'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1]=='GLU'), 0] = 'HG3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='LEU'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='ARG'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1]=='ARG'), 0] = 'HG3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1]=='ARG'), 0] = 'HD3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='ASP'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HA1', pdb_atom_names[:, 1]=='GLY'), 0] = 'HA3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='LYS'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1]=='LYS'), 0] = 'HG3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1]=='LYS'), 0] = 'HD3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HE1', pdb_atom_names[:, 1]=='LYS'), 0] = 'HE3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='TYR'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='HIP'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='SER'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1]=='SER'), 0] = 'HG' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='PRO'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1]=='PRO'), 0] = 'HG3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HD1', pdb_atom_names[:, 1]=='PRO'), 0] = 'HD3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='LEU'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='GLN'), 0] = 'HB3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HG1', pdb_atom_names[:, 1]=='GLN'), 0] = 'HG3' + pdb_atom_names[np.logical_and(pdb_atom_names[:, 0] == 'HB1', pdb_atom_names[:, 1]=='TRP'), 0] = 'HB3' + # writes termini as H because we haven't loaded in termini parameters + atom_names = [[amber_atoms[res][atom], res, resid] if atom not in ['H2', 'H3'] else [amber_atoms[res]['H'], res, resid] for atom, res, resid in pdb_atom_names] + p_atom_names = [[atom,res, resid] if atom not in ['H2', 'H3'] else ['H', res, resid] for atom, res, resid in pdb_atom_names] - #atom_names = [[amber_atoms[res][atom],res] for atom, res, resid in pdb_atom_names ] + # atom_names = [[amber_atoms[res][atom],res] for atom, res, resid in pdb_atom_names ] atom_charges=[other_parameters['charge'][res][atom] for atom, res, resid in atom_names] if NB == 'matrix': - equiv_t=other_parameters['equivalences'] + equiv_t = other_parameters['equivalences'] vdw_para = other_parameters['vdw_potential_well_depth'] - #switch these around so that values point to key + # switch these around so that values point to key equiv = {} for i in equiv_t.keys(): j = equiv_t[i] for k in j: - equiv[k]=i - atom_R = torch.tensor([vdw_para[equiv.get(atom,atom)][0] for atom, res, resid in atom_names]) #radius - atom_e = torch.tensor([vdw_para[equiv.get(atom,atom)][1] for atom, res, resid in atom_names]) #welldepth + equiv[k] = i + atom_R = torch.tensor([vdw_para[equiv.get(atom,atom)][0] for atom, res, resid in atom_names]) # radius + atom_e = torch.tensor([vdw_para[equiv.get(atom,atom)][1] for atom, res, resid in atom_names]) # welldepth print('Determining bonds') - version = v # method of selecting bonded atoms - N = dataset.shape[1] #2145 + version = v # method of selecting bonded atoms + N = dataset.shape[1] # 145 cmat=(torch.nn.functional.pdist((dataset).permute(1,0))).cpu().numpy() if version == 1: - bond_idxs=np.argpartition(cmat, (N-1,N)) - #this will work for any non cyclic monomeric protein - #that in mind will break if enough proline atoms to make a cycle are selected + bond_idxs = np.argpartition(cmat, (N-1,N)) + # this will work for any non cyclic monomeric protein + # that in mind will break if enough proline atoms to make a cycle are selected bond_idxs, u = bond_idxs[:N-1], bond_idxs[N-1] if cmat[u]-cmat[bond_idxs[-1]]<0.25: raise Exception("WARNING: May not have correctly selected the bonded distances: value " +(cmat[u]-cmat[bond_idxs[-1]])+ - "should be roughly between 0.42 and 0.57 (>0.25)" )# should be 0.42-0.57 - version+=1 #try version 2 instead - mid = cmat[bond_idxs[-1]]+((cmat[u]-cmat[bond_idxs[-1]])/2) #mid point + "should be roughly between 0.42 and 0.57 (>0.25)") # should be 0.42-0.57 + version+=1 # try version 2 instead + mid = cmat[bond_idxs[-1]]+((cmat[u]-cmat[bond_idxs[-1]])/2) # mid point full_mask = (cmatcdist).triu(diagonal=1).numpy()) remove = np.where(np.abs(j-i)>30) - max_bond_dist[i[remove],j[remove]]=0.0 + max_bond_dist[i[remove], j[remove]]=0.0 max_bond_dist = max_bond_dist.numpy() else: - max_bond_dist = (0.6*(atom_R.view(1,-1)+atom_R.view(-1,1))).cpu().numpy() + max_bond_dist = (0.6*(atom_R.view(1, -1)+atom_R.view(-1, 1))).cpu().numpy() max_bond_dist = max_bond_dist[np.where(np.triu(np.ones((N,N)),k=1))] full_mask = np.greater(max_bond_dist,cmat) - bond_idxs = np.where(full_mask)[0] # for some reason returns tuple with one array + bond_idxs = np.where(full_mask)[0] # for some reason returns tuple with one array if version == 4: - #fix_hydrogens = [[atom, res, resid] for atom, res, resid in pdb_atom_names if atom in ['H2', 'H3']] + # fix_hydrogens = [[atom, res, resid] for atom, res, resid in pdb_atom_names if atom in ['H2', 'H3']] connectivity = other_parameters['connectivity'] bond_types = [] bond_idxs = [] - #tracker = [[]]*N doesn't work because of mutability + # tracker = [[]]*N doesn't work because of mutability tracker = [[] for i in range(N)] current_resid = -9999 current_atoms = [] @@ -548,24 +554,24 @@ def get_convolutions(dataset, pdb_atom_names, continue elif not (atom2 in connectivity[res][atom1] and atom1 in connectivity[res][atom2]): continue - # if not (atom2 in connectivity[res][atom1] and atom1 in connectivity[res][atom2]): - # if resid != current_resid and atom2 == 'C': - # current_resid = resid - # current_atoms = [] - # else: - # continue + # if not (atom2 in connectivity[res][atom1] and atom1 in connectivity[res][atom2]): + # if resid != current_resid and atom2 == 'C': + # current_resid = resid + # current_atoms = [] + # else: + # continue if atom1=='N' and atom2=='CA': continue tracker[i1].append(i2) tracker[i2].append(i1) - if atom_label=='set': + if atom_label == 'set': if order: names = tuple(sorted((atom_names[i2][0], atom_names[i1][0]))) else: names = tuple((atom_names[i2][0], atom_names[i1][0])) bond_types.append(names) bond_idxs.append([i2,i1]) - if resid !=current_resid:# and atom2 == 'C': + if resid !=current_resid: # and atom2 == 'C': current_resid = resid current_atoms = [] current_atoms.append([atom1,i1]) @@ -604,25 +610,25 @@ def get_convolutions(dataset, pdb_atom_names, bond_types = [] bond_idxs = [] - tracker = [[]] # this will keep track of some of the bonds to help work out the angles - atom1=0 - atom2=1 - counter = 0 #index of the distance N,N+1 + tracker = [[]] # this will keep track of some of the bonds to help work out the angles + atom1 = 0 + atom2 = 1 + counter = 0 # index of the distance N,N+1 for bond in all_bond_idxs: if bond < counter+(N-atom1-1): atom2 = atom1+bond-counter+1 # 0-0+1 tracker[-1].append(atom2) # while bond > counter+(N-atom1-2): - counter+=(N-atom1-1) - atom1 +=1 + counter += (N-atom1-1) + atom1 += 1 tracker.append([]) if bond < counter+(N-atom1-1): atom2 = atom1+bond-counter+1 tracker[-1].append(atom2) - if atom_label=='string': #string of atom labels, doesn't handle Proline alternate ordering + if atom_label=='string': # string of atom labels, doesn't handle Proline alternate ordering bond_types.append(atom_names[atom1][0]+'_'+atom_names[atom2][0]) bond_idxs.append([atom1, atom2]) - elif atom_label=='set': #set of atom labels + elif atom_label=='set': # set of atom labels if order: names = tuple(sorted((atom_names[atom1][0], atom_names[atom2][0]))) else: @@ -631,7 +637,7 @@ def get_convolutions(dataset, pdb_atom_names, bond_idxs.append([atom1, atom2]) while len(tracker)atom1 prevents duplicates later ) + counter = 0 + # add missing bonds (each bond counted twice after but atom3>atom1 prevents duplicates later ) if version < 4: - for atom1, atom1_bonds in enumerate(deepcopy(tracker)): # for _, [] in enum [[]] - for atom2 in atom1_bonds: # for _ in [] + for atom1, atom1_bonds in enumerate(deepcopy(tracker)): # for _, [] in enum [[]] + for atom2 in atom1_bonds: # for _ in [] tracker[atom2].append(atom1) # find every angle and add it for atom1, atom1_bonds in enumerate(tracker): for atom2 in atom1_bonds: for atom3 in tracker[atom2]: - if atom3>atom1: #each angle will only be counter once + if atom3>atom1: # each angle will only be counter once if order: - sort13 = sorted([ (atom_names[atom1][0], atom1), (atom_names[atom3][0], atom3) ], key=lambda x: x[0]) - names = tuple( (sort13[0][0], atom_names[atom2][0], sort13[1][0]) ) + sort13 = sorted([(atom_names[atom1][0], atom1), (atom_names[atom3][0], atom3)], key=lambda x: x[0]) + names = tuple((sort13[0][0], atom_names[atom2][0], sort13[1][0])) angle_types.append(names) angle_idxs.append([sort13[0][1], atom2, sort13[1][1]]) @@ -666,12 +672,12 @@ def get_convolutions(dataset, pdb_atom_names, angle_idxs.append([atom1, atom2, atom3]) if atom3 != atom1: for atom4 in tracker[atom3]: - if atom4>atom1 and atom2!=atom4:# each torsion will be counter once - #torsions are done based on the 2 3 atoms, so sort 23 + if atom4>atom1 and atom2!=atom4: # each torsion will be counter once + # torsions are done based on the 2 3 atoms, so sort 23 if order: - sort23 = sorted([ (atom_names[atom2][0], atom2, atom_names[atom1][0], atom1), - (atom_names[atom3][0], atom3, atom_names[atom4][0], atom4) ], key=lambda x: x[0]) - names = tuple( (sort23[0][2], sort23[0][0], sort23[1][0], sort23[1][2]) ) + sort23 = sorted([(atom_names[atom2][0], atom2, atom_names[atom1][0], atom1), + (atom_names[atom3][0], atom3, atom_names[atom4][0], atom4)], key=lambda x: x[0]) + names = tuple((sort23[0][2], sort23[0][0], sort23[1][0], sort23[1][2])) torsion_types.append(names) torsion_idxs.append([sort23[0][3], sort23[0][1], sort23[1][1], sort23[1][3]]) else: @@ -679,45 +685,45 @@ def get_convolutions(dataset, pdb_atom_names, atom_names[atom3][0], atom_names[atom4][0])) torsion_idxs.append([atom1, atom2, atom3, atom4]) bond_14_idxs.append([atom1,atom4]) - #currently have bond_types, angle_types, and torsion_typs + idxs + # currently have bond_types, angle_types, and torsion_typs + idxs bond_idxs = np.array(bond_idxs) angle_idxs = np.array(angle_idxs) torsion_idxs = np.array(torsion_idxs) bond_max_conv = (bond_idxs.max(axis=1)-bond_idxs.min(axis=1)).max()+1 if bond_max_conv<3 and fix_slice_method: - bond_max_conv=3 + bond_max_conv = 3 angle_max_conv = (angle_idxs.max(axis=1)-angle_idxs.min(axis=1)).max()+1 if angle_max_conv<5 and fix_slice_method: - angle_max_conv=5 + angle_max_conv = 5 torsion_max_conv = (torsion_idxs.max(axis=1)-torsion_idxs.min(axis=1)).max()+1 if torsion_max_conv<7 and fix_slice_method: torsion_max_conv=7 - #there is a problem where i accidentally index [padding-3] so if (len -4) < 3 we index -1 which breaks things - #it shouldn't affect anything to say the max conv is greater than 6 - #this little bit just turns the 'types' list into equivalent parameters - #key error if you don't have the parameter - bond_para= np.array([ [bond_equil[bond], bond_force[bond]] if bond in bond_equil + # there is a problem where i accidentally index [padding-3] so if (len -4) < 3 we index -1 which breaks things + # it shouldn't affect anything to say the max conv is greater than 6 + # this little bit just turns the 'types' list into equivalent parameters + # key error if you don't have the parameter + bond_para = np.array([[bond_equil[bond], bond_force[bond]] if bond in bond_equil else [bond_equil[(bond[1],bond[0])], bond_force[(bond[1], bond[0])]] for bond in bond_types]) - angle_para=np.array([ [angle_equil[angle], angle_force[angle]] if angle in angle_equil + angle_para=np.array([[angle_equil[angle], angle_force[angle]] if angle in angle_equil else [angle_equil[(angle[2],angle[1],angle[0])], angle_force[(angle[2], angle[1], angle[0])]] for angle in angle_types]) - torsion_para=[] - t_unique=list(set(torsion_types)) + torsion_para = [] + t_unique = list(set(torsion_types)) t_unique_para = {} max_para = 0 for torsion in t_unique: - torsion_b = (torsion[3],torsion[2],torsion[1],torsion[0]) - torsion_xx = ('X', torsion[2], torsion[1], 'X') + torsion_b = (torsion[3], torsion[2], torsion[1], torsion[0]) + torsion_xx = ('X', torsion[2], torsion[1], 'X') torsion_xb = ('X', torsion[1], torsion[2], 'X') if torsion in torsion_barrier: max_para = max(max_para, len(torsion_barrier[torsion])) - t_unique_para[torsion]= [torsion_factor[torsion], torsion_barrier[torsion], + t_unique_para[torsion] = [torsion_factor[torsion], torsion_barrier[torsion], torsion_phase[torsion], torsion_period[torsion]] elif torsion_b in torsion_barrier: max_para = max(max_para, len(torsion_barrier[torsion_b])) - t_unique_para[torsion]= [torsion_factor[torsion_b], torsion_barrier[torsion_b], - torsion_phase[torsion_b], torsion_period[torsion_b]] + t_unique_para[torsion] = [torsion_factor[torsion_b], torsion_barrier[torsion_b], + torsion_phase[torsion_b], torsion_period[torsion_b]] elif torsion_xx in torsion_barrier: max_para = max(max_para, len(torsion_barrier[torsion_xx])) t_unique_para[torsion]= [torsion_factor[torsion_xx], torsion_barrier[torsion_xx], @@ -729,17 +735,15 @@ def get_convolutions(dataset, pdb_atom_names, else: print('ERROR: Torsion %s cannot be found in torsion_barrier and will not be included'% torsion) torsion_para = np.zeros((len(torsion_types),4,max_para)) - #we don't want barrier/factor to return nan so set factor to 1 by default + # we do not want barrier/factor to return nan so set factor to 1 by default torsion_para[:,0,:]=1.0 for i,torsion in enumerate(torsion_types): para = t_unique_para[torsion] - torsion_para[i,:,:len(para[0])]=para - ##### make phase positive ##### + torsion_para[i,:,:len(para[0])] = para + # make phase positive if absolute_torsion_period: torsion_para[:,3,:] = np.abs(torsion_para[:,3,:]) - - ############################### bonds ################################# bond_masks = np.zeros((bond_max_conv-1, N-(bond_max_conv-1) + 2*(bond_max_conv-2)),dtype=np.bool) @@ -750,96 +754,92 @@ def get_convolutions(dataset, pdb_atom_names, b_force = np.zeros(bond_masks.shape) for i in range(bond_max_conv-1): weight = [0.0]*bond_max_conv - weight[0]=1.0 - weight[i+1]=-1.0 + weight[0] = 1.0 + weight[i+1] = -1.0 bond_weights.append(weight) - mask_index=bond_idxs.min(axis=1)[bond_conv==i]+bond_max_conv-2 - bond_masks[i, mask_index]=True - b_equil[i, mask_index]=bond_para[bond_conv==i,0] - b_force[i, mask_index] =bond_para[bond_conv==i,1] + mask_index = bond_idxs.min(axis=1)[bond_conv==i]+bond_max_conv-2 + bond_masks[i, mask_index] = True + b_equil[i, mask_index] = bond_para[bond_conv==i, 0] + b_force[i, mask_index] = bond_para[bond_conv==i, 1] ############################### angles ################################# - angle_conv = (angle_idxs-angle_idxs.min(axis=1).reshape(-1,1))#relative positions of atoms - angle_conv = np.where((angle_conv[:,0] amber atom names and residues + # pdb atoms -> amber atom names and residues padded_atom_names = np.array([[amber_atoms[res][atom],res, resid] if atom is not None else [atom, res, resid] for atom, res, resid in pdb_atom_names]) - unpadded_atom_names = [[amber_atoms[res][atom],res, resid] for atom, res, resid in pdb_atom_names if atom is not None] + # unpadded_atom_names = [[amber_atoms[res][atom],res, resid] for atom, res, resid in pdb_atom_names if atom is not None] padded_atom_charges = np.array([other_parameters['charge'][res][atom] if atom is not None else np.nan for atom, res, _ in padded_atom_names]) - if padded_atom_names.shape != dataset.shape: # just a little check + if padded_atom_names.shape != dataset.shape: # just a little check raise Exception('996 padded_atom_names!=dataset.shape') atom_names = padded_atom_names atom_charges = padded_atom_charges print('Determining bonds') - connect = other_parameters['connectivity'] - #connectivity = [[]]*N # careful with mutability + # connectivity = [[]]*N # careful with mutability connectivity = [[] for i in range(N)] current_resid = -9999 current_atoms = [] @@ -919,12 +914,12 @@ def get_conv_pad_res(dataset, pdb_atom_names, if atom2 in connect[res][atom1] and atom1 in connect[res][atom2]: connectivity[i1].append(i2) connectivity[i2].append(i1) - # cmat = torch.cdist(dataset, dataset) #[R*M,3 ]-> [R*M, R*M] - # #1.643 was max bond distance in MurD test, 2.129 was the smallest nonbonded distance - # #can't say what the best solution is but somewhere in the middle will probably be okay - # all_bond_mask = (cmat<(1.643+2.1269)/2).triu(diagonal=1) # [R*M,R*M] - # bond_idxs = all_bond_mask.nonzero() # [B x 2] - # #name_set = set(atom_names[:,0]) +# cmat = torch.cdist(dataset, dataset) #[R*M,3 ]-> [R*M, R*M] +# #1.643 was max bond distance in MurD test, 2.129 was the smallest nonbonded distance +# #can't say what the best solution is but somewhere in the middle will probably be okay +# all_bond_mask = (cmat<(1.643+2.1269)/2).triu(diagonal=1) # [R*M,R*M] +# bond_idxs = all_bond_mask.nonzero() # [B x 2] +# #name_set = set(atom_names[:,0]) # # connectivity = [[] for i in range(N)] # this will keep track of some of the bonds to help work out the angles # for i,j in bond_idxs: @@ -945,43 +940,43 @@ def get_conv_pad_res(dataset, pdb_atom_names, for atom1, atom2_list in enumerate(connectivity): for atom2 in atom2_list: a1, a2 = atom_names[atom1][0], atom_names[atom2][0] - if atom1 < atom2: #stops any pair of atoms being selected twice + if atom1 < atom2: # stops any pair of atoms being selected twice bond_idxs_.append([atom1, atom2]) - for b in [(a1,a2), (a2,a1)]: + for b in [(a1, a2), (a2, a1)]: if b in bond_equil: bond_para.append([bond_equil[b], bond_force[b]]) - break # break prevents any bond from beind added twice + break # break prevents any bond from beind added twice else: raise Exception('No associated bond parameter') for atom3 in connectivity[atom2]: a3 = atom_names[atom3][0] - if atom3 > atom1: #each angle will only be counter once - angle_idxs.append([atom1,atom2,atom3]) - for a in [(a1,a2,a3), (a3,a2,a1)]: + if atom3 > atom1: # each angle will only be counter once + angle_idxs.append([atom1, atom2, atom3]) + for a in [(a1, a2, a3), (a3, a2, a1)]: if a in angle_equil: angle_para.append([angle_equil[a], angle_force[a]]) break else: raise Exception('No associated angle parameter') - if atom3 != atom1: #don't go back to same atom + if atom3 != atom1: # don't go back to same atom for atom4 in connectivity[atom3]: - if atom4 > atom1 and atom2!=atom4: + if atom4 > atom1 and atom2 != atom4: torsion_idxs.append([atom1, atom2, atom3, atom4]) bond_14_idxs.append([atom1, atom4]) a4 = atom_names[atom4][0] - for t in [(a1,a2,a3,a4),(a4,a3,a2,a1),('X',a2,a3,'X'),('X',a3,a2,'X')]: + for t in [(a1, a2, a3, a4),(a4, a3, a2, a1),('X', a2, a3, 'X'), ('X', a3, a2, 'X')]: if t in torsion_barrier: torsion_para_.append(torch.tensor([ torsion_factor[t], torsion_barrier[t], torsion_phase[t], torsion_period[t]])) - break #each torsion only counter once + break # each torsion only counter once else: raise Exception('No associated torsion parameter') - bond_idxs_ = torch.as_tensor(bond_idxs_) + bond_idxs = torch.as_tensor(bond_idxs_) angle_idxs = torch.tensor(angle_idxs) torsion_idxs = torch.tensor(torsion_idxs) bond_14_idxs = torch.tensor(bond_14_idxs) @@ -989,29 +984,30 @@ def get_conv_pad_res(dataset, pdb_atom_names, angle_para = torch.tensor(angle_para) max_number_torsion_para = max([tf.shape[1] for tf in torsion_para_]) torsion_para = torch.zeros(torsion_idxs.shape[0],4,max_number_torsion_para) - torsion_para[:,0,:]=1.0 + torsion_para[:,0,:] = 1.0 for i,tf in enumerate(torsion_para_): - torsion_para[i,:,0:tf.shape[1]]=tf + torsion_para[i, :, 0:tf.shape[1]] = tf if absolute_torsion_period: - torsion_para[:,3,:] = np.abs(torsion_para[:,3,:]) + torsion_para[:, 3, :] = np.abs(torsion_para[:, 3, :]) + ###### Gather based potential ###### - #currently for data [B, R*M, 3] or [B, 3, N] - aij0 = bond_idxs.reshape(-1,2,1).eq(angle_idxs[:,(0,1)].view(-1,2,1).permute(2,1,0)).all(dim=1) - aij1 = bond_idxs.reshape(-1,2,1).eq(angle_idxs[:,(1,0)].view(-1,2,1).permute(2,1,0)).all(dim=1) - ajk0 = bond_idxs.reshape(-1,2,1).eq(angle_idxs[:,(1,2)].view(-1,2,1).permute(2,1,0)).all(dim=1) - ajk1 = bond_idxs.reshape(-1,2,1).eq(angle_idxs[:,(2,1)].view(-1,2,1).permute(2,1,0)).all(dim=1) + # currently for data [B, R*M, 3] or [B, 3, N] + aij0 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (0, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) + aij1 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (1, 0)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) + ajk0 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (1, 2)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) + ajk1 = bond_idxs.reshape(-1, 2, 1).eq(angle_idxs[:, (2, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) ij_jk = torch.stack([torch.where((aij0+aij1).T)[1], torch.where((ajk0+ajk1).T)[1]]) - aij_ = aij1.float()-aij0.float() #sign change needed for loss_function equation + aij_ = aij1.float()-aij0.float() # sign change needed for loss_function equation ajk_ = ajk0.float()-ajk1.float() angle_mask = torch.stack([aij_.sum(dim=0), ajk_.sum(dim=0)]) - #following are [N_bonds, N_torsions] arrays comparing if the ij or jk are the same - ij0 = bond_idxs.reshape(-1,2,1).eq(torsion_idxs[:,(0,1)].view(-1,2,1).permute(2,1,0)).all(dim=1) - ij1 = bond_idxs.reshape(-1,2,1).eq(torsion_idxs[:,(1,0)].view(-1,2,1).permute(2,1,0)).all(dim=1) - jk0 = bond_idxs.reshape(-1,2,1).eq(torsion_idxs[:,(1,2)].view(-1,2,1).permute(2,1,0)).all(dim=1) - jk1 = bond_idxs.reshape(-1,2,1).eq(torsion_idxs[:,(2,1)].view(-1,2,1).permute(2,1,0)).all(dim=1) - kl0 = bond_idxs.reshape(-1,2,1).eq(torsion_idxs[:,(2,3)].view(-1,2,1).permute(2,1,0)).all(dim=1) - kl1 = bond_idxs.reshape(-1,2,1).eq(torsion_idxs[:,(3,2)].view(-1,2,1).permute(2,1,0)).all(dim=1) + # following are [N_bonds, N_torsions] arrays comparing if the ij or jk are the same + ij0 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (0, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) + ij1 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (1, 0)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) + jk0 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (1, 2)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) + jk1 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (2, 1)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) + kl0 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (2, 3)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) + kl1 = bond_idxs.reshape(-1, 2, 1).eq(torsion_idxs[:, (3, 2)].view(-1, 2, 1).permute(2, 1, 0)).all(dim=1) ij_jk_kl = torch.stack([torch.where((ij0+ij1).T)[1], torch.where((jk0+jk1).T)[1], torch.where((kl0+kl1).T)[1]]) @@ -1020,74 +1016,67 @@ def get_conv_pad_res(dataset, pdb_atom_names, kl_ = kl0.float()-kl1.float() torsion_mask = torch.stack([ij_.sum(dim=0), jk_.sum(dim=0), kl_.sum(dim=0)]) - #j-i i->j - #i-j j->i reverse - #k-j j->k - #j-k k->j reverse - #l-k k->l - #k-l l->k reverse - + # j-i i->j + # i-j j->i reverse + # k-j j->k + # j-k k->j reverse + # l-k k->l + # k-l l->k reverse - if NB=='matrix': - equiv_t=other_parameters['equivalences'] + if NB == 'matrix': + equiv_t = other_parameters['equivalences'] vdw_para = other_parameters['vdw_potential_well_depth'] - #switch these around so that values point to key + # switch these around so that values point to key equiv = {} for i in equiv_t.keys(): j = equiv_t[i] for k in j: - equiv[k]=i - atom_R = torch.tensor([vdw_para[equiv.get(i,i)][0] if i is not None else np.nan for i, j, k in atom_names]) #radius - atom_e = torch.tensor([vdw_para[equiv.get(i,i)][1] if i is not None else np.nan for i, j, k in atom_names]) #welldepth - #cdist is easier to work with than pdist, batch pdist doesn't seem to exist too + equiv[k] = i + atom_R = torch.tensor([vdw_para[equiv.get(i,i)][0] if i is not None else np.nan for i, j, k in atom_names]) # radius + atom_e = torch.tensor([vdw_para[equiv.get(i,i)][1] if i is not None else np.nan for i, j, k in atom_names]) # welldepth + # cdist is easier to work with than pdist, batch pdist doesn't seem to exist too vdw_R = 0.5*torch.cdist(atom_R.view(-1,1), -atom_R.view(-1, 1)).triu(diagonal=1) vdw_e = (atom_e.view(1,-1)*atom_e.view(-1, 1)).triu(diagonal=1).sqrt() - #set 1-2, and 1-3 distances to 0.0 - vdw_R[list(bond_idxs.T)]=0.0 - vdw_e[list(bond_idxs.T)]=0.0 - vdw_R[list(angle_idxs[:,(0,2)].T)]=0.0 - vdw_e[list(angle_idxs[:,(0,2)].T)]=0.0 + # set 1-2, and 1-3 distances to 0.0 + vdw_R[list(bond_idxs.T)] = 0.0 + vdw_e[list(bond_idxs.T)] = 0.0 + vdw_R[list(angle_idxs[:,(0, 2)].T)] = 0.0 + vdw_e[list(angle_idxs[:,(0, 2)].T)] = 0.0 if correct_1_4: # sum A/R**12 - B/R**6; A = e* (R**12); B = 2*e *(R**6) # therefore scale vdw by setting e /= 2.0 - #vdw_R[list(torsion_idxs[:,(0,3)].T)]/=2.0 - vdw_e[list(torsion_idxs[:,(0,3)].T)]/=2.0 + # vdw_R[list(torsion_idxs[:,(0,3)].T)]/=2.0 + vdw_e[list(torsion_idxs[:,(0, 3)].T)] /= 2.0 else: - vdw_R[list(torsion_idxs[:,(0,3)].T)]=0.0 - vdw_e[list(torsion_idxs[:,(0,3)].T)]=0.0 - vdw_R[torch.isnan(vdw_R)]=0.0 - vdw_e[torch.isnan(vdw_e)]=0.0 - - #partial charges are given as fragments of electron charge. - #Can convert coulomb energy into kcal/mol by multiplying with 332.05. - #therofore multiply q by sqrt(332.05)=18.22 - e_=permitivity #permittivity + vdw_R[list(torsion_idxs[:,(0, 3)].T)] = 0.0 + vdw_e[list(torsion_idxs[:,(0, 3)].T)] = 0.0 + vdw_R[torch.isnan(vdw_R)] = 0.0 + vdw_e[torch.isnan(vdw_e)] = 0.0 + + # partial charges are given as fragments of electron charge. + # Can convert coulomb energy into kcal/mol by multiplying with 332.05. + # therofore multiply q by sqrt(332.05)=18.22 + e_ = permitivity # permittivity atom_charges=torch.tensor(atom_charges) - q1q2=(atom_charges.view(1,-1)*atom_charges.view(-1,1)/e_).triu(diagonal=1) #Aij=bi*bj - q1q2[list(bond_idxs.T)]=0.0 - q1q2[list(angle_idxs[:,(0,2)].T)]=0.0 + q1q2 = (atom_charges.view(1,-1)*atom_charges.view(-1, 1)/e_).triu(diagonal=1) # Aij=bi*bj + q1q2[list(bond_idxs.T)] = 0.0 + q1q2[list(angle_idxs[:,(0, 2)].T)]=0.0 if correct_1_4: - q1q2[list(torsion_idxs[:,(0,3)].T)]/=1.2 + q1q2[list(torsion_idxs[:,(0, 3)].T)]/=1.2 else: - q1q2[list(torsion_idxs[:,(0,3)].T)]=0.0 - #1-4 are should be included but scaled - return ( - bond_idxs, bond_para, + q1q2[list(torsion_idxs[:,(0, 3)].T)]=0.0 + # 1-4 are should be included but scaled + return (bond_idxs, bond_para, angle_idxs, angle_para, angle_mask, ij_jk, torsion_idxs, torsion_para, torsion_mask, ij_jk_kl, vdw_R, vdw_e, - q1q2, - ) - return ( - bond_idxs, bond_para, + q1q2) + return (bond_idxs, bond_para, angle_idxs, angle_para, angle_mask, ij_jk, - torsion_idxs, torsion_para, torsion_mask, ij_jk_kl, - ) + torsion_idxs, torsion_para, torsion_mask, ij_jk_kl) + if __name__ == '__main__': + import sys - import os sys.path.insert(0, os.path.abspath('../')) - import biobox - - diff --git a/src/molearn/models/CNN_autoencoder.py b/src/molearn/models/CNN_autoencoder.py index 0d70141..0d7481a 100644 --- a/src/molearn/models/CNN_autoencoder.py +++ b/src/molearn/models/CNN_autoencoder.py @@ -16,27 +16,31 @@ class ResidualBlock(nn.Module): def __init__(self, f): super(ResidualBlock, self).__init__() - conv_block = [ nn.Conv1d(f,f, 3, stride=1, padding=1, bias=False), - nn.BatchNorm1d(f), - nn.ReLU(inplace=True), - nn.Conv1d(f,f, 3, stride=1, padding=1, bias=False), - nn.BatchNorm1d(f) ] + conv_block = [nn.Conv1d(f, f, 3, stride=1, padding=1, bias=False), + nn.BatchNorm1d(f), + nn.ReLU(inplace=True), + nn.Conv1d(f, f, 3, stride=1, padding=1, bias=False), + nn.BatchNorm1d(f)] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) - #return torch.relu(x + self.conv_block(x)) #earlier runs were with 'return x + self.conv_block(x)' but not an issue (really?) + # return torch.relu(x + self.conv_block(x)) #earlier runs were with 'return x + self.conv_block(x)' but not an issue (really?) + class To2D(nn.Module): + def __init__(self): super(To2D, self).__init__() pass + def forward(self, x): - z = torch.nn.functional.adaptive_avg_pool2d(x, output_size=(2,1)) + z = torch.nn.functional.adaptive_avg_pool2d(x, output_size=(2, 1)) z = torch.sigmoid(z) return z + class From2D(nn.Module): def __init__(self): super(From2D, self).__init__() @@ -49,8 +53,6 @@ def forward(self, x): return x - - class Autoencoder(nn.Module): ''' This is the autoencoder used in our `Ramaswamy 2021 paper `_. diff --git a/src/molearn/models/foldingnet.py b/src/molearn/models/foldingnet.py index 6b5f353..a25e613 100644 --- a/src/molearn/models/foldingnet.py +++ b/src/molearn/models/foldingnet.py @@ -1,8 +1,8 @@ import torch -import biobox as bb from torch import nn import torch.nn.functional as F + def index_points(point_clouds, index): ''' Given a batch of tensor and index, select sub-tensor. @@ -34,7 +34,7 @@ def knn(x, k): xx = torch.sum(x ** 2, dim=1, keepdim=True) # (B, 1, N) pairwise_distance = -xx - inner - xx.transpose(2, 1) # (B, 1, N), (B, N, N), (B, N, 1) -> (B, N, N) - idx = pairwise_distance.topk(k=k, dim=-1)[1] # (B, N, k) + idx = pairwise_distance.topk(k=k, dim=-1)[1] # (B, N, k) return idx @@ -70,7 +70,7 @@ class Encoder(nn.Module): ''' Graph based encoder ''' - def __init__(self, latent_dimension = 2,**kwargs): + def __init__(self, latent_dimension=2, **kwargs): super(Encoder, self).__init__() self.latent_dimension = latent_dimension self.conv1 = nn.Conv1d(12, 64, 1) @@ -104,7 +104,6 @@ def forward(self, x): x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) - # two consecutive graph layers x = self.graph_layer1(x) x = self.graph_layer2(x) @@ -142,9 +141,9 @@ def forward(self, *args): :param grids: reshaped 2D grids or intermediam reconstructed point clouds """ # concatenate - #try: + # try: # x = torch.cat([*args], dim=1) - #except: + # except: # for arg in args: # print(arg.shape) # raise @@ -154,6 +153,7 @@ def forward(self, *args): return x + class Decoder_Layer(nn.Module): ''' Decoder Module of FoldingNet @@ -163,14 +163,14 @@ def __init__(self, in_points, out_points, in_channel, out_channel,**kwargs): super(Decoder_Layer, self).__init__() # Sample the grids in 2D space - #xx = np.linspace(-0.3, 0.3, 45, dtype=np.float32) - #yy = np.linspace(-0.3, 0.3, 45, dtype=np.float32) - #self.grid = np.meshgrid(xx, yy) # (2, 45, 45) + # xx = np.linspace(-0.3, 0.3, 45, dtype=np.float32) + # yy = np.linspace(-0.3, 0.3, 45, dtype=np.float32) + # self.grid = np.meshgrid(xx, yy) # (2, 45, 45) self.out_points = out_points self.grid = torch.linspace(-0.5, 0.5, out_points).view(1,-1) # reshape - #self.grid = torch.Tensor(self.grid).view(2, -1) # (2, 45, 45) -> (2, 45 * 45) - assert out_points%in_points==0 + # self.grid = torch.Tensor(self.grid).view(2, -1) # (2, 45, 45) -> (2, 45 * 45) + assert out_points % in_points == 0 self.m = out_points//in_points self.fold1 = FoldingLayer(in_channel + 1, [512, 512, out_channel]) @@ -195,6 +195,7 @@ def forward(self, x): return recon2 + class Decoder(nn.Module): ''' Decoder Module of FoldingNet @@ -205,14 +206,12 @@ def __init__(self, out_points, latent_dimension=2, **kwargs): self.latent_dimension = latent_dimension # Sample the grids in 2D space - #xx = np.linspace(-0.3, 0.3, 45, dtype=np.float32) - #yy = np.linspace(-0.3, 0.3, 45, dtype=np.float32) - #self.grid = np.meshgrid(xx, yy) # (2, 45, 45) + # xx = np.linspace(-0.3, 0.3, 45, dtype=np.float32) + # yy = np.linspace(-0.3, 0.3, 45, dtype=np.float32) + # self.grid = np.meshgrid(xx, yy) # (2, 45, 45) - start_out = (out_points//128) +1 - self.out_points = out_points self.layer1 = Decoder_Layer(1, start_out, latent_dimension,3*128) diff --git a/src/molearn/models/small_foldingnet.py b/src/molearn/models/small_foldingnet.py index cdfae34..c31e837 100644 --- a/src/molearn/models/small_foldingnet.py +++ b/src/molearn/models/small_foldingnet.py @@ -1,7 +1,4 @@ -import torch -import biobox as bb from torch import nn -import torch.nn.functional as F from .foldingnet import * @@ -11,15 +8,15 @@ class Small_Decoder(nn.Module): ''' def __init__(self, out_points, in_channel=2, **kwargs): + super().__init__() # Sample the grids in 2D space - #xx = np.linspace(-0.3, 0.3, 45, dtype=np.float32) - #yy = np.linspace(-0.3, 0.3, 45, dtype=np.float32) - #self.grid = np.meshgrid(xx, yy) # (2, 45, 45) + # xx = np.linspace(-0.3, 0.3, 45, dtype=np.float32) + # yy = np.linspace(-0.3, 0.3, 45, dtype=np.float32) + # self.grid = np.meshgrid(xx, yy) # (2, 45, 45) start_out = (out_points//8) +1 - self.out_points = out_points self.layer1 = Decoder_Layer(1, start_out, in_channel,3*32) @@ -35,17 +32,20 @@ def forward(self, x): return x + class Small_AutoEncoder(AutoEncoder): ''' autoencoder architecture derived from FoldingNet. ''' def __init__(self, *args, **kwargs): + super(AutoEncoder, self).__init__() self.encoder = Encoder(*args, **kwargs) self.decoder = Small_Decoder(*args, **kwargs) + class Big_Skinny_Decoder(nn.Module): ''' Decoder Module of FoldingNet @@ -55,7 +55,7 @@ def __init__(self, out_points, latent_dimension=2, **kwargs): super(Decoder, self).__init__() self.latent_dimension = latent_dimension - start_out = (out_points//16) +1 + start_out = (out_points//16)+1 self.out_points = out_points @@ -76,5 +76,6 @@ def forward(self, x): return x + if __name__=='__main__': print('Nothing to see here') diff --git a/src/molearn/scoring/dope_score.py b/src/molearn/scoring/dope_score.py index da7779e..21fb2aa 100644 --- a/src/molearn/scoring/dope_score.py +++ b/src/molearn/scoring/dope_score.py @@ -1,7 +1,7 @@ import numpy as np from copy import deepcopy -from ..utils import ShutUp, cpu_count, random_string +from ..utils import ShutUp, random_string try: import modeller from modeller import * @@ -10,11 +10,11 @@ except Exception as e: print('Error importing modeller: ') print(e) - -from multiprocessing import Pool, Event, get_context +from multiprocessing import get_context import os + class DOPE_Score: ''' This class contains methods to calculate dope without saving to save and load PDB files for every structure. Atoms in a biobox coordinate tensor are mapped to the coordinates in the modeller model directly. @@ -25,7 +25,7 @@ def __init__(self, mol): :param biobox.Molecule mol: One example frame to gain access to the topology. Mol will also be used to save a temporary pdb file that will be reloaded in modeller to create the initial modeller Model. ''' - #set residues names with protonated histidines back to generic HIS name (needed by DOPE score function) + # set residues names with protonated histidines back to generic HIS name (needed by DOPE score function) testH = mol.data["resname"].values testH[testH == "HIE"] = "HIS" testH[testH == "HID"] = "HIS" @@ -34,10 +34,9 @@ def __init__(self, mol): alternate_residue_names = dict(CSS=('CYX',)) atoms = ' '.join(list(_mol.data['name'].unique())) - #tmp_file = f'tmp{np.random.randint(1e10)}.pdb' tmp_file = f'tmp{random_string()}.pdb' - _mol.write_pdb(tmp_file, conformations=[0], split_struc = False) - log.level(0,0,0,0,0) + _mol.write_pdb(tmp_file, conformations=[0], split_struc=False) + log.level(0, 0, 0, 0, 0) env = environ() env.libs.topology.read(file='$(LIB)/top_heav.lib') env.libs.parameters.read(file='$(LIB)/par.lib') @@ -47,7 +46,7 @@ def __init__(self, mol): atom_residue = _mol.get_data(columns=['name', 'resname', 'resid']) atom_order = [] first_index = next(iter(self.fast_ss)).residue.index - offset = atom_residue[0,2]-first_index + offset = atom_residue[0, 2]-first_index for i, j in enumerate(self.fast_ss): if i < len(atom_residue): for j_residue_name in alternate_residue_names.get(j.residue.name, (j.residue.name,)): @@ -59,7 +58,7 @@ def __init__(self, mol): atom_order.append(int(where)) self.fast_atom_order = atom_order # check fast dope atoms - for i,j in enumerate(self.fast_ss): + for i, j in enumerate(self.fast_ss): if i 0: - processes = min(processes, cpu_count()) + processes = min(processes, os.cpu_count()) else: - processes = cpu_count() + processes = os.cpu_count() self.processes = processes self.mol = deepcopy(mol) score = DOPE_Score ctx = get_context(context) self.pool = ctx.Pool(processes=processes, initializer=set_global_score, initargs=(score, dict(mol=mol)), - **kwargs, - ) + **kwargs) self.process_function = process_dope def __reduce__(self): @@ -182,5 +187,6 @@ def get_score(self, coords, **kwargs): ''' :param np.array coords: # shape (N, 3) numpy array ''' - #is copy necessary? + + # is copy necessary? return self.pool.apply_async(self.process_function, (coords.copy(), kwargs)) diff --git a/src/molearn/scoring/ramachandran_score.py b/src/molearn/scoring/ramachandran_score.py index 83df671..dbad82e 100644 --- a/src/molearn/scoring/ramachandran_score.py +++ b/src/molearn/scoring/ramachandran_score.py @@ -1,35 +1,38 @@ import numpy as np from copy import deepcopy -from multiprocessing import Pool, Event, get_context +from multiprocessing import get_context from scipy.spatial.distance import cdist from iotbx.data_manager import DataManager from mmtbx.validation.ramalyze import ramalyze from scitbx.array_family import flex -from ..utils import cpu_count, random_string +from ..utils import random_string import os -class Ramachandran_Score(): + +class Ramachandran_Score: ''' This class contains methods that use iotbx/mmtbx to calulate the quality of phi and psi values in a protein. ''' + def __init__(self, mol, threshold=1e-3): ''' :param biobox.Molecule mol: One example frame to gain access to the topology. Mol will also be used to save a temporary pdb file that will be reloaded to create the initial iotbx Model. :param float threshold: (default: 1e-3) Threshold used to determine similarity between biobox.molecule coordinates and iotbx model coordinates. Determine that iotbx model was created successfully. ''' + tmp_file = f'rama_tmp{random_string()}.pdb' - mol.write_pdb(tmp_file, split_struc = False)#'rama_tmp.pdb') - filename = tmp_file#'rama_tmp.pdb' + mol.write_pdb(tmp_file, split_struc=False) + filename = tmp_file self.mol = mol - self.dm = DataManager(datatypes = ['model']) + self.dm = DataManager(datatypes=['model']) self.dm.process_model_file(filename) self.model = self.dm.get_model(filename) - self.score = ramalyze(self.model.get_hierarchy()) # get score to see if this works + self.score = ramalyze(self.model.get_hierarchy()) # get score to see if this works self.shape = self.model.get_sites_cart().as_numpy_array().shape - #tests + # tests x = self.mol.coordinates[0] m = self.model.get_sites_cart().as_numpy_array() assert m.shape == x.shape @@ -38,15 +41,15 @@ def __init__(self, mol, threshold=1e-3): assert not np.any(((m-x[self.idxs])>threshold)) os.remove(tmp_file) - def get_score(self, coords, as_ratio = False): + def get_score(self, coords, as_ratio=False): ''' Given coords (corresponding to self.mol) will calculate Ramachandran scores using cctbux ramalyze module Returns the counts of number of torsion angles that fall within favored, allowed, and outlier regions and finally the total number of torsion angles analysed. :param numpy.ndarray coords: shape (N, 3) :returns: (favored, allowed, outliers, total) :rtype: tuple of ints - ''' + assert coords.shape == self.shape self.model.set_sites_cart(flex.vec3_double(coords[self.idxs].astype(np.double))) self.score = ramalyze(self.model.get_hierarchy()) @@ -60,23 +63,26 @@ def get_score(self, coords, as_ratio = False): return nf, na, no, nt - def set_global_score(score, kwargs): ''' make score a global variable This is used when initializing a multiprocessing process ''' + global worker_ramachandran_score - worker_ramachandran_score = score(**kwargs)#mol = mol, data_dir=data_dir, **kwargs) + worker_ramachandran_score = score(**kwargs) # mol = mol, data_dir=data_dir, **kwargs) + def process_ramachandran(coords, kwargs): ''' ramachandran worker Worker function for multiprocessing class ''' - return worker_ramachandran_score.get_score(coords,**kwargs) + + return worker_ramachandran_score.get_score(coords, **kwargs) -class Parallel_Ramachandran_Score(): + +class Parallel_Ramachandran_Score: ''' A multiprocessing class to get Ramachandran scores. A typical use case would looke like:: @@ -104,29 +110,24 @@ def __init__(self, mol, processes=-1): # set a number of processes as user desires, capped on number of CPUs if processes > 0: - processes = min(processes, cpu_count()) + processes = min(processes, os.cpu_count()) else: - processes = cpu_count() + processes = os.cpu_count() self.mol = deepcopy(mol) score = Ramachandran_Score ctx = get_context('spawn') self.pool = ctx.Pool(processes=processes, initializer=set_global_score, - initargs=(score, dict(mol=mol)), - ) + initargs=(score, dict(mol=mol))) self.process_function = process_ramachandran def __reduce__(self): return (self.__class__, (self.mol,)) - - def get_score(self, coords,**kwargs): + def get_score(self, coords, **kwargs): ''' :param coords: # shape (N, 3) numpy array ''' - #is copy necessary? + # is copy necessary? return self.pool.apply_async(self.process_function, (coords.copy(), kwargs)) - - - diff --git a/src/molearn/trainers/openmm_physics_trainer.py b/src/molearn/trainers/openmm_physics_trainer.py index 8442f63..61917e2 100644 --- a/src/molearn/trainers/openmm_physics_trainer.py +++ b/src/molearn/trainers/openmm_physics_trainer.py @@ -13,7 +13,7 @@ class OpenMM_Physics_Trainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold = 1e8, clamp=False, start_physics_at=0, **kwargs): + def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp=False, start_physics_at=0, **kwargs): ''' Create ``self.physics_loss`` object from :func:`loss_functions.openmm_energy ` Needs ``self.mol``, ``self.std``, and ``self._data.atoms`` to have been set with :func:`Trainer.set_data` @@ -28,11 +28,10 @@ def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold = 1e8, cla self.start_physics_at = start_physics_at self.psf = physics_scaling_factor if clamp: - clamp_kwargs = dict(max=clamp_threshold, min = -clamp_threshold) + clamp_kwargs = dict(max=clamp_threshold, min=-clamp_threshold) else: clamp_kwargs = None - self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform = 'CUDA' if self.device == torch.device('cuda') else 'Reference', atoms = self._data.atoms, **kwargs) - + self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, **kwargs) def common_physics_step(self, batch, latent): ''' @@ -45,14 +44,14 @@ def common_physics_step(self, batch, latent): alpha = torch.rand(int(len(batch)//2), 1, 1).type_as(latent) latent_interpolated = (1-alpha)*latent[:-1:2] + alpha*latent[1::2] - generated = self.autoencoder.decode(latent_interpolated)[:,:,:batch.size(2)] + generated = self.autoencoder.decode(latent_interpolated)[:, :, :batch.size(2)] self._internal['generated'] = generated energy = self.physics_loss(generated) - energy[energy.isinf()]=1e35 + energy[energy.isinf()] = 1e35 energy = torch.clamp(energy, max=1e34) energy = energy.nanmean() - return {'physics_loss':energy}#a if not energy.isinf() else torch.tensor(0.0)} + return {'physics_loss':energy} # a if not energy.isinf() else torch.tensor(0.0)} def train_step(self, batch): ''' @@ -92,10 +91,11 @@ def valid_step(self, batch): results = self.common_step(batch) results.update(self.common_physics_step(batch, self._internal['encoded'])) - #scale = (self.psf*results['mse_loss'])/(results['physics_loss'] +1e-5) + # scale = (self.psf*results['mse_loss'])/(results['physics_loss'] +1e-5) final_loss = torch.log(results['mse_loss'])+self.psf*torch.log(results['physics_loss']) results['loss'] = final_loss return results + if __name__=='__main__': pass diff --git a/src/molearn/trainers/sinkhorn_trainer.py b/src/molearn/trainers/sinkhorn_trainer.py index 2e6df73..e0cde84 100644 --- a/src/molearn/trainers/sinkhorn_trainer.py +++ b/src/molearn/trainers/sinkhorn_trainer.py @@ -1,4 +1,3 @@ -import sys import os import glob import numpy as np @@ -6,7 +5,6 @@ from molearn.loss_functions import openmm_energy from molearn.data import PDBData import json -import biobox as bb from time import time try: from geomloss import SamplesLoss @@ -17,10 +15,14 @@ import shutil from copy import deepcopy + class TrainingFailure(Exception): pass + + class Sinkhorn_Trainer(): - def __init__(self, device = None, latent_dim=2, log_filename = 'default_log_file.dat'): + + def __init__(self, device=None, latent_dim=2, log_filename='default_log_file.dat'): if not device: self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') else: @@ -32,28 +34,26 @@ def __init__(self, device = None, latent_dim=2, log_filename = 'default_log_file self.verbose = True self.log_filename = log_filename self.latent_dim = latent_dim - self.sinkhorn = SamplesLoss(loss = 'sinkhorn', p=2, blur=0.05) + self.sinkhorn = SamplesLoss(loss='sinkhorn', p=2, blur=0.05) self.save_time = time() - - def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold = 10000, clamp=False, start_physics_at=0, **kwargs): + def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=10000, clamp=False, start_physics_at=0, **kwargs): self.start_physics_at = start_physics_at self.psf = physics_scaling_factor if clamp: - clamp_kwargs = dict(max=clamp_threshold, min = -clamp_threshold) + clamp_kwargs = dict(max=clamp_threshold, min=-clamp_threshold) else: clamp_kwargs = None - self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform = 'CUDA' if self.device == torch.device('cuda') else 'Reference', atoms = self._data.atoms, **kwargs) + self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, **kwargs) def get_network_summary(self,): + def get_parameters(trainable_only, model): return sum(p.numel() for p in model.parameters() if (p.requires_grad and trainable_only)) return dict( - decoder_trainable = get_parameters(True, self.decoder), - decoder_total = get_parameters(False, self.decoder), - ) - + decoder_trainable=get_parameters(True, self.decoder), + decoder_total=get_parameters(False, self.decoder)) def set_data(self, data,*args, **kwargs): if isinstance(data, PDBData): @@ -69,10 +69,12 @@ def set_data(self, data,*args, **kwargs): def set_dataloader(self, data, *args, **kwargs): if isinstance(data, PDBData): train_dataloader, valid_dataloader = data.get_dataloader(*args, **kwargs) + def cycle(iterable): while True: for x in iterable: yield x + self.train_iterator = iter(cycle(train_dataloader)) self.valid_iterator = iter(cycle(valid_dataloader)) else: @@ -82,7 +84,6 @@ def cycle(iterable): self.mol = data.mol self._data = data - def get_adam_opt(self, *args, **kwargs): self.opt = torch.optim.AdamW(self.decoder.parameters(), *args, **kwargs) @@ -93,8 +94,7 @@ def log(self, log_dict, verbose=None): with open(self.log_filename, 'a') as f: f.write(dump+'\n') - - def run(self, steps=100, validate_every=10, log_filename = None, checkpoint_frequency=1, checkpoint_folder='checkpoints', verbose=None): + def run(self, steps=100, validate_every=10, log_filename=None, checkpoint_frequency=1, checkpoint_folder='checkpoints', verbose=None): if log_filename is not None: self.log_filename = log_filename if verbose is not None: @@ -104,7 +104,7 @@ def run(self, steps=100, validate_every=10, log_filename = None, checkpoint_freq number_of_validations = 0 while self.step valid_logs['valid_loss']: self.checkpoint(valid_logs, checkpoint_folder) - elif number_of_validations%checkpoint_frequency==0: + elif number_of_validations % checkpoint_frequency==0: self.checkpoint(valid_logs, checkpoint_folder) time4 = time() logs = {'step':self.step, **train_logs, **valid_logs, @@ -141,8 +141,6 @@ def training_n_steps(self, steps): results[key] += train_result[key].item() return {f'train_{key}': results[key]/steps for key in results.keys()} - - def validation_one_step(self): self.decoder.eval() result = self.valid_step() @@ -152,7 +150,7 @@ def train_step(self): data = self.train_data z = torch.randn(data.shape[0], self.latent_dim,1).to(self.device) structures = self.decoder(z)[:,:,:data.shape[2]] - loss = self.sinkhorn(structures.reshape(structures.size(0),-1), data.reshape(data.size(0),-1)) + loss = self.sinkhorn(structures.reshape(structures.size(0), -1), data.reshape(data.size(0), -1)) return dict(loss=loss) def valid_step(self): @@ -160,14 +158,14 @@ def valid_step(self): data = self.valid_data z = torch.randn(data.shape[0], self.latent_dim).to(self.device) structures = self.decoder(z)[:,:,:data.shape[2]] - loss = self.sinkhorn(structures.reshape(structures.size(0),-1), data.reshape(data.size(0),-1)) + loss = self.sinkhorn(structures.reshape(structures.size(0), -1), data.reshape(data.size(0), -1)) energy = self.physics_loss(structures) - energy[energy.isinf()]=1e35 + energy[energy.isinf()] = 1e35 energy = torch.clamp(energy, max=1e34) energy = energy.nanmean() z_0 = torch.zeros_like(z).requires_grad_() - structures_0 = self.decoder(z_0)[:,:,:data.shape[2]] + structures_0 = self.decoder(z_0)[:, :, :data.shape[2]] inner_loss = ((structures_0-data)**2).sum(1).mean() encoded = -torch.autograd.grad(inner_loss, [z_0], create_graph=True, retain_graph=True)[0] with torch.no_grad(): @@ -176,29 +174,26 @@ def valid_step(self): mse_loss = se.mean() rmsd = se.sum(1).mean().sqrt()*self.std if time()-self.save_time>120.: - coords = [] + # coords = [] with torch.no_grad(): z1_index, z2_index = self.get_extrema() z1 = encoded[z1_index].unsqueeze(0) z2 = encoded[z2_index].unsqueeze(0) frames = 20 ts = torch.linspace(0,1,frames).to(self.device).unsqueeze(-1) - #from IPython import embed - #embed(headre='valid') + # from IPython import embed + # embed(headre='valid') zinterp =(1-ts)*z1 + ts*z2 if zinterp.shape == (2,frames): zinterp = zinterp.permute(1,0) - interp_structures = self.decoder(zinterp)[:,:,:data.shape[2]] + interp_structures = self.decoder(zinterp)[:, :, :data.shape[2]] mol = deepcopy(self.mol) - mol.coordinates = (interp_structures.permute(0,2,1)*self.std).detach().cpu().numpy() - mol.write_pdb('sample_interp.pdb', split_struc = False) + mol.coordinates = (interp_structures.permute(0, 2, 1)*self.std).detach().cpu().numpy() + mol.write_pdb('sample_interp.pdb', split_struc=False) self.save_time = time() - return dict(loss=loss, physics_loss=energy,mse_loss=mse_loss, rmsd=rmsd) - - def update_optimiser_hyperparameters(self, **kwargs): for g in self.opt.param_groups: for key, value in kwargs.items(): @@ -252,13 +247,13 @@ def load_checkpoint(self, checkpoint_name, checkpoint_folder, load_optimiser=Tru step = checkpoint['step'] self.step = step - def get_extrema(self, ): - #self.train_data [B, 3, N] + def get_extrema(self): + # self.train_data [B, 3, N] if hasattr(self, '_extrema'): return self._extrema a = self.valid_data B = a.shape[0] with torch.no_grad(): - m = ((a.repeat_interleave(B,dim=0)-a.repeat(B,1,1))**2).sum(1).mean(-1).argmax() - self._extrema = (m//B, m%B) + m = ((a.repeat_interleave(B,dim=0)-a.repeat(B, 1, 1))**2).sum(1).mean(-1).argmax() + self._extrema = (m//B, m % B) return self._extrema diff --git a/src/molearn/trainers/torch_physics_trainer.py b/src/molearn/trainers/torch_physics_trainer.py index 7ecad9d..1e6b142 100644 --- a/src/molearn/trainers/torch_physics_trainer.py +++ b/src/molearn/trainers/torch_physics_trainer.py @@ -2,6 +2,7 @@ from molearn.loss_functions import TorchProteinEnergy from .trainer import Trainer + class Torch_Physics_Trainer(Trainer): ''' Torch_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step. @@ -18,7 +19,7 @@ def prepare_physics(self, physics_scaling_factor=0.1): :param float physics_scaling_factor: (default: 0.1) scaling factor saved to ``self.psf`` that is used in :func: `train_step ` It will control the relative importance of mse_loss and physics_loss in training. ''' self.psf = physics_scaling_factor - self.physics_loss = TorchProteinEnergy(self._data.dataset[0]*self.std, pdb_atom_names = self._data.get_atominfo(), device = self.device, method = 'roll') + self.physics_loss = TorchProteinEnergy(self._data.dataset[0]*self.std, pdb_atom_names=self._data.get_atominfo(), device=self.device, method='roll') def common_physics_step(self, batch, latent): ''' @@ -30,8 +31,8 @@ def common_physics_step(self, batch, latent): ''' alpha = torch.rand(int(len(batch)//2), 1, 1).type_as(latent) latent_interpolated = (1-alpha)*latent[:-1:2] + alpha*latent[1::2] - generated = self.autoencoder.decode(latent_interpolated)[:,:,:batch.size(2)] - bond, angle, torsion = self.physics_loss._roll_bond_angle_torsion_loss(generated*self.std) + generated = self.autoencoder.decode(latent_interpolated)[:, :, :batch.size(2)] + bond, angle, torsion = self.physics_loss._roll_bond_angle_torsion_loss(generated*self.std) n = len(generated) bond/=n angle/=n @@ -39,11 +40,10 @@ def common_physics_step(self, batch, latent): _all = torch.tensor([bond, angle, torsion]) _all[_all.isinf()]=1e35 total_physics = _all.nansum() - #total_physics = torch.nansum(torch.tensor([bond ,angle ,torsion])) + # total_physics = torch.nansum(torch.tensor([bond ,angle ,torsion])) return {'physics_loss':total_physics, 'bond_energy':bond, 'angle_energy':angle, 'torsion_energy':torsion} - def train_step(self, batch): ''' This method overrides :func:`Trainer.train_step ` and adds an additional 'Physics_loss' term. @@ -80,11 +80,11 @@ def valid_step(self, batch): ''' results = self.common_step(batch) results.update(self.common_physics_step(batch, self._internal['encoded'])) - #scale = self.psf*results['mse_loss']/(results['physics_loss']+1e-5) + # scale = self.psf*results['mse_loss']/(results['physics_loss']+1e-5) final_loss = torch.log(results['mse_loss'])+self.psf*torch.log(results['physics_loss']) results['loss'] = final_loss return results if __name__=='__main__': - pass \ No newline at end of file + pass diff --git a/src/molearn/trainers/trainer.py b/src/molearn/trainers/trainer.py index 0f10fee..950dd6c 100644 --- a/src/molearn/trainers/trainer.py +++ b/src/molearn/trainers/trainer.py @@ -5,14 +5,14 @@ import time import torch from molearn.data import PDBData -import warnings -from decimal import Decimal import json + class TrainingFailure(Exception): pass -class Trainer(): + +class Trainer: ''' Trainer class that defines a number of useful methods for training an autoencoder. @@ -31,9 +31,7 @@ class Trainer(): ''' - - - def __init__(self, device = None, log_filename = 'log_file.dat'): + def __init__(self, device=None, log_filename='log_file.dat'): ''' :param torch.Device device: if not given will be determinined automatically based on torch.cuda.is_available() :param str log_filename: (default: 'default_log_filename.json') file used to log outputs to @@ -51,7 +49,7 @@ def __init__(self, device = None, log_filename = 'log_file.dat'): self.log_filename = 'default_log_filename.json' self.scheduler_key = None - def get_network_summary(self,): + def get_network_summary(self): ''' returns a dictionary containing information about the size of the autoencoder. ''' @@ -59,13 +57,12 @@ def get_parameters(trainable_only, model): return sum(p.numel() for p in model.parameters() if (p.requires_grad and trainable_only)) return dict( - encoder_trainable = get_parameters(True, self.autoencoder.encoder), - encoder_total = get_parameters(False, self.autoencoder.encoder), - decoder_trainable = get_parameters(True, self.autoencoder.decoder), - decoder_total = get_parameters(False, self.autoencoder.decoder), - autoencoder_trainable = get_parameters(True, self.autoencoder), - autoencoder_total = get_parameters(False, self.autoencoder), - ) + encoder_trainable=get_parameters(True, self.autoencoder.encoder), + encoder_total=get_parameters(False, self.autoencoder.encoder), + decoder_trainable=get_parameters(True, self.autoencoder.decoder), + decoder_total=get_parameters(False, self.autoencoder.decoder), + autoencoder_trainable=get_parameters(True, self.autoencoder), + autoencoder_total=get_parameters(False, self.autoencoder)) def set_autoencoder(self, autoencoder, **kwargs): ''' @@ -106,8 +103,7 @@ def set_data(self, data, **kwargs): self.mol = data.mol self._data = data - - def prepare_optimiser(self, lr = 1e-3, weight_decay = 0.0001, **optimiser_kwargs): + def prepare_optimiser(self, lr=1e-3, weight_decay=0.0001, **optimiser_kwargs): ''' The Default optimiser is ``AdamW`` and is saved in ``self.optimiser``. With no optional arguments this function is the same as doing: @@ -117,7 +113,7 @@ def prepare_optimiser(self, lr = 1e-3, weight_decay = 0.0001, **optimiser_kwargs :param float weight_decay: (default: 0.0001) optimiser weight_decay :param \*\*optimiser_kwargs: other kwargs that are passed onto AdamW ''' - self.optimiser = torch.optim.AdamW(self.autoencoder.parameters(), lr=lr, weight_decay = weight_decay, **optimiser_kwargs) + self.optimiser = torch.optim.AdamW(self.autoencoder.parameters(), lr=lr, weight_decay=weight_decay, **optimiser_kwargs) def log(self, log_dict, verbose=None): ''' @@ -142,7 +138,7 @@ def scheduler_step(self, logs): ''' pass - def run(self, max_epochs=100, log_filename = None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None): + def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None): ''' Calls the following in a loop: @@ -183,20 +179,20 @@ def run(self, max_epochs=100, log_filename = None, log_folder=None, checkpoint_f self.scheduler_step(logs) if self.best is None or self.best > logs['valid_loss']: self.checkpoint(epoch, logs, checkpoint_folder) - elif epoch%checkpoint_frequency==0: + elif epoch % checkpoint_frequency == 0: self.checkpoint(epoch, logs, checkpoint_folder) time4 = time.time() - logs.update(epoch = epoch, + logs.update(epoch=epoch, train_seconds=time2-time1, valid_seconds=time3-time2, - checkpoint_seconds= time4-time3, + checkpoint_seconds=time4-time3, total_seconds=time4-time1) self.log(logs) if np.isnan(logs['valid_loss']) or np.isnan(logs['train_loss']): raise TrainingFailure('nan received, failing') self.epoch+= 1 except TrainingFailure: - if attempt==(allow_n_failures-1): + if attempt == (allow_n_failures-1): failure_message = f'Training Failure due to Nan in attempt {attempt}, end now/n' self.log({'Failure':failure_message}) raise TrainingFailure('nan received, failing') @@ -207,7 +203,6 @@ def run(self, max_epochs=100, log_filename = None, log_folder=None, checkpoint_f else: break - def train_epoch(self,epoch): ''' Train one epoch. Called once an epoch from :func:`trainer.run ` @@ -270,10 +265,9 @@ def common_step(self, batch): self._internal['encoded'] = encoded decoded = self.autoencoder.decode(encoded)[:,:,:batch.size(2)] self._internal['decoded'] = decoded - return dict(mse_loss = ((batch-decoded)**2).mean()) + return dict(mse_loss=((batch-decoded)**2).mean()) - - def valid_epoch(self,epoch): + def valid_epoch(self, epoch): ''' Called once an epoch from :func:`trainer.run ` within a no_grad context. This method performs the following functions: @@ -313,7 +307,7 @@ def valid_step(self, batch): results['loss'] = results['mse_loss'] return results - def learning_rate_sweep(self, max_lr=100, min_lr=1e-5, number_of_iterations=1000, checkpoint_folder='checkpoint_sweep',train_on='mse_loss', save=['loss', 'mse_loss']): + def learning_rate_sweep(self, max_lr=100, min_lr=1e-5, number_of_iterations=1000, checkpoint_folder='checkpoint_sweep', train_on='mse_loss', save=['loss', 'mse_loss']): ''' Deprecated method. Performs a sweep of learning rate between ``max_lr`` and ``min_lr`` over ``number_of_iterations``. @@ -328,10 +322,12 @@ def learning_rate_sweep(self, max_lr=100, min_lr=1e-5, number_of_iterations=1000 :rtype: numpy.ndarray ''' self.autoencoder.train() + def cycle(iterable): while True: for i in iterable: yield i + init_loss = 0.0 values = [] data = iter(cycle(self.train_dataloader)) @@ -342,14 +338,13 @@ def cycle(iterable): self.optimiser.zero_grad() result = self.train_step(batch) - #result['loss']/=len(batch) + # result['loss']/=len(batch) result[train_on].backward() self.optimiser.step() values.append((lr,)+tuple((result[name].item() for name in save))) - #print(i,lr, result['loss'].item()) if i==0: init_loss = result[train_on].item() - #if result[train_on].item()>1e6*init_loss: + # if result[train_on].item()>1e6*init_loss: # break values = np.array(values) print('min value ', values[np.nanargmin(values[:,1])]) @@ -397,7 +392,7 @@ def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss' self.best_epoch = epoch self.best = valid_loss - def load_checkpoint(self, checkpoint_name ='best', checkpoint_folder = '', load_optimiser=True): + def load_checkpoint(self, checkpoint_name='best', checkpoint_folder='', load_optimiser=True): ''' Load checkpoint. @@ -417,7 +412,7 @@ def load_checkpoint(self, checkpoint_name ='best', checkpoint_folder = '', load_ _name = f'{checkpoint_folder}/last.ckpt' else: _name = f'{checkpoint_folder}/{checkpoint_name}' - checkpoint = torch.load(_name, map_location = self.device) + checkpoint = torch.load(_name, map_location=self.device) if not hasattr(self, 'autoencoder'): raise NotImplementedError('self.autoencoder does not exist, I have no way of knowing what network you want to load checkoint weights into yet, please set the network first') @@ -429,5 +424,6 @@ def load_checkpoint(self, checkpoint_name ='best', checkpoint_folder = '', load_ epoch = checkpoint['epoch'] self.epoch = epoch+1 + if __name__=='__main__': pass diff --git a/src/molearn/utils.py b/src/molearn/utils.py index 1e9feae..8c17f5f 100644 --- a/src/molearn/utils.py +++ b/src/molearn/utils.py @@ -1,17 +1,19 @@ -import os, sys +import os +import sys import numpy as np import torch import random import string + def random_string(length=32): ''' generate a random string of arbitrary characters. Useful to generate temporary file names. :param length: length of random string ''' - return ''.join([random.choice(string.ascii_letters) - for n in range(length)]) + return ''.join(random.choice(string.ascii_letters) + for n in range(length)) def as_numpy(tensor): @@ -22,33 +24,14 @@ def as_numpy(tensor): else: return np.array(tensor) - -def cpu_count(): - """ detect the number of available CPU """ - if hasattr(os, "sysconf"): - if "SC_NPROCESSORS_ONLN" in list(os.sysconf_names): - # Linux & Unix - ncpus = os.sysconf("SC_NPROCESSORS_ONLN") - if isinstance(ncpus, int) and ncpus > 0: - return ncpus - else: - # OSX - return int(os.popen2("sysctl -n hw.ncpu")[1].read()) - # Windows - if "NUMBER_OF_PROCESSORS" in list(os.environ): - ncpus = int(os.environ["NUMBER_OF_PROCESSORS"]); - if ncpus > 0: - return ncpus - - return 1 +class ShutUp: - -class ShutUp(object): def __enter__(self): self._stdout = sys.stdout sys.stdout = open(os.devnull, 'w') def __exit__(self, *args): sys.stdout.close() - sys.stdout = self._stdout \ No newline at end of file + sys.stdout = self._stdout + \ No newline at end of file