[Yt-svn] yt-commit r1226 - in trunk/yt: extensions lagos

mturk at wrangler.dreamhost.com mturk at wrangler.dreamhost.com
Sat Mar 21 14:20:49 PDT 2009


Author: mturk
Date: Sat Mar 21 14:20:48 2009
New Revision: 1226
URL: http://yt.spacepope.org/changeset/1226

Log:
Britton and I hammered down this morning, and the halo profiler now works in
parallel.  It currently round-robins, so the 'single' mode will not work in
parallel yet.  The check for parallel_capable in ParallelAnalysisInterface has
been replaced with a check for _distributed.



Modified:
   trunk/yt/extensions/HaloProfiler.py
   trunk/yt/lagos/BaseDataTypes.py
   trunk/yt/lagos/DerivedQuantities.py
   trunk/yt/lagos/ParallelTools.py

Modified: trunk/yt/extensions/HaloProfiler.py
==============================================================================
--- trunk/yt/extensions/HaloProfiler.py	(original)
+++ trunk/yt/extensions/HaloProfiler.py	Sat Mar 21 14:20:48 2009
@@ -33,7 +33,7 @@
 
 PROFILE_RADIUS_THRESHOLD = 2
 
-class HaloProfiler(object):
+class HaloProfiler(lagos.ParallelAnalysisInterface):
     def __init__(self,dataset,HaloProfilerParameterFile,halos='multiple',radius=0.1,radius_units='1',hop_style='new'):
         self.dataset = dataset
         self.HaloProfilerParameterFile = HaloProfilerParameterFile
@@ -49,7 +49,7 @@
         self.halos = halos
         if not(self.halos is 'multiple' or self.halos is 'single'):
             mylog.error("Keyword, halos, must be either 'single' or 'multiple'.")
-            exit(1)
+            return None
 
         # Set hop file style.
         # old: enzo_hop output.
@@ -57,7 +57,7 @@
         self.hop_style = hop_style
         if not(self.hop_style is 'old' or self.hop_style is 'new'):
             mylog.error("Keyword, hop_style, must be either 'old' or 'new'.")
-            exit(1)
+            return None
 
         # Set some parameter defaults.
         self._SetParameterDefaults()
@@ -71,25 +71,26 @@
             if self.haloProfilerParameters['VelocityCenter'][1] == 'halo' and \
                     self.halos is 'single':
                 mylog.error("Parameter, VelocityCenter, must be set to 'bulk sphere' or 'max <field>' with halos flag set to 'single'.")
-                exit(1)
+                return None
             if self.haloProfilerParameters['VelocityCenter'][1] == 'halo' and \
                     self.hop_style is 'old':
                 mylog.error("Parameter, VelocityCenter, must be 'bulk sphere' for old style hop output files.")
-                exit(1)
+                return None
             if not(self.haloProfilerParameters['VelocityCenter'][1] == 'halo' or 
                    self.haloProfilerParameters['VelocityCenter'][1] == 'sphere'):
                 mylog.error("Second value of VelocityCenter must be either 'halo' or 'sphere' if first value is 'bulk'.")
-                exit(1)
+                return None
         elif self.haloProfilerParameters['VelocityCenter'][0] == 'max':
             if self.halos is 'multiple':
                 mylog.error("Getting velocity center from a max field value only works with halos='single'.")
-                exit(1)
+                return None
         else:
             mylog.error("First value of parameter, VelocityCenter, must be either 'bulk' or 'max'.")
-            exit(1)
+            return None
 
         # Create dataset object.
         self.pf = lagos.EnzoStaticOutput(self.dataset)
+        self.pf.h
         if self.halos is 'single' or hop_style is 'old':
             self.haloRadius = radius / self.pf[radius_units]
 
@@ -122,9 +123,8 @@
         else:
             os.mkdir(outputDir)
 
-        pbar = lagos.get_pbar("Profiling halos ", len(self.hopHalos))
-        for q,halo in enumerate(self.hopHalos):
-            filename = "%s/Halo_%04d_profile.dat" % (outputDir,q)
+        for q,halo in enumerate(self._get_objs('hopHalos', round_robin=True)):
+            filename = "%s/Halo_%04d_profile.dat" % (outputDir,halo['id'])
 
             # Read profile from file if it already exists.
             # If not, profile will be None.
