[Yt-svn] yt-commit r1560 - in trunk/yt: . extensions extensions/kdtree lagos lagos/parallelHOP
sskory at wrangler.dreamhost.com
sskory at wrangler.dreamhost.com
Sun Dec 20 12:42:19 PST 2009
Author: sskory
Date: Sun Dec 20 12:42:09 2009
New Revision: 1560
URL: http://yt.enzotools.org/changeset/1560
Log:
Adding my patches as provided by mturk to the svn-trunk. In particular, Parallel HOP, and star and halo mass analysis routines.
Added:
trunk/yt/extensions/HaloMassFcn.py
trunk/yt/extensions/StarAnalysis.py
trunk/yt/extensions/kdtree/
trunk/yt/extensions/kdtree/Makefile
trunk/yt/extensions/kdtree/__init__.py
trunk/yt/extensions/kdtree/fKD.f90
trunk/yt/extensions/kdtree/fKD.v
trunk/yt/extensions/kdtree/fKD_source.f90
trunk/yt/extensions/kdtree/test.py
trunk/yt/lagos/parallelHOP/
trunk/yt/lagos/parallelHOP/__init__.py
trunk/yt/lagos/parallelHOP/parallelHOP.py
trunk/yt/lagos/parallelHOP/run.py
trunk/yt/math_utils.py
Modified:
trunk/yt/lagos/BaseDataTypes.py
trunk/yt/lagos/EnzoFields.py
trunk/yt/lagos/HaloFinding.py
trunk/yt/lagos/ParallelTools.py
trunk/yt/lagos/setup.py
Added: trunk/yt/extensions/HaloMassFcn.py
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/HaloMassFcn.py Sun Dec 20 12:42:09 2009
@@ -0,0 +1,719 @@
+"""
+HaloMassFcn - Halo Mass Function and supporting functions.
+
+Author: Stephen Skory <sskory at physics.ucsd.edu>
+Affiliation: UC San Diego / CASS
+Homepage: http://yt.enzotools.org/
+License:
+ Copyright (C) 2008-2009 Stephen Skory (and others). All Rights Reserved.
+
+ This file is part of yt.
+
+ yt is free software; you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation; either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+
+import yt.lagos as lagos
+from yt.logger import lagosLogger as mylog
+import numpy as na
+import math, time
+
+class HaloMassFcn(object):
+ def __init__(self, pf, halo_file=None, omega_matter0=None, omega_lambda0=None,
+ omega_baryon0=0.05, hubble0=None, sigma8input=0.86, primordial_index=1.0,
+ this_redshift=None, log_mass_min=None, log_mass_max=None, num_sigma_bins=360,
+ fitting_function=4):
+ """
+ Initalize a HaloMassFcn object to analyze the distribution of haloes
+ as a function of mass.
+ :param halo_file (str): The filename of the output of the Halo Profiler.
+ Default=None.
+ :param omega_matter0 (float): The fraction of the universe made up of
+ matter (dark and baryonic). Default=None.
+ :param omega_lambda0 (float): The fraction of the universe made up of
+ dark energy. Default=None.
+ :param omega_baryon0 (float): The fraction of the universe made up of
+ ordinary baryonic matter. This should match the value
+ used to create the initial conditions, using 'inits'. This is
+ *not* stored in the enzo datset so it must be checked by hand.
+ Default=0.05.
+ :param hubble0 (float): The expansion rate of the universe in units of
+ 100 km/s/Mpc. Default=None.
+ :param sigma8input (float): The amplitude of the linear power
+ spectrum at z=0 as specified by the rms amplitude of mass-fluctuations
+ in a top-hat sphere of radius 8 Mpc/h. This should match the value
+ used to create the initial conditions, using 'inits'. This is
+ *not* stored in the enzo datset so it must be checked by hand.
+ Default=0.86.
+ :param primoridal_index (float): This is the index of the mass power
+ spectrum before modification by the transfer function. A value of 1
+ corresponds to the scale-free primordial spectrum. This should match
+ the value used to make the initial conditions using 'inits'. This is
+ *not* stored in the enzo datset so it must be checked by hand.
+ Default=1.0.
+ :param this_redshift (float): The current redshift. Default=None.
+ :param log_mass_min (float): The log10 of the mass of the minimum of the
+ halo mass range. Default=None.
+ :param log_mass_max (float): The log10 of the mass of the maximum of the
+ halo mass range. Default=None.
+ :param num_sigma_bins (float): The number of bins (points) to use for
+ the calculations and generated fit. Default=360.
+ :param fitting_function (int): Which fitting function to use.
+ 1 = Press-schechter, 2 = Jenkins, 3 = Sheth-Tormen, 4 = Warren fit
+ Default=4.
+ """
+ self.pf = pf
+ self.halo_file = halo_file
+ self.omega_matter0 = omega_matter0
+ self.omega_lambda0 = omega_lambda0
+ self.omega_baryon0 = omega_baryon0
+ self.hubble0 = hubble0
+ self.sigma8input = sigma8input
+ self.primordial_index = primordial_index
+ self.this_redshift = this_redshift
+ self.log_mass_min = log_mass_min
+ self.log_mass_max = log_mass_max
+ self.num_sigma_bins = num_sigma_bins
+ self.fitting_function = fitting_function
+
+ # Determine the run mode.
+ if halo_file is None:
+ # We are hand-picking our various cosmological parameters
+ self.mode = 'single'
+ else:
+ # Make the fit using the same cosmological parameters as the dataset.
+ self.mode = 'haloes'
+ self.omega_matter0 = self.pf['CosmologyOmegaMatterNow']
+ self.omega_lambda0 = self.pf['CosmologyOmegaLambdaNow']
+ self.hubble0 = self.pf['CosmologyHubbleConstantNow']
+ self.this_redshift = self.pf['CosmologyCurrentRedshift']
+ self.read_haloes()
+ self.log_mass_min = math.log10(min(self.haloes))
+ self.log_mass_max = math.log10(max(self.haloes))
+
+ # Input error check.
+ if self.mode == 'single':
+ if omega_matter0 == None or omega_lambda0 == None or \
+ hubble0 == None or this_redshift == None or log_mass_min == None or\
+ log_mass_max == None:
+ mylog.error("All of these parameters need to be set:")
+ mylog.error("[omega_matter0, omega_lambda0, \
+ hubble0, this_redshift, log_mass_min, log_mass_max]")
+ mylog.error("[%s,%s,%s,%s,%s,%s]" % (omega_matter0,\
+ omega_lambda0, hubble0, this_redshift,\
+ log_mass_min, log_mass_max))
+ return None
+
+ # Poke the user to make sure they're doing it right.
+ mylog.info(
+ """
+ Please make sure these are the correct values! They are
+ not stored in enzo datasets, so must be entered by hand.
+ sigma8input=%f primordial_index=%f omega_baryon0=%f
+ """ % (self.sigma8input, self.primordial_index, self.omega_baryon0))
+ time.sleep(1)
+
+ # Do the calculations.
+ self.sigmaM()
+ self.dndm()
+
+ if self.mode == 'haloes':
+ self.bin_haloes()
+
+ def write_out(self, prefix='HMF', fit=True, haloes=True):
+ """
+ Writes out the halo mass functions to file(s) with prefix *prefix*.
+ """
+ # First the fit file.
+ if fit:
+ fitname = prefix + '-fit.dat'
+ fp = open(fitname, 'w')
+ line = \
+ """#Columns:
+#1. log10 of mass (Msolar, NOT Msolar/h)
+#2. mass (Msolar/h)
+#3. (dn/dM)*dM (differential number density of halos, per Mpc^3 (NOT h^3/Mpc^3)
+#4. cumulative number density of halos (per Mpc^3, NOT h^3/Mpc^3)
+"""
+ fp.write(line)
+ for i in xrange(self.logmassarray.size - 1):
+ line = "%e\t%e\t%e\t%e\n" % (self.logmassarray[i], self.massarray[i],
+ self.dn_M_z[i], self.nofmz_cum[i])
+ fp.write(line)
+ fp.close()
+ if self.mode == 'haloes' and haloes:
+ haloname = prefix + '-haloes.dat'
+ fp = open(haloname, 'w')
+ line = \
+ """#Columns:
+#1. log10 of mass (Msolar, NOT Msolar/h)
+#2. mass (Msolar/h)
+#3. cumulative number density of halos (per Mpc^3, NOT h^3/Mpc^3)\n
+"""
+ for i in xrange(self.logmassarray.size - 1):
+ line = "%e\t%e\t%e\n" % (self.logmassarray[i], self.massarray[i],
+ self.dis[i])
+ fp.write(line)
+ fp.close()
+
+ def read_haloes(self):
+ """
+ Read in the virial masses of the haloes.
+ """
+ mylog.info("Reading halo masses from %s" % self.halo_file)
+ f = open(self.halo_file,'r')
+ line = f.readline() # burn the top header line.
+ line = f.readline()
+ self.haloes = []
+ while line:
+ line = line.split()
+ # Mass is in the 6th column (ord. 5)
+ mass = float(line[5])
+ if mass > 0:
+ self.haloes.append(float(line[5]))
+ line = f.readline()
+ f.close()
+ self.haloes = na.array(self.haloes)
+
+ def bin_haloes(self):
+ """
+ With the list of virial masses, find the halo mass function.
+ """
+ bins = na.logspace(math.pow(10.,self.log_mass_min),
+ math.pow(10.,log_mass_max),self.num_sigma_bins)
+ avgs = (bins[1:]+bins[:-1])/2.
+
+ dis, bins = na.histogram(self.haloes,bins,new=True)
+
+ # add right to left
+ for i,b in enumerate(dis):
+ dis[self.num_sigma_bins-i-3] += dis[self.num_sigma_bins-i-2]
+ if i == (self.num_sigma_bins - 3): break
+
+ self.dis = dis / self.pf['CosmologyComovingBoxSize']**3.0
+
+ def sigmaM(self):
+ """
+ Written by BWO, 2006 (updated 25 January 2007).
+ Converted to Python by Stephen Skory December 2009.
+
+ This routine takes in cosmological parameters and creates a file (array) with
+ sigma(M) in it, which is necessary for various press-schechter type
+ stuff. In principle one can calculate it ahead of time, but it's far,
+ far faster in the long run to calculate your sigma(M) ahead of time.
+
+ Inputs: cosmology, user must set parameters
+
+ Outputs: four columns of data containing the following information:
+
+ 1) log mass (Msolar)
+ 2) mass (Msolar/h)
+ 3) Radius (comoving Mpc/h)
+ 4) sigma (normalized) using Msun/h as the input
+
+ The arrays output are used later.
+ """
+
+ # Set up the transfer function object.
+ self.TF = TransferFunction(self.omega_matter0, self.omega_baryon0, 0.0, 0,
+ self.omega_lambda0, self.hubble0, self.this_redshift);
+
+ if self.TF.qwarn:
+ mylog.error("You should probably fix your cosmology parameters!")
+
+ # output arrays
+ # 1) log10 of mass (Msolar, NOT Msolar/h)
+ self.Rarray = na.empty(self.num_sigma_bins,dtype='float64')
+ # 2) mass (Msolar/h)
+ self.logmassarray = na.empty(self.num_sigma_bins, dtype='float64')
+ # 3) spatial scale corresponding to that radius (Mpc/h)
+ self.massarray = na.empty(self.num_sigma_bins, dtype='float64')
+ # 4) sigma(M, z=0, where mass is in Msun/h)
+ self.sigmaarray = na.empty(self.num_sigma_bins, dtype='float64')
+
+ # get sigma_8 normalization
+ R = 8.0; # in units of Mpc/h (comoving)
+
+ sigma8_unnorm = math.sqrt(self.sigma_squared_of_R(R));
+ sigma_normalization = self.sigma8input / sigma8_unnorm;
+
+ rho0 = self.omega_matter0 * 2.78e+11; # in units of h^2 Msolar/Mpc^3
+
+ # spacing in mass of our sigma calculation
+ dm = (float(self.log_mass_max) - self.log_mass_min)/self.num_sigma_bins;
+
+ """
+ loop over the total number of sigma_bins the user has requested.
+ For each bin, calculate mass and equivalent radius, and call
+ sigma_squared_of_R to get the sigma(R) (equivalent to sigma(M)),
+ normalize by user-specified sigma_8, and then write out.
+ """
+ for i in xrange(self.num_sigma_bins):
+
+ # thislogmass is in units of Msolar, NOT Msolar/h
+ thislogmass = self.log_mass_min + i*dm
+
+ # mass in units of h^-1 Msolar
+ thismass = math.pow(10.0, thislogmass) * self.hubble0;
+
+ # radius is in units of h^-1 Mpc (comoving)
+ thisradius = math.pow( 3.0*thismass / 4.0 / math.pi / rho0, 1.0/3.0 );
+
+ R = thisradius; # h^-1 Mpc (comoving)
+
+ self.Rarray[i] = thisradius; # h^-1 Mpc (comoving)
+ self.logmassarray[i] = thislogmass; # Msun (NOT Msun/h)
+ self.massarray[i] = thismass; # Msun/h
+
+ # get normalized sigma(R)
+ self.sigmaarray[i] = math.sqrt(self.sigma_squared_of_R(R)) * sigma_normalization;
+ # All done!
+
+ def dndm(self):
+
+ # constants - set these before calling any functions!
+ rho0 = self.omega_matter0 * 2.78e+11; # in units of h^2 Msolar/Mpc^3
+ self.delta_c0 = 1.69; # critical density for turnaround (Press-Schechter)
+
+ nofmz_cum = 0.0; # keep track of cumulative number density
+
+ # Loop over masses, going BACKWARD, and calculate dn/dm as well as the
+ # cumulative mass function.
+
+ # output arrays
+ # 5) (dn/dM)*dM (differential number density of halos, per Mpc^3 (NOT h^3/Mpc^3)
+ self.dn_M_z = na.empty(self.num_sigma_bins, dtype='float64')
+ # 6) cumulative number density of halos (per Mpc^3, NOT h^3/Mpc^3)
+ self.nofmz_cum = na.zeros(self.num_sigma_bins, dtype='float64')
+
+ for j in xrange(self.num_sigma_bins - 1):
+ i = (self.num_sigma_bins - 2) - j
+
+ thissigma = self.sigmaof_M_z(i, self.this_redshift);
+ nextsigma = self.sigmaof_M_z(i+1, self.this_redshift);
+
+ # calc dsigmadm - has units of h (since massarray has units of h^-1)
+ dsigmadm = (nextsigma-thissigma) / (self.massarray[i+1] - self.massarray[i]);
+
+ # calculate dn(M,z) (dn/dM * dM)
+ # this has units of h^3 since rho0 has units of h^2, dsigmadm
+ # has units of h, and massarray has units of h^-1
+ dn_M_z = -1.0 / thissigma * dsigmadm * rho0 / self.massarray[i] * \
+ self.multiplicityfunction(thissigma)*(self.massarray[i+1] - self.massarray[i]);
+
+ # scale by h^4 to get rid of all factors of h
+ dn_M_z *= math.pow(self.hubble0, 4.0);
+
+ # keep track of cumulative number density
+ if dn_M_z > 1.0e-20:
+ nofmz_cum += dn_M_z;
+
+ # Store this.
+ self.nofmz_cum[i] = nofmz_cum
+ self.dn_M_z[i] = dn_M_z
+
+
+ def sigma_squared_of_R(self, R):
+ """
+ /* calculates sigma^2(R). This is the routine where the magic happens (or
+ whatever it is that we do here). Integrates the sigma_squared_integrand
+ parameter from R to infinity. Calls GSL (gnu scientific library) to do
+ the actual integration.
+
+ Note that R is in h^-1 Mpc (comoving)
+ */
+ """
+ self.R = R
+ result = integrate_inf(self.sigma_squared_integrand)
+
+ sigmasquaredofR = result / 2.0 / math.pi / math.pi
+
+ return sigmasquaredofR;
+
+ def sigma_squared_integrand(self, k):
+ """
+ /* integrand for integral to get sigma^2(R). */
+ """
+
+ Rcom = self.R; # this is R in comoving Mpc/h
+
+ f = k*k*self.PofK(k)*na.power( abs(self.WofK(Rcom,k)), 2.0);
+
+ return f
+
+ def PofK(self, k):
+ """
+ /* returns power spectrum as a function of wavenumber k */
+ """
+
+ thisPofK = na.power(k, self.primordial_index) * na.power( self.TofK(k), 2.0);
+
+ return thisPofK;
+
+ def TofK(self, k):
+ """
+ /* returns transfer function as a function of wavenumber k. */
+ """
+
+ thisTofK = self.TF.TFmdm_onek_hmpc(k);
+
+ return thisTofK;
+
+ def WofK(self, R, k):
+ """
+ returns W(k), which is the fourier transform of the top-hat function.
+ """
+
+ x = R*k;
+
+ thisWofK = 3.0 * ( na.sin(x) - x*na.cos(x) ) / (x*x*x);
+
+ return thisWofK;
+
+ def multiplicityfunction(self, sigma):
+ """
+ /* Multiplicity function - this is where the various fitting functions/analytic
+ theories are different. The various places where I found these fitting functions
+ are listed below. */
+ """
+
+ nu = self.delta_c0 / sigma;
+
+ if self.fitting_function==1:
+ # Press-Schechter (This form from Jenkins et al. 2001, MNRAS 321, 372-384, eqtn. 5)
+ thismult = sqrt(2.0/math.pi) * nu * exp(-0.5*nu*nu);
+
+ elif self.fitting_function==2:
+ # Jenkins et al. 2001, MNRAS 321, 372-384, eqtn. 9
+ thismult = 0.315 * math.exp( -1.0 * math.pow( abs( math.log(1.0/sigma) + 0.61), 3.8 ) );
+
+ elif self.fitting_function==3:
+ # Sheth-Tormen 1999, eqtn 10, using expression from Jenkins et al. 2001, eqtn. 7
+ A=0.3222;
+ a=0.707;
+ p=0.3;
+ thismult = A*math.sqrt(2.0*a/math.pi)*(1.0+ math.pow( 1.0/(nu*nu*a), p) )*\
+ nu * math.exp(-0.5*a*nu*nu);
+
+ elif self.fitting_function==4:
+ # LANL fitting function - Warren et al. 2005, astro-ph/0506395, eqtn. 5
+ A=0.7234;
+ a=1.625;
+ b=0.2538;
+ c=1.1982;
+ thismult = A*( math.pow(sigma, -1.0*a) + b)*math.exp(-1.0*c / sigma / sigma );
+
+ else:
+ mylog.error("Don't understand this. Fitting function requested is %d\n",
+ self.fitting_function)
+ return None
+
+ return thismult
+
+ def sigmaof_M_z(self, sigmabin, redshift):
+ """
+ /* sigma(M, z) */
+ """
+
+ thissigma = self.Dofz(redshift) * self.sigmaarray[sigmabin];
+
+ return thissigma;
+
+ def Dofz(self, redshift):
+ """
+ /* Growth function */
+ """
+
+ thisDofz = self.gofz(redshift) / self.gofz(0.0) / (1.0+redshift);
+
+ return thisDofz;
+
+
+ def gofz(self, redshift):
+ """
+ /* g(z) - I don't think this has any other name*/
+ """
+
+ thisgofz = 2.5 * self.omega_matter_of_z(redshift) / \
+ ( math.pow( self.omega_matter_of_z(redshift), 4.0/7.0 ) - \
+ self.omega_lambda_of_z(redshift) + \
+ ( (1.0 + self.omega_matter_of_z(redshift) / 2.0) * \
+ (1.0 + self.omega_lambda_of_z(redshift) / 70.0) ))
+
+ return thisgofz;
+
+
+ def omega_matter_of_z(self,redshift):
+ """
+ /* Omega matter as a function of redshift */
+ """
+
+ thisomofz = self.omega_matter0 * math.pow( 1.0+redshift, 3.0) / \
+ math.pow( self.Eofz(redshift), 2.0 );
+
+ return thisomofz;
+
+ def omega_lambda_of_z(self,redshift):
+ """
+ /* Omega lambda as a function of redshift */
+ """
+
+ thisolofz = self.omega_lambda0 / math.pow( self.Eofz(redshift), 2.0 )
+
+ return thisolofz;
+
+ def Eofz(self, redshift):
+ """
+ /* E(z) - I don't think this has any other name */
+ """
+ thiseofz = math.sqrt( self.omega_lambda0 \
+ + (1.0 - self.omega_lambda0 - self.omega_matter0)*math.pow( 1.0+redshift, 2.0) \
+ + self.omega_matter0 * math.pow( 1.0+redshift, 3.0) );
+
+ return thiseofz;
+
+
+"""
+/* Fitting Formulae for CDM + Baryon + Massive Neutrino (MDM) cosmologies. */
+/* Daniel J. Eisenstein & Wayne Hu, Institute for Advanced Study */
+
+/* There are two primary routines here, one to set the cosmology, the
+other to construct the transfer function for a single wavenumber k.
+You should call the former once (per cosmology) and the latter as
+many times as you want. */
+
+/* TFmdm_set_cosm() -- User passes all the cosmological parameters as
+ arguments; the routine sets up all of the scalar quantites needed
+ computation of the fitting formula. The input parameters are:
+ 1) omega_matter -- Density of CDM, baryons, and massive neutrinos,
+ in units of the critical density.
+ 2) omega_baryon -- Density of baryons, in units of critical.
+ 3) omega_hdm -- Density of massive neutrinos, in units of critical
+ 4) degen_hdm -- (Int) Number of degenerate massive neutrino species
+ 5) omega_lambda -- Cosmological constant
+ 6) hubble -- Hubble constant, in units of 100 km/s/Mpc
+ 7) redshift -- The redshift at which to evaluate */
+
+/* TFmdm_onek_mpc() -- User passes a single wavenumber, in units of Mpc^-1.
+ Routine returns the transfer function from the Eisenstein & Hu
+ fitting formula, based on the cosmology currently held in the
+ internal variables. The routine returns T_cb (the CDM+Baryon
+ density-weighted transfer function), although T_cbn (the CDM+
+ Baryon+Neutrino density-weighted transfer function) is stored
+ in the global variable tf_cbnu. */
+
+/* We also supply TFmdm_onek_hmpc(), which is identical to the previous
+ routine, but takes the wavenumber in units of h Mpc^-1. */
+
+/* We hold the internal scalar quantities in global variables, so that
+the user may access them in an external program, via "extern" declarations. */
+
+/* Please note that all internal length scales are in Mpc, not h^-1 Mpc! */
+"""
+
+class TransferFunction(object):
+ def __init__(self, omega_matter, omega_baryon, omega_hdm,
+ degen_hdm, omega_lambda, hubble, redshift):
+ """
+ /* This routine takes cosmological parameters and a redshift and sets up
+ all the internal scalar quantities needed to compute the transfer function. */
+ /* INPUT: omega_matter -- Density of CDM, baryons, and massive neutrinos,
+ in units of the critical density. */
+ /* omega_baryon -- Density of baryons, in units of critical. */
+ /* omega_hdm -- Density of massive neutrinos, in units of critical */
+ /* degen_hdm -- (Int) Number of degenerate massive neutrino species */
+ /* omega_lambda -- Cosmological constant */
+ /* hubble -- Hubble constant, in units of 100 km/s/Mpc */
+ /* redshift -- The redshift at which to evaluate */
+ /* OUTPUT: Returns 0 if all is well, 1 if a warning was issued. Otherwise,
+ sets many global variables for use in TFmdm_onek_mpc() */
+ """
+ self.qwarn = 0;
+ self.theta_cmb = 2.728/2.7 # Assuming T_cmb = 2.728 K
+
+ # Look for strange input
+ if (omega_baryon<0.0):
+ mylog.error("TFmdm_set_cosm(): Negative omega_baryon set to trace amount.\n")
+ self.qwarn = 1
+ if (omega_hdm<0.0):
+ mylog.error("TFmdm_set_cosm(): Negative omega_hdm set to trace amount.\n")
+ self.qwarn = 1;
+ if (hubble<=0.0):
+ mylog.error("TFmdm_set_cosm(): Negative Hubble constant illegal.\n")
+ return None
+ elif (hubble>2.0):
+ mylog.error("TFmdm_set_cosm(): Hubble constant should be in units of 100 km/s/Mpc.\n");
+ self.qwarn = 1;
+ if (redshift<=-1.0):
+ mylog.error("TFmdm_set_cosm(): Redshift < -1 is illegal.\n");
+ return None
+ elif (redshift>99.0):
+ mylog.error("TFmdm_set_cosm(): Large redshift entered. TF may be inaccurate.\n");
+ self.qwarn = 1;
+
+ if (degen_hdm<1): degen_hdm=1;
+ self.num_degen_hdm = degen_hdm;
+ # Have to save this for TFmdm_onek_mpc()
+ # This routine would crash if baryons or neutrinos were zero,
+ # so don't allow that.
+ if (omega_baryon<=0): omega_baryon=1e-5;
+ if (omega_hdm<=0): omega_hdm=1e-5;
+
+ self.omega_curv = 1.0-omega_matter-omega_lambda;
+ self.omhh = omega_matter*SQR(hubble);
+ self.obhh = omega_baryon*SQR(hubble);
+ self.onhh = omega_hdm*SQR(hubble);
+ self.f_baryon = omega_baryon/omega_matter;
+ self.f_hdm = omega_hdm/omega_matter;
+ self.f_cdm = 1.0-self.f_baryon-self.f_hdm;
+ self.f_cb = self.f_cdm+self.f_baryon;
+ self.f_bnu = self.f_baryon+self.f_hdm;
+
+ # Compute the equality scale.
+ self.z_equality = 25000.0*self.omhh/SQR(SQR(self.theta_cmb)) # Actually 1+z_eq
+ self.k_equality = 0.0746*self.omhh/SQR(self.theta_cmb);
+
+ # Compute the drag epoch and sound horizon
+ z_drag_b1 = 0.313*math.pow(self.omhh,-0.419)*(1+0.607*math.pow(self.omhh,0.674));
+ z_drag_b2 = 0.238*math.pow(self.omhh,0.223);
+ self.z_drag = 1291*math.pow(self.omhh,0.251)/(1.0+0.659*math.pow(self.omhh,0.828))* \
+ (1.0+z_drag_b1*math.pow(self.obhh,z_drag_b2));
+ self.y_drag = self.z_equality/(1.0+self.z_drag);
+
+ self.sound_horizon_fit = 44.5*math.log(9.83/self.omhh)/math.sqrt(1.0+10.0*math.pow(self.obhh,0.75));
+
+ # Set up for the free-streaming & infall growth function
+ self.p_c = 0.25*(5.0-math.sqrt(1+24.0*self.f_cdm));
+ self.p_cb = 0.25*(5.0-math.sqrt(1+24.0*self.f_cb));
+
+ omega_denom = omega_lambda+SQR(1.0+redshift)*(self.omega_curv+\
+ omega_matter*(1.0+redshift));
+ self.omega_lambda_z = omega_lambda/omega_denom;
+ self.omega_matter_z = omega_matter*SQR(1.0+redshift)*(1.0+redshift)/omega_denom;
+ self.growth_k0 = self.z_equality/(1.0+redshift)*2.5*self.omega_matter_z/ \
+ (math.pow(self.omega_matter_z,4.0/7.0)-self.omega_lambda_z+ \
+ (1.0+self.omega_matter_z/2.0)*(1.0+self.omega_lambda_z/70.0));
+ self.growth_to_z0 = self.z_equality*2.5*omega_matter/(math.pow(omega_matter,4.0/7.0) \
+ -omega_lambda + (1.0+omega_matter/2.0)*(1.0+omega_lambda/70.0));
+ self.growth_to_z0 = self.growth_k0/self.growth_to_z0;
+
+ # Compute small-scale suppression
+ self.alpha_nu = self.f_cdm/self.f_cb*(5.0-2.*(self.p_c+self.p_cb))/(5.-4.*self.p_cb)* \
+ math.pow(1+self.y_drag,self.p_cb-self.p_c)* \
+ (1+self.f_bnu*(-0.553+0.126*self.f_bnu*self.f_bnu))/ \
+ (1-0.193*math.sqrt(self.f_hdm*self.num_degen_hdm)+0.169*self.f_hdm*math.pow(self.num_degen_hdm,0.2))* \
+ (1+(self.p_c-self.p_cb)/2*(1+1/(3.-4.*self.p_c)/(7.-4.*self.p_cb))/(1+self.y_drag));
+ self.alpha_gamma = math.sqrt(self.alpha_nu);
+ self.beta_c = 1/(1-0.949*self.f_bnu);
+ # Done setting scalar variables
+ self.hhubble = hubble # Need to pass Hubble constant to TFmdm_onek_hmpc()
+
+
+ def TFmdm_onek_mpc(self, kk):
+ """
+ /* Given a wavenumber in Mpc^-1, return the transfer function for the
+ cosmology held in the global variables. */
+ /* Input: kk -- Wavenumber in Mpc^-1 */
+ /* Output: The following are set as global variables:
+ growth_cb -- the transfer function for density-weighted
+ CDM + Baryon perturbations.
+ growth_cbnu -- the transfer function for density-weighted
+ CDM + Baryon + Massive Neutrino perturbations. */
+ /* The function returns growth_cb */
+ """
+
+ self.qq = kk/self.omhh*SQR(self.theta_cmb);
+
+ # Compute the scale-dependent growth functions
+ self.y_freestream = 17.2*self.f_hdm*(1+0.488*math.pow(self.f_hdm,-7.0/6.0))* \
+ SQR(self.num_degen_hdm*self.qq/self.f_hdm);
+ temp1 = math.pow(self.growth_k0, 1.0-self.p_cb);
+ temp2 = na.power(self.growth_k0/(1+self.y_freestream),0.7);
+ self.growth_cb = na.power(1.0+temp2, self.p_cb/0.7)*temp1;
+ self.growth_cbnu = na.power(na.power(self.f_cb,0.7/self.p_cb)+temp2, self.p_cb/0.7)*temp1;
+
+ # Compute the master function
+ self.gamma_eff = self.omhh*(self.alpha_gamma+(1-self.alpha_gamma)/ \
+ (1+SQR(SQR(kk*self.sound_horizon_fit*0.43))));
+ self.qq_eff = self.qq*self.omhh/self.gamma_eff;
+
+ tf_sup_L = na.log(2.71828+1.84*self.beta_c*self.alpha_gamma*self.qq_eff);
+ tf_sup_C = 14.4+325/(1+60.5*na.power(self.qq_eff,1.11));
+ self.tf_sup = tf_sup_L/(tf_sup_L+tf_sup_C*SQR(self.qq_eff));
+
+ self.qq_nu = 3.92*self.qq*math.sqrt(self.num_degen_hdm/self.f_hdm);
+ self.max_fs_correction = 1+1.2*math.pow(self.f_hdm,0.64)*math.pow(self.num_degen_hdm,0.3+0.6*self.f_hdm)/ \
+ (na.power(self.qq_nu,-1.6)+na.power(self.qq_nu,0.8));
+ self.tf_master = self.tf_sup*self.max_fs_correction;
+
+ # Now compute the CDM+HDM+baryon transfer functions
+ tf_cb = self.tf_master*self.growth_cb/self.growth_k0;
+ tf_cbnu = self.tf_master*self.growth_cbnu/self.growth_k0;
+ return tf_cb
+
+
+ def TFmdm_onek_hmpc(self, kk):
+ """
+ /* Given a wavenumber in h Mpc^-1, return the transfer function for the
+ cosmology held in the global variables. */
+ /* Input: kk -- Wavenumber in h Mpc^-1 */
+ /* Output: The following are set as global variables:
+ growth_cb -- the transfer function for density-weighted
+ CDM + Baryon perturbations.
+ growth_cbnu -- the transfer function for density-weighted
+ CDM + Baryon + Massive Neutrino perturbations. */
+ /* The function returns growth_cb */
+ """
+ return self.TFmdm_onek_mpc(kk*self.hhubble);
+
+def SQR(a):
+ return a*a
+
+def integrate_inf(fcn, error=1e-7, initial_guess=10):
+ """
+ Integrate a function *fcn* from zero to infinity, stopping when the answer
+ changes by less than *error*. Hopefully someday we can do something
+ better than this!
+ """
+ xvals = na.logspace(0,na.log10(initial_guess), initial_guess+1)-.9
+ yvals = fcn(xvals)
+ xdiffs = xvals[1:] - xvals[:-1]
+ # Trapezoid rule, but with different dxes between values, so na.trapz
+ # will not work.
+ areas = (yvals[1:] + yvals[:-1]) * xdiffs / 2.0
+ area0 = na.sum(areas)
+ # Next guess.
+ next_guess = 10 * initial_guess
+ xvals = na.logspace(0,na.log10(next_guess), 2*initial_guess**2+1)-.99
+ yvals = fcn(xvals)
+ xdiffs = xvals[1:] - xvals[:-1]
+ # Trapezoid rule.
+ areas = (yvals[1:] + yvals[:-1]) * xdiffs / 2.0
+ area1 = na.sum(areas)
+ # Now we refine until the error is smaller than *error*.
+ diff = area1 - area0
+ area_final = area1
+ area_last = area1
+ one_pow = 3
+ while diff > error:
+ next_guess *= 10
+ xvals = na.logspace(0,na.log10(next_guess), one_pow*initial_guess**one_pow+1) - (1 - 0.1**one_pow)
+ yvals = fcn(xvals)
+ xdiffs = xvals[1:] - xvals[:-1]
+ # Trapezoid rule.
+ areas = (yvals[1:] + yvals[:-1]) * xdiffs / 2.0
+ area_next = na.sum(areas)
+ diff = area_next - area_last
+ area_last = area_next
+ one_pow+=1
+ return area_last
Added: trunk/yt/extensions/StarAnalysis.py
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/StarAnalysis.py Sun Dec 20 12:42:09 2009
@@ -0,0 +1,384 @@
+"""
+StarAnalysis - Functions to analyze stars.
+
+Author: Stephen Skory <sskory at physics.ucsd.edu>
+Affiliation: UC San Diego / CASS
+Homepage: http://yt.enzotools.org/
+License:
+ Copyright (C) 2008-2009 Stephen Skory (and others). All Rights Reserved.
+
+ This file is part of yt.
+
+ yt is free software; you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation; either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+
+import yt.lagos as lagos
+from yt.logger import lagosLogger as mylog
+
+import numpy as na
+import h5py
+
+import math, itertools
+
+YEAR = 3.155693e7 # sec / year
+LIGHT = 2.997925e10 # cm / s
+
+class StarFormationRate(object):
+ def __init__(self, pf, data_source=None, star_mass=None,
+ star_creation_time=None, volume=None, bins=300):
+ self._pf = pf
+ self._data_source = data_source
+ self.star_mass = star_mass
+ self.star_creation_time = star_creation_time
+ self.volume = volume
+ self.bin_count = bins
+ # Check to make sure we have the right set of informations.
+ if data_source is None:
+ if self.star_mass is None or self.star_creation_time is None or \
+ self.volume is None:
+ mylog.error(
+ """
+ If data_source is not provided, all of these paramters need to be set:
+ star_mass (array, Msun),
+ star_creation_time (array, code units),
+ volume (float, Mpc**3).
+ """)
+ return None
+ self.mode = 'provided'
+ else:
+ self.mode = 'data_source'
+ # Set up for time conversion.
+ self.cosm = lagos.EnzoCosmology(HubbleConstantNow =
+ (100.0 * self._pf['CosmologyHubbleConstantNow']),
+ OmegaMatterNow = self._pf['CosmologyOmegaMatterNow'],
+ OmegaLambdaNow = self._pf['CosmologyOmegaLambdaNow'],
+ InitialRedshift = self._pf['CosmologyInitialRedshift'])
+ # Find the time right now.
+ self.time_now = self.cosm.ComputeTimeFromRedshift(
+ self._pf["CosmologyCurrentRedshift"]) # seconds
+ # Build the distribution.
+ self.build_dist()
+
+ def build_dist(self):
+ """
+ Build the data for plotting.
+ """
+ # Pick out the stars.
+ if self.mode == 'data_source':
+ ct = self._data_source["creation_time"]
+ ct_stars = ct[ct > 0]
+ mass_stars = self._data_source["ParticleMassMsun"][ct > 0]
+ elif self.mode == 'provided':
+ ct_stars = self.star_creation_time
+ mass_stars = self.star_mass
+ # Find the oldest stars in units of code time.
+ tmin= min(ct_stars)
+ # Multiply the end to prevent numerical issues.
+ self.time_bins = na.linspace(tmin*0.99, self._pf['InitialTime'],
+ num = self.bin_count + 1)
+ # Figure out which bins the stars go into.
+ inds = na.digitize(ct_stars, self.time_bins) - 1
+ # Sum up the stars created in each time bin.
+ self.mass_bins = na.zeros(self.bin_count + 1, dtype='float64')
+ for index in na.unique(inds):
+ self.mass_bins[index] += sum(mass_stars[inds == index])
+ # Calculate the cumulative mass sum over time by forward adding.
+ self.cum_mass_bins = self.mass_bins.copy()
+ for index in xrange(self.bin_count):
+ self.cum_mass_bins[index+1] += self.cum_mass_bins[index]
+ # We will want the time taken between bins.
+ self.time_bins_dt = self.time_bins[1:] - self.time_bins[:-1]
+
+ def write_out(self, name="StarFormationRate.out"):
+ """
+ Write out the star analysis to a text file *name*. The columns are in
+ order:
+ 1) Time (yrs)
+ 2) Look-back time (yrs)
+ 3) Redshift
+ 4) Star formation rate in this bin per year (Msol/yr)
+ 5) Star formation rate in this bin per year per Mpc**3 (Msol/yr/Mpc**3)
+ 6) Stars formed in this time bin (Msol)
+ 7) Cumulative stars formed up to this time bin (Msol)
+ """
+ fp = open(name, "w")
+ if self.mode == 'data_source':
+ vol = self._data_source.volume('mpc')
+ elif self.mode == 'provided':
+ vol = self.volume
+ tc = self._pf["Time"]
+ # Use the center of the time_bin, not the left edge.
+ for i, time in enumerate((self.time_bins[1:] + self.time_bins[:-1])/2.):
+ line = "%1.5e\t%1.5e\t%1.5e\t%1.5e\t%1.5e\t%1.5e\t%1.5e\n" % \
+ (time * tc / YEAR, # Time
+ (self.time_now - time * tc)/YEAR, # Lookback time
+ self.cosm.ComputeRedshiftFromTime(time * tc), # Redshift
+ self.mass_bins[i] / (self.time_bins_dt[i] * tc / YEAR), # Msol/yr
+ self.mass_bins[i] / (self.time_bins_dt[i] * tc / YEAR) / vol, # Msol/yr/vol
+ self.mass_bins[i], # Msol in bin
+ self.cum_mass_bins[i]) # cumulative
+ fp.write(line)
+ fp.close()
+
+CHABRIER = {
+"Z0001" : "bc2003_hr_m22_chab_ssp.ised.h5", #/* 0.5% */
+"Z0004" : "bc2003_hr_m32_chab_ssp.ised.h5", #/* 2% */
+"Z004" : "bc2003_hr_m42_chab_ssp.ised.h5", #/* 20% */
+"Z008" : "bc2003_hr_m52_chab_ssp.ised.h5", #/* 40% */
+"Z02" : "bc2003_hr_m62_chab_ssp.ised.h5", #/* solar; 0.02 */
+"Z05" : "bc2003_hr_m72_chab_ssp.ised.h5" #/* 250% */
+}
+
+SALPETER = {
+"Z0001" : "bc2003_hr_m22_salp_ssp.ised.h5", #/* 0.5% */
+"Z0004" : "bc2003_hr_m32_salp_ssp.ised.h5", #/* 2% */
+"Z004" : "bc2003_hr_m42_salp_ssp.ised.h5", #/* 20% */
+"Z008" : "bc2003_hr_m52_salp_ssp.ised.h5", #/* 40% */
+"Z02" : "bc2003_hr_m62_salp_ssp.ised.h5", #/* solar; 0.02 */
+"Z05" : "bc2003_hr_m72_salp_ssp.ised.h5" #/* 250% */
+}
+
+Zsun = 0.02
+
+#/* dividing line of metallicity; linear in log(Z/Zsun) */
+METAL1 = 0.01 # /* in units of Z/Zsun */
+METAL2 = 0.0632
+METAL3 = 0.2828
+METAL4 = 0.6325
+METAL5 = 1.5811
+METALS = na.array([METAL1, METAL2, METAL3, METAL4, METAL5])
+
+# Translate METALS array digitize to the table dicts
+MtoD = na.array(["Z0001", "Z0004", "Z004", "Z008", "Z02", "Z05"])
+
+"""
+This spectrum code is based on code from Ken Nagamine, converted from C to Python.
+I've also reversed the order of elements in the flux arrays to be in C-ordering,
+for faster memory access.
+"""
+
+class SpectrumBuilder(object):
+ def __init__(self, pf, bcdir="", model="chabrier"):
+ """
+ Initialize the data to build a summed flux spectrum for a
+ collection of stars using the models of Bruzual & Charlot (2003).
+ :param pf (object): Yt pf object.
+ :param bcdir (string): Path to directory containing Bruzual &
+ Charlot h5 fit files.
+ :param model (string): Choice of Initial Metalicity Function model,
+ 'chabrier' or 'salpeter'.
+ """
+ self._pf = pf
+ self.bcdir = bcdir
+
+ if model == "chabrier":
+ self.model = CHABRIER
+ elif model == "salpeter":
+ self.model = SALPETER
+ # Set up for time conversion.
+ self.cosm = lagos.EnzoCosmology(HubbleConstantNow =
+ (100.0 * self._pf['CosmologyHubbleConstantNow']),
+ OmegaMatterNow = self._pf['CosmologyOmegaMatterNow'],
+ OmegaLambdaNow = self._pf['CosmologyOmegaLambdaNow'],
+ InitialRedshift = self._pf['CosmologyInitialRedshift'])
+ # Find the time right now.
+ self.time_now = self.cosm.ComputeTimeFromRedshift(
+ self._pf["CosmologyCurrentRedshift"]) # seconds
+
+ # Read the tables.
+ self.read_bclib()
+
+ def read_bclib(self):
+ """
+ Read in the age and wavelength bins, and the flux bins for each
+ metallicity.
+ """
+ self.flux = {}
+ for file in self.model:
+ fname = self.bcdir + "/" + self.model[file]
+ fp = h5py.File(fname, 'r')
+ self.age = fp["agebins"][:] # 1D floats
+ self.wavelength = fp["wavebins"][:] # 1D floats
+ self.flux[file] = fp["flam"][:,:] # 2D floats, [agebin, wavebin]
+ fp.close()
+
+ def calculate_spectrum(self, data_source=None, star_mass=None,
+ star_creation_time=None, star_metallicity_fraction=None,
+ star_metallicity_constant=None):
+ """
+ For the set of stars, calculate the collective spectrum.
+ Attached to the output are several useful objects:
+ final_spec: The collective spectrum in units of flux binned in wavelength.
+ wavelength: The wavelength for the spectrum bins, in Angstroms.
+ total_mass: Total mass of all the stars.
+ avg_mass: Average mass of all the stars.
+ avg_metal: Average metallicity of all the stars.
+ :param data_source (object): A yt data_source that defines a portion
+ of the volume from which to extract stars.
+ :param star_mass (array, float): An array of star masses in Msun units.
+ :param star_creation_time (array, float): An array of star creation
+ times in code units.
+ :param star_metallicity_fraction (array, float): An array of star
+ metallicity fractions, in code units (which is not Z/Zsun).
+ :param star_metallicity_constant (float): If desired, override the star
+ metallicity fraction of all the stars to the given value.
+ """
+ # Initialize values
+ self.final_spec = na.zeros(self.wavelength.size, dtype='float64')
+ self._data_source = data_source
+ self.star_mass = star_mass
+ self.star_creation_time = star_creation_time
+ self.star_metal = star_metallicity_fraction
+
+ # Check to make sure we have the right set of data.
+ if data_source is None:
+ if self.star_mass is None or self.star_creation_time is None or \
+ (star_metallicity_fraction is None and star_metallicity_constant is None):
+ mylog.error(
+ """
+ If data_source is not provided, all of these paramters need to be set:
+ star_mass (array, Msun),
+ star_creation_time (array, code units),
+ And one of:
+ star_metallicity_fraction (array, code units).
+ --OR--
+ star_metallicity_constant (float, code units).
+ """)
+ return None
+ if star_metallicity_constant is not None:
+ self.star_metal = na.ones(self.star_mass.size, dtype='float64') * \
+ star_metallicity_constant
+ else:
+ # Get the data we need.
+ ct = self._data_source["creation_time"]
+ self.star_creation_time = ct[ct > 0]
+ self.star_mass = self._data_source["ParticleMassMsun"][ct > 0]
+ if star_metallicity_constant is not None:
+ self.star_metal = na.ones(self.star_mass.size, dtype='float64') * \
+ star_metallicity_constant
+ else:
+ self.star_metal = self._data_source["metallicity_fraction"][ct > 0]
+ # Fix metallicity to units of Zsun.
+ self.star_metal /= Zsun
+ # Age of star in years.
+ dt = (self.time_now - self.star_creation_time * self._pf['Time']) / YEAR
+ # Figure out which METALS bin the star goes into.
+ Mindex = na.digitize(dt, METALS)
+ # Replace the indices with strings.
+ Mname = MtoD[Mindex]
+ # Figure out which age bin this star goes into.
+ Aindex = na.digitize(dt, self.age)
+ # Ratios used for the interpolation.
+ ratio1 = (dt - self.age[Aindex-1]) / (self.age[Aindex] - self.age[Aindex-1])
+ ratio2 = (self.age[Aindex] - dt) / (self.age[Aindex] - self.age[Aindex-1])
+ # Sort the stars by metallicity and then by age, which should reduce
+ # memory access time by a little bit in the loop.
+ sort = na.lexsort((Aindex, Mname))
+ Mname = Mname[sort]
+ Aindex = Aindex[sort]
+ ratio1 = ratio1[sort]
+ ratio2 = ratio2[sort]
+ self.star_mass = self.star_mass[sort]
+ self.star_creation_time = self.star_creation_time[sort]
+ self.star_metal = self.star_metal[sort]
+
+ # Interpolate the flux for each star, adding to the total by weight.
+ for star in itertools.izip(Mname, Aindex, ratio1, ratio2, self.star_mass):
+ # Pick the right age bin for the right flux array.
+ flux = self.flux[star[0]][star[1],:]
+ # Get the one just before the one above.
+ flux_1 = self.flux[star[0]][star[1]-1,:]
+ # interpolate in log(flux), linear in time.
+ int_flux = star[3] * na.log10(flux_1) + star[2] * na.log10(flux)
+ # Add this flux to the total, weighted by mass.
+ self.final_spec += na.power(10., int_flux) * star[4]
+ # Normalize.
+ self.total_mass = sum(self.star_mass)
+ self.avg_mass = na.mean(self.star_mass)
+ tot_metal = sum(self.star_metal * self.star_mass)
+ self.avg_metal = math.log10(tot_metal / self.total_mass / Zsun)
+
+ # Below is an attempt to do the loop using vectors and matrices,
+ # however it doesn't appear to be much faster, probably due to all
+ # the gymnastics that have to be done to do element-by-element
+ # multiplication for matricies.
+ # I'm keeping it in here in case I come up with a more
+ # elegant way that actually is faster.
+# for metal_name in MtoD:
+# # Pick out our stars in this metallicity bin.
+# select = (Mname == metal_name)
+# A = Aindex[select]
+# if A.size == 0: continue
+# r1 = ratio1[select]
+# r2 = ratio2[select]
+# sm = self.star_mass[select]
+# # From the flux array for this metal, and our selection, build
+# # a new flux array just for the ages of these stars, in the
+# # same order as the selection of stars.
+# this_flux = na.matrix(self.flux[metal_name][A])
+# # Make one for the last time step for each star in the same fashion
+# # as above.
+# this_flux_1 = na.matrix(self.flux[metal_name][A-1])
+# # This is kind of messy, but we're going to multiply this_fluxes
+# # by the appropriate ratios and add it together to do the
+# # interpolation in log(flux) and linear in time.
+# print r1.size
+# r1 = na.matrix(r1.tolist()*self.wavelength.size).reshape(self.wavelength.size,r1.size).T
+# r2 = na.matrix(r2.tolist()*self.wavelength.size).reshape(self.wavelength.size,r2.size).T
+# print this_flux_1.shape, r1.shape
+# int_flux = na.multiply(na.log10(this_flux_1),r1) \
+# + na.multiply(na.log10(this_flux),r2)
+# # Weight the fluxes by mass.
+# sm = na.matrix(sm.tolist()*self.wavelength.size).reshape(self.wavelength.size,sm.size).T
+# int_flux = na.multiply(na.power(10., int_flux), sm)
+# # Sum along the columns, converting back to an array, adding
+# # to the full spectrum.
+# self.final_spec += na.array(int_flux.sum(axis=0))[0,:]
+
+
+ def write_out(self, name="sum_flux.out"):
+ """
+ Write out the summed flux to a file. The file has two columns:
+ 1) Wavelength (Angstrom)
+ 2) Flux (Luminosity per unit wavelength, L_sun Ang^-1,
+ L_sun = 3.826 * 10^33 ergs s^-1.)
+ :param name (string): Name of file to write to.
+ """
+ fp = open(name, 'w')
+ for i, wave in enumerate(self.wavelength):
+ fp.write("%1.5e\t%1.5e\n" % (wave, self.final_spec[i]))
+ fp.close()
+
+ def write_out_SED(self, name="sum_SED.out", flux_norm=5200.):
+ """
+ Write out the summed SED to a file. The file has two columns:
+ 1) Wavelength (Angstrom)
+ 2) Relative flux normalized to the flux at *flux_norm*.
+ It also will attach an array *f_nu* which is the normalized flux,
+ identical to the disk output.
+ :param name (string): Name of file to write to.
+ :param flux_norm (float): Wavelength of the flux to normalize the
+ distribution against.
+ """
+ # find the f_nu closest to flux_norm
+ fn_wavelength = na.argmin(abs(self.wavelength - flux_norm))
+ f_nu = self.final_spec * na.power(self.wavelength, 2.) / LIGHT
+ # Normalize f_nu
+ self.f_nu = f_nu / f_nu[fn_wavelength]
+ # Write out.
+ fp = open(name, 'w')
+ for i, wave in enumerate(self.wavelength):
+ fp.write("%1.5e\t%1.5e\n" % (wave, self.f_nu[i]))
+ fp.close()
+
Added: trunk/yt/extensions/kdtree/Makefile
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/Makefile Sun Dec 20 12:42:09 2009
@@ -0,0 +1,11 @@
+# if you want to build this statically, you need to include fKD.f90 to the
+# compile line, so pick the first of the two below. Otherwise, eliminate it,
+# like the second, for a shared object.
+
+
+fKD: fKD.f90 fKD.v fKD_source.f90
+# Forthon --compile_first fKD_source --no2underscores --with-numpy -g fKD fKD.f90 fKD_source.f90
+ Forthon -F gfortran --compile_first fKD_source --no2underscores --with-numpy --fopt "-O3" fKD fKD_source.f90
+
+clean:
+ rm -rf build fKDpy.a fKDpy.so
Added: trunk/yt/extensions/kdtree/__init__.py
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/__init__.py Sun Dec 20 12:42:09 2009
@@ -0,0 +1,3 @@
+from yt.lagos import *
+
+from fKDpy import *
\ No newline at end of file
Added: trunk/yt/extensions/kdtree/fKD.f90
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/fKD.f90 Sun Dec 20 12:42:09 2009
@@ -0,0 +1,162 @@
+subroutine create_tree()
+ use kdtree2_module
+ use fKD_module
+ use kdtree2module
+ use tree_nodemodule
+ use intervalmodule
+
+ ! create a kd tree object
+
+ tree2 => kdtree2_create(pos,sort=sort,rearrange=rearrange) ! this is how you create a tree.
+ return
+
+end subroutine create_tree
+
+
+subroutine find_nn_nearest_neighbors()
+ use kdtree2_module
+ use fKD_module
+ use kdtree2module
+ use tree_nodemodule
+ use intervalmodule
+
+ integer :: k
+ type(kdtree2_result),allocatable :: results(:) ! nearest neighbors
+ !integer, parameter :: nn ! number of nearest neighbors found
+
+
+ allocate(results(nn))
+
+ call kdtree2_n_nearest(tp=tree2,qv=qv,nn=nn,results=results)
+
+ dist = results%dis
+ tags = results%idx
+
+ !do k=1,nn
+ ! print *, "k = ", k, " idx = ", tags(k)," dis = ", dist(k)
+ ! print *, "x y z", pos(1,results(k)%idx), pos(2,results(k)%idx), pos(3,results(k)%idx)
+ !enddo
+
+
+ deallocate(results)
+ return
+
+end subroutine find_nn_nearest_neighbors
+
+subroutine find_all_nn_nearest_neighbors()
+ ! for all particles in pos, find their nearest neighbors and return the
+ ! indexes and distances as big arrays
+ use kdtree2_module
+ use fKD_module
+ use kdtree2module
+ use tree_nodemodule
+ use intervalmodule
+
+ integer :: k
+ type(kdtree2_result),allocatable :: results(:) ! nearest neighbors
+ allocate(results(nn))
+
+ do k=1,nparts
+ qv(:) = pos(:,k)
+ call kdtree2_n_nearest(tp=tree2,qv=qv,nn=nn,results=results)
+ nn_dist(:,k) = results%dis
+ nn_tags(:,k) = results%idx
+ end do
+
+ deallocate(results)
+ return
+
+end subroutine find_all_nn_nearest_neighbors
+
+subroutine find_chunk_nearest_neighbors()
+ ! for a chunk of the full number of particles, find their nearest neighbors
+ use kdtree2_module
+ use fKD_module
+ use kdtree2module
+ use tree_nodemodule
+ use intervalmodule
+
+ integer :: k
+ type(kdtree2_result),allocatable :: results(:) ! nearest neighbors
+ allocate(results(nn))
+ do k=start,finish
+ qv(:) = pos(:,k)
+ call kdtree2_n_nearest(tp=tree2,qv=qv,nn=nn,results=results)
+ chunk_tags(:,k - start + 1) = results%idx
+
+ end do
+
+ deallocate(results)
+ return
+
+end subroutine find_chunk_nearest_neighbors
+
+subroutine chainHOP_tags_dens()
+ ! for all particles in pos, find their nearest neighbors, and calculate
+ ! their density. Return only nMerge nearest neighbors.
+ use kdtree2_module
+ use fKD_module
+ use kdtree2module
+ use tree_nodemodule
+ use intervalmodule
+
+ integer :: k, pj, i
+ real :: ih2, fNorm, r2, rs
+ integer, allocatable :: temp_tags(:)
+ real, allocatable :: temp_dist(:)
+ type(kdtree2_result),allocatable :: results(:) ! nearest neighbors
+ allocate(results(nn))
+ allocate(temp_tags(nn))
+ allocate(temp_dist(nn))
+
+ do k=1,nparts
+ qv(:) = pos(:,k)
+
+ call kdtree2_n_nearest(tp=tree2,qv=qv,nn=nn,results=results)
+ temp_tags(:) = results%idx
+ temp_dist(:) = results%dis
+
+ ! calculate the density for this particle
+ ih2 = 4.0/maxval(results%dis)
+ fNorm = 0.5*sqrt(ih2)*ih2/3.1415926535897931
+ do i=1,nn
+ pj = temp_tags(i)
+ r2 = temp_dist(i) * ih2
+ rs = 2.0 - sqrt(r2)
+ if (r2 < 1.0) then
+ rs = (1.0 - 0.75*rs*r2)
+ else
+ rs = 0.25*rs*rs*rs
+ end if
+ rs = rs * fNorm
+ dens(k) = dens(k) + rs * mass(pj)
+ dens(pj) = dens(pj) + rs * mass(k)
+ end do
+
+ ! record only nMerge nearest neighbors, but skip the first one which
+ ! is always the self-same particle
+ ! nn_tags(:,k) = temp_tags(2:nMerge)
+ end do
+
+ deallocate(results)
+ deallocate(temp_dist)
+ deallocate(temp_tags)
+ return
+
+end subroutine chainHOP_tags_dens
+
+subroutine free_tree()
+ use kdtree2_module
+ use fKD_module
+ use kdtree2module
+ use tree_nodemodule
+ use intervalmodule
+
+ ! this releases memory for the tree BUT NOT THE ARRAY OF DATA YOU PASSED
+ ! TO MAKE THE TREE.
+ call kdtree2_destroy(tree2)
+
+ ! The data to make the tree has to be deleted in python BEFORE calling
+ ! this!
+end subroutine free_tree
+
Added: trunk/yt/extensions/kdtree/fKD.v
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/fKD.v Sun Dec 20 12:42:09 2009
@@ -0,0 +1,101 @@
+fKD
+
+****** fKD_module vars:
+# Not all of these are being used, but they only take memory
+# if they're initialized in python.
+tags(:) _integer # particle ID tags
+dist(:) _real # interparticle spacings
+nn_tags(:,:) _integer # for all particles at once, [nth neighbor, index]
+chunk_tags(:,:) _integer # for finding only a chunk of the nearest neighbors
+nn_dist(:,:) _real
+pos(3,:) _real
+dens(:) _real
+mass(:) _real
+qv(3) real
+nparts integer
+nn integer
+nMerge integer # number of nearest neighbors used in chain merging
+start integer
+finish integer
+tree2 _kdtree2
+sort logical /.false./
+rearrange logical /.true./
+
+%%%% interval:
+lower real
+upper real
+#real(kdkind) :: lower,upper
+
+
+%%%% tree_node:
+# an internal tree node
+cut_dim integer
+#integer :: cut_dim
+# the dimension to cut
+cut_val real
+#real(kdkind) :: cut_val
+# where to cut the dimension
+cut_val_left real
+cut_val_right real
+#real(kdkind) :: cut_val_left, cut_val_right
+# improved cutoffs knowing the spread in child boxes.
+u integer
+l integer
+#integer :: l, u
+left _tree_node
+right _tree_node
+#type(tree_node), pointer :: left, right
+box(:) _interval
+#type(interval), pointer :: box(:) => null()
+# child pointers
+# Points included in this node are indexes[k] with k \in [l,u]
+
+
+%%%% kdtree2:
+# Global information about the tree, one per tree
+dimen integer /0/
+n integer /0/
+# dimensionality and total # of points
+the_data(:,:) _real
+#real(kdkind), pointer :: the_data(:,:) => null()
+# pointer to the actual data array
+#
+# IMPORTANT NOTE: IT IS DIMENSIONED the_data(1:d,1:N)
+# which may be opposite of what may be conventional.
+# This is, because in Fortran, the memory layout is such that
+# the first dimension is in sequential order. Hence, with
+# (1:d,1:N), all components of the vector will be in consecutive
+# memory locations. The search time is dominated by the
+# evaluation of distances in the terminal nodes. Putting all
+# vector components in consecutive memory location improves
+# memory cache locality, and hence search speed, and may enable
+# vectorization on some processors and compilers.
+ind(:) _integer
+#integer, pointer :: ind(:) => null()
+# permuted index into the data, so that indexes[l..u] of some
+# bucket represent the indexes of the actual points in that
+# bucket.
+# do we always sort output results?
+sort logical /.false./
+#logical :: sort = .false.
+rearrange logical /.false./
+#logical :: rearrange = .false.
+rearranged_data(:,:) _real
+#real(kdkind), pointer :: rearranged_data(:,:) => null()
+# if (rearrange .eqv. .true.) then rearranged_data has been
+# created so that rearranged_data(:,i) = the_data(:,ind(i)),
+# permitting search to use more cache-friendly rearranged_data, at
+# some initial computation and storage cost.
+root _tree_node
+#type(tree_node), pointer :: root => null()
+# root pointer of the tree
+
+
+
+***** Subroutines:
+find_nn_nearest_neighbors subroutine
+create_tree() subroutine
+free_tree() subroutine
+find_all_nn_nearest_neighbors subroutine
+find_chunk_nearest_neighbors subroutine
+chainHOP_tags_dens subroutine
Added: trunk/yt/extensions/kdtree/fKD_source.f90
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/fKD_source.f90 Sun Dec 20 12:42:09 2009
@@ -0,0 +1,1955 @@
+!
+!(c) Matthew Kennel, Institute for Nonlinear Science (2004)
+!
+! Licensed under the Academic Free License version 1.1 found in file LICENSE
+! with additional provisions found in that same file.
+!
+! Modified by Stephen Skory, CASS/UCSD (2009), adding periodicity and changes
+! so this can be used by as a Python module using Forthon
+
+module kdtree2_precision_module
+
+ integer, parameter :: sp = kind(0.0)
+ integer, parameter :: dp = kind(0.0d0)
+
+ private :: sp, dp
+
+ !
+ ! You must comment out exactly one
+ ! of the two lines. If you comment
+ ! out kdkind = sp then you get single precision
+ ! and if you comment out kdkind = dp
+ ! you get double precision.
+ !
+
+ integer, parameter :: kdkind = sp
+ !integer, parameter :: kdkind = dp
+ public :: kdkind
+
+end module kdtree2_precision_module
+
+module kdtree2_priority_queue_module
+ use kdtree2_precision_module
+ !
+ ! maintain a priority queue (PQ) of data, pairs of 'priority/payload',
+ ! implemented with a binary heap. This is the type, and the 'dis' field
+ ! is the priority.
+ !
+ type kdtree2_result
+ ! a pair of distances, indexes
+ real(kdkind) :: dis!=0.0
+ integer :: idx!=-1 Initializers cause some bugs in compilers.
+ end type kdtree2_result
+ !
+ ! A heap-based priority queue lets one efficiently implement the following
+ ! operations, each in log(N) time, as opposed to linear time.
+ !
+ ! 1) add a datum (push a datum onto the queue, increasing its length)
+ ! 2) return the priority value of the maximum priority element
+ ! 3) pop-off (and delete) the element with the maximum priority, decreasing
+ ! the size of the queue.
+ ! 4) replace the datum with the maximum priority with a supplied datum
+ ! (of either higher or lower priority), maintaining the size of the
+ ! queue.
+ !
+ !
+ ! In the k-d tree case, the 'priority' is the square distance of a point in
+ ! the data set to a reference point. The goal is to keep the smallest M
+ ! distances to a reference point. The tree algorithm searches terminal
+ ! nodes to decide whether to add points under consideration.
+ !
+ ! A priority queue is useful here because it lets one quickly return the
+ ! largest distance currently existing in the list. If a new candidate
+ ! distance is smaller than this, then the new candidate ought to replace
+ ! the old candidate. In priority queue terms, this means removing the
+ ! highest priority element, and inserting the new one.
+ !
+ ! Algorithms based on Cormen, Leiserson, Rivest, _Introduction
+ ! to Algorithms_, 1990, with further optimization by the author.
+ !
+ ! Originally informed by a C implementation by Sriranga Veeraraghavan.
+ !
+ ! This module is not written in the most clear way, but is implemented such
+ ! for speed, as it its operations will be called many times during searches
+ ! of large numbers of neighbors.
+ !
+ type pq
+ !
+ ! The priority queue consists of elements
+ ! priority(1:heap_size), with associated payload(:).
+ !
+ ! There are heap_size active elements.
+ ! Assumes the allocation is always sufficient. Will NOT increase it
+ ! to match.
+ integer :: heap_size = 0
+ type(kdtree2_result), pointer :: elems(:)
+ end type pq
+
+ public :: kdtree2_result
+
+ public :: pq
+ public :: pq_create
+ public :: pq_delete, pq_insert
+ public :: pq_extract_max, pq_max, pq_replace_max, pq_maxpri
+ private
+
+contains
+
+
+ function pq_create(results_in) result(res)
+ !
+ ! Create a priority queue from ALREADY allocated
+ ! array pointers for storage. NOTE! It will NOT
+ ! add any alements to the heap, i.e. any existing
+ ! data in the input arrays will NOT be used and may
+ ! be overwritten.
+ !
+ ! usage:
+ ! real(kdkind), pointer :: x(:)
+ ! integer, pointer :: k(:)
+ ! allocate(x(1000),k(1000))
+ ! pq => pq_create(x,k)
+ !
+ type(kdtree2_result), target:: results_in(:)
+ type(pq) :: res
+ !
+ !
+ integer :: nalloc
+
+ nalloc = size(results_in,1)
+ if (nalloc .lt. 1) then
+ write (*,*) 'PQ_CREATE: error, input arrays must be allocated.'
+ end if
+ res%elems => results_in
+ res%heap_size = 0
+ return
+ end function pq_create
+
+ !
+ ! operations for getting parents and left + right children
+ ! of elements in a binary heap.
+ !
+
+!
+! These are written inline for speed.
+!
+! integer function parent(i)
+! integer, intent(in) :: i
+! parent = (i/2)
+! return
+! end function parent
+
+! integer function left(i)
+! integer, intent(in) ::i
+! left = (2*i)
+! return
+! end function left
+
+! integer function right(i)
+! integer, intent(in) :: i
+! right = (2*i)+1
+! return
+! end function right
+
+! logical function compare_priority(p1,p2)
+! real(kdkind), intent(in) :: p1, p2
+!
+! compare_priority = (p1 .gt. p2)
+! return
+! end function compare_priority
+
+ subroutine heapify(a,i_in)
+ !
+ ! take a heap rooted at 'i' and force it to be in the
+ ! heap canonical form. This is performance critical
+ ! and has been tweaked a little to reflect this.
+ !
+ type(pq),pointer :: a
+ integer, intent(in) :: i_in
+ !
+ integer :: i, l, r, largest
+
+ real(kdkind) :: pri_i, pri_l, pri_r, pri_largest
+
+
+ type(kdtree2_result) :: temp
+
+ i = i_in
+
+bigloop: do
+ l = 2*i ! left(i)
+ r = l+1 ! right(i)
+ !
+ ! set 'largest' to the index of either i, l, r
+ ! depending on whose priority is largest.
+ !
+ ! note that l or r can be larger than the heap size
+ ! in which case they do not count.
+
+
+ ! does left child have higher priority?
+ if (l .gt. a%heap_size) then
+ ! we know that i is the largest as both l and r are invalid.
+ exit
+ else
+ pri_i = a%elems(i)%dis
+ pri_l = a%elems(l)%dis
+ if (pri_l .gt. pri_i) then
+ largest = l
+ pri_largest = pri_l
+ else
+ largest = i
+ pri_largest = pri_i
+ endif
+
+ !
+ ! between i and l we have a winner
+ ! now choose between that and r.
+ !
+ if (r .le. a%heap_size) then
+ pri_r = a%elems(r)%dis
+ if (pri_r .gt. pri_largest) then
+ largest = r
+ endif
+ endif
+ endif
+
+ if (largest .ne. i) then
+ ! swap data in nodes largest and i, then heapify
+
+ temp = a%elems(i)
+ a%elems(i) = a%elems(largest)
+ a%elems(largest) = temp
+ !
+ ! Canonical heapify() algorithm has tail-ecursive call:
+ !
+ ! call heapify(a,largest)
+ ! we will simulate with cycle
+ !
+ i = largest
+ cycle bigloop ! continue the loop
+ else
+ return ! break from the loop
+ end if
+ enddo bigloop
+ return
+ end subroutine heapify
+
+ subroutine pq_max(a,e)
+ !
+ ! return the priority and its payload of the maximum priority element
+ ! on the queue, which should be the first one, if it is
+ ! in heapified form.
+ !
+ type(pq),pointer :: a
+ type(kdtree2_result),intent(out) :: e
+
+ if (a%heap_size .gt. 0) then
+ e = a%elems(1)
+ else
+ write (*,*) 'PQ_MAX: ERROR, heap_size < 1'
+ stop
+ endif
+ return
+ end subroutine pq_max
+
+ real(kdkind) function pq_maxpri(a)
+ type(pq), pointer :: a
+
+ if (a%heap_size .gt. 0) then
+ pq_maxpri = a%elems(1)%dis
+ else
+ write (*,*) 'PQ_MAX_PRI: ERROR, heapsize < 1'
+ stop
+ endif
+ return
+ end function pq_maxpri
+
+ subroutine pq_extract_max(a,e)
+ !
+ ! return the priority and payload of maximum priority
+ ! element, and remove it from the queue.
+ ! (equivalent to 'pop()' on a stack)
+ !
+ type(pq),pointer :: a
+ type(kdtree2_result), intent(out) :: e
+
+ if (a%heap_size .ge. 1) then
+ !
+ ! return max as first element
+ !
+ e = a%elems(1)
+
+ !
+ ! move last element to first
+ !
+ a%elems(1) = a%elems(a%heap_size)
+ a%heap_size = a%heap_size-1
+ call heapify(a,1)
+ return
+ else
+ write (*,*) 'PQ_EXTRACT_MAX: error, attempted to pop non-positive PQ'
+ stop
+ end if
+
+ end subroutine pq_extract_max
+
+
+ real(kdkind) function pq_insert(a,dis,idx)
+ !
+ ! Insert a new element and return the new maximum priority,
+ ! which may or may not be the same as the old maximum priority.
+ !
+ type(pq),pointer :: a
+ real(kdkind), intent(in) :: dis
+ integer, intent(in) :: idx
+ ! type(kdtree2_result), intent(in) :: e
+ !
+ integer :: i, isparent
+ real(kdkind) :: parentdis
+ !
+
+ ! if (a%heap_size .ge. a%max_elems) then
+ ! write (*,*) 'PQ_INSERT: error, attempt made to insert element on full PQ'
+ ! stop
+ ! else
+ a%heap_size = a%heap_size + 1
+ i = a%heap_size
+
+ do while (i .gt. 1)
+ isparent = int(i/2)
+ parentdis = a%elems(isparent)%dis
+ if (dis .gt. parentdis) then
+ ! move what was in i's parent into i.
+ a%elems(i)%dis = parentdis
+ a%elems(i)%idx = a%elems(isparent)%idx
+ i = isparent
+ else
+ exit
+ endif
+ end do
+
+ ! insert the element at the determined position
+ a%elems(i)%dis = dis
+ a%elems(i)%idx = idx
+
+ pq_insert = a%elems(1)%dis
+ return
+ ! end if
+
+ end function pq_insert
+
+ subroutine pq_adjust_heap(a,i)
+ type(pq),pointer :: a
+ integer, intent(in) :: i
+ !
+ ! nominally arguments (a,i), but specialize for a=1
+ !
+ ! This routine assumes that the trees with roots 2 and 3 are already heaps, i.e.
+ ! the children of '1' are heaps. When the procedure is completed, the
+ ! tree rooted at 1 is a heap.
+ real(kdkind) :: prichild
+ integer :: parent, child, N
+
+ type(kdtree2_result) :: e
+
+ e = a%elems(i)
+
+ parent = i
+ child = 2*i
+ N = a%heap_size
+
+ do while (child .le. N)
+ if (child .lt. N) then
+ if (a%elems(child)%dis .lt. a%elems(child+1)%dis) then
+ child = child+1
+ endif
+ endif
+ prichild = a%elems(child)%dis
+ if (e%dis .ge. prichild) then
+ exit
+ else
+ ! move child into parent.
+ a%elems(parent) = a%elems(child)
+ parent = child
+ child = 2*parent
+ end if
+ end do
+ a%elems(parent) = e
+ return
+ end subroutine pq_adjust_heap
+
+
+ real(kdkind) function pq_replace_max(a,dis,idx)
+ !
+ ! Replace the extant maximum priority element
+ ! in the PQ with (dis,idx). Return
+ ! the new maximum priority, which may be larger
+ ! or smaller than the old one.
+ !
+ type(pq),pointer :: a
+ real(kdkind), intent(in) :: dis
+ integer, intent(in) :: idx
+! type(kdtree2_result), intent(in) :: e
+ ! not tested as well!
+
+ integer :: parent, child, N
+ real(kdkind) :: prichild, prichildp1
+
+ type(kdtree2_result) :: etmp
+
+ if (.true.) then
+ N=a%heap_size
+ if (N .ge. 1) then
+ parent =1
+ child=2
+
+ loop: do while (child .le. N)
+ prichild = a%elems(child)%dis
+
+ !
+ ! posibly child+1 has higher priority, and if
+ ! so, get it, and increment child.
+ !
+
+ if (child .lt. N) then
+ prichildp1 = a%elems(child+1)%dis
+ if (prichild .lt. prichildp1) then
+ child = child+1
+ prichild = prichildp1
+ endif
+ endif
+
+ if (dis .ge. prichild) then
+ exit loop
+ ! we have a proper place for our new element,
+ ! bigger than either children's priority.
+ else
+ ! move child into parent.
+ a%elems(parent) = a%elems(child)
+ parent = child
+ child = 2*parent
+ end if
+ end do loop
+ a%elems(parent)%dis = dis
+ a%elems(parent)%idx = idx
+ pq_replace_max = a%elems(1)%dis
+ else
+ a%elems(1)%dis = dis
+ a%elems(1)%idx = idx
+ pq_replace_max = dis
+ endif
+ else
+ !
+ ! slower version using elementary pop and push operations.
+ !
+ call pq_extract_max(a,etmp)
+ etmp%dis = dis
+ etmp%idx = idx
+ pq_replace_max = pq_insert(a,dis,idx)
+ endif
+ return
+ end function pq_replace_max
+
+ subroutine pq_delete(a,i)
+ !
+ ! delete item with index 'i'
+ !
+ type(pq),pointer :: a
+ integer :: i
+
+ if ((i .lt. 1) .or. (i .gt. a%heap_size)) then
+ write (*,*) 'PQ_DELETE: error, attempt to remove out of bounds element.'
+ stop
+ endif
+
+ ! swap the item to be deleted with the last element
+ ! and shorten heap by one.
+ a%elems(i) = a%elems(a%heap_size)
+ a%heap_size = a%heap_size - 1
+
+ call heapify(a,i)
+
+ end subroutine pq_delete
+
+end module kdtree2_priority_queue_module
+
+
+module kdtree2_module
+ use kdtree2_precision_module
+ use kdtree2_priority_queue_module
+ use kdtree2module
+ use tree_nodemodule
+ use intervalmodule
+ ! K-D tree routines in Fortran 90 by Matt Kennel.
+ ! Original program was written in Sather by Steve Omohundro and
+ ! Matt Kennel. Only the Euclidean metric is supported.
+ !
+ !
+ ! This module is identical to 'kd_tree', except that the order
+ ! of subscripts is reversed in the data file.
+ ! In otherwords for an embedding of N D-dimensional vectors, the
+ ! data file is here, in natural Fortran order data(1:D, 1:N)
+ ! because Fortran lays out columns first,
+ !
+ ! whereas conventionally (C-style) it is data(1:N,1:D)
+ ! as in the original kd_tree module.
+ !
+ !-------------DATA TYPE, CREATION, DELETION---------------------
+ public :: kdkind
+ public :: kdtree2, kdtree2_result, tree_node, kdtree2_create, kdtree2_destroy
+ !---------------------------------------------------------------
+ !-------------------SEARCH ROUTINES-----------------------------
+ public :: kdtree2_n_nearest,kdtree2_n_nearest_around_point
+ ! Return fixed number of nearest neighbors around arbitrary vector,
+ ! or extant point in dataset, with decorrelation window.
+ !
+ public :: kdtree2_r_nearest, kdtree2_r_nearest_around_point
+ ! Return points within a fixed ball of arb vector/extant point
+ !
+ public :: kdtree2_sort_results
+ ! Sort, in order of increasing distance, rseults from above.
+ !
+ public :: kdtree2_r_count, kdtree2_r_count_around_point
+ ! Count points within a fixed ball of arb vector/extant point
+ !
+ public :: kdtree2_n_nearest_brute_force, kdtree2_r_nearest_brute_force
+ ! brute force of kdtree2_[n|r]_nearest
+ !----------------------------------------------------------------
+
+
+ integer, parameter :: bucket_size = 65
+ ! The maximum number of points to keep in a terminal node.
+
+! type interval
+! real(kdkind) :: lower,upper
+! end type interval
+!
+! type :: tree_node
+! ! an internal tree node
+! private
+! integer :: cut_dim
+! ! the dimension to cut
+! real(kdkind) :: cut_val
+! ! where to cut the dimension
+! real(kdkind) :: cut_val_left, cut_val_right
+! ! improved cutoffs knowing the spread in child boxes.
+! integer :: l, u
+! type(tree_node), pointer :: left, right
+! type(interval), pointer :: box(:) => null()
+! ! child pointers
+! ! Points included in this node are indexes[k] with k \in [l,u]
+!
+!
+! end type tree_node
+!
+! type :: kdtree2
+! ! Global information about the tree, one per tree
+! integer :: dimen=0, n=0
+! ! dimensionality and total # of points
+! real(kdkind), pointer :: the_data(:,:) => null()
+! ! pointer to the actual data array
+! !
+! ! IMPORTANT NOTE: IT IS DIMENSIONED the_data(1:d,1:N)
+! ! which may be opposite of what may be conventional.
+! ! This is, because in Fortran, the memory layout is such that
+! ! the first dimension is in sequential order. Hence, with
+! ! (1:d,1:N), all components of the vector will be in consecutive
+! ! memory locations. The search time is dominated by the
+! ! evaluation of distances in the terminal nodes. Putting all
+! ! vector components in consecutive memory location improves
+! ! memory cache locality, and hence search speed, and may enable
+! ! vectorization on some processors and compilers.
+!
+! integer, pointer :: ind(:) => null()
+! ! permuted index into the data, so that indexes[l..u] of some
+! ! bucket represent the indexes of the actual points in that
+! ! bucket.
+! logical :: sort = .false.
+! ! do we always sort output results?
+! logical :: rearrange = .false.
+! real(kdkind), pointer :: rearranged_data(:,:) => null()
+! ! if (rearrange .eqv. .true.) then rearranged_data has been
+! ! created so that rearranged_data(:,i) = the_data(:,ind(i)),
+! ! permitting search to use more cache-friendly rearranged_data, at
+! ! some initial computation and storage cost.
+! type(tree_node), pointer :: root => null()
+! ! root pointer of the tree
+! end type kdtree2
+
+ type :: tree_search_record
+ !
+ ! One of these is created for each search.
+ !
+ private
+ !
+ ! Many fields are copied from the tree structure, in order to
+ ! speed up the search.
+ !
+ integer :: dimen
+ integer :: nn, nfound
+ real(kdkind) :: ballsize
+ integer :: centeridx=999, correltime=9999
+ ! exclude points within 'correltime' of 'centeridx', iff centeridx >= 0
+ integer :: nalloc ! how much allocated for results(:)?
+ logical :: rearrange ! are the data rearranged or original?
+ ! did the # of points found overflow the storage provided?
+ logical :: overflow
+ real(kdkind), pointer :: qv(:) ! query vector
+ type(kdtree2_result), pointer :: results(:) ! results
+ type(pq) :: pq
+ real(kdkind), pointer :: data(:,:) ! temp pointer to data
+ integer, pointer :: ind(:) ! temp pointer to indexes
+ end type tree_search_record
+
+ private
+ ! everything else is private.
+
+ type(tree_search_record), save, target :: sr ! A GLOBAL VARIABLE for search
+
+contains
+
+ function kdtree2_create(input_data,dim,sort,rearrange) result (mr)
+ !
+ ! create the actual tree structure, given an input array of data.
+ !
+ ! Note, input data is input_data(1:d,1:N), NOT the other way around.
+ ! THIS IS THE REVERSE OF THE PREVIOUS VERSION OF THIS MODULE.
+ ! The reason for it is cache friendliness, improving performance.
+ !
+ ! Optional arguments: If 'dim' is specified, then the tree
+ ! will only search the first 'dim' components
+ ! of input_data, otherwise, dim is inferred
+ ! from SIZE(input_data,1).
+ !
+ ! if sort .eqv. .true. then output results
+ ! will be sorted by increasing distance.
+ ! default=.false., as it is faster to not sort.
+ !
+ ! if rearrange .eqv. .true. then an internal
+ ! copy of the data, rearranged by terminal node,
+ ! will be made for cache friendliness.
+ ! default=.true., as it speeds searches, but
+ ! building takes longer, and extra memory is used.
+ !
+ ! .. Function Return Cut_value ..
+ type(kdtree2), pointer :: mr
+ integer, intent(in), optional :: dim
+ logical, intent(in), optional :: sort
+ logical, intent(in), optional :: rearrange
+ ! ..
+ ! .. Array Arguments ..
+ real(kdkind), target :: input_data(:,:)
+ !
+ integer :: i
+ ! ..
+ allocate (mr)
+ mr%the_data => input_data
+ ! pointer assignment
+
+ if (present(dim)) then
+ mr%dimen = dim
+ else
+ mr%dimen = size(input_data,1)
+ end if
+ mr%n = size(input_data,2)
+
+ if (mr%dimen > mr%n) then
+ ! unlikely to be correct
+ write (*,*) 'KD_TREE_TRANS: likely user error.'
+ write (*,*) 'KD_TREE_TRANS: You passed in matrix with D=',mr%dimen
+ write (*,*) 'KD_TREE_TRANS: and N=',mr%n
+ write (*,*) 'KD_TREE_TRANS: note, that new format is data(1:D,1:N)'
+ write (*,*) 'KD_TREE_TRANS: with usually N >> D. If N =approx= D, then a k-d tree'
+ write (*,*) 'KD_TREE_TRANS: is not an appropriate data structure.'
+ stop
+ end if
+
+ call build_tree(mr)
+
+ if (present(sort)) then
+ mr%sort = sort
+ else
+ mr%sort = .false.
+ endif
+
+ if (present(rearrange)) then
+ mr%rearrange = rearrange
+ else
+ mr%rearrange = .true.
+ endif
+
+ if (mr%rearrange) then
+ allocate(mr%rearranged_data(mr%dimen,mr%n))
+ do i=1,mr%n
+ mr%rearranged_data(:,i) = mr%the_data(:, &
+ mr%ind(i))
+ enddo
+ else
+ nullify(mr%rearranged_data)
+ endif
+
+ end function kdtree2_create
+
+ subroutine build_tree(tp)
+ type(kdtree2), pointer :: tp
+ ! ..
+ integer :: j
+ type(tree_node), pointer :: dummy => null()
+ ! ..
+ allocate (tp%ind(tp%n))
+ forall (j=1:tp%n)
+ tp%ind(j) = j
+ end forall
+ tp%root => build_tree_for_range(tp,1,tp%n, dummy)
+ end subroutine build_tree
+
+ recursive function build_tree_for_range(tp,l,u,parent) result (res)
+ ! .. Function Return Cut_value ..
+ type(tree_node), pointer :: res
+ ! ..
+ ! .. Structure Arguments ..
+ type(kdtree2), pointer :: tp
+ type(tree_node),pointer :: parent
+ ! ..
+ ! .. Scalar Arguments ..
+ integer, intent (In) :: l, u
+ ! ..
+ ! .. Local Scalars ..
+ integer :: i, c, m, dimen
+ logical :: recompute
+ real(kdkind) :: average
+
+!!$ If (.False.) Then
+!!$ If ((l .Lt. 1) .Or. (l .Gt. tp%n)) Then
+!!$ Stop 'illegal L value in build_tree_for_range'
+!!$ End If
+!!$ If ((u .Lt. 1) .Or. (u .Gt. tp%n)) Then
+!!$ Stop 'illegal u value in build_tree_for_range'
+!!$ End If
+!!$ If (u .Lt. l) Then
+!!$ Stop 'U is less than L, thats illegal.'
+!!$ End If
+!!$ Endif
+!!$
+ ! first compute min and max
+ dimen = tp%dimen
+ allocate (res)
+ allocate(res%box(dimen))
+
+ ! First, compute an APPROXIMATE bounding box of all points associated with this node.
+ if ( u < l ) then
+ ! no points in this box
+ nullify(res)
+ return
+ end if
+
+ if ((u-l)<=bucket_size) then
+ !
+ ! always compute true bounding box for terminal nodes.
+ !
+ do i=1,dimen
+ call spread_in_coordinate(tp,i,l,u,res%box(i))
+ end do
+ res%cut_dim = 0
+ res%cut_val = 0.0
+ res%l = l
+ res%u = u
+ res%left =>null()
+ res%right => null()
+ else
+ !
+ ! modify approximate bounding box. This will be an
+ ! overestimate of the true bounding box, as we are only recomputing
+ ! the bounding box for the dimension that the parent split on.
+ !
+ ! Going to a true bounding box computation would significantly
+ ! increase the time necessary to build the tree, and usually
+ ! has only a very small difference. This box is not used
+ ! for searching but only for deciding which coordinate to split on.
+ !
+ do i=1,dimen
+ recompute=.true.
+ if (associated(parent)) then
+ if (i .ne. parent%cut_dim) then
+ recompute=.false.
+ end if
+ endif
+ if (recompute) then
+ call spread_in_coordinate(tp,i,l,u,res%box(i))
+ else
+ res%box(i) = parent%box(i)
+ endif
+ end do
+
+
+ c = maxloc(res%box(1:dimen)%upper-res%box(1:dimen)%lower,1)
+ !
+ ! c is the identity of which coordinate has the greatest spread.
+ !
+
+ if (.false.) then
+ ! select exact median to have fully balanced tree.
+ m = (l+u)/2
+ call select_on_coordinate(tp%the_data,tp%ind,c,m,l,u)
+ else
+ !
+ ! select point halfway between min and max, as per A. Moore,
+ ! who says this helps in some degenerate cases, or
+ ! actual arithmetic average.
+ !
+ if (.true.) then
+ ! actually compute average
+ average = sum(tp%the_data(c,tp%ind(l:u))) / real(u-l+1,kdkind)
+ else
+ average = (res%box(c)%upper + res%box(c)%lower)/2.0
+ endif
+
+ res%cut_val = average
+ m = select_on_coordinate_value(tp%the_data,tp%ind,c,average,l,u)
+ endif
+
+ ! moves indexes around
+ res%cut_dim = c
+ res%l = l
+ res%u = u
+! res%cut_val = tp%the_data(c,tp%ind(m))
+
+ res%left => build_tree_for_range(tp,l,m,res)
+ res%right => build_tree_for_range(tp,m+1,u,res)
+
+ if (associated(res%right) .eqv. .false.) then
+ res%box = res%left%box
+ res%cut_val_left = res%left%box(c)%upper
+ res%cut_val = res%cut_val_left
+ elseif (associated(res%left) .eqv. .false.) then
+ res%box = res%right%box
+ res%cut_val_right = res%right%box(c)%lower
+ res%cut_val = res%cut_val_right
+ else
+ res%cut_val_right = res%right%box(c)%lower
+ res%cut_val_left = res%left%box(c)%upper
+ res%cut_val = (res%cut_val_left + res%cut_val_right)/2
+
+
+ ! now remake the true bounding box for self.
+ ! Since we are taking unions (in effect) of a tree structure,
+ ! this is much faster than doing an exhaustive
+ ! search over all points
+ res%box%upper = max(res%left%box%upper,res%right%box%upper)
+ res%box%lower = min(res%left%box%lower,res%right%box%lower)
+ endif
+ end if
+ end function build_tree_for_range
+
+ integer function select_on_coordinate_value(v,ind,c,alpha,li,ui) &
+ result(res)
+ ! Move elts of ind around between l and u, so that all points
+ ! <= than alpha (in c cooordinate) are first, and then
+ ! all points > alpha are second.
+
+ !
+ ! Algorithm (matt kennel).
+ !
+ ! Consider the list as having three parts: on the left,
+ ! the points known to be <= alpha. On the right, the points
+ ! known to be > alpha, and in the middle, the currently unknown
+ ! points. The algorithm is to scan the unknown points, starting
+ ! from the left, and swapping them so that they are added to
+ ! the left stack or the right stack, as appropriate.
+ !
+ ! The algorithm finishes when the unknown stack is empty.
+ !
+ ! .. Scalar Arguments ..
+ integer, intent (In) :: c, li, ui
+ real(kdkind), intent(in) :: alpha
+ ! ..
+ real(kdkind) :: v(1:,1:)
+ integer :: ind(1:)
+ integer :: tmp
+ ! ..
+ integer :: lb, rb
+ !
+ ! The points known to be <= alpha are in
+ ! [l,lb-1]
+ !
+ ! The points known to be > alpha are in
+ ! [rb+1,u].
+ !
+ ! Therefore we add new points into lb or
+ ! rb as appropriate. When lb=rb
+ ! we are done. We return the location of the last point <= alpha.
+ !
+ !
+ lb = li; rb = ui
+
+ do while (lb < rb)
+ if ( v(c,ind(lb)) <= alpha ) then
+ ! it is good where it is.
+ lb = lb+1
+ else
+ ! swap it with rb.
+ tmp = ind(lb); ind(lb) = ind(rb); ind(rb) = tmp
+ rb = rb-1
+ endif
+ end do
+
+ ! now lb .eq. ub
+ if (v(c,ind(lb)) <= alpha) then
+ res = lb
+ else
+ res = lb-1
+ endif
+
+ end function select_on_coordinate_value
+
+ subroutine select_on_coordinate(v,ind,c,k,li,ui)
+ ! Move elts of ind around between l and u, so that the kth
+ ! element
+ ! is >= those below, <= those above, in the coordinate c.
+ ! .. Scalar Arguments ..
+ integer, intent (In) :: c, k, li, ui
+ ! ..
+ integer :: i, l, m, s, t, u
+ ! ..
+ real(kdkind) :: v(:,:)
+ integer :: ind(:)
+ ! ..
+ l = li
+ u = ui
+ do while (l<u)
+ t = ind(l)
+ m = l
+ do i = l + 1, u
+ if (v(c,ind(i))<v(c,t)) then
+ m = m + 1
+ s = ind(m)
+ ind(m) = ind(i)
+ ind(i) = s
+ end if
+ end do
+ s = ind(l)
+ ind(l) = ind(m)
+ ind(m) = s
+ if (m<=k) l = m + 1
+ if (m>=k) u = m - 1
+ end do
+ end subroutine select_on_coordinate
+
+ subroutine spread_in_coordinate(tp,c,l,u,interv)
+ ! the spread in coordinate 'c', between l and u.
+ !
+ ! Return lower bound in 'smin', and upper in 'smax',
+ ! ..
+ ! .. Structure Arguments ..
+ type(kdtree2), pointer :: tp
+ type(interval), intent(out) :: interv
+ ! ..
+ ! .. Scalar Arguments ..
+ integer, intent (In) :: c, l, u
+ ! ..
+ ! .. Local Scalars ..
+ real(kdkind) :: last, lmax, lmin, t, smin,smax
+ integer :: i, ulocal
+ ! ..
+ ! .. Local Arrays ..
+ real(kdkind), pointer :: v(:,:)
+ integer, pointer :: ind(:)
+ ! ..
+ v => tp%the_data(1:,1:)
+ ind => tp%ind(1:)
+ smin = v(c,ind(l))
+ smax = smin
+
+ ulocal = u
+
+ do i = l + 2, ulocal, 2
+ lmin = v(c,ind(i-1))
+ lmax = v(c,ind(i))
+ if (lmin>lmax) then
+ t = lmin
+ lmin = lmax
+ lmax = t
+ end if
+ if (smin>lmin) smin = lmin
+ if (smax<lmax) smax = lmax
+ end do
+ if (i==ulocal+1) then
+ last = v(c,ind(ulocal))
+ if (smin>last) smin = last
+ if (smax<last) smax = last
+ end if
+
+ interv%lower = smin
+ interv%upper = smax
+
+ end subroutine spread_in_coordinate
+
+
+ subroutine kdtree2_destroy(tp)
+ ! Deallocates all memory for the tree, except input data matrix
+ ! .. Structure Arguments ..
+ type(kdtree2), pointer :: tp
+ ! ..
+ call destroy_node(tp%root)
+
+ deallocate (tp%ind)
+ nullify (tp%ind)
+
+ if (tp%rearrange) then
+ deallocate(tp%rearranged_data)
+ nullify(tp%rearranged_data)
+ endif
+
+ deallocate(tp)
+ return
+
+ contains
+ recursive subroutine destroy_node(np)
+ ! .. Structure Arguments ..
+ type(tree_node), pointer :: np
+ ! ..
+ ! .. Intrinsic Functions ..
+ intrinsic ASSOCIATED
+ ! ..
+ if (associated(np%left)) then
+ call destroy_node(np%left)
+ nullify (np%left)
+ end if
+ if (associated(np%right)) then
+ call destroy_node(np%right)
+ nullify (np%right)
+ end if
+ if (associated(np%box)) deallocate(np%box)
+ deallocate(np)
+ return
+
+ end subroutine destroy_node
+
+ end subroutine kdtree2_destroy
+
+ subroutine kdtree2_n_nearest(tp,qv,nn,results)
+ ! Find the 'nn' vectors in the tree nearest to 'qv' in euclidean norm
+ ! returning their indexes and distances in 'indexes' and 'distances'
+ ! arrays already allocated passed to this subroutine.
+ type(kdtree2), pointer :: tp
+ real(kdkind), target, intent (In) :: qv(:)
+ integer, intent (In) :: nn
+ type(kdtree2_result), target :: results(:)
+
+
+ sr%ballsize = huge(1.0)
+ sr%qv => qv
+ sr%nn = nn
+ sr%nfound = 0
+ sr%centeridx = -1
+ sr%correltime = 0
+ sr%overflow = .false.
+
+ sr%results => results
+
+ sr%nalloc = nn ! will be checked
+
+ sr%ind => tp%ind
+ sr%rearrange = tp%rearrange
+ if (tp%rearrange) then
+ sr%Data => tp%rearranged_data
+ else
+ sr%Data => tp%the_data
+ endif
+ sr%dimen = tp%dimen
+
+ call validate_query_storage(nn)
+ sr%pq = pq_create(results)
+
+ call search(tp%root)
+
+ if (tp%sort) then
+ call kdtree2_sort_results(nn, results)
+ endif
+! deallocate(sr%pqp)
+ return
+ end subroutine kdtree2_n_nearest
+
+ subroutine kdtree2_n_nearest_around_point(tp,idxin,correltime,nn,results)
+ ! Find the 'nn' vectors in the tree nearest to point 'idxin',
+ ! with correlation window 'correltime', returing results in
+ ! results(:), which must be pre-allocated upon entry.
+ type(kdtree2), pointer :: tp
+ integer, intent (In) :: idxin, correltime, nn
+ type(kdtree2_result), target :: results(:)
+
+ allocate (sr%qv(tp%dimen))
+ sr%qv = tp%the_data(:,idxin) ! copy the vector
+ sr%ballsize = huge(1.0) ! the largest real(kdkind) number
+ sr%centeridx = idxin
+ sr%correltime = correltime
+
+ sr%nn = nn
+ sr%nfound = 0
+
+ sr%dimen = tp%dimen
+ sr%nalloc = nn
+
+ sr%results => results
+
+ sr%ind => tp%ind
+ sr%rearrange = tp%rearrange
+
+ if (sr%rearrange) then
+ sr%Data => tp%rearranged_data
+ else
+ sr%Data => tp%the_data
+ endif
+
+ call validate_query_storage(nn)
+ sr%pq = pq_create(results)
+
+ call search(tp%root)
+
+ if (tp%sort) then
+ call kdtree2_sort_results(nn, results)
+ endif
+ deallocate (sr%qv)
+ return
+ end subroutine kdtree2_n_nearest_around_point
+
+ subroutine kdtree2_r_nearest(tp,qv,r2,nfound,nalloc,results)
+ ! find the nearest neighbors to point 'qv', within SQUARED
+ ! Euclidean distance 'r2'. Upon ENTRY, nalloc must be the
+ ! size of memory allocated for results(1:nalloc). Upon
+ ! EXIT, nfound is the number actually found within the ball.
+ !
+ ! Note that if nfound .gt. nalloc then more neighbors were found
+ ! than there were storage to store. The resulting list is NOT
+ ! the smallest ball inside norm r^2
+ !
+ ! Results are NOT sorted unless tree was created with sort option.
+ type(kdtree2), pointer :: tp
+ real(kdkind), target, intent (In) :: qv(:)
+ real(kdkind), intent(in) :: r2
+ integer, intent(out) :: nfound
+ integer, intent (In) :: nalloc
+ type(kdtree2_result), target :: results(:)
+
+ !
+ sr%qv => qv
+ sr%ballsize = r2
+ sr%nn = 0 ! flag for fixed ball search
+ sr%nfound = 0
+ sr%centeridx = -1
+ sr%correltime = 0
+
+ sr%results => results
+
+ call validate_query_storage(nalloc)
+ sr%nalloc = nalloc
+ sr%overflow = .false.
+ sr%ind => tp%ind
+ sr%rearrange= tp%rearrange
+
+ if (tp%rearrange) then
+ sr%Data => tp%rearranged_data
+ else
+ sr%Data => tp%the_data
+ endif
+ sr%dimen = tp%dimen
+
+ !
+ !sr%dsl = Huge(sr%dsl) ! set to huge positive values
+ !sr%il = -1 ! set to invalid indexes
+ !
+
+ call search(tp%root)
+ nfound = sr%nfound
+ if (tp%sort) then
+ call kdtree2_sort_results(nfound, results)
+ endif
+
+ if (sr%overflow) then
+ write (*,*) 'KD_TREE_TRANS: warning! return from kdtree2_r_nearest found more neighbors'
+ write (*,*) 'KD_TREE_TRANS: than storage was provided for. Answer is NOT smallest ball'
+ write (*,*) 'KD_TREE_TRANS: with that number of neighbors! I.e. it is wrong.'
+ endif
+
+ return
+ end subroutine kdtree2_r_nearest
+
+ subroutine kdtree2_r_nearest_around_point(tp,idxin,correltime,r2,&
+ nfound,nalloc,results)
+ !
+ ! Like kdtree2_r_nearest, but around a point 'idxin' already existing
+ ! in the data set.
+ !
+ ! Results are NOT sorted unless tree was created with sort option.
+ !
+ type(kdtree2), pointer :: tp
+ integer, intent (In) :: idxin, correltime, nalloc
+ real(kdkind), intent(in) :: r2
+ integer, intent(out) :: nfound
+ type(kdtree2_result), target :: results(:)
+ ! ..
+ ! .. Intrinsic Functions ..
+ intrinsic HUGE
+ ! ..
+ allocate (sr%qv(tp%dimen))
+ sr%qv = tp%the_data(:,idxin) ! copy the vector
+ sr%ballsize = r2
+ sr%nn = 0 ! flag for fixed r search
+ sr%nfound = 0
+ sr%centeridx = idxin
+ sr%correltime = correltime
+
+ sr%results => results
+
+ sr%nalloc = nalloc
+ sr%overflow = .false.
+
+ call validate_query_storage(nalloc)
+
+ ! sr%dsl = HUGE(sr%dsl) ! set to huge positive values
+ ! sr%il = -1 ! set to invalid indexes
+
+ sr%ind => tp%ind
+ sr%rearrange = tp%rearrange
+
+ if (tp%rearrange) then
+ sr%Data => tp%rearranged_data
+ else
+ sr%Data => tp%the_data
+ endif
+ sr%rearrange = tp%rearrange
+ sr%dimen = tp%dimen
+
+ !
+ !sr%dsl = Huge(sr%dsl) ! set to huge positive values
+ !sr%il = -1 ! set to invalid indexes
+ !
+
+ call search(tp%root)
+ nfound = sr%nfound
+ if (tp%sort) then
+ call kdtree2_sort_results(nfound,results)
+ endif
+
+ if (sr%overflow) then
+ write (*,*) 'KD_TREE_TRANS: warning! return from kdtree2_r_nearest found more neighbors'
+ write (*,*) 'KD_TREE_TRANS: than storage was provided for. Answer is NOT smallest ball'
+ write (*,*) 'KD_TREE_TRANS: with that number of neighbors! I.e. it is wrong.'
+ endif
+
+ deallocate (sr%qv)
+ return
+ end subroutine kdtree2_r_nearest_around_point
+
+ function kdtree2_r_count(tp,qv,r2) result(nfound)
+ ! Count the number of neighbors within square distance 'r2'.
+ type(kdtree2), pointer :: tp
+ real(kdkind), target, intent (In) :: qv(:)
+ real(kdkind), intent(in) :: r2
+ integer :: nfound
+ ! ..
+ ! .. Intrinsic Functions ..
+ intrinsic HUGE
+ ! ..
+ sr%qv => qv
+ sr%ballsize = r2
+
+ sr%nn = 0 ! flag for fixed r search
+ sr%nfound = 0
+ sr%centeridx = -1
+ sr%correltime = 0
+
+ nullify(sr%results) ! for some reason, FTN 95 chokes on '=> null()'
+
+ sr%nalloc = 0 ! we do not allocate any storage but that's OK
+ ! for counting.
+ sr%ind => tp%ind
+ sr%rearrange = tp%rearrange
+ if (tp%rearrange) then
+ sr%Data => tp%rearranged_data
+ else
+ sr%Data => tp%the_data
+ endif
+ sr%dimen = tp%dimen
+
+ !
+ !sr%dsl = Huge(sr%dsl) ! set to huge positive values
+ !sr%il = -1 ! set to invalid indexes
+ !
+ sr%overflow = .false.
+
+ call search(tp%root)
+
+ nfound = sr%nfound
+
+ return
+ end function kdtree2_r_count
+
+ function kdtree2_r_count_around_point(tp,idxin,correltime,r2) &
+ result(nfound)
+ ! Count the number of neighbors within square distance 'r2' around
+ ! point 'idxin' with decorrelation time 'correltime'.
+ !
+ type(kdtree2), pointer :: tp
+ integer, intent (In) :: correltime, idxin
+ real(kdkind), intent(in) :: r2
+ integer :: nfound
+ ! ..
+ ! ..
+ ! .. Intrinsic Functions ..
+ intrinsic HUGE
+ ! ..
+ allocate (sr%qv(tp%dimen))
+ sr%qv = tp%the_data(:,idxin)
+ sr%ballsize = r2
+
+ sr%nn = 0 ! flag for fixed r search
+ sr%nfound = 0
+ sr%centeridx = idxin
+ sr%correltime = correltime
+ nullify(sr%results)
+
+ sr%nalloc = 0 ! we do not allocate any storage but that's OK
+ ! for counting.
+
+ sr%ind => tp%ind
+ sr%rearrange = tp%rearrange
+
+ if (sr%rearrange) then
+ sr%Data => tp%rearranged_data
+ else
+ sr%Data => tp%the_data
+ endif
+ sr%dimen = tp%dimen
+
+ !
+ !sr%dsl = Huge(sr%dsl) ! set to huge positive values
+ !sr%il = -1 ! set to invalid indexes
+ !
+ sr%overflow = .false.
+
+ call search(tp%root)
+
+ nfound = sr%nfound
+
+ return
+ end function kdtree2_r_count_around_point
+
+
+ subroutine validate_query_storage(n)
+ !
+ ! make sure we have enough storage for n
+ !
+ integer, intent(in) :: n
+
+ if (size(sr%results,1) .lt. n) then
+ write (*,*) 'KD_TREE_TRANS: you did not provide enough storage for results(1:n)'
+ stop
+ return
+ endif
+
+ return
+ end subroutine validate_query_storage
+
+ function square_distance(d, iv,qv) result (res)
+ ! distance between iv[1:n] and qv[1:n]
+ ! .. Function Return Value ..
+ ! re-implemented to improve vectorization.
+ real(kdkind) :: res
+ ! ..
+ ! ..
+ ! .. Scalar Arguments ..
+ integer :: d
+ real :: d_min(d)
+ ! ..
+ ! .. Array Arguments ..
+ real(kdkind) :: iv(:),qv(:)
+ ! ..
+ ! ..
+ ! .. Periodicity added by S Skory
+ ! res = sum( (iv(1:d)-qv(1:d))**2 )
+ d_min = min( abs(iv(1:d) - qv(1:d)) , 1. - abs(iv(1:d) - qv(1:d)) )
+ res = sum(d_min(1:d)**2)
+ end function square_distance
+
+ recursive subroutine search(node)
+ !
+ ! This is the innermost core routine of the kd-tree search. Along
+ ! with "process_terminal_node", it is the performance bottleneck.
+ !
+ ! This version uses a logically complete secondary search of
+ ! "box in bounds", whether the sear
+ !
+ type(Tree_node), pointer :: node
+ ! ..
+ type(tree_node),pointer :: ncloser, nfarther
+ !
+ integer :: cut_dim, i
+ ! ..
+ real(kdkind) :: qval, dis, dis_right, dis_left, dis_node
+ real(kdkind) :: ballsize
+ real(kdkind), pointer :: qv(:)
+ type(interval), pointer :: box(:)
+
+ if ((associated(node%left) .and. associated(node%right)) .eqv. .false.) then
+ ! we are on a terminal node
+ if (sr%nn .eq. 0) then
+ call process_terminal_node_fixedball(node)
+ else
+ call process_terminal_node(node)
+ endif
+ else
+ ! we are not on a terminal node
+ qv => sr%qv(1:)
+ cut_dim = node%cut_dim
+ qval = qv(cut_dim)
+
+
+ ! Periodic stuff added by S Skory. We have to test to see which
+ ! node edge is closer, rather than just doing a simple less than or
+ ! greater than to the cut in this dimension.
+ dis_left = min( (node%box(cut_dim)%lower - qval)**2, &
+ (1 - abs(node%box(cut_dim)%lower - qval))**2)
+ dis_right = min( (node%box(cut_dim)%upper - qval)**2, &
+ (1 - abs(node%box(cut_dim)%upper - qval))**2)
+
+
+ if (qval < node%cut_val) then
+ !if (dis_left <= dis_right) then
+ ncloser => node%left
+ nfarther => node%right
+ dis_node = (node%cut_val_right - qval)**2
+ dis = min(dis_right, dis_node)
+! extra = node%cut_val - qval
+ else
+ ncloser => node%right
+ nfarther => node%left
+ dis_node = (node%cut_val_left - qval)**2
+ dis = min(dis_left, dis_node)
+! extra = qval- node%cut_val_left
+ endif
+
+ if (associated(ncloser)) call search(ncloser)
+
+ ! we may need to search the second node.
+ if (associated(nfarther)) then
+ ballsize = sr%ballsize
+! dis=extra**2
+
+ if (dis <= ballsize) then
+ !
+ ! we do this separately as going on the first cut dimen is often
+ ! a good idea.
+ ! note that if extra**2 < sr%ballsize, then the next
+ ! check will also be false.
+ !
+ box => node%box(1:)
+ do i=1,sr%dimen
+ if (i .ne. cut_dim) then
+ dis = dis + dis2_from_bnd(qv(i),box(i)%lower,box(i)%upper)
+ if (dis > ballsize) then
+ return
+ endif
+ endif
+ end do
+
+ !
+ ! if we are still here then we need to search mroe.
+ !
+ call search(nfarther)
+ endif
+ endif
+ end if
+ end subroutine search
+
+
+! real(kdkind) function dis2_from_bnd(x,amin,amax) result (res)
+! real(kdkind), intent(in) :: x, amin,amax
+!
+! if (x > amax) then
+! res = (x-amax)**2;
+! return
+! else
+! if (x < amin) then
+! res = (amin-x)**2;
+! return
+! else
+! res = 0.0
+! return
+! endif
+! endif
+! return
+! end function dis2_from_bnd
+
+ real(kdkind) function dis2_from_bnd(x,amin,amax) result (res)
+ ! Periodicity added by S Skory
+ real(kdkind), intent(in) :: x, amin,amax
+ real :: dxmax, dxmin
+
+ if ((x < amax) .and. (x > amin)) then
+ res = 0.0
+ return
+ else
+ dxmax = (min( abs(x - amax), 1. - abs(x - amax)))**2
+ dxmin = (min( abs(x - amin), 1. - abs(x - amin)))**2
+ res = min(dxmax, dxmin)
+ return
+ endif
+ return
+ end function dis2_from_bnd
+
+ logical function box_in_search_range(node, sr) result(res)
+ !
+ ! Return the distance from 'qv' to the CLOSEST corner of node's
+ ! bounding box
+ ! for all coordinates outside the box. Coordinates inside the box
+ ! contribute nothing to the distance.
+ !
+ type(tree_node), pointer :: node
+ type(tree_search_record), pointer :: sr
+
+ integer :: dimen, i
+ real(kdkind) :: dis, ballsize
+ real(kdkind) :: l, u
+
+ dimen = sr%dimen
+ ballsize = sr%ballsize
+ dis = 0.0
+ res = .true.
+ do i=1,dimen
+ l = node%box(i)%lower
+ u = node%box(i)%upper
+ dis = dis + (dis2_from_bnd(sr%qv(i),l,u))
+ if (dis > ballsize) then
+ res = .false.
+ return
+ endif
+ end do
+ res = .true.
+ return
+ end function box_in_search_range
+
+
+ subroutine process_terminal_node(node)
+ !
+ ! Look for actual near neighbors in 'node', and update
+ ! the search results on the sr data structure.
+ !
+ type(tree_node), pointer :: node
+ !
+ real(kdkind), pointer :: qv(:)
+ integer, pointer :: ind(:)
+ real(kdkind), pointer :: data(:,:)
+ !
+ integer :: dimen, i, indexofi, k, centeridx, correltime
+ real(kdkind) :: ballsize, sd, newpri
+ logical :: rearrange
+ type(pq), pointer :: pqp
+ real :: sdtemp
+ !
+ ! copy values from sr to local variables
+ !
+ !
+ ! Notice, making local pointers with an EXPLICIT lower bound
+ ! seems to generate faster code.
+ ! why? I don't know.
+ qv => sr%qv(1:)
+ pqp => sr%pq
+ dimen = sr%dimen
+ ballsize = sr%ballsize
+ rearrange = sr%rearrange
+ ind => sr%ind(1:)
+ data => sr%Data(1:,1:)
+ centeridx = sr%centeridx
+ correltime = sr%correltime
+
+ ! doing_correl = (centeridx >= 0) ! Do we have a decorrelation window?
+ ! include_point = .true. ! by default include all points
+ ! search through terminal bucket.
+
+ mainloop: do i = node%l, node%u
+ if (rearrange) then
+ sd = 0.0
+ do k = 1,dimen
+ !sd = sd + (data(k,i) - qv(k))**2
+ ! Periodicity by S Skory
+ sdtemp = min( abs(data(k,i) - qv(k)) , 1. - abs(data(k,i) - qv(k)))
+ sd = sd + sdtemp**2
+ if (sd>ballsize) cycle mainloop
+ end do
+ indexofi = ind(i) ! only read it if we have not broken out
+ else
+ indexofi = ind(i)
+ sd = 0.0
+ do k = 1,dimen
+ !sd = sd + (data(k,indexofi) - qv(k))**2
+ sdtemp = min( abs(data(k,indexofi) - qv(k)), 1. - abs(data(k,indexofi) - qv(k)))
+ sd = sd + sdtemp**2
+ if (sd>ballsize) cycle mainloop
+ end do
+ endif
+
+ if (centeridx > 0) then ! doing correlation interval?
+ if (abs(indexofi-centeridx) < correltime) cycle mainloop
+ endif
+
+
+ !
+ ! two choices for any point. The list so far is either undersized,
+ ! or it is not.
+ !
+ ! If it is undersized, then add the point and its distance
+ ! unconditionally. If the point added fills up the working
+ ! list then set the sr%ballsize, maximum distance bound (largest distance on
+ ! list) to be that distance, instead of the initialized +infinity.
+ !
+ ! If the running list is full size, then compute the
+ ! distance but break out immediately if it is larger
+ ! than sr%ballsize, "best squared distance" (of the largest element),
+ ! as it cannot be a good neighbor.
+ !
+ ! Once computed, compare to best_square distance.
+ ! if it is smaller, then delete the previous largest
+ ! element and add the new one.
+
+ if (sr%nfound .lt. sr%nn) then
+ !
+ ! add this point unconditionally to fill list.
+ !
+ sr%nfound = sr%nfound +1
+ newpri = pq_insert(pqp,sd,indexofi)
+ if (sr%nfound .eq. sr%nn) ballsize = newpri
+ ! we have just filled the working list.
+ ! put the best square distance to the maximum value
+ ! on the list, which is extractable from the PQ.
+
+ else
+ !
+ ! now, if we get here,
+ ! we know that the current node has a squared
+ ! distance smaller than the largest one on the list, and
+ ! belongs on the list.
+ ! Hence we replace that with the current one.
+ !
+ ballsize = pq_replace_max(pqp,sd,indexofi)
+ endif
+ end do mainloop
+ !
+ ! Reset sr variables which may have changed during loop
+ !
+ sr%ballsize = ballsize
+
+ end subroutine process_terminal_node
+
+ subroutine process_terminal_node_fixedball(node)
+ !
+ ! Look for actual near neighbors in 'node', and update
+ ! the search results on the sr data structure, i.e.
+ ! save all within a fixed ball.
+ !
+ type(tree_node), pointer :: node
+ !
+ real(kdkind), pointer :: qv(:)
+ integer, pointer :: ind(:)
+ real(kdkind), pointer :: data(:,:)
+ !
+ integer :: nfound
+ integer :: dimen, i, indexofi, k
+ integer :: centeridx, correltime, nn
+ real(kdkind) :: ballsize, sd
+ logical :: rearrange
+ real :: sdtemp
+ !
+ ! copy values from sr to local variables
+ !
+ qv => sr%qv(1:)
+ dimen = sr%dimen
+ ballsize = sr%ballsize
+ rearrange = sr%rearrange
+ ind => sr%ind(1:)
+ data => sr%Data(1:,1:)
+ centeridx = sr%centeridx
+ correltime = sr%correltime
+ nn = sr%nn ! number to search for
+ nfound = sr%nfound
+
+ ! search through terminal bucket.
+ mainloop: do i = node%l, node%u
+
+ !
+ ! two choices for any point. The list so far is either undersized,
+ ! or it is not.
+ !
+ ! If it is undersized, then add the point and its distance
+ ! unconditionally. If the point added fills up the working
+ ! list then set the sr%ballsize, maximum distance bound (largest distance on
+ ! list) to be that distance, instead of the initialized +infinity.
+ !
+ ! If the running list is full size, then compute the
+ ! distance but break out immediately if it is larger
+ ! than sr%ballsize, "best squared distance" (of the largest element),
+ ! as it cannot be a good neighbor.
+ !
+ ! Once computed, compare to best_square distance.
+ ! if it is smaller, then delete the previous largest
+ ! element and add the new one.
+
+ ! which index to the point do we use?
+
+ if (rearrange) then
+ sd = 0.0
+ do k = 1,dimen
+ !sd = sd + (data(k,i) - qv(k))**2
+ ! Periodicity S Skory
+ sdtemp = min( abs(data(k,i) - qv(k)) , 1. - abs(data(k,i) - qv(k)))
+ ! print *, "k ", k, " data(k,i) ", data(k,i), " qv(k) ", qv(k), " sdtemp ", sdtemp
+ sd = sd + sdtemp**2
+ if (sd>ballsize) cycle mainloop
+ end do
+ indexofi = ind(i) ! only read it if we have not broken out
+ else
+ indexofi = ind(i)
+ sd = 0.0
+ do k = 1,dimen
+ !sd = sd + (data(k,indexofi) - qv(k))**2
+ ! Periodicity S Skory
+ sdtemp = min( abs(data(k,indexofi) - qv(k)), 1. - abs(data(k,indexofi) - qv(k)))
+ ! print *, "k ", k, " data(k,indexofi) ", data(k,indexofi), " qv(k) ", qv(k), " sdtemp ", sdtemp
+ sd = sd + sdtemp**2
+ if (sd>ballsize) cycle mainloop
+ end do
+ endif
+
+ if (centeridx > 0) then ! doing correlation interval?
+ if (abs(indexofi-centeridx)<correltime) cycle mainloop
+ endif
+
+ nfound = nfound+1
+ if (nfound .gt. sr%nalloc) then
+ ! oh nuts, we have to add another one to the tree but
+ ! there isn't enough room.
+ sr%overflow = .true.
+ else
+ sr%results(nfound)%dis = sd
+ sr%results(nfound)%idx = indexofi
+ endif
+ end do mainloop
+ !
+ ! Reset sr variables which may have changed during loop
+ !
+ sr%nfound = nfound
+ end subroutine process_terminal_node_fixedball
+
+ subroutine kdtree2_n_nearest_brute_force(tp,qv,nn,results)
+ ! find the 'n' nearest neighbors to 'qv' by exhaustive search.
+ ! only use this subroutine for testing, as it is SLOW! The
+ ! whole point of a k-d tree is to avoid doing what this subroutine
+ ! does.
+ type(kdtree2), pointer :: tp
+ real(kdkind), intent (In) :: qv(:)
+ integer, intent (In) :: nn
+ type(kdtree2_result) :: results(:)
+
+ integer :: i, j, k
+ real(kdkind), allocatable :: all_distances(:)
+ ! ..
+ allocate (all_distances(tp%n))
+ do i = 1, tp%n
+ all_distances(i) = square_distance(tp%dimen,qv,tp%the_data(:,i))
+ end do
+ ! now find 'n' smallest distances
+ do i = 1, nn
+ results(i)%dis = huge(1.0)
+ results(i)%idx = -1
+ end do
+ do i = 1, tp%n
+ if (all_distances(i)<results(nn)%dis) then
+ ! insert it somewhere on the list
+ do j = 1, nn
+ if (all_distances(i)<results(j)%dis) exit
+ end do
+ ! now we know 'j'
+ do k = nn - 1, j, -1
+ results(k+1) = results(k)
+ end do
+ results(j)%dis = all_distances(i)
+ results(j)%idx = i
+ end if
+ end do
+ deallocate (all_distances)
+ end subroutine kdtree2_n_nearest_brute_force
+
+
+ subroutine kdtree2_r_nearest_brute_force(tp,qv,r2,nfound,results)
+ ! find the nearest neighbors to 'qv' with distance**2 <= r2 by exhaustive search.
+ ! only use this subroutine for testing, as it is SLOW! The
+ ! whole point of a k-d tree is to avoid doing what this subroutine
+ ! does.
+ type(kdtree2), pointer :: tp
+ real(kdkind), intent (In) :: qv(:)
+ real(kdkind), intent (In) :: r2
+ integer, intent(out) :: nfound
+ type(kdtree2_result) :: results(:)
+
+ integer :: i, nalloc
+ real(kdkind), allocatable :: all_distances(:)
+ ! ..
+ allocate (all_distances(tp%n))
+ do i = 1, tp%n
+ all_distances(i) = square_distance(tp%dimen,qv,tp%the_data(:,i))
+ end do
+
+ nfound = 0
+ nalloc = size(results,1)
+
+ do i = 1, tp%n
+ if (all_distances(i)< r2) then
+ ! insert it somewhere on the list
+ if (nfound .lt. nalloc) then
+ nfound = nfound+1
+ results(nfound)%dis = all_distances(i)
+ results(nfound)%idx = i
+ endif
+ end if
+ enddo
+ deallocate (all_distances)
+
+ call kdtree2_sort_results(nfound,results)
+
+
+ end subroutine kdtree2_r_nearest_brute_force
+
+ subroutine kdtree2_sort_results(nfound,results)
+ ! Use after search to sort results(1:nfound) in order of increasing
+ ! distance.
+ integer, intent(in) :: nfound
+ type(kdtree2_result), target :: results(:)
+ !
+ !
+
+ !THIS IS BUGGY WITH INTEL FORTRAN
+ ! If (nfound .Gt. 1) Call heapsort(results(1:nfound)%dis,results(1:nfound)%ind,nfound)
+ !
+ if (nfound .gt. 1) call heapsort_struct(results,nfound)
+
+ return
+ end subroutine kdtree2_sort_results
+
+ subroutine heapsort(a,ind,n)
+ !
+ ! Sort a(1:n) in ascending order, permuting ind(1:n) similarly.
+ !
+ ! If ind(k) = k upon input, then it will give a sort index upon output.
+ !
+ integer,intent(in) :: n
+ real(kdkind), intent(inout) :: a(:)
+ integer, intent(inout) :: ind(:)
+
+ !
+ !
+ real(kdkind) :: value ! temporary for a value from a()
+ integer :: ivalue ! temporary for a value from ind()
+
+ integer :: i,j
+ integer :: ileft,iright
+
+ ileft=n/2+1
+ iright=n
+
+ ! do i=1,n
+ ! ind(i)=i
+ ! Generate initial idum array
+ ! end do
+
+ if(n.eq.1) return
+
+ do
+ if(ileft > 1)then
+ ileft=ileft-1
+ value=a(ileft); ivalue=ind(ileft)
+ else
+ value=a(iright); ivalue=ind(iright)
+ a(iright)=a(1); ind(iright)=ind(1)
+ iright=iright-1
+ if (iright == 1) then
+ a(1)=value;ind(1)=ivalue
+ return
+ endif
+ endif
+ i=ileft
+ j=2*ileft
+ do while (j <= iright)
+ if(j < iright) then
+ if(a(j) < a(j+1)) j=j+1
+ endif
+ if(value < a(j)) then
+ a(i)=a(j); ind(i)=ind(j)
+ i=j
+ j=j+j
+ else
+ j=iright+1
+ endif
+ end do
+ a(i)=value; ind(i)=ivalue
+ end do
+ end subroutine heapsort
+
+ subroutine heapsort_struct(a,n)
+ !
+ ! Sort a(1:n) in ascending order
+ !
+ !
+ integer,intent(in) :: n
+ type(kdtree2_result),intent(inout) :: a(:)
+
+ !
+ !
+ type(kdtree2_result) :: value ! temporary value
+
+ integer :: i,j
+ integer :: ileft,iright
+
+ ileft=n/2+1
+ iright=n
+
+ ! do i=1,n
+ ! ind(i)=i
+ ! Generate initial idum array
+ ! end do
+
+ if(n.eq.1) return
+
+ do
+ if(ileft > 1)then
+ ileft=ileft-1
+ value=a(ileft)
+ else
+ value=a(iright)
+ a(iright)=a(1)
+ iright=iright-1
+ if (iright == 1) then
+ a(1) = value
+ return
+ endif
+ endif
+ i=ileft
+ j=2*ileft
+ do while (j <= iright)
+ if(j < iright) then
+ if(a(j)%dis < a(j+1)%dis) j=j+1
+ endif
+ if(value%dis < a(j)%dis) then
+ a(i)=a(j);
+ i=j
+ j=j+j
+ else
+ j=iright+1
+ endif
+ end do
+ a(i)=value
+ end do
+ end subroutine heapsort_struct
+
+end module kdtree2_module
+
Added: trunk/yt/extensions/kdtree/test.py
==============================================================================
--- (empty file)
+++ trunk/yt/extensions/kdtree/test.py Sun Dec 20 12:42:09 2009
@@ -0,0 +1,58 @@
+from Forthon import *
+from fKDpy import *
+import numpy,random
+
+n = 32768
+
+
+fKD.tags = fzeros((64),'i')
+fKD.dist = fzeros((64),'d')
+fKD.pos = fzeros((3,n),'d')
+fKD.nn = 64
+fKD.nparts = n
+fKD.sort = True
+fKD.rearrange = True
+fKD.qv = numpy.array([16./32, 16./32, 16./32])
+
+fp = open('parts.txt','r')
+xpos = []
+ypos = []
+zpos = []
+line = fp.readline()
+while line:
+ line = line.split()
+ xpos.append(float(line[0]))
+ ypos.append(float(line[1]))
+ zpos.append(float(line[2]))
+ line= fp.readline()
+
+fp.close()
+
+
+for k in range(32):
+ for j in range(32):
+ for i in range(32):
+ fKD.pos[0][i + j*32 + k*1024] = float(i)/32 + 1./64 + 0.0001*random.random()
+ fKD.pos[1][i + j*32 + k*1024] = float(j)/32 + 1./64 + 0.0001*random.random()
+ fKD.pos[2][i + j*32 + k*1024] = float(k)/32 + 1./64 + 0.0001*random.random()
+
+
+
+#print fKD.pos[0][0],fKD.pos[1][0],fKD.pos[2][0]
+
+create_tree()
+
+
+find_nn_nearest_neighbors()
+
+#print 'next'
+
+#fKD.qv = numpy.array([0., 0., 0.])
+
+#find_nn_nearest_neighbors()
+
+
+#print (fKD.tags - 1)
+#print fKD.dist
+
+free_tree()
Modified: trunk/yt/lagos/BaseDataTypes.py
==============================================================================
--- trunk/yt/lagos/BaseDataTypes.py (original)
+++ trunk/yt/lagos/BaseDataTypes.py Sun Dec 20 12:42:09 2009
@@ -1990,6 +1990,12 @@
& (r < self._radius))
return cm
+ def volume(self, unit="unitary"):
+ """
+ Return the volume of the cylinder in units of *unit*.
+ """
+ return math.pi * (self._radius)**2. * self._height * pf[unit]**3
+
class AMRRegionBase(AMR3DData):
"""
AMRRegions are rectangular prisms of data.
@@ -2031,6 +2037,15 @@
& (grid['z'] + dzp > self.left_edge[2]) )
return cm
+ def volume(self, unit = "unitary"):
+ """
+ Return the volume of the region in units *unit*.
+ """
+ diff = na.array(self.right_edge) - na.array(self.left_edge)
+ # Find the full volume
+ vol = na.prod(diff * self.pf[unit])
+ return vol
+
class AMRRegionStrictBase(AMRRegionBase):
"""
AMRRegion without any dx padding for cell selection
@@ -2092,6 +2107,20 @@
& (grid['z'] + dzp + off_z > self.left_edge[2]) )
return cm
+ def volume(self, unit = "unitary"):
+ """
+ Return the volume of the region in units *unit*.
+ """
+ period = self.pf["DomainRightEdge"] - self.pf["DomainLeftEdge"]
+ diff = na.array(self.right_edge) - na.array(self.left_edge)
+ # Correct for wrap-arounds.
+ tofix = (diff < 0)
+ toadd = period[tofix]
+ diff += toadd
+ # Find the full volume
+ vol = na.prod(diff * self.pf[unit])
+ return vol
+
class AMRPeriodicRegionStrictBase(AMRPeriodicRegionBase):
"""
@@ -2178,6 +2207,12 @@
self._cut_masks[grid.id] = cm
return cm
+ def volume(self, unit = "unitary"):
+ """
+ Return the volume of the sphere in units *unit*.
+ """
+ return 4./3. * math.pi * (self.radius * self.pf[unit])**3.0
+
class AMRFloatCoveringGridBase(AMR3DData):
"""
Covering grids represent fixed-resolution data over a given region.
Modified: trunk/yt/lagos/EnzoFields.py
==============================================================================
--- trunk/yt/lagos/EnzoFields.py (original)
+++ trunk/yt/lagos/EnzoFields.py Sun Dec 20 12:42:09 2009
@@ -69,16 +69,20 @@
validators=ValidateDataField("%s_Density" % species))
def _Metallicity(field, data):
- return data["Metal_Fraction"] / 0.0204
-add_field("Metallicity", units=r"Z_{\rm{Solar}}",
+ return data["Metal_Fraction"]
+def _ConvertMetallicity(data):
+ return 49.0196 # 1 / 0.0204
+add_field("Metallicity", units=r"Z_{\rm{\odot}}",
function=_Metallicity,
+ convert_function=_ConvertMetallicity,
validators=ValidateDataField("Metal_Density"),
projection_conversion="1")
def _Metallicity3(field, data):
- return data["SN_Colour"] / 0.0204
-add_field("Metallicity3", units=r"Z_{\rm{Solar}}",
+ return data["SN_Colour"]
+add_field("Metallicity3", units=r"Z_{\rm{\odot}}",
function=_Metallicity3,
+ convert_function=_ConvertMetallicity,
validators=ValidateDataField("SN_Colour"),
projection_conversion="1")
@@ -208,6 +212,7 @@
not_in_all = True)
EnzoFieldInfo["Temperature"]._units = r"\rm{K}"
+EnzoFieldInfo["Temperature"].units = r"K"
def _convertVelocity(data):
return data.convert("x-velocity")
@@ -244,20 +249,6 @@
add_field("particle_density_pyx", function=_pdensity_pyx,
validators=[ValidateSpatial(0)], convert_function=_convertDensity)
-def _spdensity(field, data):
- blank = na.zeros(data.ActiveDimensions, dtype='float32', order="FORTRAN")
- if data.NumberOfParticles == 0: return blank
- filter = data['creation_time'] > 0.0
- if not filter.any(): return blank
- cic_deposit.cic_deposit(data["particle_position_x"][filter],
- data["particle_position_y"][filter],
- data["particle_position_z"][filter], 3,
- data["particle_mass"][filter],
- blank, data.LeftEdge, data['dx'])
- return blank
-add_field("star_density", function=_spdensity,
- validators=[ValidateSpatial(0)], convert_function=_convertDensity)
-
def _spdensity_pyx(field, data):
blank = na.zeros(data.ActiveDimensions, dtype='float32')
if data.NumberOfParticles == 0: return blank
@@ -275,7 +266,81 @@
add_field("star_density_pyx", function=_spdensity_pyx,
validators=[ValidateSpatial(0)], convert_function=_convertDensity)
-EnzoFieldInfo["Temperature"].units = r"K"
+def _star_field(field, data):
+ """
+ Create a grid field for star quantities, weighted by star mass.
+ """
+ particle_field = field.name[5:]
+ top = na.zeros(data.ActiveDimensions, dtype='float32')
+ if data.NumberOfParticles == 0: return top
+ filter = data['creation_time'] > 0.0
+ if not filter.any(): return top
+ particle_field_data = data[particle_field][filter] * data['particle_mass'][filter]
+ CICDeposit_3(data["particle_position_x"][filter].astype(na.float64),
+ data["particle_position_y"][filter].astype(na.float64),
+ data["particle_position_z"][filter].astype(na.float64),
+ particle_field_data.astype(na.float32),
+ na.int64(na.where(filter)[0].size),
+ top, na.array(data.LeftEdge).astype(na.float64),
+ na.array(data.ActiveDimensions).astype(na.int32),
+ na.float64(data['dx']))
+ del particle_field_data
+
+ bottom = na.zeros(data.ActiveDimensions, dtype='float32')
+ CICDeposit_3(data["particle_position_x"][filter].astype(na.float64),
+ data["particle_position_y"][filter].astype(na.float64),
+ data["particle_position_z"][filter].astype(na.float64),
+ data["particle_mass"][filter].astype(na.float32),
+ na.int64(na.where(filter)[0].size),
+ bottom, na.array(data.LeftEdge).astype(na.float64),
+ na.array(data.ActiveDimensions).astype(na.int32),
+ na.float64(data['dx']))
+
+ top[bottom == 0] = 0.0
+ bnz = bottom.nonzero()
+ top[bnz] /= bottom[bnz]
+ return top
+
+add_field('star_metallicity_fraction', function=_star_field,
+ validators=[ValidateSpatial(0)])
+add_field('star_creation_time', function=_star_field,
+ validators=[ValidateSpatial(0)])
+add_field('star_dynamical_time', function=_star_field,
+ validators=[ValidateSpatial(0)])
+
+def _StarMetallicity(field, data):
+ return data['star_metallicity_fraction']
+add_field('StarMetallicity', units=r"Z_{\rm{\odot}}",
+ function=_StarMetallicity,
+ convert_function=_ConvertMetallicity,
+ projection_conversion="1")
+
+def _StarCreationTime(field, data):
+ return data['star_creation_time']
+def _ConvertEnzoTimeYears(data):
+ return data.pf.time_units['years']
+add_field('StarCreationTimeYears', units="\rm{yr}",
+ function=_StarCreationTime,
+ convert_function=_ConvertEnzoTimeYears,
+ projection_conversion="1")
+
+def _StarDynamicalTime(field, data):
+ return data['star_dynamical_time']
+add_field('StarDynamicalTimeYears', units="\rm{yr}",
+ function=_StarDynamicalTime,
+ convert_function=_ConvertEnzoTimeYears,
+ projection_conversion="1")
+
+def _StarAge(field, data):
+ star_age = na.zeros(data['StarCreationTimeYears'].shape)
+ with_stars = data['StarCreationTimeYears'] > 0
+ star_age[with_stars] = data.pf.time_units['years'] * \
+ data.pf["InitialTime"] - \
+ data['StarCreationTimeYears'][with_stars]
+ return star_age
+add_field('StarAgeYears', units="\rm{yr}",
+ function=_StarAge,
+ projection_conversion="1")
#
# Now we do overrides for 2D fields
Modified: trunk/yt/lagos/HaloFinding.py
==============================================================================
--- trunk/yt/lagos/HaloFinding.py (original)
+++ trunk/yt/lagos/HaloFinding.py Sun Dec 20 12:42:09 2009
@@ -26,6 +26,7 @@
"""
from yt.lagos import *
+from yt.math_utils import *
from yt.lagos.hop.EnzoHop import RunHOP
try:
from yt.lagos.parallelHOP.parallelHOP import *
@@ -39,7 +40,7 @@
from yt.performance_counters import yt_counters, time_function
from kd import *
-import math, sys
+import math, sys, itertools
from collections import defaultdict
class Halo(object):
@@ -61,7 +62,10 @@
self.halo_list = halo_list
self.id = id
self.data = halo_list._data_source
- if indices is not None: self.indices = halo_list._base_indices[indices]
+ if indices is not None:
+ self.indices = halo_list._base_indices[indices]
+ else:
+ self.indices = None
# We assume that if indices = None, the instantiator has OTHER plans
# for us -- i.e., setting it somehow else
self.size = size
@@ -71,6 +75,8 @@
self.max_radius = max_radius
self.bulk_vel = bulk_vel
self.tasks = tasks
+ self.bin_count = None
+ self.overdensity = None
def center_of_mass(self):
"""
@@ -165,13 +171,107 @@
# set attributes on n
self._processing = False
+ def virial_mass(self, virial_overdensity=200., bins=300):
+ """
+ Return the virial mass of the halo in Msun, using only the particles
+ in the halo (no baryonic information used).
+ Calculate using *bins* number of bins and *virial_overdensity* density
+ threshold. Returns -1 if the halo is not virialized.
+ """
+ self.virial_info(bins=bins)
+ vir_bin = self.virial_bin(virial_overdensity=virial_overdensity, bins=bins)
+ if vir_bin != -1:
+ return self.mass_bins[vir_bin]
+ else:
+ return -1
+
+
+ def virial_radius(self, virial_overdensity=200., bins=300):
+ """
+ Return the virial radius of the halo in code units, using only the
+ particles in the halo (no baryonic information used).
+ Calculate using *bins* number of bins and *virial_overdensity* density
+ threshold. Returns -1 if the halo is not virialized.
+ """
+ self.virial_info(bins=bins)
+ vir_bin = self.virial_bin(virial_overdensity=virial_overdensity, bins=bins)
+ if vir_bin != -1:
+ return self.radial_bins[vir_bin]
+ else:
+ return -1
+
+ def virial_bin(self, virial_overdensity=200., bins=300):
+ """
+ Return the bin index for the virial radius for the given halo.
+ Returns -1 if the halo is not virialized to the set
+ *virial_overdensity*.
+ """
+ self.virial_info(bins=bins)
+ over = (self.overdensity > virial_overdensity)
+ if (over == True).any():
+ vir_bin = max(na.arange(bins+1)[over])
+ return vir_bin
+ else:
+ return -1
+
+ def virial_info(self, bins=300):
+ """
+ Calculate the virial profile bins for this halo, using only the particles
+ in the halo (no baryonic information used).
+ Calculate using *bins* number of bins.
+ """
+ # Skip if we've already calculated for this number of bins.
+ if self.bin_count == bins and self.overdensity is not None:
+ return None
+ self.bin_count = bins
+ # Cosmology
+ h = self.halo_list._data_source.pf['CosmologyHubbleConstantNow']
+ Om_matter = self.halo_list._data_source.pf['CosmologyOmegaMatterNow']
+ z = self.halo_list._data_source.pf['CosmologyCurrentRedshift']
+ rho_crit_now = 1.8788e-29 * h**2.0 * Om_matter # g cm^-3
+ Msun2g = 1.989e33
+ rho_crit = rho_crit_now * ((1.0 + z)**3.0)
+
+ # Get some pertinent information about the halo.
+ self.mass_bins = na.zeros(self.bin_count+1, dtype='float64')
+ dist = na.empty(self.indices.size, dtype='float64')
+ cen = self.center_of_mass()
+ period = self.halo_list._data_source.pf["DomainRightEdge"] - \
+ self.halo_list._data_source.pf["DomainLeftEdge"]
+ mark = 0
+ # Find the distances to the particles. I don't like this much, but I
+ # can't see a way to eliminate a loop like this, either here or in
+ # yt.math.
+ for pos in izip(self["particle_position_x"], self["particle_position_y"],
+ self["particle_position_z"]):
+ dist[mark] = periodic_dist(cen, pos, period)
+ mark += 1
+ # Set up the radial bins.
+ # Multiply min and max to prevent issues with digitize below.
+ self.radial_bins = na.logspace(math.log10(min(dist)*.99),
+ math.log10(max(dist)*1.01), num=self.bin_count+1)
+ # Find out which bin each particle goes into, and add the particle
+ # mass to that bin.
+ inds = na.digitize(dist, self.radial_bins) - 1
+ for index in na.unique(inds):
+ self.mass_bins[index] += sum(self["ParticleMassMsun"][inds==index])
+ # Now forward sum the masses in the bins.
+ for i in xrange(self.bin_count):
+ self.mass_bins[i+1] += self.mass_bins[i]
+ # Calculate the over densities in the bins.
+ self.overdensity = self.mass_bins * Msun2g / \
+ (4./3. * math.pi * rho_crit * \
+ (self.radial_bins * self.halo_list._data_source.pf["cm"])**3.0)
+
+
class HOPHalo(Halo):
pass
class parallelHOPHalo(Halo,ParallelAnalysisInterface):
dont_wrap = ["maximum_density","maximum_density_location",
"center_of_mass","total_mass","bulk_velocity","maximum_radius",
- "get_size","get_sphere", "write_particle_list","__getitem__"]
+ "get_size","get_sphere", "write_particle_list","__getitem__",
+ "virial_info", "virial_bin", "virial_mass", "virial_radius"]
def maximum_density(self):
"""
@@ -296,6 +396,118 @@
global_size = self._mpi_allsum(my_size)
return global_size
+ def __getitem__(self, key):
+ if ytcfg.getboolean("yt","inline") == False:
+ return self.data.particles[key][self.indices]
+ else:
+ return self.data[key][self.indices]
+
+ def virial_mass(self, virial_overdensity=200., bins=300):
+ """
+ Return the virial mass of the halo in Msun, using only the particles
+ in the halo (no baryonic information used).
+ Calculate using *bins* number of bins and *virial_overdensity* density
+ threshold. Returns -1 if the halo is not virialized.
+ """
+ self.virial_info(bins=bins)
+ vir_bin = self.virial_bin(virial_overdensity=virial_overdensity, bins=bins)
+ if vir_bin != -1:
+ return self.mass_bins[vir_bin]
+ else:
+ return -1
+
+
+ def virial_radius(self, virial_overdensity=200., bins=300):
+ """
+ Return the virial radius of the halo in code units, using only the
+ particles in the halo (no baryonic information used).
+ Calculate using *bins* number of bins and *virial_overdensity* density
+ threshold. Returns -1 if the halo is not virialized.
+ """
+ self.virial_info(bins=bins)
+ vir_bin = self.virial_bin(virial_overdensity=virial_overdensity, bins=bins)
+ if vir_bin != -1:
+ return self.radial_bins[vir_bin]
+ else:
+ return -1
+
+ def virial_bin(self, virial_overdensity=200., bins=300):
+ """
+ Return the bin index for the virial radius for the given halo.
+ Returns -1 if the halo is not virialized to the set
+ *virial_overdensity*.
+ """
+ self.virial_info(bins=bins)
+ over = (self.overdensity > virial_overdensity)
+ if (over == True).any():
+ vir_bin = max(na.arange(bins+1)[over])
+ return vir_bin
+ else:
+ return -1
+
+ def virial_info(self, bins=300):
+ """
+ Calculate the virial profile bins for this halo, using only the particles
+ in the halo (no baryonic information used).
+ Calculate using *bins* number of bins.
+ """
+ # Skip if we've already calculated for this number of bins.
+ if self.bin_count == bins and self.overdensity is not None:
+ return None
+ # Do this for all because all will use it.
+ self.bin_count = bins
+ period = self.halo_list._data_source.pf["DomainRightEdge"] - \
+ self.halo_list._data_source.pf["DomainLeftEdge"]
+ self.mass_bins = na.zeros(self.bin_count+1, dtype='float64')
+ cen = self.center_of_mass()
+ # Cosmology
+ h = self.halo_list._data_source.pf['CosmologyHubbleConstantNow']
+ Om_matter = self.halo_list._data_source.pf['CosmologyOmegaMatterNow']
+ z = self.halo_list._data_source.pf['CosmologyCurrentRedshift']
+ rho_crit_now = 1.8788e-29 * h**2.0 * Om_matter # g cm^-3
+ Msun2g = 1.989e33
+ rho_crit = rho_crit_now * ((1.0 + z)**3.0)
+ # If I own some of this halo operate on the particles.
+ if self.indices is not None:
+ # Get some pertinent information about the halo.
+ dist = na.empty(self.indices.size, dtype='float64')
+ mark = 0
+ # Find the distances to the particles. I don't like this much, but I
+ # can't see a way to eliminate a loop like this, either here or in
+ # yt.math.
+ for pos in izip(self["particle_position_x"], self["particle_position_y"],
+ self["particle_position_z"]):
+ dist[mark] = periodic_dist(cen, pos, period)
+ mark += 1
+ dist_min, dist_max = min(dist), max(dist)
+ # If I don't have this halo, make some dummy values.
+ else:
+ dist_min = max(period)
+ dist_max = 0.0
+ # In this parallel case, we're going to find the global dist extrema
+ # and built identical bins on all tasks.
+ dist_min = self._mpi_allmin(dist_min)
+ dist_max = self._mpi_allmax(dist_max)
+ # Set up the radial bins.
+ # Multiply min and max to prevent issues with digitize below.
+ self.radial_bins = na.logspace(math.log10(dist_min*.99),
+ math.log10(dist_max*1.01), num=self.bin_count+1)
+ if self.indices is not None:
+ # Find out which bin each particle goes into, and add the particle
+ # mass to that bin.
+ inds = na.digitize(dist, self.radial_bins) - 1
+ for index in na.unique(inds):
+ self.mass_bins[index] += sum(self["ParticleMassMsun"][inds==index])
+ # Now forward sum the masses in the bins.
+ for i in xrange(self.bin_count):
+ self.mass_bins[i+1] += self.mass_bins[i]
+ # Sum up the mass_bins globally
+ self.mass_bins = self._mpi_Allsum_double(self.mass_bins)
+ # Calculate the over densities in the bins.
+ self.overdensity = self.mass_bins * Msun2g / \
+ (4./3. * math.pi * rho_crit * \
+ (self.radial_bins * self.halo_list._data_source.pf["cm"])**3.0)
+
class FOFHalo(Halo):
@@ -735,7 +947,7 @@
self.max_dens_point[index][2], self.max_dens_point[index][3]]
index += 1
# Clean up
- del self.max_dens_point, self.Tot_M, self.max_radius, self.bulk_vel
+ del self.max_dens_point, self.max_radius, self.bulk_vel
del self.halo_taskmap, self.tags
def __len__(self):
@@ -802,7 +1014,7 @@
halo._owner = proc
id += 1
def haloCmp(h1,h2):
- c = cmp(h1.get_size(),h2.get_size())
+ c = cmp(h1.total_mass(),h2.total_mass())
if c != 0:
return -1 * c
if c == 0:
@@ -970,11 +1182,12 @@
yt_counters("Final Grouping")
def _join_halolists(self):
- gs = -self.group_sizes.copy()
+ ms = -self.Tot_M.copy()
+ del self.Tot_M
Cx = self.CoM[:,0].copy()
indexes = na.arange(self.group_count)
- sorted = na.asarray([indexes[i] for i in na.lexsort([indexes, Cx, gs])])
- del indexes, Cx, gs
+ sorted = na.asarray([indexes[i] for i in na.lexsort([indexes, Cx, ms])])
+ del indexes, Cx, ms
self._groups = self._groups[sorted]
self._max_dens = self._max_dens[sorted]
for i in xrange(self.group_count):
Modified: trunk/yt/lagos/ParallelTools.py
==============================================================================
--- trunk/yt/lagos/ParallelTools.py (original)
+++ trunk/yt/lagos/ParallelTools.py Sun Dec 20 12:42:09 2009
@@ -780,7 +780,7 @@
def _mpi_exit_test(self, data=False):
# data==True -> exit. data==False -> no exit
- statuses = self._mpi_info_dict(data)
+ mine, statuses = self._mpi_info_dict(data)
if True in statuses.values():
raise RunTimeError("Fatal error. Exiting.")
return None
Added: trunk/yt/lagos/parallelHOP/__init__.py
==============================================================================
--- (empty file)
+++ trunk/yt/lagos/parallelHOP/__init__.py Sun Dec 20 12:42:09 2009
@@ -0,0 +1,3 @@
+from yt.lagos import *
+
+from parallelHOP import *
Added: trunk/yt/lagos/parallelHOP/parallelHOP.py
==============================================================================
--- (empty file)
+++ trunk/yt/lagos/parallelHOP/parallelHOP.py Sun Dec 20 12:42:09 2009
@@ -0,0 +1,1479 @@
+"""
+A implementation of the HOP algorithm that runs in parallel.
+
+Author: Stephen Skory <sskory at physics.ucsd.edu>
+Affiliation: UCSD/CASS
+Homepage: http://yt.enzotools.org/
+License:
+ Copyright (C) 2008-2009 Stephen Skory. All Rights Reserved.
+
+ This file is part of yt.
+
+ yt is free software; you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation; either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+
+from collections import defaultdict
+import itertools, sys
+
+from yt.lagos import *
+from yt.funcs import *
+from yt.extensions.kdtree import *
+from yt.performance_counters import yt_counters, time_function
+
+class RunParallelHOP(ParallelAnalysisInterface):
+ def __init__(self,period, padding, num_neighbors, bounds,
+ xpos, ypos, zpos, index, mass, threshold=160.0, rearrange=True):
+ self.threshold = threshold
+ self.rearrange = rearrange
+ self.saddlethresh = 2.5 * threshold
+ self.peakthresh = 3 * threshold
+ self.period = period
+ self.padding = padding
+ self.num_neighbors = num_neighbors
+ self.bounds = bounds
+ self.xpos = xpos
+ self.ypos = ypos
+ self.zpos = zpos
+ self.real_size = len(self.xpos)
+ self.index = na.array(index, dtype='int64')
+ self.mass = mass
+ self.padded_particles = []
+ self.nMerge = 4
+ yt_counters("chainHOP")
+ self.max_mem = 0
+ self.__max_memory()
+ self._chain_hop()
+ yt_counters("chainHOP")
+
+ def _global_bounds_neighbors(self):
+ """
+ Build a dict of the boundaries of all the tasks, and figure out which
+ tasks are our geometric neighbors.
+ """
+ self.neighbors = set([])
+ self.mine, global_bounds = self._mpi_info_dict(self.bounds)
+ my_LE, my_RE = self.bounds
+ # Put the vertices into a big list, each row is
+ # array[x,y,z, taskID]
+ vertices = []
+ my_vertices = []
+ for taskID in global_bounds:
+ thisLE, thisRE = global_bounds[taskID]
+ if self.mine != taskID:
+ vertices.append(na.array([thisLE[0], thisLE[1], thisLE[2], taskID]))
+ vertices.append(na.array([thisLE[0], thisLE[1], thisRE[2], taskID]))
+ vertices.append(na.array([thisLE[0], thisRE[1], thisLE[2], taskID]))
+ vertices.append(na.array([thisRE[0], thisLE[1], thisLE[2], taskID]))
+ vertices.append(na.array([thisLE[0], thisRE[1], thisRE[2], taskID]))
+ vertices.append(na.array([thisRE[0], thisLE[1], thisRE[2], taskID]))
+ vertices.append(na.array([thisRE[0], thisRE[1], thisLE[2], taskID]))
+ vertices.append(na.array([thisRE[0], thisRE[1], thisRE[2], taskID]))
+ if self.mine == taskID:
+ my_vertices.append(na.array([thisLE[0], thisLE[1], thisLE[2]]))
+ my_vertices.append(na.array([thisLE[0], thisLE[1], thisRE[2]]))
+ my_vertices.append(na.array([thisLE[0], thisRE[1], thisLE[2]]))
+ my_vertices.append(na.array([thisRE[0], thisLE[1], thisLE[2]]))
+ my_vertices.append(na.array([thisLE[0], thisRE[1], thisRE[2]]))
+ my_vertices.append(na.array([thisRE[0], thisLE[1], thisRE[2]]))
+ my_vertices.append(na.array([thisRE[0], thisRE[1], thisLE[2]]))
+ my_vertices.append(na.array([thisRE[0], thisRE[1], thisRE[2]]))
+ # Find the neighbors we share corners with. Yes, this is lazy with
+ # a double loop, but it works and this is definitely not a performance
+ # bottleneck.
+ for my_vertex in my_vertices:
+ for vertex in vertices:
+ if vertex[3] in self.neighbors: continue
+ # If the corners touch, it's easy. This is the case if resizing
+ # (load-balancing) is turned off.
+ if (my_vertex % self.period == vertex[0:3] % self.period).all():
+ self.neighbors.add(int(vertex[3]))
+ continue
+ # Also test to see if the distance to this corner is within
+ # max_padding, which is more likely the case with load-balancing
+ # turned on.
+ dx = na.min( na.abs(my_vertex[0] - vertex[0]), \
+ self.period[0] - na.abs(my_vertex[0] - vertex[0]))
+ dy = na.min( na.abs(my_vertex[1] - vertex[1]), \
+ self.period[1] - na.abs(my_vertex[1] - vertex[1]))
+ dz = na.min( na.abs(my_vertex[2] - vertex[2]), \
+ self.period[2] - na.abs(my_vertex[2] - vertex[2]))
+ d = na.sqrt(dx*dx + dy*dy + dz*dz)
+ if d <= self.max_padding:
+ self.neighbors.add(int(vertex[3]))
+ # Faces and edges.
+ for dim in range(3):
+ dim1 = (dim + 1) % 3
+ dim2 = (dim + 2) % 3
+ left_face = my_LE[dim]
+ right_face = my_RE[dim]
+ for taskID in global_bounds:
+ if taskID == self.mine or taskID in self.neighbors: continue
+ thisLE, thisRE = global_bounds[taskID]
+ max1 = max(my_LE[dim1], thisLE[dim1])
+ max2 = max(my_LE[dim2], thisLE[dim2])
+ min1 = min(my_RE[dim1], thisRE[dim1])
+ min2 = min(my_RE[dim2], thisRE[dim2])
+ # Faces.
+ # First, faces that touch directly.
+ if (thisRE[dim] == left_face or thisRE[dim]%self.period[dim] == left_face) and \
+ max1 <= min1 and max2 <= min2:
+ self.neighbors.add(taskID)
+ continue
+ elif (thisLE[dim] == right_face or thisLE[dim] == right_face%self.period[dim]) and \
+ max1 <= min1 and max2 <= min2:
+ self.neighbors.add(taskID)
+ continue
+ # If an intervening subvolume has a width less than the padding
+ # (rare, but possible), a neighbor may not actually touch, so
+ # we need to account for that.
+ if (abs(thisRE[dim] - left_face) <= self.max_padding or \
+ abs(thisRE[dim]%self.period[dim] - left_face) <= self.max_padding) and \
+ max1 <= min1 and max2 <= min2:
+ self.neighbors.add(taskID)
+ continue
+ elif (abs(thisLE[dim] - right_face) <= self.max_padding or \
+ abs(thisLE[dim] - right_face%self.period[dim]) <= self.max_padding) and \
+ max1 <= min1 and max2 <= min2:
+ self.neighbors.add(taskID)
+ continue
+ # Edges.
+ # First, edges that touch.
+ elif (my_LE[dim] == (thisRE[dim]%self.period[dim]) and \
+ my_LE[dim1] == (thisRE[dim1]%self.period[dim1]) and \
+ max2 <= min2) or \
+ (my_LE[dim] == (thisRE[dim]%self.period[dim]) and \
+ my_LE[dim2] == (thisRE[dim2]%self.period[dim2]) and \
+ max1 <= min1):
+ self.neighbors.add(taskID)
+ continue
+ elif ((my_RE[dim]%self.period[dim]) == thisLE[dim] and \
+ (my_RE[dim1]%self.period[dim1]) == thisLE[dim1] and \
+ max2 <= min2) or \
+ ((my_RE[dim]%self.period[dim]) == thisLE[dim] and \
+ (my_RE[dim2]%self.period[dim2]) == thisLE[dim2] and \
+ max1 <= min1):
+ self.neighbors.add(taskID)
+ continue
+ # Now edges that don't touch, but are close.
+ if (abs(my_LE[dim] - thisRE[dim]%self.period[dim]) <= self.max_padding and \
+ abs(my_LE[dim1] - thisRE[dim1]%self.period[dim1]) <= self.max_padding and \
+ max2 <= min2) or \
+ (abs(my_LE[dim] - thisRE[dim]%self.period[dim]) <= self.max_padding and \
+ abs(my_LE[dim2] - thisRE[dim2]%self.period[dim2]) <= self.max_padding and \
+ max1 <= min1):
+ self.neighbors.add(taskID)
+ continue
+ elif (abs(my_RE[dim]%self.period[dim] - thisLE[dim]) <= self.max_padding and \
+ abs(my_RE[dim1]%self.period[dim1] - thisLE[dim1]) <= self.max_padding and \
+ max2 <= min2) or \
+ (abs(my_RE[dim]%self.period[dim] - thisLE[dim]) <= self.max_padding and \
+ abs(my_RE[dim2]%self.period[dim2] - thisLE[dim2]) <= self.max_padding and \
+ max1 <= min1):
+ self.neighbors.add(taskID)
+ continue
+ # Now we build a global dict of neighbor sets, and if a remote task
+ # lists us as their neighbor, we add them as our neighbor. This is
+ # probably not needed because the stuff above should be symmetric,
+ # but it isn't a big issue.
+ self.mine, global_neighbors = self._mpi_info_dict(self.neighbors)
+ for taskID in global_neighbors:
+ if taskID == self.mine: continue
+ if self.mine in global_neighbors[taskID]:
+ self.neighbors.add(taskID)
+ # We can remove ourselves from the set if it got added somehow.
+ self.neighbors.discard(self.mine)
+ # Clean up.
+ del global_neighbors, global_bounds, vertices, my_vertices
+
+ def _global_padding(self, round):
+ """
+ Find the maximum padding of all our neighbors, used to send our
+ annulus data.
+ """
+ if round == 'first':
+ max_pad = na.max(self.padding)
+ self.mine, self.global_padding = self._mpi_info_dict(max_pad)
+ self.max_padding = max(self.global_padding.itervalues())
+ elif round == 'second':
+ self.max_padding = 0.
+ for neighbor in self.neighbors:
+ self.max_padding = na.maximum(self.global_padding[neighbor], \
+ self.max_padding)
+
+ def _communicate_padding_data(self):
+ """
+ Send the particles each of my neighbors need to build up their padding.
+ """
+ yt_counters("Communicate discriminated padding")
+ # First build a global dict of the padded boundaries of all the tasks.
+ (LE, RE) = self.bounds
+ (LE_padding, RE_padding) = self.padding
+ temp_LE = LE - LE_padding
+ temp_RE = RE + RE_padding
+ expanded_bounds = (temp_LE, temp_RE)
+ self.mine, global_exp_bounds = self._mpi_info_dict(expanded_bounds)
+ send_real_indices = {}
+ send_points = {}
+ send_mass = {}
+ send_size = {}
+ # This will reduce the size of the loop over particles.
+ yt_counters("Picking padding data to send.")
+ send_count = len(na.where(self.is_inside_annulus == True)[0])
+ points = na.empty((send_count, 3), dtype='float64')
+ points[:,0] = self.xpos[self.is_inside_annulus]
+ points[:,1] = self.ypos[self.is_inside_annulus]
+ points[:,2] = self.zpos[self.is_inside_annulus]
+ real_indices = self.index[self.is_inside_annulus].astype('int64')
+ mass = self.mass[self.is_inside_annulus].astype('float64')
+ # Make the arrays to send.
+ shift_points = points.copy()
+ for neighbor in self.neighbors:
+ temp_LE, temp_RE = global_exp_bounds[neighbor]
+ for i in xrange(3):
+ left = ((points[:,i] < temp_LE[i]) * (points[:,i] < temp_RE[i])) * self.period[i]
+ right = ((points[:,i] > temp_LE[i]) * (points[:,i] > temp_RE[i])) * self.period[i]
+ shift_points[:,i] = points[:,i] + left - right
+ is_inside = ( (shift_points >= temp_LE).all(axis=1) * \
+ (shift_points < temp_RE).all(axis=1) )
+ send_real_indices[neighbor] = real_indices[is_inside].copy()
+ send_points[neighbor] = shift_points[is_inside].copy()
+ send_mass[neighbor] = mass[is_inside].copy()
+ send_size[neighbor] = len(na.where(is_inside == True)[0])
+ del points, shift_points, mass, real_indices
+ yt_counters("Picking padding data to send.")
+ # Communicate the sizes to send.
+ self.mine, global_send_count = self._mpi_info_dict(send_size)
+ # Initialize the arrays to receive data.
+ yt_counters("Initalizing recv arrays.")
+ recv_real_indices = {}
+ recv_points = {}
+ recv_mass = {}
+ recv_size = 0
+ for opp_neighbor in self.neighbors:
+ opp_size = global_send_count[opp_neighbor][self.mine]
+ recv_real_indices[opp_neighbor] = na.empty(opp_size, dtype='int64')
+ recv_points[opp_neighbor] = na.empty((opp_size, 3), dtype='float64')
+ recv_mass[opp_neighbor] = na.empty(opp_size, dtype='float64')
+ recv_size += opp_size
+ yt_counters("Initalizing recv arrays.")
+ # Setup the receiving slots.
+ yt_counters("MPI stuff.")
+ hooks = []
+ for opp_neighbor in self.neighbors:
+ hooks.append(self._mpi_Irecv_long(recv_real_indices[opp_neighbor], opp_neighbor))
+ hooks.append(self._mpi_Irecv_double(recv_points[opp_neighbor], opp_neighbor))
+ hooks.append(self._mpi_Irecv_double(recv_mass[opp_neighbor], opp_neighbor))
+ # Let's wait here to be absolutely sure that all the receive buffers
+ # have been created before any sending happens!
+ self._barrier()
+ # Now we send the data.
+ for neighbor in self.neighbors:
+ hooks.append(self._mpi_Isend_long(send_real_indices[neighbor], neighbor))
+ hooks.append(self._mpi_Isend_double(send_points[neighbor], neighbor))
+ hooks.append(self._mpi_Isend_double(send_mass[neighbor], neighbor))
+ # Now we use the data, after all the comms are done.
+ self._mpi_Request_Waitall(hooks)
+ yt_counters("MPI stuff.")
+ yt_counters("Processing padded data.")
+ del send_real_indices, send_points, send_mass
+ # Now we add the data to ourselves.
+ self.index_pad = na.empty(recv_size, dtype='int64')
+ self.xpos_pad = na.empty(recv_size, dtype='float64')
+ self.ypos_pad = na.empty(recv_size, dtype='float64')
+ self.zpos_pad = na.empty(recv_size, dtype='float64')
+ self.mass_pad = na.empty(recv_size, dtype='float64')
+ so_far = 0
+ for opp_neighbor in self.neighbors:
+ opp_size = global_send_count[opp_neighbor][self.mine]
+ self.index_pad[so_far:so_far+opp_size] = recv_real_indices[opp_neighbor]
+ # Clean up immediately to reduce peak memory usage.
+ del recv_real_indices[opp_neighbor]
+ self.xpos_pad[so_far:so_far+opp_size] = recv_points[opp_neighbor][:,0]
+ self.ypos_pad[so_far:so_far+opp_size] = recv_points[opp_neighbor][:,1]
+ self.zpos_pad[so_far:so_far+opp_size] = recv_points[opp_neighbor][:,2]
+ del recv_points[opp_neighbor]
+ self.mass_pad[so_far:so_far+opp_size] = recv_mass[opp_neighbor]
+ del recv_mass[opp_neighbor]
+ so_far += opp_size
+ yt_counters("Processing padded data.")
+ # The KDtree node search wants the particles to be in the full box,
+ # not the expanded dimensions of shifted (<0 or >1 generally) volume,
+ # so we fix the positions of particles here.
+ yt_counters("Flipping coordinates around the periodic boundary.")
+ self.xpos_pad = self.xpos_pad % self.period[0]
+ self.ypos_pad = self.ypos_pad % self.period[1]
+ self.zpos_pad = self.zpos_pad % self.period[2]
+ yt_counters("Flipping coordinates around the periodic boundary.")
+ self.size = self.index.size + self.index_pad.size
+ # Now that we have the full size, initialize the chainID array
+ self.chainID = na.ones(self.size,dtype='int64') * -1
+ # Clean up explicitly, but these should be empty dicts by now.
+ del recv_real_indices, hooks, recv_points, recv_mass
+ yt_counters("Communicate discriminated padding")
+
+ def _init_kd_tree(self):
+ """
+ Set up the data objects that get passed to the kD-tree code.
+ """
+ yt_counters("init kd tree")
+ # Yes, we really do need to initialize this many arrays.
+ # They're deleted in _parallelHOP.
+ fKD.dens = na.asfortranarray(na.zeros(self.size, dtype='float64'))
+ fKD.mass = na.concatenate((self.mass, self.mass_pad))
+ fKD.pos = na.asfortranarray(na.empty((3, self.size), dtype='float64'))
+ # This actually copies the data into the fortran space.
+ fKD.pos[0, :] = na.concatenate((self.xpos, self.xpos_pad))
+ fKD.pos[1, :] = na.concatenate((self.ypos, self.ypos_pad))
+ fKD.pos[2, :] = na.concatenate((self.zpos, self.zpos_pad))
+ fKD.qv = na.asfortranarray(na.empty(3, dtype='float64'))
+ fKD.nn = self.num_neighbors
+ # Plus 2 because we're looking for that neighbor, but only keeping
+ # nMerge + 1 neighbor tags, skipping ourselves.
+ fKD.nMerge = self.nMerge + 2
+ fKD.nparts = self.size
+ fKD.sort = True # Slower, but needed in _connect_chains
+ fKD.rearrange = self.rearrange # True is faster, but uses more memory
+ # Now call the fortran.
+ create_tree()
+ self.__max_memory()
+ yt_counters("init kd tree")
+
+ def _is_inside(self, round):
+ """
+ There are three classes of particles.
+ 1. Particles inside the 'real' region of each subvolume.
+ 2. Particles ouside, added in the 'padding' for purposes of having
+ correct particle densities in the real region.
+ 3. Particles that are one padding distance inside the edges of the
+ real region. The chainIDs of these particles are communicated
+ to the neighboring tasks so chains can be merged into groups.
+ The input *round* is either 'first' or 'second.' First is before the
+ padded particles have been communicated, and second after.
+ """
+ # Test to see if the points are in the 'real' region
+ (LE, RE) = self.bounds
+ if round == 'first':
+ points = na.empty((self.real_size, 3), dtype='float64')
+ points[:,0] = self.xpos
+ points[:,1] = self.ypos
+ points[:,2] = self.zpos
+ self.is_inside = ( (points >= LE).all(axis=1) * \
+ (points < RE).all(axis=1) )
+ elif round == 'second':
+ self.is_inside = ( (fKD.pos.T >= LE).all(axis=1) * \
+ (fKD.pos.T < RE).all(axis=1) )
+ # Below we find out which particles are in the `annulus', one padding
+ # distance inside the boundaries. First we find the particles outside
+ # this inner boundary.
+ temp_LE = LE + self.max_padding
+ temp_RE = RE - self.max_padding
+ if round == 'first':
+ inner = na.invert( (points >= temp_LE).all(axis=1) * \
+ (points < temp_RE).all(axis=1) )
+ elif round == 'second' or round == 'third':
+ inner = na.invert( (fKD.pos.T >= temp_LE).all(axis=1) * \
+ (fKD.pos.T < temp_RE).all(axis=1) )
+ if round == 'first':
+ del points
+ # After inverting the logic above, we want points that are both
+ # inside the real region, but within one padding of the boundary,
+ # and this will do it.
+ self.is_inside_annulus = na.bitwise_and(self.is_inside, inner)
+ # Below we make a mapping of real particle index->local ID
+ # Unf. this has to be a dict, because any task can have
+ # particles of any particle_index, which means that if it were an
+ # array every task would probably end up having this array be as long
+ # as the full number of particles.
+ # We can skip this the first two times around.
+ if round == 'third':
+ temp = na.arange(self.size)
+ my_part = na.bitwise_or(na.invert(self.is_inside), self.is_inside_annulus)
+ my_part = na.bitwise_and(my_part, (self.chainID != -1))
+ catted_indices = na.concatenate(
+ (self.index, self.index_pad))[my_part]
+ self.rev_index = dict.fromkeys(catted_indices)
+ self.rev_index.update(itertools.izip(catted_indices, temp[my_part]))
+ del my_part, temp, catted_indices
+ self.__max_memory()
+
+ def _densestNN(self):
+ """
+ For all particles, find their densest nearest neighbor. It is done in
+ chunks to keep the memory usage down.
+ The first search of nearest neighbors (done earlier) did not return all
+ num_neighbor neighbors, so we need to do it again, but we're not
+ keeping the all of this data, just using it.
+ """
+ yt_counters("densestNN")
+ self.densestNN = na.empty(self.size,dtype='int64')
+ # We find nearest neighbors in chunks.
+ chunksize = 10000
+ fKD.chunk_tags = na.asfortranarray(na.empty((self.num_neighbors, chunksize), dtype='int64'))
+ start = 1 # Fortran counting!
+ finish = 0
+ while finish < self.size:
+ finish = min(finish+chunksize,self.size)
+ # Call the fortran. start and finish refer to the data locations
+ # in fKD.pos, and specify the range of particles to find nearest
+ # neighbors
+ fKD.start = start
+ fKD.finish = finish
+ find_chunk_nearest_neighbors()
+ chunk_NNtags = (fKD.chunk_tags[:,:finish-start+1] - 1).transpose()
+ # Find the densest nearest neighbors by referencing the already
+ # calculated density.
+ n_dens = na.take(self.density,chunk_NNtags)
+ max_loc = na.argmax(n_dens,axis=1)
+ for i in xrange(finish - start + 1): # +1 for fortran counting.
+ j = start + i - 1 # -1 for fortran counting.
+ self.densestNN[j] = chunk_NNtags[i,max_loc[i]]
+ start = finish + 1
+ yt_counters("densestNN")
+ self.__max_memory()
+ del chunk_NNtags, max_loc, n_dens
+
+ def _build_chains(self):
+ """
+ Build the first round of particle chains. If the particle is too low in
+ density, move on.
+ """
+ yt_counters("build_chains")
+ chainIDmax = 0
+ self.densest_in_chain = na.ones(10000, dtype='float64') * -1 # chainID->density, one to one
+ self.densest_in_chain_real_index = na.ones(10000, dtype='int64') * -1 # chainID->real_index, one to one
+ for i in xrange(int(self.size)):
+ # If it's already in a group, move on, or if this particle is
+ # in the padding, move on because chains can only terminate in
+ # the padding, not begin, or if this particle is too low in
+ # density, move on.
+ if self.chainID[i] > -1 or not self.is_inside[i] or \
+ self.density[i] < self.threshold:
+ continue
+ chainIDnew = self._recurse_links(i, chainIDmax)
+ # If the new chainID returned is the same as we entered, the chain
+ # has been named chainIDmax, so we need to start a new chain
+ # in the next loop.
+ if chainIDnew == chainIDmax:
+ chainIDmax += 1
+ self.padded_particles = na.array(self.padded_particles, dtype='int64')
+ self.densest_in_chain = self.__clean_up_array(self.densest_in_chain)
+ self.densest_in_chain_real_index = self.__clean_up_array(self.densest_in_chain_real_index)
+ yt_counters("build_chains")
+ self.__max_memory()
+ return chainIDmax
+
+ def _recurse_links(self, pi, chainIDmax):
+ """
+ Recurse up the chain to a) a self-highest density particle,
+ b) a particle that already has a chainID, then turn it back around
+ assigning that chainID to where we came from. If c) which
+ is a particle in the padding, terminate the chain right then
+ and there, because chains only go one particle deep into the padding.
+ """
+ nn = self.densestNN[pi]
+ inside = self.is_inside[pi]
+ nn_chainID = self.chainID[nn]
+ # Linking to an already chainID-ed particle (don't make links from
+ # padded particles!)
+ if nn_chainID > -1 and inside:
+ self.chainID[pi] = nn_chainID
+ return nn_chainID
+ # If pi is a self-most dense particle or inside the padding, end/create
+ # a new chain.
+ elif nn == pi or not inside:
+ self.chainID[pi] = chainIDmax
+ self.densest_in_chain = self.__add_to_array(self.densest_in_chain,
+ chainIDmax, self.density[pi], 'float64')
+ if pi < self.real_size:
+ self.densest_in_chain_real_index = self.__add_to_array(self.densest_in_chain_real_index,
+ chainIDmax, self.index[pi], 'int64')
+ else:
+ self.densest_in_chain_real_index = self.__add_to_array(self.densest_in_chain_real_index,
+ chainIDmax, self.index_pad[pi-self.real_size], 'int64')
+ # if this is a padded particle, record it for later
+ if not inside:
+ self.padded_particles.append(pi)
+ return chainIDmax
+ # Otherwise, recursively link to nearest neighbors.
+ else:
+ chainIDnew = self._recurse_links(nn, chainIDmax)
+ self.chainID[pi] = chainIDnew
+ return chainIDnew
+
+ def _recurse_preconnected_links(self, chain_map, thisID):
+ if min(thisID, min(chain_map[thisID])) == thisID:
+ return thisID
+ else:
+ return self._recurse_preconnected_links(chain_map, min(chain_map[thisID]))
+
+ def _preconnect_chains(self, chain_count):
+ """
+ In each subvolume, chains that share a boundary that both have high
+ enough peak densities are prelinked in order to reduce the size of the
+ global chain objects. This is very similar to _connect_chains().
+ """
+ # First we'll sort them, which will be used below.
+ mylog.info("Locally sorting chains...")
+ yt_counters("preconnect_chains")
+ yt_counters("local chain sorting.")
+ sort = self.densest_in_chain.argsort()
+ sort = na.flipud(sort)
+ map = na.empty(sort.size,dtype='int64')
+ map[sort] = na.arange(sort.size)
+ self.densest_in_chain = self.densest_in_chain[sort]
+ self.densest_in_chain_real_index = self.densest_in_chain_real_index[sort]
+ del sort
+ for i,chID in enumerate(self.chainID):
+ if chID == -1: continue
+ self.chainID[i] = map[chID]
+ del map
+ yt_counters("local chain sorting.")
+ mylog.info("Preconnecting %d chains..." % chain_count)
+ chain_map = defaultdict(set)
+ for i in xrange(max(self.chainID)+1):
+ chain_map[i].add(i)
+ # Plus 2 because we're looking for that neighbor, but only keeping
+ # nMerge + 1 neighbor tags, skipping ourselves.
+ fKD.dist = na.empty(self.nMerge+2, dtype='float64')
+ fKD.tags = na.empty(self.nMerge+2, dtype='int64')
+ # We can change this here to make the searches faster.
+ fKD.nn = self.nMerge+2
+ yt_counters("preconnect kd tree search.")
+ for i in xrange(self.size):
+ # Don't consider this particle if it's not part of a chain.
+ if self.chainID[i] < 0: continue
+ chainID_i = self.chainID[i]
+ # If this particle is in the padding, don't make a connection.
+ if not self.is_inside[i]: continue
+ # Find this particle's chain max_dens.
+ part_max_dens = self.densest_in_chain[chainID_i]
+ # We're only connecting >= peakthresh chains now.
+ if part_max_dens < self.peakthresh: continue
+ # Loop over nMerge closest nearest neighbors.
+ fKD.qv = fKD.pos[:, i]
+ find_nn_nearest_neighbors()
+ NNtags = fKD.tags[:] - 1
+ same_count = 0
+ for j in xrange(int(self.nMerge+1)):
+ thisNN = NNtags[j+1] # Don't consider ourselves at NNtags[0]
+ thisNN_chainID = self.chainID[thisNN]
+ # If our neighbor is in the same chain, move on.
+ # Move on if these chains are already connected:
+ if chainID_i == thisNN_chainID or \
+ thisNN_chainID in chain_map[chainID_i]:
+ same_count += 1
+ continue
+ # Everything immediately below is for
+ # neighboring particles with a chainID.
+ if thisNN_chainID >= 0:
+ # Find thisNN's chain's max_dens.
+ thisNN_max_dens = self.densest_in_chain[thisNN_chainID]
+ # We're only linking peakthresh chains
+ if thisNN_max_dens < self.peakthresh: continue
+ # Calculate the two groups boundary density.
+ boundary_density = (self.density[thisNN] + self.density[i]) / 2.
+ # Don't connect if the boundary is too low.
+ if boundary_density < self.saddlethresh: continue
+ # Mark these chains as related.
+ chain_map[thisNN_chainID].add(chainID_i)
+ chain_map[chainID_i].add(thisNN_chainID)
+ if same_count == self.nMerge + 1:
+ # All our neighbors are in the same chain already, so
+ # we don't need to search again.
+ self.search_again[i] = False
+ yt_counters("preconnect kd tree search.")
+ # Recursively jump links until we get to a chain whose densest
+ # link is to itself. At that point we've found the densest chain
+ # in this set of sets and we keep a record of that.
+ yt_counters("preconnect pregrouping.")
+ final_chain_map = na.empty(max(self.chainID)+1, dtype='int64')
+ removed = 0
+ for i in xrange(max(self.chainID)+1):
+ j = chain_count - i - 1
+ densest_link = self._recurse_preconnected_links(chain_map, j)
+ final_chain_map[j] = densest_link
+ if j != densest_link:
+ removed += 1
+ self.densest_in_chain[j] = -1
+ self.densest_in_chain_real_index[j] = -1
+ del chain_map
+ for i in xrange(self.size):
+ if self.chainID[i] != -1:
+ self.chainID[i] = final_chain_map[self.chainID[i]]
+ del final_chain_map
+ # Now make the chainID assignments consecutive.
+ map = na.empty(self.densest_in_chain.size, dtype='int64')
+ dic_new = na.empty(chain_count - removed, dtype='float64')
+ dicri_new = na.empty(chain_count - removed, dtype='int64')
+ new = 0
+ for i,dic in enumerate(self.densest_in_chain):
+ if dic > 0:
+ map[i] = new
+ dic_new[new] = dic
+ dicri_new[new] = self.densest_in_chain_real_index[i]
+ new += 1
+ else:
+ map[i] = -1
+ for i in range(self.size):
+ if self.chainID[i] != -1:
+ self.chainID[i] = map[self.chainID[i]]
+ del map
+ self.densest_in_chain = dic_new.copy()
+ self.densest_in_chain_real_index = dicri_new.copy()
+ self.__max_memory()
+ yt_counters("preconnect pregrouping.")
+ mylog.info("Preconnected %d chains." % removed)
+ yt_counters("preconnect_chains")
+
+ return chain_count - removed
+
+ def _globally_assign_chainIDs(self, chain_count):
+ """
+ Convert local chainIDs into globally unique chainIDs.
+ """
+ yt_counters("globally_assign_chainIDs")
+ # First find out the number of chains on each processor.
+ self.mine, chain_info = self._mpi_info_dict(chain_count)
+ self.nchains = sum(chain_info.values())
+ # Figure out our offset.
+ self.my_first_id = sum([v for k,v in chain_info.iteritems() if k < self.mine])
+ # Change particle IDs, -1 always means no chain assignment.
+ select = (self.chainID != -1)
+ select = select * self.my_first_id
+ self.chainID += select
+ del select
+ yt_counters("globally_assign_chainIDs")
+
+ def _create_global_densest_in_chain(self):
+ """
+ With the globally unique chainIDs, update densest_in_chain.
+ """
+ yt_counters("create_global_densest_in_chain")
+ # Shift the values over effectively by concatenating them in the same
+ # order as the values have been shifted in _globally_assign_chainIDs()
+ yt_counters("global chain MPI stuff.")
+ self.densest_in_chain = self._mpi_concatenate_array_double(self.densest_in_chain)
+ self.densest_in_chain_real_index = self._mpi_concatenate_array_long(self.densest_in_chain_real_index)
+ yt_counters("global chain MPI stuff.")
+ # Sort the chains by density here. This is an attempt to make it such
+ # that the merging stuff in a few steps happens in the same order
+ # all the time.
+ mylog.info("Sorting chains...")
+ yt_counters("global chain sorting.")
+ sort = self.densest_in_chain.argsort()
+ sort = na.flipud(sort)
+ map = na.empty(sort.size,dtype='int64')
+ map[sort] =na.arange(sort.size)
+ self.densest_in_chain = self.densest_in_chain[sort]
+ self.densest_in_chain_real_index = self.densest_in_chain_real_index[sort]
+ del sort
+ for i,chID in enumerate(self.chainID):
+ if chID == -1: continue
+ self.chainID[i] = map[chID]
+ del map
+ yt_counters("global chain sorting.")
+ # For some reason chains that share the most-dense particle are not
+ # being linked, so we link them 'by hand' here.
+ mylog.info("Pre-linking chains 'by hand'...")
+ yt_counters("global chain hand-linking.")
+ # If there are no repeats, we can skip this mess entirely.
+ uniq = na.unique(self.densest_in_chain_real_index)
+ if uniq.size != self.densest_in_chain_real_index.size:
+ # Find only the real particle indices that are repeated to reduce
+ # the dict workload below.
+ dicri = self.densest_in_chain_real_index[self.densest_in_chain_real_index.argsort()]
+ diff = na.ediff1d(dicri)
+ diff = (diff == 0) # Picks out the places where the ids are equal
+ diff = na.concatenate((diff, [False])) # Makes it the same length
+ # This has only the repeated IDs. Sets are faster at searches than
+ # arrays.
+ dicri = set(dicri[diff])
+ reverse = defaultdict(set)
+ # Here we find a reverse mapping of real particle ID to chainID
+ for chainID, real_index in enumerate(self.densest_in_chain_real_index):
+ if real_index in dicri:
+ reverse[real_index].add(chainID)
+ del dicri, diff
+ # If the real index has len(set)>1, there are multiple chains that need
+ # to be linked
+ tolink = defaultdict(set)
+ for real in reverse:
+ if len(reverse[real]) > 1:
+ # Unf. can't slice a set, so this will have to do.
+ tolink[min(reverse[real])] = reverse[real]
+ tolink[min(reverse[real])].discard(min(reverse[real]))
+ del reverse
+ # Now we will remove the other chains from the dicts and re-assign
+ # particles to their new chainID.
+ fix_map = {}
+ for tokeep in tolink:
+ for remove in tolink[tokeep]:
+ fix_map[remove] = tokeep
+ self.densest_in_chain[remove] = -1.0
+ self.densest_in_chain_real_index[remove] = -1
+ for i, chainID in enumerate(self.chainID):
+ try:
+ new = fix_map[chainID]
+ except KeyError:
+ continue
+ self.chainID[i] = new
+ del tolink, fix_map
+ yt_counters("global chain hand-linking.")
+ yt_counters("create_global_densest_in_chain")
+
+ def _communicate_uphill_info(self):
+ """
+ Communicate the links to the correct neighbors from uphill_info.
+ """
+ yt_counters("communicate_uphill_info")
+ # Find out how many particles we're going to receive, and make arrays
+ # of the right size and type to store them.
+ to_recv_count = 0
+ temp_indices = dict.fromkeys(self.neighbors)
+ temp_chainIDs = dict.fromkeys(self.neighbors)
+ for opp_neighbor in self.neighbors:
+ opp_size = self.global_padded_count[opp_neighbor]
+ to_recv_count += opp_size
+ temp_indices[opp_neighbor] = na.empty(opp_size, dtype='int64')
+ temp_chainIDs[opp_neighbor] = na.empty(opp_size, dtype='int64')
+ # The arrays we'll actually keep around...
+ self.recv_real_indices = na.empty(to_recv_count, dtype='int64')
+ self.recv_chainIDs = na.empty(to_recv_count, dtype='int64')
+ # Set up the receives, but don't actually use them.
+ hooks = []
+ for opp_neighbor in self.neighbors:
+ hooks.append(self._mpi_Irecv_long(temp_indices[opp_neighbor], opp_neighbor))
+ hooks.append(self._mpi_Irecv_long(temp_chainIDs[opp_neighbor], opp_neighbor))
+ # Make sure all the receive buffers are set before continuing.
+ self._barrier()
+ # Send padded particles to our neighbors.
+ for neighbor in self.neighbors:
+ hooks.append(self._mpi_Isend_long(self.uphill_real_indices, neighbor))
+ hooks.append(self._mpi_Isend_long(self.uphill_chainIDs, neighbor))
+ # Now actually use the data once it's good to go.
+ self._mpi_Request_Waitall(hooks)
+ self.__max_memory()
+ so_far = 0
+ for opp_neighbor in self.neighbors:
+ opp_size = self.global_padded_count[opp_neighbor]
+ # Only save the part of the buffer that we want to the right places
+ # in the full listing.
+ self.recv_real_indices[so_far:(so_far + opp_size)] = \
+ temp_indices[opp_neighbor][0:opp_size]
+ self.recv_chainIDs[so_far:(so_far + opp_size)] = \
+ temp_chainIDs[opp_neighbor][0:opp_size]
+ so_far += opp_size
+ # Clean up.
+ del temp_indices, temp_chainIDs, hooks
+ yt_counters("communicate_uphill_info")
+
+ def _recurse_global_chain_links(self, chainID_translate_map_global, chainID, seen):
+ """
+ Step up the global chain links until we reach the self-densest chain,
+ very similarly to the recursion of particles to densest nearest
+ neighbors.
+ """
+ new_chainID = chainID_translate_map_global[chainID]
+ if new_chainID == chainID:
+ return int(chainID)
+ elif new_chainID in seen:
+ # Bad things are about to happen if this condition is met! The
+ # padding probably needs to be increased (using the safety factor).
+ mylog.info('seen %s' % str(seen))
+ for s in seen:
+ mylog.info('%d %d' % (s, chainID_translate_map_global[s]))
+ else:
+ seen.append(new_chainID)
+ return self._recurse_global_chain_links(chainID_translate_map_global, new_chainID, seen)
+
+ def _connect_chains_across_tasks(self):
+ """
+ Using the uphill links of chains, chains are linked across boundaries.
+ Chains that link to a remote chain are recorded, and a complete dict
+ of chain connections is created, globally. Then chainIDs are
+ reassigned recursively, assigning the ID of the most dense chainID
+ to every chain that links to it.
+ """
+ yt_counters("connect_chains_across_tasks")
+ # Remote (lower dens) chain -> local (higher) chain.
+ chainID_translate_map_local = na.arange(self.nchains)
+ # Build the stuff to send.
+ self.uphill_real_indices = na.concatenate((
+ self.index, self.index_pad))[self.padded_particles]
+ self.uphill_chainIDs = self.chainID[self.padded_particles]
+ del self.padded_particles
+ # Now we make a global dict of how many particles each task is
+ # sending.
+ self.global_padded_count = {self.mine:self.uphill_chainIDs.size}
+ self.global_padded_count = self._mpi_joindict(self.global_padded_count)
+ # Send/receive 'em.
+ self._communicate_uphill_info()
+ self.__max_memory()
+ # Fix the IDs to localIDs.
+ for i,real_index in enumerate(self.recv_real_indices):
+ try:
+ localID = self.rev_index[real_index]
+ # We don't want to update the chainIDs of my padded particles.
+ # Remember we are supposed to be only considering particles
+ # in my *real* region, that are padded in my neighbor.
+ if not self.is_inside[localID]:
+ # Make it negative so we can skip it below.
+ self.recv_real_indices[i] = -1
+ continue
+ self.recv_real_indices[i] = localID
+ except KeyError:
+ # This is probably a particle we don't even own, so we want
+ # to ignore it.
+ self.recv_real_indices[i] = -1
+ continue
+ # Now relate the local chainIDs to the received chainIDs
+ for i,localID in enumerate(self.recv_real_indices):
+ # If the 'new' chainID is different that what we already have,
+ # we need to record it, but we skip particles that were assigned
+ # -1 above. Also, since links are supposed to go only uphill,
+ # ensure that they are being recorded that way below.
+ if localID != -1 and self.chainID[localID] != -1:
+ if self.recv_chainIDs[i] != self.chainID[localID] and \
+ self.densest_in_chain[self.chainID[localID]] >= self.densest_in_chain[self.recv_chainIDs[i]] and \
+ self.densest_in_chain[self.chainID[localID]] != -1.0 and \
+ self.densest_in_chain[self.recv_chainIDs[i]] != -1.0:
+ chainID_translate_map_local[self.recv_chainIDs[i]] = \
+ self.chainID[localID]
+ self.__max_memory()
+ # In chainID_translate_map_local, chains may
+ # 'point' to only one chain, but a chain may have many that point to
+ # it. Therefore each key (a chain) in this dict is unique, but the items
+ # the keys point to are not necessarily unique.
+ chainID_translate_map_global = \
+ self._mpi_minimum_array_long(chainID_translate_map_local)
+ # Loop over chains, smallest to largest density, recursively until
+ # we reach a self-assigned chain. Then we assign that final chainID to
+ # the *current* one only.
+ seen = []
+ for key, density in enumerate(self.densest_in_chain):
+ if density == -1: continue # Skip 'deleted' chains
+ seen = []
+ seen.append(key)
+ new_chainID = \
+ self._recurse_global_chain_links(chainID_translate_map_global, key, seen)
+ chainID_translate_map_global[key] = new_chainID
+ # At the same time, remove chains from densest_in_chain that have
+ # been reassigned.
+ if key != new_chainID:
+ self.densest_in_chain[key] = -1.0
+ self.densest_in_chain_real_index[key] = -1
+ # Also fix nchains to keep up.
+ self.nchains -= 1
+ self.__max_memory()
+ # Convert local particles to their new chainID
+ for i in xrange(int(self.size)):
+ old_chainID = self.chainID[i]
+ if old_chainID == -1: continue
+ new_chainID = chainID_translate_map_global[old_chainID]
+ self.chainID[i] = new_chainID
+ del chainID_translate_map_local, self.recv_chainIDs
+ del self.recv_real_indices, self.uphill_real_indices, self.uphill_chainIDs
+ del seen, chainID_translate_map_global
+ yt_counters("connect_chains_across_tasks")
+
+ def _communicate_annulus_chainIDs(self):
+ """
+ Transmit all of our chainID-ed particles that are within self.padding
+ of the boundaries to all of our neighbors. Tests show that this is
+ faster than trying to figure out which of the neighbors to send the data
+ to.
+ """
+ yt_counters("communicate_annulus_chainIDs")
+ # Pick the particles in the annulus.
+ real_indices = na.concatenate(
+ (self.index, self.index_pad))[self.is_inside_annulus]
+ chainIDs = self.chainID[self.is_inside_annulus]
+ # We're done with this here.
+ del self.is_inside_annulus
+ # Eliminate un-assigned particles.
+ select = (chainIDs != -1)
+ real_indices = real_indices[select]
+ chainIDs = chainIDs[select]
+ send_count = real_indices.size
+ # Here distribute the counts globally. Unfortunately, it's a barrier(),
+ # but there's so many places in this that need to be globally synched
+ # that it's not worth the effort right now to make this one spot better.
+ global_annulus_count = {self.mine:send_count}
+ global_annulus_count = self._mpi_joindict(global_annulus_count)
+ # Set up the receiving arrays.
+ recv_real_indices = dict.fromkeys(self.neighbors)
+ recv_chainIDs = dict.fromkeys(self.neighbors)
+ for opp_neighbor in self.neighbors:
+ opp_size = global_annulus_count[opp_neighbor]
+ recv_real_indices[opp_neighbor] = na.empty(opp_size, dtype='int64')
+ recv_chainIDs[opp_neighbor] = na.empty(opp_size, dtype='int64')
+ # Set up the receving hooks.
+ hooks = []
+ for opp_neighbor in self.neighbors:
+ hooks.append(self._mpi_Irecv_long(recv_real_indices[opp_neighbor], opp_neighbor))
+ hooks.append(self._mpi_Irecv_long(recv_chainIDs[opp_neighbor], opp_neighbor))
+ # Make sure the recv buffers are set before continuing.
+ self._barrier()
+ # Now we send them.
+ for neighbor in self.neighbors:
+ hooks.append(self._mpi_Isend_long(real_indices, neighbor))
+ hooks.append(self._mpi_Isend_long(chainIDs, neighbor))
+ # Now we use them when they're nice and ripe.
+ self._mpi_Request_Waitall(hooks)
+ self.__max_memory()
+ for opp_neighbor in self.neighbors:
+ opp_size = global_annulus_count[opp_neighbor]
+ # Update our local data.
+ for i,real_index in enumerate(recv_real_indices[opp_neighbor][0:opp_size]):
+ try:
+ localID = self.rev_index[real_index]
+ # We are only updating our particles that are in our
+ # padding, so to be rigorous we will skip particles
+ # that are in our real region.
+ if self.is_inside[localID]:
+ continue
+ self.chainID[localID] = recv_chainIDs[opp_neighbor][i]
+ except KeyError:
+ # We ignore data that's not for us.
+ continue
+ # Clean up.
+ del recv_real_indices, recv_chainIDs, real_indices, chainIDs, select
+ del hooks
+ # We're done with this here.
+ del self.rev_index
+ yt_counters("communicate_annulus_chainIDs")
+
+
+ def _connect_chains(self):
+ """
+ With the set of particle chains, build a mapping of connected chainIDs
+ by finding the highest boundary density neighbor for each chain. Some
+ chains will have no neighbors!
+ """
+ yt_counters("connect_chains")
+ self.chain_densest_n = {} # chainID -> {chainIDs->boundary dens}
+ # Plus 2 because we're looking for that neighbor, but only keeping
+ # nMerge + 1 neighbor tags, skipping ourselves.
+ fKD.dist = na.empty(self.nMerge+2, dtype='float64')
+ fKD.tags = na.empty(self.nMerge+2, dtype='int64')
+ # We can change this here to make the searches faster.
+ fKD.nn = self.nMerge+2
+ for i in xrange(int(self.size)):
+ # Don't consider this particle if it's not part of a chain.
+ if self.chainID[i] < 0: continue
+ # If this particle is in the padding, don't make a connection.
+ if not self.is_inside[i]: continue
+ # Make sure that we should search this particle again.
+ if not self.search_again[i]: continue
+ # Find this particle's chain max_dens.
+ part_max_dens = self.densest_in_chain[self.chainID[i]]
+ # Make sure we're skipping deleted chains.
+ if part_max_dens == -1.0: continue
+ # Loop over nMerge closest nearest neighbors.
+ fKD.qv = fKD.pos[:, i]
+ find_nn_nearest_neighbors()
+ NNtags = fKD.tags[:] - 1
+ for j in xrange(int(self.nMerge+1)):
+ thisNN = NNtags[j+1] # Don't consider ourselves at NNtags[0]
+ thisNN_chainID = self.chainID[thisNN]
+ # If our neighbor is in the same chain, move on.
+ if self.chainID[i] == thisNN_chainID: continue
+ # Everything immediately below is for
+ # neighboring particles with a chainID.
+ if thisNN_chainID >= 0:
+ # Find thisNN's chain's max_dens.
+ thisNN_max_dens = self.densest_in_chain[thisNN_chainID]
+ if thisNN_max_dens == -1.0: continue
+ # Calculate the two groups boundary density.
+ boundary_density = (self.density[thisNN] + self.density[i]) / 2.
+ # Find out who's denser.
+ if thisNN_max_dens >= part_max_dens:
+ higher_chain = thisNN_chainID
+ lower_chain = self.chainID[i]
+ else:
+ higher_chain = self.chainID[i]
+ lower_chain = thisNN_chainID
+ # Make sure that the higher density chain has an entry.
+ try:
+ test = self.chain_densest_n[int(higher_chain)]
+ except KeyError:
+ self.chain_densest_n[int(higher_chain)] = {}
+ # See if this boundary density is higher than
+ # previously recorded for this pair of chains.
+ # Links only go one direction.
+ try:
+ old = self.chain_densest_n[int(higher_chain)][int(lower_chain)]
+ if old < boundary_density:
+ # make this the new densest boundary between this pair
+ self.chain_densest_n[int(higher_chain)][int(lower_chain)] = \
+ boundary_density
+ except KeyError:
+ # we haven't seen this pairing before, record this as the
+ # new densest boundary between chains
+ self.chain_densest_n[int(higher_chain)][int(lower_chain)] = \
+ boundary_density
+ else:
+ continue
+ self.__max_memory()
+ yt_counters("connect_chains")
+
+ def _make_global_chain_densest_n(self):
+ """
+ We want to record the maximum boundary density between all chains on
+ all tasks.
+ """
+ yt_counters("make_global_chain_densest_n")
+ (self.top_keys, self.bot_keys, self.vals) = \
+ self._mpi_maxdict_dict(self.chain_densest_n)
+ self.__max_memory()
+ del self.chain_densest_n
+ yt_counters("make_global_chain_densest_n")
+
+ def _build_groups(self):
+ """
+ With the collection of possible chain links, build groups.
+ """
+ yt_counters("build_groups")
+ # We need to find out which pairs of self.top_keys, self.bot_keys are
+ # both < self.peakthresh, and create arrays that will store this
+ # relationship.
+ both = na.bitwise_and((self.densest_in_chain[self.top_keys] < self.peakthresh),
+ (self.densest_in_chain[self.bot_keys] < self.peakthresh))
+ g_high = self.top_keys[both]
+ g_low = self.bot_keys[both]
+ g_dens = self.vals[both]
+ del both
+ self.reverse_map = na.ones(self.densest_in_chain.size) * -1
+ densestbound = na.ones(self.densest_in_chain.size) * -1.0
+ for i, gl in enumerate(g_low):
+ if g_dens[i] > densestbound[gl]:
+ densestbound[gl] = g_dens[i]
+ groupID = 0
+ # First assign a group to all chains with max_dens above peakthresh.
+ # The initial groupIDs will be assigned with decending peak density.
+ # This guarantees that the group with the smaller groupID is the
+ # higher chain, as in chain_high below.
+ for chainID,density in enumerate(self.densest_in_chain):
+ if density == -1.0: continue
+ if self.densest_in_chain[chainID] >= self.peakthresh:
+ self.reverse_map[chainID] = groupID
+ groupID += 1
+ group_equivalancy_map = na.empty(groupID, dtype='object')
+ for i in xrange(groupID):
+ group_equivalancy_map[i] = set([])
+ # Loop over all of the chain linkages.
+ for i,chain_high in enumerate(self.top_keys):
+ chain_low = self.bot_keys[i]
+ dens = self.vals[i]
+ max_dens_high = self.densest_in_chain[chain_high]
+ max_dens_low = self.densest_in_chain[chain_low]
+ if max_dens_high == -1.0 or max_dens_low == -1.0: continue
+ # If neither are peak density groups, mark them for later
+ # consideration.
+ if max_dens_high < self.peakthresh and \
+ max_dens_low < self.peakthresh:
+ # This step is now done vectorized above, with the g_dens
+ # stuff.
+ continue
+ # If both are peak density groups, and have a boundary density
+ # that is high enough, make them into a group, otherwise
+ # move onto another linkage.
+ if max_dens_high >= self.peakthresh and \
+ max_dens_low >= self.peakthresh:
+ if dens < self.saddlethresh:
+ continue
+ else:
+ group_high = self.reverse_map[chain_high]
+ group_low = self.reverse_map[chain_low]
+ if group_high == -1 or group_low == -1: continue
+ # Both are already identified as groups, so we need
+ # to re-assign the less dense group to the denser
+ # groupID.
+ if group_low != group_high:
+ group_equivalancy_map[group_low].add(group_high)
+ group_equivalancy_map[group_high].add(group_low)
+ continue
+ # Else, one is above peakthresh, the other below
+ # find out if this is the densest boundary seen so far for
+ # the lower chain.
+ group_high = self.reverse_map[chain_high]
+ if group_high == -1: continue
+ if dens >= densestbound[chain_low]:
+ densestbound[chain_low] = dens
+ self.reverse_map[chain_low] = group_high
+ self.__max_memory()
+ del self.top_keys, self.bot_keys, self.vals
+ # Now refactor group_equivalancy_map back into reverse_map. The group
+ # mapping may be more than one link long, so we need to do it
+ # recursively. The best way to think about this is a field full of
+ # rabbit holes. The holes are connected at nexuses at the surface.
+ # Each groupID (key) in group_equivalancy_map represents a hole, and
+ # the values the nexuses are the tunnels lead to. The tunnels are two-way,
+ # and when you go through it, you block the passage through that
+ # tunnel in that direction, so you don't repeat yourself later. You can
+ # go back through that tunnel, but your search ends there because all
+ # the other tunnels have been closed at the old nexus. In this fashion your search
+ # spreads out like the water shooting out of the ground in 'Caddy
+ # Shack.'
+ Set_list = []
+ # We only want the holes that are modulo mine.
+ keys = na.arange(groupID, dtype='int64')
+ if self._mpi_get_size() == None:
+ size = 1
+ else:
+ size = self._mpi_get_size()
+ select = (keys % size == self.mine)
+ groupIDs = keys[select]
+ mine_groupIDs = set([]) # Records only ones modulo mine.
+ not_mine_groupIDs = set([]) # All the others.
+ # Declare these to prevent Errors when they're del-ed below, in case
+ # this task doesn't create them in the loop, for whatever reason.
+ current_sets, new_mine, new_other = [], [], []
+ new_set, final_set, to_add_set, liter = set([]), set([]), set([]), set([])
+ to_add_set = set([])
+ for groupID in groupIDs:
+ if groupID in mine_groupIDs:
+ continue
+ mine_groupIDs.add(groupID)
+ current_sets = []
+ new_set = group_equivalancy_map[groupID]
+ final_set = new_set.copy()
+ while len(new_set) > 0:
+ to_add_set = set([])
+ liter = new_set.difference(mine_groupIDs).difference(not_mine_groupIDs)
+ new_mine, new_other = [], []
+ for link_gID in liter:
+ to_add_set.update(group_equivalancy_map[link_gID])
+ if link_gID % size == self.mine:
+ new_mine.append(link_gID)
+ else:
+ new_other.append(link_gID)
+ mine_groupIDs.update(new_mine)
+ not_mine_groupIDs.update(new_other)
+ final_set.update(to_add_set)
+ new_set = to_add_set
+ # Make sure it's not empty
+ final_set.add(groupID)
+ Set_list.append(final_set)
+ self.__max_memory()
+ del group_equivalancy_map, final_set, keys, select, groupIDs, current_sets
+ del mine_groupIDs, not_mine_groupIDs, new_set, to_add_set, liter
+ # Convert this list of sets into a look-up table
+ lookup = na.ones(self.densest_in_chain.size, dtype='int64') * (self.densest_in_chain.size + 2)
+ for i,item in enumerate(Set_list):
+ item_min = min(item)
+ for groupID in item:
+ lookup[groupID] = item_min
+ self.__max_memory()
+ del Set_list
+ # To bring it all together, find the minimum values at each entry
+ # globally.
+ lookup = self._mpi_minimum_array_long(lookup)
+ # Now apply this to reverse_map
+ for chainID,groupID in enumerate(self.reverse_map):
+ if groupID == -1:
+ continue
+ if lookup[groupID] != (self.densest_in_chain.size + 2):
+ self.reverse_map[chainID] = lookup[groupID]
+ del lookup
+ """
+ Now the fringe chains are connected to the proper group
+ (>peakthresh) with the largest boundary. But we want to look
+ through the boundaries between fringe groups to propagate this
+ along. Connections are only as good as their smallest boundary
+ """
+ changes = 1
+ while changes:
+ changes = 0
+ for j,dens in enumerate(g_dens):
+ chain_high = g_high[j]
+ chain_low = g_low[j]
+ # If the density of this boundary and the densestbound of
+ # the other chain is higher than a chain's densestbound, then
+ # replace it. We also don't want to link to un-assigned
+ # neighbors, and we can skip neighbors we're already assigned to.
+ if dens >= densestbound[chain_low] and \
+ densestbound[chain_high] > densestbound[chain_low] and \
+ self.reverse_map[chain_high] != -1 and \
+ self.reverse_map[chain_low] != self.reverse_map[chain_high]:
+ changes += 1
+ if dens < densestbound[chain_high]:
+ densestbound[chain_low] = dens
+ else:
+ densestbound[chain_low] = densestbound[chain_high]
+ self.reverse_map[chain_low] = self.reverse_map[chain_high]
+ self.__max_memory()
+ del g_high, g_low, g_dens, densestbound
+ # Now we have to find the unique groupIDs, since they may have been
+ # merged.
+ temp = list(set(self.reverse_map))
+ # Remove -1 from the list.
+ try:
+ temp.pop(temp.index(-1))
+ except ValueError:
+ # There are no groups, probably.
+ pass
+ # Make a secondary map to make the IDs consecutive.
+ values = na.arange(len(temp))
+ secondary_map = dict(itertools.izip(temp, values))
+ del values
+ # Update reverse_map
+ for chain, map in enumerate(self.reverse_map):
+ # Don't attempt to fix non-assigned chains.
+ if map == -1: continue
+ self.reverse_map[chain] = secondary_map[map]
+ group_count = len(temp)
+ del secondary_map, temp
+ yt_counters("build_groups")
+ self.__max_memory()
+ return group_count
+
+ def _translate_groupIDs(self, group_count):
+ """
+ Using the maps, convert the particle chainIDs into their locally-final
+ groupIDs.
+ """
+ yt_counters("translate_groupIDs")
+ self.I_own = set([])
+ for i in xrange(int(self.size)):
+ # Don't translate non-affiliated particles.
+ if self.chainID[i] == -1: continue
+ # We want to remove the group tag from padded particles,
+ # so when we return it to HaloFinding, there is no duplication.
+ if self.is_inside[i]:
+ self.chainID[i] = self.reverse_map[self.chainID[i]]
+ self.I_own.add(self.chainID[i])
+ else:
+ self.chainID[i] = -1
+ del self.is_inside
+ # Create a densest_in_group, analogous to densest_in_chain.
+ keys = na.arange(group_count)
+ vals = na.zeros(group_count)
+ self.densest_in_group = dict(itertools.izip(keys,vals))
+ self.densest_in_group_real_index = self.densest_in_group.copy()
+ del keys, vals
+ for chainID,max_dens in enumerate(self.densest_in_chain):
+ if max_dens == -1.0: continue
+ groupID = self.reverse_map[chainID]
+ if groupID == -1: continue
+ if self.densest_in_group[groupID] < max_dens:
+ self.densest_in_group[groupID] = max_dens
+ self.densest_in_group_real_index[groupID] = self.densest_in_chain_real_index[chainID]
+ del self.densest_in_chain, self.densest_in_chain_real_index
+ yt_counters("translate_groupIDs")
+
+ def _precompute_group_info(self):
+ yt_counters("Precomp.")
+ """
+ For all groups, compute the various global properties, except bulk
+ velocity, to save time in HaloFinding.py (fewer barriers!).
+ """
+ select = (self.chainID != -1)
+ calc = len(na.where(select == True)[0])
+ loc = na.empty((calc, 3), dtype='float64')
+ loc[:, 0] = na.concatenate((self.xpos, self.xpos_pad))[select]
+ loc[:, 1] = na.concatenate((self.ypos, self.ypos_pad))[select]
+ loc[:, 2] = na.concatenate((self.zpos, self.zpos_pad))[select]
+ self.__max_memory()
+ del self.xpos_pad, self.ypos_pad, self.zpos_pad
+ subchain = self.chainID[select]
+ # First we need to find the maximum density point for all groups.
+ # I think this will be faster than several vector operations that need
+ # to pull the entire chainID array out of memory several times.
+ yt_counters("max dens point")
+ max_dens_point = na.zeros((self.group_count,4),dtype='float64')
+ for i,part in enumerate(na.arange(self.size)[select]):
+ groupID = self.chainID[part]
+ if part < self.real_size:
+ real_index = self.index[part]
+ else:
+ real_index = self.index_pad[part - self.real_size]
+ if real_index == self.densest_in_group_real_index[groupID]:
+ max_dens_point[groupID] = na.array([self.density[part], \
+ loc[i, 0], loc[i, 1], loc[i, 2]])
+ del self.index, self.index_pad, self.densest_in_group_real_index
+ # Now we broadcast this, effectively, with an allsum. Even though
+ # some groups are on multiple tasks, there is only one densest_in_chain
+ # and only that task contributed above.
+ self.max_dens_point = self._mpi_Allsum_double(max_dens_point)
+ yt_counters("max dens point")
+ # Now CoM.
+ yt_counters("CoM")
+ CoM_M = na.zeros((self.group_count,3),dtype='float64')
+ Tot_M = na.zeros(self.group_count, dtype='float64')
+ #c_vec = self.max_dens_point[:,1:4][subchain] - na.array([0.5,0.5,0.5])
+ if calc:
+ c_vec = self.max_dens_point[:,1:4][subchain] - na.array([0.5,0.5,0.5])
+ size = na.bincount(self.chainID[select]).astype('int64')
+ else:
+ # This task has no particles in groups!
+ size = na.zeros(self.group_count, dtype='int64')
+ # In case this task doesn't have all the groups, add trailing zeros.
+ if size.size != self.group_count:
+ size = na.concatenate((size, na.zeros(self.group_count - size.size, dtype='int64')))
+ if calc:
+ cc = loc - c_vec
+ cc = cc - na.floor(cc)
+ ms = na.concatenate((self.mass, self.mass_pad))[select]
+ # Most of the time, the masses will be all the same, and we can try
+ # to save some effort.
+ ms_u = na.unique(ms)
+ if ms_u.size == 1:
+ single = True
+ Tot_M = size.astype('float64') * ms_u
+ else:
+ single = False
+ del ms_u
+ cc[:,0] = cc[:,0] * ms
+ cc[:,1] = cc[:,1] * ms
+ cc[:,2] = cc[:,2] * ms
+ sort = subchain.argsort()
+ cc = cc[sort]
+ sort_subchain = subchain[sort]
+ uniq_subchain = na.unique(sort_subchain)
+ diff_subchain = na.ediff1d(sort_subchain)
+ marks = (diff_subchain > 0)
+ marks = na.arange(calc)[marks] + 1
+ marks = na.concatenate(([0], marks, [calc]))
+ for i, u in enumerate(uniq_subchain):
+ CoM_M[u] = na.sum(cc[marks[i]:marks[i+1]], axis=0)
+ if not single:
+ for i,groupID in enumerate(subchain):
+ Tot_M[groupID] += ms[i]
+ del cc, ms
+ for groupID in xrange(int(self.group_count)):
+ # Don't divide by zero.
+ if groupID in self.I_own:
+ CoM_M[groupID] /= Tot_M[groupID]
+ CoM_M[groupID] += self.max_dens_point[groupID,1:4] - na.array([0.5,0.5,0.5])
+ CoM_M[groupID] *= Tot_M[groupID]
+ # Now we find their global values
+ self.group_sizes = self._mpi_Allsum_long(size)
+ CoM_M = self._mpi_Allsum_double(CoM_M)
+ self.Tot_M = self._mpi_Allsum_double(Tot_M)
+ self.CoM = na.empty((self.group_count,3), dtype='float64')
+ for groupID in xrange(int(self.group_count)):
+ self.CoM[groupID] = CoM_M[groupID] / self.Tot_M[groupID]
+ yt_counters("CoM")
+ self.__max_memory()
+ # Now we find the maximum radius for all groups.
+ yt_counters("max radius")
+ max_radius = na.zeros(self.group_count, dtype='float64')
+ if calc:
+ com = self.CoM[subchain]
+ rad = na.abs(com - loc)
+ dist = (na.minimum(rad, self.period - rad)**2.).sum(axis=1)
+ dist = dist[sort]
+ for i, u in enumerate(uniq_subchain):
+ max_radius[u] = na.max(dist[marks[i]:marks[i+1]])
+ # Find the maximum across all tasks.
+ mylog.info('Fraction of particles in this region in groups: %f' % (float(calc)/self.size))
+ self.max_radius = self._mpi_double_array_max(max_radius)
+ self.max_radius = na.sqrt(self.max_radius)
+ yt_counters("max radius")
+ yt_counters("Precomp.")
+ self.__max_memory()
+ if calc:
+ del loc, subchain, CoM_M, Tot_M, c_vec, max_radius, select
+ del sort_subchain, uniq_subchain, diff_subchain, marks, dist, sort
+ del rad, com
+
+ def _chain_hop(self):
+ self._global_padding('first')
+ self._global_bounds_neighbors()
+ self._global_padding('second')
+ self._is_inside('first')
+ mylog.info('Distributing padded particles...')
+ self._communicate_padding_data()
+ mylog.info('Building kd tree for %d particles...' % \
+ self.size)
+ self._init_kd_tree()
+ # Mark particles in as being in/out of the domain.
+ self._is_inside('second')
+ # Loop over the particles to find NN for each.
+ mylog.info('Finding nearest neighbors/density...')
+ yt_counters("chainHOP_tags_dens")
+ chainHOP_tags_dens()
+ yt_counters("chainHOP_tags_dens")
+ self.density = fKD.dens
+ # Now each particle has NNtags, and a local self density.
+ # Let's find densest NN
+ mylog.info('Finding densest nearest neighbors...')
+ self._densestNN()
+ # Build the chain of links.
+ mylog.info('Building particle chains...')
+ chain_count = self._build_chains()
+ # This array tracks whether or not relationships for this particle
+ # need to be examined twice, in preconnect_chains and in connect_chains
+ self.search_again = na.ones(self.size, dtype='bool')
+ chain_count = self._preconnect_chains(chain_count)
+ mylog.info('Gobally assigning chainIDs...')
+ self._globally_assign_chainIDs(chain_count)
+ mylog.info('Globally finding densest in chains...')
+ self._create_global_densest_in_chain()
+ mylog.info('Building chain connections across tasks...')
+ self._is_inside('third')
+ self._connect_chains_across_tasks()
+ mylog.info('Communicating connected chains...')
+ self._communicate_annulus_chainIDs()
+ mylog.info('Connecting %d chains into groups...' % self.nchains)
+ self._connect_chains()
+ del fKD.dens, fKD.mass, fKD.dens
+ del fKD.pos, fKD.chunk_tags
+ free_tree() # Frees the kdtree object.
+ del self.densestNN
+ mylog.info('Communicating group links globally...')
+ self._make_global_chain_densest_n()
+ mylog.info('Building final groups...')
+ group_count = self._build_groups()
+ self.group_count = group_count
+ mylog.info('Remapping particles to final groups...')
+ self._translate_groupIDs(group_count)
+ mylog.info('Precomputing info for %d groups...' % group_count)
+ self._precompute_group_info()
+ mylog.info("All done! Max Memory = %d MB" % self.max_mem)
+ # We need to fix chainID and density because HaloFinding is expecting
+ # an array only as long as the real data.
+ self.chainID = self.chainID[:self.real_size]
+ self.density = self.density[:self.real_size]
+ # We'll make this a global object, which can be used to write a text
+ # file giving the names of hdf5 files the particles for each halo.
+ self.mine, self.I_own = self._mpi_info_dict(self.I_own)
+ self.halo_taskmap = defaultdict(set)
+ for taskID in self.I_own:
+ for groupID in self.I_own[taskID]:
+ self.halo_taskmap[groupID].add(taskID)
+ del self.I_own
+ del self.mass, self.xpos, self.ypos, self.zpos
+
+ def __add_to_array(self, arr, key, value, type):
+ """
+ In an effort to replace the functionality of a dict with an array, in
+ order to save memory, this function adds items to an array. If the
+ array is not long enough, it is resized and filled with 'bad' values."""
+
+ try:
+ arr[key] = value
+ except IndexError:
+ arr = na.concatenate((arr, na.ones(10000, dtype=type)*-1))
+ arr[key] = value
+ return arr
+
+ def __clean_up_array(self, arr):
+ good = (arr != -1)
+ return arr[good]
+
+ def __max_memory(self):
+ my_mem = get_memory_usage()
+ self.max_mem = max(my_mem, self.max_mem)
\ No newline at end of file
Added: trunk/yt/lagos/parallelHOP/run.py
==============================================================================
--- (empty file)
+++ trunk/yt/lagos/parallelHOP/run.py Sun Dec 20 12:42:09 2009
@@ -0,0 +1,24 @@
+from yt.config import ytcfg
+
+ytcfg["yt","time_functions"] = "True"
+
+from yt.mods import *
+
+yt_counters("Full Time")
+
+yt_counters("yt Hierarchy")
+pf = load('data0005')
+
+pf.h
+yt_counters("yt Hierarchy")
+
+h = yt.lagos.HaloFinding.parallelHF(pf, threshold=160.0, safety=2.5, \
+dm_only=False,resize=True, fancy_padding=True, rearrange=True)
+
+yt_counters("Writing Data")
+h.write_out('dist-chain.out')
+h.write_particle_lists_txt("chain")
+h.write_particle_lists("chain")
+yt_counters("Writing Data")
+
+yt_counters("Full Time")
Modified: trunk/yt/lagos/setup.py
==============================================================================
--- trunk/yt/lagos/setup.py (original)
+++ trunk/yt/lagos/setup.py Sun Dec 20 12:42:09 2009
@@ -22,6 +22,7 @@
config.add_extension("PointCombine", "yt/lagos/PointCombine.c", libraries=["m"])
config.add_subpackage("hop")
config.add_subpackage("fof")
+ config.add_subpackage("parallelHOP")
H5dir = check_for_hdf5()
if H5dir is not None:
include_dirs=[os.path.join(H5dir,"include")]
Added: trunk/yt/math_utils.py
==============================================================================
--- (empty file)
+++ trunk/yt/math_utils.py Sun Dec 20 12:42:09 2009
@@ -0,0 +1,46 @@
+"""
+Commonly used mathematical functions.
+
+Author: Matthew Turk <matthewturk at gmail.com>
+Affiliation: UCSD Physics/CASS
+Author: Stephen Skory <stephenskory at yahoo.com>
+Affiliation: UCSD Physics/CASS
+Homepage: http://yt.enzotools.org/
+License:
+ Copyright (C) 2008-2009 Matthew Turk. All Rights Reserved.
+
+ This file is part of yt.
+
+ yt is free software; you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation; either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+
+import numpy as na
+import math
+
+def periodic_dist(a, b, period):
+ """
+ Find the Euclidian periodic distance between two points.
+ *a* and *b* are array-like vectors, and *period* is a float or
+ array-like value for the periodic size of the volume.
+ """
+ a = na.array(a)
+ b = na.array(b)
+ if a.size != b.size: RunTimeError("Arrays must be the same shape.")
+ c = na.empty((2, a.size), dtype="float64")
+ c[0,:] = abs(a - b)
+ c[1,:] = period - abs(a - b)
+ d = na.amin(c, axis=0)**2
+ return math.sqrt(d.sum())
+
+
\ No newline at end of file
More information about the yt-svn
mailing list