Skip to content

Commit 7166f2f

Browse files
committed
Merge branch 'features/O3' into dev
2 parents cfdbe2e + 7db7610 commit 7166f2f

7 files changed

Lines changed: 336 additions & 14 deletions

File tree

src/fuga/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ pub use crate::statistics::stat::QType::{
170170
pub use crate::structure::matrix::{
171171
Form::{Diagonal, Identity},
172172
SolveKind::{LU, WAZ},
173+
UPLO::{Upper, Lower}
173174
};
174175
pub use crate::structure::dataframe::DType::*;
175176
pub use crate::structure::ad::AD::*;

src/prelude/simpler.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@ pub trait SimplerLinearAlgebra {
2727
fn waz_diag(&self) -> Option<matrix::WAZD>;
2828
fn waz(&self) -> Option<matrix::WAZD>;
2929
fn qr(&self) -> matrix::QR;
30+
fn cholesky(&self) -> Matrix;
3031
fn rref(&self) -> Matrix;
3132
fn det(&self) -> f64;
3233
fn block(&self) -> (Matrix, Matrix, Matrix, Matrix);
3334
fn inv(&self) -> Matrix;
3435
fn pseudo_inv(&self) -> Matrix;
3536
fn solve(&self, b: &Vec<f64>) -> Vec<f64>;
3637
fn solve_mat(&self, m: &Matrix) -> Matrix;
38+
fn is_symmetric(&self) -> bool;
3739
}
3840

3941
/// Simple Eigenpair
@@ -115,6 +117,14 @@ impl SimplerLinearAlgebra for Matrix {
115117
fn solve_mat(&self, m: &Matrix) -> Matrix {
116118
matrix::LinearAlgebra::solve_mat(self, m, matrix::SolveKind::LU)
117119
}
120+
121+
fn cholesky(&self) -> Matrix {
122+
matrix::LinearAlgebra::cholesky(self, matrix::UPLO::Lower)
123+
}
124+
125+
fn is_symmetric(&self) -> bool {
126+
matrix::LinearAlgebra::is_symmetric(self)
127+
}
118128
}
119129

120130
/// Simple solve

src/structure/matrix.rs

Lines changed: 173 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ use ::matrixmultiply;
594594
#[cfg(feature = "O3")]
595595
use blas::{daxpy, dgemm, dgemv};
596596
#[cfg(feature = "O3")]
597-
use lapack::{dgecon, dgeqrf, dgetrf, dgetri, dgetrs, dorgqr, dgesvd};
597+
use lapack::{dgecon, dgeqrf, dgetrf, dgetri, dgetrs, dorgqr, dgesvd, dpotrf};
598598
#[cfg(feature = "O3")]
599599
use std::f64::NAN;
600600

@@ -617,7 +617,7 @@ use std::convert;
617617
pub use std::error::Error;
618618
use std::fmt;
619619
use std::ops::{Add, Div, Index, IndexMut, Mul, Neg, Sub};
620-
use crate::traits::sugar::{Scalable, ScalableMut};
620+
use crate::traits::sugar::ScalableMut;
621621

622622
pub type Perms = Vec<(usize, usize)>;
623623

@@ -2770,13 +2770,15 @@ pub trait LinearAlgebra {
27702770
fn waz(&self, d_form: Form) -> Option<WAZD>;
27712771
fn qr(&self) -> QR;
27722772
fn svd(&self) -> SVD;
2773+
fn cholesky(&self, uplo: UPLO) -> Matrix;
27732774
fn rref(&self) -> Matrix;
27742775
fn det(&self) -> f64;
27752776
fn block(&self) -> (Matrix, Matrix, Matrix, Matrix);
27762777
fn inv(&self) -> Matrix;
27772778
fn pseudo_inv(&self) -> Matrix;
27782779
fn solve(&self, b: &Vec<f64>, sk: SolveKind) -> Vec<f64>;
27792780
fn solve_mat(&self, m: &Matrix, sk: SolveKind) -> Matrix;
2781+
fn is_symmetric(&self) -> bool;
27802782
}
27812783

27822784
pub fn diag(n: usize) -> Matrix {
@@ -3221,6 +3223,55 @@ impl LinearAlgebra for Matrix {
32213223
}
32223224
}
32233225

3226+
/// Cholesky Decomposition
3227+
///
3228+
/// # Examples
3229+
/// ```
3230+
/// extern crate peroxide;
3231+
/// use peroxide::fuga::*;
3232+
///
3233+
/// fn main() {
3234+
/// let a = ml_matrix("1 2;2 5");
3235+
/// #[cfg(feature = "O3")]
3236+
/// {
3237+
/// let u = a.cholesky(Upper);
3238+
/// let l = a.cholesky(Lower);
3239+
///
3240+
/// assert_eq!(u, ml_matrix("1 2;0 1"));
3241+
/// assert_eq!(l, ml_matrix("1 0;2 1"));
3242+
/// }
3243+
/// a.print();
3244+
/// }
3245+
/// ```
3246+
fn cholesky(&self, uplo: UPLO) -> Matrix {
3247+
match () {
3248+
#[cfg(feature = "O3")]
3249+
() => {
3250+
if !self.is_symmetric() {
3251+
panic!("Cholesky Error: Matrix is not symmetric!");
3252+
}
3253+
let dpotrf = lapack_dpotrf(self, uplo);
3254+
match dpotrf {
3255+
None => panic!("Cholesky Error: Not symmetric or not positive definite."),
3256+
Some(x) => {
3257+
match x.status {
3258+
POSITIVE_STATUS::Failed(i) => panic!("Cholesky Error: the leading minor of order {} is not positive definite", i),
3259+
POSITIVE_STATUS::Success => {
3260+
match uplo {
3261+
UPLO::Upper => x.get_U().unwrap(),
3262+
UPLO::Lower => x.get_L().unwrap()
3263+
}
3264+
}
3265+
}
3266+
}
3267+
}
3268+
}
3269+
_ => {
3270+
unimplemented!()
3271+
}
3272+
}
3273+
}
3274+
32243275
/// Reduced Row Echelon Form
32253276
///
32263277
/// Implementation of [RosettaCode](https://rosettacode.org/wiki/Reduced_row_echelon_form)
@@ -3551,6 +3602,21 @@ impl LinearAlgebra for Matrix {
35513602
}
35523603
}
35533604
}
3605+
3606+
fn is_symmetric(&self) -> bool {
3607+
if self.row != self.col {
3608+
return false;
3609+
}
3610+
3611+
for i in 0 .. self.row {
3612+
for j in i .. self.col {
3613+
if !nearly_eq(self[(i,j)], self[(j,i)]) {
3614+
return false;
3615+
}
3616+
}
3617+
}
3618+
true
3619+
}
35543620
}
35553621

