@@ -47,3 +47,105 @@ pub struct LassoOutput<T> {
4747 /// The coefficients of the lasso regression,
4848 pub coefficients : DVector < T > ,
4949}
50+
51+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
52+ // IMPLEMENTATIONS
53+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
54+
55+ impl LassoInput < f64 > {
56+ /// Create a new `LassoInput` struct.
57+ #[ must_use]
58+ pub fn new (
59+ x : DMatrix < f64 > ,
60+ y : DVector < f64 > ,
61+ lambda : f64 ,
62+ fit_intercept : bool ,
63+ max_iter : usize ,
64+ tolerance : f64 ,
65+ ) -> Self {
66+ Self { x, y, lambda, fit_intercept, max_iter, tolerance }
67+ }
68+
69+ /// Fits a Lasso regression to the input data.
70+ /// Returns the intercept and coefficients.
71+ /// The intercept is the first value of the coefficients.
72+ pub fn fit ( & self ) -> Result < LassoOutput < f64 > , RustQuantError > {
73+ let n_cols = self . x . ncols ( ) ;
74+ let n_rows = self . x . nrows ( ) as f64 ;
75+ let mut features_matrix = self . x . clone ( ) ;
76+ let mut response_vec = self . y . clone ( ) ;
77+ let feature_means = DVector :: from_iterator (
78+ self . x . ncols ( ) ,
79+ ( 0 ..self . x . ncols ( ) ) . map ( |j| self . x . column ( j) . mean ( ) )
80+ ) ;
81+
82+ if self . fit_intercept {
83+
84+ features_matrix = self . x . clone ( ) ;
85+ for j in 0 ..self . x . ncols ( ) {
86+ let mean = feature_means[ j] ;
87+ for i in 0 ..self . x . nrows ( ) {
88+ features_matrix[ ( i, j) ] -= mean;
89+ }
90+ }
91+ response_vec = & self . y - DVector :: from_element ( self . x . nrows ( ) , self . y . mean ( ) ) ;
92+ }
93+
94+ let mut residual = response_vec;
95+ let mut coefficients = DVector :: < f64 > :: zeros ( n_cols) ;
96+
97+ for _ in 0 ..self . max_iter {
98+ let mut max_delta: f64 = 0.0 ;
99+ for j in 0 ..n_cols {
100+
101+ let feature_vals_col_j = features_matrix. column ( j) ;
102+ let col_norm: f64 = feature_vals_col_j. dot ( & feature_vals_col_j) ;
103+ let rho: f64 = ( residual. dot ( & feature_vals_col_j) + coefficients[ j] * col_norm) / n_rows;
104+
105+ let new_coefficient_j: f64 = if rho < -self . lambda {
106+ ( rho + self . lambda ) / ( col_norm / n_rows)
107+ } else if rho > self . lambda {
108+ ( rho - self . lambda ) / ( col_norm / n_rows)
109+ } else {
110+ 0.0
111+ } ;
112+
113+ let delta: f64 = new_coefficient_j - coefficients[ j] ;
114+ if delta. abs ( ) > 0.0 {
115+ residual -= & feature_vals_col_j * delta;
116+ }
117+ coefficients[ j] = new_coefficient_j;
118+ max_delta = max_delta. max ( delta. abs ( ) ) ;
119+ }
120+
121+ if max_delta < self . tolerance {
122+ break ;
123+ }
124+ }
125+
126+ let intercept: f64 = if self . fit_intercept {
127+ self . y . mean ( ) - feature_means. dot ( & coefficients)
128+ } else {
129+ 0.0
130+ } ;
131+ coefficients = coefficients. insert_row ( 0 , intercept) ;
132+
133+ Ok ( LassoOutput {
134+ intercept,
135+ coefficients,
136+ } )
137+ }
138+ }
139+
140+ impl LassoOutput < f64 > {
141+ /// Predicts the output for the given input data.
142+ pub fn predict ( & self , input : DMatrix < f64 > ) -> Result < DVector < f64 > , RustQuantError > {
143+ let intercept = DVector :: from_element (
144+ input. nrows ( ) ,
145+ self . intercept
146+ ) ;
147+ let coefficients = self . coefficients . clone ( ) . remove_row ( 0 ) ;
148+ let predictions = input * coefficients + intercept;
149+ Ok ( predictions)
150+ }
151+ }
0 commit comments