Skip to content

Commit ece4f28

Browse files
committed
two encounters of a bad pattern is_none() + unwrap(). FIXED.
1 parent 9fdfd93 commit ece4f28

2 files changed

Lines changed: 44 additions & 31 deletions

File tree

src/ensemble/base_forest_regressor.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -161,25 +161,31 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
161161
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
162162
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
163163
let (n, _) = x.shape();
164-
if self.samples.is_none() {
165-
Err(Failed::because(
166-
FailedError::PredictFailed,
167-
"Need samples=true for OOB predictions.",
168-
))
169-
} else if self.samples.as_ref().unwrap()[0].len() != n {
170-
Err(Failed::because(
164+
165+
let samples = match &self.samples {
166+
Some(s) => s,
167+
None => {
168+
return Err(Failed::because(
169+
FailedError::PredictFailed,
170+
"Need samples=true for OOB predictions.",
171+
))
172+
}
173+
};
174+
175+
if samples[0].len() != n {
176+
return Err(Failed::because(
171177
FailedError::PredictFailed,
172178
"Prediction matrix must match matrix used in training for OOB predictions.",
173-
))
174-
} else {
175-
let mut result = Y::zeros(n);
179+
));
180+
}
176181

177-
for i in 0..n {
178-
result.set(i, self.predict_for_row_oob(x, i));
179-
}
182+
let mut result = Y::zeros(n);
180183

181-
Ok(result)
184+
for i in 0..n {
185+
result.set(i, self.predict_for_row_oob(x, i));
182186
}
187+
188+
Ok(result)
183189
}
184190

185191
fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {

src/ensemble/random_forest_classifier.rs

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -539,27 +539,34 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
539539
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
540540
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
541541
let (n, _) = x.shape();
542-
if self.samples.is_none() {
543-
Err(Failed::because(
544-
FailedError::PredictFailed,
545-
"Need samples=true for OOB predictions.",
546-
))
547-
} else if self.samples.as_ref().unwrap()[0].len() != n {
548-
Err(Failed::because(
542+
543+
let samples = match &self.samples {
544+
Some(s) => s,
545+
None => {
546+
return Err(Failed::because(
547+
FailedError::PredictFailed,
548+
"Need samples=true for OOB predictions.",
549+
));
550+
}
551+
};
552+
553+
if samples[0].len() != n {
554+
return Err(Failed::because(
549555
FailedError::PredictFailed,
550556
"Prediction matrix must match matrix used in training for OOB predictions.",
551-
))
552-
} else {
553-
let mut result = Y::zeros(n);
557+
));
558+
}
554559

555-
for i in 0..n {
556-
result.set(
557-
i,
558-
self.classes.as_ref().unwrap()[self.predict_for_row_oob(x, i)],
559-
);
560-
}
561-
Ok(result)
560+
let mut result = Y::zeros(n);
561+
562+
for i in 0..n {
563+
result.set(
564+
i,
565+
self.classes.as_ref().unwrap()[self.predict_for_row_oob(x, i)],
566+
);
562567
}
568+
569+
Ok(result)
563570
}
564571

565572
fn predict_for_row_oob(&self, x: &X, row: usize) -> usize {

0 commit comments

Comments
 (0)