[Yt-svn] yt-commit r1348 - trunk/yt/lagos

sskory at wrangler.dreamhost.com sskory at wrangler.dreamhost.com
Wed Jun 17 16:04:37 PDT 2009


Author: sskory
Date: Wed Jun 17 16:04:37 2009
New Revision: 1348
URL: http://yt.spacepope.org/changeset/1348

Log:
Adding a python kd tree, and nearest neighbor stuff to HaloFinding.

Added:
   trunk/yt/lagos/kd.py
Modified:
   trunk/yt/lagos/HaloFinding.py

Modified: trunk/yt/lagos/HaloFinding.py
==============================================================================
--- trunk/yt/lagos/HaloFinding.py	(original)
+++ trunk/yt/lagos/HaloFinding.py	Wed Jun 17 16:04:37 2009
@@ -32,6 +32,9 @@
 except ImportError:
     pass
 
+from kd import *
+import math
+
 class Halo(object):
     """
     A data source that returns particle information about the members of a
@@ -239,6 +242,76 @@
     def __getitem__(self, key):
         return self._groups[key]
 
+    def nearest_neighbors_3D(self, haloID, num_neighbors=7, search_radius=.2):
+        """
+        for halo *haloID*, find up to *num_neighbors* nearest neighbors in 3D
+        using the kd tree. Search over *search_radius* in code units.
+        Returns a list of the neighbors distances and ID with format
+        [distance,haloID].
+        """
+        period = self.pf['DomainRightEdge'] - self.pf['DomainLeftEdge']
+        # Initialize the dataset of points from all the haloes
+        dataset = []
+        for group in self:
+            p = Point()
+            p.data = group.center_of_mass().tolist()
+            p.haloID = group.id
+            dataset.append(p)
+        mylog.info('Building kd tree...')
+        kd = buildKdHyperRectTree(dataset[:],2*num_neighbors)
+        # make the neighbors object
+        neighbors = Neighbors()
+        neighbors.k = num_neighbors
+        neighbors.points = []
+        neighbors.minDistanceSquared = search_radius * search_radius
+        mylog.info('Finding nearest neighbors...')
+        getKNN(self[haloID].center_of_mass().tolist(), kd, neighbors,0., period.tolist())
+        # convert the data in order to return something less perverse than a
+        # Neighbors object, also root the distances
+        n_points = []
+        for n in neighbors.points:
+            n_points.append([math.sqrt(n[0]),n[1].haloID])
+        return n_points
+
+    def nearest_neighbors_2D(self, haloID, num_neighbors=7, search_radius=.2,
+        proj_dim=0):
+        """
+        for halo *haloID*, find up to *num_neighbors* nearest neighbors in 2D
+        using the kd tree. Search over *search_radius* in code units.
+        The halo positions are projected along dimension *proj_dim*.
+        Returns a list of the neighbors distances and ID with format
+        [distance,haloID].
+        """
+        # Set up a vector to multiply other vectors by to project along proj_dim
+        vec = na.array([1.,1.,1.])
+        vec[proj_dim] = 0.
+        period = self.pf['DomainRightEdge'] - self.pf['DomainLeftEdge']
+        period = period * vec
+        # Initialize the dataset of points from all the haloes
+        dataset = []
+        for group in self:
+            p = Point()
+            cm = group.center_of_mass() * vec
+            p.data = cm.tolist()
+            p.haloID = group.id
+            dataset.append(p)
+        mylog.info('Building kd tree...')
+        kd = buildKdHyperRectTree(dataset[:],2*num_neighbors)
+        # make the neighbors object
+        neighbors = Neighbors()
+        neighbors.k = num_neighbors
+        neighbors.points = []
+        neighbors.minDistanceSquared = search_radius * search_radius
+        mylog.info('Finding nearest neighbors...')
+        cm = self[haloID].center_of_mass() * vec
+        getKNN(cm.tolist(), kd, neighbors,0., period.tolist())
+        # convert the data in order to return something less perverse than a
+        # Neighbors object, also root the distances
+        n_points = []
+        for n in neighbors.points:
+            n_points.append([math.sqrt(n[0]),n[1].haloID])
+        return n_points
+
     def write_out(self, filename):
         """
         Write out standard HOP information to *filename*.

