[Yt-svn] yt-commit r1560 - in trunk/yt: . extensions extensions/kdtree lagos lagos/parallelHOP

sskory at wrangler.dreamhost.com sskory at wrangler.dreamhost.com
Sun Dec 20 12:42:19 PST 2009


Author: sskory
Date: Sun Dec 20 12:42:09 2009
New Revision: 1560
URL: http://yt.enzotools.org/changeset/1560

Log:
Adding my patches as provided by mturk to the svn-trunk. In particular, Parallel HOP, and star and halo mass analysis routines.

Added:
   trunk/yt/extensions/HaloMassFcn.py
   trunk/yt/extensions/StarAnalysis.py
   trunk/yt/extensions/kdtree/
   trunk/yt/extensions/kdtree/Makefile
   trunk/yt/extensions/kdtree/__init__.py
   trunk/yt/extensions/kdtree/fKD.f90
   trunk/yt/extensions/kdtree/fKD.v
   trunk/yt/extensions/kdtree/fKD_source.f90
   trunk/yt/extensions/kdtree/test.py
   trunk/yt/lagos/parallelHOP/
   trunk/yt/lagos/parallelHOP/__init__.py
   trunk/yt/lagos/parallelHOP/parallelHOP.py
   trunk/yt/lagos/parallelHOP/run.py
   trunk/yt/math_utils.py
Modified:
   trunk/yt/lagos/BaseDataTypes.py
   trunk/yt/lagos/EnzoFields.py
   trunk/yt/lagos/HaloFinding.py
   trunk/yt/lagos/ParallelTools.py
   trunk/yt/lagos/setup.py

Added: trunk/yt/extensions/HaloMassFcn.py
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/HaloMassFcn.py	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,719 @@
+"""
+HaloMassFcn - Halo Mass Function and supporting functions.
+
+Author: Stephen Skory <sskory at physics.ucsd.edu>
+Affiliation: UC San Diego / CASS
+Homepage: http://yt.enzotools.org/
+License:
+  Copyright (C) 2008-2009 Stephen Skory (and others).  All Rights Reserved.
+
+  This file is part of yt.
+
+  yt is free software; you can redistribute it and/or modify
+  it under the terms of the GNU General Public License as published by
+  the Free Software Foundation; either version 3 of the License, or
+  (at your option) any later version.
+
+  This program is distributed in the hope that it will be useful,
+  but WITHOUT ANY WARRANTY; without even the implied warranty of
+  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+  GNU General Public License for more details.
+
+  You should have received a copy of the GNU General Public License
+  along with this program.  If not, see <http://www.gnu.org/licenses/>.
+"""
+
+import yt.lagos as lagos
+from yt.logger import lagosLogger as mylog
+import numpy as na
+import math, time
+
+class HaloMassFcn(object):
+    def __init__(self, pf, halo_file=None, omega_matter0=None, omega_lambda0=None,
+    omega_baryon0=0.05, hubble0=None, sigma8input=0.86, primordial_index=1.0,
+    this_redshift=None, log_mass_min=None, log_mass_max=None, num_sigma_bins=360,
+    fitting_function=4):
+        """
+        Initalize a HaloMassFcn object to analyze the distribution of haloes
+        as a function of mass.
+        :param halo_file (str): The filename of the output of the Halo Profiler.
+        Default=None.
+        :param omega_matter0 (float): The fraction of the universe made up of
+        matter (dark and baryonic). Default=None.
+        :param omega_lambda0 (float): The fraction of the universe made up of
+        dark energy. Default=None.
+        :param omega_baryon0 (float): The fraction of the universe made up of
+        ordinary baryonic matter. This should match the value
+        used to create the initial conditions, using 'inits'. This is 
+        *not* stored in the enzo datset so it must be checked by hand.
+        Default=0.05.
+        :param hubble0 (float): The expansion rate of the universe in units of
+        100 km/s/Mpc. Default=None.
+        :param sigma8input (float): The amplitude of the linear power
+        spectrum at z=0 as specified by the rms amplitude of mass-fluctuations
+        in a top-hat sphere of radius 8 Mpc/h. This should match the value
+        used to create the initial conditions, using 'inits'. This is 
+        *not* stored in the enzo datset so it must be checked by hand.
+        Default=0.86.
+        :param primoridal_index (float): This is the index of the mass power
+        spectrum before modification by the transfer function. A value of 1
+        corresponds to the scale-free primordial spectrum. This should match
+        the value used to make the initial conditions using 'inits'. This is 
+        *not* stored in the enzo datset so it must be checked by hand.
+        Default=1.0.
+        :param this_redshift (float): The current redshift. Default=None.
+        :param log_mass_min (float): The log10 of the mass of the minimum of the
+        halo mass range. Default=None.
+        :param log_mass_max (float): The log10 of the mass of the maximum of the
+        halo mass range. Default=None.
+        :param num_sigma_bins (float): The number of bins (points) to use for
+        the calculations and generated fit. Default=360.
+        :param fitting_function (int): Which fitting function to use.
+        1 = Press-schechter, 2 = Jenkins, 3 = Sheth-Tormen, 4 = Warren fit
+        Default=4.
+        """
+        self.pf = pf
+        self.halo_file = halo_file
+        self.omega_matter0 = omega_matter0
+        self.omega_lambda0 = omega_lambda0
+        self.omega_baryon0 = omega_baryon0
+        self.hubble0 = hubble0
+        self.sigma8input = sigma8input
+        self.primordial_index = primordial_index
+        self.this_redshift = this_redshift
+        self.log_mass_min = log_mass_min
+        self.log_mass_max = log_mass_max
+        self.num_sigma_bins = num_sigma_bins
+        self.fitting_function = fitting_function
+        
+        # Determine the run mode.
+        if halo_file is None:
+            # We are hand-picking our various cosmological parameters
+            self.mode = 'single'
+        else:
+            # Make the fit using the same cosmological parameters as the dataset.
+            self.mode = 'haloes'
+            self.omega_matter0 = self.pf['CosmologyOmegaMatterNow']
+            self.omega_lambda0 = self.pf['CosmologyOmegaLambdaNow']
+            self.hubble0 = self.pf['CosmologyHubbleConstantNow']
+            self.this_redshift = self.pf['CosmologyCurrentRedshift']
+            self.read_haloes()
+            self.log_mass_min = math.log10(min(self.haloes))
+            self.log_mass_max = math.log10(max(self.haloes))
+
+        # Input error check.
+        if self.mode == 'single':
+            if omega_matter0 == None or omega_lambda0 == None or \
+            hubble0 == None or this_redshift == None or log_mass_min == None or\
+            log_mass_max == None:
+                mylog.error("All of these parameters need to be set:")
+                mylog.error("[omega_matter0, omega_lambda0, \
+                hubble0, this_redshift, log_mass_min, log_mass_max]")
+                mylog.error("[%s,%s,%s,%s,%s,%s]" % (omega_matter0,\
+                omega_lambda0, hubble0, this_redshift,\
+                log_mass_min, log_mass_max))
+                return None
+        
+        # Poke the user to make sure they're doing it right.
+        mylog.info(
+        """
+        Please make sure these are the correct values! They are
+        not stored in enzo datasets, so must be entered by hand.
+        sigma8input=%f primordial_index=%f omega_baryon0=%f
+        """ % (self.sigma8input, self.primordial_index, self.omega_baryon0))
+        time.sleep(1)
+        
+        # Do the calculations.
+        self.sigmaM()
+        self.dndm()
+        
+        if self.mode == 'haloes':
+            self.bin_haloes()
+
+    def write_out(self, prefix='HMF', fit=True, haloes=True):
+        """
+        Writes out the halo mass functions to file(s) with prefix *prefix*.
+        """
+        # First the fit file.
+        if fit:
+            fitname = prefix + '-fit.dat'
+            fp = open(fitname, 'w')
+            line = \
+            """#Columns:
+#1. log10 of mass (Msolar, NOT Msolar/h)
+#2. mass (Msolar/h)
+#3. (dn/dM)*dM (differential number density of halos, per Mpc^3 (NOT h^3/Mpc^3)
+#4. cumulative number density of halos (per Mpc^3, NOT h^3/Mpc^3)
+"""
+            fp.write(line)
+            for i in xrange(self.logmassarray.size - 1):
+                line = "%e\t%e\t%e\t%e\n" % (self.logmassarray[i], self.massarray[i],
+                self.dn_M_z[i], self.nofmz_cum[i])
+                fp.write(line)
+            fp.close()
+        if self.mode == 'haloes' and haloes:
+            haloname = prefix + '-haloes.dat'
+            fp = open(haloname, 'w')
+            line = \
+            """#Columns:
+#1. log10 of mass (Msolar, NOT Msolar/h)
+#2. mass (Msolar/h)
+#3. cumulative number density of halos (per Mpc^3, NOT h^3/Mpc^3)\n
+"""
+            for i in xrange(self.logmassarray.size - 1):
+                line = "%e\t%e\t%e\n" % (self.logmassarray[i], self.massarray[i],
+                self.dis[i])
+                fp.write(line)
+            fp.close()
+        
+    def read_haloes(self):
+        """
+        Read in the virial masses of the haloes.
+        """
+        mylog.info("Reading halo masses from %s" % self.halo_file)
+        f = open(self.halo_file,'r')
+        line = f.readline() # burn the top header line.
+        line = f.readline()
+        self.haloes = []
+        while line:
+            line = line.split()
+            # Mass is in the 6th column (ord. 5)
+            mass = float(line[5])
+            if mass > 0:
+                self.haloes.append(float(line[5]))
+            line = f.readline()
+        f.close()
+        self.haloes = na.array(self.haloes)
+
+    def bin_haloes(self):
+        """
+        With the list of virial masses, find the halo mass function.
+        """
+        bins = na.logspace(math.pow(10.,self.log_mass_min),
+            math.pow(10.,log_mass_max),self.num_sigma_bins)
+        avgs = (bins[1:]+bins[:-1])/2.
+        
+        dis, bins = na.histogram(self.haloes,bins,new=True)
+        
+        # add right to left
+        for i,b in enumerate(dis):
+            dis[self.num_sigma_bins-i-3] += dis[self.num_sigma_bins-i-2]
+            if i == (self.num_sigma_bins - 3): break
+
+        self.dis = dis / self.pf['CosmologyComovingBoxSize']**3.0
+
+    def sigmaM(self):
+        """
+         Written by BWO, 2006 (updated 25 January 2007).
+         Converted to Python by Stephen Skory December 2009.
+
+         This routine takes in cosmological parameters and creates a file (array) with
+         sigma(M) in it, which is necessary for various press-schechter type
+         stuff.  In principle one can calculate it ahead of time, but it's far,
+         far faster in the long run to calculate your sigma(M) ahead of time.
+        
+         Inputs: cosmology, user must set parameters
+        
+         Outputs: four columns of data containing the following information:
+
+         1) log mass (Msolar)
+         2) mass (Msolar/h)
+         3) Radius (comoving Mpc/h)
+         4) sigma (normalized) using Msun/h as the input
+         
+         The arrays output are used later.
+        """
+        
+        # Set up the transfer function object.
+        self.TF = TransferFunction(self.omega_matter0, self.omega_baryon0, 0.0, 0,
+            self.omega_lambda0, self.hubble0, self.this_redshift);
+
+        if self.TF.qwarn:
+            mylog.error("You should probably fix your cosmology parameters!")
+
+        # output arrays
+        # 1) log10 of mass (Msolar, NOT Msolar/h)
+        self.Rarray = na.empty(self.num_sigma_bins,dtype='float64')
+        # 2) mass (Msolar/h)
+        self.logmassarray = na.empty(self.num_sigma_bins, dtype='float64')
+        # 3) spatial scale corresponding to that radius (Mpc/h)
+        self.massarray = na.empty(self.num_sigma_bins, dtype='float64')
+        # 4) sigma(M, z=0, where mass is in Msun/h)
+        self.sigmaarray = na.empty(self.num_sigma_bins, dtype='float64')
+
+        # get sigma_8 normalization
+        R = 8.0;  # in units of Mpc/h (comoving)
+
+        sigma8_unnorm = math.sqrt(self.sigma_squared_of_R(R));
+        sigma_normalization = self.sigma8input / sigma8_unnorm;
+
+        rho0 = self.omega_matter0 * 2.78e+11; # in units of h^2 Msolar/Mpc^3
+
+        # spacing in mass of our sigma calculation
+        dm = (float(self.log_mass_max) - self.log_mass_min)/self.num_sigma_bins;
+
+        """
+         loop over the total number of sigma_bins the user has requested. 
+         For each bin, calculate mass and equivalent radius, and call
+         sigma_squared_of_R to get the sigma(R) (equivalent to sigma(M)),
+         normalize by user-specified sigma_8, and then write out.
+        """
+        for i in xrange(self.num_sigma_bins):
+    
+            # thislogmass is in units of Msolar, NOT Msolar/h
+            thislogmass = self.log_mass_min +  i*dm
+    
+            # mass in units of h^-1 Msolar
+            thismass = math.pow(10.0, thislogmass) * self.hubble0; 
+    
+            # radius is in units of h^-1 Mpc (comoving)
+            thisradius = math.pow( 3.0*thismass / 4.0 / math.pi / rho0, 1.0/3.0 );
+    
+            R = thisradius; # h^-1 Mpc (comoving)
+    
+            self.Rarray[i] = thisradius;  # h^-1 Mpc (comoving)
+            self.logmassarray[i] = thislogmass;  # Msun (NOT Msun/h)
+            self.massarray[i] = thismass;  # Msun/h
+    
+            # get normalized sigma(R)
+            self.sigmaarray[i] = math.sqrt(self.sigma_squared_of_R(R)) * sigma_normalization;
+            # All done!
+
+    def dndm(self):
+        
+        # constants - set these before calling any functions!
+        rho0 = self.omega_matter0 * 2.78e+11; # in units of h^2 Msolar/Mpc^3
+        self.delta_c0 = 1.69;  # critical density for turnaround (Press-Schechter)
+        
+        nofmz_cum = 0.0;  # keep track of cumulative number density
+        
+        # Loop over masses, going BACKWARD, and calculate dn/dm as well as the 
+        # cumulative mass function.
+        
+        # output arrays
+        # 5) (dn/dM)*dM (differential number density of halos, per Mpc^3 (NOT h^3/Mpc^3)
+        self.dn_M_z = na.empty(self.num_sigma_bins, dtype='float64')
+        # 6) cumulative number density of halos (per Mpc^3, NOT h^3/Mpc^3)
+        self.nofmz_cum = na.zeros(self.num_sigma_bins, dtype='float64')
+        
+        for j in xrange(self.num_sigma_bins - 1):
+            i = (self.num_sigma_bins - 2) - j
+        
+            thissigma = self.sigmaof_M_z(i, self.this_redshift);
+            nextsigma = self.sigmaof_M_z(i+1, self.this_redshift);
+            
+            # calc dsigmadm - has units of h (since massarray has units of h^-1)
+            dsigmadm = (nextsigma-thissigma) / (self.massarray[i+1] - self.massarray[i]);
+
+            # calculate dn(M,z) (dn/dM * dM)
+            # this has units of h^3 since rho0 has units of h^2, dsigmadm
+            # has units of h, and massarray has units of h^-1
+            dn_M_z = -1.0 / thissigma * dsigmadm * rho0 / self.massarray[i] * \
+            self.multiplicityfunction(thissigma)*(self.massarray[i+1] - self.massarray[i]);
+
+            # scale by h^4 to get rid of all factors of h
+            dn_M_z *= math.pow(self.hubble0, 4.0);
+            
+            # keep track of cumulative number density
+            if dn_M_z > 1.0e-20:
+                nofmz_cum += dn_M_z;
+            
+            # Store this.
+            self.nofmz_cum[i] = nofmz_cum
+            self.dn_M_z[i] = dn_M_z
+        
+
+    def sigma_squared_of_R(self, R):
+        """
+        /* calculates sigma^2(R).  This is the routine where the magic happens (or
+           whatever it is that we do here).  Integrates the sigma_squared_integrand
+           parameter from R to infinity.  Calls GSL (gnu scientific library) to do
+           the actual integration.  
+        
+           Note that R is in h^-1 Mpc (comoving)
+        */
+        """
+        self.R = R
+        result = integrate_inf(self.sigma_squared_integrand)
+
+        sigmasquaredofR = result / 2.0 / math.pi / math.pi
+
+        return sigmasquaredofR;
+
+    def sigma_squared_integrand(self, k):
+        """
+        /* integrand for integral to get sigma^2(R). */
+        """
+
+        Rcom = self.R;  # this is R in comoving Mpc/h
+
+        f = k*k*self.PofK(k)*na.power( abs(self.WofK(Rcom,k)), 2.0);
+
+        return f
+
+    def PofK(self, k):
+        """
+        /* returns power spectrum as a function of wavenumber k */
+        """
+
+        thisPofK = na.power(k, self.primordial_index) * na.power( self.TofK(k), 2.0);
+
+        return thisPofK;
+
+    def TofK(self, k):
+        """
+        /* returns transfer function as a function of wavenumber k. */
+        """
+        
+        thisTofK = self.TF.TFmdm_onek_hmpc(k);
+
+        return thisTofK;
+
+    def WofK(self, R, k):
+        """
+        returns W(k), which is the fourier transform of the top-hat function.
+        """
+
+        x = R*k;
+
+        thisWofK = 3.0 * ( na.sin(x) - x*na.cos(x) ) / (x*x*x);
+
+        return thisWofK;
+
+    def multiplicityfunction(self, sigma):
+        """
+        /* Multiplicity function - this is where the various fitting functions/analytic 
+        theories are different.  The various places where I found these fitting functions
+        are listed below.  */
+        """
+        
+        nu = self.delta_c0 / sigma;
+        
+        if self.fitting_function==1:
+            # Press-Schechter (This form from Jenkins et al. 2001, MNRAS 321, 372-384, eqtn. 5)
+            thismult = sqrt(2.0/math.pi) * nu * exp(-0.5*nu*nu);
+        
+        elif self.fitting_function==2:
+            # Jenkins et al. 2001, MNRAS 321, 372-384, eqtn. 9
+            thismult = 0.315 * math.exp( -1.0 * math.pow( abs( math.log(1.0/sigma) + 0.61), 3.8 ) );
+        
+        elif self.fitting_function==3:
+            # Sheth-Tormen 1999, eqtn 10, using expression from Jenkins et al. 2001, eqtn. 7
+            A=0.3222;
+            a=0.707;
+            p=0.3;
+            thismult = A*math.sqrt(2.0*a/math.pi)*(1.0+ math.pow( 1.0/(nu*nu*a), p) )*\
+            nu * math.exp(-0.5*a*nu*nu);
+        
+        elif self.fitting_function==4:
+            # LANL fitting function - Warren et al. 2005, astro-ph/0506395, eqtn. 5 
+            A=0.7234; 
+            a=1.625; 
+            b=0.2538; 
+            c=1.1982;
+            thismult = A*( math.pow(sigma, -1.0*a) + b)*math.exp(-1.0*c / sigma / sigma );
+        
+        else:
+            mylog.error("Don't understand this.  Fitting function requested is %d\n",
+            self.fitting_function)
+            return None
+        
+        return thismult
+
+    def sigmaof_M_z(self, sigmabin, redshift):
+        """
+        /* sigma(M, z) */
+        """
+        
+        thissigma = self.Dofz(redshift) * self.sigmaarray[sigmabin];
+        
+        return thissigma;
+
+    def Dofz(self, redshift):
+        """
+        /* Growth function */
+        """
+
+        thisDofz = self.gofz(redshift) / self.gofz(0.0) / (1.0+redshift);
+
+        return thisDofz;
+
+
+    def gofz(self, redshift):
+        """
+        /* g(z) - I don't think this has any other name*/
+        """
+
+        thisgofz = 2.5 * self.omega_matter_of_z(redshift) / \
+        ( math.pow( self.omega_matter_of_z(redshift), 4.0/7.0 ) - \
+          self.omega_lambda_of_z(redshift) + \
+          ( (1.0 + self.omega_matter_of_z(redshift) / 2.0) * \
+          (1.0 + self.omega_lambda_of_z(redshift) / 70.0) ))
+
+        return thisgofz;
+
+
+    def omega_matter_of_z(self,redshift):
+        """
+        /* Omega matter as a function of redshift */
+        """
+        
+        thisomofz = self.omega_matter0 * math.pow( 1.0+redshift, 3.0) / \
+            math.pow( self.Eofz(redshift), 2.0 );
+        
+        return thisomofz;
+
+    def omega_lambda_of_z(self,redshift):
+        """
+        /* Omega lambda as a function of redshift */
+        """
+
+        thisolofz = self.omega_lambda0 / math.pow( self.Eofz(redshift), 2.0 )
+
+        return thisolofz;
+
+    def Eofz(self, redshift):
+        """
+        /* E(z) - I don't think this has any other name */
+        """
+        thiseofz = math.sqrt( self.omega_lambda0 \
+            + (1.0 - self.omega_lambda0 - self.omega_matter0)*math.pow( 1.0+redshift, 2.0) \
+            + self.omega_matter0 * math.pow( 1.0+redshift, 3.0)  );
+
+        return thiseofz;
+
+
+""" 
+/* Fitting Formulae for CDM + Baryon + Massive Neutrino (MDM) cosmologies. */
+/* Daniel J. Eisenstein & Wayne Hu, Institute for Advanced Study */
+
+/* There are two primary routines here, one to set the cosmology, the
+other to construct the transfer function for a single wavenumber k. 
+You should call the former once (per cosmology) and the latter as 
+many times as you want. */
+
+/* TFmdm_set_cosm() -- User passes all the cosmological parameters as
+	arguments; the routine sets up all of the scalar quantites needed 
+	computation of the fitting formula.  The input parameters are: 
+	1) omega_matter -- Density of CDM, baryons, and massive neutrinos,
+				in units of the critical density. 
+	2) omega_baryon -- Density of baryons, in units of critical. 
+	3) omega_hdm    -- Density of massive neutrinos, in units of critical 
+	4) degen_hdm    -- (Int) Number of degenerate massive neutrino species 
+	5) omega_lambda -- Cosmological constant 
+	6) hubble       -- Hubble constant, in units of 100 km/s/Mpc 
+	7) redshift     -- The redshift at which to evaluate */
+
+/* TFmdm_onek_mpc() -- User passes a single wavenumber, in units of Mpc^-1.
+	Routine returns the transfer function from the Eisenstein & Hu
+	fitting formula, based on the cosmology currently held in the 
+	internal variables.  The routine returns T_cb (the CDM+Baryon
+	density-weighted transfer function), although T_cbn (the CDM+
+	Baryon+Neutrino density-weighted transfer function) is stored
+	in the global variable tf_cbnu. */
+
+/* We also supply TFmdm_onek_hmpc(), which is identical to the previous
+	routine, but takes the wavenumber in units of h Mpc^-1. */
+
+/* We hold the internal scalar quantities in global variables, so that
+the user may access them in an external program, via "extern" declarations. */
+
+/* Please note that all internal length scales are in Mpc, not h^-1 Mpc! */
+"""
+
+class TransferFunction(object):
+    def __init__(self, omega_matter, omega_baryon, omega_hdm,
+	    degen_hdm, omega_lambda, hubble, redshift):
+        """
+        /* This routine takes cosmological parameters and a redshift and sets up
+        all the internal scalar quantities needed to compute the transfer function. */
+        /* INPUT: omega_matter -- Density of CDM, baryons, and massive neutrinos,
+                        in units of the critical density. */
+        /* 	  omega_baryon -- Density of baryons, in units of critical. */
+        /* 	  omega_hdm    -- Density of massive neutrinos, in units of critical */
+        /* 	  degen_hdm    -- (Int) Number of degenerate massive neutrino species */
+        /*        omega_lambda -- Cosmological constant */
+        /* 	  hubble       -- Hubble constant, in units of 100 km/s/Mpc */
+        /*        redshift     -- The redshift at which to evaluate */
+        /* OUTPUT: Returns 0 if all is well, 1 if a warning was issued.  Otherwise,
+            sets many global variables for use in TFmdm_onek_mpc() */
+        """
+        self.qwarn = 0;
+        self.theta_cmb = 2.728/2.7 # Assuming T_cmb = 2.728 K
+    
+        # Look for strange input
+        if (omega_baryon<0.0):
+            mylog.error("TFmdm_set_cosm(): Negative omega_baryon set to trace amount.\n")
+            self.qwarn = 1
+        if (omega_hdm<0.0):
+            mylog.error("TFmdm_set_cosm(): Negative omega_hdm set to trace amount.\n")
+            self.qwarn = 1;
+        if (hubble<=0.0):
+            mylog.error("TFmdm_set_cosm(): Negative Hubble constant illegal.\n")
+            return None
+        elif (hubble>2.0):
+            mylog.error("TFmdm_set_cosm(): Hubble constant should be in units of 100 km/s/Mpc.\n");
+            self.qwarn = 1;
+        if (redshift<=-1.0):
+            mylog.error("TFmdm_set_cosm(): Redshift < -1 is illegal.\n");
+            return None
+        elif (redshift>99.0):
+            mylog.error("TFmdm_set_cosm(): Large redshift entered.  TF may be inaccurate.\n");
+            self.qwarn = 1;
+
+        if (degen_hdm<1): degen_hdm=1;
+        self.num_degen_hdm = degen_hdm;	
+        # Have to save this for TFmdm_onek_mpc()
+        # This routine would crash if baryons or neutrinos were zero,
+        # so don't allow that.
+        if (omega_baryon<=0): omega_baryon=1e-5;
+        if (omega_hdm<=0): omega_hdm=1e-5;
+    
+        self.omega_curv = 1.0-omega_matter-omega_lambda;
+        self.omhh = omega_matter*SQR(hubble);
+        self.obhh = omega_baryon*SQR(hubble);
+        self.onhh = omega_hdm*SQR(hubble);
+        self.f_baryon = omega_baryon/omega_matter;
+        self.f_hdm = omega_hdm/omega_matter;
+        self.f_cdm = 1.0-self.f_baryon-self.f_hdm;
+        self.f_cb = self.f_cdm+self.f_baryon;
+        self.f_bnu = self.f_baryon+self.f_hdm;
+    
+        # Compute the equality scale.
+        self.z_equality = 25000.0*self.omhh/SQR(SQR(self.theta_cmb)) # Actually 1+z_eq
+        self.k_equality = 0.0746*self.omhh/SQR(self.theta_cmb);
+    
+        # Compute the drag epoch and sound horizon
+        z_drag_b1 = 0.313*math.pow(self.omhh,-0.419)*(1+0.607*math.pow(self.omhh,0.674));
+        z_drag_b2 = 0.238*math.pow(self.omhh,0.223);
+        self.z_drag = 1291*math.pow(self.omhh,0.251)/(1.0+0.659*math.pow(self.omhh,0.828))* \
+            (1.0+z_drag_b1*math.pow(self.obhh,z_drag_b2));
+        self.y_drag = self.z_equality/(1.0+self.z_drag);
+    
+        self.sound_horizon_fit = 44.5*math.log(9.83/self.omhh)/math.sqrt(1.0+10.0*math.pow(self.obhh,0.75));
+    
+        # Set up for the free-streaming & infall growth function 
+        self.p_c = 0.25*(5.0-math.sqrt(1+24.0*self.f_cdm));
+        self.p_cb = 0.25*(5.0-math.sqrt(1+24.0*self.f_cb));
+    
+        omega_denom = omega_lambda+SQR(1.0+redshift)*(self.omega_curv+\
+                omega_matter*(1.0+redshift));
+        self.omega_lambda_z = omega_lambda/omega_denom;
+        self.omega_matter_z = omega_matter*SQR(1.0+redshift)*(1.0+redshift)/omega_denom;
+        self.growth_k0 = self.z_equality/(1.0+redshift)*2.5*self.omega_matter_z/ \
+            (math.pow(self.omega_matter_z,4.0/7.0)-self.omega_lambda_z+ \
+            (1.0+self.omega_matter_z/2.0)*(1.0+self.omega_lambda_z/70.0));
+        self.growth_to_z0 = self.z_equality*2.5*omega_matter/(math.pow(omega_matter,4.0/7.0) \
+            -omega_lambda + (1.0+omega_matter/2.0)*(1.0+omega_lambda/70.0));
+        self.growth_to_z0 = self.growth_k0/self.growth_to_z0;	
+        
+        # Compute small-scale suppression
+        self.alpha_nu = self.f_cdm/self.f_cb*(5.0-2.*(self.p_c+self.p_cb))/(5.-4.*self.p_cb)* \
+        math.pow(1+self.y_drag,self.p_cb-self.p_c)* \
+        (1+self.f_bnu*(-0.553+0.126*self.f_bnu*self.f_bnu))/ \
+        (1-0.193*math.sqrt(self.f_hdm*self.num_degen_hdm)+0.169*self.f_hdm*math.pow(self.num_degen_hdm,0.2))* \
+        (1+(self.p_c-self.p_cb)/2*(1+1/(3.-4.*self.p_c)/(7.-4.*self.p_cb))/(1+self.y_drag));
+        self.alpha_gamma = math.sqrt(self.alpha_nu);
+        self.beta_c = 1/(1-0.949*self.f_bnu);
+        # Done setting scalar variables
+        self.hhubble = hubble # Need to pass Hubble constant to TFmdm_onek_hmpc()
+        
+
+    def TFmdm_onek_mpc(self,  kk):
+        """
+        /* Given a wavenumber in Mpc^-1, return the transfer function for the
+        cosmology held in the global variables. */
+        /* Input: kk -- Wavenumber in Mpc^-1 */
+        /* Output: The following are set as global variables:
+            growth_cb -- the transfer function for density-weighted
+                    CDM + Baryon perturbations. 
+            growth_cbnu -- the transfer function for density-weighted
+                    CDM + Baryon + Massive Neutrino perturbations. */
+        /* The function returns growth_cb */
+        """
+    
+        self.qq = kk/self.omhh*SQR(self.theta_cmb);
+    
+        # Compute the scale-dependent growth functions
+        self.y_freestream = 17.2*self.f_hdm*(1+0.488*math.pow(self.f_hdm,-7.0/6.0))* \
+            SQR(self.num_degen_hdm*self.qq/self.f_hdm);
+        temp1 = math.pow(self.growth_k0, 1.0-self.p_cb);
+        temp2 = na.power(self.growth_k0/(1+self.y_freestream),0.7);
+        self.growth_cb = na.power(1.0+temp2, self.p_cb/0.7)*temp1;
+        self.growth_cbnu = na.power(na.power(self.f_cb,0.7/self.p_cb)+temp2, self.p_cb/0.7)*temp1;
+    
+        # Compute the master function
+        self.gamma_eff = self.omhh*(self.alpha_gamma+(1-self.alpha_gamma)/ \
+            (1+SQR(SQR(kk*self.sound_horizon_fit*0.43))));
+        self.qq_eff = self.qq*self.omhh/self.gamma_eff;
+    
+        tf_sup_L = na.log(2.71828+1.84*self.beta_c*self.alpha_gamma*self.qq_eff);
+        tf_sup_C = 14.4+325/(1+60.5*na.power(self.qq_eff,1.11));
+        self.tf_sup = tf_sup_L/(tf_sup_L+tf_sup_C*SQR(self.qq_eff));
+    
+        self.qq_nu = 3.92*self.qq*math.sqrt(self.num_degen_hdm/self.f_hdm);
+        self.max_fs_correction = 1+1.2*math.pow(self.f_hdm,0.64)*math.pow(self.num_degen_hdm,0.3+0.6*self.f_hdm)/ \
+            (na.power(self.qq_nu,-1.6)+na.power(self.qq_nu,0.8));
+        self.tf_master = self.tf_sup*self.max_fs_correction;
+    
+        # Now compute the CDM+HDM+baryon transfer functions
+        tf_cb = self.tf_master*self.growth_cb/self.growth_k0;
+        tf_cbnu = self.tf_master*self.growth_cbnu/self.growth_k0;
+        return tf_cb
+
+
+    def TFmdm_onek_hmpc(self, kk):
+        """
+        /* Given a wavenumber in h Mpc^-1, return the transfer function for the
+        cosmology held in the global variables. */
+        /* Input: kk -- Wavenumber in h Mpc^-1 */
+        /* Output: The following are set as global variables:
+            growth_cb -- the transfer function for density-weighted
+                    CDM + Baryon perturbations. 
+            growth_cbnu -- the transfer function for density-weighted
+                    CDM + Baryon + Massive Neutrino perturbations. */
+        /* The function returns growth_cb */
+        """
+        return self.TFmdm_onek_mpc(kk*self.hhubble);
+
+def SQR(a):
+    return a*a
+
+def integrate_inf(fcn, error=1e-7, initial_guess=10):
+    """
+    Integrate a function *fcn* from zero to infinity, stopping when the answer
+    changes by less than *error*. Hopefully someday we can do something
+    better than this!
+    """
+    xvals = na.logspace(0,na.log10(initial_guess), initial_guess+1)-.9
+    yvals = fcn(xvals)
+    xdiffs = xvals[1:] - xvals[:-1]
+    # Trapezoid rule, but with different dxes between values, so na.trapz
+    # will not work.
+    areas = (yvals[1:] + yvals[:-1]) * xdiffs / 2.0
+    area0 = na.sum(areas)
+    # Next guess.
+    next_guess = 10 * initial_guess
+    xvals = na.logspace(0,na.log10(next_guess), 2*initial_guess**2+1)-.99
+    yvals = fcn(xvals)
+    xdiffs = xvals[1:] - xvals[:-1]
+    # Trapezoid rule.
+    areas = (yvals[1:] + yvals[:-1]) * xdiffs / 2.0
+    area1 = na.sum(areas)
+    # Now we refine until the error is smaller than *error*.
+    diff = area1 - area0
+    area_final = area1
+    area_last = area1
+    one_pow = 3
+    while diff > error:
+        next_guess *= 10
+        xvals = na.logspace(0,na.log10(next_guess), one_pow*initial_guess**one_pow+1) - (1 - 0.1**one_pow)
+        yvals = fcn(xvals)
+        xdiffs = xvals[1:] - xvals[:-1]
+        # Trapezoid rule.
+        areas = (yvals[1:] + yvals[:-1]) * xdiffs / 2.0
+        area_next = na.sum(areas)
+        diff = area_next - area_last
+        area_last = area_next
+        one_pow+=1
+    return area_last

