[Yt-svn] yt-commit r637 - trunk/yt/lagos
mturk at wrangler.dreamhost.com
mturk at wrangler.dreamhost.com
Mon Jun 30 15:22:51 PDT 2008
Author: mturk
Date: Mon Jun 30 15:22:50 2008
New Revision: 637
URL: http://yt.spacepope.org/changeset/637
Log:
Refactoring the profiles code for readability and generality with particles.
Next is hybrid particle/baryon fields.
Modified:
trunk/yt/lagos/Profiles.py
Modified: trunk/yt/lagos/Profiles.py
==============================================================================
--- trunk/yt/lagos/Profiles.py (original)
+++ trunk/yt/lagos/Profiles.py Mon Jun 30 15:22:50 2008
@@ -112,6 +112,16 @@
def __setitem__(self, key, value):
self._data[key] = value
+ def _get_field(self, source, field, check_cut):
+ if check_cut:
+ if field in fieldInfo and fieldInfo[field].particle_type:
+ pointI = self._data_source._get_particle_indices(source)
+ else:
+ pointI = self._data_source._get_point_indices(source)
+ else:
+ pointI = slice(None)
+ return source[field][pointI].ravel().astype('float64')
+
# @todo: Fix accumulation with overriding
class BinnedProfile1D(BinnedProfile):
def __init__(self, data_source, n_bins, bin_field,
@@ -154,13 +164,8 @@
mi, inv_bin_indices = args # Args has the indices to use as input
# check_cut is set if source != self._data_source
# (i.e., lazy_reader)
- if check_cut: # Only use the points inside the source
- cm = self._data_source._get_point_indices(source)
- source_data = source[field][cm].astype('float64')[mi]
- if weight: weight_data = source[weight][cm].astype('float64')[mi]
- else:
- source_data = source[field].astype('float64')[mi]
- if weight: weight_data = source[weight].astype('float64')[mi]
+ source_data = self._get_field(source, field, check_cut)[mi]
+ if weight: weight_data = self._get_field(source, weight, check_cut)[mi]
binned_field = self._get_empty_field()
weight_field = self._get_empty_field()
used_field = na.ones(weight_field.shape, dtype='bool')
@@ -183,11 +188,7 @@
@preserve_source_parameters
def _get_bins(self, source, check_cut=False):
- if check_cut: # if source != self._data_source
- cm = self._data_source._get_point_indices(source)
- source_data = source[self.bin_field][cm]
- else:
- source_data = source[self.bin_field]
+ source_data = self._get_field(source, self.bin_field, check_cut)
if source_data.size == 0: # Nothing for us here.
return
# Truncate at boundaries.
@@ -264,15 +265,9 @@
@preserve_source_parameters
def _bin_field(self, source, field, weight, accumulation,
args, check_cut=False):
- if check_cut:
- pointI = self._data_source._get_point_indices(source)
- source_data = source[field][pointI].ravel().astype('float64')
- weight_data = na.ones(source_data.shape).astype('float64')
- if weight: weight_data = source[weight][pointI].ravel().astype('float64')
- else:
- source_data = source[field].ravel().astype('float64')
- weight_data = na.ones(source_data.shape).astype('float64')
- if weight: weight_data = source[weight].ravel().astype('float64')
+ source_data = self._get_field(source, field, check_cut)
+ if weight: weight_data = self._get_field(source, weight, check_cut)
+ else: weight_data = na.ones(source_data.shape, dtype='float64')
self.total_stuff = source_data.sum()
binned_field = self._get_empty_field()
weight_field = self._get_empty_field()
@@ -297,13 +292,8 @@
@preserve_source_parameters
def _get_bins(self, source, check_cut=False):
- if check_cut:
- cm = self._data_source._get_point_indices(source)
- source_data_x = source[self.x_bin_field][cm]
- source_data_y = source[self.y_bin_field][cm]
- else:
- source_data_x = source[self.x_bin_field]
- source_data_y = source[self.y_bin_field]
+ source_data_x = self._get_field(source, self.x_bin_field, check_cut)
+ source_data_y = self._get_field(source, self.y_bin_field, check_cut)
if source_data_x.size == 0:
return
mi = na.where( (source_data_x > self[self.x_bin_field].min())
@@ -384,15 +374,10 @@
@preserve_source_parameters
def _bin_field(self, source, field, weight, accumulation,
args, check_cut=False):
- if check_cut:
- pointI = self._data_source._get_point_indices(source)
- source_data = source[field][pointI].ravel().astype('float64')
- weight_data = na.ones(source_data.shape).astype('float64')
- if weight: weight_data = source[weight][pointI].ravel().astype('float64')
- else:
- source_data = source[field].ravel().astype('float64')
- weight_data = na.ones(source_data.shape).astype('float64')
- if weight: weight_data = source[weight].ravel().astype('float64')
+ source_data = self._get_field(source, field, check_cut)
+ weight_data = na.ones(source_data.shape).astype('float64')
+ if weight: weight_data = self._get_field(source, weight, check_cut)
+ else: weight_data = na.ones(source_data.shape).astype('float64')
self.total_stuff = source_data.sum()
binned_field = self._get_empty_field()
weight_field = self._get_empty_field()
@@ -420,15 +405,9 @@
@preserve_source_parameters
def _get_bins(self, source, check_cut=False):
- if check_cut:
- cm = self._data_source._get_point_indices(source)
- source_data_x = source[self.x_bin_field][cm]
- source_data_y = source[self.y_bin_field][cm]
- source_data_z = source[self.z_bin_field][cm]
- else:
- source_data_x = source[self.x_bin_field]
- source_data_y = source[self.y_bin_field]
- source_data_z = source[self.z_bin_field]
+ source_data_x = self._get_field(source, self.x_bin_field, check_cut)
+ source_data_y = self._get_field(source, self.y_bin_field, check_cut)
+ source_data_y = self._get_field(source, self.z_bin_field, check_cut)
if source_data_x.size == 0:
return
mi = na.where( (source_data_x > self[self.x_bin_field].min())
More information about the yt-svn
mailing list