[Yt-svn] yt-commit r1713 - trunk/yt/extensions

sskory at wrangler.dreamhost.com sskory at wrangler.dreamhost.com
Fri Apr 30 11:03:12 PDT 2010


Author: sskory
Date: Fri Apr 30 11:03:12 2010
New Revision: 1713
URL: http://yt.enzotools.org/changeset/1713

Log:
A modification to the Merger Tree that substaintailly speeds up the database writing steps.

Modified:
   trunk/yt/extensions/MergerTree.py

Modified: trunk/yt/extensions/MergerTree.py
==============================================================================
--- trunk/yt/extensions/MergerTree.py	(original)
+++ trunk/yt/extensions/MergerTree.py	Fri Apr 30 11:03:12 2010
@@ -127,6 +127,8 @@
         # Loop over the pairs of snapshots to locate likely neighbors, and
         # then use those likely neighbors to compute fractional contributions.
         last = None
+        self.write_values = []
+        self.write_values_dict = defaultdict(dict)
         for snap, pair in enumerate(zip(self.restart_files[:-1], self.restart_files[1:])):
             if not self.with_halos[snap] or not self.with_halos[snap+1]:
                 continue
@@ -135,15 +137,10 @@
             # as the child from the previous round for all but the first loop.
             last = self._compute_child_fraction(pair[0], pair[1], last)
         del last
-        if self.mine == 0 and index:
-            mylog.info("Creating database index.")
-            line = "CREATE INDEX IF NOT EXISTS HalosIndex ON Halos ("
-            for name in columns:
-                line += name +","
-            line = line[:-1] + ");"
-            self.cursor.execute(line)
+        # Now update the database with all the writes.
+        mylog.info("Updating database with parent-child relationships.")
+        self._copy_and_update_db()
         self._barrier()
-        self._close_database()
         mylog.info("Done!")
         
     def _read_halo_lists(self):
@@ -555,7 +552,6 @@
         baseChildID = self.cursor.fetchone()[0]
         
         # Now we prepare a big list of writes to put in the database.
-        write_values = []
         for i,parent_halo in enumerate(sorted(self.candidates)):
             child_indexes = []
             child_per = []
@@ -574,29 +570,79 @@
             values = []
             for pair in zip(child_indexes, child_per):
                 values.extend([int(pair[0]), float(pair[1])])
-            values.extend([parent_currt, parent_halo])
+            #values.extend([parent_currt, parent_halo])
             # This has the child ID, child percent listed five times, followed
             # by the currt and this parent halo ID (SnapHaloID).
-            values = tuple(values)
-            write_values.append(values)
-        
-        # Now we do the actual writing, but only by task 0.
-        line = 'UPDATE Halos SET ChildHaloID0=?, ChildHaloFrac0=?,\
-        ChildHaloID1=?, ChildHaloFrac1=?,\
-        ChildHaloID2=?, ChildHaloFrac2=?,\
-        ChildHaloID3=?, ChildHaloFrac3=?,\
-        ChildHaloID4=?, ChildHaloFrac4=?\
-        WHERE SnapCurrentTimeIdentifier=? AND SnapHaloID=?;'
-        if self.mine == 0:
-            for values in write_values:
-                self.cursor.execute(line, values)
-            self.conn.commit()
-        
-        # This has a barrier in it, which ensures the disk isn't lagging.
-        self._ensure_db_sync()
+            #values = tuple(values)
+            self.write_values.append(values)
+            self.write_values_dict[parent_currt][parent_halo] = values
         
         return (child_IDs, child_masses, child_halos)
 
+    def _copy_and_update_db(self):
+        """
+        Because doing an UPDATE of a SQLite database is really slow, what we'll
+        do here is basically read in lines from the database, and then insert
+        the parent-child relationships, writing to a new DB.
+        """
+        temp_name = self.database + '-tmp'
+        if self.mine == 0:
+            to_write = []
+            # Open the temporary database.
+            temp_conn = sql.connect(temp_name)
+            temp_cursor = temp_conn.cursor()
+            line = "CREATE TABLE Halos (GlobalHaloID INTEGER PRIMARY KEY,\
+                    SnapCurrentTimeIdentifier INTEGER, SnapZ FLOAT, SnapHaloID INTEGER, \
+                    HaloMass FLOAT,\
+                    NumPart INTEGER, CenMassX FLOAT, CenMassY FLOAT,\
+                    CenMassZ FLOAT, BulkVelX FLOAT, BulkVelY FLOAT, BulkVelZ FLOAT,\
+                    MaxRad FLOAT,\
+                    ChildHaloID0 INTEGER, ChildHaloFrac0 FLOAT, \
+                    ChildHaloID1 INTEGER, ChildHaloFrac1 FLOAT, \
+                    ChildHaloID2 INTEGER, ChildHaloFrac2 FLOAT, \
+                    ChildHaloID3 INTEGER, ChildHaloFrac3 FLOAT, \
+                    ChildHaloID4 INTEGER, ChildHaloFrac4 FLOAT);"
+            temp_cursor.execute(line)
+            temp_conn.commit()
+            # Get all the data!
+            self.cursor.execute("SELECT * FROM Halos;")
+            results = self.cursor.fetchone()
+            while results:
+                results = list(results)
+                currt = results[1]
+                hid = results[3]
+                # If for some reason this halo doesn't have relationships,
+                # we'll just keep the old results the same.
+                try:
+                    lookup = self.write_values_dict[currt][hid]
+                    new = tuple(results[:-10] + lookup)
+                except KeyError:
+                    new = tuple(results)
+                to_write.append(new)
+                results = self.cursor.fetchone()
+            # Now write to the temp database.
+            # 23 question marks for 23 data columns.
+            line = ''
+            for i in range(23):
+                line += '?,'
+            # Pull off the last comma.
+            line = 'INSERT into Halos VALUES (' + line[:-1] + ')'
+            for insert in to_write:
+                temp_cursor.execute(line, insert)
+            temp_conn.commit()
+            mylog.info("Creating database index.")
+            line = "CREATE INDEX IF NOT EXISTS HalosIndex ON Halos ("
+            for name in columns:
+                line += name +","
+            line = line[:-1] + ");"
+            temp_cursor.execute(line)
+            temp_cursor.close()
+            temp_conn.close()
+        self._close_database()
+        self._barrier()
+        if self.mine == 0:
+            os.rename(temp_name, self.database)
+
 class MergerTreeConnect(DatabaseFunctions):
     def __init__(self, database='halos.db'):
         self.database = database



More information about the yt-svn mailing list