Skip to content

Commit b1f13b6

Browse files
authored
Merge pull request #5 from gjbex/copilot/fix-de6b1cb0-cb4f-493a-81b2-b53907cdfd71
Fix critical bugs in intersection_tree.py: class variables, search algorithm, and validation
2 parents e3423bd + 6896688 commit b1f13b6

1 file changed

Lines changed: 45 additions & 16 deletions

File tree

source_code/intersection_trees/intersection_tree.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
'''
22
Implementation of an intersection tree approach for efficiently finding
33
intersections among a set of query intervals and a database intervals.
4+
5+
This module provides:
6+
- Node: A tree node representing an interval with efficient intersection search
7+
- Utility functions for generating intervals, creating databases, and executing queries
8+
9+
Note: The tree structure is a binary search tree ordered by interval start points.
10+
In worst case scenarios (e.g., with sorted input), the tree may become unbalanced
11+
and degrade to O(n) performance instead of the optimal O(log n).
12+
13+
Example usage:
14+
>>> from intersection_tree import create_db, create_queries, execute_queries
15+
>>> db = create_db(size=100)
16+
>>> queries = create_queries(size=10)
17+
>>> results = execute_queries(queries, db)
418
'''
519

620
import random
@@ -12,19 +26,13 @@
1226
QueryResult: typing.TypeAlias = list[tuple[Interval, Interval]]
1327

1428
class Node:
15-
'''Class representting a node in an intersection tree.
29+
'''Class representing a node in an intersection tree.
1630
1731
Each node represents an interval in the dataset, so it has the start and the
1832
end of that interval. It also contains a reference to its left and right child.
1933
For indexing purposes, it also contains the maximum end value over all its children.
2034
'''
2135

22-
_start: int
23-
_end: int
24-
_max_end: int
25-
_left: 'Node' = None
26-
_right: 'Node' = None
27-
2836
def __init__(self, interval: Interval) -> None:
2937
'''Initialize node representing the interval [start, end).
3038
@@ -33,9 +41,14 @@ def __init__(self, interval: Interval) -> None:
3341
interval: Interval
3442
the interval represented by this node
3543
'''
36-
self._start = interval[0]
37-
self._end = interval[1]
38-
self._max_end = interval[1]
44+
if interval[0] >= interval[1]:
45+
raise ValueError(f"Invalid interval: start ({interval[0]}) must be less than end ({interval[1]})")
46+
47+
self._start: int = interval[0]
48+
self._end: int = interval[1]
49+
self._max_end: int = interval[1]
50+
self._left: 'Node | None' = None
51+
self._right: 'Node | None' = None
3952

4053
def insert(self, interval: Interval) -> None:
4154
'''Insert a new interval [start, end) in the tree.
@@ -44,7 +57,15 @@ def insert(self, interval: Interval) -> None:
4457
----------
4558
interval: Interval
4659
the interval to insert
60+
61+
Raises
62+
------
63+
ValueError
64+
if interval start is not less than end
4765
'''
66+
if interval[0] >= interval[1]:
67+
raise ValueError(f"Invalid interval: start ({interval[0]}) must be less than end ({interval[1]})")
68+
4869
if interval[0] < self._start:
4970
if self._left is None:
5071
self._left = Node(interval)
@@ -72,7 +93,7 @@ def search(self, interval: Interval, results: list[Interval]) -> None:
7293
results.append((self._start, self._end))
7394
if self._left is not None and self._left._max_end >= interval[0]:
7495
self._left.search(interval, results)
75-
if self._right is not None and self._start <= interval[1]:
96+
if self._right is not None and self._right._max_end >= interval[0]:
7697
self._right.search(interval, results)
7798

7899
def to_str(self, prefix: str = '') -> str:
@@ -127,10 +148,18 @@ def generate_interval(max_end: int = 1_000_000_000) -> Interval:
127148
Returns
128149
-------
129150
Interval
130-
Tuple (start, end) such that end - start > 1
151+
Tuple (start, end) such that end - start >= 1
152+
153+
Raises
154+
------
155+
ValueError
156+
if max_end is less than 2
131157
'''
158+
if max_end < 2:
159+
raise ValueError(f"max_end must be at least 2, got {max_end}")
160+
132161
start = random.randint(0, max_end - 2)
133-
end = random.randint(start + 2, max_end)
162+
end = random.randint(start + 1, max_end)
134163
return start, end
135164

136165

@@ -142,7 +171,7 @@ def create_db(size: int, max_end: int = 1_000_000) -> Node:
142171
size: int
143172
number of intervals in the database
144173
max_end: int
145-
largest end value of the interval, default value 1_000_000_000
174+
largest end value of the interval, default value 1_000_000
146175
147176
Returns
148177
-------
@@ -162,7 +191,7 @@ def execute_queries(queries: Queries, db: Node) -> QueryResult:
162191
----------
163192
queries: Queries
164193
queries to be executed
165-
db: Db
194+
db: Node
166195
database to query
167196
168197
Returns
@@ -186,7 +215,7 @@ def create_queries(size: int = 1_000, max_end: int = 1_000_000) -> Queries:
186215
size: int
187216
number of intervals in the query, default value 1_000
188217
max_end: int
189-
largest end value of the interval, default value 1_000_000_000
218+
largest end value of the interval, default value 1_000_000
190219
191220
Returns
192221
-------

0 commit comments

Comments
 (0)