[yt-svn] commit/yt: brittonsmith: Merged in ngoldbaum/yt (pull request #1950)
commits-noreply at bitbucket.org
commits-noreply at bitbucket.org
Thu Jan 28 09:08:19 PST 2016
1 new commit in yt:
https://bitbucket.org/yt_analysis/yt/commits/e14103b2c595/
Changeset: e14103b2c595
Branch: yt
User: brittonsmith
Date: 2016-01-28 17:08:14+00:00
Summary: Merged in ngoldbaum/yt (pull request #1950)
Fix a number of issues with ds.find_field_values_at_point[s]. Closes #1161
Affected #: 3 files
diff -r bc0cc00c150c6e5fca1b85310935dfe9ccd905a2 -r e14103b2c5950477b6ed23c988277d03c9db408a yt/data_objects/static_output.py
--- a/yt/data_objects/static_output.py
+++ b/yt/data_objects/static_output.py
@@ -665,7 +665,15 @@
coordinates. Returns a list of field values in the same order as
the input *fields*.
"""
- return self.point(coords)[fields]
+ point = self.point(coords)
+ ret = []
+ field_list = ensure_list(fields)
+ for field in field_list:
+ ret.append(point[field])
+ if len(field_list) == 1:
+ return ret[0]
+ else:
+ return ret
def find_field_values_at_points(self, fields, coords):
"""
@@ -673,19 +681,26 @@
[(x1, y1, z2), (x2, y2, z2),...] points. Returns a list of field
values in the same order as the input *fields*.
- This is quite slow right now as it creates a new data object for each
- point. If an optimized version exists on the Index object we'll use
- that instead.
"""
- if hasattr(self,"index") and \
- hasattr(self.index,"_find_field_values_at_points"):
- return self.index._find_field_values_at_points(fields,coords)
+ # If an optimized version exists on the Index object we'll use that
+ try:
+ return self.index._find_field_values_at_points(fields, coords)
+ except AttributeError:
+ pass
fields = ensure_list(fields)
- out = np.zeros((len(fields),len(coords)), dtype=np.float64)
- for i,coord in enumerate(coords):
- out[:][i] = self.point(coord)[fields]
- return out
+ out = []
+
+ # This may be slow because it creates a data object for each point
+ for field_index, field in enumerate(fields):
+ funit = self._get_field_info[field].units
+ out.append(self.arr(np.empty((len(coords),)), funit))
+ for coord_index, coord in enumerate(coords):
+ out[field_index][coord_index] = self.point(coord)[fields]
+ if len(fields) == 1:
+ return out[0]
+ else:
+ return out
# Now all the object related stuff
def all_data(self, find_max=False, **kwargs):
diff -r bc0cc00c150c6e5fca1b85310935dfe9ccd905a2 -r e14103b2c5950477b6ed23c988277d03c9db408a yt/data_objects/tests/test_points.py
--- a/yt/data_objects/tests/test_points.py
+++ b/yt/data_objects/tests/test_points.py
@@ -1,10 +1,66 @@
-from yt.testing import fake_random_ds
+import numpy as np
+import yt
+
+from yt.testing import \
+ fake_random_ds, \
+ assert_equal, \
+ requires_file
def setup():
from yt.config import ytcfg
ytcfg["yt","__withintesting"] = "True"
def test_domain_point():
- ds = fake_random_ds(16)
+ nparticles = 3
+ ds = fake_random_ds(16, particles=nparticles)
p = ds.point(ds.domain_center)
- p['density']
+
+ # ensure accessing one field works, store for comparison later
+ point_den = p['density']
+ point_vel = p['velocity_x']
+
+ ad = ds.all_data()
+ ppos = ad['all', 'particle_position']
+
+ fpoint_den = ds.find_field_values_at_point('density', ds.domain_center)
+
+ fpoint_den_vel = ds.find_field_values_at_point(
+ ['density', 'velocity_x'], ds.domain_center)
+
+ assert_equal(point_den, fpoint_den)
+ assert_equal(point_den, fpoint_den_vel[0])
+ assert_equal(point_vel, fpoint_den_vel[1])
+
+ ppos_den = ds.find_field_values_at_points('density', ppos)
+ ppos_vel = ds.find_field_values_at_points('velocity_x', ppos)
+ ppos_den_vel = ds.find_field_values_at_points(
+ ['density', 'velocity_x'], ppos)
+
+ assert_equal(ppos_den.shape, (nparticles,))
+ assert_equal(ppos_vel.shape, (nparticles,))
+ assert_equal(len(ppos_den_vel), 2)
+ assert_equal(ppos_den_vel[0], ppos_den)
+ assert_equal(ppos_den_vel[1], ppos_vel)
+
+g30 = "IsolatedGalaxy/galaxy0030/galaxy0030"
+
+ at requires_file(g30)
+def test_fast_find_field_values_at_points():
+ ds = yt.load(g30)
+ ad = ds.all_data()
+ # right now this is slow for large numbers of particles, so randomly
+ # sample 100 particles
+ nparticles = 100
+ ppos = ad['all', 'particle_position']
+ ppos = ppos[np.random.random_integers(0, len(ppos), size=nparticles)]
+
+ ppos_den = ds.find_field_values_at_points('density', ppos)
+ ppos_vel = ds.find_field_values_at_points('velocity_x', ppos)
+ ppos_den_vel = ds.find_field_values_at_points(
+ ['density', 'velocity_x'], ppos)
+
+ assert_equal(ppos_den.shape, (nparticles,))
+ assert_equal(ppos_vel.shape, (nparticles,))
+ assert_equal(len(ppos_den_vel), 2)
+ assert_equal(ppos_den_vel[0], ppos_den)
+ assert_equal(ppos_den_vel[1], ppos_vel)
diff -r bc0cc00c150c6e5fca1b85310935dfe9ccd905a2 -r e14103b2c5950477b6ed23c988277d03c9db408a yt/geometry/grid_geometry_handler.py
--- a/yt/geometry/grid_geometry_handler.py
+++ b/yt/geometry/grid_geometry_handler.py
@@ -26,7 +26,6 @@
ensure_list, ensure_numpy_array
from yt.geometry.geometry_handler import \
Index, YTDataChunk, ChunkDataCache
-from yt.units.yt_array import YTArray
from yt.utilities.definitions import MAXLEVEL
from yt.utilities.logger import ytLogger as mylog
from .grid_container import \
@@ -203,8 +202,8 @@
Returns the values [field1, field2,...] of the fields at the given
(x, y, z) points. Returns a numpy array of field values cross coords
"""
- coords = YTArray(ensure_numpy_array(coords),'code_length', registry=self.ds.unit_registry)
- grids = self._find_points(coords[:,0], coords[:,1], coords[:,2])[0]
+ coords = self.ds.arr(ensure_numpy_array(coords), 'code_length')
+ grids = self._find_points(coords[:, 0], coords[:, 1], coords[:, 2])[0]
fields = ensure_list(fields)
mark = np.zeros(3, dtype=np.int)
out = []
@@ -216,13 +215,21 @@
grid_index[grid] = []
grid_index[grid].append(coord_index)
- out = np.zeros((len(fields),len(coords)), dtype=np.float64)
+ out = []
+ for field in fields:
+ funit = self.ds._get_field_info(field).units
+ out.append(self.ds.arr(np.empty((len(coords))), funit))
+
for grid in grid_index:
cellwidth = (grid.RightEdge - grid.LeftEdge) / grid.ActiveDimensions
- for field in fields:
+ for field_index, field in enumerate(fields):
for coord_index in grid_index[grid]:
- mark = ((coords[coord_index,:] - grid.LeftEdge) / cellwidth).astype('int')
- out[:,coord_index] = grid[field][mark[0],mark[1],mark[2]]
+ mark = ((coords[coord_index, :] - grid.LeftEdge) / cellwidth)
+ mark = np.array(mark, dtype='int64')
+ out[field_index][coord_index] = \
+ grid[field][mark[0], mark[1], mark[2]]
+ if len(fields) == 1:
+ return out[0]
return out
Repository URL: https://bitbucket.org/yt_analysis/yt/
--
This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.
More information about the yt-svn
mailing list