From 45243d51fd78cb6edc45cca50d29b04fb4b35511 Mon Sep 17 00:00:00 2001 From: Mocuto Oshi Date: Thu, 16 May 2019 17:50:31 -0700 Subject: [PATCH] Added configurable `title` for each canvas plot --- hiddenlayer/canvas.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/hiddenlayer/canvas.py b/hiddenlayer/canvas.py index bdeed67..84a630e 100644 --- a/hiddenlayer/canvas.py +++ b/hiddenlayer/canvas.py @@ -68,7 +68,7 @@ def show_images(images, titles=None, cols=5, **kwargs): ############################################################################### class Canvas(): - + def __init__(self): self._context = None self.theme = DEFAULT_THEME @@ -81,7 +81,7 @@ def __enter__(self): self._context = "build" self.drawing_calls = [] return self - + def __exit__(self, exc_type, exc_val, exc_tb): self.render() @@ -91,7 +91,7 @@ def render(self): if 'inline' in self.backend: IPython.display.clear_output(wait=True) self.figure = None - + # Separate the draw_*() calls that generate a grid cell grid_calls = [] silent_calls = [] @@ -102,7 +102,7 @@ def render(self): grid_calls.append(c) # Header area - # TODO: ideally, compute how much header area we need based on the + # TODO: ideally, compute how much header area we need based on the # length of text to show there. Right now, we're just using # a fixed number multiplied by the number of calls. Since there # is only one silent call, draw_summary(), then the header padding @@ -116,7 +116,7 @@ def render(self): # Divide figure area by number of grid calls gs = matplotlib.gridspec.GridSpec(len(grid_calls), 1) - + # Call silent calls for c in silent_calls: getattr(self, c[0])(*c[1], **c[2]) @@ -153,7 +153,7 @@ def wrapper(*args, **kwargs): self.render() return wrapper else: - return object.__getattribute__(self, name) + return object.__getattribute__(self, name) def save(self, file_name): self.figure.savefig(file_name) @@ -168,19 +168,20 @@ def draw_summary(self, history, title=""): summary = title + "\n\n" + summary self.figure.suptitle(summary) - def draw_plot(self, metrics, labels=None, ylabel=""): + def draw_plot(self, metrics, labels=None, ylabel="", title=None): """ metrics: One or more metrics parameters. Each represents the history of one metric. """ metrics = metrics if isinstance(metrics, list) else [metrics] # Loop through metrics - title = "" + default_title = "" for i, m in enumerate(metrics): label = labels[i] if labels else m.name # TODO: use a standard formating function for values - title += (" " if title else "") + "{}: {}".format(label, m.data[-1]) + default_title += (" " if default_title else "") + "{}: {}".format(label, m.data[-1]) self.ax.plot(m.formatted_steps, m.data, label=label) + title = default_title if title is None else title self.ax.set_title(title) self.ax.set_ylabel(ylabel) self.ax.legend()