[Yt-svn] yt-commit r373 - in trunk: tests yt/lagos

mturk at wrangler.dreamhost.com mturk at wrangler.dreamhost.com
Wed Jan 30 16:27:51 PST 2008


Author: mturk
Date: Wed Jan 30 16:27:49 2008
New Revision: 373
URL: http://yt.spacepope.org/changeset/373

Log:

Added UnilinearFieldInterpolator, BilinearFieldInterpolator
and TrilinearFieldInterpolator.  This enables you to create an object which
will then accept data objects in its __call__ method which return things back
to you, interpolated appropriately.  For example, if you have a tabulated set
of values in two dimensions that you wanted to see in a slice, you create a
bilinear interpolator object with the field names of the varying parameters
(i.e., "Density" and "Temperature") and then call it with the data object you
want your tabulated values interpolated into.

More on this later.

Additionally, changed all datatypes to support **kwargs, which allows the
passing of things like field interpolators into datatypes.

Added interpolation tests to the increasingly unweildy test_lagos.



Added:
   trunk/yt/lagos/HelperFunctions.py
Modified:
   trunk/tests/test_lagos.py
   trunk/yt/lagos/BaseDataTypes.py
   trunk/yt/lagos/HierarchyType.py
   trunk/yt/lagos/__init__.py

Modified: trunk/tests/test_lagos.py
==============================================================================
--- trunk/tests/test_lagos.py	(original)
+++ trunk/tests/test_lagos.py	Wed Jan 30 16:27:49 2008
@@ -162,6 +162,12 @@
                                                         accumulation_y)
                 setattr(Data3DBase, name, func)
 
+class TestDiskDataType(Data3DBase, DataTypeTestingBase, LagosTestingBase, unittest.TestCase):
+    def setUp(self):
+        DataTypeTestingBase.setUp(self)
+        self.data=self.hierarchy.disk(
+                     [0.5,0.5,0.5],[0.2, 0.1, 0.5],0.25,0.25)
+
 class TestRegionDataType(Data3DBase, DataTypeTestingBase, LagosTestingBase, unittest.TestCase):
     def setUp(self):
         DataTypeTestingBase.setUp(self)
@@ -234,5 +240,94 @@
             / self.data.convert("cm")**3.0
         self.assertAlmostEqual(vol,1.0,7)
 
