[yt-svn] commit/yt: MatthewTurk: Merged in ngoldbaum/yt/yt-3.0 (pull request #884)

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Fri May 9 15:06:39 PDT 2014


1 new commit in yt:

https://bitbucket.org/yt_analysis/yt/commits/b19ce8c82e07/
Changeset:   b19ce8c82e07
Branch:      yt-3.0
User:        MatthewTurk
Date:        2014-05-10 00:06:32
Summary:     Merged in ngoldbaum/yt/yt-3.0 (pull request #884)

Fixes for YTArray subclasses and adding HDF5 I/O for YTArray.
Affected #:  3 files

diff -r d6681af789175edcbcf4360dd03b1a8ec726ea25 -r b19ce8c82e07be0ebc07f8e8943ac56c1bf40a21 yt/data_objects/image_array.py
--- a/yt/data_objects/image_array.py
+++ b/yt/data_objects/image_array.py
@@ -12,7 +12,6 @@
 #-----------------------------------------------------------------------------
 
 import numpy as np
-import h5py as h5
 from yt.visualization.image_writer import write_bitmap, write_image
 from yt.units.yt_array import YTArray
 
@@ -26,7 +25,7 @@
     Parameters
     ----------
     input_array: array_like
-        A numpy ndarray, or list. 
+        A numpy ndarray, or list.
 
     Other Parameters
     ----------------
@@ -35,7 +34,7 @@
 
     Returns
     -------
-    obj: ImageArray object 
+    obj: ImageArray object
 
     Raises
     ------
@@ -55,15 +54,15 @@
     --------
     These are written in doctest format, and should illustrate how to
     use the function.  Use the variables 'pf' for the parameter file, 'pc' for
-    a plot collection, 'c' for a center, and 'L' for a vector. 
+    a plot collection, 'c' for a center, and 'L' for a vector.
 
     >>> im = np.zeros([64,128,3])
     >>> for i in xrange(im.shape[0]):
     ...     for k in xrange(im.shape[2]):
     ...         im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
 
-    >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), 
-    ...     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
+    >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]),
+    ...     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),
     ...     'width':0.245, 'units':'cm', 'type':'rendering'}
 
     >>> im_arr = ImageArray(im, info=myinfo)
@@ -84,38 +83,36 @@
         super(ImageArray, self).__array_finalize__(obj)
         self.info = getattr(obj, 'info', None)
 
-    def write_hdf5(self, filename):
+    def write_hdf5(self, filename, dataset_name=None):
         r"""Writes ImageArray to hdf5 file.
 
         Parameters
         ----------
         filename: string
-            Note filename not be modified.
-       
+        The filename to create and write a dataset to
+
+        dataset_name: string
+            The name of the dataset to create in the file.
+
         Examples
-        -------- 
+        --------
         >>> im = np.zeros([64,128,3])
         >>> for i in xrange(im.shape[0]):
         ...     for k in xrange(im.shape[2]):
         ...         im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
 
-        >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), 
-        ...     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
+        >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]),
+        ...     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),
         ...     'width':0.245, 'units':'cm', 'type':'rendering'}
 
         >>> im_arr = ImageArray(im, info=myinfo)
         >>> im_arr.write_hdf5('test_ImageArray.h5')
 
         """
-        array_name = self.info.get("name","image")
-
-        f = h5.File(filename)
-        if array_name in f.keys():
-            del f[array_name]
-        d = f.create_dataset(array_name, data=self)
-        for k, v in self.info.iteritems():
-            d.attrs.create(k, v)
-        f.close()
+        if dataset_name is None:
+            dataset_name = self.info.get("name", "image")
+        super(ImageArray, self).write_hdf5(filename, dataset_name=dataset_name,
+                                           info=self.info)
 
     def add_background_color(self, background='black', inline=True):
         r"""Adds a background color to a 4-channel ImageArray
@@ -126,7 +123,7 @@
 
         Parameters
         ----------
-        background: 
+        background:
             This can be used to set a background color for the image, and can
             take several types of values:
 
@@ -144,7 +141,7 @@
         -------
         out: ImageArray
             The modified ImageArray with a background color added.
-       
+
         Examples
         --------
         >>> im = np.zeros([64,128,4])
@@ -160,8 +157,8 @@
         >>> im_arr.write_png('black_bg.png')
         """
         assert(self.shape[-1] == 4)
