[Yt-svn] yt-commit r1196 - in trunk: tests yt yt/fido yt/lagos yt/raven

mturk at wrangler.dreamhost.com mturk at wrangler.dreamhost.com
Tue Mar 3 12:21:00 PST 2009


Author: mturk
Date: Tue Mar  3 12:20:59 2009
New Revision: 1196
URL: http://yt.spacepope.org/changeset/1196

Log:
Rewrote the parameter file store to use flat text in CSV format, rather than
the shelve module.  Stores type of parameter file.  All parameter file types
now self-register and are retrieved correctly.  Unit tests now exist for the
parameter file store.

Slices now err on the side of inclusion for coordinates, so edges-of-grids will
get selected.



Modified:
   trunk/tests/test_lagos.py
   trunk/yt/config.py
   trunk/yt/fido/ParameterFileStorage.py
   trunk/yt/lagos/BaseDataTypes.py
   trunk/yt/lagos/OutputTypes.py
   trunk/yt/lagos/__init__.py
   trunk/yt/raven/PlotTypes.py

Modified: trunk/tests/test_lagos.py
==============================================================================
--- trunk/tests/test_lagos.py	(original)
+++ trunk/tests/test_lagos.py	Tue Mar  3 12:20:59 2009
@@ -17,7 +17,9 @@
 
 import cPickle
 import yt.lagos
+import yt.lagos.OutputTypes
 import numpy as na
+from yt.fido import ParameterFileStore
 
 # The dataset used is located at:
 # http://yt.spacepope.org/DD0018.zip
@@ -41,6 +43,41 @@
         if hasattr(self,'ind_to_get'): del self.ind_to_get
         del self.OutputFile, self.hierarchy
         
+class TestParameterFileStore(unittest.TestCase):
+    def setUp(self):
+        self.original = (yt.config.ytcfg.get("yt","ParameterFileStore"),
+                         yt.config.ytcfg.get("lagos","serialize"))
+        ytcfg['yt','ParameterFileStore'] = "testing.csv"
+        pfs = ParameterFileStore()
+        os.unlink(pfs._get_db_name())
+        self.pfs = ParameterFileStore() # __init__ gets called again
+        ytcfg['lagos', 'serialize'] = "True"
+
+    def testCacheFile(self):
+        pf1 = yt.lagos.EnzoStaticOutput(fn)
+        pf2 = self.pfs.get_pf_hash(pf1._hash())
+        self.assertTrue(pf1 is pf2)
+
+    def testGrabFile(self):
+        pf1 = yt.lagos.EnzoStaticOutput(fn)
+        hash = pf1._hash()
+        del pf1
+        pf2 = self.pfs.get_pf_hash(hash)
+        self.assertTrue(hash == pf2._hash())
+
+    def testGetCurrentTimeID(self):
+        pf1 = yt.lagos.EnzoStaticOutput(fn)
+        hash = pf1._hash()
+        ctid = pf1["CurrentTimeIdentifier"]
+        del pf1
+        pf2 = self.pfs.get_pf_ctid(ctid)
+        self.assertTrue(hash == pf2._hash())
+
+    def tearDown(self):
+        os.unlink(self.pfs._get_db_name())
+        ytcfg['yt', 'ParameterFileStore'] = self.original[0]
+        ytcfg['lagos', 'serialize'] = self.original[1]
+        self.pfs.__init__()
 
 class TestHierarchy(LagosTestingBase, unittest.TestCase):
     def testGetHierarchy(self):

Modified: trunk/yt/config.py
==============================================================================
--- trunk/yt/config.py	(original)
+++ trunk/yt/config.py	Tue Mar  3 12:20:59 2009
@@ -65,6 +65,7 @@
         '__parallel_rank':'0',
         '__parallel_size':'1',
         'StoreParameterFiles': 'True',
