Skip to content

Commit

Permalink
Merge pull request #43 from haoxingz/magic_test
Browse files Browse the repository at this point in the history
Magic test
  • Loading branch information
arokem committed Feb 3, 2014
2 parents c230b79 + 2e3bfd9 commit ca5c9f4
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 28 deletions.
85 changes: 57 additions & 28 deletions pymatbridge/matlab_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
has_io = False
no_io_str = "Must have h5py and scipy.io to perform i/o"
no_io_str += "operations with the Matlab session"

from IPython.core.displaypub import publish_display_data
from IPython.core.magic import (Magics, magics_class, cell_magic, line_magic,
line_cell_magic, needs_local_scope)
Expand All @@ -39,7 +39,7 @@

import pymatbridge as pymat


class MatlabInterperterError(RuntimeError):
"""
Some error occurs while matlab is running
Expand All @@ -52,7 +52,7 @@ def __unicode__(self):
s = "Failed to parse and evaluate line %r.\n Matlab error message: %r"%\
(self.line, self.err)
return s

if PY3:
__str__ = __unicode__
else:
Expand All @@ -66,25 +66,54 @@ def loadmat(fname):
"""

f = h5py.File(fname)
data = f.values()[0][:]
if len(data.dtype) > 0:
# must be complex data
data = data['real'] + 1j * data['imag']
return data

for var_name in f.iterkeys():
if isinstance(f[var_name], h5py.Dataset):
# Currently only supports numerical array
data = f[var_name].value
if len(data.dtype) > 0:
# must be complex data
data = data['real'] + 1j * data['imag']
return np.squeeze(data.T)

elif isinstance(f[var_name], h5py.Group):
data = {}
for mem_name in f[var_name].iterkeys():
if isinstance(f[var_name][mem_name], h5py.Dataset):
# Check if the dataset is a string
attr = h5py.AttributeManager(f[var_name][mem_name])
if (attr.__getitem__('MATLAB_class') == 'char'):
is_string = True
else:
is_string = False

data[mem_name] = f[var_name][mem_name].value
data[mem_name] = np.squeeze(data[mem_name].T)

if is_string:
result = ''
for asc in data[mem_name]:
result += chr(asc)
data[mem_name] = result
else:
# Currently doesn't support nested struct
pass

return data


def matlab_converter(matlab, key):
"""
Reach into the matlab namespace and get me the value of the variable
"""
tempdir = tempfile.gettempdir()
# We save as hdf5 in the matlab session, so that we can grab large
# variables:
matlab.run_code("save('%s/%s.mat','%s','-v7.3')"%(tempdir, key, key),
maxtime=matlab.maxtime)

return loadmat('%s/%s.mat'%(tempdir, key))


