[yt-svn] commit/yt: 9 new changesets

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Mon Jun 26 12:50:12 PDT 2017


9 new commits in yt:

https://bitbucket.org/yt_analysis/yt/commits/48687b565a4b/
Changeset:   48687b565a4b
User:        ngoldbaum
Date:        2017-06-22 22:07:58+00:00
Summary:     fix issues with adding dimensionless data to data with units
Affected #:  2 files

diff -r b5fa47f8c92aaee458b551ee225fb6d22b3d5cd2 -r 48687b565a4b0d91acf5b33e97fde92f8fb76ecb yt/units/tests/test_ytarray.py
--- a/yt/units/tests/test_ytarray.py
+++ b/yt/units/tests/test_ytarray.py
@@ -119,9 +119,19 @@
     # Catch the different dimensions error
     a1 = YTArray([1, 2, 3], 'm')
     a2 = YTArray([4, 5, 6], 'kg')
+    a3 = [7, 8, 9]
+    a4 = YTArray([10, 11, 12], '')
 
     assert_raises(YTUnitOperationError, operator.add, a1, a2)
     assert_raises(YTUnitOperationError, operator.iadd, a1, a2)
+    assert_raises(YTUnitOperationError, operator.add, a1, a3)
+    assert_raises(YTUnitOperationError, operator.iadd, a1, a3)
+    assert_raises(YTUnitOperationError, operator.add, a3, a1)
+    assert_raises(YTUnitOperationError, operator.iadd, a3, a1)
+    assert_raises(YTUnitOperationError, operator.add, a1, a4)
+    assert_raises(YTUnitOperationError, operator.iadd, a1, a4)
+    assert_raises(YTUnitOperationError, operator.add, a4, a1)
+    assert_raises(YTUnitOperationError, operator.iadd, a4, a1)
 
     # adding with zero is allowed irrespective of the units
     zeros = np.zeros(3)
@@ -200,9 +210,19 @@
     # Catch the different dimensions error
     a1 = YTArray([1, 2, 3], 'm')
     a2 = YTArray([4, 5, 6], 'kg')
+    a3 = [7, 8, 9]
+    a4 = YTArray([10, 11, 12], '')
 
     assert_raises(YTUnitOperationError, operator.sub, a1, a2)
     assert_raises(YTUnitOperationError, operator.isub, a1, a2)
+    assert_raises(YTUnitOperationError, operator.sub, a1, a3)
+    assert_raises(YTUnitOperationError, operator.isub, a1, a3)
+    assert_raises(YTUnitOperationError, operator.sub, a3, a1)
+    assert_raises(YTUnitOperationError, operator.isub, a3, a1)
+    assert_raises(YTUnitOperationError, operator.sub, a1, a4)
+    assert_raises(YTUnitOperationError, operator.isub, a1, a4)
+    assert_raises(YTUnitOperationError, operator.sub, a4, a1)
+    assert_raises(YTUnitOperationError, operator.isub, a4, a1)
 
     # subtracting with zero is allowed irrespective of the units
     zeros = np.zeros(3)

diff -r b5fa47f8c92aaee458b551ee225fb6d22b3d5cd2 -r 48687b565a4b0d91acf5b33e97fde92f8fb76ecb yt/units/yt_array.py
--- a/yt/units/yt_array.py
+++ b/yt/units/yt_array.py
@@ -156,9 +156,21 @@
             unit2 = 1.0
     return (inp1, inp2), (unit1, unit2), ret_class
 
-def handle_preserve_units(inps, units, ufunc, ret_class, raise_error=False):
-    # Allow comparisons, addition, and subtraction with
-    # dimensionless quantities or arrays filled with zeros.
+def handle_preserve_units(inps, units, ufunc, ret_class):
+    if units[0] != units[1]:
+        any_nonzero = [np.any(inps[0]), np.any(inps[1])]
+        if any_nonzero[0] == np.bool_(False):
+            units = (units[1], units[1])
+        elif any_nonzero[1] == np.bool_(False):
+            units = (units[0], units[0])
+        else:
+            if not units[0].same_dimensions_as(units[1]):
+                raise YTUnitOperationError(ufunc, *units)
+            inps = (inps[0], ret_class(inps[1]).to(
+                ret_class(inps[0]).units))
+    return inps, units
+
+def handle_comparison_units(inps, units, ufunc, ret_class, raise_error=False):
     if units[0] != units[1]:
         u1d = units[0].is_dimensionless
         u2d = units[1].is_dimensionless
