Skip to content

Commit 5536db7

Browse files
committed
Add estimator parameters filter
1 parent 7df7818 commit 5536db7

2 files changed

Lines changed: 46 additions & 4 deletions

File tree

sklbench/benchmarks/sklearn_estimator.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,48 @@ def verify_patching(stream: io.StringIO, function_name) -> bool:
334334
return acceleration_lines > 0 and fallback_lines == 0
335335

336336

337+
def validate_estimator_params(estimator_class, estimator_params: Dict) -> Dict:
338+
"""Validates parameters and returns only those supported by the estimator.
339+
340+
Args:
341+
estimator_class: The estimator class to validate against
342+
estimator_params: Dictionary of parameters to validate
343+
344+
Returns:
345+
Dictionary with only valid parameters
346+
"""
347+
try:
348+
init_signature = inspect.signature(estimator_class.__init__)
349+
valid_params = set(init_signature.parameters.keys()) - {"self"}
350+
351+
# Check if estimator accepts **kwargs
352+
has_var_keyword = any(
353+
param.kind == inspect.Parameter.VAR_KEYWORD
354+
for param in init_signature.parameters.values()
355+
)
356+
357+
# If accepts **kwargs, return all params
358+
if has_var_keyword:
359+
return estimator_params
360+
361+
# Filter out invalid params and warn
362+
filtered_params = {}
363+
for param_name, param_value in estimator_params.items():
364+
if param_name in valid_params:
365+
filtered_params[param_name] = param_value
366+
else:
367+
logger.warning(
368+
f"Parameter '{param_name}' is not supported by "
369+
f"{estimator_class.__name__} and will be ignored"
370+
)
371+
372+
return filtered_params
373+
374+
except Exception as e:
375+
logger.debug(f"Could not validate parameters for {estimator_class.__name__}: {e}")
376+
return estimator_params
377+
378+
337379
def create_online_function(method_instance, data_args, batch_size):
338380
n_batches = data_args[0].shape[0] // batch_size
339381

@@ -491,6 +533,9 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
491533
bench_case, "algorithm:estimator_params", dict()
492534
)
493535

536+
# validate and filter estimator parameters
537+
estimator_params = validate_estimator_params(estimator_class, estimator_params)
538+
494539
# get estimator methods for measurement
495540
estimator_methods = get_estimator_methods(bench_case)
496541

sklbench/datasets/common.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def preprocess_x(
168168
x: Array,
169169
replace_nan="auto",
170170
category_encoding="ordinal",
171-
normalize=None, # None, "standard", "minmax"
171+
normalize=None,
172172
force_for_sparse=True,
173173
**kwargs,
174174
) -> Array:
@@ -223,13 +223,10 @@ def preprocess_x(
223223
# Normalization
224224
if normalize:
225225
if normalize == "standard":
226-
#x = (x - x.mean()) / x.std()
227226
scaler = StandardScaler(with_mean=True, with_std=True)
228227
elif normalize == "mean":
229-
#x = x - x.mean()
230228
scaler = StandardScaler(with_mean=True, with_std=False)
231229
elif normalize == "minmax":
232-
#x = (x - x.min()) / (x.max() - x.min())
233230
scaler = MinMaxScaler(feature_range=(0, 1))
234231
else:
235232
logger.warning(f'Unknown "{normalize}" normalization type.')

0 commit comments

Comments
 (0)