-
-
Notifications
You must be signed in to change notification settings - Fork 323
Expand file tree
/
Copy pathhyperparams.rs
More file actions
119 lines (105 loc) · 3.95 KB
/
hyperparams.rs
File metadata and controls
119 lines (105 loc) · 3.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use linfa::ParamGuard;
use ndarray::{Array, Array1, Dimension};
use crate::error::Error;
use crate::float::Float;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
/// A generalized logistic regression type that specializes as either binomial logistic regression
/// or multinomial logistic regression.
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct LogisticRegressionParams<F: Float, D: Dimension>(LogisticRegressionValidParams<F, D>);
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct LogisticRegressionValidParams<F: Float, D: Dimension> {
pub(crate) alpha: F,
pub(crate) fit_intercept: bool,
pub(crate) max_iterations: u64,
pub(crate) gradient_tolerance: F,
pub(crate) initial_params: Option<Array<F, D>>,
pub(crate) offset: Option<Array1<F>>,
}
impl<F: Float, D: Dimension> ParamGuard for LogisticRegressionParams<F, D> {
type Checked = LogisticRegressionValidParams<F, D>;
type Error = Error;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
if !self.0.alpha.is_finite() || self.0.alpha < F::zero() {
return Err(Error::InvalidAlpha);
}
if !self.0.gradient_tolerance.is_finite() || self.0.gradient_tolerance <= F::zero() {
return Err(Error::InvalidGradientTolerance);
}
if let Some(params) = self.0.initial_params.as_ref() {
if params.iter().any(|p| !p.is_finite()) {
return Err(Error::InvalidInitialParameters);
}
}
if let Some(ref offset) = self.0.offset {
if offset.iter().any(|o| !o.is_finite()) {
return Err(Error::InvalidOffset);
}
}
Ok(&self.0)
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}
impl<F: Float, D: Dimension> LogisticRegressionParams<F, D> {
/// Creates a new LogisticRegression with default configuration.
pub fn new() -> Self {
Self(LogisticRegressionValidParams {
alpha: F::cast(1.0),
fit_intercept: true,
max_iterations: 100,
gradient_tolerance: F::cast(1e-4),
initial_params: None,
offset: None,
})
}
/// Set the regularization parameter `alpha` used for L2 regularization,
/// defaults to `1.0`.
pub fn alpha(mut self, alpha: F) -> Self {
self.0.alpha = alpha;
self
}
/// Configure if an intercept should be fitted, defaults to `true`.
pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
self.0.fit_intercept = fit_intercept;
self
}
/// Configure the maximum number of iterations that the solver should perform,
/// defaults to `100`.
pub fn max_iterations(mut self, max_iterations: u64) -> Self {
self.0.max_iterations = max_iterations;
self
}
/// Configure the minimum change to the gradient to continue the solver,
/// defaults to `1e-4`.
pub fn gradient_tolerance(mut self, gradient_tolerance: F) -> Self {
self.0.gradient_tolerance = gradient_tolerance;
self
}
/// Configure the initial parameters from where the optimization starts. The `params` array
/// must have the same number of rows as there are columns on the feature matrix `x` passed to
/// the `fit` method. If `with_intercept` is set, then it needs to have one more row. For
/// multinomial regression, `params` also must have the same number of columns as the number of
/// distinct classes in `y`.
pub fn initial_params(mut self, params: Array<F, D>) -> Self {
self.0.initial_params = Some(params);
self
}
pub fn offset(mut self, offset: Array1<F>) -> Self {
self.0.offset = Some(offset);
self
}
}