Added: trunk/yt/extensions/StarAnalysis.py
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/StarAnalysis.py	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,384 @@
+"""
+StarAnalysis - Functions to analyze stars.
+
+Author: Stephen Skory <sskory at physics.ucsd.edu>
+Affiliation: UC San Diego / CASS
+Homepage: http://yt.enzotools.org/
+License:
+  Copyright (C) 2008-2009 Stephen Skory (and others).  All Rights Reserved.
+
+  This file is part of yt.
+
+  yt is free software; you can redistribute it and/or modify
+  it under the terms of the GNU General Public License as published by
+  the Free Software Foundation; either version 3 of the License, or
+  (at your option) any later version.
+
+  This program is distributed in the hope that it will be useful,
+  but WITHOUT ANY WARRANTY; without even the implied warranty of
+  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+  GNU General Public License for more details.
+
+  You should have received a copy of the GNU General Public License
+  along with this program.  If not, see <http://www.gnu.org/licenses/>.
+"""
+
+import yt.lagos as lagos
+from yt.logger import lagosLogger as mylog
+
+import numpy as na
+import h5py
+
+import math, itertools
+
+YEAR = 3.155693e7 # sec / year
+LIGHT = 2.997925e10 # cm / s
+
+class StarFormationRate(object):
+    def __init__(self, pf, data_source=None, star_mass=None,
+            star_creation_time=None, volume=None, bins=300):
+        self._pf = pf
+        self._data_source = data_source
+        self.star_mass = star_mass
+        self.star_creation_time = star_creation_time
+        self.volume = volume
+        self.bin_count = bins
+        # Check to make sure we have the right set of informations.
+        if data_source is None:
+            if self.star_mass is None or self.star_creation_time is None or \
+            self.volume is None:
+                mylog.error(
+                """
+                If data_source is not provided, all of these paramters need to be set:
+                star_mass (array, Msun),
+                star_creation_time (array, code units),
+                volume (float, Mpc**3).
+                """)
+                return None
+            self.mode = 'provided'
+        else:
+            self.mode = 'data_source'
+        # Set up for time conversion.
+        self.cosm = lagos.EnzoCosmology(HubbleConstantNow = 
+             (100.0 * self._pf['CosmologyHubbleConstantNow']),
+             OmegaMatterNow = self._pf['CosmologyOmegaMatterNow'],
+             OmegaLambdaNow = self._pf['CosmologyOmegaLambdaNow'],
+             InitialRedshift = self._pf['CosmologyInitialRedshift'])
+        # Find the time right now.
+        self.time_now = self.cosm.ComputeTimeFromRedshift(
+            self._pf["CosmologyCurrentRedshift"]) # seconds
+        # Build the distribution.
+        self.build_dist()
+
+    def build_dist(self):
+        """
+        Build the data for plotting.
+        """
+        # Pick out the stars.
+        if self.mode == 'data_source':
+            ct = self._data_source["creation_time"]
+            ct_stars = ct[ct > 0]
+            mass_stars = self._data_source["ParticleMassMsun"][ct > 0]
+        elif self.mode == 'provided':
+            ct_stars = self.star_creation_time
+            mass_stars = self.star_mass
+        # Find the oldest stars in units of code time.
+        tmin= min(ct_stars)
+        # Multiply the end to prevent numerical issues.
+        self.time_bins = na.linspace(tmin*0.99, self._pf['InitialTime'],
+            num = self.bin_count + 1)
+        # Figure out which bins the stars go into.
+        inds = na.digitize(ct_stars, self.time_bins) - 1
+        # Sum up the stars created in each time bin.
+        self.mass_bins = na.zeros(self.bin_count + 1, dtype='float64')
+        for index in na.unique(inds):
+            self.mass_bins[index] += sum(mass_stars[inds == index])
+        # Calculate the cumulative mass sum over time by forward adding.
+        self.cum_mass_bins = self.mass_bins.copy()
+        for index in xrange(self.bin_count):
+            self.cum_mass_bins[index+1] += self.cum_mass_bins[index]
+        # We will want the time taken between bins.
+        self.time_bins_dt = self.time_bins[1:] - self.time_bins[:-1]
+    
+    def write_out(self, name="StarFormationRate.out"):
+        """
+        Write out the star analysis to a text file *name*. The columns are in
+        order:
+        1) Time (yrs)
+        2) Look-back time (yrs)
+        3) Redshift
+        4) Star formation rate in this bin per year (Msol/yr)
+        5) Star formation rate in this bin per year per Mpc**3 (Msol/yr/Mpc**3)
+        6) Stars formed in this time bin (Msol)
+        7) Cumulative stars formed up to this time bin (Msol)
+        """
+        fp = open(name, "w")
+        if self.mode == 'data_source':
+            vol = self._data_source.volume('mpc')
+        elif self.mode == 'provided':
+            vol = self.volume
+        tc = self._pf["Time"]
+        # Use the center of the time_bin, not the left edge.
+        for i, time in enumerate((self.time_bins[1:] + self.time_bins[:-1])/2.):
+            line = "%1.5e\t%1.5e\t%1.5e\t%1.5e\t%1.5e\t%1.5e\t%1.5e\n" % \
+            (time * tc / YEAR, # Time
+            (self.time_now - time * tc)/YEAR, # Lookback time
+            self.cosm.ComputeRedshiftFromTime(time * tc), # Redshift
+            self.mass_bins[i] / (self.time_bins_dt[i] * tc / YEAR), # Msol/yr
+            self.mass_bins[i] / (self.time_bins_dt[i] * tc / YEAR) / vol, # Msol/yr/vol
+            self.mass_bins[i], # Msol in bin
+            self.cum_mass_bins[i]) # cumulative
+            fp.write(line)
+        fp.close()
+
+CHABRIER = {
+"Z0001" : "bc2003_hr_m22_chab_ssp.ised.h5", #/* 0.5% */
+"Z0004" : "bc2003_hr_m32_chab_ssp.ised.h5", #/* 2% */
+"Z004" : "bc2003_hr_m42_chab_ssp.ised.h5", #/* 20% */
+"Z008" : "bc2003_hr_m52_chab_ssp.ised.h5", #/* 40% */
+"Z02" : "bc2003_hr_m62_chab_ssp.ised.h5", #/* solar; 0.02 */
+"Z05" : "bc2003_hr_m72_chab_ssp.ised.h5" #/* 250% */
+}
+
+SALPETER = {
+"Z0001" : "bc2003_hr_m22_salp_ssp.ised.h5", #/* 0.5% */
+"Z0004" : "bc2003_hr_m32_salp_ssp.ised.h5", #/* 2% */
+"Z004" : "bc2003_hr_m42_salp_ssp.ised.h5", #/* 20% */
+"Z008" : "bc2003_hr_m52_salp_ssp.ised.h5", #/* 40% */
+"Z02" : "bc2003_hr_m62_salp_ssp.ised.h5", #/* solar; 0.02 */
+"Z05" : "bc2003_hr_m72_salp_ssp.ised.h5" #/* 250% */
+}
+
+Zsun = 0.02
+
+#/* dividing line of metallicity; linear in log(Z/Zsun) */
+METAL1 = 0.01  # /* in units of Z/Zsun */
+METAL2 = 0.0632
+METAL3 = 0.2828
+METAL4 = 0.6325
+METAL5 = 1.5811
+METALS = na.array([METAL1, METAL2, METAL3, METAL4, METAL5])
+
+# Translate METALS array digitize to the table dicts
+MtoD = na.array(["Z0001", "Z0004", "Z004", "Z008", "Z02",  "Z05"])
+
+"""
+This spectrum code is based on code from Ken Nagamine, converted from C to Python.
+I've also reversed the order of elements in the flux arrays to be in C-ordering,
+for faster memory access.
+"""
+
+class SpectrumBuilder(object):
+    def __init__(self, pf, bcdir="", model="chabrier"):
+        """
+        Initialize the data to build a summed flux spectrum for a
+        collection of stars using the models of Bruzual & Charlot (2003).
+        :param pf (object): Yt pf object.
+        :param bcdir (string): Path to directory containing Bruzual &
+        Charlot h5 fit files.
+        :param model (string): Choice of Initial Metalicity Function model,
+        'chabrier' or 'salpeter'.
+        """
+        self._pf = pf
+        self.bcdir = bcdir
+        
+        if model == "chabrier":
+            self.model = CHABRIER
+        elif model == "salpeter":
+            self.model = SALPETER
+        # Set up for time conversion.
+        self.cosm = lagos.EnzoCosmology(HubbleConstantNow = 
+             (100.0 * self._pf['CosmologyHubbleConstantNow']),
+             OmegaMatterNow = self._pf['CosmologyOmegaMatterNow'],
+             OmegaLambdaNow = self._pf['CosmologyOmegaLambdaNow'],
+             InitialRedshift = self._pf['CosmologyInitialRedshift'])
+        # Find the time right now.
+        self.time_now = self.cosm.ComputeTimeFromRedshift(
+            self._pf["CosmologyCurrentRedshift"]) # seconds
+        
+        # Read the tables.
+        self.read_bclib()
+
+    def read_bclib(self):
+        """
+        Read in the age and wavelength bins, and the flux bins for each
+        metallicity.
+        """
+        self.flux = {}
+        for file in self.model:
+            fname = self.bcdir + "/" + self.model[file]
+            fp = h5py.File(fname, 'r')
+            self.age = fp["agebins"][:] # 1D floats
+            self.wavelength = fp["wavebins"][:] # 1D floats
+            self.flux[file] = fp["flam"][:,:] # 2D floats, [agebin, wavebin]
+            fp.close()
+    
+    def calculate_spectrum(self, data_source=None, star_mass=None,
+            star_creation_time=None, star_metallicity_fraction=None,
+            star_metallicity_constant=None):
+        """
+        For the set of stars, calculate the collective spectrum.
+        Attached to the output are several useful objects:
+        final_spec: The collective spectrum in units of flux binned in wavelength.
+        wavelength: The wavelength for the spectrum bins, in Angstroms.
+        total_mass: Total mass of all the stars.
+        avg_mass: Average mass of all the stars.
+        avg_metal: Average metallicity of all the stars.
+        :param data_source (object): A yt data_source that defines a portion
+        of the volume from which to extract stars.
+        :param star_mass (array, float): An array of star masses in Msun units.
+        :param star_creation_time (array, float): An array of star creation
+        times in code units.
+        :param star_metallicity_fraction (array, float): An array of star
+        metallicity fractions, in code units (which is not Z/Zsun).
+        :param star_metallicity_constant (float): If desired, override the star
+        metallicity fraction of all the stars to the given value.
+        """
+        # Initialize values
+        self.final_spec = na.zeros(self.wavelength.size, dtype='float64')
+        self._data_source = data_source
+        self.star_mass = star_mass
+        self.star_creation_time = star_creation_time
+        self.star_metal = star_metallicity_fraction
+        
+        # Check to make sure we have the right set of data.
+        if data_source is None:
+            if self.star_mass is None or self.star_creation_time is None or \
+            (star_metallicity_fraction is None and star_metallicity_constant is None):
+                mylog.error(
+                """
+                If data_source is not provided, all of these paramters need to be set:
+                star_mass (array, Msun),
+                star_creation_time (array, code units),
+                And one of:
+                star_metallicity_fraction (array, code units).
+                --OR--
+                star_metallicity_constant (float, code units).
+                """)
+                return None
+            if star_metallicity_constant is not None:
+                self.star_metal = na.ones(self.star_mass.size, dtype='float64') * \
+                    star_metallicity_constant
+        else:
+            # Get the data we need.
+            ct = self._data_source["creation_time"]
+            self.star_creation_time = ct[ct > 0]
+            self.star_mass = self._data_source["ParticleMassMsun"][ct > 0]
+            if star_metallicity_constant is not None:
+                self.star_metal = na.ones(self.star_mass.size, dtype='float64') * \
+                    star_metallicity_constant
+            else:
+                self.star_metal = self._data_source["metallicity_fraction"][ct > 0]
+        # Fix metallicity to units of Zsun.
+        self.star_metal /= Zsun
+        # Age of star in years.
+        dt = (self.time_now - self.star_creation_time * self._pf['Time']) / YEAR
+        # Figure out which METALS bin the star goes into.
+        Mindex = na.digitize(dt, METALS)
+        # Replace the indices with strings.
+        Mname = MtoD[Mindex]
+        # Figure out which age bin this star goes into.
+        Aindex = na.digitize(dt, self.age)
+        # Ratios used for the interpolation.
+        ratio1 = (dt - self.age[Aindex-1]) / (self.age[Aindex] - self.age[Aindex-1])
+        ratio2 = (self.age[Aindex] - dt) / (self.age[Aindex] - self.age[Aindex-1])
+        # Sort the stars by metallicity and then by age, which should reduce
+        # memory access time by a little bit in the loop.
+        sort = na.lexsort((Aindex, Mname))
+        Mname = Mname[sort]
+        Aindex = Aindex[sort]
+        ratio1 = ratio1[sort]
+        ratio2 = ratio2[sort]
+        self.star_mass = self.star_mass[sort]
+        self.star_creation_time = self.star_creation_time[sort]
+        self.star_metal = self.star_metal[sort]
+        
+        # Interpolate the flux for each star, adding to the total by weight.
+        for star in itertools.izip(Mname, Aindex, ratio1, ratio2, self.star_mass):
+            # Pick the right age bin for the right flux array.
+            flux = self.flux[star[0]][star[1],:]
+            # Get the one just before the one above.
+            flux_1 = self.flux[star[0]][star[1]-1,:]
+            # interpolate in log(flux), linear in time.
+            int_flux = star[3] * na.log10(flux_1) + star[2] * na.log10(flux)
+            # Add this flux to the total, weighted by mass.
+            self.final_spec += na.power(10., int_flux) * star[4]
+        # Normalize.
+        self.total_mass = sum(self.star_mass)
+        self.avg_mass = na.mean(self.star_mass)
+        tot_metal = sum(self.star_metal * self.star_mass)
+        self.avg_metal = math.log10(tot_metal / self.total_mass / Zsun)
+
+        # Below is an attempt to do the loop using vectors and matrices,
+        # however it doesn't appear to be much faster, probably due to all
+        # the gymnastics that have to be done to do element-by-element
+        # multiplication for matricies.
+        # I'm keeping it in here in case I come up with a more
+        # elegant way that actually is faster.
+#         for metal_name in MtoD:
+#             # Pick out our stars in this metallicity bin.
+#             select = (Mname == metal_name)
+#             A = Aindex[select]
+#             if A.size == 0: continue
+#             r1 = ratio1[select]
+#             r2 = ratio2[select]
+#             sm = self.star_mass[select]
+#             # From the flux array for this metal, and our selection, build
+#             # a new flux array just for the ages of these stars, in the 
+#             # same order as the selection of stars.
+#             this_flux = na.matrix(self.flux[metal_name][A])
+#             # Make one for the last time step for each star in the same fashion
+#             # as above.
+#             this_flux_1 = na.matrix(self.flux[metal_name][A-1])
+#             # This is kind of messy, but we're going to multiply this_fluxes
+#             # by the appropriate ratios and add it together to do the 
+#             # interpolation in log(flux) and linear in time.
+#             print r1.size
+#             r1 = na.matrix(r1.tolist()*self.wavelength.size).reshape(self.wavelength.size,r1.size).T
+#             r2 = na.matrix(r2.tolist()*self.wavelength.size).reshape(self.wavelength.size,r2.size).T
+#             print this_flux_1.shape, r1.shape
+#             int_flux = na.multiply(na.log10(this_flux_1),r1) \
+#                 + na.multiply(na.log10(this_flux),r2)
+#             # Weight the fluxes by mass.
+#             sm = na.matrix(sm.tolist()*self.wavelength.size).reshape(self.wavelength.size,sm.size).T
+#             int_flux = na.multiply(na.power(10., int_flux), sm)
+#             # Sum along the columns, converting back to an array, adding
+#             # to the full spectrum.
+#             self.final_spec += na.array(int_flux.sum(axis=0))[0,:]
+
+    
+    def write_out(self, name="sum_flux.out"):
+        """
+        Write out the summed flux to a file. The file has two columns:
+        1) Wavelength (Angstrom)
+        2) Flux (Luminosity per unit wavelength, L_sun Ang^-1,
+        L_sun = 3.826 * 10^33 ergs s^-1.)
+        :param name (string): Name of file to write to.
+        """
+        fp = open(name, 'w')
+        for i, wave in enumerate(self.wavelength):
+            fp.write("%1.5e\t%1.5e\n" % (wave, self.final_spec[i]))
+        fp.close()
+
+    def write_out_SED(self, name="sum_SED.out", flux_norm=5200.):
+        """
+        Write out the summed SED to a file. The file has two columns:
+        1) Wavelength (Angstrom)
+        2) Relative flux normalized to the flux at *flux_norm*.
+        It also will attach an array *f_nu* which is the normalized flux,
+        identical to the disk output.
+        :param name (string): Name of file to write to.
+        :param flux_norm (float): Wavelength of the flux to normalize the
+        distribution against.
+        """
+        # find the f_nu closest to flux_norm
+        fn_wavelength = na.argmin(abs(self.wavelength - flux_norm))
+        f_nu = self.final_spec * na.power(self.wavelength, 2.) / LIGHT
+        # Normalize f_nu
+        self.f_nu = f_nu / f_nu[fn_wavelength]
+        # Write out.
+        fp = open(name, 'w')
+        for i, wave in enumerate(self.wavelength):
+            fp.write("%1.5e\t%1.5e\n" % (wave, self.f_nu[i]))
+        fp.close()
+

Added: trunk/yt/extensions/kdtree/Makefile
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/Makefile	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,11 @@
+# if you want to build this statically, you need to include fKD.f90 to the
+# compile line, so pick the first of the two below. Otherwise, eliminate it, 
+# like the second, for a shared object.
+
+
+fKD: fKD.f90 fKD.v fKD_source.f90
+#	Forthon --compile_first fKD_source --no2underscores --with-numpy -g fKD fKD.f90 fKD_source.f90
+	Forthon -F gfortran --compile_first fKD_source --no2underscores --with-numpy --fopt "-O3" fKD fKD_source.f90
+
+clean:
+	rm -rf build fKDpy.a fKDpy.so

Added: trunk/yt/extensions/kdtree/__init__.py
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/__init__.py	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,3 @@
+from yt.lagos import *
+
+from fKDpy import *
\ No newline at end of file