@@ -1243,7 +1255,7 @@
                 inps, units, ret_class = get_inp_u_binary(ufunc, inputs)
                 if unit_operator in (preserve_units, comparison_unit,
                                      arctan2_unit):
-                    inps, units = handle_preserve_units(
+                    inps, units = handle_comparison_units(
                         inps, units, ufunc, ret_class, raise_error=True)
                 unit = unit_operator(*units)
                 if unit_operator in (multiply_units, divide_units):
@@ -1291,10 +1303,12 @@
             elif len(inputs) == 2:
                 unit_operator = self._ufunc_registry[ufunc]
                 inps, units, ret_class = get_inp_u_binary(ufunc, inputs)
-                if unit_operator in (preserve_units, comparison_unit,
-                                     arctan2_unit):
+                if unit_operator in (comparison_unit, arctan2_unit):
+                    inps, units = handle_comparison_units(
+                        inps, units, ufunc, ret_class)
+                elif unit_operator is preserve_units:
                     inps, units = handle_preserve_units(
-                        inps, units, ufunc, ret_class)
+                         inps, units, ufunc, ret_class)
                 unit = unit_operator(*units)
                 out_arr = func(np.asarray(inps[0]), np.asarray(inps[1]),
                                out=out, **kwargs)


https://bitbucket.org/yt_analysis/yt/commits/0c586360554c/
Changeset:   0c586360554c
User:        ngoldbaum
Date:        2017-06-22 23:34:33+00:00
Summary:     fix issues in the ply exporter revealed by previous patch
Affected #:  1 file

diff -r 48687b565a4b0d91acf5b33e97fde92f8fb76ecb -r 0c586360554c6a0dc047e6dee70293aae95f05ed yt/data_objects/construction_data_containers.py
--- a/yt/data_objects/construction_data_containers.py
+++ b/yt/data_objects/construction_data_containers.py
@@ -62,6 +62,7 @@
 from yt.fields.field_exceptions import \
     NeedsOriginalGrid
 from yt.frontends.stream.api import load_uniform_grid
+from yt.units.yt_array import YTArray
 import yt.extern.six as six
 
 class YTStreamline(YTSelectionContainer1D):
@@ -1781,6 +1782,11 @@
             DLE = self.ds.domain_left_edge
             DRE = self.ds.domain_right_edge
             bounds = [(DLE[i], DRE[i]) for i in range(3)]
+        elif any([not all([isinstance(be, YTArray) for be in b])
+                  for b in bounds]):
+            bounds = [tuple(be if isinstance(be, YTArray) else
+                            self.ds.quan(be, 'code_length') for be in b)
+                      for b in bounds]
         nv = self.vertices.shape[1]
         vs = [("x", "<f"), ("y", "<f"), ("z", "<f"),
               ("red", "uint8"), ("green", "uint8"), ("blue", "uint8") ]


https://bitbucket.org/yt_analysis/yt/commits/0b76a25d6e2f/
Changeset:   0b76a25d6e2f
User:        ngoldbaum
Date:        2017-06-22 23:37:35+00:00
Summary:     fix units issue in projection data object revealed by units fix
Affected #:  1 file

diff -r 0c586360554c6a0dc047e6dee70293aae95f05ed -r 0b76a25d6e2f777a4bbcdeb7d0db89b46a1b9bfb yt/data_objects/construction_data_containers.py
--- a/yt/data_objects/construction_data_containers.py
+++ b/yt/data_objects/construction_data_containers.py
@@ -354,8 +354,8 @@
         # TODO: Add the combine operation
         xax = self.ds.coordinates.x_axis[self.axis]
         yax = self.ds.coordinates.y_axis[self.axis]
