Skip to content

Commit 8738990

Browse files
Atwood706Secbone
andauthored
feat: add ChiMerge constraint_mode parameter (#158)
* Add ChiMerge constraint mode * Fix imputer string dtype encoding --------- Co-authored-by: Secbone <secbone@gmail.com>
1 parent 72f3fbc commit 8738990

7 files changed

Lines changed: 440 additions & 56 deletions

File tree

src/merge.rs

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,77 @@
11
use pyo3::prelude::*;
2+
use pyo3::exceptions::PyValueError;
23
use numpy::{PyArray1, PyReadonlyArray1};
34
use ndarray::{Array1, Array2, Axis, s};
45
use std::collections::HashSet;
56

67
const DEFAULT_BINS: usize = 10;
78

9+
#[derive(Clone, Copy, PartialEq, Eq)]
10+
enum ConstraintMode {
11+
Any,
12+
All,
13+
}
14+
15+
impl ConstraintMode {
16+
fn parse(mode: &str) -> PyResult<Self> {
17+
match mode {
18+
"any" => Ok(Self::Any),
19+
"all" => Ok(Self::All),
20+
_ => Err(PyValueError::new_err(
21+
"`constraint_mode` must be either 'any' or 'all'",
22+
)),
23+
}
24+
}
25+
}
26+
27+
fn min_row_count(grouped: &Array2<f64>) -> f64 {
28+
let row_sums = grouped.sum_axis(Axis(1));
29+
row_sums
30+
.iter()
31+
.fold(f64::INFINITY, |current_min, &count| current_min.min(count))
32+
}
33+
34+
fn should_break(
35+
grouped: &Array2<f64>,
36+
n_bins: Option<usize>,
37+
min_samples_val: Option<f64>,
38+
constraint_mode: ConstraintMode,
39+
) -> bool {
40+
match constraint_mode {
41+
ConstraintMode::Any => {
42+
if let Some(nb) = n_bins {
43+
if grouped.nrows() <= nb {
44+
return true;
45+
}
46+
}
47+
48+
if let Some(ms_val) = min_samples_val {
49+
if min_row_count(grouped) > ms_val {
50+
return true;
51+
}
52+
}
53+
54+
false
55+
}
56+
ConstraintMode::All => {
57+
let mut has_hard_constraint = false;
58+
let mut satisfied = true;
59+
60+
if let Some(nb) = n_bins {
61+
has_hard_constraint = true;
62+
satisfied &= grouped.nrows() <= nb;
63+
}
64+
65+
if let Some(ms_val) = min_samples_val {
66+
has_hard_constraint = true;
67+
satisfied &= min_row_count(grouped) >= ms_val;
68+
}
69+
70+
has_hard_constraint && satisfied
71+
}
72+
}
73+
}
74+
875
/// Helper function to fill NaN values
976
fn fill_nan(arr: &Array1<f64>, fill_value: f64) -> Array1<f64> {
1077
arr.mapv(|x| if x.is_nan() { fill_value } else { x })
@@ -20,7 +87,7 @@ fn unique_sorted_f64(arr: &Array1<f64>) -> Vec<f64> {
2087

2188
/// ChiMerge - Chi-square based merging (core algorithm in Rust)
2289
#[pyfunction]
23-
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=-1.0, balance=true))]
90+
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=-1.0, balance=true, constraint_mode="any"))]
2491
fn chi_merge<'py>(
2592
py: Python<'py>,
2693
feature: PyReadonlyArray1<f64>,
@@ -30,7 +97,16 @@ fn chi_merge<'py>(
3097
min_threshold: Option<f64>,
3198
nan: f64,
3299
balance: bool,
100+
constraint_mode: &str,
33101
) -> PyResult<Bound<'py, PyArray1<f64>>> {
102+
let constraint_mode = ConstraintMode::parse(constraint_mode)?;
103+
104+
if constraint_mode == ConstraintMode::All && n_bins.is_none() && min_samples.is_none() {
105+
return Err(PyValueError::new_err(
106+
"`constraint_mode='all'` requires `n_bins` and/or `min_samples`; `min_threshold` is ignored in this mode",
107+
));
108+
}
109+
34110
// Set default break condition
35111
let n_bins = if n_bins.is_none() && min_samples.is_none() && min_threshold.is_none() {
36112
Some(DEFAULT_BINS)
@@ -81,20 +157,8 @@ fn chi_merge<'py>(
81157

82158
// Merge loop
83159
loop {
84-
// Break if n_bins reached
85-
if let Some(nb) = n_bins {
86-
if grouped.nrows() <= nb {
87-
break;
88-
}
89-
}
90-
91-
// Break if min_samples reached
92-
if let Some(ms_val) = min_samples_val {
93-
let row_sums = grouped.sum_axis(Axis(1));
94-
let min_count = row_sums.iter().fold(f64::INFINITY, |a, &b| a.min(b));
95-
if min_count > ms_val {
96-
break;
97-
}
160+
if should_break(&grouped, n_bins, min_samples_val, constraint_mode) {
161+
break;
98162
}
99163

100164
// Calculate chi-square for each adjacent group pair
@@ -131,10 +195,12 @@ fn chi_merge<'py>(
131195
}
132196
}
133197

134-
// Break if min_threshold reached
135-
if let Some(mt) = min_threshold {
136-
if chi_min > mt {
137-
break;
198+
// Legacy mode keeps min_threshold as an independent stopping condition.
199+
if constraint_mode == ConstraintMode::Any {
200+
if let Some(mt) = min_threshold {
201+
if chi_min > mt {
202+
break;
203+
}
138204
}
139205
}
140206

src/merge_generic.rs

Lines changed: 139 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,76 @@
11
use pyo3::prelude::*;
2+
use pyo3::exceptions::PyValueError;
23
use numpy::{PyArray1, PyReadonlyArray1, Element};
34
use ndarray::{Array1, Array2, Axis, s};
45
use std::collections::HashSet;
56
use num_traits::{Num, NumCast};
67

8+
#[derive(Clone, Copy, PartialEq, Eq)]
9+
enum ConstraintMode {
10+
Any,
11+
All,
12+
}
13+
14+
impl ConstraintMode {
15+
fn parse(mode: &str) -> PyResult<Self> {
16+
match mode {
17+
"any" => Ok(Self::Any),
18+
"all" => Ok(Self::All),
19+
_ => Err(PyValueError::new_err(
20+
"`constraint_mode` must be either 'any' or 'all'",
21+
)),
22+
}
23+
}
24+
}
25+
26+
fn min_row_count(grouped: &Array2<f64>) -> f64 {
27+
let row_sums = grouped.sum_axis(Axis(1));
28+
row_sums
29+
.iter()
30+
.fold(f64::INFINITY, |current_min, &count| current_min.min(count))
31+
}
32+
33+
fn should_break(
34+
grouped: &Array2<f64>,
35+
n_bins: Option<usize>,
36+
min_samples_val: Option<f64>,
37+
constraint_mode: ConstraintMode,
38+
) -> bool {
39+
match constraint_mode {
40+
ConstraintMode::Any => {
41+
if let Some(nb) = n_bins {
42+
if grouped.nrows() <= nb {
43+
return true;
44+
}
45+
}
46+
47+
if let Some(ms_val) = min_samples_val {
48+
if min_row_count(grouped) > ms_val {
49+
return true;
50+
}
51+
}
52+
53+
false
54+
}
55+
ConstraintMode::All => {
56+
let mut has_hard_constraint = false;
57+
let mut satisfied = true;
58+
59+
if let Some(nb) = n_bins {
60+
has_hard_constraint = true;
61+
satisfied &= grouped.nrows() <= nb;
62+
}
63+
64+
if let Some(ms_val) = min_samples_val {
65+
has_hard_constraint = true;
66+
satisfied &= min_row_count(grouped) >= ms_val;
67+
}
68+
69+
has_hard_constraint && satisfied
70+
}
71+
}
72+
}
73+
774
/// Helper function to fill NaN values for floating point types
875
fn fill_nan_f64(arr: &Array1<f64>, fill_value: f64) -> Array1<f64> {
976
arr.mapv(|x| if x.is_nan() { fill_value } else { x })
@@ -34,6 +101,7 @@ fn chi_merge_generic<T>(
34101
min_threshold: Option<f64>,
35102
_nan: f64,
36103
balance: bool,
104+
constraint_mode: ConstraintMode,
37105
) -> PyResult<Vec<T>>
38106
where
39107
T: Num + NumCast + Copy + PartialOrd + std::fmt::Display + Element + 'static,
@@ -88,26 +156,16 @@ where
88156

89157
// Merge loop
90158
loop {
91-
// Break if n_bins reached
92-
if let Some(nb) = n_bins {
93-
if grouped.nrows() <= nb {
94-
break;
95-
}
96-
}
97-
98-
// Break if min_samples reached
99-
if let Some(ms_val) = min_samples.map(|ms| {
159+
let min_samples_val = min_samples.map(|ms| {
100160
if ms < 1.0 {
101161
(feature_filled.len() as f64) * ms
102162
} else {
103163
ms as f64
104164
}
105-
}) {
106-
let row_sums = grouped.sum_axis(Axis(1));
107-
let min_count = row_sums.iter().fold(f64::INFINITY, |a, &b| a.min(b));
108-
if min_count > ms_val {
109-
break;
110-
}
165+
});
166+
167+
if should_break(&grouped, n_bins, min_samples_val, constraint_mode) {
168+
break;
111169
}
112170

113171
// Calculate chi-square for each adjacent group pair
@@ -144,10 +202,12 @@ where
144202
}
145203
}
146204

147-
// Break if min_threshold reached
148-
if let Some(mt) = min_threshold {
149-
if chi_min > mt {
150-
break;
205+
// Legacy mode keeps min_threshold as an independent stopping condition.
206+
if constraint_mode == ConstraintMode::Any {
207+
if let Some(mt) = min_threshold {
208+
if chi_min > mt {
209+
break;
210+
}
151211
}
152212
}
153213

@@ -192,7 +252,7 @@ where
192252

193253
/// ChiMerge for f64 (floating point)
194254
#[pyfunction]
195-
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0.0, balance=false))]
255+
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0.0, balance=false, constraint_mode="any"))]
196256
pub fn chi_merge_f64<'py>(
197257
py: Python<'py>,
198258
feature: PyReadonlyArray1<f64>,
@@ -202,14 +262,32 @@ pub fn chi_merge_f64<'py>(
202262
min_threshold: Option<f64>,
203263
nan: f64,
204264
balance: bool,
265+
constraint_mode: &str,
205266
) -> PyResult<Bound<'py, PyArray1<f64>>> {
206-
let splits = chi_merge_generic(feature, target, n_bins, min_samples, min_threshold, nan, balance)?;
267+
let constraint_mode = ConstraintMode::parse(constraint_mode)?;
268+
269+
if constraint_mode == ConstraintMode::All && n_bins.is_none() && min_samples.is_none() {
270+
return Err(PyValueError::new_err(
271+
"`constraint_mode='all'` requires `n_bins` and/or `min_samples`; `min_threshold` is ignored in this mode",
272+
));
273+
}
274+
275+
let splits = chi_merge_generic(
276+
feature,
277+
target,
278+
n_bins,
279+
min_samples,
280+
min_threshold,
281+
nan,
282+
balance,
283+
constraint_mode,
284+
)?;
207285
Ok(PyArray1::from_vec_bound(py, splits))
208286
}
209287

