[Yt-svn] yt-commit r902 - trunk/yt/lagos

mturk at wrangler.dreamhost.com mturk at wrangler.dreamhost.com
Thu Nov 6 10:28:18 PST 2008


Author: mturk
Date: Thu Nov  6 10:28:18 2008
New Revision: 902
URL: http://yt.spacepope.org/changeset/902

Log:
Added some parallel passthrough stuff, a new recvarray and catarray set of
functions.



Modified:
   trunk/yt/lagos/ParallelTools.py

Modified: trunk/yt/lagos/ParallelTools.py
==============================================================================
--- trunk/yt/lagos/ParallelTools.py	(original)
+++ trunk/yt/lagos/ParallelTools.py	Thu Nov  6 10:28:18 2008
@@ -108,6 +108,13 @@
         if not self.just_list: self.pobj._finalize_parallel()
         raise StopIteration
 
+def parallel_passthrough(func):
+    @wraps(func)
+    def passage(self, data):
+        if not parallel_capable: return data
+        return func(self, data)
+    return passage
+
 class ParallelAnalysisInterface(object):
     _grids = None
 
@@ -163,13 +170,13 @@
 
         if padding > 0:
             return True, \
-                self.hierarchy.periodic_region(self.center, LE-padding, RE+padding)
+                LE, RE, self.hierarchy.periodic_region(self.center, LE-padding, RE+padding)
 
-        return True, self.hierarchy.region(self.center, LE, RE)
+        return LE, RE, self.hierarchy.region(self.center, LE, RE)
         
 
+    @parallel_passthrough
     def _mpi_catdict(self, data):
-        if not parallel_capable: return data
         mylog.debug("Opening MPI Barrier on %s", MPI.COMM_WORLD.rank)
         MPI.COMM_WORLD.Barrier()
         if MPI.COMM_WORLD.rank == 0:
@@ -181,6 +188,7 @@
         MPI.COMM_WORLD.Barrier()
         return data
 
+    @parallel_passthrough
     def __mpi_recvdict(self, data):
         # First we receive, then we make a new dict.
         for i in range(1,MPI.COMM_WORLD.size):
@@ -188,6 +196,7 @@
             for j in buf: data[j] = na.concatenate([data[j],buf[j]], axis=-1)
         return data
 
+    @parallel_passthrough
     def __mpi_recvlist(self, data):
         # First we receive, then we make a new list.
         for i in range(1,MPI.COMM_WORLD.size):
@@ -195,8 +204,8 @@
             data += buf
         return na.array(data)
 
+    @parallel_passthrough
     def _mpi_catlist(self, data):
-        if not parallel_capable: return data
         mylog.debug("Opening MPI Barrier on %s", MPI.COMM_WORLD.rank)
         MPI.COMM_WORLD.Barrier()
         if MPI.COMM_WORLD.rank == 0:
@@ -208,6 +217,26 @@
         MPI.COMM_WORLD.Barrier()
         return data
 
+    @parallel_passthrough
+    def __mpi_recvarray(self, data):
+        # First we receive, then we make a new list.
+        for i in range(1,MPI.COMM_WORLD.size):
+            buf = MPI.COMM_WORLD.Recv(source=i, tag=0)
+            data = na.concatenate([data, buf])
+        return data
+
+    @parallel_passthrough
+    def _mpi_catarray(self, data):
+        mylog.debug("Opening MPI Barrier on %s", MPI.COMM_WORLD.rank)
+        MPI.COMM_WORLD.Barrier()
+        if MPI.COMM_WORLD.rank == 0:
+            data = self.__mpi_recvarray(data)
+        else:
+            MPI.COMM_WORLD.Send(data, dest=0, tag=0)
+        mylog.debug("Opening MPI Broadcast on %s", MPI.COMM_WORLD.rank)
+        data = MPI.COMM_WORLD.Bcast(data, root=0)
+        MPI.COMM_WORLD.Barrier()
+
     def _should_i_write(self):
         if not parallel_capable: return True
         return (MPI.COMM_WORLD == 0)
@@ -218,10 +247,12 @@
         if not parallel_capable: return
         queue.preload(grids, fields)
 
+    @parallel_passthrough
     def _mpi_allsum(self, data):
         MPI.COMM_WORLD.Barrier()
         return MPI.COMM_WORLD.Allreduce(data, op=MPI.SUM)
 
+
     def _get_dependencies(self, fields):
         deps = []
         for field in fields:



More information about the yt-svn mailing list