-        
-        if background == None:
+
+        if background is None:
             background = (0., 0., 0., 0.)
         elif background == 'white':
             background = (1., 1., 1., 1.)
@@ -175,11 +172,10 @@
             out = self.copy()
 
         for i in range(3):
-            out[:,:,i] = self[:,:,i]*self[:,:,3] + \
-                    background[i]*background[3]*(1.0-self[:,:,3])
-        out[:,:,3] = self[:,:,3] + background[3]*(1.0-self[:,:,3]) 
-        return out 
-
+            out[:, :, i] = self[:, :, i]*self[:, :, 3]
+            out[:, :, i] += background[i]*background[3]*(1.0-self[:, :, 3])
+        out[:, :, 3] = self[:, :, 3]+background[3]*(1.0-self[:, :, 3])
+        return out
 
     def rescale(self, cmax=None, amax=None, inline=True):
         r"""Rescales the image to be in [0,1] range.
@@ -194,7 +190,7 @@
             corresponding to using the maximum value in the alpha channel.
         inline: boolean, optional
             Specifies whether or not the rescaling is done inline. If false,
-            a new copy of the ImageArray will be created, returned. 
+            a new copy of the ImageArray will be created, returned.
             Default:True.
 
         Returns
@@ -207,17 +203,18 @@
         This requires that the shape of the ImageArray to have a length of 3,
         and for the third dimension to be >= 3.  If the third dimension has
         a shape of 4, the alpha channel will also be rescaled.
-       
+
         Examples
-        -------- 
+        --------
         >>> im = np.zeros([64,128,4])
         >>> for i in xrange(im.shape[0]):
         ...     for k in xrange(im.shape[2]):
         ...         im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
 
-        >>> im_arr.write_png('original.png')
-        >>> im_arr.rescale()
-        >>> im_arr.write_png('normalized.png')
+        >>> im = ImageArray(im)
+        >>> im.write_png('original.png')
+        >>> im.rescale()
+        >>> im.write_png('normalized.png')
 
         """
         assert(len(self.shape) == 3)
@@ -226,22 +223,22 @@
             out = self
         else:
             out = self.copy()
-        if cmax is None: 
-            cmax = self[:,:,:3].sum(axis=2).max()
+        if cmax is None:
+            cmax = self[:, :, :3].sum(axis=2).max()
 
-        np.multiply(self[:,:,:3], 1./cmax, out[:,:,:3])
+        np.multiply(self[:, :, :3], 1.0/cmax, out[:, :, :3])
 
         if self.shape[2] == 4:
             if amax is None:
-                amax = self[:,:,3].max()
+                amax = self[:, :, 3].max()
             if amax > 0.0:
-                np.multiply(self[:,:,3], 1./amax, out[:,:,3])
-        
+                np.multiply(self[:, :, 3], 1.0/amax, out[:, :, 3])
+
         np.clip(out, 0.0, 1.0, out)
         return out
 
     def write_png(self, filename, clip_ratio=None, background='black',
-                 rescale=True):
+                  rescale=True):
         r"""Writes ImageArray to png file.
 
         Parameters
@@ -250,9 +247,9 @@
             Note filename not be modified.
         clip_ratio: float, optional
             Image will be clipped before saving to the standard deviation
-            of the image multiplied by this value.  Useful for enhancing 
+            of the image multiplied by this value.  Useful for enhancing
             images. Default: None
-        background: 
+        background:
             This can be used to set a background color for the image, and can
             take several types of values:
 
@@ -265,7 +262,7 @@
         rescale: boolean, optional
             If True, will write out a rescaled image (without modifying the
             original image). Default: True
-       
+
         Examples
         --------
         >>> im = np.zeros([64,128,4])
@@ -292,25 +289,25 @@
         else:
             out = scaled
 
-        if filename[-4:] != '.png': 
+        if filename[-4:] != '.png':
             filename += '.png'
 
         if clip_ratio is not None:
-            nz = out[:,:,:3][out[:,:,:3].nonzero()]
+            nz = out[:, :, :3][out[:, :, :3].nonzero()]
             return write_bitmap(out.swapaxes(0, 1), filename,
-                                nz.mean() + \
-                                clip_ratio * nz.std())
+                                nz.mean() + clip_ratio*nz.std())
         else:
             return write_bitmap(out.swapaxes(0, 1), filename)
 
-    def write_image(self, filename, color_bounds=None, channel=None,  cmap_name="algae", func=lambda x: x):
+    def write_image(self, filename, color_bounds=None, channel=None,
+                    cmap_name="algae", func=lambda x: x):
         r"""Writes a single channel of the ImageArray to a png file.
 
         Parameters
         ----------
         filename: string
             Note filename not be modified.
