Skip to content

Commit

Permalink
Fix kwargs edge case, clean code, and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinchern committed Nov 21, 2023
1 parent 7acafe3 commit 4fe8ffc
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 47 deletions.
76 changes: 29 additions & 47 deletions dwave_networkx/drawing/qubit_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,61 +131,44 @@ def node_color(v):

# since we're applying the colormap here, matplotlib throws warnings if
# we provide these arguments and it doesn't use them.
cmap = kwargs.pop('cmap', plt.get_cmap('coolwarm'))
vmin = kwargs.pop('vmin', -1 * vmag)
vmax = kwargs.pop('vmax', vmag)
cmap = kwargs.pop('cmap', None) or plt.get_cmap('coolwarm')
vmin = kwargs.pop('vmin', None) or -1 * vmag
vmax = kwargs.pop('vmax', None) or vmag

edge_cmap = kwargs.pop('edge_cmap', plt.get_cmap('coolwarm'))
edge_vmin = kwargs.pop('edge_vmin', -1 * vmag)
edge_vmax = kwargs.pop('edge_vmax', vmag)
edge_cmap = kwargs.pop('edge_cmap', None) or plt.get_cmap('coolwarm')
edge_vmin = kwargs.pop('edge_vmin', None) or -1 * vmag
edge_vmax = kwargs.pop('edge_vmax', None) or vmag

if linear_biases and quadratic_biases:
global_vmin = min(edge_vmin, vmin)
global_vmax = max(edge_vmax, vmax)
final_vmin = min(edge_vmin, vmin)
final_vmax = max(edge_vmax, vmax)
mapper = cmap

if midpoint is None:
midpoint = (global_vmax + global_vmin) / 2.0
norm_map = mpl.colors.TwoSlopeNorm(midpoint, vmin=global_vmin, vmax=global_vmax)

node_color = [cmap(norm_map(node)) for node in node_color]
edge_color = [cmap(norm_map(edge)) for edge in edge_color]
mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm_map, orientation='vertical')

# if the biases are provided, then add a legend explaining the color map
elif linear_biases:
if midpoint is None:
midpoint = (vmax + vmin) / 2.0
norm_map = mpl.colors.TwoSlopeNorm(midpoint, vmin=vmin, vmax=vmax)
node_color = [cmap(norm_map(node)) for node in node_color]
mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm_map, orientation='vertical')
final_vmin = vmin
final_vmax = vmax
mapper = cmap

elif quadratic_biases:
if midpoint is None:
midpoint = (edge_vmax + edge_vmin) / 2.0
norm_map = mpl.colors.TwoSlopeNorm(midpoint, vmin=edge_vmin, vmax=edge_vmax)
edge_color = [edge_cmap(norm_map(edge)) for edge in edge_color]
mpl.colorbar.ColorbarBase(cax, cmap=edge_cmap, norm=norm_map, orientation='vertical')
final_vmin = edge_vmin
final_vmax = edge_vmax
mapper = edge_cmap

kwargs['edge_color'] = edge_color
kwargs['node_color'] = node_color
midpoint = midpoint or (final_vmax + final_vmin) / 2.0
norm_map = mpl.colors.TwoSlopeNorm(midpoint, vmin=final_vmin, vmax=final_vmax)
mpl.colorbar.ColorbarBase(cax, cmap=mapper, norm=norm_map, orientation='vertical')
kwargs['node_color'] = [mapper(norm_map(node)) for node in node_color]
kwargs['edge_color'] = [mapper(norm_map(edge)) for edge in edge_color]

else:
if ax is None:
ax = fig.add_axes([0.01, 0.01, 0.98, 0.98])

if linear_biases and not quadratic_biases:
kwargs['edge_vmin'] = edge_vmin
kwargs['edge_vmax'] = edge_vmax
kwargs['edge_cmap'] = edge_cmap
if quadratic_biases and not linear_biases:
kwargs['vmin'] = vmin
kwargs['vmax'] = vmax
kwargs['cmap'] = cmap
draw(G, layout, ax=ax, nodelist=nodelist, edgelist=edgelist, **kwargs)


