6262//! let svc = SVC::fit(&x, &y, parameters, None).unwrap();
6363//!
6464//! let y_hat = svc.predict(&x).unwrap();
65+ //!
6566//! ```
6667//!
6768//! ## References:
@@ -92,20 +93,43 @@ use crate::svm::Kernel;
9293
9394#[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
9495#[ derive( Debug ) ]
96+ /// Configuration for a multi-class Support Vector Machine (SVM) classifier.
97+ ///
98+ /// This struct holds the indices of the data points relevant to a specific binary
99+ /// classification problem within a multi-class context, and the two classes
100+ /// being discriminated.
95101pub struct MultiClassConfig < TY : Number + Ord > {
102+ /// The indices of the data points from the original dataset that belong to the two `classes`.
96103 indices : Vec < usize > ,
104+ /// A tuple representing the two classes that this configuration is designed to distinguish.
97105 classes : ( TY , TY ) ,
98106}
99107
100108impl < ' a , TX : Number + RealNumber , TY : Number + Ord , X : Array2 < TX > , Y : Array1 < TY > >
101109 SupervisedEstimatorBorrow < ' a , X , Y , SVCParameters < TX , TY , X , Y > >
102110 for MultiClassSVC < ' a , TX , TY , X , Y >
103111{
112+ /// Creates a new, empty `MultiClassSVC` instance.
113+ ///
114+ /// The `classifiers` field is initialized to `Option::None`, indicating that
115+ /// the model has not yet been fitted.
104116 fn new ( ) -> Self {
105117 Self {
106118 classifiers : Option :: None ,
107119 }
108120 }
121+
122+ /// Fits the `MultiClassSVC` model to the provided data and parameters.
123+ ///
124+ /// This method delegates the fitting process to the inherent `MultiClassSVC::fit` method.
125+ ///
126+ /// # Arguments
127+ /// * `x` - A reference to the input features (2D array).
128+ /// * `y` - A reference to the target labels (1D array).
129+ /// * `parameters` - A reference to the `SVCParameters` controlling the SVM training.
130+ ///
131+ /// # Returns
132+ /// A `Result` indicating success (`Self`) or failure (`Failed`).
109133 fn fit (
110134 x : & ' a X ,
111135 y : & ' a Y ,
@@ -118,50 +142,95 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
118142impl < ' a , TX : Number + RealNumber , TY : Number + Ord , X : Array2 < TX > , Y : Array1 < TY > >
119143 PredictorBorrow < ' a , X , TX > for MultiClassSVC < ' a , TX , TY , X , Y >
120144{
145+ /// Predicts the class labels for new data points.
146+ ///
147+ /// This method delegates the prediction process to the inherent `MultiClassSVC::predict` method.
148+ /// It unwraps the inner `Result` from `MultiClassSVC::predict`, assuming that
149+ /// the prediction will always succeed after a successful fit.
150+ ///
151+ /// # Arguments
152+ /// * `x` - A reference to the input features (2D array) for which to make predictions.
153+ ///
154+ /// # Returns
155+ /// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
121156 fn predict ( & self , x : & ' a X ) -> Result < Vec < TX > , Failed > {
122157 Ok ( self . predict ( x) . unwrap ( ) )
123158 }
124159}
125160
161+ /// A multi-class Support Vector Machine (SVM) classifier.
162+ ///
163+ /// This struct implements a multi-class SVM using the "one-vs-one" strategy,
164+ /// where a separate binary SVC classifier is trained for every pair of classes.
165+ ///
166+ /// # Type Parameters
167+ /// * `'a` - Lifetime parameter for borrowed data.
168+ /// * `TX` - The numeric type of the input features (must implement `Number` and `RealNumber`).
169+ /// * `TY` - The numeric type of the target labels (must implement `Number` and `Ord`).
170+ /// * `X` - The type representing the 2D array of input features (e.g., a matrix).
171+ /// * `Y` - The type representing the 1D array of target labels (e.g., a vector).
126172pub struct MultiClassSVC <
127173 ' a ,
128174 TX : Number + RealNumber ,
129175 TY : Number + Ord ,
130176 X : Array2 < TX > ,
131177 Y : Array1 < TY > ,
132178> {
179+ /// An optional vector of binary `SVC` classifiers.
180+ ///
181+ /// This will be `Some` after the model has been fitted, containing one `SVC`
182+ /// for each pair of unique classes.
133183 classifiers : Option < Vec < SVC < ' a , TX , TY , X , Y > > > ,
134184}
135185
136186impl < ' a , TX : Number + RealNumber , TY : Number + Ord , X : Array2 < TX > , Y : Array1 < TY > >
137187 MultiClassSVC < ' a , TX , TY , X , Y >
138188{
189+ /// Fits the `MultiClassSVC` model to the provided data using a one-vs-one strategy.
190+ ///
191+ /// This method identifies all unique classes in the target labels `y` and then
192+ /// trains a binary `SVC` for every unique pair of classes. For each pair, it
193+ /// extracts the relevant data points and their labels, and then trains a
194+ /// specialized `SVC` for that binary classification task.
195+ ///
196+ /// # Arguments
197+ /// * `x` - A reference to the input features (2D array).
198+ /// * `y` - A reference to the target labels (1D array).
199+ /// * `parameters` - A reference to the `SVCParameters` controlling the SVM training for each individual binary classifier.
200+ ///
201+ ///
202+ /// # Returns
203+ /// A `Result` indicating success (`MultiClassSVC`) or failure (`Failed`).
139204 pub fn fit (
140205 x : & ' a X ,
141206 y : & ' a Y ,
142207 parameters : & ' a SVCParameters < TX , TY , X , Y > ,
143208 ) -> Result < MultiClassSVC < ' a , TX , TY , X , Y > , Failed > {
144209 let unique_classes = y. unique ( ) ;
145210 let mut classifiers = Vec :: new ( ) ;
211+ // Iterate through all unique pairs of classes (one-vs-one strategy)
146212 for i in 0 ..unique_classes. len ( ) {
147213 for j in i..unique_classes. len ( ) {
148214 if i == j {
149- continue ;
215+ continue ; // Skip comparing a class to itself
150216 }
151217 let class0 = unique_classes[ j] ;
152218 let class1 = unique_classes[ i] ;
219+
153220 let mut indices = Vec :: new ( ) ;
221+ // Collect indices of data points belonging to the current pair of classes
154222 for ( index, v) in y. iterator ( 0 ) . enumerate ( ) {
155223 if * v == class0 || * v == class1 {
156224 indices. push ( index)
157225 }
158226 }
159227 let classes = ( class0, class1) ;
160228 let multiclass_config = MultiClassConfig {
161- classes : classes . clone ( ) ,
229+ classes,
162230 indices,
163231 } ;
164- let svc = SVC :: fit ( x, y, parameters, Some ( multiclass_config) ) . unwrap ( ) ;
232+ // Fit a binary SVC for the current pair of classes
233+ let svc = SVC :: fit ( x, y, parameters, Some ( multiclass_config) ) . unwrap ( ) ; // .unwrap() might panic if SVC::fit fails
165234 classifiers. push ( svc) ;
166235 }
167236 }
@@ -170,25 +239,50 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
170239 } )
171240 }
172241
242+ /// Predicts the class labels for new data points using the trained multi-class SVM.
243+ ///
244+ /// This method uses a "voting" scheme (majority vote) among all the binary
245+ /// classifiers to determine the final prediction for each data point.
246+ ///
247+ /// # Arguments
248+ /// * `x` - A reference to the input features (2D array) for which to make predictions.
249+ ///
250+ /// # Returns
251+ /// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
252+ ///
253+ /// # Panics
254+ /// Panics if the model has not been fitted (`self.classifiers` is `None`).
173255 pub fn predict ( & self , x : & X ) -> Result < Vec < TX > , Failed > {
256+ // Initialize a HashMap for each data point to store votes for each class
174257 let mut polls = vec ! [ HashMap :: new( ) ; x. shape( ) . 0 ] ;
258+ // Retrieve the trained binary classifiers; panics if not fitted
175259 let classifiers = self . classifiers . as_ref ( ) . unwrap ( ) ;
260+
261+ // Iterate through each binary classifier
176262 for i in 0 ..classifiers. len ( ) {
177- let svc = classifiers. get ( i) . unwrap ( ) ;
178- let predictions = svc. predict ( x) . unwrap ( ) ;
263+ let svc = classifiers. get ( i) . unwrap ( ) ; // .unwrap() might panic if index is out of bounds
264+ let predictions = svc. predict ( x) . unwrap ( ) ; // .unwrap() might panic if SVC::predict fails
265+
266+ // For each prediction from the current binary classifier
179267 for ( j, prediction) in predictions. iter ( ) . enumerate ( ) {
180- let prediction = prediction. to_i32 ( ) . unwrap ( ) ;
181- let poll = polls. get_mut ( j) . unwrap ( ) ;
268+ let prediction = prediction. to_i32 ( ) . unwrap ( ) ; // Convert prediction to i32 for HashMap key
269+ let poll = polls. get_mut ( j) . unwrap ( ) ; // Get the poll for the current data point
270+ // Increment the vote for the predicted class
182271 if let Some ( count) = poll. get_mut ( & prediction) {
183272 * count += 1
184273 } else {
185274 poll. insert ( prediction, 1 ) ;
186275 }
187276 }
188277 }
278+
279+ // Determine the final prediction for each data point based on majority vote
189280 Ok ( polls
190281 . iter ( )
191- . map ( |v| TX :: from ( * v. iter ( ) . max_by_key ( |( _, class) | * class) . unwrap ( ) . 0 ) . unwrap ( ) )
282+ . map ( |v| {
283+ // Find the class with the maximum votes for each data point
284+ TX :: from ( * v. iter ( ) . max_by_key ( |( _, class) | * class) . unwrap ( ) . 0 ) . unwrap ( )
285+ } )
192286 . collect ( ) )
193287 }
194288}
0 commit comments