+        'ParameterFileStore': 'parameter_files.csv',
          },
     "raven":{
         'ImagePath':".",

Modified: trunk/yt/fido/ParameterFileStorage.py
==============================================================================
--- trunk/yt/fido/ParameterFileStorage.py	(original)
+++ trunk/yt/fido/ParameterFileStorage.py	Tue Mar  3 12:20:59 2009
@@ -26,16 +26,32 @@
 from yt.config import ytcfg
 from yt.fido import *
 from yt.funcs import *
-import shelve
+from yt.lagos.ParallelTools import parallel_simple_proxy
+import csv
 import os.path
 
+output_type_registry = {}
+_field_names = ('hash','bn','fp','tt','ctid','class_name')
+
 class NoParameterShelf(Exception):
     pass
 
+class UnknownStaticOutputType(Exception):
+    def __init__(self, name):
+        self.name = name
+
+    def __str__(self):
+        return "%s" % self.name
+
+    def __repr__(self):
+        return "%s" % self.name
+
 class ParameterFileStore(object):
 
     _shared_state = {}
-    _shelf = None
+    _distributed = True
+    _processing = False
+    _owner = 0
 
     def __new__(cls, *p, **k):
         self = object.__new__(cls, *p, **k)
@@ -43,100 +59,106 @@
         return self
 
     def __init__(self, in_memory = False):
-        self.__init_shelf()
+        if ytcfg.getboolean("yt", "StoreParameterFiles"):
+            self._read_only = False
+            self.init_db()
+            self._records = self.read_db()
+        else:
+            self._read_only = True
+            self._records = {}
+
+    @parallel_simple_proxy
+    def init_db(self):
+        dbn = self._get_db_name()
+        dbdir = os.path.dirname(dbn)
+        try:
+            if not os.path.isdir(dbdir): os.mkdir(dbdir)
+        except OSError:
+            raise NoParameterShelf()
+        open(dbn, 'ab') # make sure it exists, allow to close
+        # Now we read in all our records and return them
+        # these will be broadcast
 
     def _get_db_name(self):
+        base_file_name = ytcfg.get("yt","ParameterFileStore")
         if not os.access(os.path.expanduser("~/"), os.W_OK):
-            return os.path.abspath("parameter_files.db")
-        return os.path.expanduser("~/.yt/parameter_files.db")
-
-    def wipe_hash(self, hash):
-        if hash in self.keys():
-            del self[hash]
+            return os.path.abspath(base_file_name)
+        return os.path.expanduser("~/.yt/%s" % base_file_name)
 
     def get_pf_hash(self, hash):
-        return self._convert_pf(self[hash])
+        return self._convert_pf(self._records[hash])
 
     def get_pf_ctid(self, ctid):
-        for h in self.keys():
-            if self[h]['ctid'] == ctid:
-                return self._convert_pf(self[h])
+        for h in self._records:
+            if self._records[h]['ctid'] == ctid:
+                return self._convert_pf(self._records[h])
 
     def _adapt_pf(self, pf):
         return dict(bn=pf.basename,
                     fp=pf.fullpath,
                     tt=pf["InitialTime"],
-                    ctid=pf["CurrentTimeIdentifier"])
+                    ctid=pf["CurrentTimeIdentifier"],
+                    class_name=pf.__class__.__name__)
 
     def _convert_pf(self, pf_dict):
         bn = pf_dict['bn']
         fp = pf_dict['fp']
         fn = os.path.join(fp, bn)
+        class_name = pf_dict['class_name']
+        if class_name not in output_type_registry:
+            raise UnknownStaticOutputType(class_name)
         mylog.info("Checking %s", fn)
         if os.path.exists(fn):
-            import yt.lagos.OutputTypes as ot
-            pf = ot.EnzoStaticOutput(
-                os.path.join(fp, bn))
+            pf = output_type_registry[class_name](os.path.join(fp, bn))
         else:
             raise IOError
         return pf
 
     def check_pf(self, pf):
-        if pf._hash() not in self.keys():
+        if pf._hash() not in self._records:
             self.insert_pf(pf)
             return
-        pf_dict = self[pf._hash()]
+        pf_dict = self._records[pf._hash()]
         if pf_dict['bn'] != pf.basename \
           or pf_dict['fp'] != pf.fullpath:
             self.wipe_hash(pf._hash())
             self.insert_pf(pf)
 
-    def __read_only(self):
-        if self._shelf is not None: return self._shelf
-        return shelve.open(self._get_db_name(), flag='r', protocol=-1)
-
-    def __read_write(self):
-        if self._shelf is not None: return self._shelf
-        return shelve.open(self._get_db_name(), flag='c', protocol=-1)
-
     def insert_pf(self, pf):
-        self[pf._hash()] = self._adapt_pf(pf)
+        self._records[pf._hash()] = self._adapt_pf(pf)
+        self.flush_db()
 
