Skip to content

Commit 3f5daa4

Browse files
committed
working on svc implementation
1 parent 4442480 commit 3f5daa4

1 file changed

Lines changed: 97 additions & 1 deletion

File tree

src/svm/svc.rs

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
use std::collections::{HashMap, HashSet};
7676
use std::fmt::Debug;
7777
use std::marker::PhantomData;
78+
use std::ops::Mul;
7879

7980
use num::Bounded;
8081
use rand::seq::SliceRandom;
@@ -84,12 +85,107 @@ use serde::{Deserialize, Serialize};
8485

8586
use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow};
8687
use crate::error::{Failed, FailedError};
87-
use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
88+
use crate::linalg::basic::arrays::ArrayView1;
89+
use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray};
90+
use crate::linalg::basic::matrix::DenseMatrix;
8891
use crate::numbers::basenum::Number;
8992
use crate::numbers::realnum::RealNumber;
9093
use crate::rand_custom::get_rng_impl;
9194
use crate::svm::Kernel;
9295

96+
pub struct MultiClassSVCParameters<'a> {
97+
xs_filtered: Vec<DenseMatrix<f64>>,
98+
ys_filtered: Vec<Vec<i64>>,
99+
svc_params: &'a SVCParameters<f64, i64, DenseMatrix<f64>, Vec<i64>>,
100+
classes: Vec<(i64, i64)>,
101+
}
102+
103+
impl<'a> MultiClassSVCParameters<'a> {
104+
pub fn new(
105+
x: &DenseMatrix<f64>,
106+
y: &Vec<i64>,
107+
params: &'a SVCParameters<f64, i64, DenseMatrix<f64>, Vec<i64>>,
108+
) -> Self {
109+
let y = y.unique();
110+
let mut classes = Vec::new();
111+
let mut xs_filtered = Vec::new();
112+
let mut ys_filtered = Vec::new();
113+
for i in 0..y.len() {
114+
for j in (i + 1)..y.len() {
115+
let class1 = y[i];
116+
let class2 = y[j];
117+
let mut y_filtered = Vec::new();
118+
let mut x_filtered = Vec::new();
119+
for k in 0..y.len() {
120+
let y_val = y[k];
121+
let x_val = x.get_row(k).iterator(1).map(|v| *v).collect();
122+
if y_val == class1 {
123+
y_filtered.push(1);
124+
x_filtered.push(x_val);
125+
} else if y_val == class2 {
126+
y_filtered.push(-1);
127+
x_filtered.push(x_val);
128+
}
129+
}
130+
let x_filtered = DenseMatrix::from_2d_vec(&x_filtered).unwrap();
131+
xs_filtered.push(x_filtered);
132+
ys_filtered.push(y_filtered);
133+
classes.push((class1, class2));
134+
}
135+
}
136+
Self {
137+
xs_filtered,
138+
ys_filtered,
139+
svc_params: &params,
140+
classes,
141+
}
142+
}
143+
}
144+
pub struct MulticlassSVC<'a> {
145+
parameters: &'a MultiClassSVCParameters<'a>,
146+
classifiers: Vec<SVC<'a, f64, i64, DenseMatrix<f64>, Vec<i64>>>,
147+
}
148+
149+
impl<'a> MulticlassSVC<'a> {
150+
pub fn fit(params: &'a MultiClassSVCParameters) -> Result<Self, Failed> {
151+
let mut classifiers = Vec::new();
152+
for i in 0..params.classes.len() {
153+
let y_filtered = params.ys_filtered.get(i).unwrap();
154+
let x_filtered = params.xs_filtered.get(i).unwrap();
155+
let svc = SVC::fit(x_filtered, y_filtered, params.svc_params);
156+
}
157+
Ok(Self {
158+
parameters: &params,
159+
classifiers,
160+
})
161+
}
162+
163+
pub fn predict(&self, x: &DenseMatrix<f64>) -> Vec<i64> {
164+
let mut polls = vec![HashMap::new(); x.shape().0];
165+
for i in 0..self.parameters.classes.len() {
166+
let svc = self.classifiers.get(i).unwrap();
167+
let (class1, class2) = self.parameters.classes[i];
168+
let predictions = svc.predict(x).unwrap();
169+
for (j, prediction) in predictions.iter().enumerate() {
170+
let poll = polls.get_mut(j).unwrap();
171+
let class = match prediction {
172+
1.0 => class1,
173+
_ => class2,
174+
};
175+
if let Some(count) = poll.get_mut(&class) {
176+
*count += 1
177+
} else {
178+
poll.insert(class, 1);
179+
}
180+
}
181+
}
182+
polls
183+
.iter()
184+
.map(|v| *v.iter().max_by_key(|(_, count)| **count).unwrap().0)
185+
.collect()
186+
}
187+
}
188+
93189
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
94190
#[derive(Debug)]
95191
/// SVC Parameters

0 commit comments

Comments
 (0)