|
1 | 1 | extern crate itertools; |
2 | 2 | extern crate matrixmultiply; |
3 | 3 |
|
4 | | -use matrixmultiply::{sgemm, dgemm}; |
| 4 | +use matrixmultiply::{sgemm, dgemm, igemm}; |
5 | 5 |
|
6 | 6 | use itertools::Itertools; |
7 | 7 | use itertools::{ |
@@ -35,6 +35,13 @@ impl Float for f64 { |
35 | 35 | fn is_nan(self) -> bool { self.is_nan() } |
36 | 36 | } |
37 | 37 |
|
| 38 | +impl Float for i32 { |
| 39 | + fn zero() -> Self { 0 } |
| 40 | + fn one() -> Self { 1 } |
| 41 | + fn from(x: i64) -> Self { x as Self } |
| 42 | + fn nan() -> Self { i32::min_value() } // hack |
| 43 | + fn is_nan(self) -> bool { self == i32::min_value() } |
| 44 | +} |
38 | 45 |
|
39 | 46 | trait Gemm : Sized { |
40 | 47 | unsafe fn gemm( |
@@ -64,6 +71,24 @@ impl Gemm for f32 { |
64 | 71 | } |
65 | 72 | } |
66 | 73 |
|
| 74 | +impl Gemm for i32 { |
| 75 | + unsafe fn gemm( |
| 76 | + m: usize, k: usize, n: usize, |
| 77 | + alpha: Self, |
| 78 | + a: *const Self, rsa: isize, csa: isize, |
| 79 | + b: *const Self, rsb: isize, csb: isize, |
| 80 | + beta: Self, |
| 81 | + c: *mut Self, rsc: isize, csc: isize) { |
| 82 | + igemm( |
| 83 | + m, k, n, |
| 84 | + alpha, |
| 85 | + a, rsa, csa, |
| 86 | + b, rsb, csb, |
| 87 | + beta, |
| 88 | + c, rsc, csc) |
| 89 | + } |
| 90 | +} |
| 91 | + |
67 | 92 | impl Gemm for f64 { |
68 | 93 | unsafe fn gemm( |
69 | 94 | m: usize, k: usize, n: usize, |
@@ -99,6 +124,11 @@ fn test_dgemm_strides() { |
99 | 124 | test_gemm_strides::<f64>(); |
100 | 125 | } |
101 | 126 |
|
| 127 | +#[test] |
| 128 | +fn test_i32gemm_strides() { |
| 129 | + test_gemm_strides::<i32>(); |
| 130 | +} |
| 131 | + |
102 | 132 | fn test_gemm_strides<F>() where F: Gemm + Float { |
103 | 133 | for n in 0..10 { |
104 | 134 | test_strides::<F>(n, n, n); |
|
0 commit comments