Skip to content

Commit

Permalink
Added visualization code
Browse files Browse the repository at this point in the history
  • Loading branch information
exilef committed Jun 1, 2016
1 parent b3c2caf commit 3f0e65d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
16 changes: 14 additions & 2 deletions hdnet/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,27 @@ def counts_by_label(self):
def patterns(self):
"""
Returns the patterns encountered in the raw data
as 1d vectors.
as 01-strings.
Returns
-------
patterns : list of 01-strings
"""
return self._patterns

@property
def patterns_as_binary(self):
"""
Returns the patterns encountered in the raw data
as binary vectors.
Returns
-------
patterns : 2d numpy array, int
Binary array of patterns encountered in the
raw data, as 1d vectors
"""
return self._patterns
return np.array([self.pattern_to_binary_matrix(i) for i in range(len(self._patterns))])

@property
def num_patterns(self):
Expand Down
59 changes: 57 additions & 2 deletions hdnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
"""

import os
import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter

from hdnet.util import hdlog

HAS_PRETTYPLOTLIB = False
try:
import prettyplotlib as ppl
Expand Down Expand Up @@ -50,9 +53,9 @@ def plot_matrix_whole_canvas(matrix, **kwargs):
Value : Type
Description
"""
plt.axis("off")
ax = plt.axes([0, 0, 1, 1])
ax.matshow(matrix, **kwargs)
plt.axis('off')
return ax


Expand All @@ -77,7 +80,7 @@ def save_matrix_whole_canvas(matrix, fname, **kwargs):
plt.figure()
plot_matrix_whole_canvas(matrix, **kwargs)
plt.savefig(fname)
plt.close
plt.close()


def raster_plot_psth(spikes,
Expand Down Expand Up @@ -614,4 +617,56 @@ def plot_graph(g, nodeval=None, cmap_nodes='cool', cmap_edges='autumn',
return fig


def plot_network(network, filename = 'Jtheta.png', cmap = 'jet', axis = False, colorbar = True, overwrite = False):
if os.path.exists(filename) and not overwrite:
hdlog.error('plot_network: file name exists: {}, pass overwrite = True to overwrite'.format(filename))
return

plt.figure()
mat = network.J.copy()
mat[np.diag_indices(mat.shape[0])] = network.theta.ravel()
plt.matshow(mat, cmap = cmap)
if colorbar:
plt.colorbar()
if not axis:
plt.axis('off')
plt.savefig(filename)
plt.close()


def plot_hopfield_patterns(patterns, path, format = 'png', window_size = 1, memories = True, mtas = True, overwrite = False):
if os.path.exists(path):
if not overwrite:
hdlog.error('plot_overview_hofield_patterns: path exists: {}, pass overwrite = True to overwrite'.format(path))
return
else:
os.makedirs(path)

def _save_mat(fn, mat):
plt.figure()
ax = plt.axes([0, 0, 1, 1])
ax.matshow(mat, cmap = 'gray')
ax.axes.get_xaxis().set_ticks([])
ax.axes.get_yaxis().set_ticks([])
plt.savefig(fn)
plt.close()

npats = len(patterns.patterns)
digits = int(np.ceil(np.log10(npats))) + 1
suffix = '{{0:0>{}d}}.{}'.format(digits, format)
if memories:
fn = 'memory' + suffix
for i in range(npats):
pat = patterns.pattern_to_binary_matrix(i)
patmat = pat.reshape((len(pat) // window_size, window_size))
_save_mat(os.path.join(path, fn.format(i)), patmat)

if mtas:
fn = 'mta' + suffix
for i in range(npats):
pat = patterns.pattern_to_mta_matrix(i)
patmat = pat.reshape((len(pat) // window_size, window_size))
_save_mat(os.path.join(path, fn.format(i)), patmat)


# end of source

0 comments on commit 3f0e65d

Please sign in to comment.