-        ox = self.ds.domain_left_edge[xax]
-        oy = self.ds.domain_left_edge[yax]
+        ox = self.ds.domain_left_edge[xax].v
+        oy = self.ds.domain_left_edge[yax].v
         px, py, pdx, pdy, nvals, nwvals = tree.get_all(False, merge_style)
         nvals = self.comm.mpi_allreduce(nvals, op=op)
         nwvals = self.comm.mpi_allreduce(nwvals, op=op)


https://bitbucket.org/yt_analysis/yt/commits/adf3d0c62fc3/
Changeset:   adf3d0c62fc3
User:        ngoldbaum
Date:        2017-06-23 21:00:13+00:00
Summary:     fix more tests
Affected #:  2 files

diff -r 0b76a25d6e2f777a4bbcdeb7d0db89b46a1b9bfb -r adf3d0c62fc35841184a3dbb47115ee29aea8d68 yt/analysis_modules/absorption_spectrum/absorption_spectrum.py
--- a/yt/analysis_modules/absorption_spectrum/absorption_spectrum.py
+++ b/yt/analysis_modules/absorption_spectrum/absorption_spectrum.py
@@ -365,7 +365,7 @@
                                    this_wavelength[lixel]), \
                               continuum['index']) * \
                     (column_density[lixel] / continuum['normalization'])
-                self.tau_field[left_index[lixel]:right_index[lixel]] += cont_tau
+                self.tau_field[left_index[lixel]:right_index[lixel]] += cont_tau.d
                 pbar.update(i)
             pbar.finish()
 

diff -r 0b76a25d6e2f777a4bbcdeb7d0db89b46a1b9bfb -r adf3d0c62fc35841184a3dbb47115ee29aea8d68 yt/analysis_modules/halo_finding/halo_objects.py
--- a/yt/analysis_modules/halo_finding/halo_objects.py
+++ b/yt/analysis_modules/halo_finding/halo_objects.py
@@ -1281,7 +1281,7 @@
 
         def haloCmp(h1, h2):
             def cmp(a, b):
-                return (a > b) - (a < b)
+                return (a > b) ^ (a < b)
             c = cmp(h1.total_mass(), h2.total_mass())
             if c != 0:
                 return -1 * c


https://bitbucket.org/yt_analysis/yt/commits/c9aa2ef75915/
Changeset:   c9aa2ef75915
User:        ngoldbaum
Date:        2017-06-23 21:05:32+00:00
Summary:     bump unstructured mesh answer numbers
Affected #:  1 file

diff -r adf3d0c62fc35841184a3dbb47115ee29aea8d68 -r c9aa2ef759150f47d057eb21886f9b7faa290920 tests/tests.yaml
--- a/tests/tests.yaml
+++ b/tests/tests.yaml
@@ -64,7 +64,7 @@
     - yt/analysis_modules/photon_simulator/tests/test_spectra.py
     - yt/analysis_modules/photon_simulator/tests/test_sloshing.py
 
-  local_unstructured_005:
+  local_unstructured_006:
     - yt/visualization/volume_rendering/tests/test_mesh_render.py
     - yt/visualization/tests/test_mesh_slices.py:test_tri2
     - yt/visualization/tests/test_mesh_slices.py:test_quad2


https://bitbucket.org/yt_analysis/yt/commits/54918985df0f/
Changeset:   54918985df0f
User:        ngoldbaum
Date:        2017-06-25 22:02:12+00:00
Summary:     fix more test failures
Affected #:  2 files

diff -r c9aa2ef759150f47d057eb21886f9b7faa290920 -r 54918985df0fb9a024128529f9c244474c4518ed yt/analysis_modules/photon_simulator/photon_models.py
--- a/yt/analysis_modules/photon_simulator/photon_models.py
+++ b/yt/analysis_modules/photon_simulator/photon_models.py
@@ -224,7 +224,7 @@
                         tot_spec *= norm_factor
                         eidxs = self.prng.choice(nchan, size=cn, p=tot_spec)
                         cell_e = emid[eidxs]
