1 /** 2 This module is based on https://github.com/Mihail-K/kdtree. 3 */ 4 module kdtree; 5 6 import std.algorithm : map, sum, sort; 7 import std.range : iota, zip; 8 9 struct KDNode(size_t k, T) if (k > 0) { 10 private const(T[k]) state; 11 private KDNode!(k, T)* left = null, right = null; 12 13 this(T[k] state...) pure nothrow { 14 this.state = state; 15 } 16 } 17 18 /** 19 Counts the number of elements in the kd tree. 20 */ 21 size_t size(size_t k, T)(in KDNode!(k, T)* node) pure nothrow @nogc { 22 if (node is null) { 23 return 0; 24 } 25 26 return node.left.size + node.right.size + 1; 27 } 28 29 /** 30 Creates a slice of all elements in the kd tree. 31 */ 32 T[k][] elements(size_t k, T)(in KDNode!(k, T)* node) pure nothrow { 33 if (node is null) { 34 return []; 35 } 36 37 return node.left.elements ~ node.state ~ node.right.elements; // Preserve order 38 } 39 40 /** 41 Creates a new kd tree. 42 */ 43 template kdTree(size_t k, T) { 44 KDNode!(k, T)* kdTree() pure nothrow @nogc { 45 return null; 46 } 47 48 KDNode!(k, T)* kdTree(T[k][] points, size_t depth = 0) pure nothrow { 49 if (points.length == 0) { 50 return null; 51 } 52 if (points.length == 1) { 53 return new KDNode!(k, T)(points[0]); 54 } 55 56 immutable axis = depth % k, md = points.length / 2; 57 58 quickSelect(points, axis, md); 59 60 auto node = new KDNode!(k, T)(points[md]); 61 node.left = kdTree(points[0 .. md], depth + 1); 62 node.right = kdTree(points[md + 1 .. $], depth + 1); 63 64 return node; 65 } 66 67 // Adapted from https://rosettacode.org/wiki/K-d_tree#Faster_Alternative_Version 68 void quickSelect(T[k][] points, in size_t axis, in size_t k) pure nothrow @nogc { 69 import std.algorithm : swap; 70 71 size_t start = 0, end = points.length; 72 if (end < 2) 73 return; 74 75 while (true) { 76 immutable pivot = points[k][axis]; 77 78 swap(points[k], points[end - 1]); // Swaps the whole arrays x. 79 auto store = start; 80 foreach (p; start .. end) { 81 if (points[p][axis] < pivot) { 82 if (p != store) 83 swap(points[p], points[store]); 84 store++; 85 } 86 } 87 swap(points[store], points[end - 1]); 88 89 // Median has duplicate values. 90 if (points[store][axis] == points[k][axis]) 91 return; 92 93 if (store > k) 94 end = store; 95 else 96 start = store; 97 } 98 } 99 } 100 101 /** 102 Adds a new point to the kd tree. 103 */ 104 void add(size_t k, T)(ref KDNode!(k, T)* root, in T[k] point, size_t depth = 0) pure nothrow { 105 if (root is null) { 106 root = new KDNode!(k, T)(point); 107 return; 108 } 109 110 auto axis = depth % k; 111 if (point[axis] < root.state[axis]) { 112 root.left.add(point, depth + 1); 113 } else { 114 root.right.add(point, depth + 1); 115 } 116 } 117 118 /** 119 Rebalances the kd tree by creating a new tree with the same elements 120 */ 121 void rebalance(size_t k, T)(ref KDNode!(k, T)* root) pure nothrow { 122 root = kdTree(root.elements); 123 } 124 125 /** 126 Finds the neares neighbor in the kd tree using euclidean distance metric. 127 root must not be empty. 128 */ 129 const(T[k]) nearest(size_t k, T)(in KDNode!(k, T)* root, in auto ref T[k] point) pure nothrow @nogc 130 in { 131 assert(root !is null, "tree is empty"); 132 } 133 body { 134 const(T[k])* nearest = null; 135 double nearestDistance; 136 137 static double distanceSq(in T[k] a, in T[k] b) { 138 double sum = (b[0] - a[0]) ^^ 2; 139 static foreach (i; 1 .. k) { 140 sum += (b[i] - a[i]) ^^ 2; 141 } 142 return sum; 143 } 144 145 void nearestImpl(in KDNode!(k, T)* current, in ref T[k] point, size_t depth = 0) { 146 if (current !is null) { 147 immutable axis = depth % k; 148 immutable distance = distanceSq(current.state, point); 149 150 if (nearest is null || distance < nearestDistance) { 151 nearestDistance = distance; 152 nearest = ¤t.state; 153 } 154 155 if (nearestDistance > 0) { 156 immutable distanceAxis = (current.state[axis] - point[axis]); 157 158 nearestImpl(distanceAxis > 0 ? current.left : current.right, point, depth + 1); 159 160 if (distanceAxis ^^ 2 <= nearestDistance) { 161 nearestImpl(distanceAxis > 0 ? current.right : current.left, point, depth + 1); 162 } 163 } 164 } 165 } 166 167 nearestImpl(root, point); 168 return *nearest; 169 } 170 171 unittest { 172 import fluent.asserts : should; 173 174 auto root = kdTree([[0, 0], [1, 1], [1, 0], [0, 1]]); 175 176 root.nearest([0, 0]).should.equal([0, 0]); 177 root.nearest([1, 1]).should.equal([1, 1]); 178 root.nearest([-3, 5]).should.equal([0, 1]); 179 root.nearest([25, -4]).should.equal([1, 0]); 180 181 root.add([25, 0]); 182 root.nearest([25, -4]).should.equal([25, 0]); 183 } 184 185 /// Test build and nearest 186 unittest { 187 import fluent.asserts : should; 188 import std.algorithm : minElement; 189 import std.numeric : euclideanDistance; 190 import std.random : uniform01; 191 192 auto points = new double[3][1000]; 193 foreach (i; 0 .. points.length) { 194 foreach (j; 0 .. points[i].length) { 195 points[i][j] = uniform01; 196 } 197 } 198 199 auto root = kdTree(points); 200 root.size.should.equal(points.length); 201 202 foreach (_; 0 .. 1000) { 203 double[3] point = [uniform01, uniform01, uniform01]; 204 205 root.nearest(point).should.equal(points.minElement!(a => a[0 .. $].euclideanDistance(point[0 .. $]))); 206 } 207 } 208 209 /// Test add and nearest 210 unittest { 211 import fluent.asserts : should; 212 import std.algorithm : minElement; 213 import std.numeric : euclideanDistance; 214 import std.random : uniform01; 215 216 auto points = new double[3][1000]; 217 foreach (i; 0 .. points.length) { 218 foreach (j; 0 .. points[i].length) { 219 points[i][j] = uniform01; 220 } 221 } 222 223 auto root = kdTree!(3, double); 224 foreach (i; 0 .. points.length) { 225 root.add(points[i]); 226 } 227 root.size.should.equal(points.length); 228 229 foreach (_; 0 .. 1000) { 230 double[3] point = [uniform01, uniform01, uniform01]; 231 232 root.nearest(point).should.equal(points.minElement!(a => a[].euclideanDistance(point[]))); 233 } 234 } 235 236 /// Test rebalance 237 unittest { 238 import fluent.asserts : should; 239 import std.algorithm : minElement; 240 import std.numeric : euclideanDistance; 241 import std.random : uniform01; 242 243 auto points = new double[3][1000]; 244 foreach (i; 0 .. points.length) { 245 foreach (j; 0 .. points[i].length) { 246 points[i][j] = uniform01; 247 } 248 } 249 250 auto root = kdTree!(3, double); 251 foreach (i; 0 .. points.length) { 252 root.add(points[i]); 253 } 254 root.size.should.equal(points.length); 255 256 root.rebalance(); 257 root.size.should.equal(points.length); 258 259 foreach (_; 0 .. 1000) { 260 double[3] point = [uniform01, uniform01, uniform01]; 261 262 root.nearest(point).should.equal(points.minElement!(a => a[].euclideanDistance(point[]))); 263 } 264 } 265 266 /// Test nearest on empty tree 267 unittest { 268 import fluent.asserts : should; 269 import core.exception : AssertError; 270 271 auto root = kdTree!(3, double); 272 root.nearest([0.0, 0.0, 0.0]).should.throwException!AssertError.withMessage.equal("tree is empty"); 273 }