[Yt-svn] yt-commit r1196 - in trunk: tests yt yt/fido yt/lagos yt/raven
mturk at wrangler.dreamhost.com
mturk at wrangler.dreamhost.com
Tue Mar 3 12:21:00 PST 2009
Author: mturk
Date: Tue Mar 3 12:20:59 2009
New Revision: 1196
URL: http://yt.spacepope.org/changeset/1196
Log:
Rewrote the parameter file store to use flat text in CSV format, rather than
the shelve module. Stores type of parameter file. All parameter file types
now self-register and are retrieved correctly. Unit tests now exist for the
parameter file store.
Slices now err on the side of inclusion for coordinates, so edges-of-grids will
get selected.
Modified:
trunk/tests/test_lagos.py
trunk/yt/config.py
trunk/yt/fido/ParameterFileStorage.py
trunk/yt/lagos/BaseDataTypes.py
trunk/yt/lagos/OutputTypes.py
trunk/yt/lagos/__init__.py
trunk/yt/raven/PlotTypes.py
Modified: trunk/tests/test_lagos.py
==============================================================================
--- trunk/tests/test_lagos.py (original)
+++ trunk/tests/test_lagos.py Tue Mar 3 12:20:59 2009
@@ -17,7 +17,9 @@
import cPickle
import yt.lagos
+import yt.lagos.OutputTypes
import numpy as na
+from yt.fido import ParameterFileStore
# The dataset used is located at:
# http://yt.spacepope.org/DD0018.zip
@@ -41,6 +43,41 @@
if hasattr(self,'ind_to_get'): del self.ind_to_get
del self.OutputFile, self.hierarchy
+class TestParameterFileStore(unittest.TestCase):
+ def setUp(self):
+ self.original = (yt.config.ytcfg.get("yt","ParameterFileStore"),
+ yt.config.ytcfg.get("lagos","serialize"))
+ ytcfg['yt','ParameterFileStore'] = "testing.csv"
+ pfs = ParameterFileStore()
+ os.unlink(pfs._get_db_name())
+ self.pfs = ParameterFileStore() # __init__ gets called again
+ ytcfg['lagos', 'serialize'] = "True"
+
+ def testCacheFile(self):
+ pf1 = yt.lagos.EnzoStaticOutput(fn)
+ pf2 = self.pfs.get_pf_hash(pf1._hash())
+ self.assertTrue(pf1 is pf2)
+
+ def testGrabFile(self):
+ pf1 = yt.lagos.EnzoStaticOutput(fn)
+ hash = pf1._hash()
+ del pf1
+ pf2 = self.pfs.get_pf_hash(hash)
+ self.assertTrue(hash == pf2._hash())
+
+ def testGetCurrentTimeID(self):
+ pf1 = yt.lagos.EnzoStaticOutput(fn)
+ hash = pf1._hash()
+ ctid = pf1["CurrentTimeIdentifier"]
+ del pf1
+ pf2 = self.pfs.get_pf_ctid(ctid)
+ self.assertTrue(hash == pf2._hash())
+
+ def tearDown(self):
+ os.unlink(self.pfs._get_db_name())
+ ytcfg['yt', 'ParameterFileStore'] = self.original[0]
+ ytcfg['lagos', 'serialize'] = self.original[1]
+ self.pfs.__init__()
class TestHierarchy(LagosTestingBase, unittest.TestCase):
def testGetHierarchy(self):
Modified: trunk/yt/config.py
==============================================================================
--- trunk/yt/config.py (original)
+++ trunk/yt/config.py Tue Mar 3 12:20:59 2009
@@ -65,6 +65,7 @@
'__parallel_rank':'0',
'__parallel_size':'1',
'StoreParameterFiles': 'True',
+ 'ParameterFileStore': 'parameter_files.csv',
},
"raven":{
'ImagePath':".",
Modified: trunk/yt/fido/ParameterFileStorage.py
==============================================================================
--- trunk/yt/fido/ParameterFileStorage.py (original)
+++ trunk/yt/fido/ParameterFileStorage.py Tue Mar 3 12:20:59 2009
@@ -26,16 +26,32 @@
from yt.config import ytcfg
from yt.fido import *
from yt.funcs import *
-import shelve
+from yt.lagos.ParallelTools import parallel_simple_proxy
+import csv
import os.path
+output_type_registry = {}
+_field_names = ('hash','bn','fp','tt','ctid','class_name')
+
class NoParameterShelf(Exception):
pass
+class UnknownStaticOutputType(Exception):
+ def __init__(self, name):
+ self.name = name
+
+ def __str__(self):
+ return "%s" % self.name
+
+ def __repr__(self):
+ return "%s" % self.name
+
class ParameterFileStore(object):
_shared_state = {}
- _shelf = None
+ _distributed = True
+ _processing = False
+ _owner = 0
def __new__(cls, *p, **k):
self = object.__new__(cls, *p, **k)
@@ -43,100 +59,106 @@
return self
def __init__(self, in_memory = False):
- self.__init_shelf()
+ if ytcfg.getboolean("yt", "StoreParameterFiles"):
+ self._read_only = False
+ self.init_db()
+ self._records = self.read_db()
+ else:
+ self._read_only = True
+ self._records = {}
+
+ @parallel_simple_proxy
+ def init_db(self):
+ dbn = self._get_db_name()
+ dbdir = os.path.dirname(dbn)
+ try:
+ if not os.path.isdir(dbdir): os.mkdir(dbdir)
+ except OSError:
+ raise NoParameterShelf()
+ open(dbn, 'ab') # make sure it exists, allow to close
+ # Now we read in all our records and return them
+ # these will be broadcast
def _get_db_name(self):
+ base_file_name = ytcfg.get("yt","ParameterFileStore")
if not os.access(os.path.expanduser("~/"), os.W_OK):
- return os.path.abspath("parameter_files.db")
- return os.path.expanduser("~/.yt/parameter_files.db")
-
- def wipe_hash(self, hash):
- if hash in self.keys():
- del self[hash]
+ return os.path.abspath(base_file_name)
+ return os.path.expanduser("~/.yt/%s" % base_file_name)
def get_pf_hash(self, hash):
- return self._convert_pf(self[hash])
+ return self._convert_pf(self._records[hash])
def get_pf_ctid(self, ctid):
- for h in self.keys():
- if self[h]['ctid'] == ctid:
- return self._convert_pf(self[h])
+ for h in self._records:
+ if self._records[h]['ctid'] == ctid:
+ return self._convert_pf(self._records[h])
def _adapt_pf(self, pf):
return dict(bn=pf.basename,
fp=pf.fullpath,
tt=pf["InitialTime"],
- ctid=pf["CurrentTimeIdentifier"])
+ ctid=pf["CurrentTimeIdentifier"],
+ class_name=pf.__class__.__name__)
def _convert_pf(self, pf_dict):
bn = pf_dict['bn']
fp = pf_dict['fp']
fn = os.path.join(fp, bn)
+ class_name = pf_dict['class_name']
+ if class_name not in output_type_registry:
+ raise UnknownStaticOutputType(class_name)
mylog.info("Checking %s", fn)
if os.path.exists(fn):
- import yt.lagos.OutputTypes as ot
- pf = ot.EnzoStaticOutput(
- os.path.join(fp, bn))
+ pf = output_type_registry[class_name](os.path.join(fp, bn))
else:
raise IOError
return pf
def check_pf(self, pf):
- if pf._hash() not in self.keys():
+ if pf._hash() not in self._records:
self.insert_pf(pf)
return
- pf_dict = self[pf._hash()]
+ pf_dict = self._records[pf._hash()]
if pf_dict['bn'] != pf.basename \
or pf_dict['fp'] != pf.fullpath:
self.wipe_hash(pf._hash())
self.insert_pf(pf)
- def __read_only(self):
- if self._shelf is not None: return self._shelf
- return shelve.open(self._get_db_name(), flag='r', protocol=-1)
-
- def __read_write(self):
- if self._shelf is not None: return self._shelf
- return shelve.open(self._get_db_name(), flag='c', protocol=-1)
-
def insert_pf(self, pf):
- self[pf._hash()] = self._adapt_pf(pf)
+ self._records[pf._hash()] = self._adapt_pf(pf)
+ self.flush_db()
- def __getitem__(self, key):
- my_shelf = self.__read_only()
- return my_shelf[key]
-
- def __store_item(self, key, val):
- my_shelf = self.__read_write()
- my_shelf[key] = val
-
- def __delete_item(self, key):
- my_shelf = self.__read_write()
- del my_shelf[key]
-
- def __init_shelf(self):
- dbn = self._get_db_name()
- dbdir = os.path.dirname(dbn)
- if not ytcfg.getboolean("yt", "StoreParameterFiles"):
- # This ensures that even if we're not storing them in the file
- # system, we're at least keeping track of what we load
- self._shelf = defaultdict(lambda: dict(bn='',fp='',tt='',ctid=''))
- return
- try:
- if not os.path.isdir(dbdir): os.mkdir(dbdir)
- except OSError:
- raise NoParameterShelf()
- only_on_root(shelve.open, dbn, 'c', protocol=-1)
-
- def __setitem__(self, key, val):
- only_on_root(self.__store_item, key, val)
-
- def __delitem__(self, key):
- only_on_root(self.__delete_item, key)
-
- def keys(self):
- my_shelf = self.__read_only()
- return my_shelf.keys()
+ def wipe_hash(self, hash):
+ if hash not in self._records: return
+ del self._records[hash]
+ self.flush_db()
+
+ def flush_db(self):
+ if self._read_only: return
+ self._write_out()
+ self.read_db()
+
+ @parallel_simple_proxy
+ def _write_out(self):
+ if self._read_only: return
+ fn = self._get_db_name()
+ f = open("%s.tmp" % fn, 'wb')
+ w = csv.DictWriter(f, _field_names)
+ for h,v in sorted(self._records.items()):
+ v['hash'] = h
+ w.writerow(v)
+ f.close()
+ os.rename("%s.tmp" % fn, fn)
+
+ @parallel_simple_proxy
+ def read_db(self):
+ f=open(self._get_db_name(), 'rb')
+ vals = csv.DictReader(f, _field_names)
+ db = {}
+ for v in vals:
+ db[v.pop('hash')] = v
+
+ return db
class ObjectStorage(object):
pass
Modified: trunk/yt/lagos/BaseDataTypes.py
==============================================================================
--- trunk/yt/lagos/BaseDataTypes.py (original)
+++ trunk/yt/lagos/BaseDataTypes.py Tue Mar 3 12:20:59 2009
@@ -691,7 +691,7 @@
def _get_list_of_grids(self):
goodI = ((self.source.gridRightEdge[:,self.axis] > self.coord)
- & (self.source.gridLeftEdge[:,self.axis] < self.coord ))
+ & (self.source.gridLeftEdge[:,self.axis] <= self.coord ))
self._grids = self.source._grids[goodI] # Using sources not hierarchy
def __cut_mask_child_mask(self, grid):
Modified: trunk/yt/lagos/OutputTypes.py
==============================================================================
--- trunk/yt/lagos/OutputTypes.py (original)
+++ trunk/yt/lagos/OutputTypes.py Tue Mar 3 12:20:59 2009
@@ -38,8 +38,10 @@
class StaticOutput(object):
class __metaclass__(type):
- def __call__(cls, *args, **kwargs):
- return cls.__new__(cls, *args, **kwargs)
+ def __init__(cls, name, b, d):
+ type.__init__(cls, name, b, d)
+ output_type_registry[name]=cls
+ mylog.debug("Registering: %s as %s", name, cls)
def __new__(cls, filename, *args, **kwargs):
apath = os.path.abspath(filename)
@@ -374,6 +376,10 @@
(1.0 + self.parameters["CosmologyCurrentRedshift"])
return k
+# We set our default output type to EnzoStaticOutput
+
+output_type_registry[None] = EnzoStaticOutput
+
class EnzoStaticOutputInMemory(EnzoStaticOutput):
_hierarchy_class = EnzoHierarchyInMemory
def __init__(self, parameter_override=None, conversion_override=None):
Modified: trunk/yt/lagos/__init__.py
==============================================================================
--- trunk/yt/lagos/__init__.py (original)
+++ trunk/yt/lagos/__init__.py Tue Mar 3 12:20:59 2009
@@ -89,7 +89,7 @@
# We by-default add universal fields.
add_field = FieldInfo.add_field
-from yt.fido import ParameterFileStore
+from yt.fido import ParameterFileStore, output_type_registry
from DerivedQuantities import DerivedQuantityCollection, GridChildMaskWrapper
from DataReadingFuncs import *
Modified: trunk/yt/raven/PlotTypes.py
==============================================================================
--- trunk/yt/raven/PlotTypes.py (original)
+++ trunk/yt/raven/PlotTypes.py Tue Mar 3 12:20:59 2009
@@ -304,8 +304,8 @@
def set_width(self, width, unit):
self["Unit"] = str(unit)
self["Width"] = float(width)
- if isinstance(unit, types.StringType):
- unit = self.data.hierarchy[unit]
+ if isinstance(unit, types.StringTypes):
+ unit = self.data.hierarchy[str(unit)]
self.width = width / unit
self._refresh_display_width()
@@ -548,8 +548,8 @@
def set_width(self, width, unit):
self["Unit"] = str(unit)
self["Width"] = float(width)
- if isinstance(unit, types.StringType):
- unit = self.data.hierarchy[unit]
+ if isinstance(unit, types.StringTypes):
+ unit = self.data.hierarchy[str(unit)]
self.width = width / unit
self._refresh_display_width()
More information about the yt-svn
mailing list