[Yt-svn] yt-commit r1223 - in trunk/yt: . lagos raven

mturk at wrangler.dreamhost.com mturk at wrangler.dreamhost.com
Fri Mar 20 08:24:45 PDT 2009


Author: mturk
Date: Fri Mar 20 08:24:43 2009
New Revision: 1223
URL: http://yt.spacepope.org/changeset/1223

Log:
 * New generalized parallel object iterator
 * New modification method for plot callbacks: the .modify[] dictionary, which
also happens to include a velocity callback!
 * x_names and y_names



Modified:
   trunk/yt/lagos/EnzoDefs.py
   trunk/yt/lagos/ParallelTools.py
   trunk/yt/lagos/setup.py
   trunk/yt/mods.py
   trunk/yt/raven/Callbacks.py
   trunk/yt/raven/PlotTypes.py

Modified: trunk/yt/lagos/EnzoDefs.py
==============================================================================
--- trunk/yt/lagos/EnzoDefs.py	(original)
+++ trunk/yt/lagos/EnzoDefs.py	Fri Mar 20 08:24:43 2009
@@ -41,6 +41,9 @@
 x_dict = [1,0,0]
 y_dict = [2,2,1]
 
+x_names = ['y','x','x']
+y_names = ['z','z','y']
+
 mh = 1.67e-24
 mu = 1.22
 

Modified: trunk/yt/lagos/ParallelTools.py
==============================================================================
--- trunk/yt/lagos/ParallelTools.py	(original)
+++ trunk/yt/lagos/ParallelTools.py	Fri Mar 20 08:24:43 2009
@@ -54,66 +54,50 @@
 else:
     parallel_capable = False
 
-class GridIterator(object):
-    def __init__(self, pobj, just_list = False):
+class ObjectIterator(object):
+    def __init__(self, pobj, just_list = False, attr='_grids'):
         self.pobj = pobj
-        if hasattr(pobj, '_grids') and pobj._grids is not None:
-            gs = pobj._grids
+        if hasattr(pobj, attr) and getattr(pobj, attr) is not None:
+            gs = getattr(pobj, attr)
         else:
-            gs = pobj._data_source._grids
+            gs = getattr(pobj._data_source, attr)
         if hasattr(gs[0], 'proc_num'):
             # This one sort of knows about MPI, but not quite
-            self._grids = [g for g in gs if g.proc_num ==
-                            ytcfg.getint('yt','__parallel_rank')]
+            self._objs = [g for g in gs if g.proc_num ==
+                          ytcfg.getint('yt','__parallel_rank')]
             self._use_all = True
         else:
-            self._grids = sorted(gs, key = lambda g: g.filename)
+            self._objs = gs
+            if hasattr(self._objs[0], 'filename'):
+                self._objs = sorted(self._objs, key = lambda g: g.filename)
             self._use_all = False
-        self.ng = len(self._grids)
+        self.ng = len(self._objs)
         self.just_list = just_list
 
     def __iter__(self):
-        self.pos = 0
-        return self
-
-    def next(self):
-        # We do this manually in case
-        # something else asks for us.pos
-        if self.pos < len(self._grids):
-            self.pos += 1
-            return self._grids[self.pos - 1]
-        raise StopIteration
-
-class ParallelGridIterator(GridIterator):
+        for obj in self._objs: yield obj
+        
+class ParallelObjectIterator(ObjectIterator):
     """
     This takes an object, pobj, that implements ParallelAnalysisInterface,
     and then does its thing.
     """
-    def __init__(self, pobj, just_list = False):
-        GridIterator.__init__(self, pobj, just_list)
+    def __init__(self, pobj, just_list = False, attr='_grids'):
+        ObjectIterator.__init__(self, pobj, just_list)
         self._offset = MPI.COMM_WORLD.rank
         self._skip = MPI.COMM_WORLD.size
         # Note that we're doing this in advance, and with a simple means
         # of choosing them; more advanced methods will be explored later.
         if self._use_all:
-            self.my_grid_ids = na.arange(len(self._grids))
+            self.my_obj_ids = na.arange(len(self._objs))
         else:
-            #upper, lower = na.mgrid[0:self.ng:(self._skip+1)*1j][self._offset:self._offset+2]
-            #self.my_grid_ids = na.mgrid[upper:lower-1].astype("int64")
-            self.my_grid_ids = na.array_split(
-                            na.arange(len(self._grids)), self._skip)[self._offset]
+            self.my_obj_ids = na.array_split(
+                            na.arange(len(self._objs)), self._skip)[self._offset]
         
     def __iter__(self):
