Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 85 additions & 19 deletions src/merge.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,77 @@
use pyo3::prelude::*;
use pyo3::exceptions::PyValueError;
use numpy::{PyArray1, PyReadonlyArray1};
use ndarray::{Array1, Array2, Axis, s};
use std::collections::HashSet;

const DEFAULT_BINS: usize = 10;

#[derive(Clone, Copy, PartialEq, Eq)]
enum ConstraintMode {
Any,
All,
}

impl ConstraintMode {
fn parse(mode: &str) -> PyResult<Self> {
match mode {
"any" => Ok(Self::Any),
"all" => Ok(Self::All),
_ => Err(PyValueError::new_err(
"`constraint_mode` must be either 'any' or 'all'",
)),
}
}
}

fn min_row_count(grouped: &Array2<f64>) -> f64 {
let row_sums = grouped.sum_axis(Axis(1));
row_sums
.iter()
.fold(f64::INFINITY, |current_min, &count| current_min.min(count))
}

fn should_break(
grouped: &Array2<f64>,
n_bins: Option<usize>,
min_samples_val: Option<f64>,
constraint_mode: ConstraintMode,
) -> bool {
match constraint_mode {
ConstraintMode::Any => {
if let Some(nb) = n_bins {
if grouped.nrows() <= nb {
return true;
}
}

if let Some(ms_val) = min_samples_val {
if min_row_count(grouped) > ms_val {
return true;
}
}

false
}
ConstraintMode::All => {
let mut has_hard_constraint = false;
let mut satisfied = true;

if let Some(nb) = n_bins {
has_hard_constraint = true;
satisfied &= grouped.nrows() <= nb;
}

if let Some(ms_val) = min_samples_val {
has_hard_constraint = true;
satisfied &= min_row_count(grouped) >= ms_val;
}

has_hard_constraint && satisfied
}
}
}

/// Helper function to fill NaN values
fn fill_nan(arr: &Array1<f64>, fill_value: f64) -> Array1<f64> {
arr.mapv(|x| if x.is_nan() { fill_value } else { x })
Expand All @@ -20,7 +87,7 @@ fn unique_sorted_f64(arr: &Array1<f64>) -> Vec<f64> {

/// ChiMerge - Chi-square based merging (core algorithm in Rust)
#[pyfunction]
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=-1.0, balance=true))]
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=-1.0, balance=true, constraint_mode="any"))]
fn chi_merge<'py>(
py: Python<'py>,
feature: PyReadonlyArray1<f64>,
Expand All @@ -30,7 +97,16 @@ fn chi_merge<'py>(
min_threshold: Option<f64>,
nan: f64,
balance: bool,
constraint_mode: &str,
) -> PyResult<Bound<'py, PyArray1<f64>>> {
let constraint_mode = ConstraintMode::parse(constraint_mode)?;

if constraint_mode == ConstraintMode::All && n_bins.is_none() && min_samples.is_none() {
return Err(PyValueError::new_err(
"`constraint_mode='all'` requires `n_bins` and/or `min_samples`; `min_threshold` is ignored in this mode",
));
}

