Skip to content

Commit c07bb50

Browse files
committed
Lasso: Implementations for constructor, fitting and create output struct
1 parent 8a1a534 commit c07bb50

1 file changed

Lines changed: 102 additions & 0 deletions

File tree

crates/RustQuant_ml/src/lasso.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,105 @@ pub struct LassoOutput<T> {
4747
/// The coefficients of the lasso regression,
4848
pub coefficients: DVector<T>,
4949
}
50+
51+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
52+
// IMPLEMENTATIONS
53+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
54+
55+
impl LassoInput<f64> {
56+
/// Create a new `LassoInput` struct.
57+
#[must_use]
58+
pub fn new(
59+
x: DMatrix<f64>,
60+
y: DVector<f64>,
61+
lambda: f64,
62+
fit_intercept: bool,
63+
max_iter: usize,
64+
tolerance: f64,
65+
) -> Self {
66+
Self { x, y, lambda, fit_intercept, max_iter, tolerance }
67+
}
68+
69+
/// Fits a Lasso regression to the input data.
70+
/// Returns the intercept and coefficients.
71+
/// The intercept is the first value of the coefficients.
72+
pub fn fit(&self) -> Result<LassoOutput<f64>, RustQuantError> {
73+
let n_cols = self.x.ncols();
74+
let n_rows = self.x.nrows() as f64;
75+
let mut features_matrix = self.x.clone();
76+
let mut response_vec = self.y.clone();
77+
let feature_means = DVector::from_iterator(
78+
self.x.ncols(),
79+
(0..self.x.ncols()).map(|j| self.x.column(j).mean())
80+
);
81+
82+
if self.fit_intercept {
83+
84+
features_matrix = self.x.clone();
85+
for j in 0..self.x.ncols() {
86+
let mean = feature_means[j];
87+
for i in 0..self.x.nrows() {
88+
features_matrix[(i, j)] -= mean;
89+
}
90+
}
91+
response_vec = &self.y - DVector::from_element(self.x.nrows(), self.y.mean());
92+
}
93+
94+
let mut residual = response_vec;
95+
let mut coefficients = DVector::<f64>::zeros(n_cols);
96+
97+
for _ in 0..self.max_iter {
98+
let mut max_delta: f64 = 0.0;
99+
for j in 0..n_cols {
100+
101+
let feature_vals_col_j = features_matrix.column(j);
102+
let col_norm: f64 = feature_vals_col_j.dot(&feature_vals_col_j);
103+
let rho: f64 = (residual.dot(&feature_vals_col_j) + coefficients[j] * col_norm) / n_rows;
104+
105+
let new_coefficient_j: f64 = if rho < -self.lambda {
106+
(rho + self.lambda) / (col_norm / n_rows)
107+
} else if rho > self.lambda {
108+
(rho - self.lambda) / (col_norm / n_rows)
109+
} else {
110+
0.0
111+
};
112+
113+
let delta: f64 = new_coefficient_j - coefficients[j];
114+
if delta.abs() > 0.0 {
115+
residual -= &feature_vals_col_j * delta;
116+
}
117+
coefficients[j] = new_coefficient_j;
118+
max_delta = max_delta.max(delta.abs());
119+
}
120+
121+
if max_delta < self.tolerance {
122+
break;
123+
}
124+
}
125+
126+
let intercept: f64 = if self.fit_intercept {
127+
self.y.mean() - feature_means.dot(&coefficients)
128+
} else {
129+
0.0
130+
};
131+
coefficients = coefficients.insert_row(0, intercept);
132+
133+
Ok(LassoOutput {
134+
intercept,
135+
coefficients,
136+
})
137+
}
138+
}
139+
140+
impl LassoOutput<f64> {
141+
/// Predicts the output for the given input data.
142+
pub fn predict(&self, input: DMatrix<f64>) -> Result<DVector<f64>, RustQuantError> {
143+
let intercept = DVector::from_element(
144+
input.nrows(),
145+
self.intercept
146+
);
147+
let coefficients = self.coefficients.clone().remove_row(0);
148+
let predictions = input * coefficients + intercept;
149+
Ok(predictions)
150+
}
151+
}

0 commit comments

Comments
 (0)