@@ -140,6 +140,7 @@
                     continue
 
                 sphere = self.pf.h.sphere(halo['center'],halo['r_max']/self.pf.units['mpc'])
+                if len(sphere._grids) == 0: continue
 
                 # Set velocity to zero out radial velocity profiles.
                 if self.haloProfilerParameters['VelocityCenter'][0] == 'bulk':
@@ -157,7 +158,7 @@
 
                 profile = lagos.BinnedProfile1D(sphere,self.haloProfilerParameters['n_bins'],"RadiusMpc",
                                                 r_min,halo['r_max'],
-                                                log_space=True, lazy_reader=True)
+                                                log_space=True, lazy_reader=False)
                 for field in self.profileFields.keys():
                     profile.add_fields(field,weight=self.profileFields[field][0],
                                        accumulation=self.profileFields[field][1])
@@ -165,13 +166,13 @@
             self._AddActualOverdensity(profile)
 
             virial = self._CalculateVirialQuantities(profile)
-            virial['center'] = self.hopHalos[q]['center']
+            virial['center'] = halo['center']
+            virial['id'] = halo['id']
 
-            if (virial['TotalMassMsun'] < self.haloProfilerParameters['VirialMassCutoff']):
-                self.virialQuantities.append(None)
-            else:
+            if (virial['TotalMassMsun'] >= self.haloProfilerParameters['VirialMassCutoff']):
                 self.virialQuantities.append(virial)
             if newProfile:
+                mylog.info("Writing halo %d" % virial['id'])
                 profile.write_out(filename, format='%0.6e')
             del profile
 
@@ -182,11 +183,21 @@
                 sphere.clear_data()
                 del sphere
 
-            pbar.update(q)
-
-        pbar.finish()
         self._WriteVirialQuantities()
 
+    def _finalize_parallel(self):
+        self.virialQuantities = self._mpi_catlist(self.virialQuantities)
+        self.virialQuantities.sort(key = lambda a:a['id'])
+
+    @lagos.parallel_root_only
+    def __check_directory(self, outputDir):
+        if (os.path.exists(outputDir)):
+            if not(os.path.isdir(outputDir)):
+                mylog.error("Output directory exists, but is not a directory: %s." % outputDir)
+                raise IOError(outputDir)
+        else:
+            os.mkdir(outputDir)
+
     def makeProjections(self,save_images=True,save_cube=True,**kwargs):
         "Make projections of all halos using specified fields."
 
@@ -204,13 +215,7 @@
             projectionResolution = int(self.haloProfilerParameters['ProjectionWidth'] / proj_dx)
 
         outputDir = "%s/%s" % (self.pf.fullpath,self.haloProfilerParameters['ProjectionOutputDir'])
-
-        if (os.path.exists(outputDir)):
-            if not(os.path.isdir(outputDir)):
-                mylog.error("Output directory exists, but is not a directory: %s." % outputDir)
-                return
-        else:
-            os.mkdir(outputDir)
+        self.__check_directory(outputDir)
 
         center = [0.5 * (self.pf.parameters['DomainLeftEdge'][w] + self.pf.parameters['DomainRightEdge'][w])
                   for w in range(self.pf.parameters['TopGridRank'])]
@@ -218,7 +223,7 @@
         # Create a plot collection.
         pc = raven.PlotCollection(self.pf,center=center)
 
-        for q,halo in enumerate(self.virialQuantities):
+        for halo in self._get_objs('hopHalos', round_robin=True):
             if halo is None:
                 continue
             # Check if region will overlap domain edge.
@@ -229,7 +234,7 @@
                          for w in range(len(halo['center']))]
 
             mylog.info("Projecting halo %04d in region: [%f, %f, %f] to [%f, %f, %f]." %
-                       (q,leftEdge[0],leftEdge[1],leftEdge[2],rightEdge[0],rightEdge[1],rightEdge[2]))
+                       (halo['id'],leftEdge[0],leftEdge[1],leftEdge[2],rightEdge[0],rightEdge[1],rightEdge[2]))
 
             need_per = False
             for w in range(len(halo['center'])):