35563622
#[allow(non_snake_case)]
@@ -4038,6 +4104,20 @@ pub enum SVD_STATUS {
40384104
Diverge(i32),
40394105
}
40404106

4107+
#[allow(non_camel_case_types)]
4108+
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
4109+
pub enum POSITIVE_STATUS {
4110+
Success,
4111+
Failed(i32),
4112+
}
4113+
4114+
#[allow(non_camel_case_types)]
4115+
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
4116+
pub enum UPLO {
4117+
Upper,
4118+
Lower
4119+
}
4120+
40414121
/// Temporary data structure from `dgetrf`
40424122
#[derive(Debug, Clone)]
40434123
pub struct DGETRF {
@@ -4062,6 +4142,13 @@ pub struct DGESVD {
40624142
pub status: SVD_STATUS,
40634143
}
40644144

4145+
#[derive(Debug, Clone)]
4146+
pub struct DPOTRF {
4147+
pub fact_mat: Matrix,
4148+
pub uplo: UPLO,
4149+
pub status: POSITIVE_STATUS
4150+
}
4151+
40654152
///// Temporary data structure from `dgeev`
40664153
//#[derive(Debug, Clone)]
40674154
//pub struct DGEEV {
@@ -4300,6 +4387,54 @@ pub fn lapack_dgesvd(mat: &Matrix) -> Option<DGESVD> {
43004387
}
43014388
}
43024389