+class TestUnilinearInterpolator(unittest.TestCase):
+    def setUp(self):
+        x0, x1 = na.random.uniform(-100,100,2)
+        nstep_x = na.random.randint(10,200)
+        nvals = na.random.randint(100,1000)
+
+        table = na.mgrid[x0:x1:nstep_x*1j]
+
+        self.ufi_x = yt.lagos.UnilinearFieldInterpolator(table,
+                      (x0,x1),'x')
+        self.my_dict = {}
+        self.my_dict['x'] = na.random.uniform(x0,x1,nvals)
+
+    def testXInt(self):
+        nv = self.ufi_x(self.my_dict)
+        for i,v in enumerate(nv):
+            self.assertAlmostEqual(v, self.my_dict['x'][i], 5)
+
+class TestBilinearInterpolator(unittest.TestCase):
+    def setUp(self):
+        x0, x1 = na.random.uniform(-100,100,2)
+        y0, y1 = na.random.uniform(-100,100,2)
+        nstep_x = na.random.randint(10,200)
+        nstep_y = na.random.randint(10,200)
+        nvals = na.random.randint(100,1000)
+
+        table = na.mgrid[x0:x1:nstep_x*1j,
+                         y0:y1:nstep_y*1j]
+
+        self.bfi_x = yt.lagos.BilinearFieldInterpolator(table[0,...],
+                      (x0,x1,y0,y1),['x','y'])
+        self.bfi_y = yt.lagos.BilinearFieldInterpolator(table[1,...],
+                      (x0,x1,y0,y1),['x','y'])
+        self.my_dict = {}
+        self.my_dict['x'] = na.random.uniform(x0,x1,nvals)
+        self.my_dict['y'] = na.random.uniform(y0,y1,nvals)
+
+    def testXInt(self):
+        nv = self.bfi_x(self.my_dict)
+        for i,v in enumerate(nv):
+            self.assertAlmostEqual(v, self.my_dict['x'][i], 5)
+
+    def testYInt(self):
+        nv = self.bfi_y(self.my_dict)
+        for i,v in enumerate(nv):
+            self.assertAlmostEqual(v, self.my_dict['y'][i], 5)
+
+class TestTrilinearInterpolator(unittest.TestCase):
+    def setUp(self):
+        x0, x1 = na.random.uniform(-100,100,2)
+        y0, y1 = na.random.uniform(-100,100,2)
+        z0, z1 = na.random.uniform(-100,100,2)
+        nstep_x = na.random.randint(10,200)
+        nstep_y = na.random.randint(10,200)
+        nstep_z = na.random.randint(10,200)
+        nvals = na.random.randint(100,1000)
+
+        table = na.mgrid[x0:x1:nstep_x*1j,
+                         y0:y1:nstep_y*1j,
+                         z0:z1:nstep_z*1j]
+
+        self.tfi_x = yt.lagos.TrilinearFieldInterpolator(table[0,...],
+                      (x0,x1,y0,y1,z0,z1),['x','y','z'])
+        self.tfi_y = yt.lagos.TrilinearFieldInterpolator(table[1,...],
+                      (x0,x1,y0,y1,z0,z1),['x','y','z'])
+        self.tfi_z = yt.lagos.TrilinearFieldInterpolator(table[2,...],
+                      (x0,x1,y0,y1,z0,z1),['x','y','z'])
+        self.my_dict = {}
+        self.my_dict['x'] = na.random.uniform(x0,x1,nvals)
+        self.my_dict['y'] = na.random.uniform(y0,y1,nvals)
+        self.my_dict['z'] = na.random.uniform(z0,z1,nvals)
+
+    def testXInt(self):
+        nv = self.tfi_x(self.my_dict)
+        for i,v in enumerate(nv):
+            self.assertAlmostEqual(v, self.my_dict['x'][i], 5)
+
+    def testYInt(self):
+        nv = self.tfi_y(self.my_dict)
+        for i,v in enumerate(nv):
+            self.assertAlmostEqual(v, self.my_dict['y'][i], 5)
+
+    def testZInt(self):
+        nv = self.tfi_z(self.my_dict)
+        for i,v in enumerate(nv):
+            self.assertAlmostEqual(v, self.my_dict['z'][i], 5)
+
+
+
 if __name__ == "__main__":
     unittest.main()

Modified: trunk/yt/lagos/BaseDataTypes.py
==============================================================================
--- trunk/yt/lagos/BaseDataTypes.py	(original)
+++ trunk/yt/lagos/BaseDataTypes.py	Wed Jan 30 16:27:49 2008
@@ -52,7 +52,7 @@
     _grids = None
     _num_ghost_zones = 0
 
-    def __init__(self, pf, fields):
+    def __init__(self, pf, fields, **kwargs):
         """
         @param pf: The parameterfile associated with this container
         @type hierarchy: L{EnzoOutput<EnzoOutput>}
@@ -67,6 +67,8 @@
         self.data = {}
         self.field_parameters = {}
         self.__set_default_field_parameters()
+        for key, val in kwargs.items():
+            self.set_field_parameter(key, val)
 
     def __set_default_field_parameters(self):
         self.set_field_parameter("center",na.zeros(3,dtype='float64'))
@@ -144,8 +146,8 @@
 
 class Enzo1DData(EnzoData):
     _spatial = False
-    def __init__(self, pf, fields):
-        EnzoData.__init__(self, pf, fields)
+    def __init__(self, pf, fields, **kwargs):
+        EnzoData.__init__(self, pf, fields, **kwargs)
         self._grids = None
 
     def _generate_field_in_grids(self, field, num_ghost_zones=0):
@@ -169,8 +171,8 @@
 
 
 class EnzoOrthoRayBase(Enzo1DData):
-    def __init__(self, axis, coords, fields=None, pf=None):
-        Enzo1DData.__init__(self, pf, fields)
+    def __init__(self, axis, coords, fields=None, pf=None, **kwargs):
+        Enzo1DData.__init__(self, pf, fields, **kwargs)
         self.axis = axis
         self.px_ax = x_dict[self.axis]
         self.py_ax = y_dict[self.axis]
@@ -231,7 +233,7 @@
     does not have as many actions as the 3-D data types.
     """
     _spatial = False