-       
+
         Other Parameters
         ----------------
         channel: int
@@ -323,43 +320,44 @@
             An acceptable colormap.  See either yt.visualization.color_maps or
             http://www.scipy.org/Cookbook/Matplotlib/Show_colormaps .
         func : function, optional
-            A function to transform the buffer before applying a colormap. 
+            A function to transform the buffer before applying a colormap.
 
         Returns
         -------
         scaled_image : uint8 image that has been saved
-        
+
         Examples
         --------
-        
+
         >>> im = np.zeros([64,128])
         >>> for i in xrange(im.shape[0]):
-        ...     im[i,:] = np.linspace(0.,0.3*k, im.shape[1])
+        ...     im[i,:] = np.linspace(0.,0.3*i, im.shape[1])
 
-        >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), 
-        ...     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
+        >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]),
+        ...     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),
         ...     'width':0.245, 'units':'cm', 'type':'rendering'}
 
         >>> im_arr = ImageArray(im, info=myinfo)
         >>> im_arr.write_image('test_ImageArray.png')
 
         """
-        if filename[-4:] != '.png': 
+        if filename[-4:] != '.png':
             filename += '.png'
 
+        #TODO: Write info dict as png metadata
         if channel is None:
-            return write_image(self.swapaxes(0,1).to_ndarray(), filename,
+            return write_image(self.swapaxes(0, 1).to_ndarray(), filename,
                                color_bounds=color_bounds, cmap_name=cmap_name,
                                func=func)
         else:
-            return write_image(self.swapaxes(0,1)[:,:,channel].to_ndarray(),
+            return write_image(self.swapaxes(0, 1)[:, :, channel].to_ndarray(),
                                filename,
-                               color_bounds=color_bounds, cmap_name=cmap_name, 
+                               color_bounds=color_bounds, cmap_name=cmap_name,
                                func=func)
 
     def save(self, filename, png=True, hdf5=True):
         """
-        Saves ImageArray. 
+        Saves ImageArray.
 
         Arguments:
           filename: string
@@ -380,6 +378,3 @@
                 self.write_image("%s.png" % filename)
         if hdf5:
             self.write_hdf5("%s.h5" % filename)
-
-    __doc__ += np.ndarray.__doc__
-

diff -r d6681af789175edcbcf4360dd03b1a8ec726ea25 -r b19ce8c82e07be0ebc07f8e8943ac56c1bf40a21 yt/units/tests/test_ytarray.py
--- a/yt/units/tests/test_ytarray.py
+++ b/yt/units/tests/test_ytarray.py
@@ -14,7 +14,15 @@
 # The full license is in the file COPYING.txt, distributed with this software.
 # ----------------------------------------------------------------------------
 
+import copy
+import cPickle as pickle
+import itertools
+import numpy as np
+import operator
 import os
+import shutil
+import tempfile
+
 from nose.tools import assert_true
 from numpy.testing import \
     assert_array_equal, \
@@ -28,12 +36,6 @@
     YTUnitOperationError, YTUfuncUnitError
 from yt.testing import fake_random_pf, requires_module
 from yt.funcs import fix_length