-                    energies[ei:ei+cn] = cell_e
+                    energies[int(ei):int(ei + cn)] = cell_e
                     cell_counter += 1
                     pbar.update(cell_counter)
                     ei += cn

diff -r c9aa2ef759150f47d057eb21886f9b7faa290920 -r 54918985df0fb9a024128529f9c244474c4518ed yt/frontends/owls_subfind/io.py
--- a/yt/frontends/owls_subfind/io.py
+++ b/yt/frontends/owls_subfind/io.py
@@ -105,7 +105,7 @@
                                 field_data = f[ptype][fname].value.astype("float64")
                                 my_div = field_data.size / pcount
                                 if my_div > 1:
-                                    field_data = np.resize(field_data, (pcount, my_div))
+                                    field_data = np.resize(field_data, (int(pcount), int(my_div)))
                                     findex = int(field[field.rfind("_") + 1:])
                                     field_data = field_data[:, findex]
                         data = field_data[mask]


https://bitbucket.org/yt_analysis/yt/commits/cf3bac3a44ac/
Changeset:   cf3bac3a44ac
User:        ngoldbaum
Date:        2017-06-25 23:17:07+00:00
Summary:     attempt fix photon simulator test failure
Affected #:  1 file

diff -r 54918985df0fb9a024128529f9c244474c4518ed -r cf3bac3a44acb440e1068082fe0e0dfacf17f2a8 yt/analysis_modules/photon_simulator/photon_simulator.py
--- a/yt/analysis_modules/photon_simulator/photon_simulator.py
+++ b/yt/analysis_modules/photon_simulator/photon_simulator.py
@@ -678,9 +678,9 @@
             x *= delta
             y *= delta
             z *= delta
-            x += self.photons["x"][obs_cells]
-            y += self.photons["y"][obs_cells]
-            z += self.photons["z"][obs_cells]
+            x += self.photons["x"][obs_cells].d
+            y += self.photons["y"][obs_cells].d
+            z += self.photons["z"][obs_cells].d
 
             xsky = x*x_hat[0] + y*x_hat[1] + z*x_hat[2]
             ysky = x*y_hat[0] + y*y_hat[1] + z*y_hat[2]


https://bitbucket.org/yt_analysis/yt/commits/abbb0b239edd/
Changeset:   abbb0b239edd
User:        ngoldbaum
Date:        2017-06-26 01:42:51+00:00
Summary:     bump PlotWindow answer number as well
Affected #:  1 file

diff -r cf3bac3a44acb440e1068082fe0e0dfacf17f2a8 -r abbb0b239edd49d9e6935a4d644c1efeaeccfe98 tests/tests.yaml
--- a/tests/tests.yaml
+++ b/tests/tests.yaml
@@ -42,7 +42,7 @@
   local_owls_001:
     - yt/frontends/owls/tests/test_outputs.py
 
-  local_pw_015:
+  local_pw_016:
     - yt/visualization/tests/test_plotwindow.py:test_attributes
     - yt/visualization/tests/test_plotwindow.py:test_attributes_wt
     - yt/visualization/tests/test_profile_plots.py:test_phase_plot_attributes


https://bitbucket.org/yt_analysis/yt/commits/fc58054806fe/
Changeset:   fc58054806fe
User:        xarthisius
Date:        2017-06-26 19:49:36+00:00
Summary:     Merge pull request #1466 from ngoldbaum/units-fix

fix issues with adding dimensionless data to data with units
Affected #:  9 files

diff -r 2fc4b8bc2f00c540b11c37bd82aefb39ab1537a0 -r fc58054806fe06993c0fc585939e8fd48552c117 tests/tests.yaml
--- a/tests/tests.yaml
+++ b/tests/tests.yaml
@@ -42,7 +42,7 @@
   local_owls_001:
     - yt/frontends/owls/tests/test_outputs.py
 
