@@ -10,16 +10,19 @@ use std::fmt;
1010use std:: fmt:: Debug ;
1111use std:: fmt:: Display ;
1212use std:: fmt:: Formatter ;
13+ use std:: hash:: Hash ;
1314use std:: ops:: Deref ;
1415use std:: sync:: Arc ;
1516use std:: sync:: LazyLock ;
1617use std:: sync:: OnceLock ;
1718
19+ use arc_swap:: ArcSwap ;
1820use lasso:: Spur ;
1921use lasso:: ThreadedRodeo ;
2022use parking_lot:: RwLock ;
2123use vortex_error:: VortexExpect ;
2224use vortex_utils:: aliases:: dash_map:: DashMap ;
25+ use vortex_utils:: aliases:: hash_map:: HashMap ;
2326
2427/// Global string interner for [`Id`] values.
2528static INTERNER : LazyLock < ThreadedRodeo > = LazyLock :: new ( ThreadedRodeo :: new) ;
@@ -299,8 +302,8 @@ impl<T: Clone> Context<T> {
299302/// optimizer's parent-reduce registry keys by `(parent_encoding_id, child_encoding_id)` so that
300303/// downstream crates can override the rule that would normally run from the child encoding's
301304/// static `PARENT_RULES` set.
302- #[ derive( Clone , Debug , Default ) ]
303- pub struct FnRegistry ( Arc < DashMap < ( Id , Id ) , Arc < dyn Any + Send + Sync > > > ) ;
305+ #[ derive( Debug , Default ) ]
306+ pub struct FnRegistry ( ArcSwap < HashMap < u64 , Arc < dyn Any + Send + Sync > > > ) ;
304307
305308impl FnRegistry {
306309 /// Create a new, empty registry.
@@ -309,26 +312,35 @@ impl FnRegistry {
309312 }
310313
311314 /// Register a function under `(outer, inner)`, replacing any existing entry.
312- pub fn register < F : Any + Send + Sync > ( & self , outer : Id , inner : Id , f : F ) {
313- self . 0 . insert ( ( outer, inner) , Arc :: new ( f) ) ;
315+ pub fn register < F : Any + Send + Sync > ( & self , id : u64 , f : F ) {
316+ let registry = self . 0 . load ( ) ;
317+ let mut owned_registry = registry. as_ref ( ) . clone ( ) ;
318+ owned_registry. insert ( id, Arc :: new ( f) ) ;
319+ self . 0 . store ( Arc :: new ( owned_registry) ) ;
314320 }
315321
316322 /// Look up a function registered under `(outer, inner)`, downcasting to `F`.
317323 ///
318324 /// Returns `None` if no function is registered, or if the registered value is not of type `F`.
319- pub fn find < F : Any + Send + Sync > ( & self , outer : Id , inner : Id ) -> Option < Arc < F > > {
320- let entry = self . 0 . get ( & ( outer, inner) ) ?;
321- Arc :: clone ( entry. value ( ) ) . downcast :: < F > ( ) . ok ( )
325+ pub fn find < F : Any + Send + Sync > ( & self , id : u64 ) -> Option < Arc < F > > {
326+ let map = self . 0 . load ( ) ;
327+ let entry = map. get ( & id) ?;
328+ Arc :: clone ( entry) . downcast :: < F > ( ) . ok ( )
322329 }
323330
324331 /// Return `true` if any function is registered under `(outer, inner)`.
325- pub fn contains ( & self , outer : Id , inner : Id ) -> bool {
326- self . 0 . contains_key ( & ( outer, inner) )
332+ pub fn contains ( & self , id : u64 ) -> bool {
333+ let map = self . 0 . load ( ) ;
334+ map. contains_key ( & id)
327335 }
328336}
329337
330338#[ cfg( test) ]
331339mod fn_registry_tests {
340+ use std:: hash:: BuildHasher ;
341+
342+ use vortex_utils:: aliases:: DefaultHashBuilder ;
343+
332344 use super :: FnRegistry ;
333345 use super :: Id ;
334346
@@ -343,12 +355,13 @@ mod fn_registry_tests {
343355 let registry = FnRegistry :: default ( ) ;
344356 let outer = Id :: new ( "test.double" ) ;
345357 let inner = Id :: new ( "test.int" ) ;
358+ let id = DefaultHashBuilder :: default ( ) . hash_one ( ( outer, inner) ) ;
346359
347- assert ! ( !registry. contains( outer , inner ) ) ;
348- registry. register :: < DoubleFn > ( outer , inner , double) ;
360+ assert ! ( !registry. contains( id ) ) ;
361+ registry. register :: < DoubleFn > ( id , double) ;
349362
350- assert ! ( registry. contains( outer , inner ) ) ;
351- let f = registry. find :: < DoubleFn > ( outer , inner ) . unwrap ( ) ;
363+ assert ! ( registry. contains( id ) ) ;
364+ let f = registry. find :: < DoubleFn > ( id ) . unwrap ( ) ;
352365 assert_eq ! ( f( 21 ) , 42 ) ;
353366 }
354367
@@ -357,17 +370,20 @@ mod fn_registry_tests {
357370 let registry = FnRegistry :: default ( ) ;
358371 let outer = Id :: new ( "test.double" ) ;
359372 let inner = Id :: new ( "test.int" ) ;
360- registry. register :: < DoubleFn > ( outer, inner, double) ;
373+ let id = DefaultHashBuilder :: default ( ) . hash_one ( ( outer, inner) ) ;
374+
375+ registry. register :: < DoubleFn > ( id, double) ;
361376
362377 type OtherFn = fn ( i32 ) -> i32 ;
363- assert ! ( registry. find:: <OtherFn >( outer , inner ) . is_none( ) ) ;
378+ assert ! ( registry. find:: <OtherFn >( id ) . is_none( ) ) ;
364379 }
365380
366381 #[ test]
367382 fn missing_entry_returns_none ( ) {
368383 let registry = FnRegistry :: default ( ) ;
369384 let outer = Id :: new ( "test.missing" ) ;
370385 let inner = Id :: new ( "test.int" ) ;
371- assert ! ( registry. find:: <DoubleFn >( outer, inner) . is_none( ) ) ;
386+ let id = DefaultHashBuilder :: default ( ) . hash_one ( ( outer, inner) ) ;
387+ assert ! ( registry. find:: <DoubleFn >( id) . is_none( ) ) ;
372388 }
373389}
0 commit comments