210288
/// ChiMerge for i32 (integer)
211289
#[pyfunction]
212-
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0, balance=false))]
290+
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0, balance=false, constraint_mode="any"))]
213291
pub fn chi_merge_i32<'py>(
214292
py: Python<'py>,
215293
feature: PyReadonlyArray1<i32>,
@@ -219,14 +297,32 @@ pub fn chi_merge_i32<'py>(
219297
min_threshold: Option<f64>,
220298
nan: i32,
221299
balance: bool,
300+
constraint_mode: &str,
222301
) -> PyResult<Bound<'py, PyArray1<i32>>> {
223-
let splits = chi_merge_generic(feature, target, n_bins, min_samples, min_threshold, nan as f64, balance)?;
302+
let constraint_mode = ConstraintMode::parse(constraint_mode)?;
303+
304+
if constraint_mode == ConstraintMode::All && n_bins.is_none() && min_samples.is_none() {
305+
return Err(PyValueError::new_err(
306+
"`constraint_mode='all'` requires `n_bins` and/or `min_samples`; `min_threshold` is ignored in this mode",
307+
));
308+
}
309+
310+
let splits = chi_merge_generic(
311+
feature,
312+
target,
313+
n_bins,
314+
min_samples,
315+
min_threshold,
316+
nan as f64,
317+
balance,
318+
constraint_mode,
319+
)?;
224320
Ok(PyArray1::from_vec_bound(py, splits))
225321
}
226322

