[Yt-svn] yt-commit r637 - trunk/yt/lagos

mturk at wrangler.dreamhost.com mturk at wrangler.dreamhost.com
Mon Jun 30 15:22:51 PDT 2008


Author: mturk
Date: Mon Jun 30 15:22:50 2008
New Revision: 637
URL: http://yt.spacepope.org/changeset/637

Log:
Refactoring the profiles code for readability and generality with particles.

Next is hybrid particle/baryon fields.



Modified:
   trunk/yt/lagos/Profiles.py

Modified: trunk/yt/lagos/Profiles.py
==============================================================================
--- trunk/yt/lagos/Profiles.py	(original)
+++ trunk/yt/lagos/Profiles.py	Mon Jun 30 15:22:50 2008
@@ -112,6 +112,16 @@
     def __setitem__(self, key, value):
         self._data[key] = value
 
+    def _get_field(self, source, field, check_cut):
+        if check_cut:
+            if field in fieldInfo and fieldInfo[field].particle_type:
+                pointI = self._data_source._get_particle_indices(source)
+            else:
+                pointI = self._data_source._get_point_indices(source)
+        else:
+            pointI = slice(None)
+        return source[field][pointI].ravel().astype('float64')
+
 # @todo: Fix accumulation with overriding
 class BinnedProfile1D(BinnedProfile):
     def __init__(self, data_source, n_bins, bin_field,
@@ -154,13 +164,8 @@
         mi, inv_bin_indices = args # Args has the indices to use as input
         # check_cut is set if source != self._data_source
         # (i.e., lazy_reader)
-        if check_cut: # Only use the points inside the source
-            cm = self._data_source._get_point_indices(source)
-            source_data = source[field][cm].astype('float64')[mi]
-            if weight: weight_data = source[weight][cm].astype('float64')[mi]
-        else:
-            source_data = source[field].astype('float64')[mi]
-            if weight: weight_data = source[weight].astype('float64')[mi]
+        source_data = self._get_field(source, field, check_cut)[mi]
+        if weight: weight_data = self._get_field(source, weight, check_cut)[mi]
         binned_field = self._get_empty_field()
         weight_field = self._get_empty_field()
         used_field = na.ones(weight_field.shape, dtype='bool')
@@ -183,11 +188,7 @@
 
     @preserve_source_parameters
     def _get_bins(self, source, check_cut=False):
-        if check_cut: # if source != self._data_source
-            cm = self._data_source._get_point_indices(source)
-            source_data = source[self.bin_field][cm]
-        else:
-            source_data = source[self.bin_field]
+        source_data = self._get_field(source, self.bin_field, check_cut)
         if source_data.size == 0: # Nothing for us here.
             return
         # Truncate at boundaries.
@@ -264,15 +265,9 @@
     @preserve_source_parameters
     def _bin_field(self, source, field, weight, accumulation,
                    args, check_cut=False):
-        if check_cut:
-            pointI = self._data_source._get_point_indices(source)
-            source_data = source[field][pointI].ravel().astype('float64')
-            weight_data = na.ones(source_data.shape).astype('float64')
-            if weight: weight_data = source[weight][pointI].ravel().astype('float64')
-        else:
-            source_data = source[field].ravel().astype('float64')
-            weight_data = na.ones(source_data.shape).astype('float64')
-            if weight: weight_data = source[weight].ravel().astype('float64')
+        source_data = self._get_field(source, field, check_cut)
+        if weight: weight_data = self._get_field(source, weight, check_cut)
+        else: weight_data = na.ones(source_data.shape, dtype='float64')
         self.total_stuff = source_data.sum()
         binned_field = self._get_empty_field()
         weight_field = self._get_empty_field()
@@ -297,13 +292,8 @@
 
     @preserve_source_parameters
     def _get_bins(self, source, check_cut=False):
-        if check_cut:
-            cm = self._data_source._get_point_indices(source)
-            source_data_x = source[self.x_bin_field][cm]
-            source_data_y = source[self.y_bin_field][cm]
-        else:
-            source_data_x = source[self.x_bin_field]
-            source_data_y = source[self.y_bin_field]
+        source_data_x = self._get_field(source, self.x_bin_field, check_cut)
+        source_data_y = self._get_field(source, self.y_bin_field, check_cut)
         if source_data_x.size == 0:
             return
         mi = na.where( (source_data_x > self[self.x_bin_field].min())
@@ -384,15 +374,10 @@
     @preserve_source_parameters
     def _bin_field(self, source, field, weight, accumulation,
                    args, check_cut=False):
-        if check_cut:
-            pointI = self._data_source._get_point_indices(source)
-            source_data = source[field][pointI].ravel().astype('float64')
-            weight_data = na.ones(source_data.shape).astype('float64')
-            if weight: weight_data = source[weight][pointI].ravel().astype('float64')
-        else:
-            source_data = source[field].ravel().astype('float64')
-            weight_data = na.ones(source_data.shape).astype('float64')
-            if weight: weight_data = source[weight].ravel().astype('float64')
+        source_data = self._get_field(source, field, check_cut)
+        weight_data = na.ones(source_data.shape).astype('float64')
+        if weight: weight_data = self._get_field(source, weight, check_cut)
+        else: weight_data = na.ones(source_data.shape).astype('float64')
         self.total_stuff = source_data.sum()
         binned_field = self._get_empty_field()
         weight_field = self._get_empty_field()
@@ -420,15 +405,9 @@
 
     @preserve_source_parameters
     def _get_bins(self, source, check_cut=False):
-        if check_cut:
-            cm = self._data_source._get_point_indices(source)
-            source_data_x = source[self.x_bin_field][cm]
-            source_data_y = source[self.y_bin_field][cm]
-            source_data_z = source[self.z_bin_field][cm]
-        else:
-            source_data_x = source[self.x_bin_field]
-            source_data_y = source[self.y_bin_field]
-            source_data_z = source[self.z_bin_field]
+        source_data_x = self._get_field(source, self.x_bin_field, check_cut)
+        source_data_y = self._get_field(source, self.y_bin_field, check_cut)
+        source_data_y = self._get_field(source, self.z_bin_field, check_cut)
         if source_data_x.size == 0:
             return
         mi = na.where( (source_data_x > self[self.x_bin_field].min())



More information about the yt-svn mailing list