-import numpy as np
-import copy
-import operator
-import cPickle as pickle
-import tempfile
-import itertools
 
 
 def operate_and_compare(a, b, op, answer):
@@ -675,3 +677,54 @@
     yield assert_equal, yt_quan, YTQuantity(yt_quan.to_astropy())
 
 
+def test_subclass():
+
+    class YTASubclass(YTArray):
+        pass
+
+    a = YTASubclass([4, 5, 6], 'g')
+    b = YTASubclass([7, 8, 9], 'kg')
+    nu = YTASubclass([10, 11, 12], '')
+    nda = np.array([3, 4, 5])
+    yta = YTArray([6, 7, 8], 'mg')
+    ytq = YTQuantity(4, 'cm')
+    ndf = np.float64(3)
+
+    def op_comparison(op, inst1, inst2, compare_class):
+        assert_isinstance(op(inst1, inst2), compare_class)
+        assert_isinstance(op(inst2, inst1), compare_class)
+
+    for op in (operator.mul, operator.div, operator.truediv):
+        for inst in (b, ytq, ndf, yta, nda):
+            yield op_comparison, op, a, inst, YTASubclass
+
+        yield op_comparison, op, ytq, nda, YTArray
+        yield op_comparison, op, ytq, yta, YTArray
+
+    for op in (operator.add, operator.sub):
+        yield op_comparison, op, nu, nda, YTASubclass
+        yield op_comparison, op, a, b, YTASubclass
+        yield op_comparison, op, a, yta, YTASubclass
+
+    yield assert_isinstance, a[0], YTQuantity
+    yield assert_isinstance, a[:], YTASubclass
+    yield assert_isinstance, a[:2], YTASubclass
+
+def test_h5_io():
+    tmpdir = tempfile.mkdtemp()
+    curdir = os.getcwd()
+    os.chdir(tmpdir)
+
+    ds = fake_random_pf(64, nprocs=1, length_unit=10)
+
+    warr = ds.arr(np.random.random((256, 256)), 'code_length')
+
+    warr.write_hdf5('test.h5')
+
+    iarr = YTArray.from_hdf5('test.h5')
+
+    yield assert_equal, warr, iarr
+    yield assert_equal, warr.units.registry['code_length'], iarr.units.registry['code_length']
+
+    os.chdir(curdir)
+    shutil.rmtree(tmpdir)

diff -r d6681af789175edcbcf4360dd03b1a8ec726ea25 -r b19ce8c82e07be0ebc07f8e8943ac56c1bf40a21 yt/units/yt_array.py
--- a/yt/units/yt_array.py
+++ b/yt/units/yt_array.py
@@ -74,7 +74,8 @@
         if ret.shape == ():
             return YTQuantity(ret, units)
         else:
-            return YTArray(ret, units)
+            # This could be a subclass, so don't call YTArray directly.
+            return type(args[0])(ret, units)
     return wrapped
 
 def sqrt_unit(unit):
@@ -464,6 +465,92 @@
     # End unit conversion methods
     #
 