Added: trunk/yt/extensions/kdtree/fKD.f90
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/fKD.f90	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,162 @@
+subroutine create_tree()
+    use kdtree2_module
+    use fKD_module
+    use kdtree2module
+    use tree_nodemodule
+    use intervalmodule
+    
+    ! create a kd tree object
+
+     tree2 => kdtree2_create(pos,sort=sort,rearrange=rearrange)  ! this is how you create a tree. 
+     return
+
+end subroutine create_tree
+
+
+subroutine find_nn_nearest_neighbors()
+     use kdtree2_module
+     use fKD_module
+     use kdtree2module
+     use tree_nodemodule
+     use intervalmodule
+
+     integer :: k
+     type(kdtree2_result),allocatable :: results(:) ! nearest neighbors
+     !integer, parameter  :: nn ! number of nearest neighbors found
+
+
+     allocate(results(nn)) 
+
+     call kdtree2_n_nearest(tp=tree2,qv=qv,nn=nn,results=results) 
+
+     dist = results%dis
+     tags = results%idx
+
+  !do k=1,nn
+  !   print *, "k = ", k, " idx = ", tags(k)," dis = ", dist(k)
+  !   print *, "x y z", pos(1,results(k)%idx), pos(2,results(k)%idx), pos(3,results(k)%idx)
+  !enddo
+
+
+     deallocate(results)
+     return
+
+end subroutine find_nn_nearest_neighbors
+
+subroutine find_all_nn_nearest_neighbors()
+    ! for all particles in pos, find their nearest neighbors and return the
+    ! indexes and distances as big arrays
+    use kdtree2_module
+    use fKD_module
+    use kdtree2module
+    use tree_nodemodule
+    use intervalmodule
+
+    integer :: k
+    type(kdtree2_result),allocatable :: results(:) ! nearest neighbors
+    allocate(results(nn))
+    
+    do k=1,nparts
+        qv(:) = pos(:,k)
+        call kdtree2_n_nearest(tp=tree2,qv=qv,nn=nn,results=results)
+        nn_dist(:,k) = results%dis
+        nn_tags(:,k) = results%idx
+    end do
+    
+    deallocate(results)
+    return
+
+end subroutine find_all_nn_nearest_neighbors
+
+subroutine find_chunk_nearest_neighbors()
+    ! for a chunk of the full number of particles, find their nearest neighbors
+    use kdtree2_module
+    use fKD_module
+    use kdtree2module
+    use tree_nodemodule
+    use intervalmodule
+
+    integer :: k
+    type(kdtree2_result),allocatable :: results(:) ! nearest neighbors
+    allocate(results(nn))
+    do k=start,finish
+        qv(:) = pos(:,k)
+        call kdtree2_n_nearest(tp=tree2,qv=qv,nn=nn,results=results)
+        chunk_tags(:,k - start + 1) = results%idx
+
+    end do
+    
+    deallocate(results)
+    return
+
+end subroutine find_chunk_nearest_neighbors
+
+subroutine chainHOP_tags_dens()
+    ! for all particles in pos, find their nearest neighbors, and calculate
+    ! their density. Return only nMerge nearest neighbors.
+    use kdtree2_module
+    use fKD_module
+    use kdtree2module
+    use tree_nodemodule
+    use intervalmodule
+
+    integer :: k, pj, i
+    real :: ih2, fNorm, r2, rs
+    integer, allocatable :: temp_tags(:)
+    real, allocatable :: temp_dist(:)
+    type(kdtree2_result),allocatable :: results(:) ! nearest neighbors
+    allocate(results(nn))
+    allocate(temp_tags(nn))
+    allocate(temp_dist(nn))
+    
+    do k=1,nparts
+        qv(:) = pos(:,k)
+        
+        call kdtree2_n_nearest(tp=tree2,qv=qv,nn=nn,results=results)
+        temp_tags(:) = results%idx
+        temp_dist(:) = results%dis
+        
+        ! calculate the density for this particle
+        ih2 = 4.0/maxval(results%dis)
+        fNorm = 0.5*sqrt(ih2)*ih2/3.1415926535897931
+        do i=1,nn
+            pj = temp_tags(i)
+            r2 = temp_dist(i) * ih2
+            rs = 2.0 - sqrt(r2)
+            if (r2 < 1.0) then
+                rs = (1.0 - 0.75*rs*r2)
+            else
+                rs = 0.25*rs*rs*rs
+            end if
+            rs = rs * fNorm
+            dens(k) = dens(k) + rs * mass(pj)
+            dens(pj) = dens(pj) + rs * mass(k)
+        end do
+
+        ! record only nMerge nearest neighbors, but skip the first one which
+        ! is always the self-same particle
+        ! nn_tags(:,k) = temp_tags(2:nMerge)
+    end do
+    
+    deallocate(results)
+    deallocate(temp_dist)
+    deallocate(temp_tags)
+    return
+
+end subroutine chainHOP_tags_dens
+
+subroutine free_tree()
+    use kdtree2_module
+    use fKD_module
+    use kdtree2module
+    use tree_nodemodule
+    use intervalmodule
+    
+    ! this releases memory for the tree BUT NOT THE ARRAY OF DATA YOU PASSED
+    ! TO MAKE THE TREE.  
+    call kdtree2_destroy(tree2)
+    
+    ! The data to make the tree has to be deleted in python BEFORE calling
+    ! this!
+end subroutine free_tree
+

Added: trunk/yt/extensions/kdtree/fKD.v
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/fKD.v	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,101 @@
+fKD
+
+****** fKD_module vars:
+# Not all of these are being used, but they only take memory
+# if they're initialized in python.
+tags(:) _integer # particle ID tags
+dist(:) _real # interparticle spacings
+nn_tags(:,:) _integer # for all particles at once, [nth neighbor, index]
+chunk_tags(:,:) _integer # for finding only a chunk of the nearest neighbors
+nn_dist(:,:) _real 
+pos(3,:) _real
+dens(:) _real
+mass(:) _real
+qv(3) real
+nparts integer
+nn integer
+nMerge integer # number of nearest neighbors used in chain merging
+start integer
+finish integer
+tree2 _kdtree2
+sort logical /.false./
+rearrange logical /.true./
+
+%%%% interval:
+lower real
+upper real
+#real(kdkind) :: lower,upper
+
+
+%%%% tree_node:
+# an internal tree node
+cut_dim integer
+#integer :: cut_dim
+# the dimension to cut
+cut_val real
+#real(kdkind) :: cut_val
+# where to cut the dimension
+cut_val_left real
+cut_val_right real
+#real(kdkind) :: cut_val_left, cut_val_right  
+# improved cutoffs knowing the spread in child boxes.
+u integer
+l integer
+#integer :: l, u
+left _tree_node
+right _tree_node
+#type(tree_node), pointer :: left, right
+box(:) _interval
+#type(interval), pointer :: box(:) => null()
+# child pointers
+# Points included in this node are indexes[k] with k \in [l,u] 
+
+
+%%%% kdtree2:
+# Global information about the tree, one per tree
+dimen integer /0/
+n integer /0/
+# dimensionality and total # of points
+the_data(:,:) _real
+#real(kdkind), pointer :: the_data(:,:) => null()
+# pointer to the actual data array 
+# 
+#  IMPORTANT NOTE:  IT IS DIMENSIONED   the_data(1:d,1:N)
+#  which may be opposite of what may be conventional.
+#  This is, because in Fortran, the memory layout is such that
+#  the first dimension is in sequential order.  Hence, with
+#  (1:d,1:N), all components of the vector will be in consecutive
+#  memory locations.  The search time is dominated by the
+#  evaluation of distances in the terminal nodes.  Putting all
+#  vector components in consecutive memory location improves
+#  memory cache locality, and hence search speed, and may enable 
+#  vectorization on some processors and compilers. 
+ind(:) _integer
+#integer, pointer :: ind(:) => null()
+# permuted index into the data, so that indexes[l..u] of some
+# bucket represent the indexes of the actual points in that
+# bucket.
+# do we always sort output results?
+sort logical /.false./
+#logical       :: sort = .false.
+rearrange logical /.false./
+#logical       :: rearrange = .false. 
+rearranged_data(:,:) _real
+#real(kdkind), pointer :: rearranged_data(:,:) => null()
+# if (rearrange .eqv. .true.) then rearranged_data has been
+# created so that rearranged_data(:,i) = the_data(:,ind(i)),
+# permitting search to use more cache-friendly rearranged_data, at
+# some initial computation and storage cost.
+root _tree_node
+#type(tree_node), pointer :: root => null()
+# root pointer of the tree
+
+
+
+***** Subroutines:
+find_nn_nearest_neighbors subroutine
+create_tree() subroutine
+free_tree() subroutine
+find_all_nn_nearest_neighbors subroutine
+find_chunk_nearest_neighbors subroutine
+chainHOP_tags_dens subroutine

