Skip to content

Commit 1664e7c

Browse files
committed
For fun, and experiment with i32 matrix multiply
There are no 64-bit multiplies in simd avx, so don't try that. i32 will autovectorize on avx, so it's a bit fun.
1 parent 52c6abc commit 1664e7c

5 files changed

Lines changed: 283 additions & 6 deletions

File tree

benches/benchmarks.rs

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
extern crate matrixmultiply;
22
pub use matrixmultiply::sgemm;
33
pub use matrixmultiply::dgemm;
4+
pub use matrixmultiply::igemm;
45

56
#[macro_use]
67
extern crate bencher;
@@ -10,7 +11,13 @@ extern crate bencher;
1011
// by flop / s = 2 M N K / time
1112

1213

13-
benchmark_main!(mat_mul_f32, mat_mul_f64, layout_f32_032, layout_f64_032);
14+
benchmark_main!(
15+
mat_mul_f32,
16+
mat_mul_f64,
17+
layout_f32_032,
18+
layout_f64_032,
19+
mat_mul_i32
20+
);
1421

1522
macro_rules! mat_mul {
1623
($modname:ident, $gemm:ident, $(($name:ident, $m:expr, $n:expr, $k:expr))+) => {
@@ -20,17 +27,17 @@ macro_rules! mat_mul {
2027
$(
2128
pub fn $name(bench: &mut Bencher)
2229
{
23-
let a = vec![0.; $m * $n];
24-
let b = vec![0.; $n * $k];
25-
let mut c = vec![0.; $m * $k];
30+
let a = vec![0 as _; $m * $n];
31+
let b = vec![0 as _; $n * $k];
32+
let mut c = vec![0 as _; $m * $k];
2633
bench.iter(|| {
2734
unsafe {
2835
$gemm(
2936
$m, $n, $k,
30-
1.,
37+
1 as _,
3138
a.as_ptr(), $n, 1,
3239
b.as_ptr(), $k, 1,
33-
0.,
40+
0 as _,
3441
c.as_mut_ptr(), $k, 1,
3542
)
3643
}
@@ -219,3 +226,22 @@ ref_mat_mul!{ref_mat_mul_f32, f32,
219226
(m032, 32, 32, 32)
220227
(m064, 64, 64, 64)
221228
}
229+
230+
mat_mul!{mat_mul_i32, igemm,
231+
(m004, 4, 4, 4)
232+
(m006, 6, 6, 6)
233+
(m008, 8, 8, 8)
234+
(m012, 12, 12, 12)
235+
(m016, 16, 16, 16)
236+
(m032, 32, 32, 32)
237+
(m064, 64, 64, 64)
238+
(m127, 127, 127, 127)
239+
/*
240+
(m256, 256, 256, 256)
241+
(m512, 512, 512, 512)
242+
(mix16x4, 32, 4, 32)
243+
(mix32x2, 32, 2, 32)
244+
(mix97, 97, 97, 125)
245+
(mix128x10000x128, 128, 10000, 128)
246+
*/
247+
}

src/gemm.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use kernel::GemmKernel;
1919
use kernel::Element;
2020
use sgemm_kernel;
2121
use dgemm_kernel;
22+
use igemm_kernel;
2223
use rawpointer::PointerExt;
2324

2425
/// General matrix multiplication (f32)
@@ -87,6 +88,23 @@ pub unsafe fn dgemm(
8788
c, rsc, csc)
8889
}
8990

91+
pub unsafe fn igemm(
92+
m: usize, k: usize, n: usize,
93+
alpha: i32,
94+
a: *const i32, rsa: isize, csa: isize,
95+
b: *const i32, rsb: isize, csb: isize,
96+
beta: i32,
97+
c: *mut i32, rsc: isize, csc: isize)
98+
{
99+
gemm_loop::<igemm_kernel::Gemm>(
100+
m, k, n,
101+
alpha,
102+
a, rsa, csa,
103+
b, rsb, csb,
104+
beta,
105+
c, rsc, csc)
106+
}
107+
90108
/// Ensure that GemmKernel parameters are supported
91109
/// (alignment, microkernel size).
92110
///

src/igemm_kernel.rs

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
// Copyright 2016 - 2018 Ulrik Sverdrup "bluss"
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use kernel::GemmKernel;
10+
use kernel::Element;
11+
use archparam;
12+
13+
14+
#[cfg(target_arch="x86")]
15+
use std::arch::x86::*;
16+
#[cfg(target_arch="x86_64")]
17+
use std::arch::x86_64::*;
18+
19+
pub enum Gemm { }
20+
21+
pub type T = i32;
22+
23+
const MR: usize = 8;
24+
const NR: usize = 4;
25+
26+
macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; }
27+
macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }
28+
29+
impl GemmKernel for Gemm {
30+
type Elem = T;
31+
32+
#[inline(always)]
33+
fn align_to() -> usize { 16 }
34+
35+
#[inline(always)]
36+
fn mr() -> usize { MR }
37+
#[inline(always)]
38+
fn nr() -> usize { NR }
39+
40+
#[inline(always)]
41+
fn always_masked() -> bool { true }
42+
43+
#[inline(always)]
44+
fn nc() -> usize { archparam::S_NC }
45+
#[inline(always)]
46+
fn kc() -> usize { archparam::S_KC }
47+
#[inline(always)]
48+
fn mc() -> usize { archparam::S_MC }
49+
50+
#[inline(always)]
51+
unsafe fn kernel(
52+
k: usize,
53+
alpha: T,
54+
a: *const T,
55+
b: *const T,
56+
beta: T,
57+
c: *mut T, rsc: isize, csc: isize) {
58+
kernel(k, alpha, a, b, beta, c, rsc, csc)
59+
}
60+
}
61+
62+
/// matrix multiplication kernel
63+
///
64+
/// This does the matrix multiplication:
65+
///
66+
/// C ← α A B + β C
67+
///
68+
/// + k: length of data in a, b
69+
/// + a, b are packed
70+
/// + c has general strides
71+
/// + rsc: row stride of c
72+
/// + csc: col stride of c
73+
/// + if beta is 0, then c does not need to be initialized
74+
#[inline(never)]
75+
pub unsafe fn kernel(k: usize, alpha: T, a: *const T, b: *const T,
76+
beta: T, c: *mut T, rsc: isize, csc: isize)
77+
{
78+
// dispatch to specific compiled versions
79+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
80+
{
81+
if is_x86_feature_detected_!("avx") {
82+
return kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc);
83+
} else if is_x86_feature_detected_!("sse2") {
84+
return kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc);
85+
}
86+
}
87+
kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc);
88+
}
89+
90+
#[inline]
91+
#[target_feature(enable="avx")]
92+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
93+
unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T,
94+
beta: T, c: *mut T, rsc: isize, csc: isize)
95+
{
96+
kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
97+
}
98+
99+
#[inline]
100+
#[target_feature(enable="sse2")]
101+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
102+
unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T,
103+
beta: T, c: *mut T, rsc: isize, csc: isize)
104+
{
105+
kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
106+
}
107+
108+
109+
#[inline(always)]
110+
unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
111+
beta: T, c: *mut T, rsc: isize, csc: isize)
112+
{
113+
let mut ab: [[T; NR]; MR] = [[0; NR]; MR];
114+
let mut a = a;
115+
let mut b = b;
116+
debug_assert_eq!(beta, 0);
117+
118+
// Compute A B into ab[i][j]
119+
unroll_by!(4 => k, {
120+
loop_m!(i, loop_n!(j, {
121+
ab[i][j] = ab[i][j].wrapping_add(at(a, i).wrapping_mul(at(b, j)));
122+
}));
123+
124+
a = a.offset(MR as isize);
125+
b = b.offset(NR as isize);
126+
});
127+
128+
macro_rules! c {
129+
($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
130+
}
131+
132+
// set C = α A B + β C
133+
loop_n!(j, loop_m!(i, *c![i, j] = alpha.wrapping_mul(ab[i][j])));
134+
}
135+
136+
#[inline(always)]
137+
unsafe fn at(ptr: *const T, i: usize) -> T {
138+
*ptr.offset(i as isize)
139+
}
140+
141+
#[cfg(test)]
142+
mod tests {
143+
use super::*;
144+
use aligned_alloc::Alloc;
145+
146+
fn aligned_alloc<T>(elt: T, n: usize) -> Alloc<T> where T: Copy
147+
{
148+
unsafe {
149+
Alloc::new(n, Gemm::align_to()).init_with(elt)
150+
}
151+
}
152+
153+
use super::T;
154+
type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize);
155+
156+
fn test_a_kernel(_name: &str, kernel_fn: KernelFn) {
157+
const K: usize = 4;
158+
let mut a = aligned_alloc(1, MR * K);
159+
let mut b = aligned_alloc(0, NR * K);
160+
for (i, x) in a.iter_mut().enumerate() {
161+
*x = i as _;
162+
}
163+
164+
for i in 0..K {
165+
b[i + i * NR] = 1;
166+
}
167+
let mut c = [0; MR * NR];
168+
unsafe {
169+
kernel_fn(K, 1, &a[0], &b[0], 0, &mut c[0], 1, MR as isize);
170+
// col major C
171+
}
172+
assert_eq!(&a[..], &c[..a.len()]);
173+
}
174+
175+
#[test]
176+
fn test_native_kernel() {
177+
test_a_kernel("kernel", kernel);
178+
}
179+
180+
#[test]
181+
fn test_kernel_fallback_impl() {
182+
test_a_kernel("kernel", kernel_fallback_impl);
183+
}
184+
185+
#[test]
186+
fn test_loop_m_n() {
187+
let mut m = [[0; NR]; MR];
188+
loop_m!(i, loop_n!(j, m[i][j] += 1));
189+
for arr in &m[..] {
190+
for elt in &arr[..] {
191+
assert_eq!(*elt, 1);
192+
}
193+
}
194+
}
195+
196+
mod test_arch_kernels {
197+
use super::test_a_kernel;
198+
macro_rules! test_arch_kernels_x86 {
199+
($($feature_name:tt, $function_name:ident),*) => {
200+
$(
201+
#[test]
202+
fn $function_name() {
203+
if is_x86_feature_detected_!($feature_name) {
204+
test_a_kernel(stringify!($function_name), super::super::$function_name);
205+
} else {
206+
println!("Skipping, host does not have feature: {:?}", $feature_name);
207+
}
208+
}
209+
)*
210+
}
211+
}
212+
213+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
214+
test_arch_kernels_x86! {
215+
"avx", kernel_target_avx,
216+
"sse2", kernel_target_sse2
217+
}
218+
}
219+
}

src/kernel.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,15 @@ impl Element for f64 {
7979
*self += alpha * a;
8080
}
8181
}
82+
83+
impl Element for i32 {
84+
fn zero() -> Self { 0 }
85+
fn one() -> Self { 1 }
86+
fn is_zero(&self) -> bool { *self == 0 }
87+
fn scale_by(&mut self, x: Self) {
88+
*self = self.wrapping_mul(x);
89+
}
90+
fn scaled_add(&mut self, alpha: Self, a: Self) {
91+
*self = self.wrapping_add(alpha.wrapping_mul(a));
92+
}
93+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ mod kernel;
6262
mod gemm;
6363
mod sgemm_kernel;
6464
mod dgemm_kernel;
65+
mod igemm_kernel;
6566
mod util;
6667
mod aligned_alloc;
6768

6869
pub use gemm::sgemm;
6970
pub use gemm::dgemm;
71+
pub use gemm::igemm;

0 commit comments

Comments
 (0)