查找到该点属于的区域之后回溯
import heapq import numpy as np from sklearn.preprocessing import StandardScaler class Node(): # KD 树节点 def __init__(self): self.father = None self.left = None self.right = None self.feature = None self.split = None @property def brother(self): """ 获取兄弟节点 """ if self.father is None: ret = None else: if self.father.left is self: ret = self.father.right else: ret = self.father.left return ret def __str__(self): return "feature: %s, split: %s" % (str(self.feature), str(self.split)) class KDTree(): #KD树 def __init__(self): self.root = Node() self.scaler = None def build_tree(self,X,y): """ 根据给定的数据集构建KD树 """ #标准化X self.scaler = StandardScaler().fit(X) X = self.scaler.transform(X) nd = self.root # 当前需要确定的节点 idxs = range(len(X)) # 当前点需要分开的区域包含的数据集下标 # BFS构建KD树 que = [(nd,idxs)] # 队列节点里是当前搜到的点和他包含的区域 while que: nd, idxs = que.pop(0) # 弹出队头 n = len(idxs) # 如果是叶节点,没啥能分了就返回 if(n == 1): nd.split = (X[idxs[0]],y[idxs[0]]) continue #不是叶节点 # (1)选择特征 if(nd.father == None): nd.feature = 0 else: nd.feature = (nd.father.feature+1)%(np.shape(X)[1]) # (2)根据特征选出中位数,获取他的下标 k = n//2 col = map(lambda i:(i,X[i][nd.feature]),idxs) # 把序列号与特征抽出来 sorted_idxs = map(lambda x:x[0],sorted(col,key = lambda x:x[1])) #col按照特征值排序,并返回排序后的下标数组 median_idx = list(sorted_idxs)[k] #拿出来中位数对应下标 nd.split = (X[median_idx],y[median_idx]) # (3)根据中位数将点分到左右儿子上 idxs_left = [] idxs_right = [] split_val = X[median_idx][nd.feature] for idx in idxs: xi = X[idx][nd.feature] if idx == median_idx: continue # 就是你让我改了一下午??? if xi < split_val: idxs_left.append(idx) else: idxs_right.append(idx) #(4) 如果左右儿子还能分,将他们加到队列中 if idxs_left != []: nd.left = Node() nd.left.father = nd que.append((nd.left,idxs_left)) if idxs_right != []: nd.right = Node() nd.right.father = nd que.append((nd.right,idxs_right)) def dfs(self,Xi,nd): """ 从nd开始dfs直到叶节点,返回叶节点(可能的最近点) """ while nd.right or nd.left: if nd.right is None: nd = nd.left elif nd.right is None: nd = nd.left else: if Xi[nd.feature] <= nd.split[0][nd.feature]: nd = nd.left else: nd = nd.right return nd def n_n_search(self,Xi,k=1): """ 返回与Xi最邻近的K个元素 """ # 标准化 Xi = self.scaler.transform([Xi]) Xi = Xi[0] # 新建最小堆 h = [] #(0) 从根DFS到叶子节点找到第一个可能的最近点,初始化最优解和搜索队列 nd_cur= self.dfs(Xi,self.root) que = [(self.root, nd_cur)] # 向上搜索 while que: nd_root, nd_cur = que.pop(0) while 1: dist = np.linalg.norm(nd_cur.split[0]-Xi)**2 # 当前节点到Xi的欧氏距离,更新最优解和判断相交都要用 # (1) 如果比堆顶更优,更新堆 if len(h) < k: heapq.heappush(h,(-dist,nd_cur.split)) else: tmp = heapq.heappop(h) if tmp[0] < -dist: heapq.heappush(h,(-dist,nd_cur.split)) else: heapq.heappush(h,tmp) # (2) 如果是根节点,继续搜索下一个可能的最近点 if nd_cur is nd_root: break # (3) 如果不是根节点,检查兄弟节点区域是否相交,相交的话DFS兄弟节点,并将新的可能的最近点加到队列中,然后接着向上搜索 nd_bro = nd_cur.brother if nd_bro is not None: dist_hyper = (Xi[nd_bro.father.feature]-nd_bro.split[0][nd_bro.father.feature]) **2 #到超平面的距离 #就是你让我改了一下午??? if dist > dist_hyper: _nd_best = self.dfs(Xi,nd_bro) que.append((nd_bro,_nd_best)) nd_cur = nd_cur.father return h X = [[2,3],[4,7],[5,4],[7,2],[8,1],[9,6]] y = [1,2,3,4,5,6] kdtree = KDTree() kdtree.build_tree(X,y) test = list(kdtree.n_n_search([3,6],3)) test = list(map(lambda x:(-x[0],x[1][1]),test)) print(test)