Skip to content

Commit 0168de9

Browse files
authored
Merge pull request #55 from AdaWorldAPI/claude/bf16-chunked-review
fix: chunked BF16, buffer cap, drop fake FMA
2 parents 6d5087e + 2213ce9 commit 0168de9

1 file changed

Lines changed: 67 additions & 21 deletions

File tree

src/hpc/gguf_indexer.rs

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,7 @@ pub fn project_8rows_bf16_simd(
345345

346346
for bin in 0..BASE_DIM {
347347
let c = counts[bin].max(1) as f64;
348-
let scaled = sums[bin].mul_add(
349-
F64x8::splat(FP_SCALE / c),
350-
F64x8::splat(0.0),
351-
);
348+
let scaled = sums[bin] * F64x8::splat(FP_SCALE / c);
352349
let clamped = scaled.round().simd_clamp(lo, hi);
353350
let vals = clamped.to_array();
354351
for lane in 0..8 {
@@ -529,7 +526,9 @@ pub fn stream_index_gguf_bf16<R: Read + Seek, W: Write>(
529526
writer.write_all(b"BGZ7").map_err(|e| e.to_string())?;
530527
writer.write_all(&(gguf_header.tensors.len() as u32).to_le_bytes()).map_err(|e| e.to_string())?;
531528

532-
// ONE reusable buffer — grows to largest tensor, never shrinks
529+
// Reusable buffer — capped at 128 MB (64M u16 elements).
530+
// Tensors larger than this are read in row batches.
531+
const MAX_BUF_ELEMS: usize = 64 * 1024 * 1024; // 128 MB of u16
533532
let mut bf16_buf: Vec<u16> = Vec::new();
534533

535534
for tensor in &gguf_header.tensors {
@@ -543,23 +542,64 @@ pub fn stream_index_gguf_bf16<R: Read + Seek, W: Write>(
543542
let is_bf16 = matches!(tensor.dtype, gguf::GgmlType::BF16);
544543

545544
if is_bf16 {
546-
// FAST PATH: BF16 direct — no f32 intermediate
547-
let n_elements = read_tensor_bf16_raw(reader, &gguf_header, tensor, &mut bf16_buf)?;
545+
// FAST PATH: BF16 direct — chunked row-batch reading.
546+
// Caps memory at MAX_BUF_ELEMS regardless of tensor size.
547+
// A 10.7 GB ffn_gate_exps tensor reads in ~128 MB batches.
548548
let (n_rows, n_cols) = tensor_to_rows_dims(&tensor.dimensions, &layer_type);
549-
550-
// F64x8: 8 rows parallel, SIMD accumulation per halftone bin
551-
let rows = if octave_stride > 1 {
552-
project_tensor_bf16_simd(&bf16_buf[..n_elements], n_rows, n_cols, octave_stride)
549+
let chunk_rows = if n_cols > 0 {
550+
(MAX_BUF_ELEMS / n_cols).max(8).min(n_rows) // at least 8 rows (SIMD batch)
553551
} else {
554-
// Full precision: scalar per-row (stride=1 doesn't benefit from SIMD halftone)
555-
let mut rows = Vec::with_capacity(n_rows);
556-
for r in 0..n_rows {
557-
let start = r * n_cols;
558-
let end = (start + n_cols).min(n_elements);
559-
rows.push(project_row_bf16_direct(&bf16_buf[start..end]));
560-
}
561-
rows
552+
n_rows
562553
};
554+
let chunk_elems = chunk_rows * n_cols;
555+
556+
// Grow buffer to chunk size (not full tensor size)
557+
if bf16_buf.len() < chunk_elems {
558+
bf16_buf.resize(chunk_elems, 0);
559+
}
560+
561+
// Seek to tensor start
562+
let abs_offset = gguf_header.tensor_data_offset + tensor.offset;
563+
reader.seek(std::io::SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?;
564+
565+
let mut rows: Vec<Base17> = Vec::with_capacity(n_rows);
566+
let mut rows_done: usize = 0;
567+
let is_large = n_rows > chunk_rows;
568+
569+
while rows_done < n_rows {
570+
let batch_n = (n_rows - rows_done).min(chunk_rows);
571+
let batch_elems = batch_n * n_cols;
572+
573+
// Read batch bytes into reusable buffer
574+
let byte_slice = unsafe {
575+
std::slice::from_raw_parts_mut(
576+
bf16_buf.as_mut_ptr() as *mut u8,
577+
batch_elems * 2,
578+
)
579+
};
580+
reader.read_exact(byte_slice).map_err(|e| e.to_string())?;
581+
582+
// Project this batch
583+
if octave_stride > 1 {
584+
let batch_b17 = project_tensor_bf16_simd(
585+
&bf16_buf[..batch_elems], batch_n, n_cols, octave_stride
586+
);
587+
rows.extend_from_slice(&batch_b17);
588+
} else {
589+
for r in 0..batch_n {
590+
let start = r * n_cols;
591+
rows.push(project_row_bf16_direct(&bf16_buf[start..start + n_cols]));
592+
}
593+
}
594+
595+
rows_done += batch_n;
596+
597+
// Progress for large tensors (every chunk)
598+
if is_large && rows_done < n_rows {
599+
eprintln!(" ... {}/{} rows ({:.0}%)",
600+
rows_done, n_rows, rows_done as f64 / n_rows as f64 * 100.0);
601+
}
602+
}
563603

564604
let orig_bytes = (n_rows * n_cols * 4) as u64;
565605
let comp_bytes = (rows.len() * Base17::BYTE_SIZE) as u64;
@@ -582,8 +622,14 @@ pub fn stream_index_gguf_bf16<R: Read + Seek, W: Write>(
582622
stats.compressed_bytes += comp_bytes;
583623
stats.tensors_indexed += 1;
584624

585-
let peak = n_elements as u64 * 2;
586-
if peak > stats.peak_tensor_bytes { stats.peak_tensor_bytes = peak; }
625+
let buf_bytes = chunk_elems as u64 * 2;
626+
if buf_bytes > stats.peak_tensor_bytes { stats.peak_tensor_bytes = buf_bytes; }
627+
628+
// Shrink buffer if it grew past the cap (shouldn't, but defensive)
629+
if bf16_buf.len() > MAX_BUF_ELEMS {
630+
bf16_buf.truncate(MAX_BUF_ELEMS);
631+
bf16_buf.shrink_to(MAX_BUF_ELEMS);
632+
}
587633

588634
if let Some(cb) = callback {
589635
cb(&tensor.name, &layer_type, orig_bytes as usize, comp_bytes as usize);

0 commit comments

Comments
 (0)