[Yt-svn] yt-commit r1080 - in trunk: tests yt/lagos

mturk at wrangler.dreamhost.com mturk at wrangler.dreamhost.com
Tue Jan 6 21:33:36 PST 2009


Author: mturk
Date: Tue Jan  6 21:33:35 2009
New Revision: 1080
URL: http://yt.spacepope.org/changeset/1080

Log:
Adding cache=True option to extract_connected_sets.  This'll cache ghost zones
for GridIndices and the field being contoured.  Speeds things up considerably.
This'll need to get passed on to the clump finder eventually.  Added tests for
extract_connected_sets and not just identify_contours.  Added progress meter
for the (changing) length of the contour check queue.



Modified:
   trunk/tests/test_lagos.py
   trunk/yt/lagos/BaseDataTypes.py
   trunk/yt/lagos/ContourFinder.py

Modified: trunk/tests/test_lagos.py
==============================================================================
--- trunk/tests/test_lagos.py	(original)
+++ trunk/tests/test_lagos.py	Tue Jan  6 21:33:35 2009
@@ -167,9 +167,32 @@
         v2 = na.abs(1.0 - v2/v1)
         self.assertAlmostEqual(v2, 0.0, 7)
 
+    def testExtractConnectedSetsNoCache(self):
+        mi = self.data["Density"].min() * 2.0
+        ma = self.data["Density"].max() * 0.99
+        cons, contours = self.data.extract_connected_sets(
+            "Density", 2, mi, ma)
+        print cons
+        self.assertEqual(len(contours), 2) # number of contour levels
+        self.assertEqual(len(contours[0]), 2)
+        self.assertEqual(len(contours[1]), 1)
+
+    def testExtractConnectedSetsCache(self):
+        mi = self.data["Density"].min() * 2.0
+        ma = self.data["Density"].max() * 0.99
+        cons, contours = self.data.extract_connected_sets(
+            "Density", 2, mi, ma, cache=True)
+        self.assertEqual(len(contours), 2) # number of contour levels
+        self.assertEqual(len(contours[0]), 2)
+        self.assertEqual(len(contours[1]), 1)
+
+    def testContoursCache(self):
+        cid = yt.lagos.identify_contours(self.data, "Density",
+                self.data["Density"].min()*2.00,
+                self.data["Density"].max()*1.01)
+        self.assertEqual(len(cid), 2)
+
     def testContoursObtain(self):
-        # As a note, unfortunately this dataset only has one sphere.
-        # Frownie face.
         cid = yt.lagos.identify_contours(self.data, "Density",
                 self.data["Density"].min()*2.00, self.data["Density"].max()*1.01)
         self.assertEqual(len(cid), 2)

Modified: trunk/yt/lagos/BaseDataTypes.py
==============================================================================
--- trunk/yt/lagos/BaseDataTypes.py	(original)
+++ trunk/yt/lagos/BaseDataTypes.py	Tue Jan  6 21:33:35 2009
@@ -1368,7 +1368,7 @@
     quantities = property(__get_quantities)
 
     def extract_connected_sets(self, field, num_levels, min_val, max_val,
-                                log_space=True, cumulative=True):
+                                log_space=True, cumulative=True, cache=False):
         """
         This function will create a set of contour objects, defined
         by having connected cell structures, which can then be
@@ -1381,13 +1381,16 @@
         else:
             cons = na.linspace(min_val, max_val, num_levels+1)
         contours = {}
+        if cache: cached_fields = defaultdict(lambda: dict())
+        else: cached_fields = None
         for level in range(num_levels):
             contours[level] = {}
             if cumulative:
                 mv = max_val
             else:
                 mv = cons[level+1]
-            cids = identify_contours(self, field, cons[level], mv)
+            cids = identify_contours(self, field, cons[level], mv,
+                                     cached_fields)
             for cid, cid_ind in cids.items():
                 contours[level][cid] = self.extract_region(cid_ind)
         return cons, contours

Modified: trunk/yt/lagos/ContourFinder.py
==============================================================================
--- trunk/yt/lagos/ContourFinder.py	(original)
+++ trunk/yt/lagos/ContourFinder.py	Tue Jan  6 21:33:35 2009
@@ -77,15 +77,18 @@
         self.n += 1
         return tr
 
+    def progress(self):
+        return self.n, len(self.to_consider)
+
 # We want an algorithm that deals with growing a given contour to *all* the
 # cells in a grid.
 
-def identify_contours(data_source, field, min_val, max_val):
+def identify_contours(data_source, field, min_val, max_val, cached_fields=None):
     """
     Given a *data_source*, we will search for topologically connected sets
     in *field* between *min_val* and *max_val*.
     """
-    maxn_cells = 0
+    if cached_fields is None: cached_fields = defaultdict(lambda: dict())
     maxn_cells = na.sum([g.ActiveDimensions.prod() for g in data_source._grids])
     contour_ind = na.where( (data_source[field] > min_val)
                           & (data_source[field] < max_val))[0]
@@ -102,12 +105,21 @@
                     priority_func = lambda g: -1*g["tempContours"].max())
     my_queue.add(data_source._grids)
     for i,grid in enumerate(my_queue):
+        mylog.info("Examining %s of %s", *my_queue.progress())
         max_before = grid["tempContours"].max()
-        if na.all(grid.LeftEdge == grid.pf["DomainLeftEdge"]) and \
-           na.all(grid.RightEdge == grid.pf["DomainRightEdge"]):
-            cg = grid.retrieve_ghost_zones(0,["tempContours","GridIndices"])
-        else:
-            cg = grid.retrieve_ghost_zones(1,["tempContours","GridIndices"])
+        to_get = ["tempContours"]
+        if field in cached_fields[grid.id] and \
+            not na.any( (cached_fields[grid.id][field] > min_val)
+                      & (cached_fields[grid.id][field] < max_val)):
+            continue
+        for f in [field, "GridIndices"]:
+            if f not in cached_fields[grid.id]: to_get.append(f)
+        cg = grid.retrieve_ghost_zones(1,to_get)
+        for f in [field, "GridIndices"]:
+            if f in cached_fields[grid.id]:
+                cg.data[f] = cached_fields[grid.id][f]
+            else:
+                cached_fields[grid.id][f] = cg[f] 
         local_ind = na.where( (cg[field] > min_val)
                             & (cg[field] < max_val)
                             & (cg["tempContours"] == -1) )
@@ -146,7 +158,7 @@
     mylog.info("Identified %s contours between %0.5e and %0.5e",
                len(contour_ind.keys()),min_val,max_val)
     for grid in data_source._grids:
-        if grid.data.has_key("tempContours"): del grid.data["tempContours"]
+        grid.data.pop("tempContours", None)
     del data_source.data["tempContours"]
     return contour_ind
 



More information about the yt-svn mailing list