kd树是一个二叉树结构,它的每个节点记载了 [特征坐标, 切分轴, 指向左枝的指针, 指向右枝的指针] 。 其中, 特征坐标是线性空间 R n \mathbb{R}^{n} Rn 中的一个点 ( x 1 , x 2 , … , x n ) ∘ \left(x_{1}, x_{2}, \ldots, x_{n}\right)_{\circ} (x1,x2,…,xn)∘html
切分轴由一个整数 r r r 表示, 这里 1 ≤ r ≤ n , 1 \leq r \leq n, 1≤r≤n, 是咱们在 n n n 维空间中沿第 r r r 维进行一次分割。 节点的左枝和右枝分别都是 kd 树, 而且知足:若是 y y y 是左枝的一个特征坐标, 那么 y r ≤ x r ; y_{r} \leq x_{r} ; yr≤xr; 而且若是 z z z 是右 枝的一个特征坐标,那么 z r ≥ x r ∘ z_{r} \geq x_{r \circ} zr≥xr∘node
给定一个数据样本集 S ⊆ R n S \subseteq R^{n} S⊆Rn 和切分轴 r , r, r, 如下递归算法将构建一个基于该数据集的 kd 树, 每一次循环制做一 个节点:python
若是 ∣ S ∣ = 1 , |S|=1, ∣S∣=1, 记录 S S S 中惟一的一个点为当前节点的特征数据, 而且不设左枝和右枝。 ( ∣ S ∣ \quad(|S| (∣S∣ 指集合 S S S 中元素 ) ) ). 的数量) − - − 若是 ∣ S ∣ > 1 : |S|>1: ∣S∣>1:web
若是 ∣ S ∣ > 1 : |S|>1: ∣S∣>1:算法
上面抽象的定义和算法确实是很很差理解,举一个例子会清楚不少。首先随机在 R 2 \mathbb{R}^{2} R2 中随机生成 13 个点做为咱们的数据集。起始的切分轴 r = 0 ; r=0 ; r=0; 这里 r = 0 r=0 r=0 对应 x x x 轴, 而 r = 1 r=1 r=1 对应 y y y 轴。
首先先沿 x 坐标进行切分,咱们选出 x 坐标的中位点,获取最根部节点的坐标
app
而且按照该点的x坐标将空间进行切分,全部 x 坐标小于 6.27 的数据用于构建左枝,x坐标大于 6.27 的点用于构建右枝。
在下一步中 r = 0 + 1 = 1 mod 2 对应 y 轴, 左右两边再按照 y 轴的排序进行切分,中位点记裁于左右枝的 \text { 在下一步中 } r=0+1=1 \quad \text { mod } 2 \text { 对应 } y \text { 轴, 左右两边再按照 } y \text { 轴的排序进行切分,中位点记裁于左右枝的 } 在下一步中 r=0+1=1 mod 2 对应 y 轴, 左右两边再按照 y 轴的排序进行切分,中位点记裁于左右枝的 节点。获得下面的树,左边的x 是指这该层的节点都是沿 x 轴进行分割的。
空间的切分以下
下一步中 r ≡ 1 + 1 ≡ 0 m o d 2 , r \equiv 1+1 \equiv 0 \quad \bmod 2, r≡1+1≡0mod2, 对应 x x x 轴, 因此下面再按照 x x x 除标进行排序和切分,有
最后每一部分都只剩一个点,将他们记在最底部的节点中。由于再也不有未被记录的点,因此再也不进行切分。
就此完成了 kd 树的构造。svg
给定一个构建于一个样本生的 kd 树, 下面的算法能够寻找距离某个点 p p p 最近的 k k k 个样本。函数
设 L L L 为一个有 k k k 个空位的列表, 用于保存已搜寻到的最近点.oop
根据 p p p 的坐标值和每一个节点的切分向下搜素(也就是选,若是树的节点是按照 x r = a x_{r}=a xr=a 进行切分,而且 p p p 的 r r r 坐标小于 a , a, a, 则向左枝进行搜索: 反之则走右枝)。fetch
当达到一个底部节点时,将其标记为访问过. 若是 L L L 里不足 k k k 个点. 则将当前节点的特征坐标加人 L : L: L: 如 果 L不为空而且当前节点 \quad 的特征与 p p p 的距离小于 L L L 里最长的距离,则用当前特征音换掉 L L L 中离 p p p 最远的点
若是当前节点不是整棵树最顶而节点, 执行 下(1):反之. 输出 L , L, L, 算法完成.
(1) . 向上爬一个节点。若是当前 (向上爬以后的) 节点未管被访问过, 将其标记为被访问过, 而后执行 1和2:若是当前节点被访 问过, 再次执行 (1)。
来看下面的例子:
首先执行1,咱们按照切分找到最底部节点。首先,咱们在顶部开始
和这个节点的 x轴比较一下,
ppp 的 x 轴更小。所以咱们向左枝进行搜索:
此次对比 y 轴,
p 的 y 值更小,所以向左枝进行搜索:
这个节点只有一个子枝,就不须要对比了。由此找到了最底部的节点 (−4.6,−10.55)。
在二维图上是
此时咱们执行2。将当前结点标记为访问过, 并记录下 L = [ ( − 4.6 , − 10.55 ) ] . L=[(-4.6,-10.55)] . L=[(−4.6,−10.55)]. 访问过的节点就在二叉树 上显示为被划掉的好了。
而后执行 3,不是最顶端节点。执行 (1),我爬。上面的是 (−6.88,−5.4)。
执行 1,由于咱们记录下的点只有一个,小于k=3,因此也将当前节点记录下,有 L=[(−4.6,−10.55),(−6.88,−5.4)]。再执行 2,由于当前节点的左枝是空的,因此直接跳过,回到步骤3。3看了一眼,好,不是顶部,交给你了,(1)。因而乎 (1) 又往上爬了一节。
1 说,因为仍是不够三个点,因而将当前点也记录下,有 L=[(−4.6,−10.55),(−6.88,−5.4),(1.24,−2.86)。固然,当前结点变为被访问过的。
2又发现,当前节点有其余的分枝,而且经计算得出 p 点和 L 中的三个点的距离分别是 6.62,5.89,3.10,可是 p 和当前节点的分割线的距离只有 2.14,小于与 L 的最大距离:
所以,在分割线的另外一端可能有更近的点。因而咱们在当前结点的另外一个分枝从头执行 1。好,咱们在红线这里:
要用 p 和这个节点比较 x 坐标:
p 的x 坐标更大,所以探索右枝 (1.75,12.26),而且发现右枝已是最底部节点,所以启动 2。
经计算,(1.75,12.26)与 p 的距离是 7.48,要大于 p 与 L 的距离,所以咱们不将其放入记录中。
而后 3 判断出不是顶端节点,呼出 (1),爬。
1出来一算,这个节点与 p 的距离是 4.91,要小于 p 与 L 的最大距离 6.62。
所以,咱们用这个新的节点替代 L 中离 p 最远的 (−4.6,−10.55)。
而后 2又来了,咱们比对 p 和当前节点的分割线的距离
这个距离小于 L 与 p 的最小距离,所以咱们要到当前节点的另外一个枝执行 1。固然,那个枝只有一个点,直接到 2。
计算距离发现这个点离 p 比 L 更远,所以不进行替代。
3发现不是顶点,因此呼出 (1)。咱们向上爬,
这个是已经访问过的了,因此再来(1),
好,(1)再爬,
啊!到顶点了。因此完了吗?固然不,还没轮到 3 呢。如今是 1的回合。
咱们进行计算比对发现顶端节点与p的距离比L还要更远,所以不进行更新。
而后是 2,计算 p 和分割线的距离发现也是更远。
所以也不须要检查另外一个分枝。
而后执行 3,判断当前节点是顶点,所以计算完成!输出距离 p 最近的三个样本是 L=[(−6.88,−5.4),(1.24,−2.86),(−2.96,−2.5)].
C实现
#include <stdio.h> #include <stdlib.h> #include <string.h> #include <float.h> #include <math.h> #include "kdtree.h" static inline int is_leaf(struct kdnode *node) { return node->left == node->right; } static inline void swap(long *a, long *b) { long tmp = *a; *a = *b; *b = tmp; } static inline double square(double d) { return d * d; } static inline double distance(double *c1, double *c2, int dim) { double distance = 0; while (dim-- > 0) { distance += square(*c1++ - *c2++); } return distance; } static inline double knn_max(struct kdtree *tree) { return tree->knn_list_head.prev->distance; } static inline double D(struct kdtree *tree, long index, int r) { return tree->coord_table[index][r]; } static inline int kdnode_passed(struct kdtree *tree, struct kdnode *node) { return node != NULL ? tree->coord_passed[node->coord_index] : 1; } static inline int knn_search_on(struct kdtree *tree, int k, double value, double target) { return tree->knn_num < k || square(target - value) < knn_max(tree); } static inline void coord_index_reset(struct kdtree *tree) { long i; for (i = 0; i < tree->capacity; i++) { tree->coord_indexes[i] = i; } } static inline void coord_table_reset(struct kdtree *tree) { long i; for (i = 0; i < tree->capacity; i++) { tree->coord_table[i] = tree->coords + i * tree->dim; } } static inline void coord_deleted_reset(struct kdtree *tree) { memset(tree->coord_deleted, 0, tree->capacity); } static inline void coord_passed_reset(struct kdtree *tree) { memset(tree->coord_passed, 0, tree->capacity); } static void coord_dump_all(struct kdtree *tree) { long i, j; for (i = 0; i < tree->count; i++) { long index = tree->coord_indexes[i]; double *coord = tree->coord_table[index]; printf("("); for (j = 0; j < tree->dim; j++) { if (j != tree->dim - 1) { printf("%.2f,", coord[j]); } else { printf("%.2f)\n", coord[j]); } } } } static void coord_dump_by_indexes(struct kdtree *tree, long low, long high, int r) { long i; printf("r=%d:", r); for (i = 0; i <= high; i++) { if (i < low) { printf("%8s", " "); } else { long index = tree->coord_indexes[i]; printf("%8.2f", tree->coord_table[index][r]); } } printf("\n"); } static void bubble_sort(struct kdtree *tree, long low, long high, int r) { long i, flag = high + 1; long *indexes = tree->coord_indexes; while (flag > 0) { long len = flag; flag = 0; for (i = low + 1; i < len; i++) { if (D(tree, indexes[i], r) < D(tree, indexes[i - 1], r)) { swap(indexes + i - 1, indexes + i); flag = i; } } } } static void insert_sort(struct kdtree *tree, long low, long high, int r) { long i, j; long *indexes = tree->coord_indexes; for (i = low + 1; i <= high; i++) { long tmp_idx = indexes[i]; double tmp_value = D(tree, indexes[i], r); j = i - 1; for (; j >= low && D(tree, indexes[j], r) > tmp_value; j--) { indexes[j + 1] = indexes[j]; } indexes[j + 1] = tmp_idx; } } static void quicksort(struct kdtree *tree, long low, long high, int r) { if (high - low <= 32) { insert_sort(tree, low, high, r); //bubble_sort(tree, low, high, r); return; } long *indexes = tree->coord_indexes; /* median of 3 */ long mid = low + (high - low) / 2; if (D(tree, indexes[low], r) > D(tree, indexes[mid], r)) { swap(indexes + low, indexes + mid); } if (D(tree, indexes[low], r) > D(tree, indexes[high], r)) { swap(indexes + low, indexes + high); } if (D(tree, indexes[high], r) > D(tree, indexes[mid], r)) { swap(indexes + high, indexes + mid); } /* D(indexes[low]) <= D(indexes[high]) <= D(indexes[mid]) */ double pivot = D(tree, indexes[high], r); /* 3-way partition * +---------+-----------+---------+-------------+---------+ * | pivot | <=pivot | ? | >=pivot | pivot | * +---------+-----------+---------+-------------+---------+ * low lt i j gt high */ long i = low - 1; long lt = i; long j = high; long gt = j; for (; ;) { while (D(tree, indexes[++i], r) < pivot) { } while (D(tree, indexes[--j], r) > pivot && j > low) { } if (i >= j) break; swap(indexes + i, indexes + j); if (D(tree, indexes[i], r) == pivot) swap(&indexes[++lt], &indexes[i]); if (D(tree, indexes[j], r) == pivot) swap(&indexes[--gt], &indexes[j]); } /* i == j or j + 1 == i */ swap(indexes + i, indexes + high); /* Move equal elements to the middle of array */ long x, y; for (x = low, j = i - 1; x <= lt && j > lt; x++, j--) swap(indexes + x, indexes + j); for (y = high, i = i + 1; y >= gt && i < gt; y--, i++) swap(indexes + y, indexes + i); quicksort(tree, low, j - lt + x - 1, r); quicksort(tree, i + y - gt, high, r); } static struct kdnode *kdnode_alloc(double *coord, long index, int r) { struct kdnode *node = malloc(sizeof(*node)); if (node != NULL) { memset(node, 0, sizeof(*node)); node->coord = coord; node->coord_index = index; node->r = r; } return node; } static void kdnode_free(struct kdnode *node) { free(node); } static int coord_cmp(double *c1, double *c2, int dim) { int i; double ret; for (i = 0; i < dim; i++) { ret = *c1++ - *c2++; if (fabs(ret) >= DBL_EPSILON) { return ret > 0 ? 1 : -1; } } if (fabs(ret) < DBL_EPSILON) { return 0; } else { return ret > 0 ? 1 : -1; } } static void knn_list_add(struct kdtree *tree, struct kdnode *node, double distance) { if (node == NULL) return; struct knn_list *head = &tree->knn_list_head; struct knn_list *p = head->prev; if (tree->knn_num == 1) { if (p->distance > distance) { p = p->prev; } } else { while (p != head && p->distance > distance) { p = p->prev; } } if (p == head || coord_cmp(p->node->coord, node->coord, tree->dim)) { struct knn_list *log = malloc(sizeof(*log)); if (log != NULL) { log->node = node; log->distance = distance; log->prev = p; log->next = p->next; p->next->prev = log; p->next = log; tree->knn_num++; } } } static void knn_list_adjust(struct kdtree *tree, struct kdnode *node, double distance) { if (node == NULL) return; struct knn_list *head = &tree->knn_list_head; struct knn_list *p = head->prev; if (tree->knn_num == 1) { if (p->distance > distance) { p = p->prev; } } else { while (p != head && p->distance > distance) { p = p->prev; } } if (p == head || coord_cmp(p->node->coord, node->coord, tree->dim)) { struct knn_list *log = head->prev; /* Replace the original max one */ log->node = node; log->distance = distance; /* Remove from the max position */ head->prev = log->prev; log->prev->next = head; /* insert as a new one */ log->prev = p; log->next = p->next; p->next->prev = log; p->next = log; } } static void knn_list_clear(struct kdtree *tree) { struct knn_list *head = &tree->knn_list_head; struct knn_list *p = head->next; while (p != head) { struct knn_list *prev = p; p = p->next; free(prev); } tree->knn_num = 0; } static void resize(struct kdtree *tree) { tree->capacity *= 2; tree->coords = realloc(tree->coords, tree->dim * sizeof(double) * tree->capacity); tree->coord_table = realloc(tree->coord_table, sizeof(double *) * tree->capacity); tree->coord_indexes = realloc(tree->coord_indexes, sizeof(long) * tree->capacity); tree->coord_deleted = realloc(tree->coord_deleted, sizeof(char) * tree->capacity); tree->coord_passed = realloc(tree->coord_passed, sizeof(char) * tree->capacity); coord_table_reset(tree); coord_index_reset(tree); coord_deleted_reset(tree); coord_passed_reset(tree); } static void kdnode_dump(struct kdnode *node, int dim) { int i; if (node->coord != NULL) { printf("("); for (i = 0; i < dim; i++) { if (i != dim - 1) { printf("%.2f,", node->coord[i]); } else { printf("%.2f)\n", node->coord[i]); } } } else { printf("(none)\n"); } } void kdtree_insert(struct kdtree *tree, double *coord) { if (tree->count + 1 > tree->capacity) { resize(tree); } memcpy(tree->coord_table[tree->count++], coord, tree->dim * sizeof(double)); } static void knn_pickup(struct kdtree *tree, struct kdnode *node, double *target, int k) { double dist = distance(node->coord, target, tree->dim); if (tree->knn_num < k) { knn_list_add(tree, node, dist); } else { if (dist < knn_max(tree)) { knn_list_adjust(tree, node, dist); } else if (fabs(dist - knn_max(tree)) < DBL_EPSILON) { knn_list_add(tree, node, dist); } } } static void kdtree_search_recursive(struct kdtree *tree, struct kdnode *node, double *target, int k, int *pickup) { if (node == NULL || kdnode_passed(tree, node)) { return; } int r = node->r; if (!knn_search_on(tree, k, node->coord[r], target[r])) { return; } if (*pickup) { tree->coord_passed[node->coord_index] = 1; knn_pickup(tree, node, target, k); kdtree_search_recursive(tree, node->left, target, k, pickup); kdtree_search_recursive(tree, node->right, target, k, pickup); } else { if (is_leaf(node)) { *pickup = 1; } else { if (target[r] <= node->coord[r]) { kdtree_search_recursive(tree, node->left, target, k, pickup); kdtree_search_recursive(tree, node->right, target, k, pickup); } else { kdtree_search_recursive(tree, node->right, target, k, pickup); kdtree_search_recursive(tree, node->left, target, k, pickup); } } /* back track and pick up */ if (*pickup) { tree->coord_passed[node->coord_index] = 1; knn_pickup(tree, node, target, k); } } } void kdtree_knn_search(struct kdtree *tree, double *target, int k) { if (k > 0) { int pickup = 0; kdtree_search_recursive(tree, tree->root, target, k, &pickup); } } void kdtree_delete(struct kdtree *tree, double *coord) { int r = 0; struct kdnode *node = tree->root; struct kdnode *parent = node; while (node != NULL) { if (node->coord == NULL) { if (parent->right->coord == NULL) { break; } else { node = parent->right; continue; } } if (coord[r] < node->coord[r]) { parent = node; node = node->left; } else if (coord[r] > node->coord[r]) { parent = node; node = node->right; } else { int ret = coord_cmp(coord, node->coord, tree->dim); if (ret < 0) { parent = node; node = node->left; } else if (ret > 0) { parent = node; node = node->right; } else { node->coord = NULL; break; } } r = (r + 1) % tree->dim; } } static void kdnode_build(struct kdtree *tree, struct kdnode **nptr, int r, long low, long high) { if (low == high) { long index = tree->coord_indexes[low]; *nptr = kdnode_alloc(tree->coord_table[index], index, r); } else if (low < high) { /* Sort and fetch the median to build a balanced BST */ quicksort(tree, low, high, r); long median = low + (high - low) / 2; long median_index = tree->coord_indexes[median]; struct kdnode *node = *nptr = kdnode_alloc(tree->coord_table[median_index], median_index, r); r = (r + 1) % tree->dim; kdnode_build(tree, &node->left, r, low, median - 1); kdnode_build(tree, &node->right, r, median + 1, high); } } static void kdtree_build(struct kdtree *tree) { kdnode_build(tree, &tree->root, 0, 0, tree->count - 1); } void kdtree_rebuild(struct kdtree *tree) { long i, j; size_t size_of_coord = tree->dim * sizeof(double); for (i = 0, j = 0; j < tree->count; i++, j++) { while (j < tree->count && tree->coord_deleted[j]) { j++; } if (i != j && j < tree->count) { memcpy(tree->coord_table[i], tree->coord_table[j], size_of_coord); tree->coord_deleted[i] = 0; } } tree->count = i; coord_index_reset(tree); kdtree_build(tree); } struct kdtree *kdtree_init(int dim) { struct kdtree *tree = malloc(sizeof(*tree)); if (tree != NULL) { tree->root = NULL; tree->dim = dim; tree->count = 0; tree->capacity = 65536; tree->knn_list_head.next = &tree->knn_list_head; tree->knn_list_head.prev = &tree->knn_list_head; tree->knn_list_head.node = NULL; tree->knn_list_head.distance = 0; tree->knn_num = 0; tree->coords = malloc(dim * sizeof(double) * tree->capacity); tree->coord_table = malloc(sizeof(double *) * tree->capacity); tree->coord_indexes = malloc(sizeof(long) * tree->capacity); tree->coord_deleted = malloc(sizeof(char) * tree->capacity); tree->coord_passed = malloc(sizeof(char) * tree->capacity); coord_index_reset(tree); coord_table_reset(tree); coord_deleted_reset(tree); coord_passed_reset(tree); } return tree; } static void kdnode_destroy(struct kdnode *node) { if (node == NULL) return; kdnode_destroy(node->left); kdnode_destroy(node->right); kdnode_free(node); } void kdtree_destroy(struct kdtree *tree) { kdnode_destroy(tree->root); knn_list_clear(tree); free(tree->coords); free(tree->coord_table); free(tree->coord_indexes); free(tree->coord_deleted); free(tree->coord_passed); free(tree); } #define _KDTREE_DEBUG #ifdef _KDTREE_DEBUG struct kdnode_backlog { struct kdnode *node; int next_sub_idx; }; void kdtree_dump(struct kdtree *tree) { int level = 0; struct kdnode *node = tree->root; struct kdnode_backlog nbl, *p_nbl = NULL; struct kdnode_backlog nbl_stack[KDTREE_MAX_LEVEL]; struct kdnode_backlog *top = nbl_stack; for (; ;) { if (node != NULL) { /* Fetch the pop-up backlogged node's sub-id. * If not backlogged, fetch the first sub-id. */ int sub_idx = p_nbl != NULL ? p_nbl->next_sub_idx : KDTREE_RIGHT_INDEX; /* Backlog should be left in next loop */ p_nbl = NULL; /* Backlog the node */ if (is_leaf(node) || sub_idx == KDTREE_LEFT_INDEX) { top->node = NULL; top->next_sub_idx = KDTREE_RIGHT_INDEX; } else { top->node = node; top->next_sub_idx = KDTREE_LEFT_INDEX; } top++; level++; /* Draw lines as long as sub_idx is the first one */ if (sub_idx == KDTREE_RIGHT_INDEX) { int i; for (i = 1; i < level; i++) { if (i == level - 1) { printf("%-8s", "+-------"); } else { if (nbl_stack[i - 1].node != NULL) { printf("%-8s", "|"); } else { printf("%-8s", " "); } } } kdnode_dump(node, tree->dim); } /* Move down according to sub_idx */ node = sub_idx == KDTREE_LEFT_INDEX ? node->left : node->right; } else { p_nbl = top == nbl_stack ? NULL : --top; if (p_nbl == NULL) { /* End of traversal */ break; } node = p_nbl->node; level--; } } } #endif
python
class kdtree(object): # 建立 kdtree # point_list 是一个 list 的 pair,pair[0] 是一 tuple 的特征,pair[1] 是类别 def __init__(self, point_list, depth=0, root=None): if len(point_list)>0: # 轮换按照树深度选择坐标轴 k = len(point_list[0][0]) axis = depth % k # 选中位线,切 point_list.sort(key=lambda x:x[0][axis]) median = len(point_list) // 2 self.axis = axis self.root = root self.size = len(point_list) # 造节点 self.node = point_list[median] # 递归造左枝和右枝 if len(point_list[:median])>0: self.left = kdtree(point_list[:median], depth+1, self) else: self.left = None if len(point_list[median+1:])>0: self.right = kdtree(point_list[median+1:], depth+1, self) else: self.right = None # 记录是按哪一个方向切的还有树根 else: return None # 在树上加一点 def insert(self, point): self.size += 1 # 分析是左仍是右,递归加在叶子上 if point[0][self.axis]<self.node[0][self.axis]: if self.left!=None: self.left.insert(point) else: self.left = kdtree([point], self.axis+1, self) else: if self.right!=None: self.right.insert(point) else: self.right = kdtree([point], self.axis+1, self) # 输入一点 # 按切分寻找叶子 def find_leaf(self, point): if self.left==None and self.right==None: return self elif self.left==None: return self.right.find_leaf(point) elif self.right==None: return self.left.find_leaf(point) elif point[self.axis]<self.node[0][self.axis]: return self.left.find_leaf(point) else: return self.right.find_leaf(point) # 查找最近的 k 个点,复杂度 O(DlogN),D是维度,N是树的大小 # 输入一点、一距离函数、一k。距离函数默认是 L_2 def knearest(self, point, k=1, dist=lambda x,y: sum(map(lambda u,v:(u-v)**2,x,y))): # 往下戳到最底叶 leaf = self.find_leaf(point) # 从叶子网上爬 return leaf.k_down_up(point, k, dist, result=[], stop=self, visited=None) # 从下往上爬函数,stop是到哪里去,visited是从哪里来 def k_down_up(self, point,k, dist, result=[],stop=None, visited=None): # 选最长距离 if result==[]: max_dist = 0 else: max_dist = max([x[1] for x in result]) other_result=[] # 若是离分界线的距离小于现有最大距离,或者数据点不够,就从另外一边的树根开始刨 if (self.left==visited and self.node[0][self.axis]-point[self.axis]<max_dist and self.right!=None)\ or (len(result)<k and self.left==visited and self.right!=None): other_result=self.right.knearest(point,k, dist) if (self.right==visited and point[self.axis]-self.node[0][self.axis]<max_dist and self.left!=None)\ or (len(result)<k and self.right==visited and self.left!=None): other_result=self.left.knearest(point, k, dist) # 刨出来的点放一块儿,选前 k 个 result.append((self.node, dist(point, self.node[0]))) result = sorted(result+other_result, key=lambda pair: pair[1])[:k] # 到停点就返回结果 if self==stop: return result # 没有就带着现有结果接着往上爬 else: return self.root.k_down_up(point,k, dist, result, stop, self) # 输入 特征、类别、k、距离函数 # 返回这个点属于该类别的几率 def kNN_prob(self, point, label, k, dist=lambda x,y: sum(map(lambda u,v:(u-v)**2,x,y))): nearests = self.knearest(point, k, dist) return float(len([pair for pair in nearests if pair[0][1]==label]))/float(len(nearests)) # 输入 特征、k、距离函数 # 返回该点几率最大的类别以及相对应的几率 def kNN(self, point, k, dist=lambda x,y: sum(map(lambda u,v:(u-v)**2,x,y))): nearests = self.knearest(point, k , dist) statistics = { } for data in nearests: label = data[0][1] if label not in statistics: statistics[label] = 1 else: statistics[label] += 1 max_label = max(statistics.iteritems(), key=operator.itemgetter(1))[0] return max_label, float(statistics[max_label])/float(len(nearests))
参考自:JoinQuant量化课堂