def draw_embedding(G, layout, emb, embedded_graph=None, interaction_edges=None,
chain_color=None, unused_color=(0.9,0.9,0.9,1.0), cmap=None,
chain_color=None, unused_color=(0.9, 0.9, 0.9, 1.0), cmap=None,
show_labels=False, overlapped_embedding=False, **kwargs):
"""Draws an embedding onto the graph G, according to layout.
Expand Down Expand Up @@ -423,10 +406,9 @@ def unoverlapped_embedding(G, emb, interaction_edges):
return new_G, new_emb, new_interaction_edges


def draw_yield(G, layout, perfect_graph, unused_color=(0.9,0.9,0.9,1.0),
fault_color=(1.0,0.0,0.0,1.0), fault_shape='x',
fault_style='dashed', **kwargs):

def draw_yield(G, layout, perfect_graph, unused_color=(0.9, 0.9, 0.9, 1.0),
fault_color=(1.0, 0.0, 0.0, 1.0), fault_shape='x',
fault_style='dashed', **kwargs):
"""Draws the given graph G with highlighted faults, according to layout.
Parameters
Expand Down Expand Up @@ -482,9 +464,9 @@ def draw_yield(G, layout, perfect_graph, unused_color=(0.9,0.9,0.9,1.0),

# Draw faults with different style and shape
draw(perfect_graph, layout, nodelist=faults_nodelist, edgelist=faults_edgelist,
node_color=faults_node_color, edge_color=faults_edge_color,
style=fault_style, node_shape=fault_shape,
**kwargs )
node_color=faults_node_color, edge_color=faults_edge_color,
style=fault_style, node_shape=fault_shape,
**kwargs)

# Draw rest of graph
if unused_color is not None:
Expand All @@ -497,5 +479,5 @@ def draw_yield(G, layout, perfect_graph, unused_color=(0.9,0.9,0.9,1.0),
unused_edge_color = [unused_color for v in edgelist]

draw(perfect_graph, layout, nodelist=nodelist, edgelist=edgelist,
node_color=unused_node_color, edge_color=unused_edge_color,
**kwargs)
node_color=unused_node_color, edge_color=unused_edge_color,
**kwargs)
59 changes: 59 additions & 0 deletions tests/test_qubit_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2018 D-Wave Systems Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

import dwave_networkx as dnx

try:
import matplotlib.pyplot as plt
except ImportError:
plt = False

try:
import numpy as np
except ImportError:
np = False

_display = os.environ.get('DISPLAY', '') != ''


@unittest.skipUnless(np and plt, "No numpy or matplotlib")
class TestDrawing(unittest.TestCase):
@unittest.skipUnless(_display, " No display found")
def test_draw_qubit_graph_kwargs(self):
G = dnx.chimera_graph(2, 2, 4)
pos = dnx.chimera_layout(G)
linear_biases = {v: -v / max(G) if v % 2 == 0 else v / max(G) for v in G}
quadratic_biases = {(u, v): (u - v) / max(abs(u), abs(v)) for u, v in G.edges}
cm = plt.get_cmap("spring_r")

# Don't supply biases
dnx.drawing.qubit_layout.draw_qubit_graph(G, pos)

# Supply both biases
dnx.drawing.qubit_layout.draw_qubit_graph(G, pos, linear_biases, quadratic_biases)

# Supply linear but not quadratic biases
dnx.drawing.qubit_layout.draw_qubit_graph(G, pos, linear_biases)
dnx.drawing.qubit_layout.draw_qubit_graph(G, pos, linear_biases, None)
dnx.drawing.qubit_layout.draw_qubit_graph(G, pos, linear_biases, None, cmap=None)
dnx.drawing.qubit_layout.draw_qubit_graph(G, pos, linear_biases, None, cmap=cm)

# Supply quadratic but not linear biases
dnx.drawing.qubit_layout.draw_qubit_graph(G, pos, {}, quadratic_biases)
dnx.drawing.qubit_layout.draw_qubit_graph(G, pos, None, quadratic_biases)
dnx.drawing.qubit_layout.draw_qubit_graph(G, pos, None, quadratic_biases, edge_cmap=None)
dnx.drawing.qubit_layout.draw_qubit_graph(G, pos, None, quadratic_biases, edge_cmap=cm)

0 comments on commit 4fe8ffc

Please sign in to comment.