[Yt-svn] yt-commit r346 - in trunk: tests yt/lagos

mturk at wrangler.dreamhost.com mturk at wrangler.dreamhost.com
Sat Dec 29 10:14:24 PST 2007


Author: mturk
Date: Sat Dec 29 10:14:24 2007
New Revision: 346
URL: http://yt.spacepope.org/changeset/346

Log:
Added unit tests for profiles.

Fixed profiles to use similar decorators to pass properties onto their
children.



Modified:
   trunk/tests/test_lagos.py
   trunk/yt/lagos/Profiles.py

Modified: trunk/tests/test_lagos.py
==============================================================================
--- trunk/tests/test_lagos.py	(original)
+++ trunk/tests/test_lagos.py	Sat Dec 29 10:14:24 2007
@@ -3,7 +3,6 @@
 """
 
 # @TODO: Add unit test for deleting field from fieldInfo
-# @TODO: Profile unit testing, including for small spheres
 
 import unittest, glob, os.path, os, sys, StringIO
 
@@ -111,16 +110,58 @@
             pass
     return field_function
 
+def _returnProfile1DFunction(field, weight, accumulation, lazy):
+    def add_field_function(self):
+        self.data.set_field_parameter("center",[.5,.5,.5])
+        profile = yt.lagos.BinnedProfile1D(
+            self.data, 8, "RadiusCode", 0, 1.0, False, lazy)
+        profile.add_fields(field, weight=weight, accumulation=accumulation)
+    return add_field_function
+
+def _returnProfile2DFunction(field, weight, accumulation, lazy):
+    def add_field_function(self):
+        self.data.set_field_parameter("center",[.5,.5,.5])
+        cv_min = self.hierarchy.gridDxs.min()**3.0
+        cv_max = self.hierarchy.gridDxs.max()**3.0
+        profile = yt.lagos.BinnedProfile2D(self.data,
+                    8, "RadiusCode", 1e-3, 1.0, True,
+                    8, "CellVolumeCode", cv_min, cv_max, True, lazy)
+        profile.add_fields(field, weight=weight, accumulation=accumulation)
+    return add_field_function
+
 class DataTypeTestingBase:
     def setUp(self):
         LagosTestingBase.setUp(self)
+
+class Data3DBase:
+    pass
+
 for field in yt.lagos.fieldInfo.values():
-    #if field.name.find("particle") > -1:
-        #continue
-    func = _returnFieldFunction(field)
-    setattr(DataTypeTestingBase, "test%s" % field.name, func)
+    setattr(DataTypeTestingBase, "test%s" % field.name, _returnFieldFunction(field))
+
+field = "Temperature"
+for weight in [None, "CellMassMsun"]:
+    for lazy in [True, False]:
+        for accumulation in [True, False]:
+            func = _returnProfile1DFunction(field, weight, accumulation, lazy)
+            name = "test%sProfile1D_w%s_l%s_a%s" % (field,
+                                                weight, lazy,
+                                                accumulation)
+            setattr(Data3DBase, name, func)
+
+for weight in [None, "CellMassMsun"]:
+    for lazy in [True, False]:
+        for accumulation_x in [True, False]:
+            for accumulation_y in [True, False]:
+                acc = (accumulation_x, accumulation_y)
+                func = _returnProfile2DFunction(field, weight, acc, lazy)
+                name = "test%sProfile2D_w%s_l%s_a%s_a%s" % (field,
+                                                        weight, lazy,
+                                                        accumulation_x,
+                                                        accumulation_y)
+                setattr(Data3DBase, name, func)
 
-class TestRegionDataType(DataTypeTestingBase, LagosTestingBase, unittest.TestCase):
+class TestRegionDataType(Data3DBase, DataTypeTestingBase, LagosTestingBase, unittest.TestCase):
     def setUp(self):
         DataTypeTestingBase.setUp(self)
         self.data=self.hierarchy.region(
@@ -143,6 +184,20 @@
         DataTypeTestingBase.setUp(self)
         self.data = self.hierarchy.slice(0,0.5)
 
+class TestCuttingPlane(DataTypeTestingBase, LagosTestingBase, unittest.TestCase):
+    def setUp(self):
+        DataTypeTestingBase.setUp(self)
+        self.data = self.hierarchy.cutting([0.1,0.3,0.4], [0.5,0.5,0.5])
+    def testAxisVectors(self):
+        x_v = self.data._x_vec
+        y_v = self.data._y_vec
+        z_v = self.data._norm_vec
+        self.assertAlmostEqual(na.dot(x_v, y_v), 0.0, 7)
+        self.assertAlmostEqual(na.dot(x_v, z_v), 0.0, 7)
+        self.assertAlmostEqual(na.dot(y_v, z_v), 0.0, 7)
+    def testZHeight(self):
+        self.assertTrue(na.all(self.data['pz'] < self.data['dx']))
+
 class TestGridDataType(DataTypeTestingBase, LagosTestingBase, unittest.TestCase):
     def setUp(self):
         DataTypeTestingBase.setUp(self)
@@ -178,6 +233,5 @@
             / self.data.convert("cm")**3.0
         self.assertAlmostEqual(vol,1.0,7)
 
-
 if __name__ == "__main__":
     unittest.main()

Modified: trunk/yt/lagos/Profiles.py
==============================================================================
--- trunk/yt/lagos/Profiles.py	(original)
+++ trunk/yt/lagos/Profiles.py	Sat Dec 29 10:14:24 2007
@@ -25,6 +25,21 @@
 
 from yt.lagos import *
 
+def preserve_source_parameters(func):
+    def save_state(*args, **kwargs):
+        prof = args[0]
+        source = args[1]
+        if hasattr(source, 'field_parameters'):
+            old_params = source.field_parameters
+            source.field_parameters = prof._data_source.field_parameters
+            tr = func(*args, **kwargs)
+            source.field_parameters = old_params
+        else:
+            tr = func(*args, **kwargs)
+        #print func.func_name, tr
+        return tr
+    return save_state
+
 # Note we do not inherit from EnzoData.
 # We could, but I think we instead want to deal with the root datasource.
 class BinnedProfile:
@@ -99,6 +114,7 @@
     def _get_empty_field(self):
         return na.zeros(self[self.bin_field].size, dtype='float64')
 
+    @preserve_source_parameters
     def _bin_field(self, source, field, weight, accumulation,
                    args, check_cut=False):
         inv_bin_indices = args
@@ -122,6 +138,7 @@
             binned_field = na.add.accumulate(binned_field)
         return binned_field, weight_field, True
 
+    @preserve_source_parameters
     def _get_bins(self, source, check_cut=False):
         if check_cut:
             cm = self._data_source._get_point_indices(source)
@@ -168,6 +185,7 @@
         return na.zeros((self[self.x_bin_field].size,
                          self[self.y_bin_field].size), dtype='float64')
 
+    @preserve_source_parameters
     def _bin_field(self, source, field, weight, accumulation,
                    args, check_cut=False):
         #mylog.debug("Binning %s", field)
@@ -211,6 +229,7 @@
                 binned_field = na.add.accumulate(binned_field, axis=1)
         return binned_field, weight_field, used_field.astype('bool')
 
+    @preserve_source_parameters
     def _get_bins(self, source, check_cut=False):
         if check_cut:
             cm = self._data_source._get_point_indices(source)



More information about the yt-svn mailing list