[Yt-svn] commit/yt: 2 new changesets
Bitbucket
commits-noreply at bitbucket.org
Wed Oct 19 12:11:21 PDT 2011
2 new changesets in yt:
http://bitbucket.org/yt_analysis/yt/changeset/cfa54be8fa1c/
changeset: cfa54be8fa1c
branch: yt
user: caseywstark
date: 2011-10-14 16:30:47
summary: Nyx frontend: remove dependency on comovingcoordinates input (always comoving now). General tidying up.
affected #: 1 file (-1 bytes)
--- a/yt/frontends/nyx/data_structures.py Thu Oct 13 08:46:35 2011 -0400
+++ b/yt/frontends/nyx/data_structures.py Fri Oct 14 10:30:47 2011 -0400
@@ -76,7 +76,6 @@
def _prepare_grid(self):
""" Copies all the appropriate attributes from the hierarchy. """
- # This is definitely the slowest part of generating the hierarchy
h = self.hierarchy # alias
h.grid_levels[self.id, 0] = self.Level
h.grid_left_edge[self.id,:] = self.LeftEdge[:]
@@ -94,11 +93,12 @@
if len(pIDs) > 0:
self.Parent = [weakref.proxy(h.grids[pID]) for pID in pIDs]
else:
+ # must be root grid
self.Parent = None
def _setup_dx(self):
# So first we figure out what the index is. We don't assume that
- # dx=dy=dz, at least here. We probably do elsewhere.
+ # dx=dy=dz here.
id = self.id - self._id_offset
if self.Parent is not None:
self.dds = self.Parent[0].dds / self.pf.refine_by
@@ -132,7 +132,6 @@
self.read_particle_header()
self.__cache_endianness(self.levels[-1].grids[-1])
- # @todo: should be first line
AMRHierarchy.__init__(self, pf, self.data_style)
self._setup_data_io()
self._setup_field_list()
@@ -142,27 +141,27 @@
""" Read the global header file for an Nyx plotfile output. """
counter = 0
header_file = open(header_path, 'r')
- self.__global_header_lines = header_file.readlines()
+ self._global_header_lines = header_file.readlines()
# parse the file
- self.nyx_version = self.__global_header_lines[0].rstrip()
- self.n_fields = int(self.__global_header_lines[1])
+ self.nyx_pf_version = self._global_header_lines[0].rstrip()
+ self.n_fields = int(self._global_header_lines[1])
# why the 2?
counter = self.n_fields + 2
self.field_list = []
- for i, line in enumerate(self.__global_header_lines[2:counter]):
+ for i, line in enumerate(self._global_header_lines[2:counter]):
self.field_list.append(line.rstrip())
# figure out dimensions and make sure it's 3D
- self.dimension = int(self.__global_header_lines[counter])
+ self.dimension = int(self._global_header_lines[counter])
if self.dimension != 3:
raise RunTimeError("Current data is %iD. yt only supports Nyx data in 3D" % self.dimension)
counter += 1
- self.Time = float(self.__global_header_lines[counter])
+ self.Time = float(self._global_header_lines[counter])
counter += 1
- self.finest_grid_level = int(self.__global_header_lines[counter])
+ self.finest_grid_level = int(self._global_header_lines[counter])
self.n_levels = self.finest_grid_level + 1
counter += 1
@@ -171,28 +170,28 @@
# case in the future we want to enable a "backwards" way of
# taking the data out of the Header file and using it to fill
# in in the case of a missing inputs file
- self.domainLeftEdge_unnecessary = na.array(map(float, self.__global_header_lines[counter].split()))
+ self.domainLeftEdge_unnecessary = na.array(map(float, self._global_header_lines[counter].split()))
counter += 1
- self.domainRightEdge_unnecessary = na.array(map(float, self.__global_header_lines[counter].split()))
+ self.domainRightEdge_unnecessary = na.array(map(float, self._global_header_lines[counter].split()))
counter += 1
- self.refinementFactor_unnecessary = self.__global_header_lines[counter].split() #na.array(map(int, self.__global_header_lines[counter].split()))
+ self.refinementFactor_unnecessary = self._global_header_lines[counter].split() #na.array(map(int, self._global_header_lines[counter].split()))
counter += 1
- self.globalIndexSpace_unnecessary = self.__global_header_lines[counter]
+ self.globalIndexSpace_unnecessary = self._global_header_lines[counter]
counter += 1
- self.timestepsPerLevel_unnecessary = self.__global_header_lines[counter]
+ self.timestepsPerLevel_unnecessary = self._global_header_lines[counter]
counter += 1
self.dx = na.zeros((self.n_levels, 3))
- for i, line in enumerate(self.__global_header_lines[counter:counter + self.n_levels]):
+ for i, line in enumerate(self._global_header_lines[counter:counter + self.n_levels]):
self.dx[i] = na.array(map(float, line.split()))
counter += self.n_levels
- self.geometry = int(self.__global_header_lines[counter])
+ self.geometry = int(self._global_header_lines[counter])
if self.geometry != 0:
raise RunTimeError("yt only supports cartesian coordinates.")
counter += 1
# @todo: this is just to debug. eventually it should go away.
- linebreak = int(self.__global_header_lines[counter])
+ linebreak = int(self._global_header_lines[counter])
if linebreak != 0:
raise RunTimeError("INTERNAL ERROR! This should be a zero.")
counter += 1
@@ -209,11 +208,11 @@
data_files_finder = re.compile(data_files_pattern)
for level in range(0, self.n_levels):
- tmp = self.__global_header_lines[counter].split()
+ tmp = self._global_header_lines[counter].split()
# should this be grid_time or level_time??
lev, ngrids, grid_time = int(tmp[0]), int(tmp[1]), float(tmp[2])
counter += 1
- nsteps = int(self.__global_header_lines[counter])
+ nsteps = int(self._global_header_lines[counter])
counter += 1
self.levels.append(NyxLevel(lev, ngrids))
@@ -227,10 +226,10 @@
key_off = 0
files = {}
offsets = {}
- while nfiles + tmp_offset < len(self.__global_header_lines) \
- and data_files_finder.match(self.__global_header_lines[nfiles + tmp_offset]):
- filen = os.path.join(self.parameter_file.path, \
- self.__global_header_lines[nfiles + tmp_offset].strip())
+ while nfiles + tmp_offset < len(self._global_header_lines) \
+ and data_files_finder.match(self._global_header_lines[nfiles + tmp_offset]):
+ filen = os.path.join(self.parameter_file.path,
+ self._global_header_lines[nfiles + tmp_offset].strip())
# open each "_H" header file, and get the number of
# components within it
level_header_file = open(filen + '_H', 'r').read()
@@ -262,11 +261,11 @@
for grid in range(0, ngrids):
gfn = fn[grid] # filename of file containing this grid
gfo = off[grid] # offset within that file
- xlo, xhi = map(float, self.__global_header_lines[counter].split())
+ xlo, xhi = map(float, self._global_header_lines[counter].split())
counter += 1
- ylo, yhi = map(float, self.__global_header_lines[counter].split())
+ ylo, yhi = map(float, self._global_header_lines[counter].split())
counter += 1
- zlo, zhi = map(float, self.__global_header_lines[counter].split())
+ zlo, zhi = map(float, self._global_header_lines[counter].split())
counter += 1
lo = na.array([xlo, ylo, zlo])
hi = na.array([xhi, yhi, zhi])
@@ -307,6 +306,7 @@
for i in line.split()),
dtype='int64',
count=3*self.num_grids).reshape((self.num_grids, 3))
+ # we need grid_info in `populate_grid_objects`, so save it to self
self.pgrid_info = grid_info
def __cache_endianness(self, test_grid):
@@ -356,17 +356,17 @@
g.NumberOfParticles = pg[1]
g._particle_offset = pg[2]
- self.grid_particle_count[:,0] = self.pgrid_info[:,1]
- del self.pgrid_info # if this is all pgrid_info is used for...
+ self.grid_particle_count[:, 0] = self.pgrid_info[:, 1]
+ del self.pgrid_info
gls = na.concatenate([level.ngrids * [level.level] for level in self.levels])
self.grid_levels[:] = gls.reshape((self.num_grids, 1))
grid_dcs = na.concatenate([level.ngrids*[self.dx[level.level]]
for level in self.levels], axis=0)
- self.grid_dxs = grid_dcs[:,0].reshape((self.num_grids, 1))
- self.grid_dys = grid_dcs[:,1].reshape((self.num_grids, 1))
- self.grid_dzs = grid_dcs[:,2].reshape((self.num_grids, 1))
+ self.grid_dxs = grid_dcs[:, 0].reshape((self.num_grids, 1))
+ self.grid_dys = grid_dcs[:, 1].reshape((self.num_grids, 1))
+ self.grid_dzs = grid_dcs[:, 2].reshape((self.num_grids, 1))
left_edges = []
right_edges = []
@@ -381,7 +381,7 @@
self.grid_dimensions = na.array(dims)
self.gridReverseTree = [] * self.num_grids
self.gridReverseTree = [ [] for i in range(self.num_grids)] # why the same thing twice?
- self.gridTree = [ [] for i in range(self.num_grids)] # meh
+ self.gridTree = [ [] for i in range(self.num_grids)]
mylog.debug("Done creating grid objects")
@@ -389,7 +389,7 @@
self.__setup_grid_tree()
for i, grid in enumerate(self.grids):
- if (i%1e4) == 0:
+ if (i % 1e4) == 0:
mylog.debug("Prepared % 7i / % 7i grids", i, self.num_grids)
grid._prepare_grid()
@@ -469,7 +469,7 @@
pass
def _setup_unknown_fields(self):
- # Doesn't seem useful
+ # not sure what the case for this is.
for field in self.field_list:
if field in self.parameter_file.field_info: continue
mylog.info("Adding %s to list of fields", field)
@@ -588,7 +588,6 @@
Parses the parameter file and establishes the various dictionaries.
"""
- # More boxlib madness...
self._parse_header_file()
if os.path.isfile(self.fparameter_file_path):
@@ -638,27 +637,24 @@
self.domain_dimensions = self.parameters["TopGridDimensions"]
self.refine_by = self.parameters.get("RefineBy", 2) # 2 is silent default? Makes sense I suppose.
- if self.parameters.has_key("ComovingCoordinates") \
- and self.parameters["ComovingCoordinates"]:
- self.cosmological_simulation = 1
- self.omega_lambda = self.parameters["CosmologyOmegaLambdaNow"]
- self.omega_matter = self.parameters["CosmologyOmegaMatterNow"]
- self.hubble_constant = self.parameters["CosmologyHubbleConstantNow"]
+ # Nyx is always cosmological.
+ self.cosmological_simulation = 1
+ self.omega_lambda = self.parameters["CosmologyOmegaLambdaNow"]
+ self.omega_matter = self.parameters["CosmologyOmegaMatterNow"]
+ self.hubble_constant = self.parameters["CosmologyHubbleConstantNow"]
- # So broken. We will fix this in the new Nyx output format
- a_file = open(os.path.join(self.path, "comoving_a"))
- line = a_file.readline().strip()
- a_file.close()
- self.parameters["CosmologyCurrentRedshift"] = 1 / float(line) - 1
- self.cosmological_scale_factor = float(line)
+ # Read in the `comoving_a` file and parse the value. We should fix this
+ # in the new Nyx output format...
+ a_file = open(os.path.join(self.path, "comoving_a"))
+ a_string = a_file.readline().strip()
+ a_file.close()
- # alias
- self.current_redshift = self.parameters["CosmologyCurrentRedshift"]
+ # Set the scale factor and redshift
+ self.cosmological_scale_factor = float(a_string)
+ self.parameters["CosmologyCurrentRedshift"] = 1 / float(a_string) - 1
- else:
- # @todo: automatic defaults
- self.current_redshift = self.omega_lambda = self.omega_matter = \
- self.hubble_constant = self.cosmological_simulation = 0.0
+ # alias
+ self.current_redshift = self.parameters["CosmologyCurrentRedshift"]
def _parse_header_file(self):
"""
@@ -668,13 +664,12 @@
Currently, only Time is read here.
"""
- # @todo: header filename option? probably not.
header_file = open(os.path.join(self.path, "Header"))
lines = header_file.readlines() # hopefully this is small
header_file.close()
n_fields = int(lines[1]) # this could change
- self.current_time = float(lines[3 + n_fields]) # very fragile
+ self.current_time = float(lines[3 + n_fields]) # fragile
def _parse_fparameter_file(self):
"""
@@ -751,7 +746,6 @@
self.time_units["days"] = seconds / (3600 * 24.0)
self.time_units["years"] = seconds / (3600 * 24.0 * 365)
-
# not the most useful right now, but someday
for key in nyx_particle_field_names:
self.conversion_factors[key] = 1.0
http://bitbucket.org/yt_analysis/yt/changeset/c4f82b5495cb/
changeset: c4f82b5495cb
branch: yt
user: caseywstark
date: 2011-10-14 17:23:40
summary: Merged yt/yt
affected #: 11 files (-1 bytes)
--- a/yt/convenience.py Fri Oct 14 10:30:47 2011 -0400
+++ b/yt/convenience.py Fri Oct 14 11:23:40 2011 -0400
@@ -32,8 +32,7 @@
from yt.funcs import *
from yt.config import ytcfg
from yt.utilities.parameter_file_storage import \
- output_type_registry, \
- EnzoRunDatabase
+ output_type_registry
def all_pfs(basedir='.', skip=None, max_depth=1, name_spec="*.hierarchy", **kwargs):
"""
@@ -92,15 +91,6 @@
if len(candidates) == 1:
return output_type_registry[candidates[0]](*args, **kwargs)
if len(candidates) == 0:
- if ytcfg.get("yt", "enzo_db") != '' \
- and len(args) == 1 \
- and isinstance(args[0], types.StringTypes):
- erdb = EnzoRunDatabase()
- fn = erdb.find_uuid(args[0])
- n = "EnzoStaticOutput"
- if n in output_type_registry \
- and output_type_registry[n]._is_valid(fn):
- return output_type_registry[n](fn)
mylog.error("Couldn't figure out output type for %s", args[0])
return None
mylog.error("Multiple output type candidates for %s:", args[0])
--- a/yt/data_objects/static_output.py Fri Oct 14 10:30:47 2011 -0400
+++ b/yt/data_objects/static_output.py Fri Oct 14 11:23:40 2011 -0400
@@ -104,6 +104,8 @@
return self.basename
def _hash(self):
+ if "MetaDataDatasetUUID" in self.parameters:
+ return self["MetaDataDatasetUUID"]
s = "%s;%s;%s" % (self.basename,
self.current_time, self.unique_identifier)
try:
--- a/yt/frontends/gdf/api.py Fri Oct 14 10:30:47 2011 -0400
+++ b/yt/frontends/gdf/api.py Fri Oct 14 11:23:40 2011 -0400
@@ -29,14 +29,15 @@
"""
from .data_structures import \
- ChomboGrid, \
- ChomboHierarchy, \
- ChomboStaticOutput
+ GDFGrid, \
+ GDFHierarchy, \
+ GDFStaticOutput
from .fields import \
- ChomboFieldContainer, \
- ChomboFieldInfo, \
- add_chombo_field
+ GDFFieldContainer, \
+ GDFFieldInfo, \
+ add_gdf_field
from .io import \
- IOHandlerChomboHDF5
+ IOHandlerGDFHDF5
+
--- a/yt/frontends/gdf/data_structures.py Fri Oct 14 10:30:47 2011 -0400
+++ b/yt/frontends/gdf/data_structures.py Fri Oct 14 11:23:40 2011 -0400
@@ -24,6 +24,9 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
+import h5py
+import numpy as na
+import weakref
from yt.funcs import *
from yt.data_objects.grid_patch import \
AMRGridPatch
@@ -33,6 +36,7 @@
StaticOutput
from .fields import GDFFieldContainer
+import pdb
class GDFGrid(AMRGridPatch):
_id_offset = 0
@@ -58,14 +62,16 @@
self.dds = na.array((RE-LE)/self.ActiveDimensions)
if self.pf.dimensionality < 2: self.dds[1] = 1.0
if self.pf.dimensionality < 3: self.dds[2] = 1.0
+ # pdb.set_trace()
self.data['dx'], self.data['dy'], self.data['dz'] = self.dds
class GDFHierarchy(AMRHierarchy):
grid = GDFGrid
-
+
def __init__(self, pf, data_style='grid_data_format'):
self.parameter_file = weakref.proxy(pf)
+ self.data_style = data_style
# for now, the hierarchy file is the parameter file!
self.hierarchy_filename = self.parameter_file.parameter_filename
self.directory = os.path.dirname(self.hierarchy_filename)
@@ -78,46 +84,39 @@
pass
def _detect_fields(self):
- ncomp = int(self._fhandle['/'].attrs['num_components'])
- self.field_list = [c[1] for c in self._fhandle['/'].attrs.listitems()[-ncomp:]]
-
+ self.field_list = self._fhandle['field_types'].keys()
+
def _setup_classes(self):
dd = self._get_data_reader_dict()
AMRHierarchy._setup_classes(self, dd)
self.object_types.sort()
def _count_grids(self):
- self.num_grids = 0
- for lev in self._levels:
- self.num_grids += self._fhandle[lev]['Processors'].len()
-
+ self.num_grids = self._fhandle['/grid_parent_id'].shape[0]
+
def _parse_hierarchy(self):
f = self._fhandle # shortcut
-
+
# this relies on the first Group in the H5 file being
# 'Chombo_global'
levels = f.listnames()[1:]
self.grids = []
- i = 0
- for lev in levels:
- level_number = int(re.match('level_(\d+)',lev).groups()[0])
- boxes = f[lev]['boxes'].value
- dx = f[lev].attrs['dx']
- for level_id, box in enumerate(boxes):
- si = na.array([box['lo_%s' % ax] for ax in 'ijk'])
- ei = na.array([box['hi_%s' % ax] for ax in 'ijk'])
- pg = self.grid(len(self.grids),self,level=level_number,
- start = si, stop = ei)
- self.grids.append(pg)
- self.grids[-1]._level_id = level_id
- self.grid_left_edge[i] = dx*si.astype(self.float_type)
- self.grid_right_edge[i] = dx*(ei.astype(self.float_type) + 1)
- self.grid_particle_count[i] = 0
- self.grid_dimensions[i] = ei - si + 1
- i += 1
- temp_grids = na.empty(len(grids), dtype='object')
- for gi, g in enumerate(self.grids): temp_grids[gi] = g
- self.grids = temp_grids
+ for i, grid in enumerate(f['data'].keys()):
+ self.grids.append(self.grid(i, self, f['grid_level'][i],
+ f['grid_left_index'][i],
+ f['grid_dimensions'][i]))
+ self.grids[-1]._level_id = f['grid_level'][i]
+
+ dx = (self.parameter_file.domain_right_edge-
+ self.parameter_file.domain_left_edge)/self.parameter_file.domain_dimensions
+ dx = dx/self.parameter_file.refine_by**(f['grid_level'][:])
+
+ self.grid_left_edge = self.parameter_file.domain_left_edge + dx*f['grid_left_index'][:]
+ self.grid_dimensions = f['grid_dimensions'][:]
+ self.grid_right_edge = self.grid_left_edge + dx*self.grid_dimensions
+ self.grid_particle_count = f['grid_particle_count'][:]
+ self.grids = na.array(self.grids, dtype='object')
+ # pdb.set_trace()
def _populate_grid_objects(self):
for g in self.grids:
@@ -145,16 +144,14 @@
class GDFStaticOutput(StaticOutput):
_hierarchy_class = GDFHierarchy
_fieldinfo_class = GDFFieldContainer
-
+
def __init__(self, filename, data_style='grid_data_format',
storage_filename = None):
StaticOutput.__init__(self, filename, data_style)
- self._handle = h5py.File(self.filename, "r")
self.storage_filename = storage_filename
+ self.filename = filename
self.field_info = self._fieldinfo_class()
- self._handle.close()
- del self._handle
-
+
def _set_units(self):
"""
Generates the conversion to various physical _units based on the parameter file
@@ -165,21 +162,25 @@
self._parse_parameter_file()
self.time_units['1'] = 1
self.units['1'] = 1.0
- self.units['unitary'] = 1.0 / (self.domain_right_edge - self.domain_right_edge).max()
+ self.units['unitary'] = 1.0 / (self.domain_right_edge - self.domain_left_edge).max()
seconds = 1
self.time_units['years'] = seconds / (365*3600*24.0)
self.time_units['days'] = seconds / (3600*24.0)
# This should be improved.
+ self._handle = h5py.File(self.parameter_filename, "r")
for field_name in self._handle["/field_types"]:
- self.units[field_name] = self._handle["/%s/field_to_cgs" % field_name]
+ self.units[field_name] = self._handle["/field_types/%s" % field_name].attrs['field_to_cgs']
+ del self._handle
def _parse_parameter_file(self):
+ self._handle = h5py.File(self.parameter_filename, "r")
sp = self._handle["/simulation_parameters"].attrs
self.domain_left_edge = sp["domain_left_edge"][:]
self.domain_right_edge = sp["domain_right_edge"][:]
- self.refine_by = sp["refine_by"][:]
- self.dimensionality = sp["dimensionality"][:]
- self.current_time = sp["current_time"][:]
+ self.domain_dimensions = sp["domain_dimensions"][:]
+ self.refine_by = sp["refine_by"]
+ self.dimensionality = sp["dimensionality"]
+ self.current_time = sp["current_time"]
self.unique_identifier = sp["unique_identifier"]
self.cosmological_simulation = sp["cosmological_simulation"]
if sp["num_ghost_zones"] != 0: raise RuntimeError
@@ -193,7 +194,8 @@
else:
self.current_redshift = self.omega_lambda = self.omega_matter = \
self.hubble_constant = self.cosmological_simulation = 0.0
-
+ del self._handle
+
@classmethod
def _is_valid(self, *args, **kwargs):
try:
@@ -204,4 +206,6 @@
pass
return False
+ def __repr__(self):
+ return self.basename.rsplit(".", 1)[0]
--- a/yt/frontends/gdf/fields.py Fri Oct 14 10:30:47 2011 -0400
+++ b/yt/frontends/gdf/fields.py Fri Oct 14 11:23:40 2011 -0400
@@ -1,5 +1,5 @@
"""
-Chombo-specific fields
+GDF-specific fields
Author: J. S. Oishi <jsoishi at gmail.com>
Affiliation: KIPAC/SLAC/Stanford
@@ -32,82 +32,45 @@
ValidateGridType
import yt.data_objects.universal_fields
-class ChomboFieldContainer(CodeFieldInfoContainer):
+class GDFFieldContainer(CodeFieldInfoContainer):
_shared_state = {}
_field_list = {}
-ChomboFieldInfo = ChomboFieldContainer()
-add_chombo_field = ChomboFieldInfo.add_field
+GDFFieldInfo = GDFFieldContainer()
+add_gdf_field = GDFFieldInfo.add_field
-add_field = add_chombo_field
+add_field = add_gdf_field
add_field("density", function=lambda a,b: None, take_log=True,
validators = [ValidateDataField("density")],
units=r"\rm{g}/\rm{cm}^3")
-ChomboFieldInfo["density"]._projected_units =r"\rm{g}/\rm{cm}^2"
+GDFFieldInfo["density"]._projected_units =r"\rm{g}/\rm{cm}^2"
-add_field("X-momentum", function=lambda a,b: None, take_log=False,
- validators = [ValidateDataField("X-Momentum")],
- units=r"",display_name=r"B_x")
-ChomboFieldInfo["X-momentum"]._projected_units=r""
+add_field("specific_energy", function=lambda a,b: None, take_log=True,
+ validators = [ValidateDataField("specific_energy")],
+ units=r"\rm{erg}/\rm{g}")
-add_field("Y-momentum", function=lambda a,b: None, take_log=False,
- validators = [ValidateDataField("Y-Momentum")],
- units=r"",display_name=r"B_y")
-ChomboFieldInfo["Y-momentum"]._projected_units=r""
+add_field("velocity_x", function=lambda a,b: None, take_log=True,
+ validators = [ValidateDataField("velocity_x")],
+ units=r"\rm{cm}/\rm{s}")
-add_field("Z-momentum", function=lambda a,b: None, take_log=False,
- validators = [ValidateDataField("Z-Momentum")],
- units=r"",display_name=r"B_z")
-ChomboFieldInfo["Z-momentum"]._projected_units=r""
+add_field("velocity_y", function=lambda a,b: None, take_log=True,
+ validators = [ValidateDataField("velocity_y")],
+ units=r"\rm{cm}/\rm{s}")
-add_field("X-magnfield", function=lambda a,b: None, take_log=False,
- validators = [ValidateDataField("X-Magnfield")],
- units=r"",display_name=r"B_x")
-ChomboFieldInfo["X-magnfield"]._projected_units=r""
+add_field("velocity_z", function=lambda a,b: None, take_log=True,
+ validators = [ValidateDataField("velocity_z")],
+ units=r"\rm{cm}/\rm{s}")
-add_field("Y-magnfield", function=lambda a,b: None, take_log=False,
- validators = [ValidateDataField("Y-Magnfield")],
- units=r"",display_name=r"B_y")
-ChomboFieldInfo["Y-magnfield"]._projected_units=r""
+add_field("mag_field_x", function=lambda a,b: None, take_log=True,
+ validators = [ValidateDataField("mag_field_x")],
+ units=r"\rm{cm}/\rm{s}")
-add_field("Z-magnfield", function=lambda a,b: None, take_log=False,
- validators = [ValidateDataField("Z-Magnfield")],
- units=r"",display_name=r"B_z")
-ChomboFieldInfo["Z-magnfield"]._projected_units=r""
+add_field("mag_field_y", function=lambda a,b: None, take_log=True,
+ validators = [ValidateDataField("mag_field_y")],
+ units=r"\rm{cm}/\rm{s}")
-def _MagneticEnergy(field,data):
- return (data["X-magnfield"]**2 +
- data["Y-magnfield"]**2 +
- data["Z-magnfield"]**2)/2.
-add_field("MagneticEnergy", function=_MagneticEnergy, take_log=True,
- units=r"",display_name=r"B^2/8\pi")
-ChomboFieldInfo["MagneticEnergy"]._projected_units=r""
-
-def _xVelocity(field, data):
- """generate x-velocity from x-momentum and density
-
- """
- return data["X-momentum"]/data["density"]
-add_field("x-velocity",function=_xVelocity, take_log=False,
- units=r'\rm{cm}/\rm{s}')
-
-def _yVelocity(field,data):
- """generate y-velocity from y-momentum and density
-
- """
- #try:
- # return data["xvel"]
- #except KeyError:
- return data["Y-momentum"]/data["density"]
-add_field("y-velocity",function=_yVelocity, take_log=False,
- units=r'\rm{cm}/\rm{s}')
-
-def _zVelocity(field,data):
- """generate z-velocity from z-momentum and density
-
- """
- return data["Z-momentum"]/data["density"]
-add_field("z-velocity",function=_zVelocity, take_log=False,
- units=r'\rm{cm}/\rm{s}')
+add_field("mag_field_z", function=lambda a,b: None, take_log=True,
+ validators = [ValidateDataField("mag_field_z")],
+ units=r"\rm{cm}/\rm{s}")
--- a/yt/frontends/gdf/io.py Fri Oct 14 10:30:47 2011 -0400
+++ b/yt/frontends/gdf/io.py Fri Oct 14 11:23:40 2011 -0400
@@ -25,44 +25,48 @@
"""
from yt.utilities.io_handler import \
BaseIOHandler
+import h5py
-class IOHandlerChomboHDF5(BaseIOHandler):
- _data_style = "chombo_hdf5"
+class IOHandlerGDFHDF5(BaseIOHandler):
+ _data_style = "grid_data_format"
_offset_string = 'data:offsets=0'
_data_string = 'data:datatype=0'
def _field_dict(self,fhandle):
- ncomp = int(fhandle['/'].attrs['num_components'])
- temp = fhandle['/'].attrs.listitems()[-ncomp:]
- val, keys = zip(*temp)
- val = [int(re.match('component_(\d+)',v).groups()[0]) for v in val]
+ keys = fhandle['field_types'].keys()
+ val = fhandle['field_types'].keys()
+ # ncomp = int(fhandle['/'].attrs['num_components'])
+ # temp = fhandle['/'].attrs.listitems()[-ncomp:]
+ # val, keys = zip(*temp)
+ # val = [int(re.match('component_(\d+)',v).groups()[0]) for v in val]
return dict(zip(keys,val))
def _read_field_names(self,grid):
fhandle = h5py.File(grid.filename,'r')
- ncomp = int(fhandle['/'].attrs['num_components'])
-
- return [c[1] for c in f['/'].attrs.listitems()[-ncomp:]]
+ return fhandle['field_types'].keys()
def _read_data_set(self,grid,field):
fhandle = h5py.File(grid.hierarchy.hierarchy_filename,'r')
+ return fhandle['/data/grid_%010i/'%grid.id+field][:]
+ # field_dict = self._field_dict(fhandle)
+ # lstring = 'level_%i' % grid.Level
+ # lev = fhandle[lstring]
+ # dims = grid.ActiveDimensions
+ # boxsize = dims.prod()
+
+ # grid_offset = lev[self._offset_string][grid._level_id]
+ # start = grid_offset+field_dict[field]*boxsize
+ # stop = start + boxsize
+ # data = lev[self._data_string][start:stop]
- field_dict = self._field_dict(fhandle)
- lstring = 'level_%i' % grid.Level
- lev = fhandle[lstring]
- dims = grid.ActiveDimensions
- boxsize = dims.prod()
-
- grid_offset = lev[self._offset_string][grid._level_id]
- start = grid_offset+field_dict[field]*boxsize
- stop = start + boxsize
- data = lev[self._data_string][start:stop]
-
- return data.reshape(dims, order='F')
+ # return data.reshape(dims, order='F')
def _read_data_slice(self, grid, field, axis, coord):
sl = [slice(None), slice(None), slice(None)]
sl[axis] = slice(coord, coord + 1)
- return self._read_data_set(grid,field)[sl]
+ fhandle = h5py.File(grid.hierarchy.hierarchy_filename,'r')
+ return fhandle['/data/grid_%010i/'%grid.id+field][:][sl]
+ # return self._read_data_set(grid,field)[sl]
+
--- a/yt/frontends/setup.py Fri Oct 14 10:30:47 2011 -0400
+++ b/yt/frontends/setup.py Fri Oct 14 11:23:40 2011 -0400
@@ -6,6 +6,7 @@
config = Configuration('frontends',parent_package,top_path)
config.make_config_py() # installs __config__.py
#config.make_svn_version_py()
+ config.add_subpackage("gdf")
config.add_subpackage("chombo")
config.add_subpackage("enzo")
config.add_subpackage("flash")
--- a/yt/mods.py Fri Oct 14 10:30:47 2011 -0400
+++ b/yt/mods.py Fri Oct 14 11:23:40 2011 -0400
@@ -77,6 +77,9 @@
from yt.frontends.chombo.api import \
ChomboStaticOutput, ChomboFieldInfo, add_chombo_field
+from yt.frontends.gdf.api import \
+ GDFStaticOutput, GDFFieldInfo, add_gdf_field
+
from yt.frontends.art.api import \
ARTStaticOutput, ARTFieldInfo, add_art_field
--- a/yt/utilities/_amr_utils/VolumeIntegrator.pyx Fri Oct 14 10:30:47 2011 -0400
+++ b/yt/utilities/_amr_utils/VolumeIntegrator.pyx Fri Oct 14 11:23:40 2011 -0400
@@ -737,15 +737,33 @@
dt = (exit_t - enter_t) / tf.ns # 4 samples should be dt=0.25
cdef int offset = ci[0] * (self.dims[1] + 1) * (self.dims[2] + 1) \
+ ci[1] * (self.dims[2] + 1) + ci[2]
+ # The initial and final values can be linearly interpolated between; so
+ # we just have to calculate our initial and final values.
+ cdef np.float64_t slopes[6]
for i in range(3):
- cell_left[i] = ci[i] * self.dds[i] + self.left_edge[i]
- # this gets us dp as the current first sample position
- pos[i] = (enter_t + 0.5 * dt) * v_dir[i] + v_pos[i]
- dp[i] = pos[i] - cell_left[i]
+ dp[i] = (enter_t + 0.5 * dt) * v_dir[i] + v_pos[i]
+ dp[i] -= ci[i] * self.dds[i] + self.left_edge[i]
dp[i] *= self.idds[i]
ds[i] = v_dir[i] * self.idds[i] * dt
- local_dds[i] = v_dir[i] * dt
+ for i in range(self.n_fields):
+ slopes[i] = offset_interpolate(self.dims, dp,
+ self.data[i] + offset)
+ for i in range(3):
+ dp[i] += ds[i] * tf.ns
+ cdef np.float64_t temp
+ for i in range(self.n_fields):
+ temp = slopes[i]
+ slopes[i] -= offset_interpolate(self.dims, dp,
+ self.data[i] + offset)
+ slopes[i] *= -1.0/tf.ns
+ self.dvs[i] = temp
if self.star_list != NULL:
+ for i in range(3):
+ cell_left[i] = ci[i] * self.dds[i] + self.left_edge[i]
+ # this gets us dp as the current first sample position
+ pos[i] = (enter_t + 0.5 * dt) * v_dir[i] + v_pos[i]
+ dp[i] -= tf.ns * ds[i]
+ local_dds[i] = v_dir[i] * dt
ballq = kdtree_utils.kd_nearest_range3(
self.star_list, cell_left[0] + self.dds[0]*0.5,
cell_left[1] + self.dds[1]*0.5,
@@ -753,15 +771,16 @@
self.star_er + 0.9*self.dds[0])
# ~0.866 + a bit
for dti in range(tf.ns):
- for i in range(self.n_fields):
- self.dvs[i] = offset_interpolate(self.dims, dp, self.data[i] + offset)
#if (dv < tf.x_bounds[0]) or (dv > tf.x_bounds[1]):
# continue
- if self.star_list != NULL: self.add_stars(ballq, dt, pos, rgba)
+ if self.star_list != NULL:
+ self.add_stars(ballq, dt, pos, rgba)
+ for i in range(3):
+ dp[i] += ds[i]
+ pos[i] += local_dds[i]
tf.eval_transfer(dt, self.dvs, rgba, grad)
- for i in range(3):
- dp[i] += ds[i]
- pos[i] += local_dds[i]
+ for i in range(self.n_fields):
+ self.dvs[i] += slopes[i]
if ballq != NULL: kdtree_utils.kd_res_free(ballq)
@cython.boundscheck(False)
--- a/yt/utilities/parameter_file_storage.py Fri Oct 14 10:30:47 2011 -0400
+++ b/yt/utilities/parameter_file_storage.py Fri Oct 14 11:23:40 2011 -0400
@@ -32,6 +32,8 @@
from yt.utilities.parallel_tools.parallel_analysis_interface import \
parallel_simple_proxy
+import yt.utilities.peewee as peewee
+
output_type_registry = {}
_field_names = ('hash', 'bn', 'fp', 'tt', 'ctid', 'class_name', 'last_seen')
@@ -48,6 +50,20 @@
def __repr__(self):
return "%s" % self.name
+_field_spec = dict(
+ dset_uuid = peewee.TextField(),
+ output_type = peewee.TextField(),
+ pf_path = peewee.TextField(),
+ creation_time = peewee.IntegerField(),
+ last_seen_time = peewee.IntegerField(),
+ simulation_uuid = peewee.TextField(),
+ redshift = peewee.FloatField(),
+ time = peewee.FloatField(),
+ topgrid0 = peewee.IntegerField(),
+ topgrid1 = peewee.IntegerField(),
+ topgrid2 = peewee.IntegerField(),
+)
+
class ParameterFileStore(object):
"""
This class is designed to be a semi-persistent storage for parameter
@@ -62,6 +78,7 @@
_distributed = True
_processing = False
_owner = 0
+ conn = None
def __new__(cls, *p, **k):
self = object.__new__(cls, *p, **k)
@@ -77,7 +94,6 @@
if ytcfg.getboolean("yt", "StoreParameterFiles"):
self._read_only = False
self.init_db()
- self._records = self.read_db()
else:
self._read_only = True
self._records = {}
@@ -93,9 +109,26 @@
if not os.path.isdir(dbdir): os.mkdir(dbdir)
except OSError:
raise NoParameterShelf()
- open(dbn, 'ab') # make sure it exists, allow to close
- # Now we read in all our records and return them
- # these will be broadcast
+ self.conn = peewee.SqliteDatabase(dbn)
+ class SimulationOutputsMeta:
+ database = self.conn
+ db_table = "simulation_outputs"
+ _field_spec["Meta"] = SimulationOutputsMeta
+ self.output_model = type(
+ "SimulationOutputs",
+ (peewee.Model,),
+ _field_spec,
+ )
+ self.output_model._meta.pk_name = "dset_uuid"
+ try:
+ self.conn.connect()
+ except:
+ self.conn = None
+ try:
+ self.output_model.create_table()
+ except:
+ pass
+ self.conn = None
def _get_db_name(self):
base_file_name = ytcfg.get("yt", "ParameterFileStore")
@@ -104,40 +137,26 @@
return os.path.expanduser("~/.yt/%s" % base_file_name)
def get_pf_hash(self, hash):
+ if self.conn is None: return
""" This returns a parameter file based on a hash. """
- return self._convert_pf(self._records[hash])
+ output = self.output_model.get(dset_uuid = hash)
+ return self._convert_pf(output)
- def get_pf_ctid(self, ctid):
- """ This returns a parameter file based on a CurrentTimeIdentifier. """
- for h in self._records:
- if self._records[h]['ctid'] == ctid:
- return self._convert_pf(self._records[h])
-
- def _adapt_pf(self, pf):
- """ This turns a parameter file into a CSV entry. """
- return dict(bn=pf.basename,
- fp=pf.fullpath,
- tt=pf.current_time,
- ctid=pf.unique_identifier,
- class_name=pf.__class__.__name__,
- last_seen=pf._instantiated)
-
- def _convert_pf(self, pf_dict):
- """ This turns a CSV entry into a parameter file. """
- bn = pf_dict['bn']
- fp = pf_dict['fp']
- fn = os.path.join(fp, bn)
- class_name = pf_dict['class_name']
- if class_name not in output_type_registry:
- raise UnknownStaticOutputType(class_name)
+ def _convert_pf(self, inst):
+ """ This turns a model into a parameter file. """
+ if self.conn is None: return
+ fn = inst.pf_path
+ if inst.output_type not in output_type_registry:
+ raise UnknownStaticOutputType(inst.output_type)
mylog.info("Checking %s", fn)
if os.path.exists(fn):
- pf = output_type_registry[class_name](os.path.join(fp, bn))
+ pf = output_type_registry[inst.output_type](fn)
else:
raise IOError
# This next one is to ensure that we manually update the last_seen
# record *now*, for during write_out.
- self._records[pf._hash()]['last_seen'] = pf._instantiated
+ self.output_model.update(last_seen_time = pf._instantiated).where(
+ dset_uuid = inst.dset_uuid).execute()
return pf
def check_pf(self, pf):
@@ -146,86 +165,36 @@
recorded in the storage unit. In doing so, it will update path
and "last_seen" information.
"""
- hash = pf._hash()
- if hash not in self._records:
+ if self.conn is None: return
+ q = self.output_model.select().where(dset_uuid = pf._hash())
+ q.execute()
+ if q.count() == 0:
self.insert_pf(pf)
return
- pf_dict = self._records[hash]
- self._records[hash]['last_seen'] = pf._instantiated
- if pf_dict['bn'] != pf.basename \
- or pf_dict['fp'] != pf.fullpath:
- self.wipe_hash(hash)
- self.insert_pf(pf)
+ # Otherwise we update
+ self.output_model.update(
+ last_seen_time = pf._instantiated,
+ pf_path = os.path.join(pf.basename, pf.fullpath)
+ ).where(
+ dset_uuid = pf._hash()).execute(
+ )
def insert_pf(self, pf):
""" This will insert a new *pf* and flush the database to disk. """
- self._records[pf._hash()] = self._adapt_pf(pf)
- self.flush_db()
-
- def wipe_hash(self, hash):
- """
- This removes a *hash* corresponding to a parameter file from the
- storage.
- """
- if hash not in self._records: return
- del self._records[hash]
- self.flush_db()
-
- def flush_db(self):
- """ This flushes the storage to disk. """
- if self._read_only: return
- self._write_out()
- self.read_db()
-
- def get_recent(self, n=10):
- recs = sorted(self._records.values(), key=lambda a: -a['last_seen'])[:n]
- return recs
-
- @parallel_simple_proxy
- def _write_out(self):
- if self._read_only: return
- fn = self._get_db_name()
- f = open("%s.tmp" % fn, 'wb')
- w = csv.DictWriter(f, _field_names)
- maxn = ytcfg.getint("yt","MaximumStoredPFs") # number written
- for h,v in islice(sorted(self._records.items(),
- key=lambda a: -a[1]['last_seen']), 0, maxn):
- v['hash'] = h
- w.writerow(v)
- f.close()
- os.rename("%s.tmp" % fn, fn)
-
- @parallel_simple_proxy
- def read_db(self):
- """ This will read the storage device from disk. """
- f = open(self._get_db_name(), 'rb')
- vals = csv.DictReader(f, _field_names)
- db = {}
- for v in vals:
- db[v.pop('hash')] = v
- if v['last_seen'] is None:
- v['last_seen'] = 0.0
- else: v['last_seen'] = float(v['last_seen'])
- return db
-
-class ObjectStorage(object):
- pass
-
-class EnzoRunDatabase(object):
- conn = None
-
- def __init__(self, path = None):
- if path is None:
- path = ytcfg.get("yt", "enzo_db")
- if len(path) == 0: raise Runtime
- import sqlite3
- self.conn = sqlite3.connect(path)
-
- def find_uuid(self, u):
- cursor = self.conn.execute(
- "select pf_path from enzo_outputs where dset_uuid = '%s'" % (
- u))
- # It's a 'unique key'
- result = cursor.fetchone()
- if result is None: return None
- return result[0]
+ if self.conn is None: return
+ q = self.output_model.insert(
+ dset_uuid = pf._hash(),
+ output_type = pf.__class__.__name__,
+ pf_path = os.path.join(
+ pf.fullpath, pf.basename),
+ creation_time = pf.parameters.get(
+ "CurrentTimeIdentifier", 0), # Get os.stat
+ last_seen_time = pf._instantiated,
+ simulation_uuid = pf.parameters.get(
+ "SimulationUUID", ""), # NULL
+ redshift = pf.current_redshift,
+ time = pf.current_time,
+ topgrid0 = pf.domain_dimensions[0],
+ topgrid1 = pf.domain_dimensions[1],
+ topgrid2 = pf.domain_dimensions[2])
+ q.execute()
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/yt/utilities/peewee.py Fri Oct 14 11:23:40 2011 -0400
@@ -0,0 +1,1425 @@
+# (\
+# ( \ /(o)\ caw!
+# ( \/ ()/ /)
+# ( `;.))'".)
+# `(/////.-'
+# =====))=))===()
+# ///'
+# //
+# '
+from datetime import datetime
+import logging
+import os
+import re
+import time
+
+try:
+ import sqlite3
+except ImportError:
+ sqlite3 = None
+
+try:
+ import psycopg2
+except ImportError:
+ psycopg2 = None
+
+try:
+ import MySQLdb as mysql
+except ImportError:
+ mysql = None
+
+if sqlite3 is None and psycopg2 is None and mysql is None:
+ raise ImproperlyConfigured('Either sqlite3, psycopg2 or MySQLdb must be installed')
+
+
+DATABASE_NAME = os.environ.get('PEEWEE_DATABASE', 'peewee.db')
+logger = logging.getLogger('peewee.logger')
+
+
+class BaseAdapter(object):
+ """
+ The various subclasses of `BaseAdapter` provide a bridge between the high-
+ level `Database` abstraction and the underlying python libraries like
+ psycopg2. It also provides a way to unify the pythonic field types with
+ the underlying column types used by the database engine.
+
+ The `BaseAdapter` provides two types of mappings:
+ - mapping between filter operations and their database equivalents
+ - mapping between basic field types and their database column types
+
+ The `BaseAdapter` also is the mechanism used by the `Database` class to:
+ - handle connections with the database
+ - extract information from the database cursor
+ """
+ operations = {'eq': '= %s'}
+ interpolation = '%s'
+
+ def get_field_types(self):
+ field_types = {
+ 'integer': 'INTEGER',
+ 'float': 'REAL',
+ 'decimal': 'NUMERIC',
+ 'string': 'VARCHAR',
+ 'text': 'TEXT',
+ 'datetime': 'DATETIME',
+ 'primary_key': 'INTEGER',
+ 'foreign_key': 'INTEGER',
+ 'boolean': 'SMALLINT',
+ }
+ field_types.update(self.get_field_overrides())
+ return field_types
+
+ def get_field_overrides(self):
+ return {}
+
+ def connect(self, database, **kwargs):
+ raise NotImplementedError
+
+ def close(self, conn):
+ conn.close()
+
+ def lookup_cast(self, lookup, value):
+ if lookup in ('contains', 'icontains'):
+ return '%%%s%%' % value
+ elif lookup in ('startswith', 'istartswith'):
+ return '%s%%' % value
+ return value
+
+ def last_insert_id(self, cursor, model):
+ return cursor.lastrowid
+
+ def rows_affected(self, cursor):
+ return cursor.rowcount
+
+
+class SqliteAdapter(BaseAdapter):
+ # note the sqlite library uses a non-standard interpolation string
+ operations = {
+ 'lt': '< ?',
+ 'lte': '<= ?',
+ 'gt': '> ?',
+ 'gte': '>= ?',
+ 'eq': '= ?',
+ 'ne': '!= ?', # watch yourself with this one
+ 'in': 'IN (%s)', # special-case to list q-marks
+ 'is': 'IS ?',
+ 'icontains': "LIKE ? ESCAPE '\\'", # surround param with %'s
+ 'contains': "GLOB ?", # surround param with *'s
+ 'istartswith': "LIKE ? ESCAPE '\\'",
+ 'startswith': "GLOB ?",
+ }
+ interpolation = '?'
+
+ def connect(self, database, **kwargs):
+ return sqlite3.connect(database, **kwargs)
+
+ def lookup_cast(self, lookup, value):
+ if lookup == 'contains':
+ return '*%s*' % value
+ elif lookup == 'icontains':
+ return '%%%s%%' % value
+ elif lookup == 'startswith':
+ return '%s*' % value
+ elif lookup == 'istartswith':
+ return '%s%%' % value
+ return value
+
+
+class PostgresqlAdapter(BaseAdapter):
+ operations = {
+ 'lt': '< %s',
+ 'lte': '<= %s',
+ 'gt': '> %s',
+ 'gte': '>= %s',
+ 'eq': '= %s',
+ 'ne': '!= %s', # watch yourself with this one
+ 'in': 'IN (%s)', # special-case to list q-marks
+ 'is': 'IS %s',
+ 'icontains': 'ILIKE %s', # surround param with %'s
+ 'contains': 'LIKE %s', # surround param with *'s
+ 'istartswith': 'ILIKE %s',
+ 'startswith': 'LIKE %s',
+ }
+
+ def connect(self, database, **kwargs):
+ return psycopg2.connect(database=database, **kwargs)
+
+ def get_field_overrides(self):
+ return {
+ 'primary_key': 'SERIAL',
+ 'datetime': 'TIMESTAMP'
+ }
+
+ def last_insert_id(self, cursor, model):
+ cursor.execute("SELECT CURRVAL('\"%s_%s_seq\"')" % (
+ model._meta.db_table, model._meta.pk_name))
+ return cursor.fetchone()[0]
+
+
+class MySQLAdapter(BaseAdapter):
+ operations = {
+ 'lt': '< %s',
+ 'lte': '<= %s',
+ 'gt': '> %s',
+ 'gte': '>= %s',
+ 'eq': '= %s',
+ 'ne': '!= %s', # watch yourself with this one
+ 'in': 'IN (%s)', # special-case to list q-marks
+ 'is': 'IS %s',
+ 'icontains': 'LIKE %s', # surround param with %'s
+ 'contains': 'LIKE BINARY %s', # surround param with *'s
+ 'istartswith': 'LIKE %s',
+ 'startswith': 'LIKE BINARY %s',
+ }
+
+ def connect(self, database, **kwargs):
+ return mysql.connect(db=database, **kwargs)
+
+ def get_field_overrides(self):
+ return {
+ 'primary_key': 'integer AUTO_INCREMENT',
+ 'boolean': 'bool',
+ 'float': 'double precision',
+ 'text': 'longtext',
+ }
+
+
+class Database(object):
+ """
+ A high-level api for working with the supported database engines. `Database`
+ provides a wrapper around some of the functions performed by the `Adapter`,
+ in addition providing support for:
+ - execution of SQL queries
+ - creating and dropping tables and indexes
+ """
+ def __init__(self, adapter, database, **connect_kwargs):
+ self.adapter = adapter
+ self.database = database
+ self.connect_kwargs = connect_kwargs
+
+ def connect(self):
+ self.conn = self.adapter.connect(self.database, **self.connect_kwargs)
+
+ def close(self):
+ self.adapter.close(self.conn)
+
+ def execute(self, sql, params=None, commit=False):
+ cursor = self.conn.cursor()
+ res = cursor.execute(sql, params or ())
+ if commit:
+ self.conn.commit()
+ logger.debug((sql, params))
+ return cursor
+
+ def last_insert_id(self, cursor, model):
+ return self.adapter.last_insert_id(cursor, model)
+
+ def rows_affected(self, cursor):
+ return self.adapter.rows_affected(cursor)
+
+ def column_for_field(self, db_field):
+ try:
+ return self.adapter.get_field_types()[db_field]
+ except KeyError:
+ raise AttributeError('Unknown field type: "%s", valid types are: %s' % \
+ db_field, ', '.join(self.adapter.get_field_types().keys())
+ )
+
+ def create_table(self, model_class):
+ framing = "CREATE TABLE %s (%s);"
+ columns = []
+
+ for field in model_class._meta.fields.values():
+ columns.append(field.to_sql())
+
+ query = framing % (model_class._meta.db_table, ', '.join(columns))
+
+ self.execute(query, commit=True)
+
+ def create_index(self, model_class, field, unique=False):
+ framing = 'CREATE %(unique)s INDEX %(model)s_%(field)s ON %(model)s(%(field)s);'
+
+ if field not in model_class._meta.fields:
+ raise AttributeError(
+ 'Field %s not on model %s' % (field, model_class)
+ )
+
+ unique_expr = ternary(unique, 'UNIQUE', '')
+
+ query = framing % {
+ 'unique': unique_expr,
+ 'model': model_class._meta.db_table,
+ 'field': field
+ }
+
+ self.execute(query, commit=True)
+
+ def drop_table(self, model_class, fail_silently=False):
+ framing = fail_silently and 'DROP TABLE IF EXISTS %s;' or 'DROP TABLE %s;'
+ self.execute(framing % model_class._meta.db_table, commit=True)
+
+ def get_indexes_for_table(self, table):
+ raise NotImplementedError
+
+
+class SqliteDatabase(Database):
+ def __init__(self, database, **connect_kwargs):
+ super(SqliteDatabase, self).__init__(SqliteAdapter(), database, **connect_kwargs)
+
+ def get_indexes_for_table(self, table):
+ res = self.execute('PRAGMA index_list(%s);' % table)
+ rows = sorted([(r[1], r[2] == 1) for r in res.fetchall()])
+ return rows
+
+
+class PostgresqlDatabase(Database):
+ def __init__(self, database, **connect_kwargs):
+ super(PostgresqlDatabase, self).__init__(PostgresqlAdapter(), database, **connect_kwargs)
+
+ def get_indexes_for_table(self, table):
+ res = self.execute("""
+ SELECT c2.relname, i.indisprimary, i.indisunique
+ FROM pg_catalog.pg_class c, pg_catalog.pg_class c2, pg_catalog.pg_index i
+ WHERE c.relname = %s AND c.oid = i.indrelid AND i.indexrelid = c2.oid
+ ORDER BY i.indisprimary DESC, i.indisunique DESC, c2.relname""", (table,))
+ return sorted([(r[0], r[1]) for r in res.fetchall()])
+
+class MySQLDatabase(Database):
+ def __init__(self, database, **connect_kwargs):
+ super(MySQLDatabase, self).__init__(MySQLAdapter(), database, **connect_kwargs)
+
+ def get_indexes_for_table(self, table):
+ res = self.execute('SHOW INDEXES IN %s;' % table)
+ rows = sorted([(r[2], r[1] == 0) for r in res.fetchall()])
+ return rows
+
+
+class QueryResultWrapper(object):
+ """
+ Provides an iterator over the results of a raw Query, additionally doing
+ two things:
+ - converts rows from the database into model instances
+ - ensures that multiple iterations do not result in multiple queries
+ """
+ def __init__(self, model, cursor):
+ self.model = model
+ self.cursor = cursor
+ self._result_cache = []
+ self._populated = False
+
+ def model_from_rowset(self, model_class, row_dict):
+ instance = model_class()
+ for attr, value in row_dict.iteritems():
+ if attr in instance._meta.fields:
+ field = instance._meta.fields[attr]
+ setattr(instance, attr, field.python_value(value))
+ else:
+ setattr(instance, attr, value)
+ return instance
+
+ def _row_to_dict(self, row, result_cursor):
+ return dict((result_cursor.description[i][0], value)
+ for i, value in enumerate(row))
+
+ def __iter__(self):
+ if not self._populated:
+ return self
+ else:
+ return iter(self._result_cache)
+
+ def next(self):
+ row = self.cursor.fetchone()
+ if row:
+ row_dict = self._row_to_dict(row, self.cursor)
+ instance = self.model_from_rowset(self.model, row_dict)
+ self._result_cache.append(instance)
+ return instance
+ else:
+ self._populated = True
+ raise StopIteration
+
+
+# create
+class DoesNotExist(Exception):
+ pass
+
+
+# semantic wrappers for ordering the results of a `SelectQuery`
+def asc(f):
+ return (f, 'ASC')
+
+def desc(f):
+ return (f, 'DESC')
+
+# wrappers for performing aggregation in a `SelectQuery`
+def Count(f, alias='count'):
+ return ('COUNT', f, alias)
+
+def Max(f, alias='max'):
+ return ('MAX', f, alias)
+
+def Min(f, alias='min'):
+ return ('MIN', f, alias)
+
+# decorator for query methods to indicate that they change the state of the
+# underlying data structures
+def returns_clone(func):
+ def inner(self, *args, **kwargs):
+ clone = self.clone()
+ res = func(clone, *args, **kwargs)
+ return clone
+ return inner
+
+# helpers
+ternary = lambda cond, t, f: (cond and [t] or [f])[0]
+
+
+class Node(object):
+ def __init__(self, connector='AND'):
+ self.connector = connector
+ self.children = []
+ self.negated = False
+
+ def connect(self, rhs, connector):
+ if isinstance(rhs, Q):
+ if connector == self.connector:
+ self.children.append(rhs)
+ return self
+ else:
+ p = Node(connector)
+ p.children = [self, rhs]
+ return p
+ elif isinstance(rhs, Node):
+ p = Node(connector)
+ p.children = [self, rhs]
+ return p
+
+ def __or__(self, rhs):
+ return self.connect(rhs, 'OR')
+
+ def __and__(self, rhs):
+ return self.connect(rhs, 'AND')
+
+ def __invert__(self):
+ self.negated = not self.negated
+ return self
+
+ def __unicode__(self):
+ query = []
+ nodes = []
+ for child in self.children:
+ if isinstance(child, Q):
+ query.append(unicode(child))
+ elif isinstance(child, Node):
+ nodes.append('(%s)' % unicode(child))
+ query.extend(nodes)
+ connector = ' %s ' % self.connector
+ query = connector.join(query)
+ if self.negated:
+ query = 'NOT %s' % query
+ return query
+
+
+class Q(object):
+ def __init__(self, **kwargs):
+ self.query = kwargs
+ self.parent = None
+ self.negated = False
+
+ def connect(self, connector):
+ if self.parent is None:
+ self.parent = Node(connector)
+ self.parent.children.append(self)
+
+ def __or__(self, rhs):
+ self.connect('OR')
+ return self.parent | rhs
+
+ def __and__(self, rhs):
+ self.connect('AND')
+ return self.parent & rhs
+
+ def __invert__(self):
+ self.negated = not self.negated
+ return self
+
+ def __unicode__(self):
+ bits = ['%s = %s' % (k, v) for k, v in self.query.items()]
+ if len(self.query.items()) > 1:
+ connector = ' AND '
+ expr = '(%s)' % connector.join(bits)
+ else:
+ expr = bits[0]
+ if self.negated:
+ expr = 'NOT %s' % expr
+ return expr
+
+
+def parseq(*args, **kwargs):
+ node = Node()
+
+ for piece in args:
+ if isinstance(piece, (Q, Node)):
+ node.children.append(piece)
+ else:
+ raise TypeError('Unknown object: %s', piece)
+
+ if kwargs:
+ node.children.append(Q(**kwargs))
+
+ return node
+
+
+class EmptyResultException(Exception):
+ pass
+
+
+class BaseQuery(object):
+ query_separator = '__'
+ requires_commit = True
+ force_alias = False
+
+ def __init__(self, model):
+ self.model = model
+ self.query_context = model
+ self.database = self.model._meta.database
+ self.operations = self.database.adapter.operations
+ self.interpolation = self.database.adapter.interpolation
+
+ self._dirty = True
+ self._where = {}
+ self._joins = []
+
+ def clone(self):
+ raise NotImplementedError
+
+ def lookup_cast(self, lookup, value):
+ return self.database.adapter.lookup_cast(lookup, value)
+
+ def parse_query_args(self, model, **query):
+ parsed = {}
+ for lhs, rhs in query.iteritems():
+ if self.query_separator in lhs:
+ lhs, op = lhs.rsplit(self.query_separator, 1)
+ else:
+ op = 'eq'
+
+ try:
+ field = model._meta.get_field_by_name(lhs)
+ except AttributeError:
+ field = model._meta.get_related_field_by_name(lhs)
+ if field is None:
+ raise
+ if isinstance(rhs, Model):
+ rhs = rhs.get_pk()
+
+ if op == 'in':
+ if isinstance(rhs, SelectQuery):
+ lookup_value = rhs
+ operation = 'IN (%s)'
+ else:
+ if not rhs:
+ raise EmptyResultException
+ lookup_value = [field.db_value(o) for o in rhs]
+ operation = self.operations[op] % \
+ (','.join([self.interpolation for v in lookup_value]))
+ elif op == 'is':
+ if rhs is not None:
+ raise ValueError('__is lookups only accept None')
+ operation = 'IS NULL'
+ lookup_value = []
+ else:
+ lookup_value = field.db_value(rhs)
+ operation = self.operations[op]
+
+ parsed[field.name] = (operation, self.lookup_cast(op, lookup_value))
+
+ return parsed
+
+ @returns_clone
+ def where(self, *args, **kwargs):
+ self._where.setdefault(self.query_context, [])
+ self._where[self.query_context].append(parseq(*args, **kwargs))
+
+ @returns_clone
+ def join(self, model, join_type=None, on=None):
+ if self.query_context._meta.rel_exists(model):
+ self._joins.append((model, join_type, on))
+ self.query_context = model
+ else:
+ raise AttributeError('No foreign key found between %s and %s' % \
+ (self.query_context.__name__, model.__name__))
+
+ @returns_clone
+ def switch(self, model):
+ if model == self.model:
+ self.query_context = model
+ return
+
+ for klass, join_type, on in self._joins:
+ if model == klass:
+ self.query_context = model
+ return
+ raise AttributeError('You must JOIN on %s' % model.__name__)
+
+ def use_aliases(self):
+ return len(self._joins) > 0 or self.force_alias
+
+ def combine_field(self, alias, field_name):
+ if alias:
+ return '%s.%s' % (alias, field_name)
+ return field_name
+
+ def compile_where(self):
+ alias_count = 0
+ alias_map = {}
+
+ alias_required = self.use_aliases()
+
+ joins = list(self._joins)
+ if self._where or len(joins):
+ joins.insert(0, (self.model, None, None))
+
+ where_with_alias = []
+ where_data = []
+ computed_joins = []
+
+ for i, (model, join_type, on) in enumerate(joins):
+ if alias_required:
+ alias_count += 1
+ alias_map[model] = 't%d' % alias_count
+ else:
+ alias_map[model] = ''
+
+ if i > 0:
+ from_model = joins[i-1][0]
+ field = from_model._meta.get_related_field_for_model(model, on)
+ if field:
+ left_field = field.name
+ right_field = model._meta.pk_name
+ else:
+ field = from_model._meta.get_reverse_related_field_for_model(model, on)
+ left_field = from_model._meta.pk_name
+ right_field = field.name
+
+ if join_type is None:
+ if field.null and model not in self._where:
+ join_type = 'LEFT OUTER'
+ else:
+ join_type = 'INNER'
+
+ computed_joins.append(
+ '%s JOIN %s AS %s ON %s = %s' % (
+ join_type,
+ model._meta.db_table,
+ alias_map[model],
+ self.combine_field(alias_map[from_model], left_field),
+ self.combine_field(alias_map[model], right_field),
+ )
+ )
+
+ for (model, join_type, on) in joins:
+ if model in self._where:
+ for node in self._where[model]:
+ query, data = self.parse_node(node, model, alias_map)
+ where_with_alias.append(query)
+ where_data.extend(data)
+
+ return computed_joins, where_with_alias, where_data, alias_map
+
+ def convert_where_to_params(self, where_data):
+ flattened = []
+ for clause in where_data:
+ if isinstance(clause, (tuple, list)):
+ flattened.extend(clause)
+ else:
+ flattened.append(clause)
+ return flattened
+
+ def parse_node(self, node, model, alias_map):
+ query = []
+ query_data = []
+ nodes = []
+ for child in node.children:
+ if isinstance(child, Q):
+ parsed, data = self.parse_q(child, model, alias_map)
+ query.append(parsed)
+ query_data.extend(data)
+ elif isinstance(child, Node):
+ parsed, data = self.parse_node(child, model, alias_map)
+ query.append('(%s)' % parsed)
+ query_data.extend(data)
+ query.extend(nodes)
+ connector = ' %s ' % node.connector
+ query = connector.join(query)
+ if node.negated:
+ query = 'NOT (%s)' % query
+ return query, query_data
+
+ def parse_q(self, q, model, alias_map):
+ query = []
+ query_data = []
+ parsed = self.parse_query_args(model, **q.query)
+ for (name, lookup) in parsed.iteritems():
+ operation, value = lookup
+ if isinstance(value, SelectQuery):
+ sql, value = self.convert_subquery(value)
+ operation = operation % sql
+
+ query_data.append(value)
+
+ combined = self.combine_field(alias_map[model], name)
+ query.append('%s %s' % (combined, operation))
+
+ if len(query) > 1:
+ query = '(%s)' % (' AND '.join(query))
+ else:
+ query = query[0]
+
+ if q.negated:
+ query = 'NOT %s' % query
+
+ return query, query_data
+
+ def convert_subquery(self, subquery):
+ subquery.query, orig_query = subquery.model._meta.pk_name, subquery.query
+ subquery.force_alias, orig_alias = True, subquery.force_alias
+ sql, data = subquery.sql()
+ subquery.query = orig_query
+ subquery.force_alias = orig_alias
+ return sql, data
+
+ def raw_execute(self):
+ query, params = self.sql()
+ return self.database.execute(query, params, self.requires_commit)
+
+
+class RawQuery(BaseQuery):
+ def __init__(self, model, query, *params):
+ self._sql = query
+ self._params = list(params)
+ super(RawQuery, self).__init__(model)
+
+ def sql(self):
+ return self._sql, self._params
+
+ def execute(self):
+ return QueryResultWrapper(self.model, self.raw_execute())
+
+ def join(self):
+ raise AttributeError('Raw queries do not support joining programmatically')
+
+ def where(self):
+ raise AttributeError('Raw queries do not support querying programmatically')
+
+ def switch(self):
+ raise AttributeError('Raw queries do not support switching contexts')
+
+ def __iter__(self):
+ return self.execute()
+
+
+class SelectQuery(BaseQuery):
+ requires_commit = False
+
+ def __init__(self, model, query=None):
+ self.query = query or '*'
+ self._group_by = []
+ self._having = []
+ self._order_by = []
+ self._pagination = None # return all by default
+ self._distinct = False
+ self._qr = None
+ super(SelectQuery, self).__init__(model)
+
+ def clone(self):
+ query = SelectQuery(self.model, self.query)
+ query.query_context = self.query_context
+ query._group_by = list(self._group_by)
+ query._having = list(self._having)
+ query._order_by = list(self._order_by)
+ query._pagination = self._pagination and tuple(self._pagination) or None
+ query._distinct = self._distinct
+ query._qr = self._qr
+ query._where = dict(self._where)
+ query._joins = list(self._joins)
+ return query
+
+ @returns_clone
+ def paginate(self, page_num, paginate_by=20):
+ self._pagination = (page_num, paginate_by)
+
+ def count(self):
+ tmp_pagination = self._pagination
+ self._pagination = None
+
+ tmp_query = self.query
+
+ if self.use_aliases():
+ self.query = 'COUNT(t1.%s)' % (self.model._meta.pk_name)
+ else:
+ self.query = 'COUNT(%s)' % (self.model._meta.pk_name)
+
+ res = self.database.execute(*self.sql())
+
+ self.query = tmp_query
+ self._pagination = tmp_pagination
+
+ return res.fetchone()[0]
+
+ @returns_clone
+ def group_by(self, clause):
+ model = self.query_context
+
+ if isinstance(clause, basestring):
+ fields = (clause,)
+ elif isinstance(clause, (list, tuple)):
+ fields = clause
+ elif issubclass(clause, Model):
+ model = clause
+ fields = clause._meta.get_field_names()
+
+ self._group_by.append((model, fields))
+
+ @returns_clone
+ def having(self, clause):
+ self._having.append(clause)
+
+ @returns_clone
+ def distinct(self):
+ self._distinct = True
+
+ @returns_clone
+ def order_by(self, field_or_string):
+ if isinstance(field_or_string, tuple):
+ field_or_string, ordering = field_or_string
+ else:
+ ordering = 'ASC'
+
+ self._order_by.append(
+ (self.query_context, field_or_string, ordering)
+ )
+
+ def parse_select_query(self, alias_map):
+ if isinstance(self.query, basestring):
+ if self.query in ('*', self.model._meta.pk_name) and self.use_aliases():
+ return '%s.%s' % (alias_map[self.model], self.query)
+ return self.query
+ elif isinstance(self.query, dict):
+ qparts = []
+ aggregates = []
+ for model, cols in self.query.iteritems():
+ alias = alias_map.get(model, '')
+ for col in cols:
+ if isinstance(col, tuple):
+ func, col, col_alias = col
+ aggregates.append('%s(%s) AS %s' % \
+ (func, self.combine_field(alias, col), col_alias)
+ )
+ else:
+ qparts.append(self.combine_field(alias, col))
+ return ', '.join(qparts + aggregates)
+ else:
+ raise TypeError('Unknown type encountered parsing select query')
+
+ def sql(self):
+ joins, where, where_data, alias_map = self.compile_where()
+
+ table = self.model._meta.db_table
+
+ params = []
+ group_by = []
+
+ if self.use_aliases():
+ table = '%s AS %s' % (table, alias_map[self.model])
+ for model, clause in self._group_by:
+ alias = alias_map[model]
+ for field in clause:
+ group_by.append(self.combine_field(alias, field))
+ else:
+ group_by = [c[1] for c in self._group_by]
+
+ parsed_query = self.parse_select_query(alias_map)
+
+ if self._distinct:
+ sel = 'SELECT DISTINCT'
+ else:
+ sel = 'SELECT'
+
+ select = '%s %s FROM %s' % (sel, parsed_query, table)
+ joins = '\n'.join(joins)
+ where = ' AND '.join(where)
+ group_by = ', '.join(group_by)
+ having = ' AND '.join(self._having)
+
+ order_by = []
+ for piece in self._order_by:
+ model, field, ordering = piece
+ if self.use_aliases() and field in model._meta.fields:
+ field = '%s.%s' % (alias_map[model], field)
+ order_by.append('%s %s' % (field, ordering))
+
+ pieces = [select]
+
+ if joins:
+ pieces.append(joins)
+ if where:
+ pieces.append('WHERE %s' % where)
+ params.extend(self.convert_where_to_params(where_data))
+
+ if group_by:
+ pieces.append('GROUP BY %s' % group_by)
+ if having:
+ pieces.append('HAVING %s' % having)
+ if order_by:
+ pieces.append('ORDER BY %s' % ', '.join(order_by))
+ if self._pagination:
+ page, paginate_by = self._pagination
+ if page > 0:
+ page -= 1
+ pieces.append('LIMIT %d OFFSET %d' % (paginate_by, page * paginate_by))
+
+ return ' '.join(pieces), params
+
+ def execute(self):
+ if self._dirty or not self._qr:
+ try:
+ self._qr = QueryResultWrapper(self.model, self.raw_execute())
+ self._dirty = False
+ return self._qr
+ except EmptyResultException:
+ return iter([])
+ else:
+ # call the __iter__ method directly
+ return iter(self._qr)
+
+ def __iter__(self):
+ return self.execute()
+
+
+class UpdateQuery(BaseQuery):
+ def __init__(self, model, **kwargs):
+ self.update_query = kwargs
+ super(UpdateQuery, self).__init__(model)
+
+ def clone(self):
+ query = UpdateQuery(self.model, **self.update_query)
+ query._where = dict(self._where)
+ query._joins = list(self._joins)
+ return query
+
+ def parse_update(self):
+ sets = {}
+ for k, v in self.update_query.iteritems():
+ try:
+ field = self.model._meta.get_field_by_name(k)
+ except AttributeError:
+ field = self.model._meta.get_related_field_by_name(k)
+ if field is None:
+ raise
+
+ sets[field.name] = field.db_value(v)
+
+ return sets
+
+ def sql(self):
+ joins, where, where_data, alias_map = self.compile_where()
+ set_statement = self.parse_update()
+
+ params = []
+ update_params = []
+
+ for k, v in set_statement.iteritems():
+ params.append(v)
+ update_params.append('%s=%s' % (k, self.interpolation))
+
+ update = 'UPDATE %s SET %s' % (
+ self.model._meta.db_table, ', '.join(update_params))
+ where = ' AND '.join(where)
+
+ pieces = [update]
+
+ if where:
+ pieces.append('WHERE %s' % where)
+ params.extend(self.convert_where_to_params(where_data))
+
+ return ' '.join(pieces), params
+
+ def join(self, *args, **kwargs):
+ raise AttributeError('Update queries do not support JOINs in sqlite')
+
+ def execute(self):
+ result = self.raw_execute()
+ return self.database.rows_affected(result)
+
+
+class DeleteQuery(BaseQuery):
+ def clone(self):
+ query = DeleteQuery(self.model)
+ query._where = dict(self._where)
+ query._joins = list(self._joins)
+ return query
+
+ def sql(self):
+ joins, where, where_data, alias_map = self.compile_where()
+
+ params = []
+
+ delete = 'DELETE FROM %s' % (self.model._meta.db_table)
+ where = ' AND '.join(where)
+
+ pieces = [delete]
+
+ if where:
+ pieces.append('WHERE %s' % where)
+ params.extend(self.convert_where_to_params(where_data))
+
+ return ' '.join(pieces), params
+
+ def join(self, *args, **kwargs):
+ raise AttributeError('Update queries do not support JOINs in sqlite')
+
+ def execute(self):
+ result = self.raw_execute()
+ return self.database.rows_affected(result)
+
+
+class InsertQuery(BaseQuery):
+ def __init__(self, model, **kwargs):
+ self.insert_query = kwargs
+ super(InsertQuery, self).__init__(model)
+
+ def parse_insert(self):
+ cols = []
+ vals = []
+ for k, v in self.insert_query.iteritems():
+ field = self.model._meta.get_field_by_name(k)
+ cols.append(k)
+ vals.append(field.db_value(v))
+
+ return cols, vals
+
+ def sql(self):
+ cols, vals = self.parse_insert()
+
+ insert = 'INSERT INTO %s (%s) VALUES (%s)' % (
+ self.model._meta.db_table,
+ ','.join(cols),
+ ','.join(self.interpolation for v in vals)
+ )
+
+ return insert, vals
+
+ def where(self, *args, **kwargs):
+ raise AttributeError('Insert queries do not support WHERE clauses')
+
+ def join(self, *args, **kwargs):
+ raise AttributeError('Insert queries do not support JOINs')
+
+ def execute(self):
+ result = self.raw_execute()
+ return self.database.last_insert_id(result, self.model)
+
+
+class Field(object):
+ db_field = ''
+ default = None
+ field_template = "%(column_type)s%(nullable)s"
+
+ def get_attributes(self):
+ return {}
+
+ def __init__(self, null=False, db_index=False, *args, **kwargs):
+ self.null = null
+ self.db_index = db_index
+ self.attributes = self.get_attributes()
+ self.default = kwargs.get('default', None)
+
+ kwargs['nullable'] = ternary(self.null, '', ' NOT NULL')
+ self.attributes.update(kwargs)
+
+ def add_to_class(self, klass, name):
+ self.name = name
+ self.model = klass
+ setattr(klass, name, None)
+
+ def render_field_template(self):
+ col_type = self.model._meta.database.column_for_field(self.db_field)
+ self.attributes['column_type'] = col_type
+ return self.field_template % self.attributes
+
+ def to_sql(self):
+ rendered = self.render_field_template()
+ return '%s %s' % (self.name, rendered)
+
+ def null_wrapper(self, value, default=None):
+ if (self.null and value is None) or default is None:
+ return value
+ return value or default
+
+ def db_value(self, value):
+ return value
+
+ def python_value(self, value):
+ return value
+
+ def lookup_value(self, lookup_type, value):
+ return self.db_value(value)
+
+
+class CharField(Field):
+ db_field = 'string'
+ field_template = '%(column_type)s(%(max_length)d)%(nullable)s'
+
+ def get_attributes(self):
+ return {'max_length': 255}
+
+ def db_value(self, value):
+ if self.null and value is None:
+ return value
+ value = value or ''
+ return value[:self.attributes['max_length']]
+
+ def lookup_value(self, lookup_type, value):
+ if lookup_type == 'contains':
+ return '*%s*' % self.db_value(value)
+ elif lookup_type == 'icontains':
+ return '%%%s%%' % self.db_value(value)
+ else:
+ return self.db_value(value)
+
+
+class TextField(Field):
+ db_field = 'text'
+
+ def db_value(self, value):
+ return self.null_wrapper(value, '')
+
+ def lookup_value(self, lookup_type, value):
+ if lookup_type == 'contains':
+ return '*%s*' % self.db_value(value)
+ elif lookup_type == 'icontains':
+ return '%%%s%%' % self.db_value(value)
+ else:
+ return self.db_value(value)
+
+
+class DateTimeField(Field):
+ db_field = 'datetime'
+
+ def python_value(self, value):
+ if isinstance(value, basestring):
+ value = value.rsplit('.', 1)[0]
+ return datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6])
+ return value
+
+
+class IntegerField(Field):
+ db_field = 'integer'
+
+ def db_value(self, value):
+ return self.null_wrapper(value, 0)
+
+ def python_value(self, value):
+ if value is not None:
+ return int(value)
+
+
+class BooleanField(IntegerField):
+ db_field = 'boolean'
+
+ def db_value(self, value):
+ if value:
+ return 1
+ return 0
+
+ def python_value(self, value):
+ return bool(value)
+
+
+class FloatField(Field):
+ db_field = 'float'
+
+ def db_value(self, value):
+ return self.null_wrapper(value, 0.0)
+
+ def python_value(self, value):
+ if value is not None:
+ return float(value)
+
+
+class PrimaryKeyField(IntegerField):
+ db_field = 'primary_key'
+ field_template = "%(column_type)s NOT NULL PRIMARY KEY"
+
+
+class ForeignRelatedObject(object):
+ def __init__(self, to, name):
+ self.field_name = name
+ self.to = to
+ self.cache_name = '_cache_%s' % name
+
+ def __get__(self, instance, instance_type=None):
+ if not getattr(instance, self.cache_name, None):
+ id = getattr(instance, self.field_name, 0)
+ qr = self.to.select().where(**{self.to._meta.pk_name: id}).execute()
+ setattr(instance, self.cache_name, qr.next())
+ return getattr(instance, self.cache_name)
+
+ def __set__(self, instance, obj):
+ assert isinstance(obj, self.to), "Cannot assign %s, invalid type" % obj
+ setattr(instance, self.field_name, obj.get_pk())
+ setattr(instance, self.cache_name, obj)
+
+
+class ReverseForeignRelatedObject(object):
+ def __init__(self, related_model, name):
+ self.field_name = name
+ self.related_model = related_model
+
+ def __get__(self, instance, instance_type=None):
+ query = {self.field_name: instance.get_pk()}
+ qr = self.related_model.select().where(**query)
+ return qr
+
+
+class ForeignKeyField(IntegerField):
+ db_field = 'foreign_key'
+ field_template = '%(column_type)s%(nullable)s REFERENCES %(to_table)s (%(to_pk)s)'
+
+ def __init__(self, to, null=False, related_name=None, *args, **kwargs):
+ self.to = to
+ self.related_name = related_name
+ kwargs.update({
+ 'to_table': to._meta.db_table,
+ 'to_pk': to._meta.pk_name
+ })
+ super(ForeignKeyField, self).__init__(null=null, *args, **kwargs)
+
+ def add_to_class(self, klass, name):
+ self.descriptor = name
+ self.name = name + '_id'
+ self.model = klass
+
+ if self.related_name is None:
+ self.related_name = klass._meta.db_table + '_set'
+
+ klass._meta.rel_fields[name] = self.name
+ setattr(klass, self.descriptor, ForeignRelatedObject(self.to, self.name))
+ setattr(klass, self.name, None)
+
+ reverse_rel = ReverseForeignRelatedObject(klass, self.name)
+ setattr(self.to, self.related_name, reverse_rel)
+
+ def lookup_value(self, lookup_type, value):
+ if isinstance(value, Model):
+ return value.get_pk()
+ return value or None
+
+ def db_value(self, value):
+ if isinstance(value, Model):
+ return value.get_pk()
+ return value
+
+
+# define a default database object in the module scope
+database = SqliteDatabase(DATABASE_NAME)
+
+
+class BaseModelOptions(object):
+ def __init__(self, model_class, options=None):
+ # configurable options
+ options = options or {'database': database}
+ for k, v in options.items():
+ setattr(self, k, v)
+
+ self.rel_fields = {}
+ self.fields = {}
+ self.model_class = model_class
+
+ def get_field_names(self):
+ fields = [self.pk_name]
+ fields.extend([f for f in sorted(self.fields.keys()) if f != self.pk_name])
+ return fields
+
+ def get_field_by_name(self, name):
+ if name in self.fields:
+ return self.fields[name]
+ raise AttributeError('Field named %s not found' % name)
+
+ def get_related_field_by_name(self, name):
+ if name in self.rel_fields:
+ return self.fields[self.rel_fields[name]]
+
+ def get_related_field_for_model(self, model, name=None):
+ for field in self.fields.values():
+ if isinstance(field, ForeignKeyField) and field.to == model:
+ if name is None or name == field.name or name == field.descriptor:
+ return field
+
+ def get_reverse_related_field_for_model(self, model, name=None):
+ for field in model._meta.fields.values():
+ if isinstance(field, ForeignKeyField) and field.to == self.model_class:
+ if name is None or name == field.name or name == field.descriptor:
+ return field
+
+ def rel_exists(self, model):
+ return self.get_related_field_for_model(model) or \
+ self.get_reverse_related_field_for_model(model)
+
+
+class BaseModel(type):
+ inheritable_options = ['database']
+
+ def __new__(cls, name, bases, attrs):
+ cls = super(BaseModel, cls).__new__(cls, name, bases, attrs)
+
+ attr_dict = {}
+ meta = attrs.pop('Meta', None)
+ if meta:
+ attr_dict = meta.__dict__
+
+ for b in bases:
+ base_meta = getattr(b, '_meta', None)
+ if not base_meta:
+ continue
+
+ for (k, v) in base_meta.__dict__.items():
+ if k in cls.inheritable_options and k not in attr_dict:
+ attr_dict[k] = v
+
+ _meta = BaseModelOptions(cls, attr_dict)
+
+ if not hasattr(_meta, 'db_table'):
+ _meta.db_table = re.sub('[^a-z]+', '_', cls.__name__.lower())
+
+ setattr(cls, '_meta', _meta)
+
+ _meta.pk_name = None
+
+ for name, attr in cls.__dict__.items():
+ if isinstance(attr, Field):
+ attr.add_to_class(cls, name)
+ _meta.fields[attr.name] = attr
+ if isinstance(attr, PrimaryKeyField):
+ _meta.pk_name = attr.name
+
+ if _meta.pk_name is None:
+ _meta.pk_name = 'id'
+ pk = PrimaryKeyField()
+ pk.add_to_class(cls, _meta.pk_name)
+ _meta.fields[_meta.pk_name] = pk
+
+ _meta.model_name = cls.__name__
+
+ if hasattr(cls, '__unicode__'):
+ setattr(cls, '__repr__', lambda self: '<%s: %s>' % (
+ _meta.model_name, self.__unicode__()))
+
+ exception_class = type('%sDoesNotExist' % _meta.model_name, (DoesNotExist,), {})
+ cls.DoesNotExist = exception_class
+
+ return cls
+
+
+class Model(object):
+ __metaclass__ = BaseModel
+
+ def __init__(self, *args, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+ def __eq__(self, other):
+ return other.__class__ == self.__class__ and \
+ self.get_pk() and \
+ other.get_pk() == self.get_pk()
+
+ def get_field_dict(self):
+ def get_field_val(field):
+ field_value = getattr(self, field.name)
+ if not self.get_pk() and field_value is None and field.default is not None:
+ if callable(field.default):
+ field_value = field.default()
+ else:
+ field_value = field.default
+ setattr(self, field.name, field_value)
+ return (field.name, field_value)
+
+ pairs = map(get_field_val, self._meta.fields.values())
+ return dict(pairs)
+
+ @classmethod
+ def create_table(cls):
+ cls._meta.database.create_table(cls)
+
+ for field_name, field_obj in cls._meta.fields.items():
+ if isinstance(field_obj, PrimaryKeyField):
+ cls._meta.database.create_index(cls, field_obj.name, True)
+ elif isinstance(field_obj, ForeignKeyField):
+ cls._meta.database.create_index(cls, field_obj.name)
+ elif field_obj.db_index:
+ cls._meta.database.create_index(cls, field_obj.name)
+
+ @classmethod
+ def drop_table(cls, fail_silently=False):
+ cls._meta.database.drop_table(cls, fail_silently)
+
+ @classmethod
+ def select(cls, query=None):
+ return SelectQuery(cls, query)
+
+ @classmethod
+ def update(cls, **query):
+ return UpdateQuery(cls, **query)
+
+ @classmethod
+ def insert(cls, **query):
+ return InsertQuery(cls, **query)
+
+ @classmethod
+ def delete(cls, **query):
+ return DeleteQuery(cls, **query)
+
+ @classmethod
+ def raw(cls, sql, *params):
+ return RawQuery(cls, sql, *params)
+
+ @classmethod
+ def create(cls, **query):
+ inst = cls(**query)
+ inst.save()
+ return inst
+
+ @classmethod
+ def get_or_create(cls, **query):
+ try:
+ inst = cls.get(**query)
+ except cls.DoesNotExist:
+ inst = cls.create(**query)
+ return inst
+
+ @classmethod
+ def get(cls, *args, **kwargs):
+ query = cls.select().where(*args, **kwargs).paginate(1, 1)
+ try:
+ return query.execute().next()
+ except StopIteration:
+ raise cls.DoesNotExist('instance matching query does not exist:\nSQL: %s\nPARAMS: %s' % (
+ query.sql()
+ ))
+
+ def get_pk(self):
+ return getattr(self, self._meta.pk_name, None)
+
+ def save(self):
+ field_dict = self.get_field_dict()
+ field_dict.pop(self._meta.pk_name)
+ if self.get_pk():
+ update = self.update(
+ **field_dict
+ ).where(**{self._meta.pk_name: self.get_pk()})
+ update.execute()
+ else:
+ insert = self.insert(**field_dict)
+ new_pk = insert.execute()
+ setattr(self, self._meta.pk_name, new_pk)
Repository URL: https://bitbucket.org/yt_analysis/yt/
--
This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.
More information about the yt-svn
mailing list