33from typing import Any , List
44from sqlalchemy .sql import operators
55
6+ from ..helpers .ordered_set import OrderedSet
7+
68
79class IndexManager :
10+ __slots__ = ('hash_index' , 'range_index' , 'table_indexes' , 'columns_mapping' , )
11+
812 def __init__ (self ):
913 self .hash_index = HashIndex ()
1014 self .range_index = RangeIndex ()
1115
1216 self .table_indexes = {}
1317 self .columns_mapping = {}
1418
19+
1520 def get_indexes (self , obj ):
1621 """
1722 Retrieve index from object's table as dict: indexname => list of column name
@@ -21,18 +26,27 @@ def get_indexes(self, obj):
2126 if tablename not in self .table_indexes :
2227 self .table_indexes [tablename ] = {}
2328
29+ pk_col_name = obj .__table__ .primary_key .columns [0 ].name
30+
2431 for index in obj .__table__ .indexes :
2532 if len (index .expressions ) > 1 :
2633 # Ignoring compound indexes for now ...
2734 continue
2835
36+ if index .name == pk_col_name :
37+ pk_col_name = None
38+
2939 self .table_indexes [tablename ][index .name ] = [
3040 col .name
3141 for col in index .expressions
3242 ]
3343
44+ if pk_col_name :
45+ self .table_indexes [tablename ][pk_col_name ] = [pk_col_name ]
46+
3447 return self .table_indexes [tablename ]
3548
49+
3650 def _column_to_index (self , tablename , colname ):
3751 """
3852 Get index name from tablename & column name
@@ -51,6 +65,7 @@ def _column_to_index(self, tablename, colname):
5165
5266 return self .columns_mapping [tablename ][colname ]
5367
68+
5469 def _get_index_key (self , obj , columns ):
5570 if len (columns ) == 1 :
5671 return getattr (obj , columns [0 ])
@@ -65,7 +80,7 @@ def on_insert(self, obj):
6580
6681 self .hash_index .add (tablename , indexname , value , obj )
6782 self .range_index .add (tablename , indexname , value , obj )
68-
83+
6984 def on_delete (self , obj ):
7085 tablename = obj .__tablename__
7186 indexes = self .get_indexes (obj )
@@ -145,6 +160,7 @@ def query(self, collection, tablename, colname, operator, value):
145160 in_range = self .range_index .query (tablename , indexname , gte = value [0 ], lte = value [1 ])
146161 return list (set (collection ) - set (in_range ))
147162
163+
148164 def get_selectivity (self , tablename , colname , operator , value , total_count ):
149165 """
150166 Estimate selectivity: higher means worst filtering.
@@ -187,23 +203,24 @@ class HashIndex:
187203 Maintains insertion order of objects.
188204 """
189205
206+ __slots__ = ('index' ,)
207+
190208 def __init__ (self ):
191- self .index = defaultdict (lambda : defaultdict (lambda : defaultdict (list )))
209+ self .index = defaultdict (lambda : defaultdict (lambda : defaultdict (OrderedSet )))
210+
192211
193212 def add (self , tablename : str , indexname : str , value : Any , obj : Any ):
194- self .index [tablename ][indexname ][value ].append (obj )
213+ self .index [tablename ][indexname ][value ].add (obj )
214+
195215
196216 def remove (self , tablename : str , indexname : str , value : Any , obj : Any ):
197- lst = self .index [tablename ][indexname ][value ]
198- try :
199- lst .remove (obj )
200- if not lst :
201- del self .index [tablename ][indexname ][value ]
202- except ValueError :
203- pass
217+ s = self .index [tablename ][indexname ][value ]
218+ s .discard (obj )
219+ if not s :
220+ del self .index [tablename ][indexname ][value ]
204221
205222 def query (self , tablename : str , indexname : str , value : Any ) -> List [Any ]:
206- return self .index [tablename ][indexname ].get (value , [])
223+ return list ( self .index [tablename ][indexname ].get (value , []) )
207224
208225
209226class RangeIndex :
@@ -215,12 +232,19 @@ class RangeIndex:
215232 index[tablename][indexname] = SortedDict { value: [obj1, obj2, ...] }
216233 """
217234
235+ __slots__ = ('index' ,)
236+
218237 def __init__ (self ):
219238 self .index = defaultdict (lambda : defaultdict (SortedDict ))
220239
221240 def add (self , tablename : str , indexname : str , value : Any , obj : Any ):
222- self .index [tablename ][indexname ].setdefault (value , []).append (obj )
241+ index = self .index [tablename ][indexname ]
242+ if value in index :
243+ index [value ].append (obj )
244+ else :
245+ index [value ] = [obj ]
223246
247+
224248 def remove (self , tablename : str , indexname : str , value : Any , obj : Any ):
225249 col = self .index [tablename ][indexname ]
226250 if value in col :
0 commit comments