[yt-svn] commit/yt: MatthewTurk: Merged in ngoldbaum/yt (pull request #1687)
commits-noreply at bitbucket.org
commits-noreply at bitbucket.org
Thu Aug 13 09:26:59 PDT 2015
1 new commit in yt:
https://bitbucket.org/yt_analysis/yt/commits/a46e96ef9fcb/
Changeset: a46e96ef9fcb
Branch: yt
User: MatthewTurk
Date: 2015-08-13 16:26:47+00:00
Summary: Merged in ngoldbaum/yt (pull request #1687)
[perf] [enhancement] Speedups for iterating over YTArray instances
Affected #: 2 files
diff -r 475280b53ae10379ad79ac02519763085ec634c8 -r a46e96ef9fcbbcd1d4ebccae406938d166212f2a yt/units/tests/test_ytarray.py
--- a/yt/units/tests/test_ytarray.py
+++ b/yt/units/tests/test_ytarray.py
@@ -554,6 +554,17 @@
yield assert_true, a_slice.base is a
+def test_iteration():
+ """
+ Test that iterating over a YTArray returns a sequence of YTQuantity insances
+ """
+ a = np.arange(3)
+ b = YTArray(np.arange(3), 'cm')
+ for ia, ib, in zip(a, b):
+ yield assert_equal, ia, ib.value
+ yield assert_equal, ib.units, b.units
+
+
def test_fix_length():
"""
Test fixing the length of an array. Used in spheres and other data objects
diff -r 475280b53ae10379ad79ac02519763085ec634c8 -r a46e96ef9fcbbcd1d4ebccae406938d166212f2a yt/units/yt_array.py
--- a/yt/units/yt_array.py
+++ b/yt/units/yt_array.py
@@ -135,6 +135,17 @@
raise YTUnitOperationError(op_string, inp.units, dimensionless)
return ret
+def validate_comparison_units(this, other, op_string):
+ # Check that other is a YTArray.
+ if hasattr(other, 'units'):
+ if this.units.expr is other.units.expr:
+ return other
+ if not this.units.same_dimensions_as(other.units):
+ raise YTUnitOperationError(op_string, this.units, other.units)
+ return other.in_units(this.units)
+
+ return other
+
unary_operators = (
negative, absolute, rint, ones_like, sign, conj, exp, exp2, log, log2,
log10, expm1, log1p, sqrt, square, reciprocal, sin, cos, tan, arcsin,
@@ -167,7 +178,13 @@
with a unit registry and this is specified, this will be used instead of
the registry associated with the unit object.
dtype : string or NumPy dtype object
- The dtype of the array data.
+ The dtype of the array data. Defaults to the dtype of the input data,
+ or, if none is found, uses np.float64
+ bypass_validation : boolean
+ If True, all input validation is skipped. Using this option may produce
+ corrupted, invalid units or array data, but can lead to significant
+ speedups in the input validation logic adds significant overhead. If set,
+ input_units *must* be a valid unit object. Defaults to False.
Examples
--------
@@ -286,9 +303,16 @@
__array_priority__ = 2.0
- def __new__(cls, input_array, input_units=None, registry=None, dtype=None):
+ def __new__(cls, input_array, input_units=None, registry=None, dtype=None,
+ bypass_validation=False):
if dtype is None:
dtype = getattr(input_array, 'dtype', np.float64)
+ if bypass_validation is True:
+ obj = np.asarray(input_array, dtype=dtype).view(cls)
+ obj.units = input_units
+ if registry is not None:
+ obj.units.registry = registry
+ return obj
if input_array is NotImplemented:
return input_array
if registry is None and isinstance(input_units, (str, bytes)):
@@ -911,26 +935,13 @@
# @todo: outsource to a single method with an op argument.
def __lt__(self, other):
""" Test if this is less than the object on the right. """
- # Check that other is a YTArray.
- if isinstance(other, YTArray):
- if not self.units.same_dimensions_as(other.units):
- raise YTUnitOperationError('less than', self.units, other.units)
-
- return np.array(self).__lt__(np.array(other.in_units(self.units)))
-
- return np.array(self).__lt__(np.array(other))
+ oth = validate_comparison_units(self, other, 'less_than')
+ return np.array(self).__lt__(np.array(oth))
def __le__(self, other):
""" Test if this is less than or equal to the object on the right. """
- # Check that other is a YTArray.
- if isinstance(other, YTArray):
- if not self.units.same_dimensions_as(other.units):
- raise YTUnitOperationError('less than or equal', self.units,
- other.units)
-
- return np.array(self).__le__(np.array(other.in_units(self.units)))
-
- return np.array(self).__le__(np.array(other))
+ oth = validate_comparison_units(self, other, 'less_than or equal')
+ return np.array(self).__le__(np.array(oth))
def __eq__(self, other):
""" Test if this is equal to the object on the right. """
@@ -938,50 +949,28 @@
if other is None:
# self is a YTArray, so it can't be None.
return False
- if isinstance(other, YTArray):
- if not self.units.same_dimensions_as(other.units):
- raise YTUnitOperationError("equal", self.units, other.units)
-
- return np.array(self).__eq__(np.array(other.in_units(self.units)))
-
- return np.array(self).__eq__(np.array(other))
+ oth = validate_comparison_units(self, other, 'equal')
+ return np.array(self).__eq__(np.array(oth))
def __ne__(self, other):
""" Test if this is not equal to the object on the right. """
# Check that the other is a YTArray.
if other is None:
return True
- if isinstance(other, YTArray):
- if not self.units.same_dimensions_as(other.units):
- raise YTUnitOperationError("not equal", self.units, other.units)
-
- return np.array(self).__ne__(np.array(other.in_units(self.units)))
-
- return np.array(self).__ne__(np.array(other))
+ oth = validate_comparison_units(self, other, 'not equal')
+ return np.array(self).__ne__(np.array(oth))
def __ge__(self, other):
""" Test if this is greater than or equal to other. """
# Check that the other is a YTArray.
- if isinstance(other, YTArray):
- if not self.units.same_dimensions_as(other.units):
- raise YTUnitOperationError("greater than or equal",
- self.units, other.units)
-
- return np.array(self).__ge__(np.array(other.in_units(self.units)))
-
- return np.array(self).__ge__(np.array(other))
+ oth = validate_comparison_units(self, other, 'greater than or equal')
+ return np.array(self).__ge__(np.array(oth))
def __gt__(self, other):
""" Test if this is greater than the object on the right. """
# Check that the other is a YTArray.
- if isinstance(other, YTArray):
- if not self.units.same_dimensions_as(other.units):
- raise YTUnitOperationError("greater than", self.units,
- other.units)
-
- return np.array(self).__gt__(np.array(other.in_units(self.units)))
-
- return np.array(self).__gt__(np.array(other))
+ oth = validate_comparison_units(self, other, 'greater than')
+ return np.array(self).__gt__(np.array(oth))
#
# End comparison operators
@@ -1018,7 +1007,7 @@
def __getitem__(self, item):
ret = super(YTArray, self).__getitem__(item)
if ret.shape == ():
- return YTQuantity(ret, self.units)
+ return YTQuantity(ret, self.units, bypass_validation=True)
else:
return ret
@@ -1185,11 +1174,11 @@
"""
def __new__(cls, input_scalar, input_units=None, registry=None,
- dtype=np.float64):
+ dtype=np.float64, bypass_validation=False):
if not isinstance(input_scalar, (numeric_type, np.number, np.ndarray)):
raise RuntimeError("YTQuantity values must be numeric")
ret = YTArray.__new__(cls, input_scalar, input_units, registry,
- dtype=dtype)
+ dtype=dtype, bypass_validation=bypass_validation)
if ret.size > 1:
raise RuntimeError("YTQuantity instances must be scalars")
return ret
Repository URL: https://bitbucket.org/yt_analysis/yt/
--
This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.
More information about the yt-svn
mailing list