-  local_pw_015:
+  local_pw_016:
     - yt/visualization/tests/test_plotwindow.py:test_attributes
     - yt/visualization/tests/test_plotwindow.py:test_attributes_wt
     - yt/visualization/tests/test_profile_plots.py:test_phase_plot_attributes
@@ -64,7 +64,7 @@
     - yt/analysis_modules/photon_simulator/tests/test_spectra.py
     - yt/analysis_modules/photon_simulator/tests/test_sloshing.py
 
-  local_unstructured_005:
+  local_unstructured_006:
     - yt/visualization/volume_rendering/tests/test_mesh_render.py
     - yt/visualization/tests/test_mesh_slices.py:test_tri2
     - yt/visualization/tests/test_mesh_slices.py:test_quad2

diff -r 2fc4b8bc2f00c540b11c37bd82aefb39ab1537a0 -r fc58054806fe06993c0fc585939e8fd48552c117 yt/analysis_modules/absorption_spectrum/absorption_spectrum.py
--- a/yt/analysis_modules/absorption_spectrum/absorption_spectrum.py
+++ b/yt/analysis_modules/absorption_spectrum/absorption_spectrum.py
@@ -365,7 +365,7 @@
                                    this_wavelength[lixel]), \
                               continuum['index']) * \
                     (column_density[lixel] / continuum['normalization'])
-                self.tau_field[left_index[lixel]:right_index[lixel]] += cont_tau
+                self.tau_field[left_index[lixel]:right_index[lixel]] += cont_tau.d
                 pbar.update(i)
             pbar.finish()
 

diff -r 2fc4b8bc2f00c540b11c37bd82aefb39ab1537a0 -r fc58054806fe06993c0fc585939e8fd48552c117 yt/analysis_modules/halo_finding/halo_objects.py
--- a/yt/analysis_modules/halo_finding/halo_objects.py
+++ b/yt/analysis_modules/halo_finding/halo_objects.py
@@ -1281,7 +1281,7 @@
 
         def haloCmp(h1, h2):
             def cmp(a, b):
-                return (a > b) - (a < b)
+                return (a > b) ^ (a < b)
             c = cmp(h1.total_mass(), h2.total_mass())
             if c != 0:
                 return -1 * c

diff -r 2fc4b8bc2f00c540b11c37bd82aefb39ab1537a0 -r fc58054806fe06993c0fc585939e8fd48552c117 yt/analysis_modules/photon_simulator/photon_models.py
--- a/yt/analysis_modules/photon_simulator/photon_models.py
+++ b/yt/analysis_modules/photon_simulator/photon_models.py
@@ -224,7 +224,7 @@
                         tot_spec *= norm_factor
                         eidxs = self.prng.choice(nchan, size=cn, p=tot_spec)
                         cell_e = emid[eidxs]
-                    energies[ei:ei+cn] = cell_e
+                    energies[int(ei):int(ei + cn)] = cell_e
                     cell_counter += 1
                     pbar.update(cell_counter)
                     ei += cn

diff -r 2fc4b8bc2f00c540b11c37bd82aefb39ab1537a0 -r fc58054806fe06993c0fc585939e8fd48552c117 yt/analysis_modules/photon_simulator/photon_simulator.py
--- a/yt/analysis_modules/photon_simulator/photon_simulator.py
+++ b/yt/analysis_modules/photon_simulator/photon_simulator.py
@@ -678,9 +678,9 @@
             x *= delta
             y *= delta
             z *= delta
-            x += self.photons["x"][obs_cells]
-            y += self.photons["y"][obs_cells]
-            z += self.photons["z"][obs_cells]
+            x += self.photons["x"][obs_cells].d
+            y += self.photons["y"][obs_cells].d
+            z += self.photons["z"][obs_cells].d
 
             xsky = x*x_hat[0] + y*x_hat[1] + z*x_hat[2]
             ysky = x*y_hat[0] + y*y_hat[1] + z*y_hat[2]