Added: trunk/yt/lagos/kd.py
==============================================================================
--- (empty file)
+++ trunk/yt/lagos/kd.py	Wed Jun 17 16:04:37 2009
@@ -0,0 +1,243 @@
+"""
+Python kD Tree
+
+Author: Michael Knight
+Affiliation: ?
+Note: Code is based on <http://sites.google.com/site/mikescoderama/Home/kd-tree-knn>,
+with periodicity added and a few other cosmetic changes. There is no contact
+infomation on that page, and this code's license is a bit uncertain.
+Author: Stephen Skory <stephenskory at yahoo.com>
+Affiliation: UCSD Physics/CASS
+Homepage: http://yt.enzotools.org/
+License:
+  Copyright (C) 2008-2009 Matthew Turk.  All Rights Reserved.
+
+  This file is part of yt.
+
+  yt is free software; you can redistribute it and/or modify
+  it under the terms of the GNU General Public License as published by
+  the Free Software Foundation; either version 3 of the License, or
+  (at your option) any later version.
+
+  This program is distributed in the hope that it will be useful,
+  but WITHOUT ANY WARRANTY; without even the implied warranty of
+  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+  GNU General Public License for more details.
+
+  You should have received a copy of the GNU General Public License
+  along with this program.  If not, see <http://www.gnu.org/licenses/>.
+"""
+
+from yt.lagos import *
+
+from bisect import insort
+
+class Point:pass
+
+class Node:
+
+    def _printOut(self,depth = 0):
+        """
+        iteratively print out the kd tree nodes, depth really isn't a user-set
+        parameter.
+        """
+        for i in range(0,depth):
+            print "....",
+        print "point count =", self.pointCount, "rect =", \
+            self.hyperRect._nodeToString(), '\n' #"points =",self.points
+        if (self.leftChild is not None):
+            for i in range(0,depth):
+                print "....",
+            print "left: "
+            self.leftChild._printOut(depth+1)
+        if (self.rightChild is not None):
+            for i in range(0,depth):
+                print "....",
+            print "right: "
+            self.rightChild._printOut(depth+1)
+    
+    def _buildBoundingHyperRect(self,points):
+        self.hyperRect = HyperRect()
+        self.hyperRect._buildBoundingHyperRect(points)
+ 
+def getFastDistance(a,b,period):
+    """
+    returns the square of the distance between points a and b,
+    using periodic boundary conditions 'period'.
+    """
+    dim = len(b);
+    total = 0;
+    for i in range(0,dim):
+        delta = min(abs(a[i] - b[i]), period[i] - abs(a[i] - b[i]));
+        total = total + (delta *delta)
+    return total
+ 
+class Neighbors:
+    
+    def _addNeighbors(self,node,query,period):
+        """
+        for each point in this node, calculate the distance to the query point,
+        if it's close enough add it to the list .points using insort
+        """
+        for i in range(0,node.pointCount):
+            dist = getFastDistance(node.points[i].data,query,period)
+            
+            if (dist < self.minDistanceSquared):
+                item = [dist,node.points[i]]
+                insort(self.points,item)
+                if (len(self.points) > self.k):
+                    self.points = self.points[0:self.k]
+ 
+        if (len(self.points) == self.k):
+            self.minDistanceSquared = self.points[self.k-1][0]
+        return;
+ 
+class HyperRect:
+    def _buildBoundingHyperRect(self,points):
+        """
+        find the extremities of the hypercube.
+        """
+        self.k = len(points[0].data)
+        self.dims  = range(0,self.k) 
+        high = points[0].data[:]
+        low = points[0].data[:]
+        for i in range(0,len(points)):
+            for j in self.dims:
+                point = points[i].data[j]
+                if (high[j]  < point):
+                    high[j] = point
+                if (low[j] > point):
+                    low[j] = point
+        self.high = high
+        self.low = low
+        return 
+ 
+    def _getWidestDimension(self):
+        """
+        get the widest dimension of the points in order to find the dimension
+        to bisect.
+        """
+        widest =0
+        widestDim =-1
+        for i in self.dims:
+            width = self.high[i] -  self.low[i]
+            if (width > widest):
+                widestDim =i
+                widest = width
+        self.widest = widest;
+        self.widestDim = widestDim;    
+        return self.widestDim
+    
+    def _getWidestDimensionWidth(self): # I don't know why this is here.
+        return self.widest
+    
+    def _nodeToString(self):
+        return "high =",self.high,"low =",self.low,
+    
+    def _getMinDistance(self,query,period):
+        """
+        find the minimum distance squared to a corner of the hypercube
+        from the query point using periodicity.
+        """
+        total = 0.0
+        for i in self.dims:
+            delta = 0.0
+            min_high = min(abs(query[i] - self.high[i]), \
+                period[i] - abs(query[i] - self.high[i]))
+            min_low = min(abs(query[i] - self.low[i]), \
+                period[i] - abs(query[i] - self.low[i]))
+            delta = min(min_low,min_high)
+            total = total + (delta*delta)
+        return total;
+
+def buildKdHyperRectTree(points,rootMin=3):
+    """
+    Recursively build the kdTree, adding nodes as needed until all have no more
+    than rootMin points. The final nodes are called leafs, which contain the
+    point data.
+    """
+    if (points is None or len(points) ==0):
+        return None
+    n = Node() # make a new node
+    n._buildBoundingHyperRect(points)  # build the hyper rect for these points
+                        # this will find the top left and botom
+                        # right of all given points.
+
+    leaf = False
+    # If the size of points is small enough, this node is a leaf
+    if len(points) <= rootMin:
+        leaf = True
+    splitDim  = -1
+ 
+    if (not leaf):
+        # get the widest dimension to split n to maximize splitting affect
+        splitDim = n.hyperRect._getWidestDimension()
+        # do we have a bunch of children at the same point?
+        if (n.hyperRect._getWidestDimensionWidth() == 0.0):
+            left = True 
+    #init the node
+    n.pointCount = len(points)
+    n.points = None
+    n.leftChild = None
+    n.rightChild = None
+    n.points = None
+ 
+    if (leaf or len(points)==0):
+        n.points = points # we are a leaf so just store all points in the rect
+    else:
+        # sort by the best split dimension
+        temp = []
+        for index,p in enumerate(points):
+            insort(temp,[p.data[splitDim],index])
+        temp2 = points[:]
+        for index,t in enumerate(temp):
+            temp2[index] = points[t[1]]
+        points = temp2[:]
+        del temp, temp2
+        #points.sort(key=lambda points: points.data[splitDim])
+        median = len(points)/2     # get the median
+        # and split left for smaller values in splitDim, right for larger
+        n.leftChild = buildKdHyperRectTree(points[0:(median+1)], rootMin)
+        if (median+1 < len(points)):
+            n.rightChild = buildKdHyperRectTree(points[median+1:], rootMin)
+    return n;
+ 
+def getKNN(query,node, neighbors,distanceSquared,period):
+    """
+    Recursively walk the kd tree, limited by *distanceSquared* to the extrema of
+    the hypercubes, only finding distances to the query point in leaf nodes.
+    *neighbors* is a Neighbors object, and needs to be initialized bofore
+    calling this. *period* is a list or array of the period for each dimension
+    of the hypercube.
+    """
+    
+    # test to see if the query point is inside this node
+    # <= and >= on both ends is okay, it's better to be inclusive and it
+    # prevents problems with particles on boundaries
+    for i in node.hyperRect.dims:
+        if query[i] <= node.hyperRect.high[i] and \
+                query[i] >= node.hyperRect.low[i]:
+            inside = True
+        else:
+            inside = False
+            break
+    
+    # if this node is close enough (the distances are calculated in the previous
+    # iteration), or if the query point is inside the node, continue on
+    if (neighbors.minDistanceSquared > distanceSquared) or inside:
+        # leafs don't have children, so this tests to see if this node
+        # is a leaf, and if it is a leaf, calculate distances to the query point
+        if (node.leftChild is None):
+            # add to neighbors.points
+            neighbors._addNeighbors(node,query,period)
+        # if this node is not a leaf, find out the distance to its children,
+        # and then continue the iteration down the kd tree.
+        else:
+            distLeft = node.leftChild.hyperRect._getMinDistance(query,period)
+            distRight = node.rightChild.hyperRect._getMinDistance(query,period)
+            if (distLeft < distRight):
+                getKNN(query,node.leftChild,neighbors,distLeft,period)
+                getKNN(query,node.rightChild,neighbors,distRight,period)
+            else:
+                getKNN(query,node.rightChild,neighbors,distRight,period)
+                getKNN(query,node.leftChild,neighbors,distLeft,period)



More information about the yt-svn mailing list