-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkernel_scalar_product.zig
More file actions
44 lines (36 loc) · 1.19 KB
/
kernel_scalar_product.zig
File metadata and controls
44 lines (36 loc) · 1.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// examples/kernel/1_Reduction/kernel_scalar_product.zig — Batched dot products
//
// Reference: cuda-samples/0_Introduction/fp16ScalarProduct
// API exercised: SharedArray, __fmaf_rn, __syncthreads, reduceSum
const cuda = @import("zcuda_kernel");
const smem = cuda.shared_mem;
const BLOCK_SIZE = 256;
/// Batched scalar (dot) product: result[b] = dot(A[b], B[b])
/// Each block computes one dot product of length `vec_len`.
export fn batchedDotProduct(
A: [*]const f32,
B: [*]const f32,
results: [*]f32,
vec_len: u32,
num_vectors: u32,
) callconv(.kernel) void {
const batch = cuda.blockIdx().x;
if (batch >= num_vectors) return;
const tile = smem.SharedArray(f32, BLOCK_SIZE);
const sdata = tile.ptr();
const tid = cuda.threadIdx().x;
const offset = batch * vec_len;
// Each thread accumulates partial dot product
var sum: f32 = 0.0;
var i = tid;
while (i < vec_len) : (i += BLOCK_SIZE) {
sum = cuda.__fmaf_rn(A[offset + i], B[offset + i], sum);
}
sdata[tid] = sum;
cuda.__syncthreads();
// Block-level reduction
smem.reduceSum(f32, sdata, tid, BLOCK_SIZE);
if (tid == 0) {
results[batch] = sdata[0];
}
}