Skip to content

Commit

Permalink
formatting and small cleanup
Browse files Browse the repository at this point in the history
- first pass fixing indentations, spaces, and syntax details (guided by flake8)
- using os.cpu_count() instead of our custom function
  • Loading branch information
degiacom committed Jul 20, 2023
1 parent 4a9bea9 commit fafeab7
Show file tree
Hide file tree
Showing 17 changed files with 724 additions and 776 deletions.
67 changes: 23 additions & 44 deletions src/molearn/analysis/GUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ..utils import as_numpy


class MolearnGUI(object):
class MolearnGUI:
'''
This class produces an interactive visualisation for data stored in a
:func:`MolearnAnalysis <molearn.analysis.MolearnAnalysis>` object,
Expand All @@ -49,11 +49,10 @@ def __init__(self, MA=None):
else:
self.MA = MA

self.waypoints = [] # collection of all saved waypoints
self.samples = [] # collection of all calculated sampling points
self.waypoints = [] # collection of all saved waypoints
self.samples = [] # collection of all calculated sampling points

self.run()


def update_trails(self):
'''
Expand All @@ -76,7 +75,6 @@ def update_trails(self):
self.latent.data[2].y = self.samples[:, 1]

self.latent.update()


def on_click(self, trace, points, selector):
'''
Expand All @@ -97,14 +95,13 @@ def on_click(self, trace, points, selector):
# update textbox (triggering update of 3D representation)
try:
pt = self.waypoints.flatten().round(decimals=4).astype(str)
#pt = np.array([self.latent.data[3].x, self.latent.data[3].y]).T.flatten().round(decimals=4).astype(str)
# pt = np.array([self.latent.data[3].x, self.latent.data[3].y]).T.flatten().round(decimals=4).astype(str)
self.mybox.value = " ".join(pt)
except Exception:
return

self.update_trails()


def get_samples(self, mybox, samplebox, path):
'''
provide a trail of point between list of waypoints, either connected
Expand All @@ -120,17 +117,17 @@ def get_samples(self, mybox, samplebox, path):
crd = np.array(mybox.split()).astype(float)
crd = crd.reshape((int(len(crd)/2), 2))
except Exception:
raise Exception("Cannot define sampling points")
return
raise Exception("Cannot define sampling points")
return

if use_path:
# connect points via A*
try:
landscape = self.latent.data[0].z
crd = get_path_aggregate(crd, landscape.T, self.MA.xvals, self.MA.yvals)
except Exception as e:
raise Exception(f"Cannot define sampling points: path finding failed. {e})")
return
raise Exception(f"Cannot define sampling points: path finding failed. {e})")
return

else:
# connect points via straight line
Expand All @@ -141,7 +138,6 @@ def get_samples(self, mybox, samplebox, path):
return

return crd


def interact_3D(self, mybox, samplebox, path):
'''
Expand All @@ -152,7 +148,7 @@ def interact_3D(self, mybox, samplebox, path):
crd = self.get_samples(mybox, samplebox, path)
self.samples = crd.copy()
crd = crd.reshape((1, len(crd), 2))
except:
except Exception:
self.button_pdb.disabled = True
return

Expand All @@ -169,12 +165,11 @@ def interact_3D(self, mybox, samplebox, path):
self.mymol.load_new(gen)
view = nv.show_mdanalysis(self.mymol)
view.add_representation("spacefill")
#view.add_representation("cartoon")
# view.add_representation("cartoon")
display.display(view)

self.button_pdb.disabled = False


def drop_background_event(self, change):
'''
control colouring style of latent space surface
Expand All @@ -186,7 +181,7 @@ def drop_background_event(self, change):
mykey = change.new

try:
data = self.MA.surfaces[mykey]
data = self.MA.surfaces[mykey]
except Exception as e:
print(f"{e}")
return
Expand All @@ -204,7 +199,7 @@ def drop_background_event(self, change):
self.latent.data[0].zmax = np.max(data)
self.block0.children[1].min = np.min(data)
self.block0.children[1].max = np.max(data)
except:
except Exception:
self.latent.data[0].zmax = np.max(data)
self.latent.data[0].zmin = np.min(data)
self.block0.children[1].max = np.max(data)
Expand All @@ -214,7 +209,6 @@ def drop_background_event(self, change):

self.update_trails()


def drop_dataset_event(self, change):
'''
control which dataset is displayed
Expand All @@ -226,7 +220,7 @@ def drop_dataset_event(self, change):

else:
try:
data = as_numpy(self.MA.get_encoded(change.new).squeeze(2))
data = as_numpy(self.MA.get_encoded(change.new).squeeze(2))
except Exception as e:
print(f"{e}")
return
Expand All @@ -238,7 +232,6 @@ def drop_dataset_event(self, change):

self.latent.update()


def drop_path_event(self, change):
'''
control way paths are looked for
Expand All @@ -251,7 +244,6 @@ def drop_path_event(self, change):

