|
| 1 | +// Copyright 2016 - 2018 Ulrik Sverdrup "bluss" |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
| 4 | +// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
| 5 | +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your |
| 6 | +// option. This file may not be copied, modified, or distributed |
| 7 | +// except according to those terms. |
| 8 | + |
| 9 | +use kernel::GemmKernel; |
| 10 | +use kernel::Element; |
| 11 | +use archparam; |
| 12 | + |
| 13 | + |
| 14 | +#[cfg(target_arch="x86")] |
| 15 | +use std::arch::x86::*; |
| 16 | +#[cfg(target_arch="x86_64")] |
| 17 | +use std::arch::x86_64::*; |
| 18 | + |
| 19 | +pub enum Gemm { } |
| 20 | + |
| 21 | +pub type T = i32; |
| 22 | + |
| 23 | +const MR: usize = 8; |
| 24 | +const NR: usize = 4; |
| 25 | + |
| 26 | +macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; } |
| 27 | +macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; } |
| 28 | + |
| 29 | +impl GemmKernel for Gemm { |
| 30 | + type Elem = T; |
| 31 | + |
| 32 | + #[inline(always)] |
| 33 | + fn align_to() -> usize { 16 } |
| 34 | + |
| 35 | + #[inline(always)] |
| 36 | + fn mr() -> usize { MR } |
| 37 | + #[inline(always)] |
| 38 | + fn nr() -> usize { NR } |
| 39 | + |
| 40 | + #[inline(always)] |
| 41 | + fn always_masked() -> bool { true } |
| 42 | + |
| 43 | + #[inline(always)] |
| 44 | + fn nc() -> usize { archparam::S_NC } |
| 45 | + #[inline(always)] |
| 46 | + fn kc() -> usize { archparam::S_KC } |
| 47 | + #[inline(always)] |
| 48 | + fn mc() -> usize { archparam::S_MC } |
| 49 | + |
| 50 | + #[inline(always)] |
| 51 | + unsafe fn kernel( |
| 52 | + k: usize, |
| 53 | + alpha: T, |
| 54 | + a: *const T, |
| 55 | + b: *const T, |
| 56 | + beta: T, |
| 57 | + c: *mut T, rsc: isize, csc: isize) { |
| 58 | + kernel(k, alpha, a, b, beta, c, rsc, csc) |
| 59 | + } |
| 60 | +} |
| 61 | + |
| 62 | +/// matrix multiplication kernel |
| 63 | +/// |
| 64 | +/// This does the matrix multiplication: |
| 65 | +/// |
| 66 | +/// C ← α A B + β C |
| 67 | +/// |
| 68 | +/// + k: length of data in a, b |
| 69 | +/// + a, b are packed |
| 70 | +/// + c has general strides |
| 71 | +/// + rsc: row stride of c |
| 72 | +/// + csc: col stride of c |
| 73 | +/// + if beta is 0, then c does not need to be initialized |
| 74 | +#[inline(never)] |
| 75 | +pub unsafe fn kernel(k: usize, alpha: T, a: *const T, b: *const T, |
| 76 | + beta: T, c: *mut T, rsc: isize, csc: isize) |
| 77 | +{ |
| 78 | + // dispatch to specific compiled versions |
| 79 | + #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
| 80 | + { |
| 81 | + if is_x86_feature_detected_!("avx") { |
| 82 | + return kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc); |
| 83 | + } else if is_x86_feature_detected_!("sse2") { |
| 84 | + return kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc); |
| 85 | + } |
| 86 | + } |
| 87 | + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc); |
| 88 | +} |
| 89 | + |
| 90 | +#[inline] |
| 91 | +#[target_feature(enable="avx")] |
| 92 | +#[cfg(any(target_arch="x86", target_arch="x86_64"))] |
| 93 | +unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T, |
| 94 | + beta: T, c: *mut T, rsc: isize, csc: isize) |
| 95 | +{ |
| 96 | + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) |
| 97 | +} |
| 98 | + |
| 99 | +#[inline] |
| 100 | +#[target_feature(enable="sse2")] |
| 101 | +#[cfg(any(target_arch="x86", target_arch="x86_64"))] |
| 102 | +unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T, |
| 103 | + beta: T, c: *mut T, rsc: isize, csc: isize) |
| 104 | +{ |
| 105 | + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) |
| 106 | +} |
| 107 | + |
| 108 | + |
| 109 | +#[inline(always)] |
| 110 | +unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T, |
| 111 | + beta: T, c: *mut T, rsc: isize, csc: isize) |
| 112 | +{ |
| 113 | + let mut ab: [[T; NR]; MR] = [[0; NR]; MR]; |
| 114 | + let mut a = a; |
| 115 | + let mut b = b; |
| 116 | + debug_assert_eq!(beta, 0); |
| 117 | + |
| 118 | + // Compute A B into ab[i][j] |
| 119 | + unroll_by!(4 => k, { |
| 120 | + loop_m!(i, loop_n!(j, { |
| 121 | + ab[i][j] = ab[i][j].wrapping_add(at(a, i).wrapping_mul(at(b, j))); |
| 122 | + })); |
| 123 | + |
| 124 | + a = a.offset(MR as isize); |
| 125 | + b = b.offset(NR as isize); |
| 126 | + }); |
| 127 | + |
| 128 | + macro_rules! c { |
| 129 | + ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize)); |
| 130 | + } |
| 131 | + |
| 132 | + // set C = α A B + β C |
| 133 | + loop_n!(j, loop_m!(i, *c![i, j] = alpha.wrapping_mul(ab[i][j]))); |
| 134 | +} |
| 135 | + |
| 136 | +#[inline(always)] |
| 137 | +unsafe fn at(ptr: *const T, i: usize) -> T { |
| 138 | + *ptr.offset(i as isize) |
| 139 | +} |
| 140 | + |
| 141 | +#[cfg(test)] |
| 142 | +mod tests { |
| 143 | + use super::*; |
| 144 | + use aligned_alloc::Alloc; |
| 145 | + |
| 146 | + fn aligned_alloc<T>(elt: T, n: usize) -> Alloc<T> where T: Copy |
| 147 | + { |
| 148 | + unsafe { |
| 149 | + Alloc::new(n, Gemm::align_to()).init_with(elt) |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + use super::T; |
| 154 | + type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize); |
| 155 | + |
| 156 | + fn test_a_kernel(_name: &str, kernel_fn: KernelFn) { |
| 157 | + const K: usize = 4; |
| 158 | + let mut a = aligned_alloc(1, MR * K); |
| 159 | + let mut b = aligned_alloc(0, NR * K); |
| 160 | + for (i, x) in a.iter_mut().enumerate() { |
| 161 | + *x = i as _; |
| 162 | + } |
| 163 | + |
| 164 | + for i in 0..K { |
| 165 | + b[i + i * NR] = 1; |
| 166 | + } |
| 167 | + let mut c = [0; MR * NR]; |
| 168 | + unsafe { |
| 169 | + kernel_fn(K, 1, &a[0], &b[0], 0, &mut c[0], 1, MR as isize); |
| 170 | + // col major C |
| 171 | + } |
| 172 | + assert_eq!(&a[..], &c[..a.len()]); |
| 173 | + } |
| 174 | + |
| 175 | + #[test] |
| 176 | + fn test_native_kernel() { |
| 177 | + test_a_kernel("kernel", kernel); |
| 178 | + } |
| 179 | + |
| 180 | + #[test] |
| 181 | + fn test_kernel_fallback_impl() { |
| 182 | + test_a_kernel("kernel", kernel_fallback_impl); |
| 183 | + } |
| 184 | + |
| 185 | + #[test] |
| 186 | + fn test_loop_m_n() { |
| 187 | + let mut m = [[0; NR]; MR]; |
| 188 | + loop_m!(i, loop_n!(j, m[i][j] += 1)); |
| 189 | + for arr in &m[..] { |
| 190 | + for elt in &arr[..] { |
| 191 | + assert_eq!(*elt, 1); |
| 192 | + } |
| 193 | + } |
| 194 | + } |
| 195 | + |
| 196 | + mod test_arch_kernels { |
| 197 | + use super::test_a_kernel; |
| 198 | + macro_rules! test_arch_kernels_x86 { |
| 199 | + ($($feature_name:tt, $function_name:ident),*) => { |
| 200 | + $( |
| 201 | + #[test] |
| 202 | + fn $function_name() { |
| 203 | + if is_x86_feature_detected_!($feature_name) { |
| 204 | + test_a_kernel(stringify!($function_name), super::super::$function_name); |
| 205 | + } else { |
| 206 | + println!("Skipping, host does not have feature: {:?}", $feature_name); |
| 207 | + } |
| 208 | + } |
| 209 | + )* |
| 210 | + } |
| 211 | + } |
| 212 | + |
| 213 | + #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
| 214 | + test_arch_kernels_x86! { |
| 215 | + "avx", kernel_target_avx, |
| 216 | + "sse2", kernel_target_sse2 |
| 217 | + } |
| 218 | + } |
| 219 | +} |
0 commit comments