From 3f0e65ddd82138abb3843688ebce1c5adaa0d4cd Mon Sep 17 00:00:00 2001 From: Felix Effenberger Date: Wed, 1 Jun 2016 17:37:32 +0200 Subject: [PATCH] Added visualization code --- hdnet/patterns.py | 16 ++++++++++-- hdnet/visualization.py | 59 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/hdnet/patterns.py b/hdnet/patterns.py index 31e9241..cb2fbaf 100755 --- a/hdnet/patterns.py +++ b/hdnet/patterns.py @@ -183,15 +183,27 @@ def counts_by_label(self): def patterns(self): """ Returns the patterns encountered in the raw data - as 1d vectors. + as 01-strings. + Returns + ------- + patterns : list of 01-strings + """ + return self._patterns + + @property + def patterns_as_binary(self): + """ + Returns the patterns encountered in the raw data + as binary vectors. + Returns ------- patterns : 2d numpy array, int Binary array of patterns encountered in the raw data, as 1d vectors """ - return self._patterns + return np.array([self.pattern_to_binary_matrix(i) for i in range(len(self._patterns))]) @property def num_patterns(self): diff --git a/hdnet/visualization.py b/hdnet/visualization.py index 7b16dca..4629877 100755 --- a/hdnet/visualization.py +++ b/hdnet/visualization.py @@ -11,12 +11,15 @@ """ +import os import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib.ticker import NullFormatter +from hdnet.util import hdlog + HAS_PRETTYPLOTLIB = False try: import prettyplotlib as ppl @@ -50,9 +53,9 @@ def plot_matrix_whole_canvas(matrix, **kwargs): Value : Type Description """ - plt.axis("off") ax = plt.axes([0, 0, 1, 1]) ax.matshow(matrix, **kwargs) + plt.axis('off') return ax @@ -77,7 +80,7 @@ def save_matrix_whole_canvas(matrix, fname, **kwargs): plt.figure() plot_matrix_whole_canvas(matrix, **kwargs) plt.savefig(fname) - plt.close + plt.close() def raster_plot_psth(spikes, @@ -614,4 +617,56 @@ def plot_graph(g, nodeval=None, cmap_nodes='cool', cmap_edges='autumn', return fig +def plot_network(network, filename = 'Jtheta.png', cmap = 'jet', axis = False, colorbar = True, overwrite = False): + if os.path.exists(filename) and not overwrite: + hdlog.error('plot_network: file name exists: {}, pass overwrite = True to overwrite'.format(filename)) + return + + plt.figure() + mat = network.J.copy() + mat[np.diag_indices(mat.shape[0])] = network.theta.ravel() + plt.matshow(mat, cmap = cmap) + if colorbar: + plt.colorbar() + if not axis: + plt.axis('off') + plt.savefig(filename) + plt.close() + + +def plot_hopfield_patterns(patterns, path, format = 'png', window_size = 1, memories = True, mtas = True, overwrite = False): + if os.path.exists(path): + if not overwrite: + hdlog.error('plot_overview_hofield_patterns: path exists: {}, pass overwrite = True to overwrite'.format(path)) + return + else: + os.makedirs(path) + + def _save_mat(fn, mat): + plt.figure() + ax = plt.axes([0, 0, 1, 1]) + ax.matshow(mat, cmap = 'gray') + ax.axes.get_xaxis().set_ticks([]) + ax.axes.get_yaxis().set_ticks([]) + plt.savefig(fn) + plt.close() + + npats = len(patterns.patterns) + digits = int(np.ceil(np.log10(npats))) + 1 + suffix = '{{0:0>{}d}}.{}'.format(digits, format) + if memories: + fn = 'memory' + suffix + for i in range(npats): + pat = patterns.pattern_to_binary_matrix(i) + patmat = pat.reshape((len(pat) // window_size, window_size)) + _save_mat(os.path.join(path, fn.format(i)), patmat) + + if mtas: + fn = 'mta' + suffix + for i in range(npats): + pat = patterns.pattern_to_mta_matrix(i) + patmat = pat.reshape((len(pat) // window_size, window_size)) + _save_mat(os.path.join(path, fn.format(i)), patmat) + + # end of source