-        self.pos = 0
-        return self
-
-    def next(self):
-        if self.pos < len(self.my_grid_ids):
-            gid = self.my_grid_ids[self.pos]
-            self.pos += 1
-            return self._grids[gid]
+        for gid in self.my_obj_ids:
+            yield self._objs[gid]
         if not self.just_list: self.pobj._finalize_parallel()
-        raise StopIteration
 
 def parallel_simple_proxy(func):
     if not parallel_capable: return func
@@ -170,16 +154,22 @@
     _grids = None
     _distributed = parallel_capable
 
+    def _get_objs(self, attr, *args, **kwargs):
+        if parallel_capable:
+            self._initialize_parallel(*args, **kwargs)
+            return ParallelObjectIterator(self, attr=attr)
+        return ObjectIterator(self, attr=attr)
+
     def _get_grids(self, *args, **kwargs):
         if parallel_capable:
             self._initialize_parallel(*args, **kwargs)
-            return ParallelGridIterator(self)
-        return GridIterator(self)
+            return ParallelObjectIterator(self, attr='_grids')
+        return ObjectIterator(self, attr='_grids')
 
     def _get_grid_objs(self):
         if parallel_capable:
-            return ParallelGridIterator(self, True)
-        return GridIterator(self, True)
+            return ParallelObjectIterator(self, True, attr='_grids')
+        return ObjectIterator(self, True, attr='_grids')
 
     def _initialize_parallel(self):
         pass
@@ -255,6 +245,18 @@
         return data
 
     @parallel_passthrough
+    def _mpi_joindict(self, data):
+        self._barrier()
+        if MPI.COMM_WORLD.rank == 0:
+            for i in range(1,MPI.COMM_WORLD.size):
+                data.update(MPI.COMM_WORLD.recv(source=i, tag=0))
+        else:
+            MPI.COMM_WORLD.send(data, dest=0, tag=0)
+        data = MPI.COMM_WORLD.bcast(data, root=0)
+        self._barrier()
+        return data
+
+    @parallel_passthrough
     def __mpi_recvlist(self, data):
         # First we receive, then we make a new list.
         data = ensure_list(data)

Modified: trunk/yt/lagos/setup.py
==============================================================================
--- trunk/yt/lagos/setup.py	(original)
+++ trunk/yt/lagos/setup.py	Fri Mar 20 08:24:43 2009
@@ -20,9 +20,9 @@
     config.make_config_py() # installs __config__.py
     config.make_svn_version_py()
     config.add_extension("PointCombine", "yt/lagos/PointCombine.c", libraries=["m"])
-    #config.add_extension("RTIntegrator", "yt/lagos/RTIntegrator.c")
+    config.add_extension("RTIntegrator", "yt/lagos/RTIntegrator.c")
     config.add_extension("Interpolators", "yt/lagos/Interpolators.c")
-    #config.add_extension("DepthFirstOctree", "yt/lagos/DepthFirstOctree.c")
+    config.add_extension("DepthFirstOctree", "yt/lagos/DepthFirstOctree.c")
     config.add_subpackage("hop")
     H5dir = check_for_hdf5()
     if H5dir is not None:
@@ -33,6 +33,6 @@
                              libraries=["m","hdf5"],
                              library_dirs=library_dirs, include_dirs=include_dirs)
     # Uncomment the next two lines if you want particle_density support
-    #config.add_extension("cic_deposit", ["yt/lagos/enzo_routines/cic_deposit.pyf",
-    #                                     "yt/lagos/enzo_routines/cic_deposit.f"])
+    config.add_extension("cic_deposit", ["yt/lagos/enzo_routines/cic_deposit.pyf",
+                                         "yt/lagos/enzo_routines/cic_deposit.f"])
     return config

Modified: trunk/yt/mods.py
==============================================================================
--- trunk/yt/mods.py	(original)
+++ trunk/yt/mods.py	Fri Mar 20 08:24:43 2009
@@ -52,8 +52,8 @@
 # Now individual component imports from raven
 from yt.raven import PlotCollection, PlotCollectionInteractive, get_multi_plot
 from yt.raven.Callbacks import callback_registry
-for name, cls in callback_registry:
-    exec("from yt.raven import %s" % name)
+for name, cls in callback_registry.items():
+    exec("%s = cls" % name)
 
 # Optional component imports from raven
 try:

