Skip to content

Commit

Permalink
enh: first support for py3dmol
Browse files Browse the repository at this point in the history
  • Loading branch information
pfebrer committed Nov 19, 2022
1 parent f81b00c commit 597399b
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 2 deletions.
2 changes: 2 additions & 0 deletions sisl/viz/nodes/canvas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

from .plotly import PlotlyCanvas
from .matplotlib import MatplotlibCanvas
from .py3dmol import Py3DmolCanvas
#from .blender import BlenderCanvas

from ..dispatcher import Dispatcher
Expand All @@ -11,4 +12,5 @@ class Canvas(Dispatcher):

Canvas.register("plotly", PlotlyCanvas)
Canvas.register("matplotlib", MatplotlibCanvas)
Canvas.register("py3dmol", Py3DmolCanvas)
#Canvas.register("blender", BlenderCanvas)
2 changes: 1 addition & 1 deletion sisl/viz/nodes/canvas/blender.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def draw_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, colle
}

for k, v in style.items():
if (not isinstance(v, collections.abc.Sequence)) or isinstance(v, str):
if (not isinstance(v, (collections.abc.Sequence, np.ndarray))) or isinstance(v, str):
style[k] = itertools.repeat(v)

ball = template_ball
Expand Down
9 changes: 8 additions & 1 deletion sisl/viz/nodes/canvas/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,13 @@ def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_angle=20, arrowhead_scale=0.3,

conebase_xyz = xyz + (1 - arrowhead_scale) * dxyz

rows_cols = {}
if row is not None:
rows_cols['rows'] = [row, row]
if col is not None:
rows_cols['cols'] = [col, col]


self.figure.add_traces([{
"x": arrows_coords[:, 0],
"y": arrows_coords[:, 1],
Expand Down Expand Up @@ -598,7 +605,7 @@ def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_angle=20, arrowhead_scale=0.3,
"legendgroup": name,
"name": name,
"showlegend": True,
}], rows=[row, row], cols=[col, col])
}], **rows_cols)

def draw_heatmap(self, values, x=None, y=None, name=None, zsmooth=False, coloraxis=None, row=None, col=None):

Expand Down
145 changes: 145 additions & 0 deletions sisl/viz/nodes/canvas/py3dmol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
import collections.abc
import itertools

from .canvas import CanvasNode

import numpy as np

import py3Dmol

class Py3DmolCanvas(CanvasNode):
"""Generic canvas for the py3Dmol framework"""

def _init_figure(self, *args, **kwargs):
self.figure = py3Dmol.view()

def draw_line(self, x, y, name="", line={}, marker={}, text=None, row=None, col=None, **kwargs):
z = np.full_like(x, 0)
# x = self._2D_scale[0] * x
# y = self._2D_scale[1] * y
return self.draw_line_3D(x, y, z, name=name, line=line, marker=marker, text=text, row=row, col=col, **kwargs)

def draw_scatter(self, x, y, name=None, marker={}, text=None, row=None, col=None, **kwargs):
z = np.full_like(x, 0)
# x = self._2D_scale[0] * x
# y = self._2D_scale[1] * y
return self.draw_scatter_3D(x, y, z, name=name, marker=marker, text=text, row=row, col=col, **kwargs)

def draw_line_3D(self, x, y, z, line={}, name="", collection=None, frame=None, **kwargs):
"""Draws a line."""

xyz = np.array([x, y, z], dtype=float).T

# To be compatible with other frameworks such as plotly and matplotlib,
# we allow x, y and z to contain None values that indicate discontinuities
# E.g.: x=[0, 1, None, 2, 3] means we should draw a line from 0 to 1 and another
# from 2 to 3.
# Here, we get the breakpoints (i.e. indices where there is a None). We add
# -1 and None at the sides to facilitate iterating.
breakpoint_indices = [-1, *np.where(np.isnan(xyz).any(axis=1))[0], None]

# Now loop through all segments using the known breakpoints
for start_i, end_i in zip(breakpoint_indices, breakpoint_indices[1:]):
# Get the coordinates of the segment
segment_xyz = xyz[start_i+1: end_i]

# If there is nothing to draw, go to next segment
if len(segment_xyz) == 0:
continue

points = [{"x": x, "y": y, "z": z} for x, y, z in segment_xyz]

# If there's only two points, py3dmol doesn't display the curve,
# probably because it can not smooth it.
if len(points) == 2:
points.append(points[-1])

self.figure.addCurve(dict(
points=points,
radius=line.get("width", 0.1),
color=line.get("color"),
opacity=line.get('opacity', 1.) or 1.,
smooth=1
))

return self

def draw_balls_3D(self, x, y, z, name=None, marker={}, row=None, col=None, collection=None, frame=None, **kwargs):
style = {
"color": marker.get("color", "gray"),
"opacity": marker.get("opacity", 1.),
"size": marker.get("size", 1.),
}

for k, v in style.items():
if (not isinstance(v, (collections.abc.Sequence, np.ndarray))) or isinstance(v, str):
style[k] = itertools.repeat(v)

for i, (x_i, y_i, z_i, color, opacity, size) in enumerate(zip(x, y, z, style["color"], style["opacity"], style["size"])):
self.figure.addSphere(dict(
center={"x": float(x_i), "y": float(y_i), "z": float(z_i)}, radius=size, color=color, opacity=opacity,
quality=5., # This does not work, but sphere quality is really bad by default
))

draw_scatter_3D = draw_balls_3D

def draw_arrows_3D(self, x, y, z, dxyz, arrowhead_scale=0.3, arrowhead_angle=15, row=None, col=None, line={},**kwargs):
"""Draws multiple arrows using the generic draw_line method.
Parameters
-----------
xy: np.ndarray of shape (n_arrows, 2)
the positions where the atoms start.
dxy: np.ndarray of shape (n_arrows, 2)
the arrow vector.
arrow_head_scale: float, optional
how big is the arrow head in comparison to the arrow vector.
arrowhead_angle: angle
the angle that the arrow head forms with the direction of the arrow (in degrees).
row: int, optional
If the figure contains subplots, the row where to draw.
col: int, optional
If the figure contains subplots, the column where to draw.
"""
# Make sure we are dealing with numpy arrays
xyz = np.array([x, y, z]).T
dxyz = np.array(dxyz)

for (x, y, z), (dx, dy, dz) in zip(xyz, dxyz):

self.figure.addArrow(dict(
start={"x": x, "y": y, "z": z},
end={"x": x + dx, "y": y + dy, "z": z + dz},
radius=line.get("width", 0.1),
color=line.get("color"),
opacity=line.get("opacity", 1.),
radiusRatio=2,
mid=(1 - arrowhead_scale),
))

def draw_mesh_3D(self, vertices, faces, color=None, opacity=None, name="Mesh", wireframe=False, row=None, col=None, **kwargs):

def vec_to_dict(a, labels="xyz"):
return dict(zip(labels,a))

self.figure.addCustom(dict(
vertexArr=[vec_to_dict(v) for v in vertices.astype(float)],
faceArr=[int(x) for f in faces for x in f],
color=color,
opacity=float(opacity or 1.),
wireframe=wireframe
))

def set_axis(self, *args, **kwargs):
"""There are no axes titles and these kind of things in py3dmol.
At least for now, we might implement it later."""

def set_axes_equal(self, *args, **kwargs):
"""Axes are always "equal" in py3dmol, so we do nothing here"""

def show(self, *args, **kwargs):
self.figure.zoomTo()
return self.figure.show()

0 comments on commit 597399b

Please sign in to comment.