@@ -256,12 +261,11 @@
 
                 # Set x and y limits, shift image if it overlaps domain boundary.
                 if need_per:
+                    pw = self.haloProfilerParameters['ProjectionWidth']/self.pf.units['mpc']
                     ShiftProjections(self.pf,pc,halo['center'],center,w)
                     # Projection has now been shifted to center of box.
-                    proj_left = [center[x_axis]-0.5 * self.haloProfilerParameters['ProjectionWidth']/self.pf.units['mpc'],
-                                 center[y_axis]-0.5 * self.haloProfilerParameters['ProjectionWidth']/self.pf.units['mpc']]
-                    proj_right = [center[x_axis]+0.5 * self.haloProfilerParameters['ProjectionWidth']/self.pf.units['mpc'],
-                                  center[y_axis]+0.5 * self.haloProfilerParameters['ProjectionWidth']/self.pf.units['mpc']]
+                    proj_left = [center[x_axis]-0.5*pw, center[y_axis]-0.5*pw]
+                    proj_right = [center[x_axis]+0.5*pw, center[y_axis]+0.5*pw]
                 else:
                     proj_left = [leftEdge[x_axis],leftEdge[y_axis]]
                     proj_right = [rightEdge[x_axis],rightEdge[y_axis]]
@@ -272,7 +276,8 @@
                 # Save projection data to hdf5 file.
                 if save_cube:
                     axes = ['x','y','z']
-                    dataFilename = "%s/Halo_%04d_%s_data.h5" % (outputDir,q,axes[w])
+                    dataFilename = "%s/Halo_%04d_%s_data.h5" % \
+                            (outputDir,halo['id'],axes[w])
                     mylog.info("Saving projection data to %s." % dataFilename)
 
                     output = h5.openFile(dataFilename, "a")
@@ -285,7 +290,7 @@
                     output.close()
 
                 if save_images:
-                    pc.save("%s/Halo_%04d" % (outputDir,q))
+                    pc.save("%s/Halo_%04d" % (outputDir,halo['id']))
 
                 pc.clear_plots()
 
@@ -293,18 +298,17 @@
 
         del pc
 
+    @lagos.parallel_root_only
     def _WriteVirialQuantities(self):
         "Write out file with halo centers and virial masses and radii."
         filename = "%s/%s" % (self.pf.fullpath,self.haloProfilerParameters['VirialQuantitiesOutputFile'])
         mylog.info("Writing virial quantities to %s." % filename)
         file = open(filename,'w')
         file.write("#Index\tx\ty\tz\tMass [Msolar]\tRadius [Mpc]\n")
-        for q in range(len(self.virialQuantities)):
-            if (self.virialQuantities[q] is not None):
-                file.write("%04d %.10f %.10f %.10f %.6e %.6e\n" % (q,self.hopHalos[q]['center'][0],self.hopHalos[q]['center'][1],
-                                                                   self.hopHalos[q]['center'][2],
-                                                                   self.virialQuantities[q]['TotalMassMsun'],
-                                                                   self.virialQuantities[q]['RadiusMpc']))
+        for vq in self.virialQuantities:
+            if vq is not None:
+                file.write("%04d %.10f %.10f %.10f %.6e %.6e\n" % (vq['id'],vq['center'][0],vq['center'][1],vq['center'][2],
+                                                                   vq['TotalMassMsun'], vq['RadiusMpc']))
         file.close()
 
     def _AddActualOverdensity(self,profile):
@@ -390,12 +394,13 @@
             line = line.strip()
             if not(line.startswith('#')):
                 onLine = line.split()
+                id = int(onLine[0])
                 mass = float(onLine[1])
                 if (mass >= self.haloProfilerParameters['VirialMassCutoff']):
                     center = [float(onLine[7]),float(onLine[8]),float(onLine[9])]
                     velocity = [float(onLine[10]),float(onLine[11]),float(onLine[12])]
                     r_max = float(onLine[13]) * self.pf.units['mpc']