Modified: trunk/yt/raven/Callbacks.py
==============================================================================
--- trunk/yt/raven/Callbacks.py	(original)
+++ trunk/yt/raven/Callbacks.py	Fri Mar 20 08:24:43 2009
@@ -32,13 +32,13 @@
 
 import _MPL
 import copy
-callback_registry = []
+callback_registry = {}
 
 class PlotCallback(object):
     class __metaclass__(type):
         def __init__(cls, name, b, d):
             type.__init__(cls, name, b, d)
-            callback_registry.append((name, cls))
+            callback_registry[name] = cls
 
     def __init__(self, *args, **kwargs):
         pass
@@ -54,7 +54,30 @@
         return ((coord[0] - int(offset)*x0)*dx,
                 (coord[1] - int(offset)*y0)*dy)
 
+class VelocityCallback(PlotCallback):
+    _type_name = "velocity"
+    def __init__(self, factor):
+        """
+        Adds a 'quiver' plot of velocity to the plot, skipping all but
+        every *factor* datapoint
+        """
+        PlotCallback.__init__(self)
+        self.factor = factor
+
+    def __call__(self, plot):
+        # Instantiation of these is cheap
+        if plot._type_name == "CuttingPlane":
+            qcb = CuttingQuiverCallback("CuttingPlaneVelocityX",
+                                        "CuttingPlaneVelocityY",
+                                        self.factor)
+        else:
+            xv = "%s-velocity" % (lagos.x_names[plot.data.axis])
+            yv = "%s-velocity" % (lagos.y_names[plot.data.axis])
+            qcb = QuiverCallback(xv, yv, self.factor)
+        return qcb(plot)
+
 class QuiverCallback(PlotCallback):
+    _type_name = "quiver"
     def __init__(self, field_x, field_y, factor):
         """
         Adds a 'quiver' plot to any plot, using the *field_x* and *field_y*
@@ -96,6 +119,7 @@
         plot._axes.hold(False)
 
 class ParticleCallback(PlotCallback):
+    _type_name = "particles"
     def __init__(self, axis, width, p_size=1.0, col='k', stride=1.0):
         """
         Adds particle positions, based on a thick slab along *axis* with a
@@ -163,6 +187,7 @@
         plot._axes.hold(False)
 
 class ContourCallback(PlotCallback):
+    _type_name = "contour"
     def __init__(self, field, ncont=5, factor=4, take_log=False, clim=None,
                  plot_args = None):
         """
@@ -235,6 +260,7 @@
         plot._axes.hold(False)
 
 class GridBoundaryCallback(PlotCallback):
+    _type_name = "grids"
     def __init__(self, alpha=1.0, min_pix = 1):
         """
         Adds grid boundaries to a plot, optionally with *alpha*-blending.
@@ -280,6 +306,7 @@
             plot._axes.hold(False)
 
 class LabelCallback(PlotCallback):
+    _type_name = "axis_label"
     def __init__(self, label):
         PlotCallback.__init__(self)
         self.label = label
@@ -302,6 +329,7 @@
     return good_u
 
 class UnitBoundaryCallback(PlotCallback):
+    _type_name = "units"
     def __init__(self, unit = "au", factor=4, text_annotate=True, text_which=-2):
         """
         Add on a plot indicating where *factor*s of *unit* are shown.
@@ -361,6 +389,7 @@
         plot._axes.hold(False)
 
 class LinePlotCallback(PlotCallback):
+    _type_name = "line"
     def __init__(self, x, y, plot_args = None):
         """
         Over plot *x* and *y* with *plot_args* fed into the plot.
@@ -377,6 +406,7 @@
         plot._axes.hold(False)
 
 class CuttingQuiverCallback(PlotCallback):
