[Yt-svn] yt-commit r1037 - in branches/yt-object-serialization: tests yt/fido yt/lagos

mturk at wrangler.dreamhost.com mturk at wrangler.dreamhost.com
Wed Dec 24 13:54:22 PST 2008


Author: mturk
Date: Wed Dec 24 13:54:22 2008
New Revision: 1037
URL: http://yt.spacepope.org/changeset/1037

Log:
* Added tests for the pickling of objects.
* Changed the ParameterFileStorage module to use shelve rather than sqlite3
* Grids can now be serialized (I think)



Modified:
   branches/yt-object-serialization/tests/test_lagos.py
   branches/yt-object-serialization/yt/fido/ParameterFileStorage.py
   branches/yt-object-serialization/yt/lagos/BaseGridType.py

Modified: branches/yt-object-serialization/tests/test_lagos.py
==============================================================================
--- branches/yt-object-serialization/tests/test_lagos.py	(original)
+++ branches/yt-object-serialization/tests/test_lagos.py	Wed Dec 24 13:54:22 2008
@@ -15,6 +15,7 @@
 ytcfg["yt","suppressStreamLogging"] = "True"
 ytcfg["lagos","serialize"] = "False"
 
+import cPickle
 import yt.lagos
 import numpy as na
 
@@ -189,6 +190,10 @@
                     and na.all(v2 > self.data["Density"][cid[0]]))
         self.assertEqual(len(cid), 3)
 
+    def testPickle(self):
+        ps = cPickle.dumps(self.data)
+        pf, obj = cPickle.loads(ps)
+        self.assertEqual(obj["CellMassMsun"].sum(), self.data["CellMassMsun"].sum())
 
 for field_name in yt.lagos.FieldInfo:
     field = yt.lagos.FieldInfo[field_name]

Modified: branches/yt-object-serialization/yt/fido/ParameterFileStorage.py
==============================================================================
--- branches/yt-object-serialization/yt/fido/ParameterFileStorage.py	(original)
+++ branches/yt-object-serialization/yt/fido/ParameterFileStorage.py	Wed Dec 24 13:54:22 2008
@@ -39,7 +39,8 @@
 ##      Hash            hash    text
 
 from yt.fido import *
-import sqlite3
+from yt.funcs import *
+import shelve
 import os.path
 
 #sqlite3.register_adapter(yt.lagos.OutputTypes.EnzoStaticOutput, _adapt_pf)
@@ -48,7 +49,7 @@
 class ParameterFileStore(object):
 
     _shared_state = {}
-    _conn = None
+    _shelve = None
 
     def __new__(cls, *p, **k):
         self = object.__new__(cls, *p, **k)
@@ -56,59 +57,34 @@
         return self
 
     def __init__(self, in_memory = False):
-        if self._conn is None:
-            self._conn = sqlite3.connect(self._get_db_name(),
-                    detect_types=sqlite3.PARSE_DECLTYPES)
-            self._conn.row_factory = self._convert_pf
-            self._initialize_new()
+        if self._shelve is None:
+            self._shelve = shelve.open(self._get_db_name())
         
     def _get_db_name(self):
-        return os.path.expanduser("~/.yt/pfdb.sql")
-
-    def _initialize_new(self, filename = None):
-        c = self._conn.cursor()
-        try:
-            c.execute("""create table parameter_files
-                            (pf text, path text, time real,
-                             ctid real, hash text primary key unique)""")
-            self._conn.commit()
-        except sqlite3.OperationalError:
-            pass
-        c.close()
+        return os.path.expanduser("~/.yt/parameter_files.db")
 
     def wipe_hash(self, hash):
-        c = self._conn.cursor()
-        c.execute("delete from parameter_files where hash=?", (hash,))
-        self._conn.commit()
-        c.close()
+        if hash in self._shelve:
+            del self._shelve[hash]
+            self._shelve.sync()
 
     def get_pf_hash(self, hash):
-        c = self._conn.cursor()
-        c.execute("""select * from parameter_files where hash=?""",
-                  (hash,))
-        return c.fetchall()[0] # Unique
+        return self._convert_pf(self._shelve[hash])
 
     def get_pf_ctid(self, ctid):
-        c = self._conn.cursor()
-        c.execute("""select * from parameter_files where ctid=?""",
-                  (ctid,))
-        return c.fetchall()
-
-    def get_count_hash(self, hash):
-        c = self._conn.cursor()
-        c.execute("""select pf, path from parameter_files where hash=?""",
-                  (hash,))
-        res = c.fetchall()
-        return len(res), res
+        for h in self._shelve:
+            if self._shelve[h]['ctid'] == ctid:
+                return self._convert_pf(self._shelve[h])
 
     def _adapt_pf(self, pf):
-        return (pf.basename, pf.fullpath,
-                pf["InitialTime"], pf["CurrentTimeIdentifier"],
-                pf._hash())
-
-    def _convert_pf(self, cursor, row):
-        if len(row) != 5: return row
-        bn, fp, t1, ctid, hash = row
+        return dict(bn=pf.basename,
+                    fp=pf.fullpath,
+                    tt=pf["InitialTime"],
+                    ctid=pf["CurrentTimeIdentifier"])
+
+    def _convert_pf(self, pf_dict):
+        bn = pf_dict['bn']
+        fp = pf_dict['fp']
         fn = os.path.join(fp, bn)
         if os.path.exists(fn):
             import yt.lagos.OutputTypes as ot
@@ -119,25 +95,18 @@
         return pf
 
     def check_pf(self, pf):
-        rc, res = self.get_count_hash(pf._hash())
-        if rc == 0:
-            self.insert_pf(pf)
-            return
-        elif rc > 1:
-            self.wipe_hash(pf._hash())
+        if pf._hash() not in self._shelve:
             self.insert_pf(pf)
             return
-        bn, fp = res[0]
-        if bn != pf.basename or fp != pf.fullpath:
+        pf_dict = self._shelve[pf._hash()]
+        if pf_dict['bn'] != pf.basename \
+          or pf_dict['fp'] != pf.fullpath:
             self.wipe_hash(pf._hash())
             self.insert_pf(pf)
 
     def insert_pf(self, pf):
-        c = self._conn.cursor()
-        c.execute("""insert into parameter_files values
-                     (?,?,?,?,?)""", self._adapt_pf(pf))
-        self._conn.commit()
-        c.close()
+        self._shelve[pf._hash()] = self._adapt_pf(pf)
+        self._shelve.sync()
 
 class ObjectStorage(object):
     pass

Modified: branches/yt-object-serialization/yt/lagos/BaseGridType.py
==============================================================================
--- branches/yt-object-serialization/yt/lagos/BaseGridType.py	(original)
+++ branches/yt-object-serialization/yt/lagos/BaseGridType.py	Wed Dec 24 13:54:22 2008
@@ -33,6 +33,9 @@
     _grids = None
     _id_offset = 1
 
+    _type_name = 'grid'
+    _con_args = ['id', 'filename']
+
     def __init__(self, id, filename=None, hierarchy = None):
         self.data = {}
         self.field_parameters = {}



More information about the yt-svn mailing list