-    def __init__(self, axis, fields, pf=None):
+    def __init__(self, axis, fields, pf=None, **kwargs):
         """
         Prepares the Enzo2DData.
 
@@ -241,7 +243,7 @@
         @type fields: list of strings
         """
         self.axis = axis
-        EnzoData.__init__(self, pf, fields)
+        EnzoData.__init__(self, pf, fields, **kwargs)
 
     @time_execution
     def get_data(self, fields = None):
@@ -333,7 +335,7 @@
     """
 
     @time_execution
-    def __init__(self, axis, coord, fields = None, center=None, pf=None):
+    def __init__(self, axis, coord, fields = None, center=None, pf=None, **kwargs):
         """
         @param axis: axis to which this data is parallel
         @type axis: integer (0,1,2)
@@ -342,7 +344,7 @@
         @keyword fields: fields to be processed or generated
         @type fields: list of strings
         """
-        Enzo2DData.__init__(self, axis, fields, pf)
+        Enzo2DData.__init__(self, axis, fields, pf, **kwargs)
         self.center = center
         self.coord = coord
         self._refresh_data()
@@ -469,14 +471,14 @@
     the appropriate data onto the plane without interpolation.
     """
     _plane = None
-    def __init__(self, normal, center, fields = None):
+    def __init__(self, normal, center, fields = None, **kwargs):
         """
         @param normal: Vector normal to which the plane will be defined
         @type normal: List or array of floats
         @param center: The center point of the plane
         @type center: List or array of floats
         """
-        Enzo2DData.__init__(self, 4, fields)
+        Enzo2DData.__init__(self, 4, fields, **kwargs)
         self.center = center
         self.set_field_parameter('center',center)
         self._cut_masks = {}
@@ -575,7 +577,7 @@
 class EnzoProjBase(Enzo2DData):
     def __init__(self, axis, field, weight_field = None,
                  max_level = None, center = None, pf = None,
-                 source=None, type=0):
+                 source=None, type=0, **kwargs):
         """
         EnzoProj is a line integral of a field along an axis.  The field
         can be weighted, in which case some degree of averaging takes place.
@@ -591,7 +593,7 @@
         @keyword source: The data source, particularly for parallel projections.
         @type source: L{EnzoData<EnzoData>}
         """
-        Enzo2DData.__init__(self, axis, field, pf)
+        Enzo2DData.__init__(self, axis, field, pf, **kwargs)
         if not source:
             source = EnzoGridCollection(center, self.hierarchy.grids)
         self.source = source
@@ -844,7 +846,7 @@
     """
     _spatial = False
     _num_ghost_zones = 0
-    def __init__(self, center, fields, pf = None):
+    def __init__(self, center, fields, pf = None, **kwargs):
         """
         Returns an instance of Enzo3DData, or prepares one.  Usually only
         used as a base class.
@@ -856,7 +858,7 @@
         @param fields: fields to read/generate
         @type fields: list of strings
         """
-        EnzoData.__init__(self, pf, fields)
+        EnzoData.__init__(self, pf, fields, **kwargs)
         self.center = center
         self.set_field_parameter("center",center)
         self.coords = None
@@ -1109,7 +1111,7 @@
     ExtractedRegions are arbitrarily defined containers of data, useful
     for things like selection along a baryon field.
     """
-    def __init__(self, base_region, indices):
+    def __init__(self, base_region, indices, **kwargs):
         """
         @param base_region: The Enzo3DData we select points from
         @type base_region: L{Enzo3DData<Enzo3DData>}
@@ -1118,7 +1120,7 @@
         """
         cen = base_region.get_field_parameter("center")
         Enzo3DData.__init__(self, center=cen,
-                            fields=None, pf=base_region.pf)
+                            fields=None, pf=base_region.pf, **kwargs)
         self._base_region = base_region
         self._base_indices = indices
         self._grids = None