diff -r 2fc4b8bc2f00c540b11c37bd82aefb39ab1537a0 -r fc58054806fe06993c0fc585939e8fd48552c117 yt/data_objects/construction_data_containers.py
--- a/yt/data_objects/construction_data_containers.py
+++ b/yt/data_objects/construction_data_containers.py
@@ -62,6 +62,7 @@
 from yt.fields.field_exceptions import \
     NeedsOriginalGrid
 from yt.frontends.stream.api import load_uniform_grid
+from yt.units.yt_array import YTArray
 import yt.extern.six as six
 
 class YTStreamline(YTSelectionContainer1D):
@@ -353,8 +354,8 @@
         # TODO: Add the combine operation
         xax = self.ds.coordinates.x_axis[self.axis]
         yax = self.ds.coordinates.y_axis[self.axis]
-        ox = self.ds.domain_left_edge[xax]
-        oy = self.ds.domain_left_edge[yax]
+        ox = self.ds.domain_left_edge[xax].v
+        oy = self.ds.domain_left_edge[yax].v
         px, py, pdx, pdy, nvals, nwvals = tree.get_all(False, merge_style)
         nvals = self.comm.mpi_allreduce(nvals, op=op)
         nwvals = self.comm.mpi_allreduce(nwvals, op=op)
@@ -1781,6 +1782,11 @@
             DLE = self.ds.domain_left_edge
             DRE = self.ds.domain_right_edge
             bounds = [(DLE[i], DRE[i]) for i in range(3)]
+        elif any([not all([isinstance(be, YTArray) for be in b])
+                  for b in bounds]):
+            bounds = [tuple(be if isinstance(be, YTArray) else
+                            self.ds.quan(be, 'code_length') for be in b)
+                      for b in bounds]
         nv = self.vertices.shape[1]
         vs = [("x", "<f"), ("y", "<f"), ("z", "<f"),
               ("red", "uint8"), ("green", "uint8"), ("blue", "uint8") ]

diff -r 2fc4b8bc2f00c540b11c37bd82aefb39ab1537a0 -r fc58054806fe06993c0fc585939e8fd48552c117 yt/frontends/owls_subfind/io.py
--- a/yt/frontends/owls_subfind/io.py
+++ b/yt/frontends/owls_subfind/io.py
@@ -105,7 +105,7 @@
                                 field_data = f[ptype][fname].value.astype("float64")
                                 my_div = field_data.size / pcount
                                 if my_div > 1:
-                                    field_data = np.resize(field_data, (pcount, my_div))
+                                    field_data = np.resize(field_data, (int(pcount), int(my_div)))
                                     findex = int(field[field.rfind("_") + 1:])
                                     field_data = field_data[:, findex]
                         data = field_data[mask]

diff -r 2fc4b8bc2f00c540b11c37bd82aefb39ab1537a0 -r fc58054806fe06993c0fc585939e8fd48552c117 yt/units/tests/test_ytarray.py
--- a/yt/units/tests/test_ytarray.py
+++ b/yt/units/tests/test_ytarray.py
@@ -119,9 +119,19 @@
     # Catch the different dimensions error
     a1 = YTArray([1, 2, 3], 'm')
     a2 = YTArray([4, 5, 6], 'kg')
+    a3 = [7, 8, 9]
+    a4 = YTArray([10, 11, 12], '')
 
     assert_raises(YTUnitOperationError, operator.add, a1, a2)
     assert_raises(YTUnitOperationError, operator.iadd, a1, a2)
+    assert_raises(YTUnitOperationError, operator.add, a1, a3)
+    assert_raises(YTUnitOperationError, operator.iadd, a1, a3)
+    assert_raises(YTUnitOperationError, operator.add, a3, a1)
+    assert_raises(YTUnitOperationError, operator.iadd, a3, a1)
+    assert_raises(YTUnitOperationError, operator.add, a1, a4)
+    assert_raises(YTUnitOperationError, operator.iadd, a1, a4)
+    assert_raises(YTUnitOperationError, operator.add, a4, a1)
+    assert_raises(YTUnitOperationError, operator.iadd, a4, a1)
 
     # adding with zero is allowed irrespective of the units
     zeros = np.zeros(3)