Added: trunk/yt/extensions/kdtree/fKD_source.f90
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/fKD_source.f90	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,1955 @@
+!
+!(c) Matthew Kennel, Institute for Nonlinear Science (2004)
+!
+! Licensed under the Academic Free License version 1.1 found in file LICENSE
+! with additional provisions found in that same file.
+!
+! Modified by Stephen Skory, CASS/UCSD (2009), adding periodicity and changes
+! so this can be used by as a Python module using Forthon
+
+module kdtree2_precision_module
+  
+  integer, parameter :: sp = kind(0.0)
+  integer, parameter :: dp = kind(0.0d0)
+
+  private :: sp, dp
+
+  !
+  ! You must comment out exactly one
+  ! of the two lines.  If you comment
+  ! out kdkind = sp then you get single precision
+  ! and if you comment out kdkind = dp 
+  ! you get double precision.
+  !
+
+  integer, parameter :: kdkind = sp  
+  !integer, parameter :: kdkind = dp  
+  public :: kdkind
+
+end module kdtree2_precision_module
+
+module kdtree2_priority_queue_module
+  use kdtree2_precision_module
+  !
+  ! maintain a priority queue (PQ) of data, pairs of 'priority/payload', 
+  ! implemented with a binary heap.  This is the type, and the 'dis' field
+  ! is the priority.
+  !
+  type kdtree2_result
+      ! a pair of distances, indexes
+      real(kdkind)    :: dis!=0.0
+      integer :: idx!=-1   Initializers cause some bugs in compilers.
+  end type kdtree2_result
+  !
+  ! A heap-based priority queue lets one efficiently implement the following
+  ! operations, each in log(N) time, as opposed to linear time.
+  !
+  ! 1)  add a datum (push a datum onto the queue, increasing its length) 
+  ! 2)  return the priority value of the maximum priority element 
+  ! 3)  pop-off (and delete) the element with the maximum priority, decreasing
+  !     the size of the queue. 
+  ! 4)  replace the datum with the maximum priority with a supplied datum
+  !     (of either higher or lower priority), maintaining the size of the
+  !     queue. 
+  !
+  !
+  ! In the k-d tree case, the 'priority' is the square distance of a point in
+  ! the data set to a reference point.   The goal is to keep the smallest M
+  ! distances to a reference point.  The tree algorithm searches terminal
+  ! nodes to decide whether to add points under consideration.
+  !
+  ! A priority queue is useful here because it lets one quickly return the
+  ! largest distance currently existing in the list.  If a new candidate
+  ! distance is smaller than this, then the new candidate ought to replace
+  ! the old candidate.  In priority queue terms, this means removing the
+  ! highest priority element, and inserting the new one.
+  !
+  ! Algorithms based on Cormen, Leiserson, Rivest, _Introduction
+  ! to Algorithms_, 1990, with further optimization by the author.
+  !
+  ! Originally informed by a C implementation by Sriranga Veeraraghavan.
+  !
+  ! This module is not written in the most clear way, but is implemented such
+  ! for speed, as it its operations will be called many times during searches
+  ! of large numbers of neighbors.
+  !
+  type pq
+      !
+      ! The priority queue consists of elements
+      ! priority(1:heap_size), with associated payload(:).
+      !
+      ! There are heap_size active elements. 
+      ! Assumes the allocation is always sufficient.  Will NOT increase it
+      ! to match.
+      integer :: heap_size = 0
+      type(kdtree2_result), pointer :: elems(:) 
+  end type pq
+
+  public :: kdtree2_result
+
+  public :: pq
+  public :: pq_create
+  public :: pq_delete, pq_insert
+  public :: pq_extract_max, pq_max, pq_replace_max, pq_maxpri
+  private
+
+contains
+
+
+  function pq_create(results_in) result(res)
+    !
+    ! Create a priority queue from ALREADY allocated
+    ! array pointers for storage.  NOTE! It will NOT
+    ! add any alements to the heap, i.e. any existing
+    ! data in the input arrays will NOT be used and may
+    ! be overwritten.
+    ! 
+    ! usage:
+    !    real(kdkind), pointer :: x(:)
+    !    integer, pointer :: k(:)
+    !    allocate(x(1000),k(1000))
+    !    pq => pq_create(x,k)
+    !
+    type(kdtree2_result), target:: results_in(:) 
+    type(pq) :: res
+    !
+    !
+    integer :: nalloc
+
+    nalloc = size(results_in,1)
+    if (nalloc .lt. 1) then
+       write (*,*) 'PQ_CREATE: error, input arrays must be allocated.'
+    end if
+    res%elems => results_in
+    res%heap_size = 0
+    return
+  end function pq_create
+
+  !
+  ! operations for getting parents and left + right children
+  ! of elements in a binary heap.
+  !
+
+!
+! These are written inline for speed.
+!    
+!  integer function parent(i)
+!    integer, intent(in) :: i
+!    parent = (i/2)
+!    return
+!  end function parent
+
+!  integer function left(i)
+!    integer, intent(in) ::i
+!    left = (2*i)
+!    return
+!  end function left
+
+!  integer function right(i)
+!    integer, intent(in) :: i
+!    right = (2*i)+1
+!    return
+!  end function right
+
+!  logical function compare_priority(p1,p2)
+!    real(kdkind), intent(in) :: p1, p2
+!
+!    compare_priority = (p1 .gt. p2)
+!    return
+!  end function compare_priority
+
+  subroutine heapify(a,i_in)
+    !
+    ! take a heap rooted at 'i' and force it to be in the
+    ! heap canonical form.   This is performance critical 
+    ! and has been tweaked a little to reflect this.
+    !
+    type(pq),pointer   :: a
+    integer, intent(in) :: i_in
+    !
+    integer :: i, l, r, largest
+
+    real(kdkind)    :: pri_i, pri_l, pri_r, pri_largest
+
+
+    type(kdtree2_result) :: temp
+
+    i = i_in
+
+bigloop:  do
+       l = 2*i ! left(i)
+       r = l+1 ! right(i)
+       ! 
+       ! set 'largest' to the index of either i, l, r
+       ! depending on whose priority is largest.
+       !
+       ! note that l or r can be larger than the heap size
+       ! in which case they do not count.
+
+
+       ! does left child have higher priority? 
+       if (l .gt. a%heap_size) then
+          ! we know that i is the largest as both l and r are invalid.
+          exit 
+       else
+          pri_i = a%elems(i)%dis
+          pri_l = a%elems(l)%dis 
+          if (pri_l .gt. pri_i) then
+             largest = l
+             pri_largest = pri_l
+          else
+             largest = i
+             pri_largest = pri_i
+          endif
+
+          !
+          ! between i and l we have a winner
+          ! now choose between that and r.
+          !
+          if (r .le. a%heap_size) then
+             pri_r = a%elems(r)%dis
+             if (pri_r .gt. pri_largest) then
+                largest = r
+             endif
+          endif
+       endif
+
+       if (largest .ne. i) then
+          ! swap data in nodes largest and i, then heapify
+
+          temp = a%elems(i)
+          a%elems(i) = a%elems(largest)
+          a%elems(largest) = temp 
+          ! 
+          ! Canonical heapify() algorithm has tail-ecursive call: 
+          !
+          !        call heapify(a,largest)   
+          ! we will simulate with cycle
+          !
+          i = largest
+          cycle bigloop ! continue the loop 
+       else
+          return   ! break from the loop
+       end if
+    enddo bigloop
+    return
+  end subroutine heapify
+
+  subroutine pq_max(a,e) 
+    !
+    ! return the priority and its payload of the maximum priority element
+    ! on the queue, which should be the first one, if it is 
+    ! in heapified form.
+    !
+    type(pq),pointer :: a
+    type(kdtree2_result),intent(out)  :: e
+
+    if (a%heap_size .gt. 0) then
+       e = a%elems(1) 
+    else
+       write (*,*) 'PQ_MAX: ERROR, heap_size < 1'
+       stop
+    endif
+    return
+  end subroutine pq_max
+  
+  real(kdkind) function pq_maxpri(a)
+    type(pq), pointer :: a
+
+    if (a%heap_size .gt. 0) then
+       pq_maxpri = a%elems(1)%dis
+    else
+       write (*,*) 'PQ_MAX_PRI: ERROR, heapsize < 1'
+       stop
+    endif
+    return
+  end function pq_maxpri
+
+  subroutine pq_extract_max(a,e)
+    !
+    ! return the priority and payload of maximum priority
+    ! element, and remove it from the queue.
+    ! (equivalent to 'pop()' on a stack)
+    !
+    type(pq),pointer :: a
+    type(kdtree2_result), intent(out) :: e
+    
+    if (a%heap_size .ge. 1) then
+       !
+       ! return max as first element
+       !
+       e = a%elems(1) 
+       
+       !
+       ! move last element to first
+       !
+       a%elems(1) = a%elems(a%heap_size) 
+       a%heap_size = a%heap_size-1
+       call heapify(a,1)
+       return
+    else
+       write (*,*) 'PQ_EXTRACT_MAX: error, attempted to pop non-positive PQ'
+       stop
+    end if
+    
+  end subroutine pq_extract_max
+
+
+  real(kdkind) function pq_insert(a,dis,idx) 
+    !
+    ! Insert a new element and return the new maximum priority,
+    ! which may or may not be the same as the old maximum priority.
+    !
+    type(pq),pointer  :: a
+    real(kdkind), intent(in) :: dis
+    integer, intent(in) :: idx
+    !    type(kdtree2_result), intent(in) :: e
+    !
+    integer :: i, isparent
+    real(kdkind)    :: parentdis
+    !
+
+    !    if (a%heap_size .ge. a%max_elems) then
+    !       write (*,*) 'PQ_INSERT: error, attempt made to insert element on full PQ'
+    !       stop
+    !    else
+    a%heap_size = a%heap_size + 1
+    i = a%heap_size
+
+    do while (i .gt. 1)
+       isparent = int(i/2)
+       parentdis = a%elems(isparent)%dis
+       if (dis .gt. parentdis) then
+          ! move what was in i's parent into i.
+          a%elems(i)%dis = parentdis
+          a%elems(i)%idx = a%elems(isparent)%idx
+          i = isparent
+       else
+          exit
+       endif
+    end do
+
+    ! insert the element at the determined position
+    a%elems(i)%dis = dis
+    a%elems(i)%idx = idx
+
+    pq_insert = a%elems(1)%dis 
+    return
+    !    end if
+
+  end function pq_insert
+
+  subroutine pq_adjust_heap(a,i)
+    type(pq),pointer  :: a
+    integer, intent(in) :: i
+    !
+    ! nominally arguments (a,i), but specialize for a=1
+    !
+    ! This routine assumes that the trees with roots 2 and 3 are already heaps, i.e.
+    ! the children of '1' are heaps.  When the procedure is completed, the
+    ! tree rooted at 1 is a heap.
+    real(kdkind) :: prichild
+    integer :: parent, child, N
+
+    type(kdtree2_result) :: e
+
+    e = a%elems(i) 
+
+    parent = i
+    child = 2*i
+    N = a%heap_size
+    
+    do while (child .le. N)
+       if (child .lt. N) then
+          if (a%elems(child)%dis .lt. a%elems(child+1)%dis) then
+             child = child+1
+          endif
+       endif
+       prichild = a%elems(child)%dis
+       if (e%dis .ge. prichild) then
+          exit 
+       else
+          ! move child into parent.
+          a%elems(parent) = a%elems(child) 
+          parent = child
+          child = 2*parent
+       end if
+    end do
+    a%elems(parent) = e
+    return
+  end subroutine pq_adjust_heap
+    
+
+  real(kdkind) function pq_replace_max(a,dis,idx) 
+    !
+    ! Replace the extant maximum priority element
+    ! in the PQ with (dis,idx).  Return
+    ! the new maximum priority, which may be larger
+    ! or smaller than the old one.
+    !
+    type(pq),pointer         :: a
+    real(kdkind), intent(in) :: dis
+    integer, intent(in) :: idx
+!    type(kdtree2_result), intent(in) :: e
+    ! not tested as well!
+
+    integer :: parent, child, N
+    real(kdkind)    :: prichild, prichildp1
+
+    type(kdtree2_result) :: etmp
+    
+    if (.true.) then
+       N=a%heap_size
+       if (N .ge. 1) then
+          parent =1
+          child=2
+
+          loop: do while (child .le. N)
+             prichild = a%elems(child)%dis
+
+             !
+             ! posibly child+1 has higher priority, and if
+             ! so, get it, and increment child.
+             !
+
+             if (child .lt. N) then
+                prichildp1 = a%elems(child+1)%dis
+                if (prichild .lt. prichildp1) then
+                   child = child+1
+                   prichild = prichildp1
+                endif
+             endif
+
+             if (dis .ge. prichild) then
+                exit loop  
+                ! we have a proper place for our new element, 
+                ! bigger than either children's priority.
+             else
+                ! move child into parent.
+                a%elems(parent) = a%elems(child) 
+                parent = child
+                child = 2*parent
+             end if
+          end do loop
+          a%elems(parent)%dis = dis
+          a%elems(parent)%idx = idx
+          pq_replace_max = a%elems(1)%dis
+       else
+          a%elems(1)%dis = dis
+          a%elems(1)%idx = idx
+          pq_replace_max = dis
+       endif
+    else
+       !
+       ! slower version using elementary pop and push operations.
+       !
+       call pq_extract_max(a,etmp) 
+       etmp%dis = dis
+       etmp%idx = idx
+       pq_replace_max = pq_insert(a,dis,idx)
+    endif
+    return
+  end function pq_replace_max
+
+  subroutine pq_delete(a,i)
+    ! 
+    ! delete item with index 'i'
+    !
+    type(pq),pointer :: a
+    integer           :: i
+
+    if ((i .lt. 1) .or. (i .gt. a%heap_size)) then
+       write (*,*) 'PQ_DELETE: error, attempt to remove out of bounds element.'
+       stop
+    endif
+
+    ! swap the item to be deleted with the last element
+    ! and shorten heap by one.
+    a%elems(i) = a%elems(a%heap_size) 
+    a%heap_size = a%heap_size - 1
+
+    call heapify(a,i)
+
+  end subroutine pq_delete
+
+end module kdtree2_priority_queue_module
+
+
+module kdtree2_module
+  use kdtree2_precision_module
+  use kdtree2_priority_queue_module
+  use kdtree2module
+  use tree_nodemodule
+  use intervalmodule
+  ! K-D tree routines in Fortran 90 by Matt Kennel.
+  ! Original program was written in Sather by Steve Omohundro and
+  ! Matt Kennel.  Only the Euclidean metric is supported. 
+  !
+  !
+  ! This module is identical to 'kd_tree', except that the order
+  ! of subscripts is reversed in the data file.
+  ! In otherwords for an embedding of N D-dimensional vectors, the
+  ! data file is here, in natural Fortran order  data(1:D, 1:N)
+  ! because Fortran lays out columns first,
+  !
+  ! whereas conventionally (C-style) it is data(1:N,1:D)
+  ! as in the original kd_tree module. 
+  !
+  !-------------DATA TYPE, CREATION, DELETION---------------------
+  public :: kdkind
+  public :: kdtree2, kdtree2_result, tree_node, kdtree2_create, kdtree2_destroy
+  !---------------------------------------------------------------
+  !-------------------SEARCH ROUTINES-----------------------------
+  public :: kdtree2_n_nearest,kdtree2_n_nearest_around_point
+  ! Return fixed number of nearest neighbors around arbitrary vector,
+  ! or extant point in dataset, with decorrelation window. 
+  !
+  public :: kdtree2_r_nearest, kdtree2_r_nearest_around_point
+  ! Return points within a fixed ball of arb vector/extant point 
+  !
+  public :: kdtree2_sort_results
+  ! Sort, in order of increasing distance, rseults from above.
+  !
+  public :: kdtree2_r_count, kdtree2_r_count_around_point 
+  ! Count points within a fixed ball of arb vector/extant point 
+  !
+  public :: kdtree2_n_nearest_brute_force, kdtree2_r_nearest_brute_force
+  ! brute force of kdtree2_[n|r]_nearest
+  !----------------------------------------------------------------
+
+
+  integer, parameter :: bucket_size = 65
+  ! The maximum number of points to keep in a terminal node.
+
+!  type interval
+!      real(kdkind) :: lower,upper
+!  end type interval
+!
+!  type :: tree_node
+!      ! an internal tree node
+!      private
+!      integer :: cut_dim
+!      ! the dimension to cut
+!      real(kdkind) :: cut_val
+!      ! where to cut the dimension
+!      real(kdkind) :: cut_val_left, cut_val_right  
+!      ! improved cutoffs knowing the spread in child boxes.
+!      integer :: l, u
+!      type(tree_node), pointer :: left, right
+!      type(interval), pointer :: box(:) => null()
+!      ! child pointers
+!      ! Points included in this node are indexes[k] with k \in [l,u] 
+!
+!
+!  end type tree_node
+!
+!  type :: kdtree2
+!      ! Global information about the tree, one per tree
+!      integer :: dimen=0, n=0
+!      ! dimensionality and total # of points
+!      real(kdkind), pointer :: the_data(:,:) => null()
+!      ! pointer to the actual data array 
+!      ! 
+!      !  IMPORTANT NOTE:  IT IS DIMENSIONED   the_data(1:d,1:N)
+!      !  which may be opposite of what may be conventional.
+!      !  This is, because in Fortran, the memory layout is such that
+!      !  the first dimension is in sequential order.  Hence, with
+!      !  (1:d,1:N), all components of the vector will be in consecutive
+!      !  memory locations.  The search time is dominated by the
+!      !  evaluation of distances in the terminal nodes.  Putting all
+!      !  vector components in consecutive memory location improves
+!      !  memory cache locality, and hence search speed, and may enable 
+!      !  vectorization on some processors and compilers. 
+!
+!      integer, pointer :: ind(:) => null()
+!      ! permuted index into the data, so that indexes[l..u] of some
+!      ! bucket represent the indexes of the actual points in that
+!      ! bucket.
+!      logical       :: sort = .false.
+!      ! do we always sort output results?
+!      logical       :: rearrange = .false. 
+!      real(kdkind), pointer :: rearranged_data(:,:) => null()
+!      ! if (rearrange .eqv. .true.) then rearranged_data has been
+!      ! created so that rearranged_data(:,i) = the_data(:,ind(i)),
+!      ! permitting search to use more cache-friendly rearranged_data, at
+!      ! some initial computation and storage cost.
+!      type(tree_node), pointer :: root => null()
+!      ! root pointer of the tree
+!  end type kdtree2
+
+  type :: tree_search_record
+      !
+      ! One of these is created for each search.
+      !
+      private
+      ! 
+      ! Many fields are copied from the tree structure, in order to
+      ! speed up the search.
+      !
+      integer           :: dimen   
+      integer           :: nn, nfound
+      real(kdkind)      :: ballsize
+      integer           :: centeridx=999, correltime=9999
+      ! exclude points within 'correltime' of 'centeridx', iff centeridx >= 0
+      integer           :: nalloc  ! how much allocated for results(:)?
+      logical           :: rearrange  ! are the data rearranged or original? 
+      ! did the # of points found overflow the storage provided?
+      logical           :: overflow
+      real(kdkind), pointer :: qv(:)  ! query vector
+      type(kdtree2_result), pointer :: results(:) ! results
+      type(pq) :: pq
+      real(kdkind), pointer :: data(:,:)  ! temp pointer to data
+      integer, pointer      :: ind(:)     ! temp pointer to indexes
+  end type tree_search_record
+
+  private
+  ! everything else is private.
+
+  type(tree_search_record), save, target :: sr   ! A GLOBAL VARIABLE for search
+
+contains
+
+  function kdtree2_create(input_data,dim,sort,rearrange) result (mr)
+    !
+    ! create the actual tree structure, given an input array of data.
+    !
+    ! Note, input data is input_data(1:d,1:N), NOT the other way around.
+    ! THIS IS THE REVERSE OF THE PREVIOUS VERSION OF THIS MODULE.
+    ! The reason for it is cache friendliness, improving performance.
+    !
+    ! Optional arguments:  If 'dim' is specified, then the tree
+    !                      will only search the first 'dim' components
+    !                      of input_data, otherwise, dim is inferred
+    !                      from SIZE(input_data,1).
+    !
+    !                      if sort .eqv. .true. then output results
+    !                      will be sorted by increasing distance.
+    !                      default=.false., as it is faster to not sort.
+    !                      
+    !                      if rearrange .eqv. .true. then an internal
+    !                      copy of the data, rearranged by terminal node,
+    !                      will be made for cache friendliness. 
+    !                      default=.true., as it speeds searches, but
+    !                      building takes longer, and extra memory is used.
+    !
+    ! .. Function Return Cut_value ..
+    type(kdtree2), pointer :: mr
+    integer, intent(in), optional      :: dim
+    logical, intent(in), optional      :: sort
+    logical, intent(in), optional      :: rearrange
+    ! ..
+    ! .. Array Arguments ..
+    real(kdkind), target :: input_data(:,:)
+    !
+    integer :: i
+    ! ..
+    allocate (mr)
+    mr%the_data => input_data
+    ! pointer assignment
+
+    if (present(dim)) then
+       mr%dimen = dim
+    else
+       mr%dimen = size(input_data,1)
+    end if
+    mr%n = size(input_data,2)
+
+    if (mr%dimen > mr%n) then
+       !  unlikely to be correct
+       write (*,*) 'KD_TREE_TRANS: likely user error.'
+       write (*,*) 'KD_TREE_TRANS: You passed in matrix with D=',mr%dimen
+       write (*,*) 'KD_TREE_TRANS: and N=',mr%n
+       write (*,*) 'KD_TREE_TRANS: note, that new format is data(1:D,1:N)'
+       write (*,*) 'KD_TREE_TRANS: with usually N >> D.   If N =approx= D, then a k-d tree'
+       write (*,*) 'KD_TREE_TRANS: is not an appropriate data structure.'
+       stop
+    end if
+
+    call build_tree(mr)
+
+    if (present(sort)) then
+       mr%sort = sort
+    else
+       mr%sort = .false.
+    endif
+
+    if (present(rearrange)) then
+       mr%rearrange = rearrange
+    else
+       mr%rearrange = .true.
+    endif
+
+    if (mr%rearrange) then
+       allocate(mr%rearranged_data(mr%dimen,mr%n))
+       do i=1,mr%n
+          mr%rearranged_data(:,i) = mr%the_data(:, &
+           mr%ind(i))
+       enddo
+    else
+       nullify(mr%rearranged_data)
+    endif
+
+  end function kdtree2_create
+
+    subroutine build_tree(tp)
+      type(kdtree2), pointer :: tp
+      ! ..
+      integer :: j
+      type(tree_node), pointer :: dummy => null()
+      ! ..
+      allocate (tp%ind(tp%n))
+      forall (j=1:tp%n)
+         tp%ind(j) = j
+      end forall
+      tp%root => build_tree_for_range(tp,1,tp%n, dummy)
+    end subroutine build_tree
+
+    recursive function build_tree_for_range(tp,l,u,parent) result (res)
+      ! .. Function Return Cut_value ..
+      type(tree_node), pointer :: res
+      ! ..
+      ! .. Structure Arguments ..
+      type(kdtree2), pointer :: tp
+      type(tree_node),pointer           :: parent
+      ! ..
+      ! .. Scalar Arguments ..
+      integer, intent (In) :: l, u
+      ! ..
+      ! .. Local Scalars ..
+      integer :: i, c, m, dimen
+      logical :: recompute
+      real(kdkind)    :: average
+
+!!$      If (.False.) Then 
+!!$         If ((l .Lt. 1) .Or. (l .Gt. tp%n)) Then
+!!$            Stop 'illegal L value in build_tree_for_range'
+!!$         End If
+!!$         If ((u .Lt. 1) .Or. (u .Gt. tp%n)) Then
+!!$            Stop 'illegal u value in build_tree_for_range'
+!!$         End If
+!!$         If (u .Lt. l) Then
+!!$            Stop 'U is less than L, thats illegal.'
+!!$         End If
+!!$      Endif
+!!$      
+      ! first compute min and max
+      dimen = tp%dimen
+      allocate (res)
+      allocate(res%box(dimen))
+
+      ! First, compute an APPROXIMATE bounding box of all points associated with this node.
+      if ( u < l ) then
+         ! no points in this box
+         nullify(res)
+         return
+      end if
+
+      if ((u-l)<=bucket_size) then
+         !
+         ! always compute true bounding box for terminal nodes.
+         !
+         do i=1,dimen
+            call spread_in_coordinate(tp,i,l,u,res%box(i))
+         end do
+         res%cut_dim = 0
+         res%cut_val = 0.0
+         res%l = l
+         res%u = u
+         res%left =>null()
+         res%right => null() 
+      else
+         ! 
+         ! modify approximate bounding box.  This will be an
+         ! overestimate of the true bounding box, as we are only recomputing 
+         ! the bounding box for the dimension that the parent split on.
+         !
+         ! Going to a true bounding box computation would significantly
+         ! increase the time necessary to build the tree, and usually
+         ! has only a very small difference.  This box is not used
+         ! for searching but only for deciding which coordinate to split on.
+         !
+         do i=1,dimen
+            recompute=.true.
+            if (associated(parent)) then
+               if (i .ne. parent%cut_dim) then
+                  recompute=.false.
+               end if
+            endif
+            if (recompute) then
+               call spread_in_coordinate(tp,i,l,u,res%box(i))
+            else
+               res%box(i) = parent%box(i)
+            endif
+         end do
+         
+
+         c = maxloc(res%box(1:dimen)%upper-res%box(1:dimen)%lower,1)
+         !
+         ! c is the identity of which coordinate has the greatest spread.
+         !
+         
+         if (.false.) then
+            ! select exact median to have fully balanced tree.
+            m = (l+u)/2
+            call select_on_coordinate(tp%the_data,tp%ind,c,m,l,u)
+         else
+            !
+            ! select point halfway between min and max, as per A. Moore,
+            ! who says this helps in some degenerate cases, or 
+            ! actual arithmetic average. 
+            !
+            if (.true.) then
+               ! actually compute average
+               average = sum(tp%the_data(c,tp%ind(l:u))) / real(u-l+1,kdkind)
+            else
+               average = (res%box(c)%upper + res%box(c)%lower)/2.0
+            endif
+               
+            res%cut_val = average
+            m = select_on_coordinate_value(tp%the_data,tp%ind,c,average,l,u)
+         endif
+            
+         ! moves indexes around
+         res%cut_dim = c
+         res%l = l
+         res%u = u
+!         res%cut_val = tp%the_data(c,tp%ind(m))
+
+         res%left => build_tree_for_range(tp,l,m,res)
+         res%right => build_tree_for_range(tp,m+1,u,res)
+
+         if (associated(res%right) .eqv. .false.) then
+            res%box = res%left%box
+            res%cut_val_left = res%left%box(c)%upper
+            res%cut_val = res%cut_val_left
+         elseif (associated(res%left) .eqv. .false.) then
+            res%box = res%right%box
+            res%cut_val_right = res%right%box(c)%lower
+            res%cut_val = res%cut_val_right
+         else
+            res%cut_val_right = res%right%box(c)%lower
+            res%cut_val_left = res%left%box(c)%upper
+            res%cut_val = (res%cut_val_left + res%cut_val_right)/2
+
+
+            ! now remake the true bounding box for self.  
+            ! Since we are taking unions (in effect) of a tree structure,
+            ! this is much faster than doing an exhaustive
+            ! search over all points
+            res%box%upper = max(res%left%box%upper,res%right%box%upper)
+            res%box%lower = min(res%left%box%lower,res%right%box%lower) 
+         endif
+      end if
+    end function build_tree_for_range
+
+    integer function select_on_coordinate_value(v,ind,c,alpha,li,ui) &
+     result(res)
+      ! Move elts of ind around between l and u, so that all points
+      ! <= than alpha (in c cooordinate) are first, and then
+      ! all points > alpha are second. 
+
+      !
+      ! Algorithm (matt kennel). 
+      !
+      ! Consider the list as having three parts: on the left,
+      ! the points known to be <= alpha.  On the right, the points
+      ! known to be > alpha, and in the middle, the currently unknown
+      ! points.   The algorithm is to scan the unknown points, starting
+      ! from the left, and swapping them so that they are added to
+      ! the left stack or the right stack, as appropriate.
+      ! 
+      ! The algorithm finishes when the unknown stack is empty. 
+      !
+      ! .. Scalar Arguments ..
+      integer, intent (In) :: c, li, ui
+      real(kdkind), intent(in) :: alpha
+      ! ..
+      real(kdkind) :: v(1:,1:)
+      integer :: ind(1:)
+      integer :: tmp  
+      ! ..
+      integer :: lb, rb
+      !
+      ! The points known to be <= alpha are in
+      ! [l,lb-1]
+      !
+      ! The points known to be > alpha are in
+      ! [rb+1,u].  
+      !
+      ! Therefore we add new points into lb or
+      ! rb as appropriate.  When lb=rb
+      ! we are done.  We return the location of the last point <= alpha.
+      !
+      ! 
+      lb = li; rb = ui
+
+      do while (lb < rb)
+         if ( v(c,ind(lb)) <= alpha ) then
+            ! it is good where it is.
+            lb = lb+1
+         else
+            ! swap it with rb.
+            tmp = ind(lb); ind(lb) = ind(rb); ind(rb) = tmp
+            rb = rb-1
+         endif
+      end do
+      
+      ! now lb .eq. ub 
+      if (v(c,ind(lb)) <= alpha) then
+         res = lb
+      else
+         res = lb-1
+      endif
+      
+    end function select_on_coordinate_value
+
+    subroutine select_on_coordinate(v,ind,c,k,li,ui)
+      ! Move elts of ind around between l and u, so that the kth
+      ! element
+      ! is >= those below, <= those above, in the coordinate c.
+      ! .. Scalar Arguments ..
+      integer, intent (In) :: c, k, li, ui
+      ! ..
+      integer :: i, l, m, s, t, u
+      ! ..
+      real(kdkind) :: v(:,:)
+      integer :: ind(:)
+      ! ..
+      l = li
+      u = ui
+      do while (l<u)
+         t = ind(l)
+         m = l
+         do i = l + 1, u
+            if (v(c,ind(i))<v(c,t)) then
+               m = m + 1
+               s = ind(m)
+               ind(m) = ind(i)
+               ind(i) = s
+            end if
+         end do
+         s = ind(l)
+         ind(l) = ind(m)
+         ind(m) = s
+         if (m<=k) l = m + 1
+         if (m>=k) u = m - 1
+      end do
+    end subroutine select_on_coordinate
+
+   subroutine spread_in_coordinate(tp,c,l,u,interv) 
+      ! the spread in coordinate 'c', between l and u. 
+      !
+      ! Return lower bound in 'smin', and upper in 'smax', 
+      ! ..
+      ! .. Structure Arguments ..
+      type(kdtree2), pointer :: tp
+      type(interval), intent(out) :: interv
+      ! ..
+      ! .. Scalar Arguments ..
+      integer, intent (In) :: c, l, u
+      ! ..
+      ! .. Local Scalars ..
+      real(kdkind) :: last, lmax, lmin, t, smin,smax
+      integer :: i, ulocal
+      ! ..
+      ! .. Local Arrays ..
+      real(kdkind), pointer :: v(:,:)
+      integer, pointer :: ind(:)
+      ! ..
+      v => tp%the_data(1:,1:)
+      ind => tp%ind(1:)
+      smin = v(c,ind(l))
+      smax = smin
+
+      ulocal = u
+
+      do i = l + 2, ulocal, 2
+         lmin = v(c,ind(i-1))
+         lmax = v(c,ind(i))
+         if (lmin>lmax) then
+            t = lmin
+            lmin = lmax
+            lmax = t
+         end if
+         if (smin>lmin) smin = lmin
+         if (smax<lmax) smax = lmax
+      end do
+      if (i==ulocal+1) then
+         last = v(c,ind(ulocal))
+         if (smin>last) smin = last
+         if (smax<last) smax = last
+      end if
+
+      interv%lower = smin
+      interv%upper = smax
+
+    end subroutine spread_in_coordinate
+
+
+  subroutine kdtree2_destroy(tp)
+    ! Deallocates all memory for the tree, except input data matrix
+    ! .. Structure Arguments ..
+    type(kdtree2), pointer :: tp
+    ! ..
+    call destroy_node(tp%root)
+
+    deallocate (tp%ind)
+    nullify (tp%ind)
+
+    if (tp%rearrange) then
+       deallocate(tp%rearranged_data)
+       nullify(tp%rearranged_data)
+    endif
+
+    deallocate(tp)
+    return
+
+  contains
+    recursive subroutine destroy_node(np)
+      ! .. Structure Arguments ..
+      type(tree_node), pointer :: np
+      ! ..
+      ! .. Intrinsic Functions ..
+      intrinsic ASSOCIATED
+      ! ..
+      if (associated(np%left)) then
+         call destroy_node(np%left)
+         nullify (np%left)
+      end if
+      if (associated(np%right)) then
+         call destroy_node(np%right)
+         nullify (np%right)
+      end if
+      if (associated(np%box)) deallocate(np%box)
+      deallocate(np)
+      return
+      
+    end subroutine destroy_node
+
+  end subroutine kdtree2_destroy
+
+  subroutine kdtree2_n_nearest(tp,qv,nn,results)
+    ! Find the 'nn' vectors in the tree nearest to 'qv' in euclidean norm
+    ! returning their indexes and distances in 'indexes' and 'distances'
+    ! arrays already allocated passed to this subroutine.
+    type(kdtree2), pointer      :: tp
+    real(kdkind), target, intent (In)    :: qv(:)
+    integer, intent (In)         :: nn
+    type(kdtree2_result), target :: results(:)
+
+
+    sr%ballsize = huge(1.0)
+    sr%qv => qv
+    sr%nn = nn
+    sr%nfound = 0
+    sr%centeridx = -1
+    sr%correltime = 0
+    sr%overflow = .false. 
+
+    sr%results => results
+
+    sr%nalloc = nn   ! will be checked
+
+    sr%ind => tp%ind
+    sr%rearrange = tp%rearrange
+    if (tp%rearrange) then
+       sr%Data => tp%rearranged_data
+    else
+       sr%Data => tp%the_data
+    endif
+    sr%dimen = tp%dimen
+
+    call validate_query_storage(nn) 
+    sr%pq = pq_create(results)
+
+    call search(tp%root)
+
+    if (tp%sort) then
+       call kdtree2_sort_results(nn, results)
+    endif
+!    deallocate(sr%pqp)
+    return
+  end subroutine kdtree2_n_nearest
+
+  subroutine kdtree2_n_nearest_around_point(tp,idxin,correltime,nn,results)
+    ! Find the 'nn' vectors in the tree nearest to point 'idxin',
+    ! with correlation window 'correltime', returing results in
+    ! results(:), which must be pre-allocated upon entry.
+    type(kdtree2), pointer        :: tp
+    integer, intent (In)           :: idxin, correltime, nn
+    type(kdtree2_result), target   :: results(:)
+
+    allocate (sr%qv(tp%dimen))
+    sr%qv = tp%the_data(:,idxin) ! copy the vector
+    sr%ballsize = huge(1.0)       ! the largest real(kdkind) number
+    sr%centeridx = idxin
+    sr%correltime = correltime
+
+    sr%nn = nn
+    sr%nfound = 0
+
+    sr%dimen = tp%dimen
+    sr%nalloc = nn
+
+    sr%results => results
+
+    sr%ind => tp%ind
+    sr%rearrange = tp%rearrange
+
+    if (sr%rearrange) then
+       sr%Data => tp%rearranged_data
+    else
+       sr%Data => tp%the_data
+    endif
+
+    call validate_query_storage(nn)
+    sr%pq = pq_create(results)
+
+    call search(tp%root)
+
+    if (tp%sort) then
+       call kdtree2_sort_results(nn, results)
+    endif
+    deallocate (sr%qv)
+    return
+  end subroutine kdtree2_n_nearest_around_point
+
+  subroutine kdtree2_r_nearest(tp,qv,r2,nfound,nalloc,results) 
+    ! find the nearest neighbors to point 'qv', within SQUARED
+    ! Euclidean distance 'r2'.   Upon ENTRY, nalloc must be the
+    ! size of memory allocated for results(1:nalloc).  Upon
+    ! EXIT, nfound is the number actually found within the ball. 
+    !
+    !  Note that if nfound .gt. nalloc then more neighbors were found
+    !  than there were storage to store.  The resulting list is NOT
+    !  the smallest ball inside norm r^2 
+    !
+    ! Results are NOT sorted unless tree was created with sort option.
+    type(kdtree2), pointer      :: tp
+    real(kdkind), target, intent (In)    :: qv(:)
+    real(kdkind), intent(in)             :: r2
+    integer, intent(out)         :: nfound
+    integer, intent (In)         :: nalloc
+    type(kdtree2_result), target :: results(:)
+
+    !
+    sr%qv => qv
+    sr%ballsize = r2
+    sr%nn = 0      ! flag for fixed ball search
+    sr%nfound = 0
+    sr%centeridx = -1
+    sr%correltime = 0
+
+    sr%results => results
+
+    call validate_query_storage(nalloc)
+    sr%nalloc = nalloc
+    sr%overflow = .false. 
+    sr%ind => tp%ind
+    sr%rearrange= tp%rearrange
+
+    if (tp%rearrange) then
+       sr%Data => tp%rearranged_data
+    else
+       sr%Data => tp%the_data
+    endif
+    sr%dimen = tp%dimen
+
+    !
+    !sr%dsl = Huge(sr%dsl)    ! set to huge positive values
+    !sr%il = -1               ! set to invalid indexes
+    !
+
+    call search(tp%root)
+    nfound = sr%nfound
+    if (tp%sort) then
+       call kdtree2_sort_results(nfound, results)
+    endif
+
+    if (sr%overflow) then
+       write (*,*) 'KD_TREE_TRANS: warning! return from kdtree2_r_nearest found more neighbors'
+       write (*,*) 'KD_TREE_TRANS: than storage was provided for.  Answer is NOT smallest ball'
+       write (*,*) 'KD_TREE_TRANS: with that number of neighbors!  I.e. it is wrong.'
+    endif
+
+    return
+  end subroutine kdtree2_r_nearest
+
+  subroutine kdtree2_r_nearest_around_point(tp,idxin,correltime,r2,&
+   nfound,nalloc,results)
+    !
+    ! Like kdtree2_r_nearest, but around a point 'idxin' already existing
+    ! in the data set. 
+    ! 
+    ! Results are NOT sorted unless tree was created with sort option.
+    !
+    type(kdtree2), pointer      :: tp
+    integer, intent (In)         :: idxin, correltime, nalloc
+    real(kdkind), intent(in)             :: r2
+    integer, intent(out)         :: nfound
+    type(kdtree2_result), target :: results(:)
+    ! ..
+    ! .. Intrinsic Functions ..
+    intrinsic HUGE
+    ! ..
+    allocate (sr%qv(tp%dimen))
+    sr%qv = tp%the_data(:,idxin) ! copy the vector
+    sr%ballsize = r2
+    sr%nn = 0    ! flag for fixed r search
+    sr%nfound = 0
+    sr%centeridx = idxin
+    sr%correltime = correltime
+
+    sr%results => results
+
+    sr%nalloc = nalloc
+    sr%overflow = .false.
+
+    call validate_query_storage(nalloc)
+
+    !    sr%dsl = HUGE(sr%dsl)    ! set to huge positive values
+    !    sr%il = -1               ! set to invalid indexes
+
+    sr%ind => tp%ind
+    sr%rearrange = tp%rearrange
+
+    if (tp%rearrange) then
+       sr%Data => tp%rearranged_data
+    else
+       sr%Data => tp%the_data
+    endif
+    sr%rearrange = tp%rearrange
+    sr%dimen = tp%dimen
+
+    !
+    !sr%dsl = Huge(sr%dsl)    ! set to huge positive values
+    !sr%il = -1               ! set to invalid indexes
+    !
+
+    call search(tp%root)
+    nfound = sr%nfound
+    if (tp%sort) then
+       call kdtree2_sort_results(nfound,results)
+    endif
+
+    if (sr%overflow) then
+       write (*,*) 'KD_TREE_TRANS: warning! return from kdtree2_r_nearest found more neighbors'
+       write (*,*) 'KD_TREE_TRANS: than storage was provided for.  Answer is NOT smallest ball'
+       write (*,*) 'KD_TREE_TRANS: with that number of neighbors!  I.e. it is wrong.'
+    endif
+
+    deallocate (sr%qv)
+    return
+  end subroutine kdtree2_r_nearest_around_point
+
+  function kdtree2_r_count(tp,qv,r2) result(nfound)
+    ! Count the number of neighbors within square distance 'r2'. 
+    type(kdtree2), pointer   :: tp
+    real(kdkind), target, intent (In) :: qv(:)
+    real(kdkind), intent(in)          :: r2
+    integer                   :: nfound
+    ! ..
+    ! .. Intrinsic Functions ..
+    intrinsic HUGE
+    ! ..
+    sr%qv => qv
+    sr%ballsize = r2
+
+    sr%nn = 0       ! flag for fixed r search
+    sr%nfound = 0
+    sr%centeridx = -1
+    sr%correltime = 0
+    
+    nullify(sr%results) ! for some reason, FTN 95 chokes on '=> null()'
+
+    sr%nalloc = 0            ! we do not allocate any storage but that's OK
+                             ! for counting.
+    sr%ind => tp%ind
+    sr%rearrange = tp%rearrange
+    if (tp%rearrange) then
+       sr%Data => tp%rearranged_data
+    else
+       sr%Data => tp%the_data
+    endif
+    sr%dimen = tp%dimen
+
+    !
+    !sr%dsl = Huge(sr%dsl)    ! set to huge positive values
+    !sr%il = -1               ! set to invalid indexes
+    !
+    sr%overflow = .false.
+
+    call search(tp%root)
+
+    nfound = sr%nfound
+
+    return
+  end function kdtree2_r_count
+
+  function kdtree2_r_count_around_point(tp,idxin,correltime,r2) &
+   result(nfound)
+    ! Count the number of neighbors within square distance 'r2' around
+    ! point 'idxin' with decorrelation time 'correltime'.
+    !
+    type(kdtree2), pointer :: tp
+    integer, intent (In)    :: correltime, idxin
+    real(kdkind), intent(in)        :: r2
+    integer                 :: nfound
+    ! ..
+    ! ..
+    ! .. Intrinsic Functions ..
+    intrinsic HUGE
+    ! ..
+    allocate (sr%qv(tp%dimen))
+    sr%qv = tp%the_data(:,idxin)
+    sr%ballsize = r2
+
+    sr%nn = 0       ! flag for fixed r search
+    sr%nfound = 0
+    sr%centeridx = idxin
+    sr%correltime = correltime
+    nullify(sr%results)
+
+    sr%nalloc = 0            ! we do not allocate any storage but that's OK
+                             ! for counting.
+
+    sr%ind => tp%ind
+    sr%rearrange = tp%rearrange
+
+    if (sr%rearrange) then
+       sr%Data => tp%rearranged_data
+    else
+       sr%Data => tp%the_data
+    endif
+    sr%dimen = tp%dimen
+
+    !
+    !sr%dsl = Huge(sr%dsl)    ! set to huge positive values
+    !sr%il = -1               ! set to invalid indexes
+    !
+    sr%overflow = .false.
+
+    call search(tp%root)
+
+    nfound = sr%nfound
+
+    return
+  end function kdtree2_r_count_around_point
+
+
+  subroutine validate_query_storage(n)
+    !
+    ! make sure we have enough storage for n
+    !
+    integer, intent(in) :: n
+
+    if (size(sr%results,1) .lt. n) then
+       write (*,*) 'KD_TREE_TRANS:  you did not provide enough storage for results(1:n)'
+       stop
+       return
+    endif
+
+    return
+  end subroutine validate_query_storage
+
+  function square_distance(d, iv,qv) result (res)
+    ! distance between iv[1:n] and qv[1:n] 
+    ! .. Function Return Value ..
+    ! re-implemented to improve vectorization.
+    real(kdkind) :: res
+    ! ..
+    ! ..
+    ! .. Scalar Arguments ..
+    integer :: d
+    real :: d_min(d)
+    ! ..
+    ! .. Array Arguments ..
+    real(kdkind) :: iv(:),qv(:)
+    ! ..
+    ! ..
+    ! .. Periodicity added by S Skory
+    ! res = sum( (iv(1:d)-qv(1:d))**2 )
+    d_min = min( abs(iv(1:d) - qv(1:d)) , 1. - abs(iv(1:d) - qv(1:d)) )
+    res = sum(d_min(1:d)**2)
+  end function square_distance
+  
+  recursive subroutine search(node)
+    !
+    ! This is the innermost core routine of the kd-tree search.  Along
+    ! with "process_terminal_node", it is the performance bottleneck. 
+    !
+    ! This version uses a logically complete secondary search of
+    ! "box in bounds", whether the sear
+    !
+    type(Tree_node), pointer          :: node
+    ! ..
+    type(tree_node),pointer            :: ncloser, nfarther
+    !
+    integer                            :: cut_dim, i
+    ! ..
+    real(kdkind)                               :: qval, dis, dis_right, dis_left, dis_node
+    real(kdkind)                               :: ballsize
+    real(kdkind), pointer           :: qv(:)
+    type(interval), pointer :: box(:) 
+
+    if ((associated(node%left) .and. associated(node%right)) .eqv. .false.) then
+       ! we are on a terminal node
+       if (sr%nn .eq. 0) then
+          call process_terminal_node_fixedball(node)
+       else
+          call process_terminal_node(node)
+       endif
+    else
+       ! we are not on a terminal node
+       qv => sr%qv(1:)
+       cut_dim = node%cut_dim
+       qval = qv(cut_dim)
+
+
+       ! Periodic stuff added by S Skory. We have to test to see which
+       ! node edge is closer, rather than just doing a simple less than or
+       ! greater than to the cut in this dimension.
+       dis_left = min( (node%box(cut_dim)%lower - qval)**2, &
+          (1 - abs(node%box(cut_dim)%lower - qval))**2)
+       dis_right = min( (node%box(cut_dim)%upper - qval)**2, &
+          (1 - abs(node%box(cut_dim)%upper - qval))**2)
+
+       
+       if (qval < node%cut_val) then
+       !if (dis_left <= dis_right) then
+          ncloser => node%left
+          nfarther => node%right
+          dis_node = (node%cut_val_right - qval)**2
+          dis = min(dis_right, dis_node)
+!          extra = node%cut_val - qval
+       else
+          ncloser => node%right
+          nfarther => node%left
+          dis_node = (node%cut_val_left - qval)**2
+          dis = min(dis_left, dis_node)
+!          extra = qval- node%cut_val_left
+       endif
+
+       if (associated(ncloser)) call search(ncloser)
+
+       ! we may need to search the second node. 
+       if (associated(nfarther)) then
+          ballsize = sr%ballsize
+!          dis=extra**2
+
+          if (dis <= ballsize) then
+             !
+             ! we do this separately as going on the first cut dimen is often
+             ! a good idea.
+             ! note that if extra**2 < sr%ballsize, then the next
+             ! check will also be false. 
+             !
+             box => node%box(1:)
+             do i=1,sr%dimen
+                if (i .ne. cut_dim) then
+                   dis = dis + dis2_from_bnd(qv(i),box(i)%lower,box(i)%upper)
+                   if (dis > ballsize) then
+                      return
+                   endif
+                endif
+             end do
+             
+             !
+             ! if we are still here then we need to search mroe.
+             !
+             call search(nfarther)
+          endif
+       endif
+    end if
+  end subroutine search
+
+
+!  real(kdkind) function dis2_from_bnd(x,amin,amax) result (res)
+!    real(kdkind), intent(in) :: x, amin,amax
+!   
+!    if (x > amax) then
+!       res = (x-amax)**2;
+!       return
+!    else
+!       if (x < amin) then
+!          res = (amin-x)**2;
+!          return
+!       else
+!          res = 0.0
+!          return
+!       endif
+!    endif
+!    return
+!  end function dis2_from_bnd
+
+  real(kdkind) function dis2_from_bnd(x,amin,amax) result (res)
+    ! Periodicity added by S Skory
+    real(kdkind), intent(in) :: x, amin,amax
+    real :: dxmax, dxmin
+    
+    if ((x < amax) .and. (x > amin)) then
+      res = 0.0
+      return
+    else
+      dxmax = (min( abs(x - amax), 1. - abs(x - amax)))**2
+      dxmin = (min( abs(x - amin), 1. - abs(x - amin)))**2
+      res = min(dxmax, dxmin)
+      return
+    endif
+    return
+  end function dis2_from_bnd
+
+  logical function box_in_search_range(node, sr) result(res)
+    !
+    ! Return the distance from 'qv' to the CLOSEST corner of node's
+    ! bounding box
+    ! for all coordinates outside the box.   Coordinates inside the box
+    ! contribute nothing to the distance.
+    !
+    type(tree_node), pointer :: node
+    type(tree_search_record), pointer :: sr
+
+    integer :: dimen, i
+    real(kdkind)    :: dis, ballsize
+    real(kdkind)    :: l, u
+
+    dimen = sr%dimen
+    ballsize = sr%ballsize
+    dis = 0.0
+    res = .true.
+    do i=1,dimen
+       l = node%box(i)%lower
+       u = node%box(i)%upper
+       dis = dis + (dis2_from_bnd(sr%qv(i),l,u))
+       if (dis > ballsize) then
+          res = .false.
+          return
+       endif
+    end do
+    res = .true.
+    return
+  end function box_in_search_range
+
+
+  subroutine process_terminal_node(node)
+    !
+    ! Look for actual near neighbors in 'node', and update
+    ! the search results on the sr data structure.
+    !
+    type(tree_node), pointer          :: node
+    !
+    real(kdkind), pointer          :: qv(:)
+    integer, pointer       :: ind(:)
+    real(kdkind), pointer          :: data(:,:)
+    !
+    integer                :: dimen, i, indexofi, k, centeridx, correltime
+    real(kdkind)                   :: ballsize, sd, newpri
+    logical                :: rearrange
+    type(pq), pointer      :: pqp
+    real :: sdtemp
+    !
+    ! copy values from sr to local variables
+    !
+    !
+    ! Notice, making local pointers with an EXPLICIT lower bound
+    ! seems to generate faster code.
+    ! why?  I don't know.
+    qv => sr%qv(1:) 
+    pqp => sr%pq
+    dimen = sr%dimen
+    ballsize = sr%ballsize 
+    rearrange = sr%rearrange
+    ind => sr%ind(1:)
+    data => sr%Data(1:,1:)     
+    centeridx = sr%centeridx
+    correltime = sr%correltime
+
+    !    doing_correl = (centeridx >= 0)  ! Do we have a decorrelation window? 
+    !    include_point = .true.    ! by default include all points
+    ! search through terminal bucket.
+
+    mainloop: do i = node%l, node%u
+       if (rearrange) then
+          sd = 0.0
+          do k = 1,dimen
+             !sd = sd + (data(k,i) - qv(k))**2
+             ! Periodicity by S Skory
+             sdtemp = min( abs(data(k,i) - qv(k)) , 1. - abs(data(k,i) - qv(k)))
+             sd = sd + sdtemp**2
+             if (sd>ballsize) cycle mainloop
+          end do
+          indexofi = ind(i)  ! only read it if we have not broken out
+       else
+          indexofi = ind(i)
+          sd = 0.0
+          do k = 1,dimen
+             !sd = sd + (data(k,indexofi) - qv(k))**2
+             sdtemp = min( abs(data(k,indexofi) - qv(k)), 1. - abs(data(k,indexofi) - qv(k)))
+             sd = sd + sdtemp**2
+             if (sd>ballsize) cycle mainloop
+          end do
+       endif
+
+       if (centeridx > 0) then ! doing correlation interval?
+          if (abs(indexofi-centeridx) < correltime) cycle mainloop
+       endif
+
+
+       ! 
+       ! two choices for any point.  The list so far is either undersized,
+       ! or it is not.
+       !
+       ! If it is undersized, then add the point and its distance
+       ! unconditionally.  If the point added fills up the working
+       ! list then set the sr%ballsize, maximum distance bound (largest distance on
+       ! list) to be that distance, instead of the initialized +infinity. 
+       !
+       ! If the running list is full size, then compute the
+       ! distance but break out immediately if it is larger
+       ! than sr%ballsize, "best squared distance" (of the largest element),
+       ! as it cannot be a good neighbor. 
+       !
+       ! Once computed, compare to best_square distance.
+       ! if it is smaller, then delete the previous largest
+       ! element and add the new one. 
+
+       if (sr%nfound .lt. sr%nn) then
+          !
+          ! add this point unconditionally to fill list.
+          !
+          sr%nfound = sr%nfound +1 
+          newpri = pq_insert(pqp,sd,indexofi)
+          if (sr%nfound .eq. sr%nn) ballsize = newpri
+          ! we have just filled the working list.
+          ! put the best square distance to the maximum value
+          ! on the list, which is extractable from the PQ. 
+
+       else
+          !
+          ! now, if we get here,
+          ! we know that the current node has a squared
+          ! distance smaller than the largest one on the list, and
+          ! belongs on the list. 
+          ! Hence we replace that with the current one.
+          !
+          ballsize = pq_replace_max(pqp,sd,indexofi)
+       endif
+    end do mainloop
+    !
+    ! Reset sr variables which may have changed during loop
+    !
+    sr%ballsize = ballsize 
+
+  end subroutine process_terminal_node
+
+  subroutine process_terminal_node_fixedball(node)
+    !
+    ! Look for actual near neighbors in 'node', and update
+    ! the search results on the sr data structure, i.e.
+    ! save all within a fixed ball.
+    !
+    type(tree_node), pointer          :: node
+    !
+    real(kdkind), pointer          :: qv(:)
+    integer, pointer       :: ind(:)
+    real(kdkind), pointer          :: data(:,:)
+    !
+    integer                :: nfound
+    integer                :: dimen, i, indexofi, k
+    integer                :: centeridx, correltime, nn
+    real(kdkind)                   :: ballsize, sd
+    logical                :: rearrange
+    real :: sdtemp
+    !
+    ! copy values from sr to local variables
+    !
+    qv => sr%qv(1:)
+    dimen = sr%dimen
+    ballsize = sr%ballsize 
+    rearrange = sr%rearrange
+    ind => sr%ind(1:)
+    data => sr%Data(1:,1:)
+    centeridx = sr%centeridx
+    correltime = sr%correltime
+    nn = sr%nn ! number to search for
+    nfound = sr%nfound
+
+    ! search through terminal bucket.
+    mainloop: do i = node%l, node%u
+
+       ! 
+       ! two choices for any point.  The list so far is either undersized,
+       ! or it is not.
+       !
+       ! If it is undersized, then add the point and its distance
+       ! unconditionally.  If the point added fills up the working
+       ! list then set the sr%ballsize, maximum distance bound (largest distance on
+       ! list) to be that distance, instead of the initialized +infinity. 
+       !
+       ! If the running list is full size, then compute the
+       ! distance but break out immediately if it is larger
+       ! than sr%ballsize, "best squared distance" (of the largest element),
+       ! as it cannot be a good neighbor. 
+       !
+       ! Once computed, compare to best_square distance.
+       ! if it is smaller, then delete the previous largest
+       ! element and add the new one. 
+
+       ! which index to the point do we use? 
+
+       if (rearrange) then
+          sd = 0.0
+          do k = 1,dimen
+             !sd = sd + (data(k,i) - qv(k))**2
+             ! Periodicity S Skory
+             sdtemp = min( abs(data(k,i) - qv(k)) , 1. - abs(data(k,i) - qv(k)))
+             ! print *, "k ", k, " data(k,i) ", data(k,i), " qv(k) ", qv(k), " sdtemp ", sdtemp
+             sd = sd + sdtemp**2
+             if (sd>ballsize) cycle mainloop
+          end do
+          indexofi = ind(i)  ! only read it if we have not broken out
+       else
+          indexofi = ind(i)
+          sd = 0.0
+          do k = 1,dimen
+             !sd = sd + (data(k,indexofi) - qv(k))**2
+             ! Periodicity S Skory
+             sdtemp = min( abs(data(k,indexofi) - qv(k)), 1. - abs(data(k,indexofi) - qv(k)))
+             ! print *, "k ", k, " data(k,indexofi) ", data(k,indexofi), " qv(k) ", qv(k), " sdtemp ", sdtemp
+             sd = sd + sdtemp**2
+             if (sd>ballsize) cycle mainloop
+          end do
+       endif
+
+       if (centeridx > 0) then ! doing correlation interval?
+          if (abs(indexofi-centeridx)<correltime) cycle mainloop
+       endif
+
+       nfound = nfound+1
+       if (nfound .gt. sr%nalloc) then
+          ! oh nuts, we have to add another one to the tree but
+          ! there isn't enough room.
+          sr%overflow = .true.
+       else
+          sr%results(nfound)%dis = sd
+          sr%results(nfound)%idx = indexofi
+       endif
+    end do mainloop
+    !
+    ! Reset sr variables which may have changed during loop
+    !
+    sr%nfound = nfound
+  end subroutine process_terminal_node_fixedball
+
+  subroutine kdtree2_n_nearest_brute_force(tp,qv,nn,results) 
+    ! find the 'n' nearest neighbors to 'qv' by exhaustive search.
+    ! only use this subroutine for testing, as it is SLOW!  The
+    ! whole point of a k-d tree is to avoid doing what this subroutine
+    ! does.
+    type(kdtree2), pointer :: tp
+    real(kdkind), intent (In)       :: qv(:)
+    integer, intent (In)    :: nn
+    type(kdtree2_result)    :: results(:) 
+
+    integer :: i, j, k
+    real(kdkind), allocatable :: all_distances(:)
+    ! ..
+    allocate (all_distances(tp%n))
+    do i = 1, tp%n
+       all_distances(i) = square_distance(tp%dimen,qv,tp%the_data(:,i))
+    end do
+    ! now find 'n' smallest distances
+    do i = 1, nn
+       results(i)%dis =  huge(1.0)
+       results(i)%idx = -1
+    end do
+    do i = 1, tp%n
+       if (all_distances(i)<results(nn)%dis) then
+          ! insert it somewhere on the list
+          do j = 1, nn
+             if (all_distances(i)<results(j)%dis) exit
+          end do
+          ! now we know 'j'
+          do k = nn - 1, j, -1
+             results(k+1) = results(k)
+          end do
+          results(j)%dis = all_distances(i)
+          results(j)%idx = i
+       end if
+    end do
+    deallocate (all_distances)
+  end subroutine kdtree2_n_nearest_brute_force
+  
+
+  subroutine kdtree2_r_nearest_brute_force(tp,qv,r2,nfound,results) 
+    ! find the nearest neighbors to 'qv' with distance**2 <= r2 by exhaustive search.
+    ! only use this subroutine for testing, as it is SLOW!  The
+    ! whole point of a k-d tree is to avoid doing what this subroutine
+    ! does.
+    type(kdtree2), pointer :: tp
+    real(kdkind), intent (In)       :: qv(:)
+    real(kdkind), intent (In)       :: r2
+    integer, intent(out)    :: nfound
+    type(kdtree2_result)    :: results(:) 
+
+    integer :: i, nalloc
+    real(kdkind), allocatable :: all_distances(:)
+    ! ..
+    allocate (all_distances(tp%n))
+    do i = 1, tp%n
+       all_distances(i) = square_distance(tp%dimen,qv,tp%the_data(:,i))
+    end do
+    
+    nfound = 0
+    nalloc = size(results,1)
+
+    do i = 1, tp%n
+       if (all_distances(i)< r2) then
+          ! insert it somewhere on the list
+          if (nfound .lt. nalloc) then
+             nfound = nfound+1
+             results(nfound)%dis = all_distances(i)
+             results(nfound)%idx = i
+          endif
+       end if
+    enddo
+    deallocate (all_distances)
+
+    call kdtree2_sort_results(nfound,results)
+
+
+  end subroutine kdtree2_r_nearest_brute_force
+
+  subroutine kdtree2_sort_results(nfound,results)
+    !  Use after search to sort results(1:nfound) in order of increasing 
+    !  distance.
+    integer, intent(in)          :: nfound
+    type(kdtree2_result), target :: results(:) 
+    !
+    !
+
+    !THIS IS BUGGY WITH INTEL FORTRAN
+    !    If (nfound .Gt. 1) Call heapsort(results(1:nfound)%dis,results(1:nfound)%ind,nfound)
+    !
+    if (nfound .gt. 1) call heapsort_struct(results,nfound)
+
+    return
+  end subroutine kdtree2_sort_results
+
+  subroutine heapsort(a,ind,n)
+    !
+    ! Sort a(1:n) in ascending order, permuting ind(1:n) similarly.
+    ! 
+    ! If ind(k) = k upon input, then it will give a sort index upon output.
+    !
+    integer,intent(in)          :: n
+    real(kdkind), intent(inout)         :: a(:) 
+    integer, intent(inout)      :: ind(:)
+
+    !
+    !
+    real(kdkind)        :: value   ! temporary for a value from a()
+    integer     :: ivalue  ! temporary for a value from ind()
+
+    integer     :: i,j
+    integer     :: ileft,iright
+
+    ileft=n/2+1
+    iright=n
+
+    !    do i=1,n
+    !       ind(i)=i
+    ! Generate initial idum array
+    !    end do
+
+    if(n.eq.1) return                  
+
+    do 
+       if(ileft > 1)then
+          ileft=ileft-1
+          value=a(ileft); ivalue=ind(ileft)
+       else
+          value=a(iright); ivalue=ind(iright)
+          a(iright)=a(1); ind(iright)=ind(1)
+          iright=iright-1
+          if (iright == 1) then
+             a(1)=value;ind(1)=ivalue
+             return
+          endif
+       endif
+       i=ileft
+       j=2*ileft
+       do while (j <= iright) 
+          if(j < iright) then
+             if(a(j) < a(j+1)) j=j+1
+          endif
+          if(value < a(j)) then
+             a(i)=a(j); ind(i)=ind(j)
+             i=j
+             j=j+j
+          else
+             j=iright+1
+          endif
+       end do
+       a(i)=value; ind(i)=ivalue
+    end do
+  end subroutine heapsort
+
+  subroutine heapsort_struct(a,n)
+    !
+    ! Sort a(1:n) in ascending order
+    ! 
+    !
+    integer,intent(in)                 :: n
+    type(kdtree2_result),intent(inout) :: a(:)
+
+    !
+    !
+    type(kdtree2_result) :: value ! temporary value
+
+    integer     :: i,j
+    integer     :: ileft,iright
+
+    ileft=n/2+1
+    iright=n
+
+    !    do i=1,n
+    !       ind(i)=i
+    ! Generate initial idum array
+    !    end do
+
+    if(n.eq.1) return                  
+
+    do 
+       if(ileft > 1)then
+          ileft=ileft-1
+          value=a(ileft)
+       else
+          value=a(iright)
+          a(iright)=a(1)
+          iright=iright-1
+          if (iright == 1) then
+             a(1) = value
+             return
+          endif
+       endif
+       i=ileft
+       j=2*ileft
+       do while (j <= iright) 
+          if(j < iright) then
+             if(a(j)%dis < a(j+1)%dis) j=j+1
+          endif
+          if(value%dis < a(j)%dis) then
+             a(i)=a(j); 
+             i=j
+             j=j+j
+          else
+             j=iright+1
+          endif
+       end do
+       a(i)=value
+    end do
+  end subroutine heapsort_struct
+
+end module kdtree2_module
+

