[yt-svn] commit/yt: MatthewTurk: Merged in ngoldbaum/yt/yt-3.0 (pull request #884)
commits-noreply at bitbucket.org
commits-noreply at bitbucket.org
Fri May 9 15:06:39 PDT 2014
1 new commit in yt:
https://bitbucket.org/yt_analysis/yt/commits/b19ce8c82e07/
Changeset: b19ce8c82e07
Branch: yt-3.0
User: MatthewTurk
Date: 2014-05-10 00:06:32
Summary: Merged in ngoldbaum/yt/yt-3.0 (pull request #884)
Fixes for YTArray subclasses and adding HDF5 I/O for YTArray.
Affected #: 3 files
diff -r d6681af789175edcbcf4360dd03b1a8ec726ea25 -r b19ce8c82e07be0ebc07f8e8943ac56c1bf40a21 yt/data_objects/image_array.py
--- a/yt/data_objects/image_array.py
+++ b/yt/data_objects/image_array.py
@@ -12,7 +12,6 @@
#-----------------------------------------------------------------------------
import numpy as np
-import h5py as h5
from yt.visualization.image_writer import write_bitmap, write_image
from yt.units.yt_array import YTArray
@@ -26,7 +25,7 @@
Parameters
----------
input_array: array_like
- A numpy ndarray, or list.
+ A numpy ndarray, or list.
Other Parameters
----------------
@@ -35,7 +34,7 @@
Returns
-------
- obj: ImageArray object
+ obj: ImageArray object
Raises
------
@@ -55,15 +54,15 @@
--------
These are written in doctest format, and should illustrate how to
use the function. Use the variables 'pf' for the parameter file, 'pc' for
- a plot collection, 'c' for a center, and 'L' for a vector.
+ a plot collection, 'c' for a center, and 'L' for a vector.
>>> im = np.zeros([64,128,3])
>>> for i in xrange(im.shape[0]):
... for k in xrange(im.shape[2]):
... im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
- >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]),
- ... 'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),
+ >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]),
+ ... 'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),
... 'width':0.245, 'units':'cm', 'type':'rendering'}
>>> im_arr = ImageArray(im, info=myinfo)
@@ -84,38 +83,36 @@
super(ImageArray, self).__array_finalize__(obj)
self.info = getattr(obj, 'info', None)
- def write_hdf5(self, filename):
+ def write_hdf5(self, filename, dataset_name=None):
r"""Writes ImageArray to hdf5 file.
Parameters
----------
filename: string
- Note filename not be modified.
-
+ The filename to create and write a dataset to
+
+ dataset_name: string
+ The name of the dataset to create in the file.
+
Examples
- --------
+ --------
>>> im = np.zeros([64,128,3])
>>> for i in xrange(im.shape[0]):
... for k in xrange(im.shape[2]):
... im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
- >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]),
- ... 'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),
+ >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]),
+ ... 'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),
... 'width':0.245, 'units':'cm', 'type':'rendering'}
>>> im_arr = ImageArray(im, info=myinfo)
>>> im_arr.write_hdf5('test_ImageArray.h5')
"""
- array_name = self.info.get("name","image")
-
- f = h5.File(filename)
- if array_name in f.keys():
- del f[array_name]
- d = f.create_dataset(array_name, data=self)
- for k, v in self.info.iteritems():
- d.attrs.create(k, v)
- f.close()
+ if dataset_name is None:
+ dataset_name = self.info.get("name", "image")
+ super(ImageArray, self).write_hdf5(filename, dataset_name=dataset_name,
+ info=self.info)
def add_background_color(self, background='black', inline=True):
r"""Adds a background color to a 4-channel ImageArray
@@ -126,7 +123,7 @@
Parameters
----------
- background:
+ background:
This can be used to set a background color for the image, and can
take several types of values:
@@ -144,7 +141,7 @@
-------
out: ImageArray
The modified ImageArray with a background color added.
-
+
Examples
--------
>>> im = np.zeros([64,128,4])
@@ -160,8 +157,8 @@
>>> im_arr.write_png('black_bg.png')
"""
assert(self.shape[-1] == 4)
-
- if background == None:
+
+ if background is None:
background = (0., 0., 0., 0.)
elif background == 'white':
background = (1., 1., 1., 1.)
@@ -175,11 +172,10 @@
out = self.copy()
for i in range(3):
- out[:,:,i] = self[:,:,i]*self[:,:,3] + \
- background[i]*background[3]*(1.0-self[:,:,3])
- out[:,:,3] = self[:,:,3] + background[3]*(1.0-self[:,:,3])
- return out
-
+ out[:, :, i] = self[:, :, i]*self[:, :, 3]
+ out[:, :, i] += background[i]*background[3]*(1.0-self[:, :, 3])
+ out[:, :, 3] = self[:, :, 3]+background[3]*(1.0-self[:, :, 3])
+ return out
def rescale(self, cmax=None, amax=None, inline=True):
r"""Rescales the image to be in [0,1] range.
@@ -194,7 +190,7 @@
corresponding to using the maximum value in the alpha channel.
inline: boolean, optional
Specifies whether or not the rescaling is done inline. If false,
- a new copy of the ImageArray will be created, returned.
+ a new copy of the ImageArray will be created, returned.
Default:True.
Returns
@@ -207,17 +203,18 @@
This requires that the shape of the ImageArray to have a length of 3,
and for the third dimension to be >= 3. If the third dimension has
a shape of 4, the alpha channel will also be rescaled.
-
+
Examples
- --------
+ --------
>>> im = np.zeros([64,128,4])
>>> for i in xrange(im.shape[0]):
... for k in xrange(im.shape[2]):
... im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
- >>> im_arr.write_png('original.png')
- >>> im_arr.rescale()
- >>> im_arr.write_png('normalized.png')
+ >>> im = ImageArray(im)
+ >>> im.write_png('original.png')
+ >>> im.rescale()
+ >>> im.write_png('normalized.png')
"""
assert(len(self.shape) == 3)
@@ -226,22 +223,22 @@
out = self
else:
out = self.copy()
- if cmax is None:
- cmax = self[:,:,:3].sum(axis=2).max()
+ if cmax is None:
+ cmax = self[:, :, :3].sum(axis=2).max()
- np.multiply(self[:,:,:3], 1./cmax, out[:,:,:3])
+ np.multiply(self[:, :, :3], 1.0/cmax, out[:, :, :3])
if self.shape[2] == 4:
if amax is None:
- amax = self[:,:,3].max()
+ amax = self[:, :, 3].max()
if amax > 0.0:
- np.multiply(self[:,:,3], 1./amax, out[:,:,3])
-
+ np.multiply(self[:, :, 3], 1.0/amax, out[:, :, 3])
+
np.clip(out, 0.0, 1.0, out)
return out
def write_png(self, filename, clip_ratio=None, background='black',
- rescale=True):
+ rescale=True):
r"""Writes ImageArray to png file.
Parameters
@@ -250,9 +247,9 @@
Note filename not be modified.
clip_ratio: float, optional
Image will be clipped before saving to the standard deviation
- of the image multiplied by this value. Useful for enhancing
+ of the image multiplied by this value. Useful for enhancing
images. Default: None
- background:
+ background:
This can be used to set a background color for the image, and can
take several types of values:
@@ -265,7 +262,7 @@
rescale: boolean, optional
If True, will write out a rescaled image (without modifying the
original image). Default: True
-
+
Examples
--------
>>> im = np.zeros([64,128,4])
@@ -292,25 +289,25 @@
else:
out = scaled
- if filename[-4:] != '.png':
+ if filename[-4:] != '.png':
filename += '.png'
if clip_ratio is not None:
- nz = out[:,:,:3][out[:,:,:3].nonzero()]
+ nz = out[:, :, :3][out[:, :, :3].nonzero()]
return write_bitmap(out.swapaxes(0, 1), filename,
- nz.mean() + \
- clip_ratio * nz.std())
+ nz.mean() + clip_ratio*nz.std())
else:
return write_bitmap(out.swapaxes(0, 1), filename)
- def write_image(self, filename, color_bounds=None, channel=None, cmap_name="algae", func=lambda x: x):
+ def write_image(self, filename, color_bounds=None, channel=None,
+ cmap_name="algae", func=lambda x: x):
r"""Writes a single channel of the ImageArray to a png file.
Parameters
----------
filename: string
Note filename not be modified.
-
+
Other Parameters
----------------
channel: int
@@ -323,43 +320,44 @@
An acceptable colormap. See either yt.visualization.color_maps or
http://www.scipy.org/Cookbook/Matplotlib/Show_colormaps .
func : function, optional
- A function to transform the buffer before applying a colormap.
+ A function to transform the buffer before applying a colormap.
Returns
-------
scaled_image : uint8 image that has been saved
-
+
Examples
--------
-
+
>>> im = np.zeros([64,128])
>>> for i in xrange(im.shape[0]):
- ... im[i,:] = np.linspace(0.,0.3*k, im.shape[1])
+ ... im[i,:] = np.linspace(0.,0.3*i, im.shape[1])
- >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]),
- ... 'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),
+ >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]),
+ ... 'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),
... 'width':0.245, 'units':'cm', 'type':'rendering'}
>>> im_arr = ImageArray(im, info=myinfo)
>>> im_arr.write_image('test_ImageArray.png')
"""
- if filename[-4:] != '.png':
+ if filename[-4:] != '.png':
filename += '.png'
+ #TODO: Write info dict as png metadata
if channel is None:
- return write_image(self.swapaxes(0,1).to_ndarray(), filename,
+ return write_image(self.swapaxes(0, 1).to_ndarray(), filename,
color_bounds=color_bounds, cmap_name=cmap_name,
func=func)
else:
- return write_image(self.swapaxes(0,1)[:,:,channel].to_ndarray(),
+ return write_image(self.swapaxes(0, 1)[:, :, channel].to_ndarray(),
filename,
- color_bounds=color_bounds, cmap_name=cmap_name,
+ color_bounds=color_bounds, cmap_name=cmap_name,
func=func)
def save(self, filename, png=True, hdf5=True):
"""
- Saves ImageArray.
+ Saves ImageArray.
Arguments:
filename: string
@@ -380,6 +378,3 @@
self.write_image("%s.png" % filename)
if hdf5:
self.write_hdf5("%s.h5" % filename)
-
- __doc__ += np.ndarray.__doc__
-
diff -r d6681af789175edcbcf4360dd03b1a8ec726ea25 -r b19ce8c82e07be0ebc07f8e8943ac56c1bf40a21 yt/units/tests/test_ytarray.py
--- a/yt/units/tests/test_ytarray.py
+++ b/yt/units/tests/test_ytarray.py
@@ -14,7 +14,15 @@
# The full license is in the file COPYING.txt, distributed with this software.
# ----------------------------------------------------------------------------
+import copy
+import cPickle as pickle
+import itertools
+import numpy as np
+import operator
import os
+import shutil
+import tempfile
+
from nose.tools import assert_true
from numpy.testing import \
assert_array_equal, \
@@ -28,12 +36,6 @@
YTUnitOperationError, YTUfuncUnitError
from yt.testing import fake_random_pf, requires_module
from yt.funcs import fix_length
-import numpy as np
-import copy
-import operator
-import cPickle as pickle
-import tempfile
-import itertools
def operate_and_compare(a, b, op, answer):
@@ -675,3 +677,54 @@
yield assert_equal, yt_quan, YTQuantity(yt_quan.to_astropy())
+def test_subclass():
+
+ class YTASubclass(YTArray):
+ pass
+
+ a = YTASubclass([4, 5, 6], 'g')
+ b = YTASubclass([7, 8, 9], 'kg')
+ nu = YTASubclass([10, 11, 12], '')
+ nda = np.array([3, 4, 5])
+ yta = YTArray([6, 7, 8], 'mg')
+ ytq = YTQuantity(4, 'cm')
+ ndf = np.float64(3)
+
+ def op_comparison(op, inst1, inst2, compare_class):
+ assert_isinstance(op(inst1, inst2), compare_class)
+ assert_isinstance(op(inst2, inst1), compare_class)
+
+ for op in (operator.mul, operator.div, operator.truediv):
+ for inst in (b, ytq, ndf, yta, nda):
+ yield op_comparison, op, a, inst, YTASubclass
+
+ yield op_comparison, op, ytq, nda, YTArray
+ yield op_comparison, op, ytq, yta, YTArray
+
+ for op in (operator.add, operator.sub):
+ yield op_comparison, op, nu, nda, YTASubclass
+ yield op_comparison, op, a, b, YTASubclass
+ yield op_comparison, op, a, yta, YTASubclass
+
+ yield assert_isinstance, a[0], YTQuantity
+ yield assert_isinstance, a[:], YTASubclass
+ yield assert_isinstance, a[:2], YTASubclass
+
+def test_h5_io():
+ tmpdir = tempfile.mkdtemp()
+ curdir = os.getcwd()
+ os.chdir(tmpdir)
+
+ ds = fake_random_pf(64, nprocs=1, length_unit=10)
+
+ warr = ds.arr(np.random.random((256, 256)), 'code_length')
+
+ warr.write_hdf5('test.h5')
+
+ iarr = YTArray.from_hdf5('test.h5')
+
+ yield assert_equal, warr, iarr
+ yield assert_equal, warr.units.registry['code_length'], iarr.units.registry['code_length']
+
+ os.chdir(curdir)
+ shutil.rmtree(tmpdir)
diff -r d6681af789175edcbcf4360dd03b1a8ec726ea25 -r b19ce8c82e07be0ebc07f8e8943ac56c1bf40a21 yt/units/yt_array.py
--- a/yt/units/yt_array.py
+++ b/yt/units/yt_array.py
@@ -74,7 +74,8 @@
if ret.shape == ():
return YTQuantity(ret, units)
else:
- return YTArray(ret, units)
+ # This could be a subclass, so don't call YTArray directly.
+ return type(args[0])(ret, units)
return wrapped
def sqrt_unit(unit):
@@ -464,6 +465,92 @@
# End unit conversion methods
#
+ def write_hdf5(self, filename, dataset_name=None, info=None):
+ r"""Writes ImageArray to hdf5 file.
+
+ Parameters
+ ----------
+ filename: string
+ The filename to create and write a dataset to
+
+ dataset_name: string
+ The name of the dataset to create in the file.
+
+ info: dictionary
+ A dictionary of supplementary info to write to append as attributes
+ to the dataset.
+
+ Examples
+ --------
+ >>> a = YTArray([1,2,3], 'cm')
+
+ >>> myinfo = {'field':'dinosaurs', 'type':'field_data'}
+
+ >>> a.write_hdf5('test_array_data.h5', dataset_name='dinosaurs',
+ ... info=myinfo)
+
+ """
+ import h5py
+ from yt.extern.six.moves import cPickle as pickle
+ if info is None:
+ info = {}
+
+ info['units'] = str(self.units)
+ info['unit_registry'] = pickle.dumps(self.units.registry.lut)
+
+ if dataset_name is None:
+ dataset_name = 'array_data'
+
+ f = h5py.File(filename)
+ if dataset_name in f.keys():
+ d = f[dataset_name]
+ # Overwrite without deleting if we can get away with it.
+ if d.shape == self.shape and d.dtype == self.dtype:
+ d[:] = self
+ for k in d.attrs.keys():
+ del d.attrs[k]
+ else:
+ del f[dataset_name]
+ d = f.create_dataset(dataset_name, data=self)
+ else:
+ d = f.create_dataset(dataset_name, data=self)
+
+ for k, v in info.iteritems():
+ d.attrs.create(k, v)
+ f.close()
+
+ @classmethod
+ def from_hdf5(cls, filename, dataset_name=None):
+ r"""Attempts read in and convert a dataset in an hdf5 file into a YTArray.
+
+ Parameters
+ ----------
+ filename: string
+ The filename to of the hdf5 file.
+
+ dataset_name: string
+ The name of the dataset to read from. If the dataset has a units
+ attribute, attempt to infer units as well.
+
+ """
+ import h5py
+ from yt.extern.six.moves import cPickle as pickle
+
+ if dataset_name is None:
+ dataset_name = 'array_data'
+
+ f = h5py.File(filename)
+ dataset = f[dataset_name]
+ data = dataset[:]
+ units = dataset.attrs.get('units', '')
+ if 'unit_registry' in dataset.attrs.keys():
+ unit_lut = pickle.loads(dataset.attrs['unit_registry'])
+ else:
+ unit_lut = None
+
+ registry = UnitRegistry(lut=unit_lut, add_default_symbols=False)
+ return cls(data, units, registry=registry)
+
#
# Start convenience methods
#
@@ -766,7 +853,7 @@
@return_arr
def prod(self, axis=None, dtype=None, out=None):
- if axis:
+ if axis is not None:
units = self.units**self.shape[axis]
else:
units = self.units**self.size
@@ -814,9 +901,13 @@
# Raise YTUnitOperationError up here since we know the context now
except RuntimeError:
raise YTUnitOperationError(context[0], u)
+ ret_class = type(self)
elif context[0] in binary_operators:
unit1 = getattr(context[1][0], 'units', None)
unit2 = getattr(context[1][1], 'units', None)
+ cls1 = type(context[1][0])
+ cls2 = type(context[1][1])
+ ret_class = get_binary_op_return_class(cls1, cls2)
if unit1 is None:
unit1 = Unit(registry=getattr(unit2, 'registry', None))
if unit2 is None and context[0] is not power:
@@ -849,10 +940,15 @@
out_arr = np.array(out_arr)
return out_arr
out_arr.units = unit
- if out_arr.size > 1:
- return YTArray(np.array(out_arr), unit)
+ if out_arr.size == 1:
+ return YTQuantity(np.array(out_arr), unit)
else:
- return YTQuantity(np.array(out_arr), unit)
+ if ret_class is YTQuantity:
+ # This happens if you do ndarray * YTQuantity. Explicitly
+ # casting to YTArray avoids creating a YTQuantity with size > 1
+ return YTArray(np.array(out_arr, unit))
+ return ret_class(np.array(out_arr), unit)
+
def __reduce__(self):
"""Pickle reduction method
@@ -929,3 +1025,22 @@
return data.pf.arr(x, units)
else:
return data.pf.quan(x, units)
+
+def get_binary_op_return_class(cls1, cls2):
+ if cls1 is cls2:
+ return cls1
+ if cls1 is np.ndarray or issubclass(cls1, numeric_type):
+ return cls2
+ if cls2 is np.ndarray or issubclass(cls2, numeric_type):
+ return cls1
+ if issubclass(cls1, YTQuantity):
+ return cls2
+ if issubclass(cls2, YTQuantity):
+ return cls1
+ if issubclass(cls1, cls2):
+ return cls1
+ if issubclass(cls2, cls1):
+ return cls2
+ else:
+ raise RuntimeError("Operations are only defined on pairs of objects"
+ "in which one is a subclass of the other")
Repository URL: https://bitbucket.org/yt_analysis/yt/
--
This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.
More information about the yt-svn
mailing list