[yt-svn] commit/yt: 2 new changesets

Bitbucket commits-noreply at bitbucket.org
Tue Mar 20 09:18:09 PDT 2012


2 new commits in yt:


https://bitbucket.org/yt_analysis/yt/changeset/b03443ed6d88/
changeset:   b03443ed6d88
branch:      yt
user:        sskory
date:        2012-03-19 22:09:28
summary:     Modifying the merger tree such that only the root task
ever touches the SQLite database for writes AND reads.
Previously, writes were only done by the root task,
but reads were done by all. This should hopefully
A) speed things up and B) reduce memory footprint
from the fopens on the database.
affected #:  2 files

diff -r d82e3fe882046290ae7a21d1f2fb1272d500ddd5 -r b03443ed6d88bb3a2d6b8d7af272f393123777e6 yt/analysis_modules/halo_merger_tree/merger_tree.py
--- a/yt/analysis_modules/halo_merger_tree/merger_tree.py
+++ b/yt/analysis_modules/halo_merger_tree/merger_tree.py
@@ -109,7 +109,7 @@
 class MergerTree(DatabaseFunctions, ParallelAnalysisInterface):
     def __init__(self, restart_files=[], database='halos.db',
             halo_finder_function=HaloFinder, halo_finder_threshold=80.0,
-            FOF_link_length=0.2, dm_only=False, refresh=False, sleep=1,
+            FOF_link_length=0.2, dm_only=False, refresh=False,
             index=True):
         r"""Build a merger tree of halos over a time-ordered set of snapshots.
         This will run a halo finder to find the halos first if it hasn't already
@@ -140,12 +140,6 @@
         refresh : Boolean
             True forces the halo finder to run even if the halo data has been
             detected on disk. Default = False.
-        sleep : Float
-            Due to the nature of the SQLite database and network file systems,
-            it is crucial that all tasks see the database in the same state at
-            all times. This parameter specifies how long in seconds the merger
-            tree waits between checks to ensure the database is synched across
-            all tasks. Default = 1.
         index : Boolean
             SQLite databases can have added to them an index which greatly
             speeds up future queries of the database,
@@ -168,9 +162,6 @@
         self.FOF_link_length= FOF_link_length # For FOF
         self.dm_only = dm_only
         self.refresh = refresh
-        self.sleep = sleep # How long to wait between db sync checks.
-        if self.sleep <= 0.:
-            self.sleep = 5
         # MPI stuff
         self.mine = self.comm.rank
         if self.mine is None:
@@ -184,7 +175,6 @@
                 os.unlink(self.database)
             except:
                 pass
-        self.comm.barrier()
         self._open_create_database()
         self._create_halo_table()
         self._run_halo_finder_add_to_db()
@@ -247,12 +237,17 @@
                 del halos
             # Now add halo data to the db if it isn't already there by
             # checking the first halo.
-            currt = pf.unique_identifier
-            line = "SELECT GlobalHaloID from Halos where SnapHaloID=0\
-            and SnapCurrentTimeIdentifier=%d;" % currt
-            self.cursor.execute(line)
-            result = self.cursor.fetchone()
-            if result != None:
+            continue_check = False
+            if self.mine == 0:
+                currt = pf.unique_identifier
+                line = "SELECT GlobalHaloID from Halos where SnapHaloID=0\
+                and SnapCurrentTimeIdentifier=%d;" % currt
+                self.cursor.execute(line)
+                result = self.cursor.fetchone()
+                if result != None:
+                    continue_check = True
+            continue_check = self.comm.mpi_bcast_pickled(continue_check)
+            if continue_check:
                 continue
             red = pf.current_redshift
             # Read the halos off the disk using the Halo Profiler tools.
@@ -261,6 +256,7 @@
             if len(hp.all_halos) == 0:
                 mylog.info("Dataset %s has no halos." % file)
                 self.with_halos[cycle] = False
+                del hp
                 continue
             mylog.info("Entering halos into database for z=%f" % red)
             if self.mine == 0:
@@ -284,43 +280,10 @@
     
     def _open_create_database(self):
         # open the database. This creates the database file on disk if it
-        # doesn't already exist. Open it first on root, and then on the others.
+        # doesn't already exist. Open it on root only.
         if self.mine == 0:
             self.conn = sql.connect(self.database)
-        self.comm.barrier()
-        self._ensure_db_sync()
-        if self.mine != 0:
-            self.conn = sql.connect(self.database)
-        self.cursor = self.conn.cursor()
-
-    def _ensure_db_sync(self):
-        # If the database becomes out of sync for each task, ostensibly due to
-        # parallel file system funniness, things will go bad very quickly.
-        # Therefore, just to be very, very careful, we will ensure that the
-        # md5 hash of the file is identical across all tasks before proceeding.
-        self.comm.barrier()
-        for i in range(5):
-            try:
-                file = open(self.database)
-            except IOError:
-                # This is to give a little bit of time for the database creation
-                # to replicate across the file system.
-                time.sleep(self.sleep)
-                file = open(self.database)
-            hash = md5.md5(file.read()).hexdigest()
-            file.close()
-            ignore, hashes = self.comm.mpi_info_dict(hash)
-            hashes = set(hashes.values())
-            if len(hashes) == 1:
-                break
-            else:
-                # Wait a little bit for the file system to (hopefully) sync up.
-                time.sleep(self.sleep)
-        if len(hashes) == 1:
-            return
-        else:
-            mylog.error("The file system is not properly synchronizing the database.")
-            raise RunTimeError("Fatal error. Exiting.")
+            self.cursor = self.conn.cursor()
 
     def _create_halo_table(self):
         if self.mine == 0:
@@ -342,69 +305,74 @@
                 self.conn.commit()
             except sql.OperationalError:
                 pass
-        self.comm.barrier()
     
     def _find_likely_children(self, parentfile, childfile):
         # For each halo in the parent list, identify likely children in the 
         # list of children.
-        
+
         # First, read in the locations of the child halos.
         child_pf = load(childfile)
         child_t = child_pf.unique_identifier
-        line = "SELECT SnapHaloID, CenMassX, CenMassY, CenMassZ FROM \
-        Halos WHERE SnapCurrentTimeIdentifier = %d" % child_t
-        self.cursor.execute(line)
-        
-        mylog.info("Finding likely parents for z=%1.5f child halos." % \
-            child_pf.current_redshift)
-        
-        # Build the kdtree for the children by looping over the fetched rows.
-        # Normalize the points for use only within the kdtree.
-        child_points = []
-        for row in self.cursor:
-            child_points.append([row[1] / self.period[0],
-            row[2] / self.period[1],
-            row[3] / self.period[2]])
-        # Turn it into fortran.
-        child_points = na.array(child_points)
-        fKD.pos = na.asfortranarray(child_points.T)
-        fKD.qv = na.empty(3, dtype='float64')
-        fKD.dist = na.empty(NumNeighbors, dtype='float64')
-        fKD.tags = na.empty(NumNeighbors, dtype='int64')
-        fKD.nn = NumNeighbors
-        fKD.sort = True
-        fKD.rearrange = True
-        create_tree(0)
-
+        if self.mine == 0:
+            line = "SELECT SnapHaloID, CenMassX, CenMassY, CenMassZ FROM \
+            Halos WHERE SnapCurrentTimeIdentifier = %d" % child_t
+            self.cursor.execute(line)
+            
+            mylog.info("Finding likely parents for z=%1.5f child halos." % \
+                child_pf.current_redshift)
+            
+            # Build the kdtree for the children by looping over the fetched rows.
+            # Normalize the points for use only within the kdtree.
+            child_points = []
+            for row in self.cursor:
+                child_points.append([row[1] / self.period[0],
+                row[2] / self.period[1],
+                row[3] / self.period[2]])
+            # Turn it into fortran.
+            child_points = na.array(child_points)
+            fKD.pos = na.asfortranarray(child_points.T)
+            fKD.qv = na.empty(3, dtype='float64')
+            fKD.dist = na.empty(NumNeighbors, dtype='float64')
+            fKD.tags = na.empty(NumNeighbors, dtype='int64')
+            fKD.nn = NumNeighbors
+            fKD.sort = True
+            fKD.rearrange = True
+            create_tree(0)
+    
         # Find the parent points from the database.
         parent_pf = load(parentfile)
         parent_t = parent_pf.unique_identifier
-        line = "SELECT SnapHaloID, CenMassX, CenMassY, CenMassZ FROM \
-        Halos WHERE SnapCurrentTimeIdentifier = %d" % parent_t
-        self.cursor.execute(line)
+        if self.mine == 0:
+            line = "SELECT SnapHaloID, CenMassX, CenMassY, CenMassZ FROM \
+            Halos WHERE SnapCurrentTimeIdentifier = %d" % parent_t
+            self.cursor.execute(line)
+    
+            # Loop over the returned rows, and find the likely neighbors for the
+            # parents.
+            candidates = {}
+            for row in self.cursor:
+                # Normalize positions for use within the kdtree.
+                fKD.qv = na.array([row[1] / self.period[0],
+                row[2] / self.period[1],
+                row[3] / self.period[2]])
+                find_nn_nearest_neighbors()
+                NNtags = fKD.tags[:] - 1
+                nIDs = []
+                for n in NNtags:
+                    nIDs.append(n)
+                # We need to fill in fake halos if there aren't enough halos,
+                # which can happen at high redshifts.
+                while len(nIDs) < NumNeighbors:
+                    nIDs.append(-1)
+                candidates[row[0]] = nIDs
+            
+            del fKD.pos, fKD.tags, fKD.dist
+            free_tree(0) # Frees the kdtree object.
+        else:
+            candidates = None
 
-        # Loop over the returned rows, and find the likely neighbors for the
-        # parents.
-        candidates = {}
-        for row in self.cursor:
-            # Normalize positions for use within the kdtree.
-            fKD.qv = na.array([row[1] / self.period[0],
-            row[2] / self.period[1],
-            row[3] / self.period[2]])
-            find_nn_nearest_neighbors()
-            NNtags = fKD.tags[:] - 1
-            nIDs = []
-            for n in NNtags:
-                nIDs.append(n)
-            # We need to fill in fake halos if there aren't enough halos,
-            # which can happen at high redshifts.
-            while len(nIDs) < NumNeighbors:
-                nIDs.append(-1)
-            candidates[row[0]] = nIDs
-        
-        del fKD.pos, fKD.tags, fKD.dist
-        free_tree(0) # Frees the kdtree object.
-        
+        # Sync across tasks.
+        candidates = self.comm.mpi_bcast_pickled(candidates)
         self.candidates = candidates
         
         # This stores the masses contributed to each child candidate.
@@ -613,24 +581,31 @@
         # Now we sum up the contributions globally.
         self.child_mass_arr = self.comm.mpi_allreduce(self.child_mass_arr)
         
-        # Turn these Msol masses into percentages of the parent.
-        line = "SELECT HaloMass FROM Halos WHERE SnapCurrentTimeIdentifier=%d \
-        ORDER BY SnapHaloID ASC;" % parent_currt
-        self.cursor.execute(line)
-        mark = 0
-        result = self.cursor.fetchone()
-        while result:
-            mass = result[0]
-            self.child_mass_arr[mark:mark+NumNeighbors] /= mass
-            mark += NumNeighbors
+        if self.mine == 0:
+            # Turn these Msol masses into percentages of the parent.
+            line = "SELECT HaloMass FROM Halos WHERE SnapCurrentTimeIdentifier=%d \
+            ORDER BY SnapHaloID ASC;" % parent_currt
+            self.cursor.execute(line)
+            mark = 0
             result = self.cursor.fetchone()
+            while result:
+                mass = result[0]
+                self.child_mass_arr[mark:mark+NumNeighbors] /= mass
+                mark += NumNeighbors
+                result = self.cursor.fetchone()
+            
+            # Get the global ID for the SnapHaloID=0 from the child, this will
+            # be used to prevent unnecessary SQL reads.
+            line = "SELECT GlobalHaloID FROM Halos WHERE SnapCurrentTimeIdentifier=%d \
+            AND SnapHaloID=0;" % child_currt
+            self.cursor.execute(line)
+            baseChildID = self.cursor.fetchone()[0]
+        else:
+            baseChildID = None
         
-        # Get the global ID for the SnapHaloID=0 from the child, this will
-        # be used to prevent unnecessary SQL reads.
-        line = "SELECT GlobalHaloID FROM Halos WHERE SnapCurrentTimeIdentifier=%d \
-        AND SnapHaloID=0;" % child_currt
-        self.cursor.execute(line)
-        baseChildID = self.cursor.fetchone()[0]
+        # Sync up data on all tasks.
+        self.child_mass_arr = self.comm.mpi_Bcast_array(self.child_mass_arr)
+        baseChildID = self.comm.mpi_bcast_pickled(baseChildID)
         
         # Now we prepare a big list of writes to put in the database.
         for i,parent_halo in enumerate(sorted(self.candidates)):
@@ -663,6 +638,7 @@
         del parent_IDs, parent_masses, parent_halos
         del parent_IDs_tosend, parent_masses_tosend
         del parent_halos_tosend, child_IDs_tosend, child_halos_tosend
+        gc.collect()
         
         return (child_IDs, child_masses, child_halos)
 
@@ -729,7 +705,8 @@
             temp_cursor.execute(line)
             temp_cursor.close()
             temp_conn.close()
-        self._close_database()
+        if self.mine == 0:
+            self._close_database()
         self.comm.barrier()
         if self.mine == 0:
             os.rename(temp_name, self.database)


diff -r d82e3fe882046290ae7a21d1f2fb1272d500ddd5 -r b03443ed6d88bb3a2d6b8d7af272f393123777e6 yt/utilities/parallel_tools/parallel_analysis_interface.py
--- a/yt/utilities/parallel_tools/parallel_analysis_interface.py
+++ b/yt/utilities/parallel_tools/parallel_analysis_interface.py
@@ -507,6 +507,19 @@
         data = self.comm.bcast(data, root=0)
         return data
 
+    @parallel_passthrough
+    def mpi_Bcast_array(self, data):
+        if self.comm.rank == 0:
+            info = (data.shape, data.dtype)
+        else:
+            info = ()
+        info = self.comm.bcast(info, root=0)
+        if self.comm.rank != 0:
+            data = na.empty(info[0], dtype=info[1])
+        mpi_type = get_mpi_type(info[1])
+        self.comm.Bcast([data, mpi_type], root=0)
+        return data
+
     def preload(self, grids, fields, io_handler):
         # This will preload if it detects we are parallel capable and
         # if so, we load *everything* that we need.  Use with some care.



https://bitbucket.org/yt_analysis/yt/changeset/fcc157c8cf79/
changeset:   fcc157c8cf79
branch:      yt
user:        sskory
date:        2012-03-20 17:08:27
summary:     There is now only one mpi_bcast that auto-senses what it should do:
pickled or non-pickled.

Removed self.mine/size in favor of self.comm.rank/size.

Combined one bit of mostly repeated code into a new function.
affected #:  6 files

diff -r b03443ed6d88bb3a2d6b8d7af272f393123777e6 -r fcc157c8cf796edc172191549830973eeba3f270 yt/analysis_modules/halo_finding/halo_objects.py
--- a/yt/analysis_modules/halo_finding/halo_objects.py
+++ b/yt/analysis_modules/halo_finding/halo_objects.py
@@ -1580,7 +1580,7 @@
             if self.comm.rank == 0:
                 self._recursive_divide(root_points, topbounds, 0, cut_list)
             self.bucket_bounds = \
-                self.comm.mpi_bcast_pickled(self.bucket_bounds)
+                self.comm.mpi_bcast(self.bucket_bounds)
             my_bounds = self.bucket_bounds[self.comm.rank]
             LE, RE = my_bounds[0], my_bounds[1]
             self._data_source = self.hierarchy.region_strict([0.] * 3, LE, RE)


diff -r b03443ed6d88bb3a2d6b8d7af272f393123777e6 -r fcc157c8cf796edc172191549830973eeba3f270 yt/analysis_modules/halo_finding/rockstar/rockstar.py
--- a/yt/analysis_modules/halo_finding/rockstar/rockstar.py
+++ b/yt/analysis_modules/halo_finding/rockstar/rockstar.py
@@ -76,7 +76,7 @@
             del sock
         else:
             server_address, port = None, None
-        self.server_address, self.port = self.comm.mpi_bcast_pickled(
+        self.server_address, self.port = self.comm.mpi_bcast(
             (server_address, port))
         self.port = str(self.port)
 


diff -r b03443ed6d88bb3a2d6b8d7af272f393123777e6 -r fcc157c8cf796edc172191549830973eeba3f270 yt/analysis_modules/halo_merger_tree/merger_tree.py
--- a/yt/analysis_modules/halo_merger_tree/merger_tree.py
+++ b/yt/analysis_modules/halo_merger_tree/merger_tree.py
@@ -163,14 +163,12 @@
         self.dm_only = dm_only
         self.refresh = refresh
         # MPI stuff
-        self.mine = self.comm.rank
-        if self.mine is None:
-            self.mine = 0
-        self.size = self.comm.size
-        if self.size is None:
-            self.size = 1
+        if self.comm.rank is None:
+            self.comm.rank = 0
+        if self.comm.size is None:
+            self.comm.size = 1
         # Get to work.
-        if self.refresh and self.mine == 0:
+        if self.refresh and self.comm.rank == 0:
             try:
                 os.unlink(self.database)
             except:
@@ -196,7 +194,9 @@
         del last
         # Now update the database with all the writes.
         mylog.info("Updating database with parent-child relationships.")
-        self._copy_and_update_db()
+        if self.comm.rank == 0:
+            self._copy_and_update_db()
+        self.comm.barrier()
         self.comm.barrier()
         mylog.info("Done!")
         
@@ -238,7 +238,7 @@
             # Now add halo data to the db if it isn't already there by
             # checking the first halo.
             continue_check = False
-            if self.mine == 0:
+            if self.comm.rank == 0:
                 currt = pf.unique_identifier
                 line = "SELECT GlobalHaloID from Halos where SnapHaloID=0\
                 and SnapCurrentTimeIdentifier=%d;" % currt
@@ -246,7 +246,7 @@
                 result = self.cursor.fetchone()
                 if result != None:
                     continue_check = True
-            continue_check = self.comm.mpi_bcast_pickled(continue_check)
+            continue_check = self.comm.mpi_bcast(continue_check)
             if continue_check:
                 continue
             red = pf.current_redshift
@@ -259,7 +259,7 @@
                 del hp
                 continue
             mylog.info("Entering halos into database for z=%f" % red)
-            if self.mine == 0:
+            if self.comm.rank == 0:
                 for ID,halo in enumerate(hp.all_halos):
                     numpart = int(halo['numpart'])
                     values = (None, currt, red, ID, halo['mass'], numpart,
@@ -281,12 +281,12 @@
     def _open_create_database(self):
         # open the database. This creates the database file on disk if it
         # doesn't already exist. Open it on root only.
-        if self.mine == 0:
+        if self.comm.rank == 0:
             self.conn = sql.connect(self.database)
             self.cursor = self.conn.cursor()
 
     def _create_halo_table(self):
-        if self.mine == 0:
+        if self.comm.rank == 0:
             # Handle the error if it already exists.
             try:
                 # Create the table that will store the halo data.
@@ -313,7 +313,7 @@
         # First, read in the locations of the child halos.
         child_pf = load(childfile)
         child_t = child_pf.unique_identifier
-        if self.mine == 0:
+        if self.comm.rank == 0:
             line = "SELECT SnapHaloID, CenMassX, CenMassY, CenMassZ FROM \
             Halos WHERE SnapCurrentTimeIdentifier = %d" % child_t
             self.cursor.execute(line)
@@ -342,7 +342,7 @@
         # Find the parent points from the database.
         parent_pf = load(parentfile)
         parent_t = parent_pf.unique_identifier
-        if self.mine == 0:
+        if self.comm.rank == 0:
             line = "SELECT SnapHaloID, CenMassX, CenMassY, CenMassZ FROM \
             Halos WHERE SnapCurrentTimeIdentifier = %d" % parent_t
             self.cursor.execute(line)
@@ -372,7 +372,7 @@
             candidates = None
 
         # Sync across tasks.
-        candidates = self.comm.mpi_bcast_pickled(candidates)
+        candidates = self.comm.mpi_bcast(candidates)
         self.candidates = candidates
         
         # This stores the masses contributed to each child candidate.
@@ -425,7 +425,7 @@
             parent_masses = na.array([], dtype='float64')
             parent_halos = na.array([], dtype='int32')
             for i,pname in enumerate(parent_names):
-                if i>=self.mine and i%self.size==self.mine:
+                if i>=self.comm.rank and i%self.comm.size==self.comm.rank:
                     h5fp = h5py.File(pname)
                     for group in h5fp:
                         gID = int(group[4:])
@@ -457,7 +457,7 @@
         child_masses = na.array([], dtype='float64')
         child_halos = na.array([], dtype='int32')
         for i,cname in enumerate(child_names):
-            if i>=self.mine and i%self.size==self.mine:
+            if i>=self.comm.rank and i%self.comm.size==self.comm.rank:
                 h5fp = h5py.File(cname)
                 for group in h5fp:
                     gID = int(group[4:])
@@ -478,39 +478,9 @@
         child_send = na.ones(child_IDs.size, dtype='bool')
         del sort
         
-        # Parent IDs on the left, child IDs on the right. We skip down both
-        # columns matching IDs. If they are out of synch, the index(es) is/are
-        # advanced until they match up again.
-        left = 0
-        right = 0
-        while left < parent_IDs.size and right < child_IDs.size:
-            if parent_IDs[left] == child_IDs[right]:
-                # They match up, add this relationship.
-                try:
-                    loc = self.child_mass_loc[parent_halos[left]][child_halos[right]]
-                except KeyError:
-                    # This happens when a child halo contains a particle from
-                    # a parent halo, but the child is not identified as a 
-                    # candidate child halo. So we do nothing and move on with
-                    # our lives.
-                    left += 1
-                    right += 1
-                    continue
-                self.child_mass_arr[loc] += parent_masses[left]
-                # Mark this pair so we don't send them later.
-                parent_send[left] = False
-                child_send[right] = False
-                left += 1
-                right += 1
-                continue
-            if parent_IDs[left] < child_IDs[right]:
-                # The left is too small, so we need to increase it.
-                left += 1
-                continue
-            if parent_IDs[left] > child_IDs[right]:
-                # Right too small.
-                right += 1
-                continue
+        # Match particles in halos.
+        self._match(parent_IDs, child_IDs, parent_halos, child_halos,
+        parent_masses, parent_send, child_send)
 
         # Now we send all the un-matched particles to the root task for one more
         # pass. This depends on the assumption that most of the particles do
@@ -544,44 +514,15 @@
         child_halos_tosend = child_halos_tosend[Csort]
         del Psort, Csort
 
-        # Now Again.
-        if self.mine == 0:
-            matched = 0
-            left = 0
-            right = 0
-            while left < parent_IDs_tosend.size and right < child_IDs_tosend.size:
-                if parent_IDs_tosend[left] == child_IDs_tosend[right]:
-                    # They match up, add this relationship.
-                    try:
-                        loc = self.child_mass_loc[parent_halos_tosend[left]][child_halos_tosend[right]]
-                    except KeyError:
-                        # This happens when a child halo contains a particle from
-                        # a parent halo, but the child is not identified as a 
-                        # candidate child halo. So we do nothing and move on with
-                        # our lives.
-                        left += 1
-                        right += 1
-                        continue
-                    self.child_mass_arr[loc] += parent_masses_tosend[left]
-                    matched += 1
-                    left += 1
-                    right += 1
-                    continue
-                if parent_IDs_tosend[left] < child_IDs_tosend[right]:
-                    # The left is too small, so we need to increase it.
-                    left += 1
-                    continue
-                if parent_IDs_tosend[left] > child_IDs_tosend[right]:
-                    # Right too small.
-                    right += 1
-                    continue
-            mylog.info("Clean-up round matched %d of %d parents and %d children." % \
-            (matched, parent_IDs_tosend.size, child_IDs_tosend.size))
+        # Now again, but only on the root task.
+        if self.comm.rank == 0:
+            self._match(parent_IDs_tosend, child_IDs_tosend,
+            parent_halos_tosend, child_halos_tosend, parent_masses_tosend)
 
         # Now we sum up the contributions globally.
         self.child_mass_arr = self.comm.mpi_allreduce(self.child_mass_arr)
         
-        if self.mine == 0:
+        if self.comm.rank == 0:
             # Turn these Msol masses into percentages of the parent.
             line = "SELECT HaloMass FROM Halos WHERE SnapCurrentTimeIdentifier=%d \
             ORDER BY SnapHaloID ASC;" % parent_currt
@@ -604,8 +545,8 @@
             baseChildID = None
         
         # Sync up data on all tasks.
-        self.child_mass_arr = self.comm.mpi_Bcast_array(self.child_mass_arr)
-        baseChildID = self.comm.mpi_bcast_pickled(baseChildID)
+        self.child_mass_arr = self.comm.mpi_bcast(self.child_mass_arr)
+        baseChildID = self.comm.mpi_bcast(baseChildID)
         
         # Now we prepare a big list of writes to put in the database.
         for i,parent_halo in enumerate(sorted(self.candidates)):
@@ -642,74 +583,113 @@
         
         return (child_IDs, child_masses, child_halos)
 
+    def _match(self, parent_IDs, child_IDs, parent_halos, child_halos,
+            parent_masses, parent_send = None, child_send = None):
+        # Parent IDs on the left, child IDs on the right. We skip down both
+        # columns matching IDs. If they are out of synch, the index(es) is/are
+        # advanced until they match up again.
+        matched = 0
+        left = 0
+        right = 0
+        while left < parent_IDs.size and right < child_IDs.size:
+            if parent_IDs[left] == child_IDs[right]:
+                # They match up, add this relationship.
+                try:
+                    loc = self.child_mass_loc[parent_halos[left]][child_halos[right]]
+                except KeyError:
+                    # This happens when a child halo contains a particle from
+                    # a parent halo, but the child is not identified as a 
+                    # candidate child halo. So we do nothing and move on with
+                    # our lives.
+                    left += 1
+                    right += 1
+                    continue
+                self.child_mass_arr[loc] += parent_masses[left]
+                # If needed, mark this pair so we don't send them later.
+                if parent_send is not None:
+                    parent_send[left] = False
+                    child_send[right] = False
+                matched += 1
+                left += 1
+                right += 1
+                continue
+            if parent_IDs[left] < child_IDs[right]:
+                # The left is too small, so we need to increase it.
+                left += 1
+                continue
+            if parent_IDs[left] > child_IDs[right]:
+                # Right too small.
+                right += 1
+                continue
+        if parent_send is None:
+            mylog.info("Clean-up round matched %d of %d parents and %d children." % \
+            (matched, parent_IDs.size, child_IDs.size))
+
     def _copy_and_update_db(self):
         """
         Because doing an UPDATE of a SQLite database is really slow, what we'll
         do here is basically read in lines from the database, and then insert
         the parent-child relationships, writing to a new DB.
         """
+        # All of this happens only on the root task!
         temp_name = self.database + '-tmp'
-        if self.mine == 0:
-            to_write = []
-            # Open the temporary database.
+        to_write = []
+        # Open the temporary database.
+        try:
+            os.remove(temp_name)
+        except OSError:
+            pass
+        temp_conn = sql.connect(temp_name)
+        temp_cursor = temp_conn.cursor()
+        line = "CREATE TABLE Halos (GlobalHaloID INTEGER PRIMARY KEY,\
+                SnapCurrentTimeIdentifier INTEGER, SnapZ FLOAT, SnapHaloID INTEGER, \
+                HaloMass FLOAT,\
+                NumPart INTEGER, CenMassX FLOAT, CenMassY FLOAT,\
+                CenMassZ FLOAT, BulkVelX FLOAT, BulkVelY FLOAT, BulkVelZ FLOAT,\
+                MaxRad FLOAT,\
+                ChildHaloID0 INTEGER, ChildHaloFrac0 FLOAT, \
+                ChildHaloID1 INTEGER, ChildHaloFrac1 FLOAT, \
+                ChildHaloID2 INTEGER, ChildHaloFrac2 FLOAT, \
+                ChildHaloID3 INTEGER, ChildHaloFrac3 FLOAT, \
+                ChildHaloID4 INTEGER, ChildHaloFrac4 FLOAT);"
+        temp_cursor.execute(line)
+        temp_conn.commit()
+        # Get all the data!
+        self.cursor.execute("SELECT * FROM Halos;")
+        results = self.cursor.fetchone()
+        while results:
+            results = list(results)
+            currt = results[1]
+            hid = results[3]
+            # If for some reason this halo doesn't have relationships,
+            # we'll just keep the old results the same.
             try:
-                os.remove(temp_name)
-            except OSError:
-                pass
-            temp_conn = sql.connect(temp_name)
-            temp_cursor = temp_conn.cursor()
-            line = "CREATE TABLE Halos (GlobalHaloID INTEGER PRIMARY KEY,\
-                    SnapCurrentTimeIdentifier INTEGER, SnapZ FLOAT, SnapHaloID INTEGER, \
-                    HaloMass FLOAT,\
-                    NumPart INTEGER, CenMassX FLOAT, CenMassY FLOAT,\
-                    CenMassZ FLOAT, BulkVelX FLOAT, BulkVelY FLOAT, BulkVelZ FLOAT,\
-                    MaxRad FLOAT,\
-                    ChildHaloID0 INTEGER, ChildHaloFrac0 FLOAT, \
-                    ChildHaloID1 INTEGER, ChildHaloFrac1 FLOAT, \
-                    ChildHaloID2 INTEGER, ChildHaloFrac2 FLOAT, \
-                    ChildHaloID3 INTEGER, ChildHaloFrac3 FLOAT, \
-                    ChildHaloID4 INTEGER, ChildHaloFrac4 FLOAT);"
-            temp_cursor.execute(line)
-            temp_conn.commit()
-            # Get all the data!
-            self.cursor.execute("SELECT * FROM Halos;")
+                lookup = self.write_values_dict[currt][hid]
+                new = tuple(results[:-10] + lookup)
+            except KeyError:
+                new = tuple(results)
+            to_write.append(new)
             results = self.cursor.fetchone()
-            while results:
-                results = list(results)
-                currt = results[1]
-                hid = results[3]
-                # If for some reason this halo doesn't have relationships,
-                # we'll just keep the old results the same.
-                try:
-                    lookup = self.write_values_dict[currt][hid]
-                    new = tuple(results[:-10] + lookup)
-                except KeyError:
-                    new = tuple(results)
-                to_write.append(new)
-                results = self.cursor.fetchone()
-            # Now write to the temp database.
-            # 23 question marks for 23 data columns.
-            line = ''
-            for i in range(23):
-                line += '?,'
-            # Pull off the last comma.
-            line = 'INSERT into Halos VALUES (' + line[:-1] + ')'
-            for insert in to_write:
-                temp_cursor.execute(line, insert)
-            temp_conn.commit()
-            mylog.info("Creating database index.")
-            line = "CREATE INDEX IF NOT EXISTS HalosIndex ON Halos ("
-            for name in columns:
-                line += name +","
-            line = line[:-1] + ");"
-            temp_cursor.execute(line)
-            temp_cursor.close()
-            temp_conn.close()
-        if self.mine == 0:
-            self._close_database()
-        self.comm.barrier()
-        if self.mine == 0:
-            os.rename(temp_name, self.database)
+        # Now write to the temp database.
+        # 23 question marks for 23 data columns.
+        line = ''
+        for i in range(23):
+            line += '?,'
+        # Pull off the last comma.
+        line = 'INSERT into Halos VALUES (' + line[:-1] + ')'
+        for insert in to_write:
+            temp_cursor.execute(line, insert)
+        temp_conn.commit()
+        mylog.info("Creating database index.")
+        line = "CREATE INDEX IF NOT EXISTS HalosIndex ON Halos ("
+        for name in columns:
+            line += name +","
+        line = line[:-1] + ");"
+        temp_cursor.execute(line)
+        temp_cursor.close()
+        temp_conn.close()
+        self._close_database()
+        os.rename(temp_name, self.database)
 
 class MergerTreeConnect(DatabaseFunctions):
     def __init__(self, database='halos.db'):


diff -r b03443ed6d88bb3a2d6b8d7af272f393123777e6 -r fcc157c8cf796edc172191549830973eeba3f270 yt/analysis_modules/two_point_functions/two_point_functions.py
--- a/yt/analysis_modules/two_point_functions/two_point_functions.py
+++ b/yt/analysis_modules/two_point_functions/two_point_functions.py
@@ -403,7 +403,7 @@
             status = 0
         # Broadcast the status from root - we stop only if root thinks we should
         # stop.
-        status = self.comm.mpi_bcast_pickled(status)
+        status = self.comm.mpi_bcast(status)
         if status == 0: return True
         if self.comm_cycle_count < status:
             return True


diff -r b03443ed6d88bb3a2d6b8d7af272f393123777e6 -r fcc157c8cf796edc172191549830973eeba3f270 yt/frontends/enzo/data_structures.py
--- a/yt/frontends/enzo/data_structures.py
+++ b/yt/frontends/enzo/data_structures.py
@@ -462,7 +462,7 @@
                     field_list = field_list.union(gf)
         else:
             field_list = None
-        field_list = self.comm.mpi_bcast_pickled(field_list)
+        field_list = self.comm.mpi_bcast(field_list)
         self.save_data(list(field_list),"/","DataFields",passthrough=True)
         self.field_list = list(field_list)
 


diff -r b03443ed6d88bb3a2d6b8d7af272f393123777e6 -r fcc157c8cf796edc172191549830973eeba3f270 yt/utilities/parallel_tools/parallel_analysis_interface.py
--- a/yt/utilities/parallel_tools/parallel_analysis_interface.py
+++ b/yt/utilities/parallel_tools/parallel_analysis_interface.py
@@ -259,7 +259,7 @@
                 all_clear = 0
         else:
             all_clear = None
-        all_clear = comm.mpi_bcast_pickled(all_clear)
+        all_clear = comm.mpi_bcast(all_clear)
         if not all_clear: raise RuntimeError
     if parallel_capable: return root_only
     return func
@@ -503,22 +503,25 @@
         raise NotImplementedError
 
     @parallel_passthrough
-    def mpi_bcast_pickled(self, data):
-        data = self.comm.bcast(data, root=0)
-        return data
-
-    @parallel_passthrough
-    def mpi_Bcast_array(self, data):
-        if self.comm.rank == 0:
-            info = (data.shape, data.dtype)
+    def mpi_bcast(self, data):
+        # The second check below makes sure that we know how to communicate
+        # this type of array. Otherwise, we'll pickle it.
+        if isinstance(data, na.ndarray) and \
+                get_mpi_type(data.dtype) is not None:
+            if self.comm.rank == 0:
+                info = (data.shape, data.dtype)
+            else:
+                info = ()
+            info = self.comm.bcast(info, root=0)
+            if self.comm.rank != 0:
+                data = na.empty(info[0], dtype=info[1])
+            mpi_type = get_mpi_type(info[1])
+            self.comm.Bcast([data, mpi_type], root = 0)
+            return data
         else:
-            info = ()
-        info = self.comm.bcast(info, root=0)
-        if self.comm.rank != 0:
-            data = na.empty(info[0], dtype=info[1])
-        mpi_type = get_mpi_type(info[1])
-        self.comm.Bcast([data, mpi_type], root=0)
-        return data
+            # Use pickled methods.
+            data = self.comm.bcast(data, root = 0)
+            return data
 
     def preload(self, grids, fields, io_handler):
         # This will preload if it detects we are parallel capable and

Repository URL: https://bitbucket.org/yt_analysis/yt/

--

This is a commit notification from bitbucket.org. You are receiving
this because you have the service enabled, addressing the recipient of
this email.



More information about the yt-svn mailing list