Added: trunk/yt/extensions/kdtree/test.py
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/test.py	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,58 @@
+from Forthon import *
+from fKDpy import *
+import numpy,random
+
+n = 32768
+
+
+fKD.tags = fzeros((64),'i')
+fKD.dist = fzeros((64),'d')
+fKD.pos = fzeros((3,n),'d')
+fKD.nn = 64
+fKD.nparts = n
+fKD.sort = True
+fKD.rearrange = True
+fKD.qv = numpy.array([16./32, 16./32, 16./32])
+
+fp = open('parts.txt','r')
+xpos = []
+ypos = []
+zpos = []
+line = fp.readline()
+while line:
+    line = line.split()
+    xpos.append(float(line[0]))
+    ypos.append(float(line[1]))
+    zpos.append(float(line[2]))
+    line= fp.readline()
+
+fp.close()
+
+
+for k in range(32):
+    for j in range(32):
+        for i in range(32):
+            fKD.pos[0][i + j*32 + k*1024] = float(i)/32 + 1./64 + 0.0001*random.random()
+            fKD.pos[1][i + j*32 + k*1024] = float(j)/32 + 1./64 + 0.0001*random.random()
+            fKD.pos[2][i + j*32 + k*1024] = float(k)/32 + 1./64 + 0.0001*random.random()
+
+            
+
+#print fKD.pos[0][0],fKD.pos[1][0],fKD.pos[2][0]
+
+create_tree()
+
+
+find_nn_nearest_neighbors()
+
+#print 'next'
+
+#fKD.qv = numpy.array([0., 0., 0.])
+
+#find_nn_nearest_neighbors()
+
+
+#print (fKD.tags - 1)
+#print fKD.dist
+
+free_tree()

Modified: trunk/yt/lagos/BaseDataTypes.py
==============================================================================
--- trunk/yt/lagos/BaseDataTypes.py	(original)
+++ trunk/yt/lagos/BaseDataTypes.py	Sun Dec 20 12:42:09 2009
@@ -1990,6 +1990,12 @@
                  & (r < self._radius))
         return cm
 
+    def volume(self, unit="unitary"):
+        """
+        Return the volume of the cylinder in units of *unit*.
+        """
+        return math.pi * (self._radius)**2. * self._height * pf[unit]**3
+
 class AMRRegionBase(AMR3DData):
     """
     AMRRegions are rectangular prisms of data.
@@ -2031,6 +2037,15 @@
                  & (grid['z'] + dzp > self.left_edge[2]) )
         return cm
 
+    def volume(self, unit = "unitary"):
+        """
+        Return the volume of the region in units *unit*.
+        """
+        diff = na.array(self.right_edge) - na.array(self.left_edge)
+        # Find the full volume
+        vol = na.prod(diff * self.pf[unit])
+        return vol
+
 class AMRRegionStrictBase(AMRRegionBase):
     """
     AMRRegion without any dx padding for cell selection
@@ -2092,6 +2107,20 @@
                           & (grid['z'] + dzp + off_z > self.left_edge[2]) )
             return cm
 
+    def volume(self, unit = "unitary"):
+        """
+        Return the volume of the region in units *unit*.
+        """
+        period = self.pf["DomainRightEdge"] - self.pf["DomainLeftEdge"]
+        diff = na.array(self.right_edge) - na.array(self.left_edge)
+        # Correct for wrap-arounds.
+        tofix = (diff < 0)
+        toadd = period[tofix]
+        diff += toadd
+        # Find the full volume
+        vol = na.prod(diff * self.pf[unit])
+        return vol
+        
 
 class AMRPeriodicRegionStrictBase(AMRPeriodicRegionBase):
     """
@@ -2178,6 +2207,12 @@
             self._cut_masks[grid.id] = cm
         return cm
 
+    def volume(self, unit = "unitary"):
+        """
+        Return the volume of the sphere in units *unit*.
+        """
+        return 4./3. * math.pi * (self.radius * self.pf[unit])**3.0
+
 class AMRFloatCoveringGridBase(AMR3DData):
     """
     Covering grids represent fixed-resolution data over a given region.

Modified: trunk/yt/lagos/EnzoFields.py
==============================================================================
--- trunk/yt/lagos/EnzoFields.py	(original)
+++ trunk/yt/lagos/EnzoFields.py	Sun Dec 20 12:42:09 2009
@@ -69,16 +69,20 @@
                   validators=ValidateDataField("%s_Density" % species))
 
 def _Metallicity(field, data):
