7575use std:: collections:: { HashMap , HashSet } ;
7676use std:: fmt:: Debug ;
7777use std:: marker:: PhantomData ;
78+ use std:: ops:: Mul ;
7879
7980use num:: Bounded ;
8081use rand:: seq:: SliceRandom ;
@@ -84,12 +85,107 @@ use serde::{Deserialize, Serialize};
8485
8586use crate :: api:: { PredictorBorrow , SupervisedEstimatorBorrow } ;
8687use crate :: error:: { Failed , FailedError } ;
87- use crate :: linalg:: basic:: arrays:: { Array1 , Array2 , MutArray } ;
88+ use crate :: linalg:: basic:: arrays:: ArrayView1 ;
89+ use crate :: linalg:: basic:: arrays:: { Array , Array1 , Array2 , MutArray } ;
90+ use crate :: linalg:: basic:: matrix:: DenseMatrix ;
8891use crate :: numbers:: basenum:: Number ;
8992use crate :: numbers:: realnum:: RealNumber ;
9093use crate :: rand_custom:: get_rng_impl;
9194use crate :: svm:: Kernel ;
9295
96+ pub struct MultiClassSVCParameters < ' a > {
97+ xs_filtered : Vec < DenseMatrix < f64 > > ,
98+ ys_filtered : Vec < Vec < i64 > > ,
99+ svc_params : & ' a SVCParameters < f64 , i64 , DenseMatrix < f64 > , Vec < i64 > > ,
100+ classes : Vec < ( i64 , i64 ) > ,
101+ }
102+
103+ impl < ' a > MultiClassSVCParameters < ' a > {
104+ pub fn new (
105+ x : & DenseMatrix < f64 > ,
106+ y : & Vec < i64 > ,
107+ params : & ' a SVCParameters < f64 , i64 , DenseMatrix < f64 > , Vec < i64 > > ,
108+ ) -> Self {
109+ let y = y. unique ( ) ;
110+ let mut classes = Vec :: new ( ) ;
111+ let mut xs_filtered = Vec :: new ( ) ;
112+ let mut ys_filtered = Vec :: new ( ) ;
113+ for i in 0 ..y. len ( ) {
114+ for j in ( i + 1 ) ..y. len ( ) {
115+ let class1 = y[ i] ;
116+ let class2 = y[ j] ;
117+ let mut y_filtered = Vec :: new ( ) ;
118+ let mut x_filtered = Vec :: new ( ) ;
119+ for k in 0 ..y. len ( ) {
120+ let y_val = y[ k] ;
121+ let x_val = x. get_row ( k) . iterator ( 1 ) . map ( |v| * v) . collect ( ) ;
122+ if y_val == class1 {
123+ y_filtered. push ( 1 ) ;
124+ x_filtered. push ( x_val) ;
125+ } else if y_val == class2 {
126+ y_filtered. push ( -1 ) ;
127+ x_filtered. push ( x_val) ;
128+ }
129+ }
130+ let x_filtered = DenseMatrix :: from_2d_vec ( & x_filtered) . unwrap ( ) ;
131+ xs_filtered. push ( x_filtered) ;
132+ ys_filtered. push ( y_filtered) ;
133+ classes. push ( ( class1, class2) ) ;
134+ }
135+ }
136+ Self {
137+ xs_filtered,
138+ ys_filtered,
139+ svc_params : & params,
140+ classes,
141+ }
142+ }
143+ }
144+ pub struct MulticlassSVC < ' a > {
145+ parameters : & ' a MultiClassSVCParameters < ' a > ,
146+ classifiers : Vec < SVC < ' a , f64 , i64 , DenseMatrix < f64 > , Vec < i64 > > > ,
147+ }
148+
149+ impl < ' a > MulticlassSVC < ' a > {
150+ pub fn fit ( params : & ' a MultiClassSVCParameters ) -> Result < Self , Failed > {
151+ let mut classifiers = Vec :: new ( ) ;
152+ for i in 0 ..params. classes . len ( ) {
153+ let y_filtered = params. ys_filtered . get ( i) . unwrap ( ) ;
154+ let x_filtered = params. xs_filtered . get ( i) . unwrap ( ) ;
155+ let svc = SVC :: fit ( x_filtered, y_filtered, params. svc_params ) ;
156+ }
157+ Ok ( Self {
158+ parameters : & params,
159+ classifiers,
160+ } )
161+ }
162+
163+ pub fn predict ( & self , x : & DenseMatrix < f64 > ) -> Vec < i64 > {
164+ let mut polls = vec ! [ HashMap :: new( ) ; x. shape( ) . 0 ] ;
165+ for i in 0 ..self . parameters . classes . len ( ) {
166+ let svc = self . classifiers . get ( i) . unwrap ( ) ;
167+ let ( class1, class2) = self . parameters . classes [ i] ;
168+ let predictions = svc. predict ( x) . unwrap ( ) ;
169+ for ( j, prediction) in predictions. iter ( ) . enumerate ( ) {
170+ let poll = polls. get_mut ( j) . unwrap ( ) ;
171+ let class = match prediction {
172+ 1.0 => class1,
173+ _ => class2,
174+ } ;
175+ if let Some ( count) = poll. get_mut ( & class) {
176+ * count += 1
177+ } else {
178+ poll. insert ( class, 1 ) ;
179+ }
180+ }
181+ }
182+ polls
183+ . iter ( )
184+ . map ( |v| * v. iter ( ) . max_by_key ( |( _, count) | * * count) . unwrap ( ) . 0 )
185+ . collect ( )
186+ }
187+ }
188+
93189#[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
94190#[ derive( Debug ) ]
95191/// SVC Parameters
0 commit comments