@@ -1156,10 +1158,11 @@
 
 class EnzoCylinderBase(Enzo3DData):
     """
-    We define a disk as have an 'up' vector, a radius and a height.
+    We define a disk as having an 'up' vector, a radius and a height.
     """
-    def __init__(self, center, normal, radius, height, fields=None, pf=None):
-        Enzo3DData.__init__(self, na.array(center), fields, pf)
+    def __init__(self, center, normal, radius, height, fields=None,
+                 pf=None, **kwargs):
+        Enzo3DData.__init__(self, na.array(center), fields, pf, **kwargs)
         self._norm_vec = na.array(normal)/na.sqrt(na.dot(normal,normal))
         self.set_field_parameter("height_vector", self._norm_vec)
         self._height = height
@@ -1211,7 +1214,8 @@
     """
     EnzoRegions are rectangular prisms of data.
     """
-    def __init__(self, center, left_edge, right_edge, fields = None, pf = None):
+    def __init__(self, center, left_edge, right_edge, fields = None,
+                 pf = None, **kwargs):
         """
         @note: Center does not have to be (rightEdge - leftEdge) / 2.0
         @param center: The center for calculations that require it
@@ -1221,7 +1225,7 @@
         @param right_edge: The right boundary
         @type right_edge: list or array of floats
         """
-        Enzo3DData.__init__(self, center, fields, pf)
+        Enzo3DData.__init__(self, center, fields, pf, **kwargs)
         self.left_edge = left_edge
         self.right_edge = right_edge
         self._cut_masks = {}
@@ -1250,14 +1254,14 @@
     An arbitrary selection of grids, within which we accept all points.
     """
     def __init__(self, center, grid_list, fields = None, connection_pool = True,
-                 pf = None):
+                 pf = None, **kwargs):
         """
         @param center: The center of the region, for derived fields
         @type center: List or array of floats
         @param grid_list: The grids we are composed of
         @type grid_list: List or array of Grid objects
         """
-        Enzo3DData.__init__(self, center, fields, pf)
+        Enzo3DData.__init__(self, center, fields, pf, **kwargs)
         self._grids = na.array(grid_list)
         self.fields = fields
         self._cut_masks = {}
@@ -1281,7 +1285,7 @@
     """
     A sphere of points
     """
-    def __init__(self, center, radius, fields = None, pf = None):
+    def __init__(self, center, radius, fields = None, pf = None, **kwargs):
         """
         @param center: center of the region
         @type center: array of floats
@@ -1290,7 +1294,7 @@
         @keyword fields: fields to read/generate
         @type fields: list of strings
         """
-        Enzo3DData.__init__(self, center, fields, pf)
+        Enzo3DData.__init__(self, center, fields, pf, **kwargs)
         self._cut_masks = {}
         self.set_field_parameter('radius',radius)
         self.radius = radius
@@ -1329,7 +1333,7 @@
     """
     _spatial = True
     def __init__(self, level, left_edge, right_edge, dims, fields = None,
-                 pf = None, num_ghost_zones = 0, use_pbar = True):
+                 pf = None, num_ghost_zones = 0, use_pbar = True, **kwargs):
         """
         @param level: The maximum level to consider when creating the grid
         @note: Level does not have to be related to the dx of the object.
@@ -1341,7 +1345,7 @@
         @type dims: List or array of integers
         @note: It is faster to feed all the fields in at the initialization
         """
-        Enzo3DData.__init__(self, center=None, fields=fields, pf=pf)
+        Enzo3DData.__init__(self, center=None, fields=fields, pf=pf, **kwargs)
         self.left_edge = na.array(left_edge)
         self.right_edge = na.array(right_edge)
         self.level = level

