Skip to content

Commit

Permalink
Added configurable title for each canvas plot
Browse files Browse the repository at this point in the history
  • Loading branch information
Mocuto authored and waleedka committed Apr 24, 2020
1 parent eb409c2 commit 45243d5
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions hiddenlayer/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def show_images(images, titles=None, cols=5, **kwargs):
###############################################################################

class Canvas():

def __init__(self):
self._context = None
self.theme = DEFAULT_THEME
Expand All @@ -81,7 +81,7 @@ def __enter__(self):
self._context = "build"
self.drawing_calls = []
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.render()

Expand All @@ -91,7 +91,7 @@ def render(self):
if 'inline' in self.backend:
IPython.display.clear_output(wait=True)
self.figure = None

# Separate the draw_*() calls that generate a grid cell
grid_calls = []
silent_calls = []
Expand All @@ -102,7 +102,7 @@ def render(self):
grid_calls.append(c)

# Header area
# TODO: ideally, compute how much header area we need based on the
# TODO: ideally, compute how much header area we need based on the
# length of text to show there. Right now, we're just using
# a fixed number multiplied by the number of calls. Since there
# is only one silent call, draw_summary(), then the header padding
Expand All @@ -116,7 +116,7 @@ def render(self):

# Divide figure area by number of grid calls
gs = matplotlib.gridspec.GridSpec(len(grid_calls), 1)

# Call silent calls
for c in silent_calls:
getattr(self, c[0])(*c[1], **c[2])
Expand Down Expand Up @@ -153,7 +153,7 @@ def wrapper(*args, **kwargs):
self.render()
return wrapper
else:
return object.__getattribute__(self, name)
return object.__getattribute__(self, name)

def save(self, file_name):
self.figure.savefig(file_name)
Expand All @@ -168,19 +168,20 @@ def draw_summary(self, history, title=""):
summary = title + "\n\n" + summary
self.figure.suptitle(summary)

def draw_plot(self, metrics, labels=None, ylabel=""):
def draw_plot(self, metrics, labels=None, ylabel="", title=None):
"""
metrics: One or more metrics parameters. Each represents the history
of one metric.
"""
metrics = metrics if isinstance(metrics, list) else [metrics]
# Loop through metrics
title = ""
default_title = ""
for i, m in enumerate(metrics):
label = labels[i] if labels else m.name
# TODO: use a standard formating function for values
title += (" " if title else "") + "{}: {}".format(label, m.data[-1])
default_title += (" " if default_title else "") + "{}: {}".format(label, m.data[-1])
self.ax.plot(m.formatted_steps, m.data, label=label)
title = default_title if title is None else title
self.ax.set_title(title)
self.ax.set_ylabel(ylabel)
self.ax.legend()
Expand Down

0 comments on commit 45243d5

Please sign in to comment.