+    def write_hdf5(self, filename, dataset_name=None, info=None):
+        r"""Writes ImageArray to hdf5 file.
+
+        Parameters
+        ----------
+        filename: string
+            The filename to create and write a dataset to
+
+        dataset_name: string
+            The name of the dataset to create in the file.
+
+        info: dictionary
+            A dictionary of supplementary info to write to append as attributes
+            to the dataset.
+
+        Examples
+        --------
+        >>> a = YTArray([1,2,3], 'cm')
+
+        >>> myinfo = {'field':'dinosaurs', 'type':'field_data'}
+
+        >>> a.write_hdf5('test_array_data.h5', dataset_name='dinosaurs',
+        ...              info=myinfo)
+
+        """
+        import h5py
+        from yt.extern.six.moves import cPickle as pickle
+        if info is None:
+            info = {}
+
+        info['units'] = str(self.units)
+        info['unit_registry'] = pickle.dumps(self.units.registry.lut)
+
+        if dataset_name is None:
+            dataset_name = 'array_data'
+
+        f = h5py.File(filename)
+        if dataset_name in f.keys():
+            d = f[dataset_name]
+            # Overwrite without deleting if we can get away with it.
+            if d.shape == self.shape and d.dtype == self.dtype:
+                d[:] = self
+                for k in d.attrs.keys():
+                    del d.attrs[k]
+            else:
+                del f[dataset_name]
+                d = f.create_dataset(dataset_name, data=self)
+        else:
+            d = f.create_dataset(dataset_name, data=self)
+
+        for k, v in info.iteritems():
+            d.attrs.create(k, v)
+        f.close()
+
+    @classmethod
+    def from_hdf5(cls, filename, dataset_name=None):
+        r"""Attempts read in and convert a dataset in an hdf5 file into a YTArray.
+
+        Parameters
+        ----------
+        filename: string
+        The filename to of the hdf5 file.
+
+        dataset_name: string
+            The name of the dataset to read from.  If the dataset has a units
+            attribute, attempt to infer units as well.
+
+        """
+        import h5py
+        from yt.extern.six.moves import cPickle as pickle
+
+        if dataset_name is None:
+            dataset_name = 'array_data'
+
+        f = h5py.File(filename)
+        dataset = f[dataset_name]
+        data = dataset[:]
+        units = dataset.attrs.get('units', '')
+        if 'unit_registry' in dataset.attrs.keys():
+            unit_lut = pickle.loads(dataset.attrs['unit_registry'])
+        else:
+            unit_lut = None
+
+        registry = UnitRegistry(lut=unit_lut, add_default_symbols=False)
+        return cls(data, units, registry=registry)
+
     #
     # Start convenience methods
     #
@@ -766,7 +853,7 @@
 
     @return_arr
     def prod(self, axis=None, dtype=None, out=None):
-        if axis:
+        if axis is not None:
             units = self.units**self.shape[axis]
         else:
             units = self.units**self.size
@@ -814,9 +901,13 @@
             # Raise YTUnitOperationError up here since we know the context now
             except RuntimeError:
                 raise YTUnitOperationError(context[0], u)
+            ret_class = type(self)
         elif context[0] in binary_operators:
             unit1 = getattr(context[1][0], 'units', None)
             unit2 = getattr(context[1][1], 'units', None)
+            cls1 = type(context[1][0])
+            cls2 = type(context[1][1])
+            ret_class = get_binary_op_return_class(cls1, cls2)
             if unit1 is None:
                 unit1 = Unit(registry=getattr(unit2, 'registry', None))
             if unit2 is None and context[0] is not power:
@@ -849,10 +940,15 @@
             out_arr = np.array(out_arr)
             return out_arr
         out_arr.units = unit
-        if out_arr.size > 1:
-            return YTArray(np.array(out_arr), unit)
+        if out_arr.size == 1:
+            return YTQuantity(np.array(out_arr), unit)
         else:
-            return YTQuantity(np.array(out_arr), unit)
+            if ret_class is YTQuantity:
+                # This happens if you do ndarray * YTQuantity. Explicitly
+                # casting to YTArray avoids creating a YTQuantity with size > 1
+                return YTArray(np.array(out_arr, unit))
+            return ret_class(np.array(out_arr), unit)
+
 
     def __reduce__(self):
         """Pickle reduction method
@@ -929,3 +1025,22 @@
         return data.pf.arr(x, units)
     else:
         return data.pf.quan(x, units)
+
+def get_binary_op_return_class(cls1, cls2):
+    if cls1 is cls2:
+        return cls1
+    if cls1 is np.ndarray or issubclass(cls1, numeric_type):
+        return cls2
+    if cls2 is np.ndarray or issubclass(cls2, numeric_type):
+        return cls1
+    if issubclass(cls1, YTQuantity):
+        return cls2
+    if issubclass(cls2, YTQuantity):
+        return cls1
+    if issubclass(cls1, cls2):
+        return cls1
+    if issubclass(cls2, cls1):
+        return cls2
+    else:
+        raise RuntimeError("Operations are only defined on pairs of objects"
+                           "in which one is a subclass of the other")

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.



More information about the yt-svn mailing list