Added: trunk/yt/lagos/HelperFunctions.py
==============================================================================
--- (empty file)
+++ trunk/yt/lagos/HelperFunctions.py	Wed Jan 30 16:27:49 2008
@@ -0,0 +1,129 @@
+"""
+A collection of helper functions, most generally for things
+that SciPy doesn't have that I expected it to
+
+ at author: U{Matthew Turk<http://www.stanford.edu/~mturk/>}
+ at organization: U{KIPAC<http://www-group.slac.stanford.edu/KIPAC/>}
+ at contact: U{mturk at slac.stanford.edu<mailto:mturk at slac.stanford.edu>}
+ at license:
+  Copyright (C) 2007 Matthew Turk.  All Rights Reserved.
+
+  This file is part of yt.
+
+  yt is free software; you can redistribute it and/or modify
+  it under the terms of the GNU General Public License as published by
+  the Free Software Foundation; either version 3 of the License, or
+  (at your option) any later version.
+
+  This program is distributed in the hope that it will be useful,
+  but WITHOUT ANY WARRANTY; without even the implied warranty of
+  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+  GNU General Public License for more details.
+
+  You should have received a copy of the GNU General Public License
+  along with this program.  If not, see <http://www.gnu.org/licenses/>.
+"""
+
+from yt.lagos import *
+
+class UnilinearFieldInterpolator:
+    def __init__(self, table, boundaries, field_names):
+        self.table = table
+        x0, x1 = boundaries
+        self.x_name = field_names
+        self.x_bins = na.linspace(x0, x1, table.shape[0])
+
+    def __call__(self, data_object):
+        orig_shape = data_object[self.x_name].shape
+        x_vals = data_object[self.x_name].ravel()
+
+        x_i = na.digitize(data_object[self.x_name], self.x_bins) - 1
+        if na.any((x_i == -1) | (x_i == len(self.x_bins)-1)):
+            mylog.error("Sorry, but your values are outside" + \
+                        " the table!  Dunno what to do, so dying.")
+            mylog.error("Error was in: %s", data_object)
+            raise ValueError
+
+        x = (x_vals - self.x_bins[x_i]) / (self.x_bins[x_i+1] - self.x_bins[x_i])
+        xm = (self.x_bins[x_i+1] - x_vals) / (self.x_bins[x_i+1] - self.x_bins[x_i])
+        my_vals = self.table[x_i  ] * (xm) \
+                + self.table[x_i+1] * (x )
+        return my_vals.reshape(orig_shape)
+
+class BilinearFieldInterpolator:
+    def __init__(self, table, boundaries, field_names):
+        self.table = table
+        x0, x1, y0, y1 = boundaries
+        self.x_name, self.y_name = field_names
+        self.x_bins = na.linspace(x0, x1, table.shape[0])
+        self.y_bins = na.linspace(y0, y1, table.shape[1])
+
+    def __call__(self, data_object):
+        orig_shape = data_object[self.x_name].shape
+        x_vals = data_object[self.x_name].ravel()
+        y_vals = data_object[self.y_name].ravel()
+
+        x_i = na.digitize(data_object[self.x_name], self.x_bins) - 1
+        y_i = na.digitize(data_object[self.y_name], self.y_bins) - 1
+        if na.any((x_i == -1) | (x_i == len(self.x_bins)-1)) \
+            or na.any((y_i == -1) | (y_i == len(self.y_bins)-1)):
+            mylog.error("Sorry, but your values are outside" + \
+                        " the table!  Dunno what to do, so dying.")
+            mylog.error("Error was in: %s", data_object)
+            raise ValueError
+
+        x = (x_vals - self.x_bins[x_i]) / (self.x_bins[x_i+1] - self.x_bins[x_i])
+        y = (y_vals - self.y_bins[y_i]) / (self.y_bins[y_i+1] - self.y_bins[y_i])
+        xm = (self.x_bins[x_i+1] - x_vals) / (self.x_bins[x_i+1] - self.x_bins[x_i])
+        ym = (self.y_bins[y_i+1] - y_vals) / (self.y_bins[y_i+1] - self.y_bins[y_i])
+        my_vals = \
+                  self.table[x_i  ,y_i  ] * (xm*ym) \
+                + self.table[x_i+1,y_i  ] * (x *ym) \
+                + self.table[x_i  ,y_i+1] * (xm*y ) \
+                + self.table[x_i+1,y_i+1] * (x *y )
+        return my_vals.reshape(orig_shape)
+
+class TrilinearFieldInterpolator:
+    def __init__(self, table, boundaries, field_names):
+        self.table = table
+        x0, x1, y0, y1, z0, z1 = boundaries
+        self.x_name, self.y_name, self.z_name = field_names
+        self.x_bins = na.linspace(x0, x1, table.shape[0])
+        self.y_bins = na.linspace(y0, y1, table.shape[1])
+        self.z_bins = na.linspace(z0, z1, table.shape[2])
+
+    def __call__(self, data_object):
+        orig_shape = data_object[self.x_name].shape
+        x_vals = data_object[self.x_name].ravel()
+        y_vals = data_object[self.y_name].ravel()
+        z_vals = data_object[self.z_name].ravel()
+
+        x_i = na.digitize(data_object[self.x_name], self.x_bins) - 1
+        y_i = na.digitize(data_object[self.y_name], self.y_bins) - 1
+        z_i = na.digitize(data_object[self.z_name], self.z_bins) - 1
+        if na.any((x_i == -1) | (x_i == len(self.x_bins)-1)) \
+            or na.any((y_i == -1) | (y_i == len(self.y_bins)-1)) \
+            or na.any((z_i == -1) | (z_i == len(self.z_bins)-1)):
+            mylog.error("Sorry, but your values are outside" + \
+                        " the table!  Dunno what to do, so dying.")
+            mylog.error("Error was in: %s", data_object)
+            raise ValueError
+
+        # Use notation from Paul Bourke's page on interpolation
+        # http://local.wasp.uwa.edu.au/~pbourke/other/interpolation/
+        x = (x_vals - self.x_bins[x_i]) / (self.x_bins[x_i+1] - self.x_bins[x_i])
+        y = (y_vals - self.y_bins[y_i]) / (self.y_bins[y_i+1] - self.y_bins[y_i])
+        z = (z_vals - self.z_bins[z_i]) / (self.z_bins[z_i+1] - self.z_bins[z_i])
+        xm = (self.x_bins[x_i+1] - x_vals) / (self.x_bins[x_i+1] - self.x_bins[x_i])
+        ym = (self.y_bins[y_i+1] - y_vals) / (self.y_bins[y_i+1] - self.y_bins[y_i])
+        zm = (self.z_bins[z_i+1] - z_vals) / (self.z_bins[z_i+1] - self.z_bins[z_i])
+        my_vals = \
+                  self.table[x_i  ,y_i  ,z_i  ] * (xm*ym*zm) \
+                + self.table[x_i+1,y_i  ,z_i  ] * (x *ym*zm) \
+                + self.table[x_i  ,y_i+1,z_i  ] * (xm*y *zm) \
+                + self.table[x_i  ,y_i  ,z_i+1] * (xm*ym*z ) \
+                + self.table[x_i+1,y_i  ,z_i+1] * (x *ym*z ) \
+                + self.table[x_i  ,y_i+1,z_i+1] * (xm*y *z ) \
+                + self.table[x_i+1,y_i+1,z_i  ] * (x *y *zm) \
+                + self.table[x_i+1,y_i+1,z_i+1] * (x *y *z )
+        return my_vals.reshape(orig_shape)
\ No newline at end of file

