1 /**
2 	This module is based on https://github.com/Mihail-K/kdtree.
3 */
4 module kdtree;
5 
6 import std.algorithm : map, sum, sort;
7 import std.range : iota, zip;
8 
9 struct KDNode(size_t k, T) if (k > 0) {
10 	private const(T[k]) state;
11 	private KDNode!(k, T)* left = null, right = null;
12 
13 	this(T[k] state...) pure nothrow {
14 		this.state = state;
15 	}
16 }
17 
18 /**
19 	Counts the number of elements in the kd tree.
20 */
21 size_t size(size_t k, T)(in KDNode!(k, T)* node) pure nothrow @nogc {
22 	if (node is null) {
23 		return 0;
24 	}
25 
26 	return node.left.size + node.right.size + 1;
27 }
28 
29 /**
30 	Creates a slice of all elements in the kd tree.
31 */
32 T[k][] elements(size_t k, T)(in KDNode!(k, T)* node) pure nothrow {
33 	if (node is null) {
34 		return [];
35 	}
36 
37 	return node.left.elements ~ node.state ~ node.right.elements; // Preserve order
38 }
39 
40 /**
41 	Creates a new kd tree.
42 */
43 template kdTree(size_t k, T) {
44 	KDNode!(k, T)* kdTree() pure nothrow @nogc {
45 		return null;
46 	}
47 
48 	KDNode!(k, T)* kdTree(T[k][] points, size_t depth = 0) pure nothrow {
49 		if (points.length == 0) {
50 			return null;
51 		}
52 		if (points.length == 1) {
53 			return new KDNode!(k, T)(points[0]);
54 		}
55 
56 		immutable axis = depth % k, md = points.length / 2;
57 
58 		quickSelect(points, axis, md);
59 
60 		auto node = new KDNode!(k, T)(points[md]);
61 		node.left = kdTree(points[0 .. md], depth + 1);
62 		node.right = kdTree(points[md + 1 .. $], depth + 1);
63 
64 		return node;
65 	}
66 
67 	// Adapted from https://rosettacode.org/wiki/K-d_tree#Faster_Alternative_Version
68 	void quickSelect(T[k][] points, in size_t axis, in size_t k) pure nothrow @nogc {
69 		import std.algorithm : swap;
70 
71 		size_t start = 0, end = points.length;
72 		if (end < 2)
73 			return;
74 
75 		while (true) {
76 			immutable pivot = points[k][axis];
77 
78 			swap(points[k], points[end - 1]); // Swaps the whole arrays x.
79 			auto store = start;
80 			foreach (p; start .. end) {
81 				if (points[p][axis] < pivot) {
82 					if (p != store)
83 						swap(points[p], points[store]);
84 					store++;
85 				}
86 			}
87 			swap(points[store], points[end - 1]);
88 
89 			// Median has duplicate values.
90 			if (points[store][axis] == points[k][axis])
91 				return;
92 
93 			if (store > k)
94 				end = store;
95 			else
96 				start = store;
97 		}
98 	}
99 }
100 
101 /**
102 	Adds a new point to the kd tree.
103 */
104 void add(size_t k, T)(ref KDNode!(k, T)* root, in T[k] point, size_t depth = 0) pure nothrow {
105 	if (root is null) {
106 		root = new KDNode!(k, T)(point);
107 		return;
108 	}
109 
110 	auto axis = depth % k;
111 	if (point[axis] < root.state[axis]) {
112 		root.left.add(point, depth + 1);
113 	} else {
114 		root.right.add(point, depth + 1);
115 	}
116 }
117 
118 /**
119 	Rebalances the kd tree by creating a new tree with the same elements
120 */
121 void rebalance(size_t k, T)(ref KDNode!(k, T)* root) pure nothrow {
122 	root = kdTree(root.elements);
123 }
124 
125 /**
126 	Finds the neares neighbor in the kd tree using euclidean distance metric.
127 	root must not be empty.
128 */
129 const(T[k]) nearest(size_t k, T)(in KDNode!(k, T)* root, in auto ref T[k] point) pure nothrow @nogc
130 in {
131 	assert(root !is null, "tree is empty");
132 }
133 body {
134 	const(T[k])* nearest = null;
135 	double nearestDistance;
136 
137 	static double distanceSq(in T[k] a, in T[k] b) {
138 		double sum = (b[0] - a[0]) ^^ 2;
139 		static foreach (i; 1 .. k) {
140 			sum += (b[i] - a[i]) ^^ 2;
141 		}
142 		return sum;
143 	}
144 
145 	void nearestImpl(in KDNode!(k, T)* current, in ref T[k] point, size_t depth = 0) {
146 		if (current !is null) {
147 			immutable axis = depth % k;
148 			immutable distance = distanceSq(current.state, point);
149 
150 			if (nearest is null || distance < nearestDistance) {
151 				nearestDistance = distance;
152 				nearest = &current.state;
153 			}
154 
155 			if (nearestDistance > 0) {
156 				immutable distanceAxis = (current.state[axis] - point[axis]);
157 
158 				nearestImpl(distanceAxis > 0 ? current.left : current.right, point, depth + 1);
159 
160 				if (distanceAxis ^^ 2 <= nearestDistance) {
161 					nearestImpl(distanceAxis > 0 ? current.right : current.left, point, depth + 1);
162 				}
163 			}
164 		}
165 	}
166 
167 	nearestImpl(root, point);
168 	return *nearest;
169 }
170 
171 unittest {
172 	import fluent.asserts : should;
173 
174 	auto root = kdTree([[0, 0], [1, 1], [1, 0], [0, 1]]);
175 
176 	root.nearest([0, 0]).should.equal([0, 0]);
177 	root.nearest([1, 1]).should.equal([1, 1]);
178 	root.nearest([-3, 5]).should.equal([0, 1]);
179 	root.nearest([25, -4]).should.equal([1, 0]);
180 
181 	root.add([25, 0]);
182 	root.nearest([25, -4]).should.equal([25, 0]);
183 }
184 
185 /// Test build and nearest
186 unittest {
187 	import fluent.asserts : should;
188 	import std.algorithm : minElement;
189 	import std.numeric : euclideanDistance;
190 	import std.random : uniform01;
191 
192 	auto points = new double[3][1000];
193 	foreach (i; 0 .. points.length) {
194 		foreach (j; 0 .. points[i].length) {
195 			points[i][j] = uniform01;
196 		}
197 	}
198 
199 	auto root = kdTree(points);
200 	root.size.should.equal(points.length);
201 
202 	foreach (_; 0 .. 1000) {
203 		double[3] point = [uniform01, uniform01, uniform01];
204 
205 		root.nearest(point).should.equal(points.minElement!(a => a[0 .. $].euclideanDistance(point[0 .. $])));
206 	}
207 }
208 
209 /// Test add and nearest
210 unittest {
211 	import fluent.asserts : should;
212 	import std.algorithm : minElement;
213 	import std.numeric : euclideanDistance;
214 	import std.random : uniform01;
215 
216 	auto points = new double[3][1000];
217 	foreach (i; 0 .. points.length) {
218 		foreach (j; 0 .. points[i].length) {
219 			points[i][j] = uniform01;
220 		}
221 	}
222 
223 	auto root = kdTree!(3, double);
224 	foreach (i; 0 .. points.length) {
225 		root.add(points[i]);
226 	}
227 	root.size.should.equal(points.length);
228 
229 	foreach (_; 0 .. 1000) {
230 		double[3] point = [uniform01, uniform01, uniform01];
231 
232 		root.nearest(point).should.equal(points.minElement!(a => a[].euclideanDistance(point[])));
233 	}
234 }
235 
236 /// Test rebalance
237 unittest {
238 	import fluent.asserts : should;
239 	import std.algorithm : minElement;
240 	import std.numeric : euclideanDistance;
241 	import std.random : uniform01;
242 
243 	auto points = new double[3][1000];
244 	foreach (i; 0 .. points.length) {
245 		foreach (j; 0 .. points[i].length) {
246 			points[i][j] = uniform01;
247 		}
248 	}
249 
250 	auto root = kdTree!(3, double);
251 	foreach (i; 0 .. points.length) {
252 		root.add(points[i]);
253 	}
254 	root.size.should.equal(points.length);
255 
256 	root.rebalance();
257 	root.size.should.equal(points.length);
258 
259 	foreach (_; 0 .. 1000) {
260 		double[3] point = [uniform01, uniform01, uniform01];
261 
262 		root.nearest(point).should.equal(points.minElement!(a => a[].euclideanDistance(point[])));
263 	}
264 }
265 
266 /// Test nearest on empty tree
267 unittest {
268 	import fluent.asserts : should;
269 	import core.exception : AssertError;
270 
271 	auto root = kdTree!(3, double);
272 	root.nearest([0.0, 0.0, 0.0]).should.throwException!AssertError.withMessage.equal("tree is empty");
273 }