Expand Down Expand Up @@ -113,17 +142,17 @@ def __init__(self, shell,
maxtime : float
The maximal time to wait for responses for matlab (in seconds).
Default: 10 seconds.
pyconverter : callable
To be called on matlab variables returning into the ipython
namespace
matlab_converter : callable
To be called on values in ipython namespace before
To be called on values in ipython namespace before
assigning to variables in matlab.
cache_display_data : bool
If True, the published results of the final call to R are
If True, the published results of the final call to R are
cached in the variable 'display_cache'.
"""
Expand All @@ -133,7 +162,7 @@ def __init__(self, shell,
self.Matlab = pymat.Matlab(matlab, maxtime=maxtime)
self.Matlab.start()
self.pyconverter = pyconverter
self.matlab_converter = matlab_converter
self.matlab_converter = matlab_converter

def __del__(self):
"""shut down the Matlab server when the object dies.
Expand All @@ -154,9 +183,9 @@ def eval(self, line):
if run_dict['success'] == 'false':
raise MatlabInterperterError(line, run_dict['content']['stdout'])

# This is the matlab stdout:
# This is the matlab stdout:
return run_dict

@magic_arguments()
@argument(
'-i', '--input', action='append',
Expand All @@ -180,7 +209,7 @@ def matlab(self, line, cell=None, local_ns=None):
"""
Execute code in matlab
"""
args = parse_argstring(self.matlab, line)

Expand Down Expand Up @@ -210,7 +239,7 @@ def matlab(self, line, cell=None, local_ns=None):
except KeyError:
val = self.shell.user_ns[input]
# We save these input arguments into a .mat file:
tempdir = tempfile.gettempdir()
tempdir = tempfile.gettempdir()
sio.savemat('%s/%s.mat'%(tempdir, input),
eval("dict(%s=val)"%input), oned_as='row')

Expand All @@ -219,7 +248,7 @@ def matlab(self, line, cell=None, local_ns=None):

else:
raise RuntimeError(no_io_str)

text_output = ''
#imgfiles = []

Expand All @@ -234,14 +263,14 @@ def matlab(self, line, cell=None, local_ns=None):
e_s += "\n-----------------------"
e_s += "\nAre you sure Matlab is started?"
raise RuntimeError(e_s)



text_output += result_dict['content']['stdout']
# Figures get saved by matlab in reverse order...
imgfiles = result_dict['content']['figures'][::-1]
data_dir = result_dict['content']['datadir']

display_data = []
if text_output:
display_data.append(('MatlabMagic.matlab',
Expand All @@ -251,7 +280,7 @@ def matlab(self, line, cell=None, local_ns=None):
if len(imgf):
# Store the path to the directory so that you can delete it
# later on:
image = open(imgf, 'rb').read()
image = open(imgf, 'rb').read()
display_data.append(('MatlabMagic.matlab',
{'image/png': image}))

Expand All @@ -261,24 +290,24 @@ def matlab(self, line, cell=None, local_ns=None):
# Delete the temporary data files created by matlab:
if len(data_dir):
rmtree(data_dir)

if args.output:
if has_io:
for output in ','.join(args.output).split(','):
self.shell.push({output:self.matlab_converter(self.Matlab,
output)})
else:
raise RuntimeError(no_io_str)


_loaded = False
def load_ipython_extension(ip, **kwargs):
"""Load the extension in IPython."""
global _loaded
if not _loaded:
ip.register_magics(MatlabMagics(ip, **kwargs))
_loaded = True

def unload_ipython_extension(ip):
global _loaded
if _loaded:
Expand Down
83 changes: 83 additions & 0 deletions pymatbridge/tests/test_magic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pymatbridge as pymat
import IPython

import numpy.testing as npt

class TestMagic:

# Create an IPython shell and load Matlab magic
@classmethod
def setup_class(cls):
cls.ip = IPython.InteractiveShell()
cls.ip.run_cell('import random')
cls.ip.run_cell('import numpy as np')
pymat.load_ipython_extension(cls.ip)

# Unload the magic, shut down Matlab
@classmethod
def teardown_class(cls):
pymat.unload_ipython_extension(cls.ip)


# Test single operation on different data structures
def test_cell_magic_number(self):
# A double precision real number
self.ip.run_cell("a = np.float64(random.random())")
self.ip.run_cell_magic('matlab', '-i a -o b', 'b = a*2;')
npt.assert_almost_equal(self.ip.user_ns['b'],
self.ip.user_ns['a']*2, decimal=7)

# A complex number
self.ip.run_cell("x = 3.34+4.56j")
self.ip.run_cell_magic('matlab', '-i x -o y', 'y = x*(11.35 - 23.098j)')
self.ip.run_cell("res = x*(11.35 - 23.098j)")
npt.assert_almost_equal(self.ip.user_ns['y'],
self.ip.user_ns['res'], decimal=7)


def test_cell_magic_array(self):
# Random array multiplication
self.ip.run_cell("val1 = np.random.random_sample((3,3))")
self.ip.run_cell("val2 = np.random.random_sample((3,3))")
self.ip.run_cell("respy = np.dot(val1, val2)")
self.ip.run_cell_magic('matlab', '-i val1,val2 -o resmat',
'resmat = val1 * val2')
npt.assert_almost_equal(self.ip.user_ns['resmat'],
self.ip.user_ns['respy'], decimal=7)


def test_line_magic(self):
# Some operation in Matlab
self.ip.run_line_magic('matlab', 'a = [1 2 3]')
self.ip.run_line_magic('matlab', 'res = a*2')
# Get the result back to Python
self.ip.run_cell_magic('matlab', '-o actual', 'actual = res')

self.ip.run_cell("expected = np.array([2, 4, 6])")
npt.assert_almost_equal(self.ip.user_ns['actual'],
self.ip.user_ns['expected'], decimal=7)

def test_figure(self):
# Just make a plot to get more testing coverage
self.ip.run_line_magic('matlab', 'plot([1 2 3])')


def test_matrix(self):
self.ip.run_cell("in_array = np.array([[1,2,3], [4,5,6]])")
self.ip.run_cell_magic('matlab', '-i in_array -o out_array',
'out_array = in_array;')
npt.assert_almost_equal(self.ip.user_ns['out_array'],
self.ip.user_ns['in_array'],
decimal=7)

# Matlab struct type should be converted to a Python dict
def test_struct(self):
self.ip.run_cell('num = 2.567')
self.ip.run_cell('num_array = np.array([1.2,3.4,5.6])')
self.ip.run_cell('str = "Hello World"')
self.ip.run_cell_magic('matlab', '-i num,num_array,str -o obj',
'obj.num = num;obj.num_array = num_array;obj.str = str;')
npt.assert_equal(isinstance(self.ip.user_ns['obj'], dict), True)
npt.assert_equal(self.ip.user_ns['obj']['num'], self.ip.user_ns['num'])
npt.assert_equal(self.ip.user_ns['obj']['num_array'], self.ip.user_ns['num_array'])
npt.assert_equal(self.ip.user_ns['obj']['str'], self.ip.user_ns['str'])

0 comments on commit ca5c9f4

Please sign in to comment.