@@ -200,9 +210,19 @@
     # Catch the different dimensions error
     a1 = YTArray([1, 2, 3], 'm')
     a2 = YTArray([4, 5, 6], 'kg')
+    a3 = [7, 8, 9]
+    a4 = YTArray([10, 11, 12], '')
 
     assert_raises(YTUnitOperationError, operator.sub, a1, a2)
     assert_raises(YTUnitOperationError, operator.isub, a1, a2)
+    assert_raises(YTUnitOperationError, operator.sub, a1, a3)
+    assert_raises(YTUnitOperationError, operator.isub, a1, a3)
+    assert_raises(YTUnitOperationError, operator.sub, a3, a1)
+    assert_raises(YTUnitOperationError, operator.isub, a3, a1)
+    assert_raises(YTUnitOperationError, operator.sub, a1, a4)
+    assert_raises(YTUnitOperationError, operator.isub, a1, a4)
+    assert_raises(YTUnitOperationError, operator.sub, a4, a1)
+    assert_raises(YTUnitOperationError, operator.isub, a4, a1)
 
     # subtracting with zero is allowed irrespective of the units
     zeros = np.zeros(3)

diff -r 2fc4b8bc2f00c540b11c37bd82aefb39ab1537a0 -r fc58054806fe06993c0fc585939e8fd48552c117 yt/units/yt_array.py
--- a/yt/units/yt_array.py
+++ b/yt/units/yt_array.py
@@ -156,9 +156,21 @@
             unit2 = 1.0
     return (inp1, inp2), (unit1, unit2), ret_class
 
-def handle_preserve_units(inps, units, ufunc, ret_class, raise_error=False):
-    # Allow comparisons, addition, and subtraction with
-    # dimensionless quantities or arrays filled with zeros.
+def handle_preserve_units(inps, units, ufunc, ret_class):
+    if units[0] != units[1]:
+        any_nonzero = [np.any(inps[0]), np.any(inps[1])]
+        if any_nonzero[0] == np.bool_(False):
+            units = (units[1], units[1])
+        elif any_nonzero[1] == np.bool_(False):
+            units = (units[0], units[0])
+        else:
+            if not units[0].same_dimensions_as(units[1]):
+                raise YTUnitOperationError(ufunc, *units)
+            inps = (inps[0], ret_class(inps[1]).to(
+                ret_class(inps[0]).units))
+    return inps, units
+
+def handle_comparison_units(inps, units, ufunc, ret_class, raise_error=False):
     if units[0] != units[1]:
         u1d = units[0].is_dimensionless
         u2d = units[1].is_dimensionless
@@ -1243,7 +1255,7 @@
                 inps, units, ret_class = get_inp_u_binary(ufunc, inputs)
                 if unit_operator in (preserve_units, comparison_unit,
                                      arctan2_unit):
-                    inps, units = handle_preserve_units(
+                    inps, units = handle_comparison_units(
                         inps, units, ufunc, ret_class, raise_error=True)
                 unit = unit_operator(*units)
                 if unit_operator in (multiply_units, divide_units):
@@ -1291,10 +1303,12 @@
             elif len(inputs) == 2:
                 unit_operator = self._ufunc_registry[ufunc]
                 inps, units, ret_class = get_inp_u_binary(ufunc, inputs)
-                if unit_operator in (preserve_units, comparison_unit,
-                                     arctan2_unit):
+                if unit_operator in (comparison_unit, arctan2_unit):
+                    inps, units = handle_comparison_units(
+                        inps, units, ufunc, ret_class)
+                elif unit_operator is preserve_units:
                     inps, units = handle_preserve_units(
-                        inps, units, ufunc, ret_class)
+                         inps, units, ufunc, ret_class)
                 unit = unit_operator(*units)
                 out_arr = func(np.asarray(inps[0]), np.asarray(inps[1]),
                                out=out, **kwargs)

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.


More information about the yt-svn mailing list