Skip to content

Commit 2946204

Browse files
committed
TEST: Add i32 gemm tests
1 parent 2ea01de commit 2946204

1 file changed

Lines changed: 31 additions & 1 deletion

File tree

tests/sgemm.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
extern crate itertools;
22
extern crate matrixmultiply;
33

4-
use matrixmultiply::{sgemm, dgemm};
4+
use matrixmultiply::{sgemm, dgemm, igemm};
55

66
use itertools::Itertools;
77
use itertools::{
@@ -35,6 +35,13 @@ impl Float for f64 {
3535
fn is_nan(self) -> bool { self.is_nan() }
3636
}
3737

38+
impl Float for i32 {
39+
fn zero() -> Self { 0 }
40+
fn one() -> Self { 1 }
41+
fn from(x: i64) -> Self { x as Self }
42+
fn nan() -> Self { i32::min_value() } // hack
43+
fn is_nan(self) -> bool { self == i32::min_value() }
44+
}
3845

3946
trait Gemm : Sized {
4047
unsafe fn gemm(
@@ -64,6 +71,24 @@ impl Gemm for f32 {
6471
}
6572
}
6673

74+
impl Gemm for i32 {
75+
unsafe fn gemm(
76+
m: usize, k: usize, n: usize,
77+
alpha: Self,
78+
a: *const Self, rsa: isize, csa: isize,
79+
b: *const Self, rsb: isize, csb: isize,
80+
beta: Self,
81+
c: *mut Self, rsc: isize, csc: isize) {
82+
igemm(
83+
m, k, n,
84+
alpha,
85+
a, rsa, csa,
86+
b, rsb, csb,
87+
beta,
88+
c, rsc, csc)
89+
}
90+
}
91+
6792
impl Gemm for f64 {
6893
unsafe fn gemm(
6994
m: usize, k: usize, n: usize,
@@ -99,6 +124,11 @@ fn test_dgemm_strides() {
99124
test_gemm_strides::<f64>();
100125
}
101126

127+
#[test]
128+
fn test_i32gemm_strides() {
129+
test_gemm_strides::<i32>();
130+
}
131+
102132
fn test_gemm_strides<F>() where F: Gemm + Float {
103133
for n in 0..10 {
104134
test_strides::<F>(n, n, n);

0 commit comments

Comments
 (0)