@@ -594,7 +594,7 @@ use ::matrixmultiply;
594594#[ cfg( feature = "O3" ) ]
595595use blas:: { daxpy, dgemm, dgemv} ;
596596#[ cfg( feature = "O3" ) ]
597- use lapack:: { dgecon, dgeqrf, dgetrf, dgetri, dgetrs, dorgqr, dgesvd} ;
597+ use lapack:: { dgecon, dgeqrf, dgetrf, dgetri, dgetrs, dorgqr, dgesvd, dpotrf } ;
598598#[ cfg( feature = "O3" ) ]
599599use std:: f64:: NAN ;
600600
@@ -617,7 +617,7 @@ use std::convert;
617617pub use std:: error:: Error ;
618618use std:: fmt;
619619use std:: ops:: { Add , Div , Index , IndexMut , Mul , Neg , Sub } ;
620- use crate :: traits:: sugar:: { Scalable , ScalableMut } ;
620+ use crate :: traits:: sugar:: ScalableMut ;
621621
622622pub type Perms = Vec < ( usize , usize ) > ;
623623
@@ -2770,13 +2770,15 @@ pub trait LinearAlgebra {
27702770 fn waz ( & self , d_form : Form ) -> Option < WAZD > ;
27712771 fn qr ( & self ) -> QR ;
27722772 fn svd ( & self ) -> SVD ;
2773+ fn cholesky ( & self , uplo : UPLO ) -> Matrix ;
27732774 fn rref ( & self ) -> Matrix ;
27742775 fn det ( & self ) -> f64 ;
27752776 fn block ( & self ) -> ( Matrix , Matrix , Matrix , Matrix ) ;
27762777 fn inv ( & self ) -> Matrix ;
27772778 fn pseudo_inv ( & self ) -> Matrix ;
27782779 fn solve ( & self , b : & Vec < f64 > , sk : SolveKind ) -> Vec < f64 > ;
27792780 fn solve_mat ( & self , m : & Matrix , sk : SolveKind ) -> Matrix ;
2781+ fn is_symmetric ( & self ) -> bool ;
27802782}
27812783
27822784pub fn diag ( n : usize ) -> Matrix {
@@ -3221,6 +3223,55 @@ impl LinearAlgebra for Matrix {
32213223 }
32223224 }
32233225
3226+ /// Cholesky Decomposition
3227+ ///
3228+ /// # Examples
3229+ /// ```
3230+ /// extern crate peroxide;
3231+ /// use peroxide::fuga::*;
3232+ ///
3233+ /// fn main() {
3234+ /// let a = ml_matrix("1 2;2 5");
3235+ /// #[cfg(feature = "O3")]
3236+ /// {
3237+ /// let u = a.cholesky(Upper);
3238+ /// let l = a.cholesky(Lower);
3239+ ///
3240+ /// assert_eq!(u, ml_matrix("1 2;0 1"));
3241+ /// assert_eq!(l, ml_matrix("1 0;2 1"));
3242+ /// }
3243+ /// a.print();
3244+ /// }
3245+ /// ```
3246+ fn cholesky ( & self , uplo : UPLO ) -> Matrix {
3247+ match ( ) {
3248+ #[ cfg( feature = "O3" ) ]
3249+ ( ) => {
3250+ if !self . is_symmetric ( ) {
3251+ panic ! ( "Cholesky Error: Matrix is not symmetric!" ) ;
3252+ }
3253+ let dpotrf = lapack_dpotrf ( self , uplo) ;
3254+ match dpotrf {
3255+ None => panic ! ( "Cholesky Error: Not symmetric or not positive definite." ) ,
3256+ Some ( x) => {
3257+ match x. status {
3258+ POSITIVE_STATUS :: Failed ( i) => panic ! ( "Cholesky Error: the leading minor of order {} is not positive definite" , i) ,
3259+ POSITIVE_STATUS :: Success => {
3260+ match uplo {
3261+ UPLO :: Upper => x. get_U ( ) . unwrap ( ) ,
3262+ UPLO :: Lower => x. get_L ( ) . unwrap ( )
3263+ }
3264+ }
3265+ }
3266+ }
3267+ }
3268+ }
3269+ _ => {
3270+ unimplemented ! ( )
3271+ }
3272+ }
3273+ }
3274+
32243275 /// Reduced Row Echelon Form
32253276 ///
32263277 /// Implementation of [RosettaCode](https://rosettacode.org/wiki/Reduced_row_echelon_form)
@@ -3551,6 +3602,21 @@ impl LinearAlgebra for Matrix {
35513602 }
35523603 }
35533604 }
3605+
3606+ fn is_symmetric ( & self ) -> bool {
3607+ if self . row != self . col {
3608+ return false ;
3609+ }
3610+
3611+ for i in 0 .. self . row {
3612+ for j in i .. self . col {
3613+ if !nearly_eq ( self [ ( i, j) ] , self [ ( j, i) ] ) {
3614+ return false ;
3615+ }
3616+ }
3617+ }
3618+ true
3619+ }
35543620}
35553621
35563622#[ allow( non_snake_case) ]
@@ -4038,6 +4104,20 @@ pub enum SVD_STATUS {
40384104 Diverge ( i32 ) ,
40394105}
40404106
4107+ #[ allow( non_camel_case_types) ]
4108+ #[ derive( Debug , Copy , Clone , Eq , PartialEq ) ]
4109+ pub enum POSITIVE_STATUS {
4110+ Success ,
4111+ Failed ( i32 ) ,
4112+ }
4113+
4114+ #[ allow( non_camel_case_types) ]
4115+ #[ derive( Debug , Copy , Clone , Eq , PartialEq ) ]
4116+ pub enum UPLO {
4117+ Upper ,
4118+ Lower
4119+ }
4120+
40414121/// Temporary data structure from `dgetrf`
40424122#[ derive( Debug , Clone ) ]
40434123pub struct DGETRF {
@@ -4062,6 +4142,13 @@ pub struct DGESVD {
40624142 pub status : SVD_STATUS ,
40634143}
40644144
4145+ #[ derive( Debug , Clone ) ]
4146+ pub struct DPOTRF {
4147+ pub fact_mat : Matrix ,
4148+ pub uplo : UPLO ,
4149+ pub status : POSITIVE_STATUS
4150+ }
4151+
40654152///// Temporary data structure from `dgeev`
40664153//#[derive(Debug, Clone)]
40674154//pub struct DGEEV {
@@ -4300,6 +4387,54 @@ pub fn lapack_dgesvd(mat: &Matrix) -> Option<DGESVD> {
43004387 }
43014388}
43024389
4390+ #[ allow( non_snake_case) ]
4391+ #[ cfg( feature = "O3" ) ]
4392+ pub fn lapack_dpotrf ( mat : & Matrix , UPLO : UPLO ) -> Option < DPOTRF > {
4393+ match mat. shape {
4394+ Row => lapack_dpotrf ( & mat. change_shape ( ) , UPLO ) ,
4395+ Col => {
4396+ let lda = mat. row as i32 ;
4397+ let N = mat. col as i32 ;
4398+ let mut A = mat. clone ( ) ;
4399+ let mut info = 0i32 ;
4400+ let uplo = match UPLO {
4401+ UPLO :: Upper => b'U' ,
4402+ UPLO :: Lower => b'L'
4403+ } ;
4404+
4405+ unsafe {
4406+ dpotrf (
4407+ uplo,
4408+ N ,
4409+ & mut A . data ,
4410+ lda,
4411+ & mut info,
4412+ )
4413+ }
4414+
4415+ if info == 0 {
4416+ Some (
4417+ DPOTRF {
4418+ fact_mat : matrix ( A . data , mat. row , mat. col , Col ) ,
4419+ uplo : UPLO ,
4420+ status : POSITIVE_STATUS :: Success
4421+ }
4422+ )
4423+ } else if info > 0 {
4424+ Some (
4425+ DPOTRF {
4426+ fact_mat : matrix ( A . data , mat. row , mat. col , Col ) ,
4427+ uplo : UPLO ,
4428+ status : POSITIVE_STATUS :: Failed ( info)
4429+ }
4430+ )
4431+ } else {
4432+ None
4433+ }
4434+ }
4435+ }
4436+ }
4437+
43034438#[ allow( non_snake_case) ]
43044439#[ cfg( feature = "O3" ) ]
43054440impl DGETRF {
@@ -4396,6 +4531,42 @@ impl DGEQRF {
43964531 }
43974532}
43984533
4534+ #[ allow( non_snake_case) ]
4535+ impl DPOTRF {
4536+ pub fn get_U ( & self ) -> Option < Matrix > {
4537+ if self . uplo == UPLO :: Lower {
4538+ return None ;
4539+ }
4540+
4541+ let mat = & self . fact_mat ;
4542+ let n = mat. col ;
4543+ let mut result = matrix ( vec ! [ 0f64 ; n. pow( 2 ) ] , n, n, mat. shape ) ;
4544+ for i in 0 .. n {
4545+ for j in i .. n {
4546+ result[ ( i, j) ] = mat[ ( i, j) ] ;
4547+ }
4548+ }
4549+ Some ( result)
4550+ }
4551+
4552+ pub fn get_L ( & self ) -> Option < Matrix > {
4553+ if self . uplo == UPLO :: Upper {
4554+ return None ;
4555+ }
4556+
4557+ let mat = & self . fact_mat ;
4558+ let n = mat. col ;
4559+ let mut result = matrix ( vec ! [ 0f64 ; n. pow( 2 ) ] , n, n, mat. shape ) ;
4560+
4561+ for i in 0 .. n {
4562+ for j in 0 .. i+1 {
4563+ result[ ( i, j) ] = mat[ ( i, j) ] ;
4564+ }
4565+ }
4566+ Some ( result)
4567+ }
4568+ }
4569+
43994570#[ allow( non_snake_case) ]
44004571pub fn gen_householder ( a : & Vec < f64 > ) -> Matrix {
44014572 let mut v = a. fmap ( |t| t / ( a[ 0 ] + a. norm ( Norm :: L2 ) * a[ 0 ] . signum ( ) ) ) ;
0 commit comments