diff --git a/README.md b/README.md
index b565e9d..195c617 100644
--- a/README.md
+++ b/README.md
@@ -2,19 +2,39 @@
### Visual representation of the branch-and-cut tree of SCIP using spatial dissimilarities of LP solutions -- [Interactive Example](http://www.zib.de/miltenberger/treed-showcase.html)
-[![Example](res/PlotlyTree.gif)](https://plot.ly/~mattmilten/103/)
-
-## Usage:
-- run `TreeD.py` to get usage information
-
-## Dependencies:
-- [PySCIPOpt](https://github.com/SCIP-Interfaces/PySCIPOpt) to solve the instance and generate the necessary tree data
-- [Plot.ly](https://plot.ly/) to draw the 3D visualization
+[![Example](res/treed-example.png)](https://plot.ly/~mattmilten/103/)
+## Installation
+
+```
+python -m pip install treed
+```
+
+## Usage
+- run Python script `bin/treed` (will be installed into your PATH on Linux/macOS when using `pip install treed`) to get usage information or use this code snippet in a Jupyter notebook:
+
+```
+from treed import TreeD
+
+treed = TreeD(
+ probpath="model.mps",
+ nodelimit=20,
+ transformation='mds',
+ showcuts=True
+)
+
+treed.solve()
+fig = treed.draw()
+fig.show(renderer='notebook')
+```
+
+## Dependencies
+- [PySCIPOpt](https://github.com/scipopt/PySCIPOpt) to solve the instance and generate the necessary tree data
+- [Plotly](https://plot.ly/) to draw the 3D visualization
- [pandas](https://pandas.pydata.org/) to organize the collected data
- [sklearn](http://scikit-learn.org/stable/) for multi-dimensional scaling
-- [pysal](https://github.com/pysal) to compute statistics based on spatial (dis)similarity
+- [pysal](https://github.com/pysal) to compute statistics based on spatial (dis)similarity; this is optional
-## Export to [Amira](https://amira.zib.de/):
+## Export to [Amira](https://amira.zib.de/)
- run `AmiraTreeD.py` to get usage information.
`AmiraTreeD.py` generates the '.am' data files to be loaded by Amira software to draw the tree using LineRaycast.
diff --git a/bin/treed b/bin/treed
index 5076d1d..8190381 100644
--- a/bin/treed
+++ b/bin/treed
@@ -36,5 +36,5 @@ treed = TreeD(
nodelimit=args.nodelimit,
)
-treed.main()
+treed.solve()
treed.draw()
diff --git a/res/treed-example.png b/res/treed-example.png
new file mode 100644
index 0000000..9a0919b
Binary files /dev/null and b/res/treed-example.png differ
diff --git a/setup.py b/setup.py
index f8fbd02..c4e744a 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
setup(
name="treed",
- version="0.0.2",
+ version="1.0.0",
author="Matthias Miltenberger",
author_email="matthias.miltenberger@gmail.com",
description="3D Visualization of Branch-and-Cut Trees using PySCIPOpt",
diff --git a/src/treed/treed.py b/src/treed/treed.py
index 4983c54..31cb6dd 100644
--- a/src/treed/treed.py
+++ b/src/treed/treed.py
@@ -122,10 +122,16 @@ def transform(self):
df = self.df["LPsol"].apply(pd.Series).fillna(value=0)
if self.transformation == "tsne":
mf = manifold.TSNE(n_components=2)
+ elif self.transformation == "lle":
+ mf = manifold.LocallyLinearEmbedding(n_components=2)
+ elif self.transformation == "ltsa":
+ mf = manifold.LocallyLinearEmbedding(n_components=2, method="ltsa")
+ elif self.transformation == "spectral":
+ mf = manifold.SpectralEmbedding(n_components=2)
else:
mf = manifold.MDS(n_components=2)
self.xy = mf.fit_transform(df)
- self.stress = mf.stress_
+ # self.stress = mf.stress_ # no available with all transformations
self.df["x"] = self.xy[:, 0]
self.df["y"] = self.xy[:, 1]
@@ -259,6 +265,60 @@ def _create_nodes_and_projections(self):
)
return node_object, proj_object
+ def _create_nodes_frames(self):
+ colorbar = go.scatter3d.marker.ColorBar(title="", thickness=10, x=0)
+ marker = go.scatter3d.Marker(
+ symbol=self.df["symbol"],
+ size=self.nodesize,
+ color=self.df["age"],
+ colorscale=self.colorscale,
+ colorbar=colorbar,
+ )
+
+ frames = []
+ sliders_dict = dict(
+ active=0,
+ yanchor="top",
+ xanchor="left",
+ currentvalue={"prefix": "Age:", "visible": True, "xanchor": "right",},
+ len=0.9,
+ x=0.05,
+ y=0.1,
+ steps=[],
+ )
+
+ for a in self.df["age"]:
+ adf = self.df[self.df["age"] <= a]
+ node_object = go.Scatter3d(
+ x=adf["x"],
+ y=adf["y"],
+ z=adf["objval"],
+ mode="markers+text",
+ marker=marker,
+ hovertext=adf["number"],
+ hovertemplate="LP obj: %{z}
node number: %{hovertext}
%{marker.color}",
+ hoverinfo="z+text+name",
+ opacity=0.7,
+ name="LP solutions",
+ )
+ frames.append(go.Frame(data=node_object, name=str(a)))
+
+ slider_step = {
+ "args": [
+ [a],
+ {
+ "frame": {"redraw": True, "restyle": False},
+ "fromcurrent": True,
+ "mode": "immediate",
+ },
+ ],
+ "label": a,
+ "method": "animate",
+ }
+ sliders_dict["steps"].append(slider_step)
+
+ return frames, sliders_dict
+
def draw(self):
"""Draw the tree, depending on the mode"""
@@ -266,6 +326,8 @@ def draw(self):
nodes, nodeprojs = self._create_nodes_and_projections()
+ frames, sliders = self._create_nodes_frames()
+
edges = go.Scatter3d(
x=self.Xe,
y=self.Ye,
@@ -324,45 +386,104 @@ def draw(self):
scene=scene,
)
- updatemenus = list(
+ layout["updatemenus"] = list(
[
dict(
buttons=list(
[
dict(
- args=["marker.color", [self.df["age"]]],
label="Node Age",
method="restyle",
+ args=[
+ {
+ "marker.color": [self.df["age"]],
+ "marker.cauto": min(self.df["age"]),
+ "marker.cmax": max(self.df["age"]),
+ }
+ ],
),
dict(
- args=["marker.color", [self.df["depth"]]],
label="Tree Depth",
method="restyle",
+ args=[
+ {
+ "marker.color": [self.df["depth"]],
+ "marker.cauto": min(self.df["depth"]),
+ "marker.cmax": max(self.df["depth"]),
+ }
+ ],
),
dict(
- args=["marker.color", [self.df["condition"]]],
label="LP Condition (log 10)",
method="restyle",
+ args=[
+ {
+ "marker.color": [self.df["condition"]],
+ "marker.cmin": 1,
+ "marker.cmax": 20,
+ }
+ ],
),
dict(
- args=["marker.color", [self.df["iterations"]]],
label="LP Iterations",
method="restyle",
+ args=[
+ {
+ "marker.color": [self.df["iterations"]],
+ "marker.cauto": min(self.df["iterations"]),
+ "marker.cmax": max(self.df["iterations"]),
+ }
+ ],
),
]
),
direction="down",
showactive=True,
type="buttons",
- # x = 1.2,
- # y = 0.6,
+ ),
+ dict(
+ buttons=list(
+ [
+ dict(
+ label="▶",
+ method="animate",
+ args=[
+ None,
+ {
+ "frame": {
+ "duration": 50,
+ "redraw": True,
+ },
+ "fromcurrent": True,
+ },
+ ],
+ args2=[
+ [None],
+ {
+ "frame": {"duration": 0, "redraw": False},
+ "mode": "immediate",
+ "transition": {"duration": 0},
+ },
+ ],
+ )
+ ]
+ ),
+ direction="left",
+ yanchor="top",
+ xanchor="right",
+ showactive=True,
+ type="buttons",
+ x=0,
+ y=0,
),
]
)
- layout["updatemenus"] = updatemenus
+ layout["sliders"] = [sliders]
- self.fig = go.Figure(data=[nodes, nodeprojs, edges, optval], layout=layout)
+ self.fig = go.Figure(
+ data=[nodes, nodeprojs, edges, optval], layout=layout, frames=frames,
+ )
self.fig.write_html(file=filename, include_plotlyjs=self.include_plotlyjs)
@@ -373,7 +494,7 @@ def draw(self):
return self.fig
- def main(self):
+ def solve(self):
"""Solve the instance and collect and generate the tree data"""
self.nodelist = []