-                    halo = {'center': center, 'r_max': r_max, 'velocity': velocity}
+                    halo = {'id': id, 'center': center, 'r_max': r_max, 'velocity': velocity}
                     self.hopHalos.append(halo)
 
         mylog.info("Loaded %d halos with total dark matter mass af at least %e Msolar." % 
@@ -518,8 +523,6 @@
                 virial['TotalMassMsun'] = float(onLine[4])
                 virial['RadiusMpc'] = float(onLine[5])
                 if (virial['TotalMassMsun'] >= self.haloProfilerParameters['VirialMassCutoff']):
-                    for q in range(index - len(self.virialQuantities)):
-                        self.virialQuantities.append(None)
                     self.virialQuantities.append(virial)
                     halos += 1
 

Modified: trunk/yt/lagos/BaseDataTypes.py
==============================================================================
--- trunk/yt/lagos/BaseDataTypes.py	(original)
+++ trunk/yt/lagos/BaseDataTypes.py	Sat Mar 21 14:20:48 2009
@@ -946,6 +946,7 @@
             self._check_region = check
             #self._okay_to_serialize = (not check)
         else:
+            self._distributed = False
             self._okay_to_serialize = False
             self._check_region = True
         self.source = source

Modified: trunk/yt/lagos/DerivedQuantities.py
==============================================================================
--- trunk/yt/lagos/DerivedQuantities.py	(original)
+++ trunk/yt/lagos/DerivedQuantities.py	Sat Mar 21 14:20:48 2009
@@ -73,7 +73,7 @@
         return self.c_func(self._data_source, *self.retvals)
 
     def _finalize_parallel(self):
-        self.retvals = [self._mpi_catlist(my_list) for my_list in self.retvals]
+        self.retvals = [na.array(self._mpi_catlist(my_list)) for my_list in self.retvals]
         
     def _call_func_unlazy(self, args, kwargs):
         retval = self.func(self._data_source, *args, **kwargs)

Modified: trunk/yt/lagos/ParallelTools.py
==============================================================================
--- trunk/yt/lagos/ParallelTools.py	(original)
+++ trunk/yt/lagos/ParallelTools.py	Sat Mar 21 14:20:48 2009
@@ -82,8 +82,9 @@
     This takes an object, pobj, that implements ParallelAnalysisInterface,
     and then does its thing.
     """
-    def __init__(self, pobj, just_list = False, attr='_grids'):
-        ObjectIterator.__init__(self, pobj, just_list)
+    def __init__(self, pobj, just_list = False, attr='_grids',
+                 round_robin=False):
+        ObjectIterator.__init__(self, pobj, just_list, attr=attr)
         self._offset = MPI.COMM_WORLD.rank
         self._skip = MPI.COMM_WORLD.size
         # Note that we're doing this in advance, and with a simple means
@@ -91,8 +92,11 @@
         if self._use_all:
             self.my_obj_ids = na.arange(len(self._objs))
         else:
-            self.my_obj_ids = na.array_split(
-                            na.arange(len(self._objs)), self._skip)[self._offset]
+            if not round_robin:
+                self.my_obj_ids = na.array_split(
+                                na.arange(len(self._objs)), self._skip)[self._offset]
+            else:
+                self.my_obj_ids = na.arange(len(self._objs))[self._offset::self._skip]
         
     def __iter__(self):
         for gid in self.my_obj_ids:
@@ -132,7 +136,7 @@
 def parallel_passthrough(func):
     @wraps(func)
     def passage(self, data):
-        if not parallel_capable: return data
+        if not self._distributed: return data
         return func(self, data)
     return passage
 
@@ -150,24 +154,44 @@
     else:
         return func
 
+def parallel_root_only(func):
+    @wraps(func)
+    def root_only(*args, **kwargs):
+        if MPI.COMM_WORLD.rank == 0:
+            try:
+                func(*args, **kwargs)
+                all_clear = 1
+            except:
+                traceback.print_last()
+                all_clear = 0
+        else:
+            all_clear = None
+        MPI.COMM_WORLD.Barrier()
+        all_clear = MPI.COMM_WORLD.bcast(all_clear, root=0)
+        if not all_clear: raise RuntimeError
+    if parallel_capable: return root_only
+    return func
+
 class ParallelAnalysisInterface(object):
     _grids = None
     _distributed = parallel_capable
 
     def _get_objs(self, attr, *args, **kwargs):
-        if parallel_capable:
+        if self._distributed:
+            rr = kwargs.pop("round_robin", False)
             self._initialize_parallel(*args, **kwargs)
-            return ParallelObjectIterator(self, attr=attr)
+            return ParallelObjectIterator(self, attr=attr,
+                    round_robin=rr)
         return ObjectIterator(self, attr=attr)
 
     def _get_grids(self, *args, **kwargs):
-        if parallel_capable:
+        if self._distributed:
             self._initialize_parallel(*args, **kwargs)
             return ParallelObjectIterator(self, attr='_grids')
         return ObjectIterator(self, attr='_grids')
 
     def _get_grid_objs(self):
-        if parallel_capable:
+        if self._distributed:
             return ParallelObjectIterator(self, True, attr='_grids')
         return ObjectIterator(self, True, attr='_grids')
 
@@ -178,7 +202,7 @@
         pass
 
     def _partition_hierarchy_2d(self, axis):
-        if not parallel_capable:
+        if not self._distributed:
            return False, self.hierarchy.grid_collection(self.center, self.hierarchy.grids)
 
         xax, yax = x_dict[axis], y_dict[axis]
@@ -202,7 +226,7 @@
 
     def _partition_hierarchy_3d(self, padding=0.0):
         LE, RE = self.pf["DomainLeftEdge"], self.pf["DomainRightEdge"]
-        if not parallel_capable:
+        if not self._distributed:
            return False, LE, RE, self.hierarchy.grid_collection(self.center, self.hierarchy.grids)
 
         cc = MPI.Compute_dims(MPI.COMM_WORLD.size, 3)
@@ -222,7 +246,7 @@
         return False, LE, RE, self.hierarchy.region_strict(self.center, LE, RE)
         
     def _barrier(self):
-        if not parallel_capable: return
+        if not self._distributed: return
         mylog.debug("Opening MPI Barrier on %s", MPI.COMM_WORLD.rank)
         MPI.COMM_WORLD.Barrier()
 
@@ -266,7 +290,7 @@
         for i in range(1,MPI.COMM_WORLD.size):
             buf = ensure_list(MPI.COMM_WORLD.recv(source=i, tag=0))
             data += buf
-        return na.array(data)
+        return data
 
     @parallel_passthrough
     def _mpi_catlist(self, data):
@@ -301,13 +325,13 @@
         return data
 
     def _should_i_write(self):
-        if not parallel_capable: return True
+        if not self._distributed: return True
         return (MPI.COMM_WORLD == 0)
 
     def _preload(self, grids, fields, queue):
         # This will preload if it detects we are parallel capable and
         # if so, we load *everything* that we need.  Use with some care.
-        if not parallel_capable: return
+        if not self._distributed: return
         queue.preload(grids, fields)
 
     @parallel_passthrough
@@ -318,8 +342,7 @@
         return MPI.COMM_WORLD.allreduce(data, op=MPI.SUM)
 
     def _mpi_info_dict(self, info):
-        mylog.info("Parallel capable: %s", parallel_capable)
-        if not parallel_capable: return 0, {0:info}
+        if not self._distributed: return 0, {0:info}
         self._barrier()
         data = None
         if MPI.COMM_WORLD.rank == 0:
@@ -341,19 +364,19 @@
         return list(set(deps))
 
     def _claim_object(self, obj):
-        if not parallel_capable: return
+        if not self._distributed: return
         obj._owner = MPI.COMM_WORLD.rank
         obj._distributed = True
 
     def _write_on_root(self, fn):
-        if not parallel_capable: return open(fn, "w")
+        if not self._distributed: return open(fn, "w")
         if MPI.COMM_WORLD.rank == 0:
             return open(fn, "w")
         else:
             return cStringIO.StringIO()
 
     def _get_filename(self, prefix):
-        if not parallel_capable: return prefix
+        if not self._distributed: return prefix
         return "%s_%03i" % (prefix, MPI.COMM_WORLD.rank)
 
     def _is_mine(self, obj):



More information about the yt-svn mailing list