self.update_trails()


def range_slider_event(self, change):
'''
update surface colouring upon manipulation of range slider
Expand All @@ -261,7 +253,6 @@ def range_slider_event(self, change):
self.latent.data[0].zmax = change.new[1]
self.latent.update()


def trail_update_event(self, change):
'''
update trails (waypoints and way they are connected)
Expand All @@ -270,15 +261,14 @@ def trail_update_event(self, change):
try:
crd = np.array(self.mybox.value.split()).astype(float)
crd = crd.reshape((int(len(crd)/2), 2))
except:
except Exception:
self.button_pdb.disabled = False
return

self.waypoints = crd.copy()

self.update_trails()


def button_pdb_event(self, check):
'''
save PDB file corresponding to the interpolation shown in the 3D view
Expand Down Expand Up @@ -307,7 +297,6 @@ def button_pdb_event(self, check):
for ts in self.mymol.trajectory:
W.write(protein)


def button_save_state_event(self, check):
'''
save class state
Expand All @@ -321,8 +310,7 @@ def button_save_state_event(self, check):
if fname == "":
return

pickle.dump([self.MA, self.waypoints], open( fname, "wb" ) )

pickle.dump([self.MA, self.waypoints], open(fname, "wb"))

def button_load_state_event(self, check):
'''
Expand All @@ -338,7 +326,7 @@ def button_load_state_event(self, check):
return

try:
self.MA, self.waypoints = pickle.load( open( fname, "rb" ) )
self.MA, self.waypoints = pickle.load(open(fname, "rb"))
self.run()
except Exception as e:
raise Exception(f"Cannot load state file. {e}")
Expand All @@ -349,7 +337,7 @@ def run(self):

# create an MDAnalysis instance of input protein (for viewing purposes)
if hasattr(self.MA, "mol"):
self.MA.mol.write_pdb("tmp.pdb", conformations=[0], split_struc = False)
self.MA.mol.write_pdb("tmp.pdb", conformations=[0], split_struc=False)
self.mymol = mda.Universe('tmp.pdb')

### MENU ITEMS ###
Expand All @@ -376,7 +364,6 @@ def run(self):

self.drop_background.observe(self.drop_background_event, names='value')


# dataset selector dropdown menu
options2 = ["none"]
if self.MA is not None:
Expand Down Expand Up @@ -406,7 +393,6 @@ def run(self):

self.drop_path.observe(self.drop_path_event, names='value')


# text box holding current coordinates
self.mybox = widgets.Textarea(placeholder='coordinates',
description='crds:',
Expand All @@ -421,31 +407,27 @@ def run(self):

self.samplebox.observe(self.trail_update_event, names='value')


# button to save PDB file
self.button_pdb = widgets.Button(
description='Save PDB',
disabled=True, layout=Layout(flex='1 1 0%', width='auto'))

self.button_pdb.on_click(self.button_pdb_event)


# button to save state file
self.button_save_state = widgets.Button(
description= 'Save state',
description='Save state',
disabled=False, layout=Layout(flex='1 1 0%', width='auto'))

self.button_save_state.on_click(self.button_save_state_event)


# button to load state file
self.button_load_state = widgets.Button(
description= 'Load state',
description='Load state',
disabled=False, layout=Layout(flex='1 1 0%', width='auto'))

self.button_load_state.on_click(self.button_load_state_event)


# latent space range slider
self.range_slider = widgets.FloatRangeSlider(
description='cmap range:',
Expand All @@ -463,8 +445,7 @@ def run(self):

if self.waypoints == []:
self.button_pdb.disabled = True



### LATENT SPACE REPRESENTATION ###

# surface
Expand Down Expand Up @@ -502,7 +483,7 @@ def run(self):

# path
plot3 = go.Scatter(x=np.array([]), y=np.array([]),
showlegend=False, opacity=0.9, mode = 'lines+markers',
showlegend=False, opacity=0.9, mode='lines+markers',
marker=dict(color='red', size=4))

self.latent = go.FigureWidget([plot1, plot2, plot3])
Expand All @@ -521,7 +502,7 @@ def run(self):
try:
self.range_slider.min = scmin
self.range_slider.max = scmax
except:
except Exception:
self.range_slider.max = scmax
self.range_slider.min = scmin

Expand All @@ -530,8 +511,7 @@ def run(self):

# 3D protein representation (triggered by update of textbox, sampling box, or pathfinding method)
self.protein = widgets.interactive_output(self.interact_3D, {'mybox': self.mybox, 'samplebox': self.samplebox, 'path': self.drop_path})



### WIDGETS ARRANGEMENT ###

self.block0 = widgets.VBox([self.drop_dataset, self.range_slider,
Expand All @@ -555,4 +535,3 @@ def run(self):

display.clear_output(wait=True)
display.display(self.scene)

Loading

0 comments on commit fafeab7

Please sign in to comment.