Modified: trunk/yt/lagos/HierarchyType.py
==============================================================================
--- trunk/yt/lagos/HierarchyType.py	(original)
+++ trunk/yt/lagos/HierarchyType.py	Wed Jan 30 16:27:49 2008
@@ -62,6 +62,8 @@
         # Now we search backwards from the end of the file to find out how many
         # grids we have, which allows us to preallocate memory
         self.__hierarchy_lines = open(self.hierarchy_filename).readlines()
+        if len(self.__hierarchy_lines) == 0:
+            raise IOError(-1,"File empty", self.hierarchy_filename)
         self.__hierarchy_string = open(self.hierarchy_filename).read()
         for i in xrange(len(self.__hierarchy_lines)-1,0,-1):
             line = self.__hierarchy_lines[i]

Modified: trunk/yt/lagos/__init__.py
==============================================================================
--- trunk/yt/lagos/__init__.py	(original)
+++ trunk/yt/lagos/__init__.py	Wed Jan 30 16:27:49 2008
@@ -78,6 +78,7 @@
 from HierarchyType import *
 from OutputTypes import *
 from Profiles import *
+from HelperFunctions import *
 
 # We load plugins.  Keep in mind, this can be fairly dangerous -
 # the primary purpose is to allow people to have a set of functions



More information about the yt-svn mailing list