-    return data["Metal_Fraction"] / 0.0204
-add_field("Metallicity", units=r"Z_{\rm{Solar}}",
+    return data["Metal_Fraction"]
+def _ConvertMetallicity(data):
+    return 49.0196 # 1 / 0.0204
+add_field("Metallicity", units=r"Z_{\rm{\odot}}",
           function=_Metallicity,
+          convert_function=_ConvertMetallicity,
           validators=ValidateDataField("Metal_Density"),
           projection_conversion="1")
 
 def _Metallicity3(field, data):
-    return data["SN_Colour"] / 0.0204
-add_field("Metallicity3", units=r"Z_{\rm{Solar}}",
+    return data["SN_Colour"]
+add_field("Metallicity3", units=r"Z_{\rm{\odot}}",
           function=_Metallicity3,
+          convert_function=_ConvertMetallicity,
           validators=ValidateDataField("SN_Colour"),
           projection_conversion="1")
 
@@ -208,6 +212,7 @@
           not_in_all = True)
 
 EnzoFieldInfo["Temperature"]._units = r"\rm{K}"
+EnzoFieldInfo["Temperature"].units = r"K"
 
 def _convertVelocity(data):
     return data.convert("x-velocity")
@@ -244,20 +249,6 @@
 add_field("particle_density_pyx", function=_pdensity_pyx,
           validators=[ValidateSpatial(0)], convert_function=_convertDensity)
 
-def _spdensity(field, data):
-    blank = na.zeros(data.ActiveDimensions, dtype='float32', order="FORTRAN")
-    if data.NumberOfParticles == 0: return blank
-    filter = data['creation_time'] > 0.0
-    if not filter.any(): return blank
-    cic_deposit.cic_deposit(data["particle_position_x"][filter],
-                            data["particle_position_y"][filter],
-                            data["particle_position_z"][filter], 3,
-                            data["particle_mass"][filter],
-                            blank, data.LeftEdge, data['dx'])
-    return blank
-add_field("star_density", function=_spdensity,
-          validators=[ValidateSpatial(0)], convert_function=_convertDensity)
-
 def _spdensity_pyx(field, data):
     blank = na.zeros(data.ActiveDimensions, dtype='float32')
     if data.NumberOfParticles == 0: return blank
@@ -275,7 +266,81 @@
 add_field("star_density_pyx", function=_spdensity_pyx,
           validators=[ValidateSpatial(0)], convert_function=_convertDensity)
 
-EnzoFieldInfo["Temperature"].units = r"K"
+def _star_field(field, data):
+    """
+    Create a grid field for star quantities, weighted by star mass.
+    """
+    particle_field = field.name[5:]
+    top = na.zeros(data.ActiveDimensions, dtype='float32')
+    if data.NumberOfParticles == 0: return top
+    filter = data['creation_time'] > 0.0
+    if not filter.any(): return top
+    particle_field_data = data[particle_field][filter] * data['particle_mass'][filter]
+    CICDeposit_3(data["particle_position_x"][filter].astype(na.float64),
+                 data["particle_position_y"][filter].astype(na.float64),
+                 data["particle_position_z"][filter].astype(na.float64),
+                 particle_field_data.astype(na.float32),
+                 na.int64(na.where(filter)[0].size),
+                 top, na.array(data.LeftEdge).astype(na.float64),
+                 na.array(data.ActiveDimensions).astype(na.int32), 
+                 na.float64(data['dx']))
+    del particle_field_data
+
+    bottom = na.zeros(data.ActiveDimensions, dtype='float32')
+    CICDeposit_3(data["particle_position_x"][filter].astype(na.float64),
+                 data["particle_position_y"][filter].astype(na.float64),
+                 data["particle_position_z"][filter].astype(na.float64),
+                 data["particle_mass"][filter].astype(na.float32),
+                 na.int64(na.where(filter)[0].size),
+                 bottom, na.array(data.LeftEdge).astype(na.float64),
+                 na.array(data.ActiveDimensions).astype(na.int32), 
+                 na.float64(data['dx']))
+
+    top[bottom == 0] = 0.0
+    bnz = bottom.nonzero()
+    top[bnz] /= bottom[bnz]
+    return top
+
+add_field('star_metallicity_fraction', function=_star_field,
+          validators=[ValidateSpatial(0)])
+add_field('star_creation_time', function=_star_field,
+          validators=[ValidateSpatial(0)])
+add_field('star_dynamical_time', function=_star_field,
+          validators=[ValidateSpatial(0)])
+
+def _StarMetallicity(field, data):
+    return data['star_metallicity_fraction']
+add_field('StarMetallicity', units=r"Z_{\rm{\odot}}",
+          function=_StarMetallicity,
+          convert_function=_ConvertMetallicity,
+          projection_conversion="1")
+
+def _StarCreationTime(field, data):
+    return data['star_creation_time']
+def _ConvertEnzoTimeYears(data):
+    return data.pf.time_units['years']
+add_field('StarCreationTimeYears', units="\rm{yr}",
+          function=_StarCreationTime,
+          convert_function=_ConvertEnzoTimeYears,
+          projection_conversion="1")
+
+def _StarDynamicalTime(field, data):
+    return data['star_dynamical_time']
+add_field('StarDynamicalTimeYears', units="\rm{yr}",
+          function=_StarDynamicalTime,
+          convert_function=_ConvertEnzoTimeYears,
+          projection_conversion="1")
+
+def _StarAge(field, data):
+    star_age = na.zeros(data['StarCreationTimeYears'].shape)
+    with_stars = data['StarCreationTimeYears'] > 0
+    star_age[with_stars] = data.pf.time_units['years'] * \
+        data.pf["InitialTime"] - \
+        data['StarCreationTimeYears'][with_stars]
+    return star_age
+add_field('StarAgeYears', units="\rm{yr}",
+          function=_StarAge,
+          projection_conversion="1")
 
 #
 # Now we do overrides for 2D fields

Modified: trunk/yt/lagos/HaloFinding.py
==============================================================================
--- trunk/yt/lagos/HaloFinding.py	(original)
+++ trunk/yt/lagos/HaloFinding.py	Sun Dec 20 12:42:09 2009
@@ -26,6 +26,7 @@
 """
 
 from yt.lagos import *
+from yt.math_utils import *
 from yt.lagos.hop.EnzoHop import RunHOP
 try:
     from yt.lagos.parallelHOP.parallelHOP import *
@@ -39,7 +40,7 @@
 from yt.performance_counters import yt_counters, time_function
 
 from kd import *
-import math, sys
+import math, sys, itertools
 from collections import defaultdict
 
 class Halo(object):
@@ -61,7 +62,10 @@
         self.halo_list = halo_list
         self.id = id
         self.data = halo_list._data_source
-        if indices is not None: self.indices = halo_list._base_indices[indices]
+        if indices is not None:
+            self.indices = halo_list._base_indices[indices]
+        else:
+            self.indices = None
         # We assume that if indices = None, the instantiator has OTHER plans
         # for us -- i.e., setting it somehow else
         self.size = size
@@ -71,6 +75,8 @@
         self.max_radius = max_radius
         self.bulk_vel = bulk_vel
         self.tasks = tasks
+        self.bin_count = None
+        self.overdensity = None
 
     def center_of_mass(self):
         """
@@ -165,13 +171,107 @@
         # set attributes on n
         self._processing = False
 
+    def virial_mass(self, virial_overdensity=200., bins=300):
+        """
+        Return the virial mass of the halo in Msun, using only the particles
+        in the halo (no baryonic information used).
+        Calculate using *bins* number of bins and *virial_overdensity* density
+        threshold. Returns -1 if the halo is not virialized.
+        """
+        self.virial_info(bins=bins)
+        vir_bin = self.virial_bin(virial_overdensity=virial_overdensity, bins=bins)
+        if vir_bin != -1:
+            return self.mass_bins[vir_bin]
+        else:
+            return -1
+        
+    
+    def virial_radius(self, virial_overdensity=200., bins=300):
+        """
+        Return the virial radius of the halo in code units, using only the
+        particles in the halo (no baryonic information used).
+        Calculate using *bins* number of bins and *virial_overdensity* density
+        threshold. Returns -1 if the halo is not virialized.
+        """
+        self.virial_info(bins=bins)
+        vir_bin = self.virial_bin(virial_overdensity=virial_overdensity, bins=bins)
+        if vir_bin != -1:
+            return self.radial_bins[vir_bin]
+        else:
+            return -1
+
+    def virial_bin(self, virial_overdensity=200., bins=300):
+        """
+        Return the bin index for the virial radius for the given halo.
+        Returns -1 if the halo is not virialized to the set
+        *virial_overdensity*. 
+        """
+        self.virial_info(bins=bins)
+        over = (self.overdensity > virial_overdensity)
+        if (over == True).any():
+            vir_bin = max(na.arange(bins+1)[over])
+            return vir_bin
+        else:
+            return -1
+    
+    def virial_info(self, bins=300):
+        """
+        Calculate the virial profile bins for this halo, using only the particles
+        in the halo (no baryonic information used).
+        Calculate using *bins* number of bins.
+        """
+        # Skip if we've already calculated for this number of bins.
+        if self.bin_count == bins and self.overdensity is not None:
+            return None
+        self.bin_count = bins
+        # Cosmology
+        h = self.halo_list._data_source.pf['CosmologyHubbleConstantNow']
+        Om_matter = self.halo_list._data_source.pf['CosmologyOmegaMatterNow']
+        z = self.halo_list._data_source.pf['CosmologyCurrentRedshift']
+        rho_crit_now = 1.8788e-29 * h**2.0 * Om_matter # g cm^-3
+        Msun2g = 1.989e33
+        rho_crit = rho_crit_now * ((1.0 + z)**3.0)
+        
+        # Get some pertinent information about the halo.
+        self.mass_bins = na.zeros(self.bin_count+1, dtype='float64')
+        dist = na.empty(self.indices.size, dtype='float64')
+        cen = self.center_of_mass()
+        period = self.halo_list._data_source.pf["DomainRightEdge"] - \
+            self.halo_list._data_source.pf["DomainLeftEdge"]
+        mark = 0
+        # Find the distances to the particles. I don't like this much, but I
+        # can't see a way to eliminate a loop like this, either here or in
+        # yt.math.
+        for pos in izip(self["particle_position_x"], self["particle_position_y"],
+                self["particle_position_z"]):
+            dist[mark] = periodic_dist(cen, pos, period)
+            mark += 1
+        # Set up the radial bins.
+        # Multiply min and max to prevent issues with digitize below.
+        self.radial_bins = na.logspace(math.log10(min(dist)*.99), 
+            math.log10(max(dist)*1.01), num=self.bin_count+1)
+        # Find out which bin each particle goes into, and add the particle
+        # mass to that bin.
+        inds = na.digitize(dist, self.radial_bins) - 1
+        for index in na.unique(inds):
+            self.mass_bins[index] += sum(self["ParticleMassMsun"][inds==index])
+        # Now forward sum the masses in the bins.
+        for i in xrange(self.bin_count):
+            self.mass_bins[i+1] += self.mass_bins[i]
+        # Calculate the over densities in the bins.
+        self.overdensity = self.mass_bins * Msun2g / \
+        (4./3. * math.pi * rho_crit * \
+        (self.radial_bins * self.halo_list._data_source.pf["cm"])**3.0)
+        
+
 class HOPHalo(Halo):
     pass
 
 class parallelHOPHalo(Halo,ParallelAnalysisInterface):
     dont_wrap = ["maximum_density","maximum_density_location",
         "center_of_mass","total_mass","bulk_velocity","maximum_radius",
-        "get_size","get_sphere", "write_particle_list","__getitem__"]
+        "get_size","get_sphere", "write_particle_list","__getitem__", 
+        "virial_info", "virial_bin", "virial_mass", "virial_radius"]
 
     def maximum_density(self):
         """
@@ -296,6 +396,118 @@
         global_size = self._mpi_allsum(my_size)
         return global_size
 
+    def __getitem__(self, key):
+        if ytcfg.getboolean("yt","inline") == False:
+            return self.data.particles[key][self.indices]
+        else:
+            return self.data[key][self.indices]
+
+    def virial_mass(self, virial_overdensity=200., bins=300):
+        """
+        Return the virial mass of the halo in Msun, using only the particles
+        in the halo (no baryonic information used).
+        Calculate using *bins* number of bins and *virial_overdensity* density
+        threshold. Returns -1 if the halo is not virialized.
+        """
+        self.virial_info(bins=bins)
+        vir_bin = self.virial_bin(virial_overdensity=virial_overdensity, bins=bins)
+        if vir_bin != -1:
+            return self.mass_bins[vir_bin]
+        else:
+            return -1
+        
+    
+    def virial_radius(self, virial_overdensity=200., bins=300):
+        """
+        Return the virial radius of the halo in code units, using only the
+        particles in the halo (no baryonic information used).
+        Calculate using *bins* number of bins and *virial_overdensity* density
+        threshold. Returns -1 if the halo is not virialized.
+        """
+        self.virial_info(bins=bins)
+        vir_bin = self.virial_bin(virial_overdensity=virial_overdensity, bins=bins)
+        if vir_bin != -1:
+            return self.radial_bins[vir_bin]
+        else:
+            return -1
+
+    def virial_bin(self, virial_overdensity=200., bins=300):
+        """
+        Return the bin index for the virial radius for the given halo.
+        Returns -1 if the halo is not virialized to the set
+        *virial_overdensity*. 
+        """
+        self.virial_info(bins=bins)
+        over = (self.overdensity > virial_overdensity)
+        if (over == True).any():
+            vir_bin = max(na.arange(bins+1)[over])
+            return vir_bin
+        else:
+            return -1
+
+    def virial_info(self, bins=300):
+        """
+        Calculate the virial profile bins for this halo, using only the particles
+        in the halo (no baryonic information used).
+        Calculate using *bins* number of bins.
+        """
+        # Skip if we've already calculated for this number of bins.
+        if self.bin_count == bins and self.overdensity is not None:
+            return None
+        # Do this for all because all will use it.
+        self.bin_count = bins
+        period = self.halo_list._data_source.pf["DomainRightEdge"] - \
+            self.halo_list._data_source.pf["DomainLeftEdge"]
+        self.mass_bins = na.zeros(self.bin_count+1, dtype='float64')
+        cen = self.center_of_mass()
+        # Cosmology
+        h = self.halo_list._data_source.pf['CosmologyHubbleConstantNow']
+        Om_matter = self.halo_list._data_source.pf['CosmologyOmegaMatterNow']
+        z = self.halo_list._data_source.pf['CosmologyCurrentRedshift']
+        rho_crit_now = 1.8788e-29 * h**2.0 * Om_matter # g cm^-3
+        Msun2g = 1.989e33
+        rho_crit = rho_crit_now * ((1.0 + z)**3.0)
+        # If I own some of this halo operate on the particles.
+        if self.indices is not None:
+            # Get some pertinent information about the halo.
+            dist = na.empty(self.indices.size, dtype='float64')
+            mark = 0
+            # Find the distances to the particles. I don't like this much, but I
+            # can't see a way to eliminate a loop like this, either here or in
+            # yt.math.
+            for pos in izip(self["particle_position_x"], self["particle_position_y"],
+                    self["particle_position_z"]):
+                dist[mark] = periodic_dist(cen, pos, period)
+                mark += 1
+            dist_min, dist_max = min(dist), max(dist)
+        # If I don't have this halo, make some dummy values.
+        else:
+            dist_min = max(period)
+            dist_max = 0.0
+        # In this parallel case, we're going to find the global dist extrema
+        # and built identical bins on all tasks.
+        dist_min = self._mpi_allmin(dist_min)
+        dist_max = self._mpi_allmax(dist_max)
+        # Set up the radial bins.
+        # Multiply min and max to prevent issues with digitize below.
+        self.radial_bins = na.logspace(math.log10(dist_min*.99), 
+            math.log10(dist_max*1.01), num=self.bin_count+1)
+        if self.indices is not None:
+            # Find out which bin each particle goes into, and add the particle
+            # mass to that bin.
+            inds = na.digitize(dist, self.radial_bins) - 1
+            for index in na.unique(inds):
+                self.mass_bins[index] += sum(self["ParticleMassMsun"][inds==index])
+            # Now forward sum the masses in the bins.
+            for i in xrange(self.bin_count):
+                self.mass_bins[i+1] += self.mass_bins[i]
+        # Sum up the mass_bins globally
+        self.mass_bins = self._mpi_Allsum_double(self.mass_bins)
+        # Calculate the over densities in the bins.
+        self.overdensity = self.mass_bins * Msun2g / \
+        (4./3. * math.pi * rho_crit * \
+        (self.radial_bins * self.halo_list._data_source.pf["cm"])**3.0)
+
 
 class FOFHalo(Halo):
 
@@ -735,7 +947,7 @@
                 self.max_dens_point[index][2], self.max_dens_point[index][3]]
             index += 1
         # Clean up
-        del self.max_dens_point, self.Tot_M, self.max_radius, self.bulk_vel
+        del self.max_dens_point, self.max_radius, self.bulk_vel
         del self.halo_taskmap, self.tags
 
     def __len__(self):
@@ -802,7 +1014,7 @@
                 halo._owner = proc
                 id += 1
         def haloCmp(h1,h2):
-            c = cmp(h1.get_size(),h2.get_size())
+            c = cmp(h1.total_mass(),h2.total_mass())
             if c != 0:
                 return -1 * c
             if c == 0:
@@ -970,11 +1182,12 @@
         yt_counters("Final Grouping")
 
     def _join_halolists(self):
-        gs = -self.group_sizes.copy()
+        ms = -self.Tot_M.copy()
+        del self.Tot_M
         Cx = self.CoM[:,0].copy()
         indexes = na.arange(self.group_count)
-        sorted = na.asarray([indexes[i] for i in na.lexsort([indexes, Cx, gs])])
-        del indexes, Cx, gs
+        sorted = na.asarray([indexes[i] for i in na.lexsort([indexes, Cx, ms])])
+        del indexes, Cx, ms
         self._groups = self._groups[sorted]
         self._max_dens = self._max_dens[sorted]
         for i in xrange(self.group_count):

Modified: trunk/yt/lagos/ParallelTools.py
==============================================================================
--- trunk/yt/lagos/ParallelTools.py	(original)
+++ trunk/yt/lagos/ParallelTools.py	Sun Dec 20 12:42:09 2009
@@ -780,7 +780,7 @@
 
     def _mpi_exit_test(self, data=False):
         # data==True -> exit. data==False -> no exit
-        statuses = self._mpi_info_dict(data)
+        mine, statuses = self._mpi_info_dict(data)
         if True in statuses.values():
             raise RunTimeError("Fatal error. Exiting.")
         return None

Added: trunk/yt/lagos/parallelHOP/__init__.py
==============================================================================
--- (empty file)
+++ trunk/yt/lagos/parallelHOP/__init__.py	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,3 @@
+from yt.lagos import *
+
+from parallelHOP import *

Added: trunk/yt/lagos/parallelHOP/parallelHOP.py
==============================================================================
--- (empty file)
+++ trunk/yt/lagos/parallelHOP/parallelHOP.py	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,1479 @@
+"""
+A implementation of the HOP algorithm that runs in parallel.
+
+Author: Stephen Skory <sskory at physics.ucsd.edu>
+Affiliation: UCSD/CASS
+Homepage: http://yt.enzotools.org/
+License:
+  Copyright (C) 2008-2009 Stephen Skory.  All Rights Reserved.
+
+  This file is part of yt.
+
+  yt is free software; you can redistribute it and/or modify
+  it under the terms of the GNU General Public License as published by
+  the Free Software Foundation; either version 3 of the License, or
+  (at your option) any later version.
+
+  This program is distributed in the hope that it will be useful,
+  but WITHOUT ANY WARRANTY; without even the implied warranty of
+  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+  GNU General Public License for more details.
+
+  You should have received a copy of the GNU General Public License
+  along with this program.  If not, see <http://www.gnu.org/licenses/>.
+"""
+
+from collections import defaultdict
+import itertools, sys
+
+from yt.lagos import *
+from yt.funcs import *
+from yt.extensions.kdtree import *
+from yt.performance_counters import yt_counters, time_function
+
+class RunParallelHOP(ParallelAnalysisInterface):
+    def __init__(self,period, padding, num_neighbors, bounds,
+            xpos, ypos, zpos, index, mass, threshold=160.0, rearrange=True):
+        self.threshold = threshold
+        self.rearrange = rearrange
+        self.saddlethresh = 2.5 * threshold
+        self.peakthresh = 3 * threshold
+        self.period = period
+        self.padding = padding
+        self.num_neighbors = num_neighbors
+        self.bounds = bounds
+        self.xpos = xpos
+        self.ypos = ypos
+        self.zpos = zpos
+        self.real_size = len(self.xpos)
+        self.index = na.array(index, dtype='int64')
+        self.mass = mass
+        self.padded_particles = []
+        self.nMerge = 4
+        yt_counters("chainHOP")
+        self.max_mem = 0
+        self.__max_memory()
+        self._chain_hop()
+        yt_counters("chainHOP")
+
+    def _global_bounds_neighbors(self):
+        """
+        Build a dict of the boundaries of all the tasks, and figure out which
+        tasks are our geometric neighbors.
+        """
+        self.neighbors = set([])
+        self.mine, global_bounds = self._mpi_info_dict(self.bounds)
+        my_LE, my_RE = self.bounds
+        # Put the vertices into a big list, each row is
+        # array[x,y,z, taskID]
+        vertices = []
+        my_vertices = []
+        for taskID in global_bounds:
+            thisLE, thisRE = global_bounds[taskID]
+            if self.mine != taskID:
+                vertices.append(na.array([thisLE[0], thisLE[1], thisLE[2], taskID]))
+                vertices.append(na.array([thisLE[0], thisLE[1], thisRE[2], taskID]))
+                vertices.append(na.array([thisLE[0], thisRE[1], thisLE[2], taskID]))
+                vertices.append(na.array([thisRE[0], thisLE[1], thisLE[2], taskID]))
+                vertices.append(na.array([thisLE[0], thisRE[1], thisRE[2], taskID]))
+                vertices.append(na.array([thisRE[0], thisLE[1], thisRE[2], taskID]))
+                vertices.append(na.array([thisRE[0], thisRE[1], thisLE[2], taskID]))
+                vertices.append(na.array([thisRE[0], thisRE[1], thisRE[2], taskID]))
+            if self.mine == taskID:
+                my_vertices.append(na.array([thisLE[0], thisLE[1], thisLE[2]]))
+                my_vertices.append(na.array([thisLE[0], thisLE[1], thisRE[2]]))
+                my_vertices.append(na.array([thisLE[0], thisRE[1], thisLE[2]]))
+                my_vertices.append(na.array([thisRE[0], thisLE[1], thisLE[2]]))
+                my_vertices.append(na.array([thisLE[0], thisRE[1], thisRE[2]]))
+                my_vertices.append(na.array([thisRE[0], thisLE[1], thisRE[2]]))
+                my_vertices.append(na.array([thisRE[0], thisRE[1], thisLE[2]]))
+                my_vertices.append(na.array([thisRE[0], thisRE[1], thisRE[2]]))
+        # Find the neighbors we share corners with. Yes, this is lazy with
+        # a double loop, but it works and this is definitely not a performance
+        # bottleneck.
+        for my_vertex in my_vertices:
+            for vertex in vertices:
+                if vertex[3] in self.neighbors: continue
+                # If the corners touch, it's easy. This is the case if resizing
+                # (load-balancing) is turned off.
+                if (my_vertex % self.period == vertex[0:3] % self.period).all():
+                    self.neighbors.add(int(vertex[3]))
+                    continue
+                # Also test to see if the distance to this corner is within
+                # max_padding, which is more likely the case with load-balancing
+                # turned on.
+                dx = na.min( na.abs(my_vertex[0] - vertex[0]), \
+                    self.period[0] - na.abs(my_vertex[0] - vertex[0]))
+                dy = na.min( na.abs(my_vertex[1] - vertex[1]), \
+                    self.period[1] - na.abs(my_vertex[1] - vertex[1]))
+                dz = na.min( na.abs(my_vertex[2] - vertex[2]), \
+                    self.period[2] - na.abs(my_vertex[2] - vertex[2]))
+                d = na.sqrt(dx*dx + dy*dy + dz*dz)
+                if d <= self.max_padding:
+                    self.neighbors.add(int(vertex[3]))
+        # Faces and edges.
+        for dim in range(3):
+            dim1 = (dim + 1) % 3
+            dim2 = (dim + 2) % 3
+            left_face = my_LE[dim]
+            right_face = my_RE[dim]
+            for taskID in global_bounds:
+                if taskID == self.mine or taskID in self.neighbors: continue
+                thisLE, thisRE = global_bounds[taskID]
+                max1 = max(my_LE[dim1], thisLE[dim1])
+                max2 = max(my_LE[dim2], thisLE[dim2])
+                min1 = min(my_RE[dim1], thisRE[dim1])
+                min2 = min(my_RE[dim2], thisRE[dim2])
+                # Faces.
+                # First, faces that touch directly.
+                if (thisRE[dim] == left_face or thisRE[dim]%self.period[dim] == left_face) and \
+                        max1 <= min1 and max2 <= min2:
+                    self.neighbors.add(taskID)
+                    continue
+                elif (thisLE[dim] == right_face or thisLE[dim] == right_face%self.period[dim]) and \
+                        max1 <= min1 and max2 <= min2:
+                    self.neighbors.add(taskID)
+                    continue
+                # If an intervening subvolume has a width less than the padding
+                # (rare, but possible), a neighbor may not actually touch, so
+                # we need to account for that.
+                if (abs(thisRE[dim] - left_face) <= self.max_padding or \
+                        abs(thisRE[dim]%self.period[dim] - left_face) <= self.max_padding) and \
+                        max1 <= min1 and max2 <= min2:
+                    self.neighbors.add(taskID)
+                    continue
+                elif (abs(thisLE[dim] - right_face) <= self.max_padding or \
+                        abs(thisLE[dim] - right_face%self.period[dim]) <= self.max_padding) and \
+                        max1 <= min1 and max2 <= min2:
+                    self.neighbors.add(taskID)
+                    continue
+                # Edges.
+                # First, edges that touch.
+                elif (my_LE[dim] == (thisRE[dim]%self.period[dim]) and \
+                        my_LE[dim1] == (thisRE[dim1]%self.period[dim1]) and \
+                        max2 <= min2) or \
+                        (my_LE[dim] == (thisRE[dim]%self.period[dim]) and \
+                        my_LE[dim2] == (thisRE[dim2]%self.period[dim2]) and \
+                        max1 <= min1):
+                    self.neighbors.add(taskID)
+                    continue
+                elif ((my_RE[dim]%self.period[dim]) == thisLE[dim] and \
+                        (my_RE[dim1]%self.period[dim1]) == thisLE[dim1] and \
+                        max2 <= min2) or \
+                        ((my_RE[dim]%self.period[dim]) == thisLE[dim] and \
+                        (my_RE[dim2]%self.period[dim2]) == thisLE[dim2] and \
+                        max1 <= min1):
+                    self.neighbors.add(taskID)
+                    continue
+                # Now edges that don't touch, but are close.
+                if (abs(my_LE[dim] - thisRE[dim]%self.period[dim]) <= self.max_padding and \
+                        abs(my_LE[dim1] - thisRE[dim1]%self.period[dim1]) <= self.max_padding and \
+                        max2 <= min2) or \
+                        (abs(my_LE[dim] - thisRE[dim]%self.period[dim]) <= self.max_padding and \
+                        abs(my_LE[dim2] - thisRE[dim2]%self.period[dim2]) <= self.max_padding and \
+                        max1 <= min1):
+                    self.neighbors.add(taskID)
+                    continue
+                elif (abs(my_RE[dim]%self.period[dim] - thisLE[dim]) <= self.max_padding and \
+                        abs(my_RE[dim1]%self.period[dim1] - thisLE[dim1]) <= self.max_padding and \
+                        max2 <= min2) or \
+                        (abs(my_RE[dim]%self.period[dim] - thisLE[dim]) <= self.max_padding and \
+                        abs(my_RE[dim2]%self.period[dim2] - thisLE[dim2]) <= self.max_padding and \
+                        max1 <= min1):
+                    self.neighbors.add(taskID)
+                    continue
+        # Now we build a global dict of neighbor sets, and if a remote task
+        # lists us as their neighbor, we add them as our neighbor. This is 
+        # probably not needed because the stuff above should be symmetric,
+        # but it isn't a big issue.
+        self.mine, global_neighbors = self._mpi_info_dict(self.neighbors)
+        for taskID in global_neighbors:
+            if taskID == self.mine: continue
+            if self.mine in global_neighbors[taskID]:
+                self.neighbors.add(taskID)
+        # We can remove ourselves from the set if it got added somehow.
+        self.neighbors.discard(self.mine)
+        # Clean up.
+        del global_neighbors, global_bounds, vertices, my_vertices
+        
+    def _global_padding(self, round):
+        """
+        Find the maximum padding of all our neighbors, used to send our
+        annulus data.
+        """
+        if round == 'first':
+            max_pad = na.max(self.padding)
+            self.mine, self.global_padding = self._mpi_info_dict(max_pad)
+            self.max_padding = max(self.global_padding.itervalues())
+        elif round == 'second':
+            self.max_padding = 0.
+            for neighbor in self.neighbors:
+                self.max_padding = na.maximum(self.global_padding[neighbor], \
+                    self.max_padding)
+
+    def _communicate_padding_data(self):
+        """
+        Send the particles each of my neighbors need to build up their padding.
+        """
+        yt_counters("Communicate discriminated padding")
+        # First build a global dict of the padded boundaries of all the tasks.
+        (LE, RE) = self.bounds
+        (LE_padding, RE_padding) = self.padding
+        temp_LE = LE - LE_padding
+        temp_RE = RE + RE_padding
+        expanded_bounds = (temp_LE, temp_RE)
+        self.mine, global_exp_bounds = self._mpi_info_dict(expanded_bounds)
+        send_real_indices = {}
+        send_points = {}
+        send_mass = {}
+        send_size = {}
+        # This will reduce the size of the loop over particles.
+        yt_counters("Picking padding data to send.")
+        send_count = len(na.where(self.is_inside_annulus == True)[0])
+        points = na.empty((send_count, 3), dtype='float64')
+        points[:,0] = self.xpos[self.is_inside_annulus]
+        points[:,1] = self.ypos[self.is_inside_annulus]
+        points[:,2] = self.zpos[self.is_inside_annulus]
+        real_indices = self.index[self.is_inside_annulus].astype('int64')
+        mass = self.mass[self.is_inside_annulus].astype('float64')
+        # Make the arrays to send.
+        shift_points = points.copy()
+        for neighbor in self.neighbors:
+            temp_LE, temp_RE = global_exp_bounds[neighbor]
+            for i in xrange(3):
+                left = ((points[:,i] < temp_LE[i]) * (points[:,i] < temp_RE[i])) * self.period[i]
+                right = ((points[:,i] > temp_LE[i]) * (points[:,i] > temp_RE[i])) * self.period[i]
+                shift_points[:,i] = points[:,i] + left - right
+            is_inside = ( (shift_points >= temp_LE).all(axis=1) * \
+                (shift_points < temp_RE).all(axis=1) )
+            send_real_indices[neighbor] = real_indices[is_inside].copy()
+            send_points[neighbor] = shift_points[is_inside].copy()
+            send_mass[neighbor] = mass[is_inside].copy()
+            send_size[neighbor] = len(na.where(is_inside == True)[0])
+        del points, shift_points, mass, real_indices
+        yt_counters("Picking padding data to send.")
+        # Communicate the sizes to send.
+        self.mine, global_send_count = self._mpi_info_dict(send_size)
+        # Initialize the arrays to receive data.
+        yt_counters("Initalizing recv arrays.")
+        recv_real_indices = {}
+        recv_points = {}
+        recv_mass = {}
+        recv_size = 0
+        for opp_neighbor in self.neighbors:
+            opp_size = global_send_count[opp_neighbor][self.mine]
+            recv_real_indices[opp_neighbor] = na.empty(opp_size, dtype='int64')
+            recv_points[opp_neighbor] = na.empty((opp_size, 3), dtype='float64')
+            recv_mass[opp_neighbor] = na.empty(opp_size, dtype='float64')
+            recv_size += opp_size
+        yt_counters("Initalizing recv arrays.")
+        # Setup the receiving slots.
+        yt_counters("MPI stuff.")
+        hooks = []
+        for opp_neighbor in self.neighbors:
+            hooks.append(self._mpi_Irecv_long(recv_real_indices[opp_neighbor], opp_neighbor))
+            hooks.append(self._mpi_Irecv_double(recv_points[opp_neighbor], opp_neighbor))
+            hooks.append(self._mpi_Irecv_double(recv_mass[opp_neighbor], opp_neighbor))
+        # Let's wait here to be absolutely sure that all the receive buffers
+        # have been created before any sending happens!
+        self._barrier()
+        # Now we send the data.
+        for neighbor in self.neighbors:
+            hooks.append(self._mpi_Isend_long(send_real_indices[neighbor], neighbor))
+            hooks.append(self._mpi_Isend_double(send_points[neighbor], neighbor))
+            hooks.append(self._mpi_Isend_double(send_mass[neighbor], neighbor))
+        # Now we use the data, after all the comms are done.
+        self._mpi_Request_Waitall(hooks)
+        yt_counters("MPI stuff.")
+        yt_counters("Processing padded data.")
+        del send_real_indices, send_points, send_mass
+        # Now we add the data to ourselves.
+        self.index_pad = na.empty(recv_size, dtype='int64')
+        self.xpos_pad = na.empty(recv_size, dtype='float64')
+        self.ypos_pad = na.empty(recv_size, dtype='float64')
+        self.zpos_pad = na.empty(recv_size, dtype='float64')
+        self.mass_pad = na.empty(recv_size, dtype='float64')
+        so_far = 0
+        for opp_neighbor in self.neighbors:
+            opp_size = global_send_count[opp_neighbor][self.mine]
+            self.index_pad[so_far:so_far+opp_size] = recv_real_indices[opp_neighbor]
+            # Clean up immediately to reduce peak memory usage.
+            del recv_real_indices[opp_neighbor]
+            self.xpos_pad[so_far:so_far+opp_size] = recv_points[opp_neighbor][:,0]
+            self.ypos_pad[so_far:so_far+opp_size] = recv_points[opp_neighbor][:,1]
+            self.zpos_pad[so_far:so_far+opp_size] = recv_points[opp_neighbor][:,2]
+            del recv_points[opp_neighbor]
+            self.mass_pad[so_far:so_far+opp_size] = recv_mass[opp_neighbor]
+            del recv_mass[opp_neighbor]
+            so_far += opp_size
+        yt_counters("Processing padded data.")
+        # The KDtree node search wants the particles to be in the full box,
+        # not the expanded dimensions of shifted (<0 or >1 generally) volume,
+        # so we fix the positions of particles here.
+        yt_counters("Flipping coordinates around the periodic boundary.")
+        self.xpos_pad = self.xpos_pad % self.period[0]
+        self.ypos_pad = self.ypos_pad % self.period[1]
+        self.zpos_pad = self.zpos_pad % self.period[2]
+        yt_counters("Flipping coordinates around the periodic boundary.")
+        self.size = self.index.size + self.index_pad.size
+        # Now that we have the full size, initialize the chainID array
+        self.chainID = na.ones(self.size,dtype='int64') * -1
+        # Clean up explicitly, but these should be empty dicts by now.
+        del recv_real_indices, hooks, recv_points, recv_mass
+        yt_counters("Communicate discriminated padding")
+
+    def _init_kd_tree(self):
+        """
+        Set up the data objects that get passed to the kD-tree code.
+        """
+        yt_counters("init kd tree")
+        # Yes, we really do need to initialize this many arrays.
+        # They're deleted in _parallelHOP.
+        fKD.dens = na.asfortranarray(na.zeros(self.size, dtype='float64'))
+        fKD.mass = na.concatenate((self.mass, self.mass_pad))
+        fKD.pos = na.asfortranarray(na.empty((3, self.size), dtype='float64'))
+        # This actually copies the data into the fortran space.
+        fKD.pos[0, :] = na.concatenate((self.xpos, self.xpos_pad))
+        fKD.pos[1, :] = na.concatenate((self.ypos, self.ypos_pad))
+        fKD.pos[2, :] = na.concatenate((self.zpos, self.zpos_pad))
+        fKD.qv = na.asfortranarray(na.empty(3, dtype='float64'))
+        fKD.nn = self.num_neighbors
+        # Plus 2 because we're looking for that neighbor, but only keeping 
+        # nMerge + 1 neighbor tags, skipping ourselves.
+        fKD.nMerge = self.nMerge + 2
+        fKD.nparts = self.size
+        fKD.sort = True # Slower, but needed in _connect_chains
+        fKD.rearrange = self.rearrange # True is faster, but uses more memory
+        # Now call the fortran.
+        create_tree()
+        self.__max_memory()
+        yt_counters("init kd tree")
+
+    def _is_inside(self, round):
+        """
+        There are three classes of particles.
+        1. Particles inside the 'real' region of each subvolume.
+        2. Particles ouside, added in the 'padding' for purposes of having 
+           correct particle densities in the real region.
+        3. Particles that are one padding distance inside the edges of the
+           real region. The chainIDs of these particles are communicated
+           to the neighboring tasks so chains can be merged into groups.
+        The input *round* is either 'first' or 'second.' First is before the
+        padded particles have been communicated, and second after.
+        """
+        # Test to see if the points are in the 'real' region
+        (LE, RE) = self.bounds
+        if round == 'first':
+            points = na.empty((self.real_size, 3), dtype='float64')
+            points[:,0] = self.xpos
+            points[:,1] = self.ypos
+            points[:,2] = self.zpos
+            self.is_inside = ( (points >= LE).all(axis=1) * \
+                (points < RE).all(axis=1) )
+        elif round == 'second':
+            self.is_inside = ( (fKD.pos.T >= LE).all(axis=1) * \
+                (fKD.pos.T < RE).all(axis=1) )
+        # Below we find out which particles are in the `annulus', one padding
+        # distance inside the boundaries. First we find the particles outside
+        # this inner boundary.
+        temp_LE = LE + self.max_padding
+        temp_RE = RE - self.max_padding
+        if round == 'first':
+            inner = na.invert( (points >= temp_LE).all(axis=1) * \
+                (points < temp_RE).all(axis=1) )
+        elif round == 'second' or round == 'third':
+            inner = na.invert( (fKD.pos.T >= temp_LE).all(axis=1) * \
+                (fKD.pos.T < temp_RE).all(axis=1) )
+        if round == 'first':
+            del points
+        # After inverting the logic above, we want points that are both
+        # inside the real region, but within one padding of the boundary,
+        # and this will do it.
+        self.is_inside_annulus = na.bitwise_and(self.is_inside, inner)
+        # Below we make a mapping of real particle index->local ID
+        # Unf. this has to be a dict, because any task can have
+        # particles of any particle_index, which means that if it were an
+        # array every task would probably end up having this array be as long
+        # as the full number of particles.
+        # We can skip this the first two times around.
+        if round == 'third':
+            temp = na.arange(self.size)
+            my_part = na.bitwise_or(na.invert(self.is_inside), self.is_inside_annulus)
+            my_part = na.bitwise_and(my_part, (self.chainID != -1))
+            catted_indices = na.concatenate(
+                (self.index, self.index_pad))[my_part]
+            self.rev_index = dict.fromkeys(catted_indices)
+            self.rev_index.update(itertools.izip(catted_indices, temp[my_part]))
+            del my_part, temp, catted_indices
+        self.__max_memory()
+
+    def _densestNN(self):
+        """
+        For all particles, find their densest nearest neighbor. It is done in
+        chunks to keep the memory usage down.
+        The first search of nearest neighbors (done earlier) did not return all 
+        num_neighbor neighbors, so we need to do it again, but we're not
+        keeping the all of this data, just using it.
+        """
+        yt_counters("densestNN")
+        self.densestNN = na.empty(self.size,dtype='int64')
+        # We find nearest neighbors in chunks.
+        chunksize = 10000
+        fKD.chunk_tags = na.asfortranarray(na.empty((self.num_neighbors, chunksize), dtype='int64'))
+        start = 1 # Fortran counting!
+        finish = 0
+        while finish < self.size:
+            finish = min(finish+chunksize,self.size)
+            # Call the fortran. start and finish refer to the data locations
+            # in fKD.pos, and specify the range of particles to find nearest
+            # neighbors
+            fKD.start = start
+            fKD.finish = finish
+            find_chunk_nearest_neighbors()
+            chunk_NNtags = (fKD.chunk_tags[:,:finish-start+1] - 1).transpose()
+            # Find the densest nearest neighbors by referencing the already
+            # calculated density.
+            n_dens = na.take(self.density,chunk_NNtags)
+            max_loc = na.argmax(n_dens,axis=1)
+            for i in xrange(finish - start + 1): # +1 for fortran counting.
+                j = start + i - 1 # -1 for fortran counting.
+                self.densestNN[j] = chunk_NNtags[i,max_loc[i]]
+            start = finish + 1
+        yt_counters("densestNN")
+        self.__max_memory()
+        del chunk_NNtags, max_loc, n_dens
+    
+    def _build_chains(self):
+        """
+        Build the first round of particle chains. If the particle is too low in
+        density, move on.
+        """
+        yt_counters("build_chains")
+        chainIDmax = 0
+        self.densest_in_chain = na.ones(10000, dtype='float64') * -1 # chainID->density, one to one
+        self.densest_in_chain_real_index = na.ones(10000, dtype='int64') * -1 # chainID->real_index, one to one
+        for i in xrange(int(self.size)):
+            # If it's already in a group, move on, or if this particle is
+            # in the padding, move on because chains can only terminate in
+            # the padding, not begin, or if this particle is too low in
+            # density, move on.
+            if self.chainID[i] > -1 or not self.is_inside[i] or \
+                    self.density[i] < self.threshold:
+                continue
+            chainIDnew = self._recurse_links(i, chainIDmax)
+            # If the new chainID returned is the same as we entered, the chain
+            # has been named chainIDmax, so we need to start a new chain
+            # in the next loop.
+            if chainIDnew == chainIDmax:
+                chainIDmax += 1
+        self.padded_particles = na.array(self.padded_particles, dtype='int64')
+        self.densest_in_chain = self.__clean_up_array(self.densest_in_chain)
+        self.densest_in_chain_real_index = self.__clean_up_array(self.densest_in_chain_real_index)
+        yt_counters("build_chains")
+        self.__max_memory()
+        return chainIDmax
+    
+    def _recurse_links(self, pi, chainIDmax):
+        """
+        Recurse up the chain to a) a self-highest density particle,
+        b) a particle that already has a chainID, then turn it back around
+        assigning that chainID to where we came from. If c) which
+        is a particle in the padding, terminate the chain right then
+        and there, because chains only go one particle deep into the padding.
+        """
+        nn = self.densestNN[pi]
+        inside = self.is_inside[pi]
+        nn_chainID = self.chainID[nn]
+        # Linking to an already chainID-ed particle (don't make links from 
+        # padded particles!)
+        if nn_chainID > -1 and inside:
+            self.chainID[pi] = nn_chainID
+            return nn_chainID
+        # If pi is a self-most dense particle or inside the padding, end/create
+        # a new chain.
+        elif nn == pi or not inside:
+            self.chainID[pi] = chainIDmax
+            self.densest_in_chain = self.__add_to_array(self.densest_in_chain,
+                chainIDmax, self.density[pi], 'float64')
+            if pi < self.real_size:
+                self.densest_in_chain_real_index = self.__add_to_array(self.densest_in_chain_real_index,
+                chainIDmax, self.index[pi], 'int64')
+            else:
+                self.densest_in_chain_real_index = self.__add_to_array(self.densest_in_chain_real_index,
+                chainIDmax, self.index_pad[pi-self.real_size], 'int64')
+            # if this is a padded particle, record it for later
+            if not inside:
+                self.padded_particles.append(pi)
+            return chainIDmax
+        # Otherwise, recursively link to nearest neighbors.
+        else:
+            chainIDnew = self._recurse_links(nn, chainIDmax)
+            self.chainID[pi] = chainIDnew
+            return chainIDnew
+
+    def _recurse_preconnected_links(self, chain_map, thisID):
+        if min(thisID, min(chain_map[thisID])) == thisID:
+            return thisID
+        else:
+            return self._recurse_preconnected_links(chain_map, min(chain_map[thisID]))
+
+    def _preconnect_chains(self, chain_count):
+        """
+        In each subvolume, chains that share a boundary that both have high
+        enough peak densities are prelinked in order to reduce the size of the
+        global chain objects. This is very similar to _connect_chains().
+        """
+        # First we'll sort them, which will be used below.
+        mylog.info("Locally sorting chains...")
+        yt_counters("preconnect_chains")
+        yt_counters("local chain sorting.")
+        sort = self.densest_in_chain.argsort()
+        sort = na.flipud(sort)
+        map = na.empty(sort.size,dtype='int64')
+        map[sort] = na.arange(sort.size)
+        self.densest_in_chain = self.densest_in_chain[sort]
+        self.densest_in_chain_real_index = self.densest_in_chain_real_index[sort]
+        del sort
+        for i,chID in enumerate(self.chainID):
+            if chID == -1: continue
+            self.chainID[i] = map[chID]
+        del map
+        yt_counters("local chain sorting.")
+        mylog.info("Preconnecting %d chains..." % chain_count)
+        chain_map = defaultdict(set)
+        for i in xrange(max(self.chainID)+1):
+            chain_map[i].add(i)
+        # Plus 2 because we're looking for that neighbor, but only keeping 
+        # nMerge + 1 neighbor tags, skipping ourselves.
+        fKD.dist = na.empty(self.nMerge+2, dtype='float64')
+        fKD.tags = na.empty(self.nMerge+2, dtype='int64')
+        # We can change this here to make the searches faster.
+        fKD.nn = self.nMerge+2
+        yt_counters("preconnect kd tree search.")
+        for i in xrange(self.size):
+            # Don't consider this particle if it's not part of a chain.
+            if self.chainID[i] < 0: continue
+            chainID_i = self.chainID[i]
+            # If this particle is in the padding, don't make a connection.
+            if not self.is_inside[i]: continue
+            # Find this particle's chain max_dens.
+            part_max_dens = self.densest_in_chain[chainID_i]
+            # We're only connecting >= peakthresh chains now.
+            if part_max_dens < self.peakthresh: continue
+            # Loop over nMerge closest nearest neighbors.
+            fKD.qv = fKD.pos[:, i]
+            find_nn_nearest_neighbors()
+            NNtags = fKD.tags[:] - 1
+            same_count = 0
+            for j in xrange(int(self.nMerge+1)):
+                thisNN = NNtags[j+1] # Don't consider ourselves at NNtags[0]
+                thisNN_chainID = self.chainID[thisNN]
+                # If our neighbor is in the same chain, move on.
+                # Move on if these chains are already connected:
+                if chainID_i == thisNN_chainID or \
+                        thisNN_chainID in chain_map[chainID_i]:
+                    same_count += 1
+                    continue
+                # Everything immediately below is for
+                # neighboring particles with a chainID. 
+                if thisNN_chainID >= 0:
+                    # Find thisNN's chain's max_dens.
+                    thisNN_max_dens = self.densest_in_chain[thisNN_chainID]
+                    # We're only linking peakthresh chains
+                    if thisNN_max_dens < self.peakthresh: continue
+                    # Calculate the two groups boundary density.
+                    boundary_density = (self.density[thisNN] + self.density[i]) / 2.
+                    # Don't connect if the boundary is too low.
+                    if boundary_density < self.saddlethresh: continue
+                    # Mark these chains as related.
+                    chain_map[thisNN_chainID].add(chainID_i)
+                    chain_map[chainID_i].add(thisNN_chainID)
+            if same_count == self.nMerge + 1:
+                # All our neighbors are in the same chain already, so 
+                # we don't need to search again.
+                self.search_again[i] = False
+        yt_counters("preconnect kd tree search.")
+        # Recursively jump links until we get to a chain whose densest
+        # link is to itself. At that point we've found the densest chain
+        # in this set of sets and we keep a record of that.
+        yt_counters("preconnect pregrouping.")
+        final_chain_map = na.empty(max(self.chainID)+1, dtype='int64')
+        removed = 0
+        for i in xrange(max(self.chainID)+1):
+            j = chain_count - i - 1
+            densest_link = self._recurse_preconnected_links(chain_map, j)
+            final_chain_map[j] = densest_link
+            if j != densest_link:
+                removed += 1
+                self.densest_in_chain[j] = -1
+                self.densest_in_chain_real_index[j] = -1
+        del chain_map
+        for i in xrange(self.size):
+            if self.chainID[i] != -1:
+                self.chainID[i] = final_chain_map[self.chainID[i]]
+        del final_chain_map
+        # Now make the chainID assignments consecutive.
+        map = na.empty(self.densest_in_chain.size, dtype='int64')
+        dic_new = na.empty(chain_count - removed, dtype='float64')
+        dicri_new = na.empty(chain_count - removed, dtype='int64')
+        new = 0
+        for i,dic in enumerate(self.densest_in_chain):
+            if dic > 0:
+                map[i] = new
+                dic_new[new] = dic
+                dicri_new[new] = self.densest_in_chain_real_index[i]
+                new += 1
+            else:
+                map[i] = -1
+        for i in range(self.size):
+            if self.chainID[i] != -1:
+                self.chainID[i] = map[self.chainID[i]]
+        del map
+        self.densest_in_chain = dic_new.copy()
+        self.densest_in_chain_real_index = dicri_new.copy()
+        self.__max_memory()
+        yt_counters("preconnect pregrouping.")
+        mylog.info("Preconnected %d chains." % removed)
+        yt_counters("preconnect_chains")
+
+        return chain_count - removed
+
+    def _globally_assign_chainIDs(self, chain_count):
+        """
+        Convert local chainIDs into globally unique chainIDs.
+        """
+        yt_counters("globally_assign_chainIDs")
+        # First find out the number of chains on each processor.
+        self.mine, chain_info = self._mpi_info_dict(chain_count)
+        self.nchains = sum(chain_info.values())
+        # Figure out our offset.
+        self.my_first_id = sum([v for k,v in chain_info.iteritems() if k < self.mine])
+        # Change particle IDs, -1 always means no chain assignment.
+        select = (self.chainID != -1)
+        select = select * self.my_first_id
+        self.chainID += select
+        del select
+        yt_counters("globally_assign_chainIDs")
+
+    def _create_global_densest_in_chain(self):
+        """
+        With the globally unique chainIDs, update densest_in_chain.
+        """
+        yt_counters("create_global_densest_in_chain")
+        # Shift the values over effectively by concatenating them in the same
+        # order as the values have been shifted in _globally_assign_chainIDs()
+        yt_counters("global chain MPI stuff.")
+        self.densest_in_chain = self._mpi_concatenate_array_double(self.densest_in_chain)
+        self.densest_in_chain_real_index = self._mpi_concatenate_array_long(self.densest_in_chain_real_index)
+        yt_counters("global chain MPI stuff.")
+        # Sort the chains by density here. This is an attempt to make it such
+        # that the merging stuff in a few steps happens in the same order
+        # all the time.
+        mylog.info("Sorting chains...")
+        yt_counters("global chain sorting.")
+        sort = self.densest_in_chain.argsort()
+        sort = na.flipud(sort)
+        map = na.empty(sort.size,dtype='int64')
+        map[sort] =na.arange(sort.size)
+        self.densest_in_chain = self.densest_in_chain[sort]
+        self.densest_in_chain_real_index = self.densest_in_chain_real_index[sort]
+        del sort
+        for i,chID in enumerate(self.chainID):
+            if chID == -1: continue
+            self.chainID[i] = map[chID]
+        del map
+        yt_counters("global chain sorting.")
+        # For some reason chains that share the most-dense particle are not
+        # being linked, so we link them 'by hand' here.
+        mylog.info("Pre-linking chains 'by hand'...")
+        yt_counters("global chain hand-linking.")
+        # If there are no repeats, we can skip this mess entirely.
+        uniq = na.unique(self.densest_in_chain_real_index)
+        if uniq.size != self.densest_in_chain_real_index.size:
+            # Find only the real particle indices that are repeated to reduce
+            # the dict workload below.
+            dicri = self.densest_in_chain_real_index[self.densest_in_chain_real_index.argsort()]
+            diff = na.ediff1d(dicri)
+            diff = (diff == 0) # Picks out the places where the ids are equal
+            diff = na.concatenate((diff, [False])) # Makes it the same length
+            # This has only the repeated IDs. Sets are faster at searches than
+            # arrays.
+            dicri = set(dicri[diff])
+            reverse = defaultdict(set)
+            # Here we find a reverse mapping of real particle ID to chainID
+            for chainID, real_index in enumerate(self.densest_in_chain_real_index):
+                if real_index in dicri:
+                    reverse[real_index].add(chainID)
+            del dicri, diff
+            # If the real index has len(set)>1, there are multiple chains that need
+            # to be linked
+            tolink = defaultdict(set)
+            for real in reverse:
+                if len(reverse[real]) > 1:
+                    # Unf. can't slice a set, so this will have to do.
+                    tolink[min(reverse[real])] = reverse[real]
+                    tolink[min(reverse[real])].discard(min(reverse[real]))
+            del reverse
+            # Now we will remove the other chains from the dicts and re-assign
+            # particles to their new chainID.
+            fix_map = {}
+            for tokeep in tolink:
+                for remove in tolink[tokeep]:
+                    fix_map[remove] = tokeep
+                    self.densest_in_chain[remove] = -1.0
+                    self.densest_in_chain_real_index[remove] = -1
+            for i, chainID in enumerate(self.chainID):
+                try:
+                    new = fix_map[chainID]
+                except KeyError:
+                    continue
+                self.chainID[i] = new
+            del tolink, fix_map
+        yt_counters("global chain hand-linking.")
+        yt_counters("create_global_densest_in_chain")
+
+    def _communicate_uphill_info(self):
+        """
+        Communicate the links to the correct neighbors from uphill_info.
+        """
+        yt_counters("communicate_uphill_info")
+        # Find out how many particles we're going to receive, and make arrays
+        # of the right size and type to store them.
+        to_recv_count = 0
+        temp_indices = dict.fromkeys(self.neighbors)
+        temp_chainIDs = dict.fromkeys(self.neighbors)
+        for opp_neighbor in self.neighbors:
+            opp_size = self.global_padded_count[opp_neighbor]
+            to_recv_count += opp_size
+            temp_indices[opp_neighbor] = na.empty(opp_size, dtype='int64')
+            temp_chainIDs[opp_neighbor] = na.empty(opp_size, dtype='int64')
+        # The arrays we'll actually keep around...
+        self.recv_real_indices = na.empty(to_recv_count, dtype='int64')
+        self.recv_chainIDs = na.empty(to_recv_count, dtype='int64')
+        # Set up the receives, but don't actually use them.
+        hooks = []
+        for opp_neighbor in self.neighbors:
+            hooks.append(self._mpi_Irecv_long(temp_indices[opp_neighbor], opp_neighbor))
+            hooks.append(self._mpi_Irecv_long(temp_chainIDs[opp_neighbor], opp_neighbor))
+        # Make sure all the receive buffers are set before continuing.
+        self._barrier()
+        # Send padded particles to our neighbors.
+        for neighbor in self.neighbors:
+            hooks.append(self._mpi_Isend_long(self.uphill_real_indices, neighbor))
+            hooks.append(self._mpi_Isend_long(self.uphill_chainIDs, neighbor))
+        # Now actually use the data once it's good to go.
+        self._mpi_Request_Waitall(hooks)
+        self.__max_memory()
+        so_far = 0
+        for opp_neighbor in self.neighbors:
+            opp_size = self.global_padded_count[opp_neighbor]
+            # Only save the part of the buffer that we want to the right places
+            # in the full listing.
+            self.recv_real_indices[so_far:(so_far + opp_size)] = \
+                temp_indices[opp_neighbor][0:opp_size]
+            self.recv_chainIDs[so_far:(so_far + opp_size)] = \
+                temp_chainIDs[opp_neighbor][0:opp_size]
+            so_far += opp_size
+        # Clean up.
+        del temp_indices, temp_chainIDs, hooks
+        yt_counters("communicate_uphill_info")
+
+    def _recurse_global_chain_links(self, chainID_translate_map_global, chainID, seen):
+        """
+        Step up the global chain links until we reach the self-densest chain,
+        very similarly to the recursion of particles to densest nearest
+        neighbors.
+        """
+        new_chainID = chainID_translate_map_global[chainID]
+        if  new_chainID == chainID:
+            return int(chainID)
+        elif new_chainID in seen:
+            # Bad things are about to happen if this condition is met! The
+            # padding probably needs to be increased (using the safety factor).
+            mylog.info('seen %s' % str(seen))
+            for s in seen:
+                mylog.info('%d %d' % (s, chainID_translate_map_global[s]))
+        else:
+            seen.append(new_chainID)
+            return self._recurse_global_chain_links(chainID_translate_map_global, new_chainID, seen)
+
+    def _connect_chains_across_tasks(self):
+        """
+        Using the uphill links of chains, chains are linked across boundaries.
+        Chains that link to a remote chain are recorded, and a complete dict
+        of chain connections is created, globally. Then chainIDs are
+        reassigned recursively, assigning the ID of the most dense chainID
+        to every chain that links to it.
+        """
+        yt_counters("connect_chains_across_tasks")
+        # Remote (lower dens) chain -> local (higher) chain.
+        chainID_translate_map_local = na.arange(self.nchains)
+        # Build the stuff to send.
+        self.uphill_real_indices = na.concatenate((
+            self.index, self.index_pad))[self.padded_particles]
+        self.uphill_chainIDs = self.chainID[self.padded_particles]
+        del self.padded_particles
+        # Now we make a global dict of how many particles each task is
+        # sending.
+        self.global_padded_count = {self.mine:self.uphill_chainIDs.size}
+        self.global_padded_count = self._mpi_joindict(self.global_padded_count)
+        # Send/receive 'em.
+        self._communicate_uphill_info()
+        self.__max_memory()
+        # Fix the IDs to localIDs.
+        for i,real_index in enumerate(self.recv_real_indices):
+            try:
+                localID = self.rev_index[real_index]
+                # We don't want to update the chainIDs of my padded particles.
+                # Remember we are supposed to be only considering particles
+                # in my *real* region, that are padded in my neighbor.
+                if not self.is_inside[localID]:
+                    # Make it negative so we can skip it below.
+                    self.recv_real_indices[i] = -1
+                    continue
+                self.recv_real_indices[i] = localID
+            except KeyError:
+                # This is probably a particle we don't even own, so we want
+                # to ignore it.
+                self.recv_real_indices[i] = -1
+                continue
+        # Now relate the local chainIDs to the received chainIDs
+        for i,localID in enumerate(self.recv_real_indices):
+            # If the 'new' chainID is different that what we already have,
+            # we need to record it, but we skip particles that were assigned
+            # -1 above. Also, since links are supposed to go only uphill,
+            # ensure that they are being recorded that way below.
+            if localID != -1 and self.chainID[localID] != -1:
+                if self.recv_chainIDs[i] != self.chainID[localID] and \
+                        self.densest_in_chain[self.chainID[localID]] >= self.densest_in_chain[self.recv_chainIDs[i]] and \
+                        self.densest_in_chain[self.chainID[localID]] != -1.0 and \
+                        self.densest_in_chain[self.recv_chainIDs[i]] != -1.0:
+                    chainID_translate_map_local[self.recv_chainIDs[i]] = \
+                        self.chainID[localID]
+        self.__max_memory()
+        # In chainID_translate_map_local, chains may
+        # 'point' to only one chain, but a chain may have many that point to
+        # it. Therefore each key (a chain) in this dict is unique, but the items
+        # the keys point to are not necessarily unique.
+        chainID_translate_map_global = \
+            self._mpi_minimum_array_long(chainID_translate_map_local)
+        # Loop over chains, smallest to largest density, recursively until
+        # we reach a self-assigned chain. Then we assign that final chainID to
+        # the *current* one only.
+        seen = []
+        for key, density in enumerate(self.densest_in_chain):
+            if density == -1: continue # Skip 'deleted' chains
+            seen = []
+            seen.append(key)
+            new_chainID = \
+                self._recurse_global_chain_links(chainID_translate_map_global, key, seen)
+            chainID_translate_map_global[key] = new_chainID
+            # At the same time, remove chains from densest_in_chain that have
+            # been reassigned.
+            if key != new_chainID:
+                self.densest_in_chain[key] = -1.0
+                self.densest_in_chain_real_index[key] = -1
+                # Also fix nchains to keep up.
+                self.nchains -= 1
+        self.__max_memory()
+        # Convert local particles to their new chainID
+        for i in xrange(int(self.size)):
+            old_chainID = self.chainID[i]
+            if old_chainID == -1: continue
+            new_chainID = chainID_translate_map_global[old_chainID]
+            self.chainID[i] = new_chainID
+        del chainID_translate_map_local, self.recv_chainIDs
+        del self.recv_real_indices, self.uphill_real_indices, self.uphill_chainIDs
+        del seen, chainID_translate_map_global
+        yt_counters("connect_chains_across_tasks")
+
+    def _communicate_annulus_chainIDs(self):
+        """
+        Transmit all of our chainID-ed particles that are within self.padding
+        of the boundaries to all of our neighbors. Tests show that this is
+        faster than trying to figure out which of the neighbors to send the data
+        to.
+        """
+        yt_counters("communicate_annulus_chainIDs")
+        # Pick the particles in the annulus.
+        real_indices = na.concatenate(
+            (self.index, self.index_pad))[self.is_inside_annulus]
+        chainIDs = self.chainID[self.is_inside_annulus]
+        # We're done with this here.
+        del self.is_inside_annulus
+        # Eliminate un-assigned particles.
+        select = (chainIDs != -1)
+        real_indices = real_indices[select]
+        chainIDs = chainIDs[select]
+        send_count = real_indices.size
+        # Here distribute the counts globally. Unfortunately, it's a barrier(), 
+        # but there's so many places in this that need to be globally synched
+        # that it's not worth the effort right now to make this one spot better.
+        global_annulus_count = {self.mine:send_count}
+        global_annulus_count = self._mpi_joindict(global_annulus_count)
+        # Set up the receiving arrays.
+        recv_real_indices = dict.fromkeys(self.neighbors)
+        recv_chainIDs = dict.fromkeys(self.neighbors)
+        for opp_neighbor in self.neighbors:
+            opp_size = global_annulus_count[opp_neighbor]
+            recv_real_indices[opp_neighbor] = na.empty(opp_size, dtype='int64')
+            recv_chainIDs[opp_neighbor] = na.empty(opp_size, dtype='int64')
+        # Set up the receving hooks.
+        hooks = []
+        for opp_neighbor in self.neighbors:
+            hooks.append(self._mpi_Irecv_long(recv_real_indices[opp_neighbor], opp_neighbor))
+            hooks.append(self._mpi_Irecv_long(recv_chainIDs[opp_neighbor], opp_neighbor))
+        # Make sure the recv buffers are set before continuing.
+        self._barrier()
+        # Now we send them.
+        for neighbor in self.neighbors:
+            hooks.append(self._mpi_Isend_long(real_indices, neighbor))
+            hooks.append(self._mpi_Isend_long(chainIDs, neighbor))
+        # Now we use them when they're nice and ripe.
+        self._mpi_Request_Waitall(hooks)
+        self.__max_memory()
+        for opp_neighbor in self.neighbors:
+            opp_size = global_annulus_count[opp_neighbor]
+            # Update our local data.
+            for i,real_index in enumerate(recv_real_indices[opp_neighbor][0:opp_size]):
+                try:
+                    localID = self.rev_index[real_index]
+                    # We are only updating our particles that are in our
+                    # padding, so to be rigorous we will skip particles
+                    # that are in our real region.
+                    if self.is_inside[localID]:
+                        continue
+                    self.chainID[localID] = recv_chainIDs[opp_neighbor][i]
+                except KeyError:
+                    # We ignore data that's not for us.
+                    continue
+        # Clean up.
+        del recv_real_indices, recv_chainIDs, real_indices, chainIDs, select
+        del hooks
+        # We're done with this here.
+        del self.rev_index
+        yt_counters("communicate_annulus_chainIDs")
+
+
+    def _connect_chains(self):
+        """
+        With the set of particle chains, build a mapping of connected chainIDs
+        by finding the highest boundary density neighbor for each chain. Some
+        chains will have no neighbors!
+        """
+        yt_counters("connect_chains")
+        self.chain_densest_n = {} # chainID -> {chainIDs->boundary dens}
+        # Plus 2 because we're looking for that neighbor, but only keeping 
+        # nMerge + 1 neighbor tags, skipping ourselves.
+        fKD.dist = na.empty(self.nMerge+2, dtype='float64')
+        fKD.tags = na.empty(self.nMerge+2, dtype='int64')
+        # We can change this here to make the searches faster.
+        fKD.nn = self.nMerge+2
+        for i in xrange(int(self.size)):
+            # Don't consider this particle if it's not part of a chain.
+            if self.chainID[i] < 0: continue
+            # If this particle is in the padding, don't make a connection.
+            if not self.is_inside[i]: continue
+            # Make sure that we should search this particle again.
+            if not self.search_again[i]: continue
+            # Find this particle's chain max_dens.
+            part_max_dens = self.densest_in_chain[self.chainID[i]]
+            # Make sure we're skipping deleted chains.
+            if part_max_dens == -1.0: continue
+            # Loop over nMerge closest nearest neighbors.
+            fKD.qv = fKD.pos[:, i]
+            find_nn_nearest_neighbors()
+            NNtags = fKD.tags[:] - 1
+            for j in xrange(int(self.nMerge+1)):
+                thisNN = NNtags[j+1] # Don't consider ourselves at NNtags[0]
+                thisNN_chainID = self.chainID[thisNN]
+                # If our neighbor is in the same chain, move on.
+                if self.chainID[i] == thisNN_chainID: continue
+                # Everything immediately below is for
+                # neighboring particles with a chainID. 
+                if thisNN_chainID >= 0:
+                    # Find thisNN's chain's max_dens.
+                    thisNN_max_dens = self.densest_in_chain[thisNN_chainID]
+                    if thisNN_max_dens == -1.0: continue
+                    # Calculate the two groups boundary density.
+                    boundary_density = (self.density[thisNN] + self.density[i]) / 2.
+                    # Find out who's denser.
+                    if thisNN_max_dens >= part_max_dens:
+                        higher_chain = thisNN_chainID
+                        lower_chain = self.chainID[i]
+                    else:
+                        higher_chain = self.chainID[i]
+                        lower_chain = thisNN_chainID
+                    # Make sure that the higher density chain has an entry.
+                    try:
+                        test = self.chain_densest_n[int(higher_chain)]
+                    except KeyError:
+                        self.chain_densest_n[int(higher_chain)] = {}
+                    # See if this boundary density is higher than
+                    # previously recorded for this pair of chains.
+                    # Links only go one direction.
+                    try:
+                        old = self.chain_densest_n[int(higher_chain)][int(lower_chain)]
+                        if old < boundary_density:
+                            # make this the new densest boundary between this pair
+                            self.chain_densest_n[int(higher_chain)][int(lower_chain)] = \
+                                boundary_density
+                    except KeyError:
+                        # we haven't seen this pairing before, record this as the
+                        # new densest boundary between chains
+                        self.chain_densest_n[int(higher_chain)][int(lower_chain)] = \
+                            boundary_density
+                else:
+                    continue
+        self.__max_memory()
+        yt_counters("connect_chains")
+
+    def _make_global_chain_densest_n(self):
+        """
+        We want to record the maximum boundary density between all chains on
+        all tasks.
+        """
+        yt_counters("make_global_chain_densest_n")
+        (self.top_keys, self.bot_keys, self.vals) = \
+            self._mpi_maxdict_dict(self.chain_densest_n)
+        self.__max_memory()
+        del self.chain_densest_n
+        yt_counters("make_global_chain_densest_n")
+    
+    def _build_groups(self):
+        """
+        With the collection of possible chain links, build groups.
+        """
+        yt_counters("build_groups")
+        # We need to find out which pairs of self.top_keys, self.bot_keys are
+        # both < self.peakthresh, and create arrays that will store this
+        # relationship.
+        both = na.bitwise_and((self.densest_in_chain[self.top_keys] < self.peakthresh),
+            (self.densest_in_chain[self.bot_keys] < self.peakthresh))
+        g_high = self.top_keys[both]
+        g_low = self.bot_keys[both]
+        g_dens = self.vals[both]
+        del both
+        self.reverse_map = na.ones(self.densest_in_chain.size) * -1
+        densestbound = na.ones(self.densest_in_chain.size) * -1.0
+        for i, gl in enumerate(g_low):
+            if g_dens[i] > densestbound[gl]:
+                densestbound[gl] = g_dens[i]
+        groupID = 0
+        # First assign a group to all chains with max_dens above peakthresh.
+        # The initial groupIDs will be assigned with decending peak density.
+        # This guarantees that the group with the smaller groupID is the
+        # higher chain, as in chain_high below.
+        for chainID,density in enumerate(self.densest_in_chain):
+            if density == -1.0: continue
+            if self.densest_in_chain[chainID] >= self.peakthresh:
+                self.reverse_map[chainID] = groupID
+                groupID += 1
+        group_equivalancy_map = na.empty(groupID, dtype='object')
+        for i in xrange(groupID):
+            group_equivalancy_map[i] = set([])
+        # Loop over all of the chain linkages.
+        for i,chain_high in enumerate(self.top_keys):
+            chain_low = self.bot_keys[i]
+            dens = self.vals[i]
+            max_dens_high = self.densest_in_chain[chain_high]
+            max_dens_low = self.densest_in_chain[chain_low]
+            if max_dens_high == -1.0 or max_dens_low == -1.0: continue
+            # If neither are peak density groups, mark them for later
+            # consideration.
+            if max_dens_high < self.peakthresh and \
+                max_dens_low < self.peakthresh:
+                    # This step is now done vectorized above, with the g_dens
+                    # stuff.
+                    continue
+            # If both are peak density groups, and have a boundary density
+            # that is high enough, make them into a group, otherwise
+            # move onto another linkage.
+            if max_dens_high >= self.peakthresh and \
+                    max_dens_low >= self.peakthresh:
+                if dens < self.saddlethresh:
+                    continue
+                else:
+                    group_high = self.reverse_map[chain_high]
+                    group_low = self.reverse_map[chain_low]
+                    if group_high == -1 or group_low == -1: continue
+                    # Both are already identified as groups, so we need
+                    # to re-assign the less dense group to the denser
+                    # groupID.
+                    if group_low != group_high:
+                        group_equivalancy_map[group_low].add(group_high)
+                        group_equivalancy_map[group_high].add(group_low)
+                    continue
+            # Else, one is above peakthresh, the other below
+            # find out if this is the densest boundary seen so far for
+            # the lower chain.
+            group_high = self.reverse_map[chain_high]
+            if group_high == -1: continue
+            if dens >= densestbound[chain_low]:
+                densestbound[chain_low] = dens
+                self.reverse_map[chain_low] = group_high
+        self.__max_memory()
+        del self.top_keys, self.bot_keys, self.vals
+        # Now refactor group_equivalancy_map back into reverse_map. The group
+        # mapping may be more than one link long, so we need to do it
+        # recursively. The best way to think about this is a field full of 
+        # rabbit holes. The holes are connected at nexuses at the surface.
+        # Each groupID (key) in group_equivalancy_map represents a hole, and
+        # the values the nexuses are the tunnels lead to. The tunnels are two-way,
+        # and when you go through it, you block the passage through that
+        # tunnel in that direction, so you don't repeat yourself later. You can
+        # go back through that tunnel, but your search ends there because all
+        # the other tunnels have been closed at the old nexus. In this fashion your search 
+        # spreads out like the water shooting out of the ground in 'Caddy
+        # Shack.'
+        Set_list = []
+        # We only want the holes that are modulo mine.
+        keys = na.arange(groupID, dtype='int64')
+        if self._mpi_get_size() == None:
+            size = 1
+        else:
+            size = self._mpi_get_size()
+        select = (keys % size == self.mine)
+        groupIDs = keys[select]
+        mine_groupIDs = set([]) # Records only ones modulo mine.
+        not_mine_groupIDs = set([]) # All the others.
+        # Declare these to prevent Errors when they're del-ed below, in case
+        # this task doesn't create them in the loop, for whatever reason.
+        current_sets, new_mine, new_other = [], [], []
+        new_set, final_set, to_add_set, liter = set([]), set([]), set([]), set([])
+        to_add_set = set([])
+        for groupID in groupIDs:
+            if groupID in mine_groupIDs:
+                continue
+            mine_groupIDs.add(groupID)
+            current_sets = []
+            new_set = group_equivalancy_map[groupID]
+            final_set = new_set.copy()
+            while len(new_set) > 0:
+                to_add_set = set([])
+                liter = new_set.difference(mine_groupIDs).difference(not_mine_groupIDs)
+                new_mine, new_other = [], []
+                for link_gID in liter:
+                    to_add_set.update(group_equivalancy_map[link_gID])
+                    if link_gID % size == self.mine:
+                        new_mine.append(link_gID)
+                    else:
+                        new_other.append(link_gID)
+                mine_groupIDs.update(new_mine)
+                not_mine_groupIDs.update(new_other)
+                final_set.update(to_add_set)
+                new_set = to_add_set
+            # Make sure it's not empty
+            final_set.add(groupID)
+            Set_list.append(final_set)
+        self.__max_memory()
+        del group_equivalancy_map, final_set, keys, select, groupIDs, current_sets
+        del mine_groupIDs, not_mine_groupIDs, new_set, to_add_set, liter
+        # Convert this list of sets into a look-up table
+        lookup = na.ones(self.densest_in_chain.size, dtype='int64') * (self.densest_in_chain.size + 2)
+        for i,item in enumerate(Set_list):
+            item_min = min(item)
+            for groupID in item:
+                lookup[groupID] = item_min
+        self.__max_memory()
+        del Set_list
+        # To bring it all together, find the minimum values at each entry
+        # globally.
+        lookup = self._mpi_minimum_array_long(lookup)
+        # Now apply this to reverse_map
+        for chainID,groupID in enumerate(self.reverse_map):
+            if groupID == -1:
+                continue
+            if lookup[groupID] != (self.densest_in_chain.size + 2):
+                self.reverse_map[chainID] = lookup[groupID]
+        del lookup
+        """
+        Now the fringe chains are connected to the proper group
+        (>peakthresh) with the largest boundary.  But we want to look
+        through the boundaries between fringe groups to propagate this
+        along.  Connections are only as good as their smallest boundary
+        """
+        changes = 1
+        while changes:
+            changes = 0
+            for j,dens in enumerate(g_dens):
+                chain_high = g_high[j]
+                chain_low = g_low[j]
+                # If the density of this boundary and the densestbound of
+                # the other chain is higher than a chain's densestbound, then
+                # replace it. We also don't want to link to un-assigned 
+                # neighbors, and we can skip neighbors we're already assigned to.
+                if dens >= densestbound[chain_low] and \
+                        densestbound[chain_high] > densestbound[chain_low] and \
+                        self.reverse_map[chain_high] != -1 and \
+                        self.reverse_map[chain_low] != self.reverse_map[chain_high]:
+                    changes += 1
+                    if dens < densestbound[chain_high]:
+                        densestbound[chain_low] = dens
+                    else:
+                        densestbound[chain_low] = densestbound[chain_high]
+                    self.reverse_map[chain_low] = self.reverse_map[chain_high]
+        self.__max_memory()
+        del g_high, g_low, g_dens, densestbound
+        # Now we have to find the unique groupIDs, since they may have been
+        # merged.
+        temp = list(set(self.reverse_map))
+        # Remove -1 from the list.
+        try:
+            temp.pop(temp.index(-1))
+        except ValueError:
+            # There are no groups, probably.
+            pass
+        # Make a secondary map to make the IDs consecutive.
+        values = na.arange(len(temp))
+        secondary_map = dict(itertools.izip(temp, values))
+        del values
+        # Update reverse_map
+        for chain, map in enumerate(self.reverse_map):
+            # Don't attempt to fix non-assigned chains.
+            if map == -1: continue
+            self.reverse_map[chain] = secondary_map[map]
+        group_count = len(temp)
+        del secondary_map, temp
+        yt_counters("build_groups")
+        self.__max_memory()
+        return group_count
+
+    def _translate_groupIDs(self, group_count):
+        """
+        Using the maps, convert the particle chainIDs into their locally-final
+        groupIDs.
+        """
+        yt_counters("translate_groupIDs")
+        self.I_own = set([])
+        for i in xrange(int(self.size)):
+            # Don't translate non-affiliated particles.
+            if self.chainID[i] == -1: continue
+            # We want to remove the group tag from padded particles,
+            # so when we return it to HaloFinding, there is no duplication.
+            if self.is_inside[i]:
+                self.chainID[i] = self.reverse_map[self.chainID[i]]
+                self.I_own.add(self.chainID[i])
+            else:
+                self.chainID[i] = -1
+        del self.is_inside
+        # Create a densest_in_group, analogous to densest_in_chain.
+        keys = na.arange(group_count)
+        vals = na.zeros(group_count)
+        self.densest_in_group = dict(itertools.izip(keys,vals))
+        self.densest_in_group_real_index = self.densest_in_group.copy()
+        del keys, vals
+        for chainID,max_dens in enumerate(self.densest_in_chain):
+            if max_dens == -1.0: continue
+            groupID = self.reverse_map[chainID]
+            if groupID == -1: continue
+            if self.densest_in_group[groupID] < max_dens:
+                self.densest_in_group[groupID] = max_dens
+                self.densest_in_group_real_index[groupID] = self.densest_in_chain_real_index[chainID]
+        del self.densest_in_chain, self.densest_in_chain_real_index
+        yt_counters("translate_groupIDs")
+
+    def _precompute_group_info(self):
+        yt_counters("Precomp.")
+        """
+        For all groups, compute the various global properties, except bulk
+        velocity, to save time in HaloFinding.py (fewer barriers!).
+        """
+        select = (self.chainID != -1)
+        calc = len(na.where(select == True)[0])
+        loc = na.empty((calc, 3), dtype='float64')
+        loc[:, 0] = na.concatenate((self.xpos, self.xpos_pad))[select]
+        loc[:, 1] = na.concatenate((self.ypos, self.ypos_pad))[select]
+        loc[:, 2] = na.concatenate((self.zpos, self.zpos_pad))[select]
+        self.__max_memory()
+        del self.xpos_pad, self.ypos_pad, self.zpos_pad
+        subchain = self.chainID[select]
+        # First we need to find the maximum density point for all groups.
+        # I think this will be faster than several vector operations that need
+        # to pull the entire chainID array out of memory several times.
+        yt_counters("max dens point")
+        max_dens_point = na.zeros((self.group_count,4),dtype='float64')
+        for i,part in enumerate(na.arange(self.size)[select]):
+            groupID = self.chainID[part]
+            if part < self.real_size:
+                real_index = self.index[part]
+            else:
+                real_index = self.index_pad[part - self.real_size]
+            if real_index == self.densest_in_group_real_index[groupID]:
+                max_dens_point[groupID] = na.array([self.density[part], \
+                loc[i, 0], loc[i, 1], loc[i, 2]])
+        del self.index, self.index_pad, self.densest_in_group_real_index
+        # Now we broadcast this, effectively, with an allsum. Even though
+        # some groups are on multiple tasks, there is only one densest_in_chain
+        # and only that task contributed above.
+        self.max_dens_point = self._mpi_Allsum_double(max_dens_point)
+        yt_counters("max dens point")
+        # Now CoM.
+        yt_counters("CoM")
+        CoM_M = na.zeros((self.group_count,3),dtype='float64')
+        Tot_M = na.zeros(self.group_count, dtype='float64')
+        #c_vec = self.max_dens_point[:,1:4][subchain] - na.array([0.5,0.5,0.5])
+        if calc:
+            c_vec = self.max_dens_point[:,1:4][subchain] - na.array([0.5,0.5,0.5])
+            size = na.bincount(self.chainID[select]).astype('int64')
+        else:
+            # This task has no particles in groups!
+            size = na.zeros(self.group_count, dtype='int64')
+        # In case this task doesn't have all the groups, add trailing zeros.
+        if size.size != self.group_count:
+            size = na.concatenate((size, na.zeros(self.group_count - size.size, dtype='int64')))
+        if calc:
+            cc = loc - c_vec
+            cc = cc - na.floor(cc)
+            ms = na.concatenate((self.mass, self.mass_pad))[select]
+            # Most of the time, the masses will be all the same, and we can try
+            # to save some effort.
+            ms_u = na.unique(ms)
+            if ms_u.size == 1:
+                single = True
+                Tot_M = size.astype('float64') * ms_u
+            else:
+                single = False
+                del ms_u
+            cc[:,0] = cc[:,0] * ms
+            cc[:,1] = cc[:,1] * ms
+            cc[:,2] = cc[:,2] * ms
+            sort = subchain.argsort()
+            cc = cc[sort]
+            sort_subchain = subchain[sort]
+            uniq_subchain = na.unique(sort_subchain)
+            diff_subchain = na.ediff1d(sort_subchain)
+            marks = (diff_subchain > 0)
+            marks = na.arange(calc)[marks] + 1
+            marks = na.concatenate(([0], marks, [calc]))
+            for i, u in enumerate(uniq_subchain):
+                CoM_M[u] = na.sum(cc[marks[i]:marks[i+1]], axis=0)
+            if not single:
+                for i,groupID in enumerate(subchain):
+                    Tot_M[groupID] += ms[i]
+            del cc, ms
+            for groupID in xrange(int(self.group_count)):
+                # Don't divide by zero.
+                if groupID in self.I_own:
+                    CoM_M[groupID] /= Tot_M[groupID]
+                    CoM_M[groupID] += self.max_dens_point[groupID,1:4] - na.array([0.5,0.5,0.5])
+                    CoM_M[groupID] *= Tot_M[groupID]
+        # Now we find their global values
+        self.group_sizes = self._mpi_Allsum_long(size)
+        CoM_M = self._mpi_Allsum_double(CoM_M)
+        self.Tot_M = self._mpi_Allsum_double(Tot_M)
+        self.CoM = na.empty((self.group_count,3), dtype='float64')
+        for groupID in xrange(int(self.group_count)):
+            self.CoM[groupID] = CoM_M[groupID] / self.Tot_M[groupID]
+        yt_counters("CoM")
+        self.__max_memory()
+        # Now we find the maximum radius for all groups.
+        yt_counters("max radius")
+        max_radius = na.zeros(self.group_count, dtype='float64')
+        if calc:
+            com = self.CoM[subchain]
+            rad = na.abs(com - loc)
+            dist = (na.minimum(rad, self.period - rad)**2.).sum(axis=1)
+            dist = dist[sort]
+            for i, u in enumerate(uniq_subchain):
+                max_radius[u] = na.max(dist[marks[i]:marks[i+1]])
+        # Find the maximum across all tasks.
+        mylog.info('Fraction of particles in this region in groups: %f' % (float(calc)/self.size))
+        self.max_radius = self._mpi_double_array_max(max_radius)
+        self.max_radius = na.sqrt(self.max_radius)
+        yt_counters("max radius")
+        yt_counters("Precomp.")
+        self.__max_memory()
+        if calc:
+            del loc, subchain, CoM_M, Tot_M, c_vec, max_radius, select
+            del sort_subchain, uniq_subchain, diff_subchain, marks, dist, sort
+            del rad, com
+
+    def _chain_hop(self):
+        self._global_padding('first')
+        self._global_bounds_neighbors()
+        self._global_padding('second')
+        self._is_inside('first')
+        mylog.info('Distributing padded particles...')
+        self._communicate_padding_data()
+        mylog.info('Building kd tree for %d particles...' % \
+            self.size)
+        self._init_kd_tree()
+        # Mark particles in as being in/out of the domain.
+        self._is_inside('second')
+        # Loop over the particles to find NN for each.
+        mylog.info('Finding nearest neighbors/density...')
+        yt_counters("chainHOP_tags_dens")
+        chainHOP_tags_dens()
+        yt_counters("chainHOP_tags_dens")
+        self.density = fKD.dens
+        # Now each particle has NNtags, and a local self density.
+        # Let's find densest NN
+        mylog.info('Finding densest nearest neighbors...')
+        self._densestNN()
+        # Build the chain of links.
+        mylog.info('Building particle chains...')
+        chain_count = self._build_chains()
+        # This array tracks whether or not relationships for this particle
+        # need to be examined twice, in preconnect_chains and in connect_chains
+        self.search_again = na.ones(self.size, dtype='bool')
+        chain_count = self._preconnect_chains(chain_count)
+        mylog.info('Gobally assigning chainIDs...')
+        self._globally_assign_chainIDs(chain_count)
+        mylog.info('Globally finding densest in chains...')
+        self._create_global_densest_in_chain()
+        mylog.info('Building chain connections across tasks...')
+        self._is_inside('third')
+        self._connect_chains_across_tasks()
+        mylog.info('Communicating connected chains...')
+        self._communicate_annulus_chainIDs()
+        mylog.info('Connecting %d chains into groups...' % self.nchains)
+        self._connect_chains()
+        del fKD.dens, fKD.mass, fKD.dens
+        del fKD.pos, fKD.chunk_tags
+        free_tree() # Frees the kdtree object.
+        del self.densestNN
+        mylog.info('Communicating group links globally...')
+        self._make_global_chain_densest_n()
+        mylog.info('Building final groups...')
+        group_count = self._build_groups()
+        self.group_count = group_count
+        mylog.info('Remapping particles to final groups...')
+        self._translate_groupIDs(group_count)
+        mylog.info('Precomputing info for %d groups...' % group_count)
+        self._precompute_group_info()
+        mylog.info("All done! Max Memory = %d MB" % self.max_mem)
+        # We need to fix chainID and density because HaloFinding is expecting
+        # an array only as long as the real data.
+        self.chainID = self.chainID[:self.real_size]
+        self.density = self.density[:self.real_size]
+        # We'll make this a global object, which can be used to write a text
+        # file giving the names of hdf5 files the particles for each halo.
+        self.mine, self.I_own = self._mpi_info_dict(self.I_own)
+        self.halo_taskmap = defaultdict(set)
+        for taskID in self.I_own:
+            for groupID in self.I_own[taskID]:
+                self.halo_taskmap[groupID].add(taskID)
+        del self.I_own
+        del self.mass, self.xpos, self.ypos, self.zpos
+
+    def __add_to_array(self, arr, key, value, type):
+        """
+        In an effort to replace the functionality of a dict with an array, in
+        order to save memory, this function adds items to an array. If the
+        array is not long enough, it is resized and filled with 'bad' values."""
+        
+        try:
+            arr[key] = value
+        except IndexError:
+            arr = na.concatenate((arr, na.ones(10000, dtype=type)*-1))
+            arr[key] = value
+        return arr
+    
+    def __clean_up_array(self, arr):
+        good = (arr != -1)
+        return arr[good]
+    
+    def __max_memory(self):
+        my_mem = get_memory_usage()
+        self.max_mem = max(my_mem, self.max_mem)
\ No newline at end of file

Added: trunk/yt/lagos/parallelHOP/run.py
==============================================================================
--- (empty file)
+++ trunk/yt/lagos/parallelHOP/run.py	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,24 @@
+from yt.config import ytcfg
+
+ytcfg["yt","time_functions"] = "True"
+
+from yt.mods import *
+
+yt_counters("Full Time")
+
+yt_counters("yt Hierarchy")
+pf = load('data0005')
+
+pf.h
+yt_counters("yt Hierarchy")
+
+h = yt.lagos.HaloFinding.parallelHF(pf, threshold=160.0, safety=2.5, \
+dm_only=False,resize=True, fancy_padding=True, rearrange=True)
+
+yt_counters("Writing Data")
+h.write_out('dist-chain.out')
+h.write_particle_lists_txt("chain")
+h.write_particle_lists("chain")
+yt_counters("Writing Data")
+
+yt_counters("Full Time")

Modified: trunk/yt/lagos/setup.py
==============================================================================
--- trunk/yt/lagos/setup.py	(original)
+++ trunk/yt/lagos/setup.py	Sun Dec 20 12:42:09 2009
@@ -22,6 +22,7 @@
     config.add_extension("PointCombine", "yt/lagos/PointCombine.c", libraries=["m"])
     config.add_subpackage("hop")
     config.add_subpackage("fof")
+    config.add_subpackage("parallelHOP")
     H5dir = check_for_hdf5()
     if H5dir is not None:
         include_dirs=[os.path.join(H5dir,"include")]

Added: trunk/yt/math_utils.py
==============================================================================
--- (empty file)
+++ trunk/yt/math_utils.py	Sun Dec 20 12:42:09 2009
@@ -0,0 +1,46 @@
+"""
+Commonly used mathematical functions.
+
+Author: Matthew Turk <matthewturk at gmail.com>
+Affiliation: UCSD Physics/CASS
+Author: Stephen Skory <stephenskory at yahoo.com>
+Affiliation: UCSD Physics/CASS
+Homepage: http://yt.enzotools.org/
+License:
+  Copyright (C) 2008-2009 Matthew Turk.  All Rights Reserved.
+
+  This file is part of yt.
+
+  yt is free software; you can redistribute it and/or modify
+  it under the terms of the GNU General Public License as published by
+  the Free Software Foundation; either version 3 of the License, or
+  (at your option) any later version.
+
+  This program is distributed in the hope that it will be useful,
+  but WITHOUT ANY WARRANTY; without even the implied warranty of
+  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+  GNU General Public License for more details.
+
+  You should have received a copy of the GNU General Public License
+  along with this program.  If not, see <http://www.gnu.org/licenses/>.
+"""
+
+import numpy as na
+import math
+
+def periodic_dist(a, b, period):
+    """
+    Find the Euclidian periodic distance between two points.
+    *a* and *b* are array-like vectors, and *period* is a float or
+    array-like value for the periodic size of the volume.
+    """
+    a = na.array(a)
+    b = na.array(b)
+    if a.size != b.size: RunTimeError("Arrays must be the same shape.")
+    c = na.empty((2, a.size), dtype="float64")
+    c[0,:] = abs(a - b)
+    c[1,:] = period - abs(a - b)
+    d = na.amin(c, axis=0)**2
+    return math.sqrt(d.sum())
+
+    
\ No newline at end of file



More information about the yt-svn mailing list