xref: /haiku/headers/private/kernel/util/SplayTree.h (revision b671e9bbdbd10268a042b4f4cc4317ccd03d105e)
1 /*
2  * Copyright 2008, Ingo Weinhold <ingo_weinhold@gmx.de>.
3  * Distributed under the terms of the MIT License.
4  *
5  * Original Java implementation:
6  * Available at http://www.link.cs.cmu.edu/splay/
7  * Author: Danny Sleator <sleator@cs.cmu.edu>
8  * This code is in the public domain.
9  */
10 #ifndef KERNEL_UTIL_SPLAY_TREE_H
11 #define KERNEL_UTIL_SPLAY_TREE_H
12 
13 /*!	Implements two classes:
14 
15 	SplayTree: A top-down splay tree.
16 
17 	IteratableSplayTree: Extends SplayTree by a singly-linked list to make it
18 	cheaply iteratable (requires another pointer per node).
19 
20 	Both classes are templatized over a definition parameter with the following
21 	(or a compatible) interface:
22 
23 	struct SplayTreeDefinition {
24 		typedef xxx KeyType;
25 		typedef	yyy NodeType;
26 
27 		static const KeyType& GetKey(const NodeType* node);
28 		static SplayTreeLink<NodeType>* GetLink(NodeType* node);
29 
30 		static int Compare(const KeyType& key, const NodeType* node);
31 
32 		// for IteratableSplayTree only
33 		static NodeType** GetListLink(NodeType* node);
34 	};
35 */
36 
37 
38 template<typename Node>
39 struct SplayTreeLink {
40 	Node*	left;
41 	Node*	right;
42 };
43 
44 
45 template<typename Definition>
46 class SplayTree {
47 protected:
48 	typedef typename Definition::KeyType	Key;
49 	typedef typename Definition::NodeType	Node;
50 	typedef SplayTreeLink<Node>				Link;
51 
52 public:
53 	SplayTree()
54 		:
55 		fRoot(NULL)
56 	{
57 	}
58 
59 	/*!
60 		Insert into the tree.
61 		\param node the item to insert.
62 	*/
63 	bool Insert(Node* node)
64 	{
65 		Link* nodeLink = Definition::GetLink(node);
66 
67 		if (fRoot == NULL) {
68 			fRoot = node;
69 			nodeLink->left = NULL;
70 			nodeLink->right = NULL;
71 			return true;
72 		}
73 
74 		Key key = Definition::GetKey(node);
75 		_Splay(key);
76 
77 		int c = Definition::Compare(key, fRoot);
78 		if (c == 0)
79 			return false;
80 
81 		Link* rootLink = Definition::GetLink(fRoot);
82 
83 		if (c < 0) {
84 			nodeLink->left = rootLink->left;
85 			nodeLink->right = fRoot;
86 			rootLink->left = NULL;
87 		} else {
88 			nodeLink->right = rootLink->right;
89 			nodeLink->left = fRoot;
90 			rootLink->right = NULL;
91 		}
92 
93 		fRoot = node;
94 		return true;
95 	}
96 
97 	Node* Remove(const Key& key)
98 	{
99 		if (fRoot == NULL)
100 			return NULL;
101 
102 		_Splay(key);
103 
104 		if (Definition::Compare(key, fRoot) != 0)
105 			return NULL;
106 
107 		// Now delete the root
108 		Node* node = fRoot;
109 		Link* rootLink = Definition::GetLink(fRoot);
110 		if (rootLink->left == NULL) {
111 			fRoot = rootLink->right;
112 		} else {
113 			Node* temp = rootLink->right;
114 			fRoot = rootLink->left;
115 			_Splay(key);
116 			Definition::GetLink(fRoot)->right = temp;
117 		}
118 
119 		return node;
120     }
121 
122 	/*!
123 		Remove from the tree.
124 		\param node the item to remove.
125 	*/
126 	bool Remove(Node* node)
127 	{
128 		Key key = Definition::GetKey(node);
129 		_Splay(key);
130 
131 		if (node != fRoot)
132 			return false;
133 
134 		// Now delete the root
135 		Link* rootLink = Definition::GetLink(fRoot);
136 		if (rootLink->left == NULL) {
137 			fRoot = rootLink->right;
138 		} else {
139 			Node* temp = rootLink->right;
140 			fRoot = rootLink->left;
141 			_Splay(key);
142 			Definition::GetLink(fRoot)->right = temp;
143 		}
144 
145 		return true;
146     }
147 
148 	/*!
149 		Find the smallest item in the tree.
150 	*/
151 	Node* FindMin()
152 	{
153 		if (fRoot == NULL)
154 			return NULL;
155 
156 		Node* node = fRoot;
157 
158 		while (Node* left = Definition::GetLink(node)->left)
159 			node = left;
160 
161 		_Splay(Definition::GetKey(node));
162 
163 		return node;
164 	}
165 
166     /*!
167 		Find the largest item in the tree.
168      */
169 	Node* FindMax()
170 	{
171 		if (fRoot == NULL)
172 			return NULL;
173 
174 		Node* node = fRoot;
175 
176 		while (Node* right = Definition::GetLink(node)->right)
177 			node = right;
178 
179 		_Splay(Definition::GetKey(node));
180 
181 		return node;
182     }
183 
184     /*!
185 		Find an item in the tree.
186 	*/
187 	Node* Lookup(const Key& key)
188 	{
189 		if (fRoot == NULL)
190 			return NULL;
191 
192 		_Splay(key);
193 
194 		return Definition::Compare(key, fRoot) == 0 ? fRoot : NULL;
195     }
196 
197 	Node* Root() const
198 	{
199 		return fRoot;
200 	}
201 
202     /*!
203 		Test if the tree is logically empty.
204 		\return true if empty, false otherwise.
205 	*/
206 	bool IsEmpty() const
207 	{
208 		return fRoot == NULL;
209 	}
210 
211 	Node* PreviousDontSplay(const Key& key) const
212 	{
213 		Node* closestNode = NULL;
214 		Node* node = fRoot;
215 		while (node != NULL) {
216 			if (Definition::Compare(key, node) > 0) {
217 				closestNode = node;
218 				node = Definition::GetLink(node)->right;
219 			} else
220 				node = Definition::GetLink(node)->left;
221 		}
222 
223 		return closestNode;
224 	}
225 
226 	Node* FindClosest(const Key& key, bool greater, bool orEqual)
227 	{
228 		if (fRoot == NULL)
229 			return NULL;
230 
231 		_Splay(key);
232 
233 		Node* closestNode = NULL;
234 		Node* node = fRoot;
235 		while (node != NULL) {
236 			int compare = Definition::Compare(key, node);
237 			if (compare == 0 && orEqual)
238 				return node;
239 
240 			if (greater) {
241 				if (compare < 0) {
242 					closestNode = node;
243 					node = Definition::GetLink(node)->left;
244 				} else
245 					node = Definition::GetLink(node)->right;
246 			} else {
247 				if (compare > 0) {
248 					closestNode = node;
249 					node = Definition::GetLink(node)->right;
250 				} else
251 					node = Definition::GetLink(node)->left;
252 			}
253 		}
254 
255 		return closestNode;
256 	}
257 
258 private:
259 	/*!
260 		Internal method to perform a top-down splay.
261 
262 		_Splay(key) does the splay operation on the given key.
263 		If key is in the tree, then the node containing
264 		that key becomes the root.  If key is not in the tree,
265 		then after the splay, key.root is either the greatest key
266 		< key in the tree, or the least key > key in the tree.
267 
268 		This means, among other things, that if you splay with
269 		a key that's larger than any in the tree, the rightmost
270 		node of the tree becomes the root.  This property is used
271 		in the Remove() method.
272 	*/
273     void _Splay(const Key& key) {
274 		Link headerLink;
275 		headerLink.left = headerLink.right = NULL;
276 
277 		Link* lLink = &headerLink;
278 		Link* rLink = &headerLink;
279 
280 		Node* l = NULL;
281 		Node* r = NULL;
282 		Node* t = fRoot;
283 
284 		for (;;) {
285 			int c = Definition::Compare(key, t);
286 			if (c < 0) {
287 				Node*& left = Definition::GetLink(t)->left;
288 				if (left == NULL)
289 					break;
290 
291 				if (Definition::Compare(key, left) < 0) {
292 					// rotate right
293 					Node* y = left;
294 					Link* yLink = Definition::GetLink(y);
295 					left = yLink->right;
296 					yLink->right = t;
297 					t = y;
298 					if (yLink->left == NULL)
299 						break;
300 				}
301 
302 				// link right
303 				rLink->left = t;
304 				r = t;
305 				rLink = Definition::GetLink(r);
306 				t = rLink->left;
307 			} else if (c > 0) {
308 				Node*& right = Definition::GetLink(t)->right;
309 				if (right == NULL)
310 					break;
311 
312 				if (Definition::Compare(key, right) > 0) {
313 					// rotate left
314 					Node* y = right;
315 					Link* yLink = Definition::GetLink(y);
316 					right = yLink->left;
317 					yLink->left = t;
318 					t = y;
319 					if (yLink->right == NULL)
320 						break;
321 				}
322 
323 				// link left
324 				lLink->right = t;
325 				l = t;
326 				lLink = Definition::GetLink(l);
327 				t = lLink->right;
328 			} else
329 				break;
330 		}
331 
332 		// assemble
333 		Link* tLink = Definition::GetLink(t);
334 		lLink->right = tLink->left;
335 		rLink->left = tLink->right;
336 		tLink->left = headerLink.right;
337 		tLink->right = headerLink.left;
338 		fRoot = t;
339 	}
340 
341 protected:
342 	Node*	fRoot;
343 };
344 
345 
346 template<typename Definition>
347 class IteratableSplayTree {
348 protected:
349 	typedef typename Definition::KeyType	Key;
350 	typedef typename Definition::NodeType	Node;
351 	typedef SplayTreeLink<Node>				Link;
352 	typedef IteratableSplayTree<Definition>	Tree;
353 
354 public:
355 	class Iterator {
356 	public:
357 		Iterator()
358 		{
359 		}
360 
361 		Iterator(const Iterator& other)
362 		{
363 			*this = other;
364 		}
365 
366 		Iterator(Tree* tree)
367 			:
368 			fTree(tree)
369 		{
370 			Rewind();
371 		}
372 
373 		Iterator(Tree* tree, Node* next)
374 			:
375 			fTree(tree),
376 			fCurrent(NULL),
377 			fNext(next)
378 		{
379 		}
380 
381 		bool HasNext() const
382 		{
383 			return fNext != NULL;
384 		}
385 
386 		Node* Next()
387 		{
388 			fCurrent = fNext;
389 			if (fNext != NULL)
390 				fNext = *Definition::GetListLink(fNext);
391 			return fCurrent;
392 		}
393 
394 		Node* Current()
395 		{
396 			return fCurrent;
397 		}
398 
399 		Node* Remove()
400 		{
401 			Node* element = fCurrent;
402 			if (fCurrent) {
403 				fTree->Remove(fCurrent);
404 				fCurrent = NULL;
405 			}
406 			return element;
407 		}
408 
409 		Iterator &operator=(const Iterator &other)
410 		{
411 			fTree = other.fTree;
412 			fCurrent = other.fCurrent;
413 			fNext = other.fNext;
414 			return *this;
415 		}
416 
417 		void Rewind()
418 		{
419 			fCurrent = NULL;
420 			fNext = fTree->fFirst;
421 		}
422 
423 	private:
424 		Tree*	fTree;
425 		Node*	fCurrent;
426 		Node*	fNext;
427 	};
428 
429 	class ConstIterator {
430 	public:
431 		ConstIterator()
432 		{
433 		}
434 
435 		ConstIterator(const ConstIterator& other)
436 		{
437 			*this = other;
438 		}
439 
440 		ConstIterator(Tree* tree)
441 			:
442 			fTree(tree)
443 		{
444 			Rewind();
445 		}
446 
447 		ConstIterator(Tree* tree, Node* next)
448 			:
449 			fTree(tree),
450 			fNext(next)
451 		{
452 		}
453 
454 		bool HasNext() const
455 		{
456 			return fNext != NULL;
457 		}
458 
459 		Node* Next()
460 		{
461 			Node* node = fNext;
462 			if (fNext != NULL)
463 				fNext = *Definition::GetListLink(fNext);
464 			return node;
465 		}
466 
467 		ConstIterator &operator=(const ConstIterator &other)
468 		{
469 			fTree = other.fTree;
470 			fNext = other.fNext;
471 			return *this;
472 		}
473 
474 		void Rewind()
475 		{
476 			fNext = fTree->fFirst;
477 		}
478 
479 	private:
480 		Tree*	fTree;
481 		Node*	fNext;
482 	};
483 
484 	IteratableSplayTree()
485 		:
486 		fTree(),
487 		fFirst(NULL)
488 	{
489 	}
490 
491 	bool Insert(Node* node)
492 	{
493 		if (!fTree.Insert(node))
494 			return false;
495 
496 		Node** previousNext;
497 		if (Node* previous = fTree.PreviousDontSplay(Definition::GetKey(node)))
498 			previousNext = Definition::GetListLink(previous);
499 		else
500 			previousNext = &fFirst;
501 
502 		*Definition::GetListLink(node) = *previousNext;
503 		*previousNext = node;
504 
505 		return true;
506 	}
507 
508 	Node* Remove(const Key& key)
509 	{
510 		Node* node = fTree.Remove(key);
511 		if (node == NULL)
512 			return NULL;
513 
514 		Node** previousNext;
515 		if (Node* previous = fTree.PreviousDontSplay(key))
516 			previousNext = Definition::GetListLink(previous);
517 		else
518 			previousNext = &fFirst;
519 
520 		*previousNext = *Definition::GetListLink(node);
521 
522 		return node;
523 	}
524 
525 	bool Remove(Node* node)
526 	{
527 		if (!fTree.Remove(node))
528 			return false;
529 
530 		Node** previousNext;
531 		if (Node* previous = fTree.PreviousDontSplay(Definition::GetKey(node)))
532 			previousNext = Definition::GetListLink(previous);
533 		else
534 			previousNext = &fFirst;
535 
536 		*previousNext = *Definition::GetListLink(node);
537 
538 		return true;
539 	}
540 
541 	Node* Lookup(const Key& key)
542 	{
543 		return fTree.Lookup(key);
544 	}
545 
546 	Node* Root() const
547 	{
548 		return fTree.Root();
549 	}
550 
551     /*!
552 		Test if the tree is logically empty.
553 		\return true if empty, false otherwise.
554 	*/
555 	bool IsEmpty() const
556 	{
557 		return fTree.IsEmpty();
558 	}
559 
560 	Node* FindMin()
561 	{
562 		return fTree.FindMin();
563 	}
564 
565 	Node* FindMax()
566 	{
567 		return fTree.FindMax();
568 	}
569 
570 	Iterator GetIterator()
571 	{
572 		return Iterator(this);
573 	}
574 
575 	ConstIterator GetIterator() const
576 	{
577 		return ConstIterator(this);
578 	}
579 
580 	Iterator GetIterator(const Key& key, bool greater, bool orEqual)
581 	{
582 		return Iterator(this, fTree.FindClosest(key, greater, orEqual));
583 	}
584 
585 	ConstIterator GetIterator(const Key& key, bool greater, bool orEqual) const
586 	{
587 		return ConstIterator(this, FindClosest(key, greater, orEqual));
588 	}
589 
590 protected:
591 	friend class Iterator;
592 	friend class ConstIterator;
593 		// needed for gcc 2.95.3 only
594 
595 	SplayTree<Definition>	fTree;
596 	Node*					fFirst;
597 };
598 
599 
600 #endif	// KERNEL_UTIL_SPLAY_TREE_H
601