227323
/// ChiMerge for i64 (integer)
228324
#[pyfunction]
229-
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0, balance=false))]
325+
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0, balance=false, constraint_mode="any"))]
230326
pub fn chi_merge_i64<'py>(
231327
py: Python<'py>,
232328
feature: PyReadonlyArray1<i64>,
@@ -236,7 +332,25 @@ pub fn chi_merge_i64<'py>(
236332
min_threshold: Option<f64>,
237333
nan: i64,
238334
balance: bool,
335+
constraint_mode: &str,
239336
) -> PyResult<Bound<'py, PyArray1<i64>>> {
240-
let splits = chi_merge_generic(feature, target, n_bins, min_samples, min_threshold, nan as f64, balance)?;
337+
let constraint_mode = ConstraintMode::parse(constraint_mode)?;
338+
339+
if constraint_mode == ConstraintMode::All && n_bins.is_none() && min_samples.is_none() {
340+
return Err(PyValueError::new_err(
341+
"`constraint_mode='all'` requires `n_bins` and/or `min_samples`; `min_threshold` is ignored in this mode",
342+
));
343+
}
344+
345+
let splits = chi_merge_generic(
346+
feature,
347+
target,
348+
n_bins,
349+
min_samples,
350+
min_threshold,
351+
nan as f64,
352+
balance,
353+
constraint_mode,
354+
)?;
241355
Ok(PyArray1::from_vec_bound(py, splits))
242356
}

0 commit comments

Comments
 (0)