// Set default break condition
let n_bins = if n_bins.is_none() && min_samples.is_none() && min_threshold.is_none() {
Some(DEFAULT_BINS)
Expand Down Expand Up @@ -81,20 +157,8 @@ fn chi_merge<'py>(

// Merge loop
loop {
// Break if n_bins reached
if let Some(nb) = n_bins {
if grouped.nrows() <= nb {
break;
}
}

// Break if min_samples reached
if let Some(ms_val) = min_samples_val {
let row_sums = grouped.sum_axis(Axis(1));
let min_count = row_sums.iter().fold(f64::INFINITY, |a, &b| a.min(b));
if min_count > ms_val {
break;
}
if should_break(&grouped, n_bins, min_samples_val, constraint_mode) {
break;
}

// Calculate chi-square for each adjacent group pair
Expand Down Expand Up @@ -131,10 +195,12 @@ fn chi_merge<'py>(
}
}

// Break if min_threshold reached
if let Some(mt) = min_threshold {
if chi_min > mt {
break;
// Legacy mode keeps min_threshold as an independent stopping condition.
if constraint_mode == ConstraintMode::Any {
if let Some(mt) = min_threshold {
if chi_min > mt {
break;
}
}
}

Expand Down
164 changes: 139 additions & 25 deletions src/merge_generic.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,76 @@
use pyo3::prelude::*;
use pyo3::exceptions::PyValueError;
use numpy::{PyArray1, PyReadonlyArray1, Element};
use ndarray::{Array1, Array2, Axis, s};
use std::collections::HashSet;
use num_traits::{Num, NumCast};

#[derive(Clone, Copy, PartialEq, Eq)]
enum ConstraintMode {
Any,
All,
}

impl ConstraintMode {
fn parse(mode: &str) -> PyResult<Self> {
match mode {
"any" => Ok(Self::Any),
"all" => Ok(Self::All),
_ => Err(PyValueError::new_err(
"`constraint_mode` must be either 'any' or 'all'",
)),
}
}
}

fn min_row_count(grouped: &Array2<f64>) -> f64 {
let row_sums = grouped.sum_axis(Axis(1));
row_sums
.iter()
.fold(f64::INFINITY, |current_min, &count| current_min.min(count))
}

fn should_break(
grouped: &Array2<f64>,
n_bins: Option<usize>,
min_samples_val: Option<f64>,
constraint_mode: ConstraintMode,
) -> bool {
match constraint_mode {
ConstraintMode::Any => {
if let Some(nb) = n_bins {
if grouped.nrows() <= nb {
return true;
}
}

if let Some(ms_val) = min_samples_val {
if min_row_count(grouped) > ms_val {
return true;
}
}

false
}
ConstraintMode::All => {
let mut has_hard_constraint = false;
let mut satisfied = true;

if let Some(nb) = n_bins {
has_hard_constraint = true;
satisfied &= grouped.nrows() <= nb;
}

if let Some(ms_val) = min_samples_val {
has_hard_constraint = true;
satisfied &= min_row_count(grouped) >= ms_val;
}

has_hard_constraint && satisfied
}
}
}

/// Helper function to fill NaN values for floating point types
fn fill_nan_f64(arr: &Array1<f64>, fill_value: f64) -> Array1<f64> {
arr.mapv(|x| if x.is_nan() { fill_value } else { x })
Expand Down Expand Up @@ -34,6 +101,7 @@ fn chi_merge_generic<T>(
min_threshold: Option<f64>,
_nan: f64,
balance: bool,
constraint_mode: ConstraintMode,
) -> PyResult<Vec<T>>
where
T: Num + NumCast + Copy + PartialOrd + std::fmt::Display + Element + 'static,
Expand Down Expand Up @@ -88,26 +156,16 @@ where

// Merge loop
loop {
// Break if n_bins reached
if let Some(nb) = n_bins {
if grouped.nrows() <= nb {
break;
}
}

// Break if min_samples reached
if let Some(ms_val) = min_samples.map(|ms| {
let min_samples_val = min_samples.map(|ms| {
if ms < 1.0 {
(feature_filled.len() as f64) * ms
} else {
ms as f64
}
}) {
let row_sums = grouped.sum_axis(Axis(1));
let min_count = row_sums.iter().fold(f64::INFINITY, |a, &b| a.min(b));
if min_count > ms_val {
break;
}
});

if should_break(&grouped, n_bins, min_samples_val, constraint_mode) {
break;
}

// Calculate chi-square for each adjacent group pair
Expand Down Expand Up @@ -144,10 +202,12 @@ where
}
}

// Break if min_threshold reached
if let Some(mt) = min_threshold {
if chi_min > mt {
break;
// Legacy mode keeps min_threshold as an independent stopping condition.
if constraint_mode == ConstraintMode::Any {
if let Some(mt) = min_threshold {
if chi_min > mt {
break;
}
}
}

Expand Down Expand Up @@ -192,7 +252,7 @@ where

/// ChiMerge for f64 (floating point)
#[pyfunction]
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0.0, balance=false))]
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0.0, balance=false, constraint_mode="any"))]
pub fn chi_merge_f64<'py>(
py: Python<'py>,
feature: PyReadonlyArray1<f64>,
Expand All @@ -202,14 +262,32 @@ pub fn chi_merge_f64<'py>(
min_threshold: Option<f64>,
nan: f64,
balance: bool,
constraint_mode: &str,
) -> PyResult<Bound<'py, PyArray1<f64>>> {
let splits = chi_merge_generic(feature, target, n_bins, min_samples, min_threshold, nan, balance)?;
let constraint_mode = ConstraintMode::parse(constraint_mode)?;

