-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnrm2_asum.zig
More file actions
63 lines (50 loc) Β· 1.99 KB
/
nrm2_asum.zig
File metadata and controls
63 lines (50 loc) Β· 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
/// cuBLAS NRM2 + ASUM Example
///
/// Demonstrates vector norm operations:
/// - SNRM2: Euclidean norm ||x||β = β(Ξ£ xα΅’Β²)
/// - SASUM: Sum of absolute values ||x||β = Ξ£ |xα΅’|
///
/// Reference: CUDALibrarySamples/cuBLAS/Level-1/nrm2 + Level-1/asum
const std = @import("std");
const cuda = @import("zcuda");
pub fn main() !void {
std.debug.print("=== cuBLAS NRM2 + ASUM Example ===\n\n", .{});
const ctx = try cuda.driver.CudaContext.new(0);
defer ctx.deinit();
const stream = ctx.defaultStream();
const blas = try cuda.cublas.CublasContext.init(ctx);
defer blas.deinit();
const n: i32 = 5;
const x_data = [_]f32{ 3.0, -4.0, 5.0, -12.0, 8.0 };
const d_x = try stream.cloneHtoD(f32, &x_data);
defer d_x.deinit();
std.debug.print("x = [ ", .{});
for (&x_data) |v| std.debug.print("{d:.1} ", .{v});
std.debug.print("]\n\n", .{});
// SNRM2: L2 norm
const l2_norm = try blas.snrm2(n, d_x);
var expected_l2: f32 = 0.0;
for (&x_data) |v| expected_l2 += v * v;
expected_l2 = @sqrt(expected_l2);
std.debug.print("βββ L2 Norm (SNRM2) βββ\n", .{});
std.debug.print(" ||x||β = {d:.6}\n", .{l2_norm});
std.debug.print(" Expected: {d:.6}\n", .{expected_l2});
if (@abs(l2_norm - expected_l2) > 1e-4) {
std.debug.print(" β FAILED\n", .{});
return error.ValidationFailed;
}
std.debug.print(" β Verified\n\n", .{});
// SASUM: L1 norm
const l1_norm = try blas.sasum(n, d_x);
var expected_l1: f32 = 0.0;
for (&x_data) |v| expected_l1 += @abs(v);
std.debug.print("βββ L1 Norm (SASUM) βββ\n", .{});
std.debug.print(" ||x||β = {d:.6}\n", .{l1_norm});
std.debug.print(" Expected: {d:.6}\n", .{expected_l1});
if (@abs(l1_norm - expected_l1) > 1e-4) {
std.debug.print(" β FAILED\n", .{});
return error.ValidationFailed;
}
std.debug.print(" β Verified\n", .{});
std.debug.print("\nβ cuBLAS NRM2 + ASUM complete\n", .{});
}