@@ -334,6 +334,48 @@ def verify_patching(stream: io.StringIO, function_name) -> bool:
334334 return acceleration_lines > 0 and fallback_lines == 0
335335
336336
337+ def validate_estimator_params (estimator_class , estimator_params : Dict ) -> Dict :
338+ """Validates parameters and returns only those supported by the estimator.
339+
340+ Args:
341+ estimator_class: The estimator class to validate against
342+ estimator_params: Dictionary of parameters to validate
343+
344+ Returns:
345+ Dictionary with only valid parameters
346+ """
347+ try :
348+ init_signature = inspect .signature (estimator_class .__init__ )
349+ valid_params = set (init_signature .parameters .keys ()) - {"self" }
350+
351+ # Check if estimator accepts **kwargs
352+ has_var_keyword = any (
353+ param .kind == inspect .Parameter .VAR_KEYWORD
354+ for param in init_signature .parameters .values ()
355+ )
356+
357+ # If accepts **kwargs, return all params
358+ if has_var_keyword :
359+ return estimator_params
360+
361+ # Filter out invalid params and warn
362+ filtered_params = {}
363+ for param_name , param_value in estimator_params .items ():
364+ if param_name in valid_params :
365+ filtered_params [param_name ] = param_value
366+ else :
367+ logger .warning (
368+ f"Parameter '{ param_name } ' is not supported by "
369+ f"{ estimator_class .__name__ } and will be ignored"
370+ )
371+
372+ return filtered_params
373+
374+ except Exception as e :
375+ logger .debug (f"Could not validate parameters for { estimator_class .__name__ } : { e } " )
376+ return estimator_params
377+
378+
337379def create_online_function (method_instance , data_args , batch_size ):
338380 n_batches = data_args [0 ].shape [0 ] // batch_size
339381
@@ -491,6 +533,9 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
491533 bench_case , "algorithm:estimator_params" , dict ()
492534 )
493535
536+ # validate and filter estimator parameters
537+ estimator_params = validate_estimator_params (estimator_class , estimator_params )
538+
494539 # get estimator methods for measurement
495540 estimator_methods = get_estimator_methods (bench_case )
496541
0 commit comments