diff --git a/pymatbridge/matlab_magic.py b/pymatbridge/matlab_magic.py index badae16..59a21e1 100644 --- a/pymatbridge/matlab_magic.py +++ b/pymatbridge/matlab_magic.py @@ -28,7 +28,7 @@ has_io = False no_io_str = "Must have h5py and scipy.io to perform i/o" no_io_str += "operations with the Matlab session" - + from IPython.core.displaypub import publish_display_data from IPython.core.magic import (Magics, magics_class, cell_magic, line_magic, line_cell_magic, needs_local_scope) @@ -39,7 +39,7 @@ import pymatbridge as pymat - + class MatlabInterperterError(RuntimeError): """ Some error occurs while matlab is running @@ -52,7 +52,7 @@ def __unicode__(self): s = "Failed to parse and evaluate line %r.\n Matlab error message: %r"%\ (self.line, self.err) return s - + if PY3: __str__ = __unicode__ else: @@ -66,25 +66,54 @@ def loadmat(fname): """ f = h5py.File(fname) - data = f.values()[0][:] - if len(data.dtype) > 0: - # must be complex data - data = data['real'] + 1j * data['imag'] - return data + + for var_name in f.iterkeys(): + if isinstance(f[var_name], h5py.Dataset): + # Currently only supports numerical array + data = f[var_name].value + if len(data.dtype) > 0: + # must be complex data + data = data['real'] + 1j * data['imag'] + return np.squeeze(data.T) + + elif isinstance(f[var_name], h5py.Group): + data = {} + for mem_name in f[var_name].iterkeys(): + if isinstance(f[var_name][mem_name], h5py.Dataset): + # Check if the dataset is a string + attr = h5py.AttributeManager(f[var_name][mem_name]) + if (attr.__getitem__('MATLAB_class') == 'char'): + is_string = True + else: + is_string = False + + data[mem_name] = f[var_name][mem_name].value + data[mem_name] = np.squeeze(data[mem_name].T) + + if is_string: + result = '' + for asc in data[mem_name]: + result += chr(asc) + data[mem_name] = result + else: + # Currently doesn't support nested struct + pass + + return data def matlab_converter(matlab, key): """ Reach into the matlab namespace and get me the value of the variable - + """ tempdir = tempfile.gettempdir() # We save as hdf5 in the matlab session, so that we can grab large # variables: matlab.run_code("save('%s/%s.mat','%s','-v7.3')"%(tempdir, key, key), maxtime=matlab.maxtime) - + return loadmat('%s/%s.mat'%(tempdir, key)) @@ -113,17 +142,17 @@ def __init__(self, shell, maxtime : float The maximal time to wait for responses for matlab (in seconds). Default: 10 seconds. - + pyconverter : callable To be called on matlab variables returning into the ipython namespace - + matlab_converter : callable - To be called on values in ipython namespace before + To be called on values in ipython namespace before assigning to variables in matlab. cache_display_data : bool - If True, the published results of the final call to R are + If True, the published results of the final call to R are cached in the variable 'display_cache'. """ @@ -133,7 +162,7 @@ def __init__(self, shell, self.Matlab = pymat.Matlab(matlab, maxtime=maxtime) self.Matlab.start() self.pyconverter = pyconverter - self.matlab_converter = matlab_converter + self.matlab_converter = matlab_converter def __del__(self): """shut down the Matlab server when the object dies. @@ -154,9 +183,9 @@ def eval(self, line): if run_dict['success'] == 'false': raise MatlabInterperterError(line, run_dict['content']['stdout']) - # This is the matlab stdout: + # This is the matlab stdout: return run_dict - + @magic_arguments() @argument( '-i', '--input', action='append', @@ -180,7 +209,7 @@ def matlab(self, line, cell=None, local_ns=None): """ Execute code in matlab - + """ args = parse_argstring(self.matlab, line) @@ -210,7 +239,7 @@ def matlab(self, line, cell=None, local_ns=None): except KeyError: val = self.shell.user_ns[input] # We save these input arguments into a .mat file: - tempdir = tempfile.gettempdir() + tempdir = tempfile.gettempdir() sio.savemat('%s/%s.mat'%(tempdir, input), eval("dict(%s=val)"%input), oned_as='row') @@ -219,7 +248,7 @@ def matlab(self, line, cell=None, local_ns=None): else: raise RuntimeError(no_io_str) - + text_output = '' #imgfiles = [] @@ -234,14 +263,14 @@ def matlab(self, line, cell=None, local_ns=None): e_s += "\n-----------------------" e_s += "\nAre you sure Matlab is started?" raise RuntimeError(e_s) - - + + text_output += result_dict['content']['stdout'] # Figures get saved by matlab in reverse order... imgfiles = result_dict['content']['figures'][::-1] data_dir = result_dict['content']['datadir'] - + display_data = [] if text_output: display_data.append(('MatlabMagic.matlab', @@ -251,7 +280,7 @@ def matlab(self, line, cell=None, local_ns=None): if len(imgf): # Store the path to the directory so that you can delete it # later on: - image = open(imgf, 'rb').read() + image = open(imgf, 'rb').read() display_data.append(('MatlabMagic.matlab', {'image/png': image})) @@ -261,7 +290,7 @@ def matlab(self, line, cell=None, local_ns=None): # Delete the temporary data files created by matlab: if len(data_dir): rmtree(data_dir) - + if args.output: if has_io: for output in ','.join(args.output).split(','): @@ -269,8 +298,8 @@ def matlab(self, line, cell=None, local_ns=None): output)}) else: raise RuntimeError(no_io_str) - - + + _loaded = False def load_ipython_extension(ip, **kwargs): """Load the extension in IPython.""" @@ -278,7 +307,7 @@ def load_ipython_extension(ip, **kwargs): if not _loaded: ip.register_magics(MatlabMagics(ip, **kwargs)) _loaded = True - + def unload_ipython_extension(ip): global _loaded if _loaded: diff --git a/pymatbridge/tests/test_magic.py b/pymatbridge/tests/test_magic.py new file mode 100644 index 0000000..47987bf --- /dev/null +++ b/pymatbridge/tests/test_magic.py @@ -0,0 +1,83 @@ +import pymatbridge as pymat +import IPython + +import numpy.testing as npt + +class TestMagic: + + # Create an IPython shell and load Matlab magic + @classmethod + def setup_class(cls): + cls.ip = IPython.InteractiveShell() + cls.ip.run_cell('import random') + cls.ip.run_cell('import numpy as np') + pymat.load_ipython_extension(cls.ip) + + # Unload the magic, shut down Matlab + @classmethod + def teardown_class(cls): + pymat.unload_ipython_extension(cls.ip) + + + # Test single operation on different data structures + def test_cell_magic_number(self): + # A double precision real number + self.ip.run_cell("a = np.float64(random.random())") + self.ip.run_cell_magic('matlab', '-i a -o b', 'b = a*2;') + npt.assert_almost_equal(self.ip.user_ns['b'], + self.ip.user_ns['a']*2, decimal=7) + + # A complex number + self.ip.run_cell("x = 3.34+4.56j") + self.ip.run_cell_magic('matlab', '-i x -o y', 'y = x*(11.35 - 23.098j)') + self.ip.run_cell("res = x*(11.35 - 23.098j)") + npt.assert_almost_equal(self.ip.user_ns['y'], + self.ip.user_ns['res'], decimal=7) + + + def test_cell_magic_array(self): + # Random array multiplication + self.ip.run_cell("val1 = np.random.random_sample((3,3))") + self.ip.run_cell("val2 = np.random.random_sample((3,3))") + self.ip.run_cell("respy = np.dot(val1, val2)") + self.ip.run_cell_magic('matlab', '-i val1,val2 -o resmat', + 'resmat = val1 * val2') + npt.assert_almost_equal(self.ip.user_ns['resmat'], + self.ip.user_ns['respy'], decimal=7) + + + def test_line_magic(self): + # Some operation in Matlab + self.ip.run_line_magic('matlab', 'a = [1 2 3]') + self.ip.run_line_magic('matlab', 'res = a*2') + # Get the result back to Python + self.ip.run_cell_magic('matlab', '-o actual', 'actual = res') + + self.ip.run_cell("expected = np.array([2, 4, 6])") + npt.assert_almost_equal(self.ip.user_ns['actual'], + self.ip.user_ns['expected'], decimal=7) + + def test_figure(self): + # Just make a plot to get more testing coverage + self.ip.run_line_magic('matlab', 'plot([1 2 3])') + + + def test_matrix(self): + self.ip.run_cell("in_array = np.array([[1,2,3], [4,5,6]])") + self.ip.run_cell_magic('matlab', '-i in_array -o out_array', + 'out_array = in_array;') + npt.assert_almost_equal(self.ip.user_ns['out_array'], + self.ip.user_ns['in_array'], + decimal=7) + + # Matlab struct type should be converted to a Python dict + def test_struct(self): + self.ip.run_cell('num = 2.567') + self.ip.run_cell('num_array = np.array([1.2,3.4,5.6])') + self.ip.run_cell('str = "Hello World"') + self.ip.run_cell_magic('matlab', '-i num,num_array,str -o obj', + 'obj.num = num;obj.num_array = num_array;obj.str = str;') + npt.assert_equal(isinstance(self.ip.user_ns['obj'], dict), True) + npt.assert_equal(self.ip.user_ns['obj']['num'], self.ip.user_ns['num']) + npt.assert_equal(self.ip.user_ns['obj']['num_array'], self.ip.user_ns['num_array']) + npt.assert_equal(self.ip.user_ns['obj']['str'], self.ip.user_ns['str'])