11use pyo3:: prelude:: * ;
2+ use pyo3:: exceptions:: PyValueError ;
23use numpy:: { PyArray1 , PyReadonlyArray1 , Element } ;
34use ndarray:: { Array1 , Array2 , Axis , s} ;
45use std:: collections:: HashSet ;
56use num_traits:: { Num , NumCast } ;
67
8+ #[ derive( Clone , Copy , PartialEq , Eq ) ]
9+ enum ConstraintMode {
10+ Any ,
11+ All ,
12+ }
13+
14+ impl ConstraintMode {
15+ fn parse ( mode : & str ) -> PyResult < Self > {
16+ match mode {
17+ "any" => Ok ( Self :: Any ) ,
18+ "all" => Ok ( Self :: All ) ,
19+ _ => Err ( PyValueError :: new_err (
20+ "`constraint_mode` must be either 'any' or 'all'" ,
21+ ) ) ,
22+ }
23+ }
24+ }
25+
26+ fn min_row_count ( grouped : & Array2 < f64 > ) -> f64 {
27+ let row_sums = grouped. sum_axis ( Axis ( 1 ) ) ;
28+ row_sums
29+ . iter ( )
30+ . fold ( f64:: INFINITY , |current_min, & count| current_min. min ( count) )
31+ }
32+
33+ fn should_break (
34+ grouped : & Array2 < f64 > ,
35+ n_bins : Option < usize > ,
36+ min_samples_val : Option < f64 > ,
37+ constraint_mode : ConstraintMode ,
38+ ) -> bool {
39+ match constraint_mode {
40+ ConstraintMode :: Any => {
41+ if let Some ( nb) = n_bins {
42+ if grouped. nrows ( ) <= nb {
43+ return true ;
44+ }
45+ }
46+
47+ if let Some ( ms_val) = min_samples_val {
48+ if min_row_count ( grouped) > ms_val {
49+ return true ;
50+ }
51+ }
52+
53+ false
54+ }
55+ ConstraintMode :: All => {
56+ let mut has_hard_constraint = false ;
57+ let mut satisfied = true ;
58+
59+ if let Some ( nb) = n_bins {
60+ has_hard_constraint = true ;
61+ satisfied &= grouped. nrows ( ) <= nb;
62+ }
63+
64+ if let Some ( ms_val) = min_samples_val {
65+ has_hard_constraint = true ;
66+ satisfied &= min_row_count ( grouped) >= ms_val;
67+ }
68+
69+ has_hard_constraint && satisfied
70+ }
71+ }
72+ }
73+
774/// Helper function to fill NaN values for floating point types
875fn fill_nan_f64 ( arr : & Array1 < f64 > , fill_value : f64 ) -> Array1 < f64 > {
976 arr. mapv ( |x| if x. is_nan ( ) { fill_value } else { x } )
@@ -34,6 +101,7 @@ fn chi_merge_generic<T>(
34101 min_threshold : Option < f64 > ,
35102 _nan : f64 ,
36103 balance : bool ,
104+ constraint_mode : ConstraintMode ,
37105) -> PyResult < Vec < T > >
38106where
39107 T : Num + NumCast + Copy + PartialOrd + std:: fmt:: Display + Element + ' static ,
@@ -88,26 +156,16 @@ where
88156
89157 // Merge loop
90158 loop {
91- // Break if n_bins reached
92- if let Some ( nb) = n_bins {
93- if grouped. nrows ( ) <= nb {
94- break ;
95- }
96- }
97-
98- // Break if min_samples reached
99- if let Some ( ms_val) = min_samples. map ( |ms| {
159+ let min_samples_val = min_samples. map ( |ms| {
100160 if ms < 1.0 {
101161 ( feature_filled. len ( ) as f64 ) * ms
102162 } else {
103163 ms as f64
104164 }
105- } ) {
106- let row_sums = grouped. sum_axis ( Axis ( 1 ) ) ;
107- let min_count = row_sums. iter ( ) . fold ( f64:: INFINITY , |a, & b| a. min ( b) ) ;
108- if min_count > ms_val {
109- break ;
110- }
165+ } ) ;
166+
167+ if should_break ( & grouped, n_bins, min_samples_val, constraint_mode) {
168+ break ;
111169 }
112170
113171 // Calculate chi-square for each adjacent group pair
@@ -144,10 +202,12 @@ where
144202 }
145203 }
146204
147- // Break if min_threshold reached
148- if let Some ( mt) = min_threshold {
149- if chi_min > mt {
150- break ;
205+ // Legacy mode keeps min_threshold as an independent stopping condition.
206+ if constraint_mode == ConstraintMode :: Any {
207+ if let Some ( mt) = min_threshold {
208+ if chi_min > mt {
209+ break ;
210+ }
151211 }
152212 }
153213
@@ -192,7 +252,7 @@ where
192252
193253/// ChiMerge for f64 (floating point)
194254#[ pyfunction]
195- #[ pyo3( signature = ( feature, target, n_bins=None , min_samples=None , min_threshold=None , nan=0.0 , balance=false ) ) ]
255+ #[ pyo3( signature = ( feature, target, n_bins=None , min_samples=None , min_threshold=None , nan=0.0 , balance=false , constraint_mode= "any" ) ) ]
196256pub fn chi_merge_f64 < ' py > (
197257 py : Python < ' py > ,
198258 feature : PyReadonlyArray1 < f64 > ,
@@ -202,14 +262,32 @@ pub fn chi_merge_f64<'py>(
202262 min_threshold : Option < f64 > ,
203263 nan : f64 ,
204264 balance : bool ,
265+ constraint_mode : & str ,
205266) -> PyResult < Bound < ' py , PyArray1 < f64 > > > {
206- let splits = chi_merge_generic ( feature, target, n_bins, min_samples, min_threshold, nan, balance) ?;
267+ let constraint_mode = ConstraintMode :: parse ( constraint_mode) ?;
268+
269+ if constraint_mode == ConstraintMode :: All && n_bins. is_none ( ) && min_samples. is_none ( ) {
270+ return Err ( PyValueError :: new_err (
271+ "`constraint_mode='all'` requires `n_bins` and/or `min_samples`; `min_threshold` is ignored in this mode" ,
272+ ) ) ;
273+ }
274+
275+ let splits = chi_merge_generic (
276+ feature,
277+ target,
278+ n_bins,
279+ min_samples,
280+ min_threshold,
281+ nan,
282+ balance,
283+ constraint_mode,
284+ ) ?;
207285 Ok ( PyArray1 :: from_vec_bound ( py, splits) )
208286}
209287
210288/// ChiMerge for i32 (integer)
211289#[ pyfunction]
212- #[ pyo3( signature = ( feature, target, n_bins=None , min_samples=None , min_threshold=None , nan=0 , balance=false ) ) ]
290+ #[ pyo3( signature = ( feature, target, n_bins=None , min_samples=None , min_threshold=None , nan=0 , balance=false , constraint_mode= "any" ) ) ]
213291pub fn chi_merge_i32 < ' py > (
214292 py : Python < ' py > ,
215293 feature : PyReadonlyArray1 < i32 > ,
@@ -219,14 +297,32 @@ pub fn chi_merge_i32<'py>(
219297 min_threshold : Option < f64 > ,
220298 nan : i32 ,
221299 balance : bool ,
300+ constraint_mode : & str ,
222301) -> PyResult < Bound < ' py , PyArray1 < i32 > > > {
223- let splits = chi_merge_generic ( feature, target, n_bins, min_samples, min_threshold, nan as f64 , balance) ?;
302+ let constraint_mode = ConstraintMode :: parse ( constraint_mode) ?;
303+
304+ if constraint_mode == ConstraintMode :: All && n_bins. is_none ( ) && min_samples. is_none ( ) {
305+ return Err ( PyValueError :: new_err (
306+ "`constraint_mode='all'` requires `n_bins` and/or `min_samples`; `min_threshold` is ignored in this mode" ,
307+ ) ) ;
308+ }
309+
310+ let splits = chi_merge_generic (
311+ feature,
312+ target,
313+ n_bins,
314+ min_samples,
315+ min_threshold,
316+ nan as f64 ,
317+ balance,
318+ constraint_mode,
319+ ) ?;
224320 Ok ( PyArray1 :: from_vec_bound ( py, splits) )
225321}
226322
227323/// ChiMerge for i64 (integer)
228324#[ pyfunction]
229- #[ pyo3( signature = ( feature, target, n_bins=None , min_samples=None , min_threshold=None , nan=0 , balance=false ) ) ]
325+ #[ pyo3( signature = ( feature, target, n_bins=None , min_samples=None , min_threshold=None , nan=0 , balance=false , constraint_mode= "any" ) ) ]
230326pub fn chi_merge_i64 < ' py > (
231327 py : Python < ' py > ,
232328 feature : PyReadonlyArray1 < i64 > ,
@@ -236,7 +332,25 @@ pub fn chi_merge_i64<'py>(
236332 min_threshold : Option < f64 > ,
237333 nan : i64 ,
238334 balance : bool ,
335+ constraint_mode : & str ,
239336) -> PyResult < Bound < ' py , PyArray1 < i64 > > > {
240- let splits = chi_merge_generic ( feature, target, n_bins, min_samples, min_threshold, nan as f64 , balance) ?;
337+ let constraint_mode = ConstraintMode :: parse ( constraint_mode) ?;
338+
339+ if constraint_mode == ConstraintMode :: All && n_bins. is_none ( ) && min_samples. is_none ( ) {
340+ return Err ( PyValueError :: new_err (
341+ "`constraint_mode='all'` requires `n_bins` and/or `min_samples`; `min_threshold` is ignored in this mode" ,
342+ ) ) ;
343+ }
344+
345+ let splits = chi_merge_generic (
346+ feature,
347+ target,
348+ n_bins,
349+ min_samples,
350+ min_threshold,
351+ nan as f64 ,
352+ balance,
353+ constraint_mode,
354+ ) ?;
241355 Ok ( PyArray1 :: from_vec_bound ( py, splits) )
242356}
0 commit comments