[Yt-svn] yt-commit r1080 - in trunk: tests yt/lagos
mturk at wrangler.dreamhost.com
mturk at wrangler.dreamhost.com
Tue Jan 6 21:33:36 PST 2009
Author: mturk
Date: Tue Jan 6 21:33:35 2009
New Revision: 1080
URL: http://yt.spacepope.org/changeset/1080
Log:
Adding cache=True option to extract_connected_sets. This'll cache ghost zones
for GridIndices and the field being contoured. Speeds things up considerably.
This'll need to get passed on to the clump finder eventually. Added tests for
extract_connected_sets and not just identify_contours. Added progress meter
for the (changing) length of the contour check queue.
Modified:
trunk/tests/test_lagos.py
trunk/yt/lagos/BaseDataTypes.py
trunk/yt/lagos/ContourFinder.py
Modified: trunk/tests/test_lagos.py
==============================================================================
--- trunk/tests/test_lagos.py (original)
+++ trunk/tests/test_lagos.py Tue Jan 6 21:33:35 2009
@@ -167,9 +167,32 @@
v2 = na.abs(1.0 - v2/v1)
self.assertAlmostEqual(v2, 0.0, 7)
+ def testExtractConnectedSetsNoCache(self):
+ mi = self.data["Density"].min() * 2.0
+ ma = self.data["Density"].max() * 0.99
+ cons, contours = self.data.extract_connected_sets(
+ "Density", 2, mi, ma)
+ print cons
+ self.assertEqual(len(contours), 2) # number of contour levels
+ self.assertEqual(len(contours[0]), 2)
+ self.assertEqual(len(contours[1]), 1)
+
+ def testExtractConnectedSetsCache(self):
+ mi = self.data["Density"].min() * 2.0
+ ma = self.data["Density"].max() * 0.99
+ cons, contours = self.data.extract_connected_sets(
+ "Density", 2, mi, ma, cache=True)
+ self.assertEqual(len(contours), 2) # number of contour levels
+ self.assertEqual(len(contours[0]), 2)
+ self.assertEqual(len(contours[1]), 1)
+
+ def testContoursCache(self):
+ cid = yt.lagos.identify_contours(self.data, "Density",
+ self.data["Density"].min()*2.00,
+ self.data["Density"].max()*1.01)
+ self.assertEqual(len(cid), 2)
+
def testContoursObtain(self):
- # As a note, unfortunately this dataset only has one sphere.
- # Frownie face.
cid = yt.lagos.identify_contours(self.data, "Density",
self.data["Density"].min()*2.00, self.data["Density"].max()*1.01)
self.assertEqual(len(cid), 2)
Modified: trunk/yt/lagos/BaseDataTypes.py
==============================================================================
--- trunk/yt/lagos/BaseDataTypes.py (original)
+++ trunk/yt/lagos/BaseDataTypes.py Tue Jan 6 21:33:35 2009
@@ -1368,7 +1368,7 @@
quantities = property(__get_quantities)
def extract_connected_sets(self, field, num_levels, min_val, max_val,
- log_space=True, cumulative=True):
+ log_space=True, cumulative=True, cache=False):
"""
This function will create a set of contour objects, defined
by having connected cell structures, which can then be
@@ -1381,13 +1381,16 @@
else:
cons = na.linspace(min_val, max_val, num_levels+1)
contours = {}
+ if cache: cached_fields = defaultdict(lambda: dict())
+ else: cached_fields = None
for level in range(num_levels):
contours[level] = {}
if cumulative:
mv = max_val
else:
mv = cons[level+1]
- cids = identify_contours(self, field, cons[level], mv)
+ cids = identify_contours(self, field, cons[level], mv,
+ cached_fields)
for cid, cid_ind in cids.items():
contours[level][cid] = self.extract_region(cid_ind)
return cons, contours
Modified: trunk/yt/lagos/ContourFinder.py
==============================================================================
--- trunk/yt/lagos/ContourFinder.py (original)
+++ trunk/yt/lagos/ContourFinder.py Tue Jan 6 21:33:35 2009
@@ -77,15 +77,18 @@
self.n += 1
return tr
+ def progress(self):
+ return self.n, len(self.to_consider)
+
# We want an algorithm that deals with growing a given contour to *all* the
# cells in a grid.
-def identify_contours(data_source, field, min_val, max_val):
+def identify_contours(data_source, field, min_val, max_val, cached_fields=None):
"""
Given a *data_source*, we will search for topologically connected sets
in *field* between *min_val* and *max_val*.
"""
- maxn_cells = 0
+ if cached_fields is None: cached_fields = defaultdict(lambda: dict())
maxn_cells = na.sum([g.ActiveDimensions.prod() for g in data_source._grids])
contour_ind = na.where( (data_source[field] > min_val)
& (data_source[field] < max_val))[0]
@@ -102,12 +105,21 @@
priority_func = lambda g: -1*g["tempContours"].max())
my_queue.add(data_source._grids)
for i,grid in enumerate(my_queue):
+ mylog.info("Examining %s of %s", *my_queue.progress())
max_before = grid["tempContours"].max()
- if na.all(grid.LeftEdge == grid.pf["DomainLeftEdge"]) and \
- na.all(grid.RightEdge == grid.pf["DomainRightEdge"]):
- cg = grid.retrieve_ghost_zones(0,["tempContours","GridIndices"])
- else:
- cg = grid.retrieve_ghost_zones(1,["tempContours","GridIndices"])
+ to_get = ["tempContours"]
+ if field in cached_fields[grid.id] and \
+ not na.any( (cached_fields[grid.id][field] > min_val)
+ & (cached_fields[grid.id][field] < max_val)):
+ continue
+ for f in [field, "GridIndices"]:
+ if f not in cached_fields[grid.id]: to_get.append(f)
+ cg = grid.retrieve_ghost_zones(1,to_get)
+ for f in [field, "GridIndices"]:
+ if f in cached_fields[grid.id]:
+ cg.data[f] = cached_fields[grid.id][f]
+ else:
+ cached_fields[grid.id][f] = cg[f]
local_ind = na.where( (cg[field] > min_val)
& (cg[field] < max_val)
& (cg["tempContours"] == -1) )
@@ -146,7 +158,7 @@
mylog.info("Identified %s contours between %0.5e and %0.5e",
len(contour_ind.keys()),min_val,max_val)
for grid in data_source._grids:
- if grid.data.has_key("tempContours"): del grid.data["tempContours"]
+ grid.data.pop("tempContours", None)
del data_source.data["tempContours"]
return contour_ind
More information about the yt-svn
mailing list