[Yt-svn] yt-commit r902 - trunk/yt/lagos
mturk at wrangler.dreamhost.com
mturk at wrangler.dreamhost.com
Thu Nov 6 10:28:18 PST 2008
Author: mturk
Date: Thu Nov 6 10:28:18 2008
New Revision: 902
URL: http://yt.spacepope.org/changeset/902
Log:
Added some parallel passthrough stuff, a new recvarray and catarray set of
functions.
Modified:
trunk/yt/lagos/ParallelTools.py
Modified: trunk/yt/lagos/ParallelTools.py
==============================================================================
--- trunk/yt/lagos/ParallelTools.py (original)
+++ trunk/yt/lagos/ParallelTools.py Thu Nov 6 10:28:18 2008
@@ -108,6 +108,13 @@
if not self.just_list: self.pobj._finalize_parallel()
raise StopIteration
+def parallel_passthrough(func):
+ @wraps(func)
+ def passage(self, data):
+ if not parallel_capable: return data
+ return func(self, data)
+ return passage
+
class ParallelAnalysisInterface(object):
_grids = None
@@ -163,13 +170,13 @@
if padding > 0:
return True, \
- self.hierarchy.periodic_region(self.center, LE-padding, RE+padding)
+ LE, RE, self.hierarchy.periodic_region(self.center, LE-padding, RE+padding)
- return True, self.hierarchy.region(self.center, LE, RE)
+ return LE, RE, self.hierarchy.region(self.center, LE, RE)
+ @parallel_passthrough
def _mpi_catdict(self, data):
- if not parallel_capable: return data
mylog.debug("Opening MPI Barrier on %s", MPI.COMM_WORLD.rank)
MPI.COMM_WORLD.Barrier()
if MPI.COMM_WORLD.rank == 0:
@@ -181,6 +188,7 @@
MPI.COMM_WORLD.Barrier()
return data
+ @parallel_passthrough
def __mpi_recvdict(self, data):
# First we receive, then we make a new dict.
for i in range(1,MPI.COMM_WORLD.size):
@@ -188,6 +196,7 @@
for j in buf: data[j] = na.concatenate([data[j],buf[j]], axis=-1)
return data
+ @parallel_passthrough
def __mpi_recvlist(self, data):
# First we receive, then we make a new list.
for i in range(1,MPI.COMM_WORLD.size):
@@ -195,8 +204,8 @@
data += buf
return na.array(data)
+ @parallel_passthrough
def _mpi_catlist(self, data):
- if not parallel_capable: return data
mylog.debug("Opening MPI Barrier on %s", MPI.COMM_WORLD.rank)
MPI.COMM_WORLD.Barrier()
if MPI.COMM_WORLD.rank == 0:
@@ -208,6 +217,26 @@
MPI.COMM_WORLD.Barrier()
return data
+ @parallel_passthrough
+ def __mpi_recvarray(self, data):
+ # First we receive, then we make a new list.
+ for i in range(1,MPI.COMM_WORLD.size):
+ buf = MPI.COMM_WORLD.Recv(source=i, tag=0)
+ data = na.concatenate([data, buf])
+ return data
+
+ @parallel_passthrough
+ def _mpi_catarray(self, data):
+ mylog.debug("Opening MPI Barrier on %s", MPI.COMM_WORLD.rank)
+ MPI.COMM_WORLD.Barrier()
+ if MPI.COMM_WORLD.rank == 0:
+ data = self.__mpi_recvarray(data)
+ else:
+ MPI.COMM_WORLD.Send(data, dest=0, tag=0)
+ mylog.debug("Opening MPI Broadcast on %s", MPI.COMM_WORLD.rank)
+ data = MPI.COMM_WORLD.Bcast(data, root=0)
+ MPI.COMM_WORLD.Barrier()
+
def _should_i_write(self):
if not parallel_capable: return True
return (MPI.COMM_WORLD == 0)
@@ -218,10 +247,12 @@
if not parallel_capable: return
queue.preload(grids, fields)
+ @parallel_passthrough
def _mpi_allsum(self, data):
MPI.COMM_WORLD.Barrier()
return MPI.COMM_WORLD.Allreduce(data, op=MPI.SUM)
+
def _get_dependencies(self, fields):
deps = []
for field in fields:
More information about the yt-svn
mailing list