|
6 | 6 | const std = @import("std"); |
7 | 7 | const simd_matmul = @import("simd_ternary_matmul.zig"); |
8 | 8 | const trinity_format = @import("trinity_format.zig"); |
| 9 | +const flash_attn = @import("flash_attention.zig"); |
9 | 10 |
|
10 | 11 | // ═══════════════════════════════════════════════════════════════════════════════ |
11 | 12 | // SIMD TYPES AND HELPERS |
@@ -500,8 +501,74 @@ pub const Attention = struct { |
500 | 501 | // Output projection |
501 | 502 | ternaryMatmul(output, self.w_o, attn_out, cfg.hidden_size, q_size); |
502 | 503 | } |
| 504 | + |
| 505 | + /// Forward pass using Flash Attention (O(N) memory instead of O(N²)) |
| 506 | + /// Use this for long sequences (>256 tokens) for better memory efficiency |
| 507 | + pub fn forwardFlash( |
| 508 | + self: *const Attention, |
| 509 | + allocator: std.mem.Allocator, |
| 510 | + output: []f32, |
| 511 | + input: []const f32, |
| 512 | + kv_cache: *KVCache, |
| 513 | + rope: *const RoPE, |
| 514 | + pos: usize, |
| 515 | + ) !void { |
| 516 | + const cfg = self.config; |
| 517 | + const q_size = cfg.num_heads * cfg.head_dim; |
| 518 | + const kv_size = cfg.num_kv_heads * cfg.head_dim; |
| 519 | + |
| 520 | + // Allocate Q, K, V |
| 521 | + const q = try allocator.alloc(f32, q_size); |
| 522 | + defer allocator.free(q); |
| 523 | + const k = try allocator.alloc(f32, kv_size); |
| 524 | + defer allocator.free(k); |
| 525 | + const v = try allocator.alloc(f32, kv_size); |
| 526 | + defer allocator.free(v); |
| 527 | + |
| 528 | + // Project Q, K, V using ternary matmul |
| 529 | + ternaryMatmul(q, self.w_q, input, q_size, cfg.hidden_size); |
| 530 | + ternaryMatmul(k, self.w_k, input, kv_size, cfg.hidden_size); |
| 531 | + ternaryMatmul(v, self.w_v, input, kv_size, cfg.hidden_size); |
| 532 | + |
| 533 | + // Apply RoPE to Q and K |
| 534 | + for (0..cfg.num_heads) |h| { |
| 535 | + rope.apply(q[h * cfg.head_dim ..][0..cfg.head_dim], pos); |
| 536 | + } |
| 537 | + for (0..cfg.num_kv_heads) |h| { |
| 538 | + rope.apply(k[h * cfg.head_dim ..][0..cfg.head_dim], pos); |
| 539 | + } |
| 540 | + |
| 541 | + // Append K, V to cache |
| 542 | + kv_cache.append(k, v); |
| 543 | + |
| 544 | + // Use Flash Attention for O(N) memory complexity |
| 545 | + const attn_out = try allocator.alloc(f32, q_size); |
| 546 | + defer allocator.free(attn_out); |
| 547 | + |
| 548 | + const scale = 1.0 / @sqrt(@as(f32, @floatFromInt(cfg.head_dim))); |
| 549 | + |
| 550 | + // Flash Attention with GQA support |
| 551 | + try flash_attn.flashAttentionGQA( |
| 552 | + allocator, |
| 553 | + attn_out, |
| 554 | + q, |
| 555 | + kv_cache.k, |
| 556 | + kv_cache.v, |
| 557 | + cfg.num_heads, |
| 558 | + cfg.num_kv_heads, |
| 559 | + cfg.head_dim, |
| 560 | + kv_cache.len, |
| 561 | + scale, |
| 562 | + ); |
| 563 | + |
| 564 | + // Output projection |
| 565 | + ternaryMatmul(output, self.w_o, attn_out, cfg.hidden_size, q_size); |
| 566 | + } |
503 | 567 | }; |
504 | 568 |
|
| 569 | +/// Use Flash Attention for sequences longer than this threshold |
| 570 | +pub const FLASH_ATTENTION_THRESHOLD: usize = 256; |
| 571 | + |
505 | 572 | // ═══════════════════════════════════════════════════════════════════════════════ |
506 | 573 | // MLP - Feed Forward Network with SiLU |
507 | 574 | // ═══════════════════════════════════════════════════════════════════════════════ |
@@ -562,10 +629,14 @@ pub const BitNetLayer = struct { |
562 | 629 | defer allocator.free(normed); |
563 | 630 | self.input_norm.forward(normed, input); |
564 | 631 |
|
565 | | - // Attention |
| 632 | + // Attention (use Flash Attention for long sequences) |
566 | 633 | const attn_out = try allocator.alloc(f32, hidden_size); |
567 | 634 | defer allocator.free(attn_out); |
568 | | - try self.attention.forward(allocator, attn_out, normed, kv_cache, rope, pos); |
| 635 | + if (kv_cache.len > FLASH_ATTENTION_THRESHOLD) { |
| 636 | + try self.attention.forwardFlash(allocator, attn_out, normed, kv_cache, rope, pos); |
| 637 | + } else { |
| 638 | + try self.attention.forward(allocator, attn_out, normed, kv_cache, rope, pos); |
| 639 | + } |
569 | 640 |
|
570 | 641 | // Residual |
571 | 642 | const post_attn = try allocator.alloc(f32, hidden_size); |
|
0 commit comments