[yt-svn] commit/yt: brittonsmith: Merged in ngoldbaum/yt (pull request #1950)

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Thu Jan 28 09:08:19 PST 2016


1 new commit in yt:

https://bitbucket.org/yt_analysis/yt/commits/e14103b2c595/
Changeset:   e14103b2c595
Branch:      yt
User:        brittonsmith
Date:        2016-01-28 17:08:14+00:00
Summary:     Merged in ngoldbaum/yt (pull request #1950)

Fix a number of issues with ds.find_field_values_at_point[s]. Closes #1161
Affected #:  3 files

diff -r bc0cc00c150c6e5fca1b85310935dfe9ccd905a2 -r e14103b2c5950477b6ed23c988277d03c9db408a yt/data_objects/static_output.py
--- a/yt/data_objects/static_output.py
+++ b/yt/data_objects/static_output.py
@@ -665,7 +665,15 @@
         coordinates. Returns a list of field values in the same order as
         the input *fields*.
         """
-        return self.point(coords)[fields]
+        point = self.point(coords)
+        ret = []
+        field_list = ensure_list(fields)
+        for field in field_list:
+            ret.append(point[field])
+        if len(field_list) == 1:
+            return ret[0]
+        else:
+            return ret
 
     def find_field_values_at_points(self, fields, coords):
         """
@@ -673,19 +681,26 @@
         [(x1, y1, z2), (x2, y2, z2),...] points.  Returns a list of field
         values in the same order as the input *fields*.
 
-        This is quite slow right now as it creates a new data object for each
-        point.  If an optimized version exists on the Index object we'll use
-        that instead.
         """
-        if hasattr(self,"index") and \
-                hasattr(self.index,"_find_field_values_at_points"):
-            return self.index._find_field_values_at_points(fields,coords)
+        # If an optimized version exists on the Index object we'll use that
+        try:
+            return self.index._find_field_values_at_points(fields, coords)
+        except AttributeError:
+            pass
 
         fields = ensure_list(fields)
-        out = np.zeros((len(fields),len(coords)), dtype=np.float64)
-        for i,coord in enumerate(coords):
-            out[:][i] = self.point(coord)[fields]
-        return out
+        out = []
+
+        # This may be slow because it creates a data object for each point
+        for field_index, field in enumerate(fields):
+            funit = self._get_field_info[field].units
+            out.append(self.arr(np.empty((len(coords),)), funit))
+            for coord_index, coord in enumerate(coords):
+                out[field_index][coord_index] = self.point(coord)[fields]
+        if len(fields) == 1:
+            return out[0]
+        else:
+            return out
 
     # Now all the object related stuff
     def all_data(self, find_max=False, **kwargs):

diff -r bc0cc00c150c6e5fca1b85310935dfe9ccd905a2 -r e14103b2c5950477b6ed23c988277d03c9db408a yt/data_objects/tests/test_points.py
--- a/yt/data_objects/tests/test_points.py
+++ b/yt/data_objects/tests/test_points.py
@@ -1,10 +1,66 @@
-from yt.testing import fake_random_ds
+import numpy as np
+import yt
+
+from yt.testing import \
+    fake_random_ds, \
+    assert_equal, \
+    requires_file
 
 def setup():
     from yt.config import ytcfg
     ytcfg["yt","__withintesting"] = "True"
 
 def test_domain_point():
-    ds = fake_random_ds(16)
+    nparticles = 3
+    ds = fake_random_ds(16, particles=nparticles)
     p = ds.point(ds.domain_center)
-    p['density']
+
+    # ensure accessing one field works, store for comparison later
+    point_den = p['density']
+    point_vel = p['velocity_x']
+
+    ad = ds.all_data()
+    ppos = ad['all', 'particle_position']
+
+    fpoint_den = ds.find_field_values_at_point('density', ds.domain_center)
+
+    fpoint_den_vel = ds.find_field_values_at_point(
+        ['density', 'velocity_x'], ds.domain_center)
+
+    assert_equal(point_den, fpoint_den)
+    assert_equal(point_den, fpoint_den_vel[0])
+    assert_equal(point_vel, fpoint_den_vel[1])
+
+    ppos_den = ds.find_field_values_at_points('density', ppos)
+    ppos_vel = ds.find_field_values_at_points('velocity_x', ppos)
+    ppos_den_vel = ds.find_field_values_at_points(
+        ['density', 'velocity_x'], ppos)
+
+    assert_equal(ppos_den.shape, (nparticles,))
+    assert_equal(ppos_vel.shape, (nparticles,))
+    assert_equal(len(ppos_den_vel), 2)
+    assert_equal(ppos_den_vel[0], ppos_den)
+    assert_equal(ppos_den_vel[1], ppos_vel)
+
+g30 = "IsolatedGalaxy/galaxy0030/galaxy0030"
+
+ at requires_file(g30)
+def test_fast_find_field_values_at_points():
+    ds = yt.load(g30)
+    ad = ds.all_data()
+    # right now this is slow for large numbers of particles, so randomly
+    # sample 100 particles
+    nparticles = 100
+    ppos = ad['all', 'particle_position']
+    ppos = ppos[np.random.random_integers(0, len(ppos), size=nparticles)]
+
+    ppos_den = ds.find_field_values_at_points('density', ppos)
+    ppos_vel = ds.find_field_values_at_points('velocity_x', ppos)
+    ppos_den_vel = ds.find_field_values_at_points(
+        ['density', 'velocity_x'], ppos)
+
+    assert_equal(ppos_den.shape, (nparticles,))
+    assert_equal(ppos_vel.shape, (nparticles,))
+    assert_equal(len(ppos_den_vel), 2)
+    assert_equal(ppos_den_vel[0], ppos_den)
+    assert_equal(ppos_den_vel[1], ppos_vel)

diff -r bc0cc00c150c6e5fca1b85310935dfe9ccd905a2 -r e14103b2c5950477b6ed23c988277d03c9db408a yt/geometry/grid_geometry_handler.py
--- a/yt/geometry/grid_geometry_handler.py
+++ b/yt/geometry/grid_geometry_handler.py
@@ -26,7 +26,6 @@
     ensure_list, ensure_numpy_array
 from yt.geometry.geometry_handler import \
     Index, YTDataChunk, ChunkDataCache
-from yt.units.yt_array import YTArray
 from yt.utilities.definitions import MAXLEVEL
 from yt.utilities.logger import ytLogger as mylog
 from .grid_container import \
@@ -203,8 +202,8 @@
         Returns the values [field1, field2,...] of the fields at the given
         (x, y, z) points. Returns a numpy array of field values cross coords
         """
-        coords = YTArray(ensure_numpy_array(coords),'code_length', registry=self.ds.unit_registry)
-        grids = self._find_points(coords[:,0], coords[:,1], coords[:,2])[0]
+        coords = self.ds.arr(ensure_numpy_array(coords), 'code_length')
+        grids = self._find_points(coords[:, 0], coords[:, 1], coords[:, 2])[0]
         fields = ensure_list(fields)
         mark = np.zeros(3, dtype=np.int)
         out = []
@@ -216,13 +215,21 @@
                 grid_index[grid] = []
             grid_index[grid].append(coord_index)
 
-        out = np.zeros((len(fields),len(coords)), dtype=np.float64)
+        out = []
+        for field in fields:
+            funit = self.ds._get_field_info(field).units
+            out.append(self.ds.arr(np.empty((len(coords))), funit))
+
         for grid in grid_index:
             cellwidth = (grid.RightEdge - grid.LeftEdge) / grid.ActiveDimensions
-            for field in fields:
+            for field_index, field in enumerate(fields):
                 for coord_index in grid_index[grid]:
-                    mark = ((coords[coord_index,:] - grid.LeftEdge) / cellwidth).astype('int')
-                    out[:,coord_index] = grid[field][mark[0],mark[1],mark[2]]
+                    mark = ((coords[coord_index, :] - grid.LeftEdge) / cellwidth)
+                    mark = np.array(mark, dtype='int64')
+                    out[field_index][coord_index] = \
+                        grid[field][mark[0], mark[1], mark[2]]
+        if len(fields) == 1:
+            return out[0]
         return out

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.



More information about the yt-svn mailing list