11'''
22Implementation of an intersection tree approach for efficiently finding
33intersections among a set of query intervals and a database intervals.
4+
5+ This module provides:
6+ - Node: A tree node representing an interval with efficient intersection search
7+ - Utility functions for generating intervals, creating databases, and executing queries
8+
9+ Note: The tree structure is a binary search tree ordered by interval start points.
10+ In worst case scenarios (e.g., with sorted input), the tree may become unbalanced
11+ and degrade to O(n) performance instead of the optimal O(log n).
12+
13+ Example usage:
14+ >>> from intersection_tree import create_db, create_queries, execute_queries
15+ >>> db = create_db(size=100)
16+ >>> queries = create_queries(size=10)
17+ >>> results = execute_queries(queries, db)
418'''
519
620import random
1226QueryResult : typing .TypeAlias = list [tuple [Interval , Interval ]]
1327
1428class Node :
15- '''Class representting a node in an intersection tree.
29+ '''Class representing a node in an intersection tree.
1630
1731 Each node represents an interval in the dataset, so it has the start and the
1832 end of that interval. It also contains a reference to its left and right child.
1933 For indexing purposes, it also contains the maximum end value over all its children.
2034 '''
2135
22- _start : int
23- _end : int
24- _max_end : int
25- _left : 'Node' = None
26- _right : 'Node' = None
27-
2836 def __init__ (self , interval : Interval ) -> None :
2937 '''Initialize node representing the interval [start, end).
3038
@@ -33,9 +41,14 @@ def __init__(self, interval: Interval) -> None:
3341 interval: Interval
3442 the interval represented by this node
3543 '''
36- self ._start = interval [0 ]
37- self ._end = interval [1 ]
38- self ._max_end = interval [1 ]
44+ if interval [0 ] >= interval [1 ]:
45+ raise ValueError (f"Invalid interval: start ({ interval [0 ]} ) must be less than end ({ interval [1 ]} )" )
46+
47+ self ._start : int = interval [0 ]
48+ self ._end : int = interval [1 ]
49+ self ._max_end : int = interval [1 ]
50+ self ._left : 'Node | None' = None
51+ self ._right : 'Node | None' = None
3952
4053 def insert (self , interval : Interval ) -> None :
4154 '''Insert a new interval [start, end) in the tree.
@@ -44,7 +57,15 @@ def insert(self, interval: Interval) -> None:
4457 ----------
4558 interval: Interval
4659 the interval to insert
60+
61+ Raises
62+ ------
63+ ValueError
64+ if interval start is not less than end
4765 '''
66+ if interval [0 ] >= interval [1 ]:
67+ raise ValueError (f"Invalid interval: start ({ interval [0 ]} ) must be less than end ({ interval [1 ]} )" )
68+
4869 if interval [0 ] < self ._start :
4970 if self ._left is None :
5071 self ._left = Node (interval )
@@ -72,7 +93,7 @@ def search(self, interval: Interval, results: list[Interval]) -> None:
7293 results .append ((self ._start , self ._end ))
7394 if self ._left is not None and self ._left ._max_end >= interval [0 ]:
7495 self ._left .search (interval , results )
75- if self ._right is not None and self ._start < = interval [1 ]:
96+ if self ._right is not None and self ._right . _max_end > = interval [0 ]:
7697 self ._right .search (interval , results )
7798
7899 def to_str (self , prefix : str = '' ) -> str :
@@ -127,10 +148,18 @@ def generate_interval(max_end: int = 1_000_000_000) -> Interval:
127148 Returns
128149 -------
129150 Interval
130- Tuple (start, end) such that end - start > 1
151+ Tuple (start, end) such that end - start >= 1
152+
153+ Raises
154+ ------
155+ ValueError
156+ if max_end is less than 2
131157 '''
158+ if max_end < 2 :
159+ raise ValueError (f"max_end must be at least 2, got { max_end } " )
160+
132161 start = random .randint (0 , max_end - 2 )
133- end = random .randint (start + 2 , max_end )
162+ end = random .randint (start + 1 , max_end )
134163 return start , end
135164
136165
@@ -142,7 +171,7 @@ def create_db(size: int, max_end: int = 1_000_000) -> Node:
142171 size: int
143172 number of intervals in the database
144173 max_end: int
145- largest end value of the interval, default value 1_000_000_000
174+ largest end value of the interval, default value 1_000_000
146175
147176 Returns
148177 -------
@@ -162,7 +191,7 @@ def execute_queries(queries: Queries, db: Node) -> QueryResult:
162191 ----------
163192 queries: Queries
164193 queries to be executed
165- db: Db
194+ db: Node
166195 database to query
167196
168197 Returns
@@ -186,7 +215,7 @@ def create_queries(size: int = 1_000, max_end: int = 1_000_000) -> Queries:
186215 size: int
187216 number of intervals in the query, default value 1_000
188217 max_end: int
189- largest end value of the interval, default value 1_000_000_000
218+ largest end value of the interval, default value 1_000_000
190219
191220 Returns
192221 -------
0 commit comments