4390+
#[allow(non_snake_case)]
4391+
#[cfg(feature = "O3")]
4392+
pub fn lapack_dpotrf(mat: &Matrix, UPLO: UPLO) -> Option<DPOTRF> {
4393+
match mat.shape {
4394+
Row => lapack_dpotrf(&mat.change_shape(), UPLO),
4395+
Col => {
4396+
let lda = mat.row as i32;
4397+
let N = mat.col as i32;
4398+
let mut A = mat.clone();
4399+
let mut info = 0i32;
4400+
let uplo = match UPLO {
4401+
UPLO::Upper => b'U',
4402+
UPLO::Lower => b'L'
4403+
};
4404+
4405+
unsafe {
4406+
dpotrf(
4407+
uplo,
4408+
N,
4409+
&mut A.data,
4410+
lda,
4411+
&mut info,
4412+
)
4413+
}
4414+
4415+
if info == 0 {
4416+
Some(
4417+
DPOTRF {
4418+
fact_mat: matrix(A.data, mat.row, mat.col, Col),
4419+
uplo: UPLO,
4420+
status: POSITIVE_STATUS::Success
4421+
}
4422+
)
4423+
} else if info > 0 {
4424+
Some(
4425+
DPOTRF {
4426+
fact_mat: matrix(A.data, mat.row, mat.col, Col),
4427+
uplo: UPLO,
4428+
status: POSITIVE_STATUS::Failed(info)
4429+
}
4430+
)
4431+
} else {
4432+
None
4433+
}
4434+
}
4435+
}
4436+
}
4437+
43034438
#[allow(non_snake_case)]
43044439
#[cfg(feature = "O3")]
43054440
impl DGETRF {
@@ -4396,6 +4531,42 @@ impl DGEQRF {
43964531
}
43974532
}
43984533

4534+
#[allow(non_snake_case)]
4535+
impl DPOTRF {
4536+
pub fn get_U(&self) -> Option<Matrix> {
4537+
if self.uplo == UPLO::Lower {
4538+
return None;
4539+
}
4540+
4541+
let mat = &self.fact_mat;
4542+
let n = mat.col;
4543+
let mut result = matrix(vec![0f64; n.pow(2)], n, n, mat.shape);
4544+
for i in 0 .. n {
4545+
for j in i .. n {
4546+
result[(i, j)] = mat[(i, j)];
4547+
}
4548+
}
4549+
Some(result)
4550+
}
4551+
4552+
pub fn get_L(&self) -> Option<Matrix> {
4553+
if self.uplo == UPLO::Upper {
4554+
return None;
4555+
}
4556+
4557+
let mat = &self.fact_mat;
4558+
let n = mat.col;
4559+
let mut result = matrix(vec![0f64; n.pow(2)], n, n, mat.shape);
4560+
4561+
for i in 0 .. n {
4562+
for j in 0 .. i+1 {
4563+
result[(i, j)] = mat[(i, j)];
4564+
}
4565+
}
4566+
Some(result)
4567+
}
4568+
}
4569+
43994570
#[allow(non_snake_case)]
44004571
pub fn gen_householder(a: &Vec<f64>) -> Matrix {
44014572
let mut v = a.fmap(|t| t / (a[0] + a.norm(Norm::L2) * a[0].signum()));

src/structure/sparse.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::traits::math::LinearOp;
77
//use crate::traits::math::{InnerProduct, LinearOp, Norm, Normed, Vector};
88
use crate::util::non_macro::zeros;
99
use std::ops::Mul;
10+
use crate::fuga::UPLO;
1011

1112
#[derive(Debug, Clone)]
1213
pub struct SPMatrix {
@@ -188,6 +189,14 @@ impl LinearAlgebra for SPMatrix {
188189
fn svd(&self) -> SVD {
189190
unimplemented!()
190191
}
192+
193+
fn cholesky(&self, uplo: UPLO) -> Matrix {
194+
unimplemented!()
195+
}
196+
197+
fn is_symmetric(&self) -> bool {
198+
unimplemented!()
199+
}
191200
}
192201

193202
/// Matrix multiplication with vector

0 commit comments

Comments
 (0)