if constraint_mode == ConstraintMode::All && n_bins.is_none() && min_samples.is_none() {
return Err(PyValueError::new_err(
"`constraint_mode='all'` requires `n_bins` and/or `min_samples`; `min_threshold` is ignored in this mode",
));
}

let splits = chi_merge_generic(
feature,
target,
n_bins,
min_samples,
min_threshold,
nan,
balance,
constraint_mode,
)?;
Ok(PyArray1::from_vec_bound(py, splits))
}

/// ChiMerge for i32 (integer)
#[pyfunction]
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0, balance=false))]
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0, balance=false, constraint_mode="any"))]
pub fn chi_merge_i32<'py>(
py: Python<'py>,
feature: PyReadonlyArray1<i32>,
Expand All @@ -219,14 +297,32 @@ pub fn chi_merge_i32<'py>(
min_threshold: Option<f64>,
nan: i32,
balance: bool,
constraint_mode: &str,
) -> PyResult<Bound<'py, PyArray1<i32>>> {
let splits = chi_merge_generic(feature, target, n_bins, min_samples, min_threshold, nan as f64, balance)?;
let constraint_mode = ConstraintMode::parse(constraint_mode)?;

if constraint_mode == ConstraintMode::All && n_bins.is_none() && min_samples.is_none() {
return Err(PyValueError::new_err(
"`constraint_mode='all'` requires `n_bins` and/or `min_samples`; `min_threshold` is ignored in this mode",
));
}

let splits = chi_merge_generic(
feature,
target,
n_bins,
min_samples,
min_threshold,
nan as f64,
balance,
constraint_mode,
)?;
Ok(PyArray1::from_vec_bound(py, splits))
}

/// ChiMerge for i64 (integer)
#[pyfunction]
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0, balance=false))]
#[pyo3(signature = (feature, target, n_bins=None, min_samples=None, min_threshold=None, nan=0, balance=false, constraint_mode="any"))]
pub fn chi_merge_i64<'py>(
py: Python<'py>,
feature: PyReadonlyArray1<i64>,
Expand All @@ -236,7 +332,25 @@ pub fn chi_merge_i64<'py>(
min_threshold: Option<f64>,
nan: i64,
balance: bool,
constraint_mode: &str,
) -> PyResult<Bound<'py, PyArray1<i64>>> {
let splits = chi_merge_generic(feature, target, n_bins, min_samples, min_threshold, nan as f64, balance)?;
let constraint_mode = ConstraintMode::parse(constraint_mode)?;

if constraint_mode == ConstraintMode::All && n_bins.is_none() && min_samples.is_none() {
return Err(PyValueError::new_err(
"`constraint_mode='all'` requires `n_bins` and/or `min_samples`; `min_threshold` is ignored in this mode",
));
}

let splits = chi_merge_generic(
feature,
target,
n_bins,
min_samples,
min_threshold,
nan as f64,
balance,
constraint_mode,
)?;
Ok(PyArray1::from_vec_bound(py, splits))
}
Loading
Loading