+    _type_name = "quiver"
     def __init__(self, field_x, field_y, factor):
         """
         Get a quiver plot on top of a cutting plane, using *field_x* and
@@ -418,7 +448,8 @@
         plot._axes.hold(False)
 
 class ClumpContourCallback(PlotCallback):
-    def __init__(self, clumps, axis = None, plot_args = None):
+    _type_name = "clumps"
+    def __init__(self, clumps, plot_args = None):
         """
         Take a list of *clumps* and plot them as a set of contours.
         """
@@ -481,6 +512,7 @@
         plot._axes.hold(False)
 
 class ArrowCallback(PlotCallback):
+    _type_name = "arrow"
     def __init__(self, pos, code_size, plot_args = None):
         self.pos = pos
         self.code_size = code_size
@@ -496,6 +528,7 @@
         plot._axes.add_patch(arrow)
 
 class PointAnnotateCallback(PlotCallback):
+    _type_name = "point"
     def __init__(self, pos, text, text_args = None):
         self.pos = pos
         self.text = text
@@ -506,6 +539,7 @@
         plot._axes.text(x, y, self.text, **self.text_args)
 
 class MarkerAnnotateCallback(PlotCallback):
+    _type_name = "marker"
     def __init__(self, pos, marker='x', plot_args=None):
         self.pos = pos
         self.marker = marker
@@ -524,6 +558,7 @@
         plot._axes.hold(False)
 
 class SphereCallback(PlotCallback):
+    _type_name = "sphere"
     def __init__(self, center, radius, circle_args = None,
                  text = None, text_args = None):
         self.center = center
@@ -554,11 +589,11 @@
                             **self.text_args)
 
 class HopCircleCallback(PlotCallback):
-    def __init__(self, hop_output, axis, max_number=None,
+    _type_name = "hop_circles"
+    def __init__(self, hop_output, max_number=None,
                  annotate=False, min_size=20, max_size=10000000,
                  font_size=8, print_halo_size=False,
                  print_halo_mass=False):
-        self.axis = axis
         self.hop_output = hop_output
         self.max_number = max_number
         self.annotate = annotate
@@ -604,7 +639,8 @@
     larger than *min_size* are plotted with *p_size* pixels per particle; 
     *alpha* determines the opacity of each particle.
     """
-    def __init__(self, hop_output, axis, p_size=1.0,
+    _type_name = "hop_particles"
+    def __init__(self, hop_output, p_size=1.0,
                 max_number=None, min_size=20, alpha=0.2):
         self.axis = axis
         self.hop_output = hop_output
@@ -637,6 +673,7 @@
             plot._axes.hold(False)
 
 class FloorToValueInPlot(PlotCallback):
+    _type_name = "floor"
     def __init__(self):
         pass
 
@@ -647,7 +684,8 @@
 
 
 class VobozCircleCallback(PlotCallback):
-    def __init__(self, voboz_output, axis, max_number=None,
+    _type_name = "voboz_circle"
+    def __init__(self, voboz_output, max_number=None,
                  annotate=False, min_size=20, font_size=8, print_halo_size=False):
         self.axis = axis
         self.voboz_output = voboz_output
@@ -689,6 +727,7 @@
     attempt to guess the proper units to use.
 
     """
+    _type_name = "coord_axes"
     def __init__(self,unit=None,coords=False):
         PlotCallback.__init__(self)
         self.unit = unit

Modified: trunk/yt/raven/PlotTypes.py
==============================================================================
--- trunk/yt/raven/PlotTypes.py	(original)
+++ trunk/yt/raven/PlotTypes.py	Fri Mar 20 08:24:43 2009
@@ -42,7 +42,36 @@
     engineVals["canvas"] = FigureCanvas
     return
 
-class RavenPlot:
+class CallbackRegistryHandler(object):
+    def __init__(self, plot):
+        self.plot = plot
+        self._callbacks = {}
+
+    def __getitem__(self, item):
+        if item not in self._callbacks:
+            raise KeyError(item)
+        cb = self._callbacks[item]
+
+        @wraps(cb)
+        def get_wrapper(*args, **kwargs):
+            cbo = cb(*args, **kwargs)
+            return self.plot.add_callback(cbo)
+
+        return get_wrapper
+
+    def __setitem__(self, item, val):
+        self._callbacks[item] = val
+
+    def __delitem__(self, item):
+        del self._callbacks[item]
+
+    def __iter__(self):
+        for k in sorted(self._callbacks): yield k
+
+    def keys(self):
+        return self._callbacks.keys()
+
+class RavenPlot(object):
 
     datalabel = None
     colorbar = None
@@ -64,6 +93,7 @@
         else:
             self._axes = axes
         self._callbacks = []
+        self._setup_callback_registry()
 
     def set_autoscale(self, val):
         self.do_autoscale = val
@@ -178,6 +208,13 @@
             self.xmin = self.ymin = 0.0
             self.xmax = self.ymax = 1.0
 
+    def _setup_callback_registry(self):
+        from yt.raven.Callbacks import callback_registry
+        self.modify = CallbackRegistryHandler(self)
+        for c in callback_registry.values():
+            if not hasattr(c, '_type_name'): continue
+            self.modify[c._type_name] = c
+
 class VMPlot(RavenPlot):
     _antialias = True
     _period = (0.0, 0.0)



More information about the yt-svn mailing list