[Yt-svn] yt-commit r1123 - branches/grid-optimization/yt/lagos

mturk at wrangler.dreamhost.com mturk at wrangler.dreamhost.com
Fri Jan 16 12:37:37 PST 2009


Author: mturk
Date: Fri Jan 16 12:37:37 2009
New Revision: 1123
URL: http://yt.spacepope.org/changeset/1123

Log:
New mpi4py (trunk from google code) bindings require us to use the buffer
interface.  This replaces pickling of data.  I've wrapped all the functions so
that they can be used basically the same way, but I have also broken backwards
compatibility.

Possibly still has bugs, and only really tested with projections.



Modified:
   branches/grid-optimization/yt/lagos/ParallelTools.py

Modified: branches/grid-optimization/yt/lagos/ParallelTools.py
==============================================================================
--- branches/grid-optimization/yt/lagos/ParallelTools.py	(original)
+++ branches/grid-optimization/yt/lagos/ParallelTools.py	Fri Jan 16 12:37:37 2009
@@ -235,30 +235,23 @@
             mylog.debug("Joining %s (%s) on %s", key, type(data[key]),
                         MPI.COMM_WORLD.rank)
             if MPI.COMM_WORLD.rank == 0:
-                data[key] = na.concatenate([data[key]] +
-                 [MPI.COMM_WORLD.Recv(source=i, tag=0) for i in range(1, np)],
-                    axis=-1)
+                temp_data = []
+                for i in range(1,np):
+                    temp_data.append(_recv_array(source=i, tag=0))
+                data[key] = na.concatenate([data[key]] + temp_data, axis=-1)
             else:
-                MPI.COMM_WORLD.Send(data[key], dest=0, tag=0)
+                _send_array(data[key], dest=0, tag=0)
             self._barrier()
-            data[key] = MPI.COMM_WORLD.Bcast(data[key], root=0)
+            data[key] = _bcast_array(data[key])
         self._barrier()
         return data
 
     @parallel_passthrough
-    def __mpi_recvdict(self, data):
-        # First we receive, then we make a new dict.
-        for i in range(1,MPI.COMM_WORLD.size):
-            buf = MPI.COMM_WORLD.Recv(source=i, tag=0)
-            for j in buf: data[j] = na.concatenate([data[j],buf[j]], axis=-1)
-        return data
-
-    @parallel_passthrough
     def __mpi_recvlist(self, data):
         # First we receive, then we make a new list.
         data = ensure_list(data)
         for i in range(1,MPI.COMM_WORLD.size):
-            buf = ensure_list(MPI.COMM_WORLD.Recv(source=i, tag=0))
+            buf = ensure_list(MPI.COMM_WORLD.recv(source=i, tag=0))
             data += buf
         return na.array(data)
 
@@ -268,17 +261,17 @@
         if MPI.COMM_WORLD.rank == 0:
             data = self.__mpi_recvlist(data)
         else:
-            MPI.COMM_WORLD.Send(data, dest=0, tag=0)
+            MPI.COMM_WORLD.send(data, dest=0, tag=0)
         mylog.debug("Opening MPI Broadcast on %s", MPI.COMM_WORLD.rank)
-        data = MPI.COMM_WORLD.Bcast(data, root=0)
+        data = MPI.COMM_WORLD.bcast(data, root=0)
         self._barrier()
         return data
 
     @parallel_passthrough
-    def __mpi_recvarray(self, data):
+    def __mpi_recvarrays(self, data):
         # First we receive, then we make a new list.
         for i in range(1,MPI.COMM_WORLD.size):
-            buf = MPI.COMM_WORLD.Recv(source=i, tag=0)
+            buf = _recv_array(source=i, tag=0)
             data = na.concatenate([data, buf])
         return data
 
@@ -286,9 +279,9 @@
     def _mpi_catarray(self, data):
         self._barrier()
         if MPI.COMM_WORLD.rank == 0:
-            data = self.__mpi_recvarray(data)
+            data = self.__mpi_recvarrays(data)
         else:
-            MPI.COMM_WORLD.Send(data, dest=0, tag=0)
+            _send_array(data, dest=0, tag=0)
         mylog.debug("Opening MPI Broadcast on %s", MPI.COMM_WORLD.rank)
         data = MPI.COMM_WORLD.Bcast(data, root=0)
         self._barrier()
@@ -307,7 +300,9 @@
     @parallel_passthrough
     def _mpi_allsum(self, data):
         self._barrier()
-        return MPI.COMM_WORLD.Allreduce(data, op=MPI.SUM)
+        # We use old-school pickling here on the assumption the arrays are
+        # relatively small ( < 1e7 elements )
+        return MPI.COMM_WORLD.allreduce(data, op=MPI.SUM)
 
     def _mpi_info_dict(self, info):
         mylog.info("Parallel capable: %s", parallel_capable)
@@ -317,11 +312,11 @@
         if MPI.COMM_WORLD.rank == 0:
             data = {0:info}
             for i in range(1, MPI.COMM_WORLD.size):
-                data[i] = MPI.COMM_WORLD.Recv(source=i, tag=0)
+                data[i] = MPI.COMM_WORLD.recv(source=i, tag=0)
         else:
-            MPI.COMM_WORLD.Send(info, dest=0, tag=0)
+            MPI.COMM_WORLD.send(info, dest=0, tag=0)
         mylog.debug("Opening MPI Broadcast on %s", MPI.COMM_WORLD.rank)
-        data = MPI.COMM_WORLD.Bcast(data, root=0)
+        data = MPI.COMM_WORLD.bcast(data, root=0)
         self._barrier()
         return MPI.COMM_WORLD.rank, data
 
@@ -343,3 +338,36 @@
             return open(fn, "w")
         else:
             return cStringIO.StringIO()
+
+__tocast = 'c'
+
+def _send_array(arr, dest, tag = 0):
+    if not isinstance(arr, na.ndarray):
+        MPI.COMM_WORLD.send((None,None), dest=dest, tag=tag)
+        MPI.COMM_WORLD.send(arr, dest=dest, tag=tag)
+        return
+    tmp = arr.view(__tocast) # Cast to CHAR
+    # communicate type and shape
+    MPI.COMM_WORLD.send((arr.dtype.str, arr.shape), dest=dest, tag=tag)
+    MPI.COMM_WORLD.Send([arr, MPI.CHAR], dest=dest, tag=tag)
+    del tmp
+
+def _recv_array(source, tag = 0):
+    dt, ne = MPI.COMM_WORLD.recv(source=source, tag=tag)
+    if dt is None and ne is None:
+        return MPI.COMM_WORLD.recv(source=source, tag=tag)
+    arr = na.empty(ne, dtype=dt)
+    tmp = arr.view(__tocast)
+    MPI.COMM_WORLD.Recv([tmp, MPI.CHAR], source=source, tag=tag)
+    return arr
+
+def _bcast_array(arr, root = 0):
+    if MPI.COMM_WORLD.rank == root:
+        tmp = arr.view(__tocast) # Cast to CHAR
+        MPI.COMM_WORLD.bcast((arr.dtype.str, arr.shape), root=root)
+    else:
+        dt, ne = MPI.COMM_WORLD.bcast(None, root=root)
+        arr = na.empty(ne, dtype=dt)
+        tmp = arr.view(__tocast)
+    MPI.COMM_WORLD.Bcast([tmp, MPI.CHAR], root=root)
+    return arr



More information about the yt-svn mailing list