-    def __getitem__(self, key):
-        my_shelf = self.__read_only()
-        return my_shelf[key]
-
-    def __store_item(self, key, val):
-        my_shelf = self.__read_write()
-        my_shelf[key] = val
-
-    def __delete_item(self, key):
-        my_shelf = self.__read_write()
-        del my_shelf[key]
-
-    def __init_shelf(self):
-        dbn = self._get_db_name()
-        dbdir = os.path.dirname(dbn)
-        if not ytcfg.getboolean("yt", "StoreParameterFiles"):
-            # This ensures that even if we're not storing them in the file
-            # system, we're at least keeping track of what we load
-            self._shelf = defaultdict(lambda: dict(bn='',fp='',tt='',ctid=''))
-            return
-        try:
-            if not os.path.isdir(dbdir): os.mkdir(dbdir)
-        except OSError:
-            raise NoParameterShelf()
-        only_on_root(shelve.open, dbn, 'c', protocol=-1)
-
-    def __setitem__(self, key, val):
-        only_on_root(self.__store_item, key, val)
-
-    def __delitem__(self, key):
-        only_on_root(self.__delete_item, key)
-
-    def keys(self):
-        my_shelf = self.__read_only()
-        return my_shelf.keys()
+    def wipe_hash(self, hash):
+        if hash not in self._records: return
+        del self._records[hash]
+        self.flush_db()
+
+    def flush_db(self):
+        if self._read_only: return
+        self._write_out()
+        self.read_db()
+
+    @parallel_simple_proxy
+    def _write_out(self):
+        if self._read_only: return
+        fn = self._get_db_name()
+        f = open("%s.tmp" % fn, 'wb')
+        w = csv.DictWriter(f, _field_names)
+        for h,v in sorted(self._records.items()):
+            v['hash'] = h
+            w.writerow(v)
+        f.close()
+        os.rename("%s.tmp" % fn, fn)
+
+    @parallel_simple_proxy
+    def read_db(self):
+        f=open(self._get_db_name(), 'rb')
+        vals = csv.DictReader(f, _field_names)
+        db = {}
+        for v in vals:
+            db[v.pop('hash')] = v
+            
+        return db
 
 class ObjectStorage(object):
     pass

Modified: trunk/yt/lagos/BaseDataTypes.py
==============================================================================
--- trunk/yt/lagos/BaseDataTypes.py	(original)
+++ trunk/yt/lagos/BaseDataTypes.py	Tue Mar  3 12:20:59 2009
@@ -691,7 +691,7 @@
 
     def _get_list_of_grids(self):
         goodI = ((self.source.gridRightEdge[:,self.axis] > self.coord)
-              &  (self.source.gridLeftEdge[:,self.axis] < self.coord ))
+              &  (self.source.gridLeftEdge[:,self.axis] <= self.coord ))
         self._grids = self.source._grids[goodI] # Using sources not hierarchy
 
     def __cut_mask_child_mask(self, grid):

Modified: trunk/yt/lagos/OutputTypes.py
==============================================================================
--- trunk/yt/lagos/OutputTypes.py	(original)
+++ trunk/yt/lagos/OutputTypes.py	Tue Mar  3 12:20:59 2009
@@ -38,8 +38,10 @@
 
 class StaticOutput(object):
     class __metaclass__(type):
-        def __call__(cls, *args, **kwargs):
-            return cls.__new__(cls, *args, **kwargs)
+        def __init__(cls, name, b, d):
+            type.__init__(cls, name, b, d)
+            output_type_registry[name]=cls
+            mylog.debug("Registering: %s as %s", name, cls)
 
     def __new__(cls, filename, *args, **kwargs):
         apath = os.path.abspath(filename)
@@ -374,6 +376,10 @@
                (1.0 + self.parameters["CosmologyCurrentRedshift"])
         return k
 
+# We set our default output type to EnzoStaticOutput
+
+output_type_registry[None] = EnzoStaticOutput
+
 class EnzoStaticOutputInMemory(EnzoStaticOutput):
     _hierarchy_class = EnzoHierarchyInMemory
     def __init__(self, parameter_override=None, conversion_override=None):

Modified: trunk/yt/lagos/__init__.py
==============================================================================
--- trunk/yt/lagos/__init__.py	(original)
+++ trunk/yt/lagos/__init__.py	Tue Mar  3 12:20:59 2009
@@ -89,7 +89,7 @@
 # We by-default add universal fields.
 add_field = FieldInfo.add_field
 
-from yt.fido import ParameterFileStore
+from yt.fido import ParameterFileStore, output_type_registry
 
 from DerivedQuantities import DerivedQuantityCollection, GridChildMaskWrapper
 from DataReadingFuncs import *

Modified: trunk/yt/raven/PlotTypes.py
==============================================================================
--- trunk/yt/raven/PlotTypes.py	(original)
+++ trunk/yt/raven/PlotTypes.py	Tue Mar  3 12:20:59 2009
@@ -304,8 +304,8 @@
     def set_width(self, width, unit):
         self["Unit"] = str(unit)
         self["Width"] = float(width)
-        if isinstance(unit, types.StringType):
-            unit = self.data.hierarchy[unit]
+        if isinstance(unit, types.StringTypes):
+            unit = self.data.hierarchy[str(unit)]
         self.width = width / unit
         self._refresh_display_width()
 
@@ -548,8 +548,8 @@
     def set_width(self, width, unit):
         self["Unit"] = str(unit)
         self["Width"] = float(width)
-        if isinstance(unit, types.StringType):
-            unit = self.data.hierarchy[unit]
+        if isinstance(unit, types.StringTypes):
+            unit = self.data.hierarchy[str(unit)]
         self.width = width / unit
         self._refresh_display_width()
 



More information about the yt-svn mailing list