[Yt-svn] yt-commit r1205 - in trunk: tests yt/lagos

mturk at wrangler.dreamhost.com mturk at wrangler.dreamhost.com
Thu Mar 12 12:45:04 PDT 2009


Author: mturk
Date: Thu Mar 12 12:45:01 2009
New Revision: 1205
URL: http://yt.spacepope.org/changeset/1205

Log:
Added a join method for extracted regions.  Uses the center from the one join
is called on.  Added simple tests.



Modified:
   trunk/tests/test_lagos.py
   trunk/yt/lagos/BaseDataTypes.py
   trunk/yt/lagos/HierarchyType.py

Modified: trunk/tests/test_lagos.py
==============================================================================
--- trunk/tests/test_lagos.py	(original)
+++ trunk/tests/test_lagos.py	Thu Mar 12 12:45:01 2009
@@ -476,6 +476,13 @@
             / self.data.convert("cm")**3.0
         self.assertAlmostEqual(vol,1.0,7)
 
+    def testJoin(self):
+        new_region = self.region.extract_region(
+                self.region["Temperature"]<=500)
+        joined_region = self.data.join(new_region)
+        self.assertEqual(joined_region["CellMassMsun"].sum(),
+                         self.region["CellMassMsun"].sum())
+
 class TestExtractFromRegion(TestRegionDataType):
     def setUp(self):
         TestRegionDataType.setUp(self)
@@ -491,6 +498,14 @@
             / self.data.convert("cm")**3.0
         self.assertAlmostEqual(vol,1.0,7)
 
+    def testJoin(self):
+        new_region = self.region.extract_region(
+                self.region["Temperature"]<=500)
+        joined_region = self.data.join(new_region)
+        self.assertEqual(joined_region["CellMassMsun"].sum(),
+                         self.region["CellMassMsun"].sum())
+
+
 class TestUnilinearInterpolator(unittest.TestCase):
     def setUp(self):
         x0, x1 = na.random.uniform(-100,100,2)

Modified: trunk/yt/lagos/BaseDataTypes.py
==============================================================================
--- trunk/yt/lagos/BaseDataTypes.py	(original)
+++ trunk/yt/lagos/BaseDataTypes.py	Thu Mar 12 12:45:01 2009
@@ -1507,6 +1507,28 @@
                        for i in self._con_args if i != "_indices"])
         return s
 
+    def join(self, other):
+        ng = {}
+        gs = set(self._indices.keys() + other._indices.keys())
+        for g in gs:
+            grid = self.pf.h.grids[g]
+            if g in other._indices and g in self._indices:
+                # We now join the indices
+                ind = na.zeros(grid.ActiveDimensions, dtype='bool')
+                ind[self._indices[g]] = True
+                ind[other._indices[g]] = True
+                if ind.prod() == grid.ActiveDimensions.prod(): ind = None
+            elif g in self._indices:
+                ind = self._indices[g]
+            elif g in other._indices:
+                ind = other._indices[g]
+            # Okay we have indices
+            if ind is not None: ind = ind.copy()
+            ng[g] = ind
+        gl = self.pf.h.grids[list(gs)]
+        gc = self.pf.h.grid_collection(
+            self._base_region.get_field_parameter("center"), gl)
+        return self.pf.h.extracted_region(gc, ng)
 
 class InLineExtractedRegionBase(AMR3DData):
     """
@@ -1713,7 +1735,9 @@
     """
     An arbitrary selection of grids, within which we accept all points.
     """
-    def __init__(self, center, grid_list, fields = None, connection_pool = True,
+    _type_name = "grid_collection"
+    _con_args = ("center", "grid_list")
+    def __init__(self, center, grid_list, fields = None,
                  pf = None, **kwargs):
         """
         By selecting an arbitrary *grid_list*, we can act on those grids.
@@ -1721,7 +1745,6 @@
         """
         AMR3DData.__init__(self, center, fields, pf, **kwargs)
         self._grids = na.array(grid_list)
-        self.connection_pool = True
 
     def _get_list_of_grids(self):
         pass

Modified: trunk/yt/lagos/HierarchyType.py
==============================================================================
--- trunk/yt/lagos/HierarchyType.py	(original)
+++ trunk/yt/lagos/HierarchyType.py	Thu Mar 12 12:45:01 2009
@@ -553,7 +553,9 @@
                 ogl[0].dimensions[2],
                 position, 1,
                 output, refined, ogl)
-        return output, refined
+        dd = {}
+        for i,field in enumerate(fields): dd[field] = output[:,i]
+        return dd, refined
 
     def _generate_levels_octree(self, fields):
         import DepthFirstOctree as dfo



More information about the yt-svn mailing list