[yt-svn] commit/yt: MatthewTurk: Merged in ngoldbaum/yt (pull request #1687)

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Thu Aug 13 09:26:59 PDT 2015


1 new commit in yt:

https://bitbucket.org/yt_analysis/yt/commits/a46e96ef9fcb/
Changeset:   a46e96ef9fcb
Branch:      yt
User:        MatthewTurk
Date:        2015-08-13 16:26:47+00:00
Summary:     Merged in ngoldbaum/yt (pull request #1687)

[perf] [enhancement] Speedups for iterating over YTArray instances
Affected #:  2 files

diff -r 475280b53ae10379ad79ac02519763085ec634c8 -r a46e96ef9fcbbcd1d4ebccae406938d166212f2a yt/units/tests/test_ytarray.py
--- a/yt/units/tests/test_ytarray.py
+++ b/yt/units/tests/test_ytarray.py
@@ -554,6 +554,17 @@
     yield assert_true, a_slice.base is a
 
 
+def test_iteration():
+    """
+    Test that iterating over a YTArray returns a sequence of YTQuantity insances
+    """
+    a = np.arange(3)
+    b = YTArray(np.arange(3), 'cm')
+    for ia, ib, in zip(a, b):
+        yield assert_equal, ia, ib.value
+        yield assert_equal, ib.units, b.units
+
+
 def test_fix_length():
     """
     Test fixing the length of an array. Used in spheres and other data objects

diff -r 475280b53ae10379ad79ac02519763085ec634c8 -r a46e96ef9fcbbcd1d4ebccae406938d166212f2a yt/units/yt_array.py
--- a/yt/units/yt_array.py
+++ b/yt/units/yt_array.py
@@ -135,6 +135,17 @@
             raise YTUnitOperationError(op_string, inp.units, dimensionless)
     return ret
 
+def validate_comparison_units(this, other, op_string):
+    # Check that other is a YTArray.
+    if hasattr(other, 'units'):
+        if this.units.expr is other.units.expr:
+            return other
+        if not this.units.same_dimensions_as(other.units):
+            raise YTUnitOperationError(op_string, this.units, other.units)
+        return other.in_units(this.units)
+
+    return other
+
 unary_operators = (
     negative, absolute, rint, ones_like, sign, conj, exp, exp2, log, log2,
     log10, expm1, log1p, sqrt, square, reciprocal, sin, cos, tan, arcsin,
@@ -167,7 +178,13 @@
         with a unit registry and this is specified, this will be used instead of
         the registry associated with the unit object.
     dtype : string or NumPy dtype object
-        The dtype of the array data.
+        The dtype of the array data. Defaults to the dtype of the input data,
+        or, if none is found, uses np.float64
+    bypass_validation : boolean
+        If True, all input validation is skipped. Using this option may produce
+        corrupted, invalid units or array data, but can lead to significant
+        speedups in the input validation logic adds significant overhead. If set,
+        input_units *must* be a valid unit object. Defaults to False.
 
     Examples
     --------
@@ -286,9 +303,16 @@
 
     __array_priority__ = 2.0
 
-    def __new__(cls, input_array, input_units=None, registry=None, dtype=None):
+    def __new__(cls, input_array, input_units=None, registry=None, dtype=None,
+                bypass_validation=False):
         if dtype is None:
             dtype = getattr(input_array, 'dtype', np.float64)
+        if bypass_validation is True:
+            obj = np.asarray(input_array, dtype=dtype).view(cls)
+            obj.units = input_units
+            if registry is not None:
+                obj.units.registry = registry
+            return obj
         if input_array is NotImplemented:
             return input_array
         if registry is None and isinstance(input_units, (str, bytes)):
@@ -911,26 +935,13 @@
     # @todo: outsource to a single method with an op argument.
     def __lt__(self, other):
         """ Test if this is less than the object on the right. """
-        # Check that other is a YTArray.
-        if isinstance(other, YTArray):
-            if not self.units.same_dimensions_as(other.units):
-                raise YTUnitOperationError('less than', self.units, other.units)
-
-            return np.array(self).__lt__(np.array(other.in_units(self.units)))
-
-        return np.array(self).__lt__(np.array(other))
+        oth = validate_comparison_units(self, other, 'less_than')
+        return np.array(self).__lt__(np.array(oth))
 
     def __le__(self, other):
         """ Test if this is less than or equal to the object on the right. """
-        # Check that other is a YTArray.
-        if isinstance(other, YTArray):
-            if not self.units.same_dimensions_as(other.units):
-                raise YTUnitOperationError('less than or equal', self.units,
-                                           other.units)
-
-            return np.array(self).__le__(np.array(other.in_units(self.units)))
-
-        return np.array(self).__le__(np.array(other))
+        oth = validate_comparison_units(self, other, 'less_than or equal')
+        return np.array(self).__le__(np.array(oth))
 
     def __eq__(self, other):
         """ Test if this is equal to the object on the right. """
@@ -938,50 +949,28 @@
         if other is None:
             # self is a YTArray, so it can't be None.
             return False
-        if isinstance(other, YTArray):
-            if not self.units.same_dimensions_as(other.units):
-                raise YTUnitOperationError("equal", self.units, other.units)
-
-            return np.array(self).__eq__(np.array(other.in_units(self.units)))
-
-        return np.array(self).__eq__(np.array(other))
+        oth = validate_comparison_units(self, other, 'equal')
+        return np.array(self).__eq__(np.array(oth))
 
     def __ne__(self, other):
         """ Test if this is not equal to the object on the right. """
         # Check that the other is a YTArray.
         if other is None:
             return True
-        if isinstance(other, YTArray):
-            if not self.units.same_dimensions_as(other.units):
-                raise YTUnitOperationError("not equal", self.units, other.units)
-
-            return np.array(self).__ne__(np.array(other.in_units(self.units)))
-
-        return np.array(self).__ne__(np.array(other))
+        oth = validate_comparison_units(self, other, 'not equal')
+        return np.array(self).__ne__(np.array(oth))
 
     def __ge__(self, other):
         """ Test if this is greater than or equal to other. """
         # Check that the other is a YTArray.
-        if isinstance(other, YTArray):
-            if not self.units.same_dimensions_as(other.units):
-                raise YTUnitOperationError("greater than or equal",
-                                           self.units, other.units)
-
-            return np.array(self).__ge__(np.array(other.in_units(self.units)))
-
-        return np.array(self).__ge__(np.array(other))
+        oth = validate_comparison_units(self, other, 'greater than or equal')
+        return np.array(self).__ge__(np.array(oth))
 
     def __gt__(self, other):
         """ Test if this is greater than the object on the right. """
         # Check that the other is a YTArray.
-        if isinstance(other, YTArray):
-            if not self.units.same_dimensions_as(other.units):
-                raise YTUnitOperationError("greater than", self.units,
-                                           other.units)
-
-            return np.array(self).__gt__(np.array(other.in_units(self.units)))
-
-        return np.array(self).__gt__(np.array(other))
+        oth = validate_comparison_units(self, other, 'greater than')
+        return np.array(self).__gt__(np.array(oth))
 
     #
     # End comparison operators
@@ -1018,7 +1007,7 @@
     def __getitem__(self, item):
         ret = super(YTArray, self).__getitem__(item)
         if ret.shape == ():
-            return YTQuantity(ret, self.units)
+            return YTQuantity(ret, self.units, bypass_validation=True)
         else:
             return ret
 
@@ -1185,11 +1174,11 @@
 
     """
     def __new__(cls, input_scalar, input_units=None, registry=None,
-                dtype=np.float64):
+                dtype=np.float64, bypass_validation=False):
         if not isinstance(input_scalar, (numeric_type, np.number, np.ndarray)):
             raise RuntimeError("YTQuantity values must be numeric")
         ret = YTArray.__new__(cls, input_scalar, input_units, registry,
-                              dtype=dtype)
+                              dtype=dtype, bypass_validation=bypass_validation)
         if ret.size > 1:
             raise RuntimeError("YTQuantity instances must be scalars")
         return ret

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.


More information about the yt-svn mailing list