Skip to content

Commit 5c9847e

Browse files
committed
Add row-wise streaming for large tensors (Maverick OOM fix)
Maverick shard 1 OOM'd allocating 20 GB for a single tensor — the original stream_index_gguf loads entire tensors as f32, which fails on Maverick's massive embedding/expert tensors (5B+ elements). New stream_index_gguf_large function switches to row-by-row streaming for tensors exceeding 512M f32 elements (2 GB). Each row is read, dequanted, projected to Base17, and discarded — peak RAM per large tensor drops from 20+ GB to ~40 KB (one row). Small tensors still use the original bulk-load path. Also makes gguf::f16_to_f32 public for the F16 row reader. https://claude.ai/code/session_01HmdXNPit7QsTCfhJFef3Ee
1 parent aa9f0af commit 5c9847e

2 files changed

Lines changed: 245 additions & 3 deletions

File tree

src/hpc/gguf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ fn dequantize_q4_k<R: Read>(r: &mut R, n_elements: usize) -> Result<Vec<f32>, St
413413
}
414414

415415
/// Convert f16 bit pattern to f32.
416-
fn f16_to_f32(bits: u16) -> f32 {
416+
pub fn f16_to_f32(bits: u16) -> f32 {
417417
let sign = ((bits >> 15) & 1) as u32;
418418
let exp = ((bits >> 10) & 0x1F) as u32;
419419
let mantissa = (bits & 0x3FF) as u32;

src/hpc/gguf_indexer.rs

Lines changed: 244 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
2121
use super::bgz17_bridge::Base17;
2222
use super::gguf::{self, GgufFile, TensorInfo, GgmlType};
23-
use std::io::{Read, Seek, Write};
23+
use std::io::{Read, Seek, SeekFrom, Write};
2424

2525
// ============================================================================
2626
// Layer classification
@@ -342,6 +342,248 @@ pub fn stream_index_gguf<R: Read + Seek, W: Write>(
342342
Ok(stats)
343343
}
344344

345+
/// Maximum f32 elements before switching to row-wise streaming (512 M elements = 2 GB f32).
346+
const LARGE_TENSOR_THRESHOLD: usize = 512 * 1024 * 1024;
347+
348+
/// Read one row of a BF16 tensor directly, dequantizing in-place.
349+
/// `abs_offset` is the file offset of this row's BF16 data.
350+
fn read_bf16_row_f32<R: Read + Seek>(
351+
reader: &mut R,
352+
abs_offset: u64,
353+
n_cols: usize,
354+
buf: &mut Vec<u8>,
355+
row_f32: &mut Vec<f32>,
356+
) -> Result<(), String> {
357+
let row_bytes = n_cols * 2;
358+
buf.resize(row_bytes, 0);
359+
row_f32.resize(n_cols, 0.0);
360+
361+
reader.seek(SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?;
362+
reader.read_exact(&mut buf[..row_bytes]).map_err(|e| e.to_string())?;
363+
364+
// SAFETY: BF16 is #[repr(transparent)] over u16, same layout as [u8; 2] LE pairs.
365+
let bf16_slice: &[super::quantized::BF16] = unsafe {
366+
std::slice::from_raw_parts(buf.as_ptr() as *const super::quantized::BF16, n_cols)
367+
};
368+
super::quantized::bf16_to_f32_slice(bf16_slice, &mut row_f32[..n_cols]);
369+
Ok(())
370+
}
371+
372+
/// Read one row of an F16 tensor directly, dequantizing in-place.
373+
fn read_f16_row_f32<R: Read + Seek>(
374+
reader: &mut R,
375+
abs_offset: u64,
376+
n_cols: usize,
377+
buf: &mut Vec<u8>,
378+
row_f32: &mut Vec<f32>,
379+
) -> Result<(), String> {
380+
let row_bytes = n_cols * 2;
381+
buf.resize(row_bytes, 0);
382+
row_f32.resize(n_cols, 0.0);
383+
384+
reader.seek(SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?;
385+
reader.read_exact(&mut buf[..row_bytes]).map_err(|e| e.to_string())?;
386+
387+
for (i, c) in buf[..row_bytes].chunks_exact(2).enumerate() {
388+
let bits = u16::from_le_bytes([c[0], c[1]]);
389+
row_f32[i] = gguf::f16_to_f32(bits);
390+
}
391+
Ok(())
392+
}
393+
394+
/// Read one row of an F32 tensor directly.
395+
fn read_f32_row<R: Read + Seek>(
396+
reader: &mut R,
397+
abs_offset: u64,
398+
n_cols: usize,
399+
buf: &mut Vec<u8>,
400+
row_f32: &mut Vec<f32>,
401+
) -> Result<(), String> {
402+
let row_bytes = n_cols * 4;
403+
buf.resize(row_bytes, 0);
404+
row_f32.resize(n_cols, 0.0);
405+
406+
reader.seek(SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?;
407+
reader.read_exact(&mut buf[..row_bytes]).map_err(|e| e.to_string())?;
408+
409+
for (i, c) in buf[..row_bytes].chunks_exact(4).enumerate() {
410+
row_f32[i] = f32::from_le_bytes([c[0], c[1], c[2], c[3]]);
411+
}
412+
Ok(())
413+
}
414+
415+
/// Stream-index a GGUF file with row-wise streaming for large tensors.
416+
///
417+
/// Identical to `stream_index_gguf` for tensors under `LARGE_TENSOR_THRESHOLD`,
418+
/// but processes oversized tensors (e.g. Maverick's 20 GB embeddings) one row
419+
/// at a time — peak RAM per large tensor = one row (~20 KB–55 KB) instead of
420+
/// the full tensor.
421+
///
422+
/// Supports row-wise streaming for F32, F16, and BF16 dtypes.
423+
/// Quantized large tensors are skipped (rare — quantized blocks don't align to rows).
424+
pub fn stream_index_gguf_large<R: Read + Seek, W: Write>(
425+
reader: &mut R,
426+
writer: &mut W,
427+
callback: Option<&dyn Fn(&str, &LayerType, usize, usize)>,
428+
) -> Result<IndexStats, String> {
429+
let gguf = gguf::read_gguf_header(reader)?;
430+
let mut stats = IndexStats::default();
431+
stats.tensors_total = gguf.tensors.len();
432+
433+
// Write file header: magic + tensor count
434+
writer.write_all(b"BGZ7").map_err(|e| e.to_string())?;
435+
writer.write_all(&(gguf.tensors.len() as u32).to_le_bytes()).map_err(|e| e.to_string())?;
436+
437+
// Reusable row buffers for large-tensor streaming
438+
let mut row_buf: Vec<u8> = Vec::new();
439+
let mut row_f32: Vec<f32> = Vec::new();
440+
441+
for tensor in &gguf.tensors {
442+
let layer_type = classify_tensor(&tensor.name, &tensor.dimensions);
443+
444+
// Skip norms and tiny tensors
445+
if matches!(layer_type, LayerType::Skip | LayerType::Norm) {
446+
stats.tensors_skipped += 1;
447+
continue;
448+
}
449+
450+
let n_elements = tensor.element_count() as usize;
451+
let is_large = n_elements > LARGE_TENSOR_THRESHOLD;
452+
453+
if is_large {
454+
// ── Row-wise streaming path for large tensors ──
455+
// Only supported for unquantized types where rows align to file offsets.
456+
let elem_size = match tensor.dtype {
457+
GgmlType::BF16 => 2usize,
458+
GgmlType::F16 => 2,
459+
GgmlType::F32 => 4,
460+
_ => {
461+
// Quantized large tensors: skip (block structure doesn't align to rows)
462+
eprintln!(" SKIP large quantized tensor: {} ({:?}, {} elements)",
463+
tensor.name, tensor.dtype, n_elements);
464+
stats.tensors_skipped += 1;
465+
continue;
466+
}
467+
};
468+
469+
// Determine rows × cols
470+
let (n_rows, n_cols) = if tensor.dimensions.len() >= 2 {
471+
let rows = tensor.dimensions[0] as usize;
472+
let cols: usize = tensor.dimensions[1..].iter().map(|&d| d as usize).product();
473+
(rows, cols)
474+
} else {
475+
(1, n_elements)
476+
};
477+
478+
let tensor_f32_bytes = (n_rows as u64) * (n_cols as u64) * 4;
479+
if tensor_f32_bytes > stats.peak_tensor_bytes {
480+
// Record the logical size, even though we never allocate it all
481+
stats.peak_tensor_bytes = tensor_f32_bytes;
482+
}
483+
484+
let abs_base = gguf.tensor_data_offset + tensor.offset;
485+
486+
// Project each row one at a time
487+
let mut rows = Vec::with_capacity(n_rows);
488+
for r in 0..n_rows {
489+
let row_offset = abs_base + (r as u64) * (n_cols as u64) * (elem_size as u64);
490+
match tensor.dtype {
491+
GgmlType::BF16 => read_bf16_row_f32(reader, row_offset, n_cols, &mut row_buf, &mut row_f32)?,
492+
GgmlType::F16 => read_f16_row_f32(reader, row_offset, n_cols, &mut row_buf, &mut row_f32)?,
493+
GgmlType::F32 => read_f32_row(reader, row_offset, n_cols, &mut row_buf, &mut row_f32)?,
494+
_ => unreachable!(), // guarded above
495+
};
496+
rows.push(project_row_to_base17(&row_f32[..n_cols]));
497+
}
498+
499+
let ct = CompressedTensor {
500+
name: tensor.name.clone(),
501+
layer_type: layer_type.clone(),
502+
original_shape: tensor.dimensions.clone(),
503+
n_rows,
504+
n_cols,
505+
rows,
506+
};
507+
508+
let orig = ct.original_bytes() as u64;
509+
let comp = ct.compressed_bytes() as u64;
510+
stats.tensors_indexed += 1;
511+
stats.original_bytes += orig;
512+
stats.compressed_bytes += comp;
513+
514+
let lt_idx = match &ct.layer_type {
515+
LayerType::Attention => 0,
516+
LayerType::FeedForward => 1,
517+
LayerType::Conv2D => 2,
518+
LayerType::Norm => 3,
519+
LayerType::Embedding => 4,
520+
LayerType::Skip => 5,
521+
};
522+
stats.by_type[lt_idx].0 += 1;
523+
stats.by_type[lt_idx].1 += orig;
524+
stats.by_type[lt_idx].2 += comp;
525+
526+
if let Some(cb) = callback {
527+
cb(&ct.name, &ct.layer_type, ct.original_bytes(), ct.compressed_bytes());
528+
}
529+
530+
ct.write_to(writer)?;
531+
} else {
532+
// ── Standard path: load full tensor (same as stream_index_gguf) ──
533+
let data = gguf::read_tensor_f32(reader, &gguf, tensor)?;
534+
535+
let tensor_bytes = data.len() as u64 * 4;
536+
if tensor_bytes > stats.peak_tensor_bytes {
537+
stats.peak_tensor_bytes = tensor_bytes;
538+
}
539+
540+
let (n_rows, n_cols) = tensor_to_rows(&data, &tensor.dimensions, &layer_type);
541+
542+
let mut rows = Vec::with_capacity(n_rows);
543+
for r in 0..n_rows {
544+
let start = r * n_cols;
545+
let end = (start + n_cols).min(data.len());
546+
rows.push(project_row_to_base17(&data[start..end]));
547+
}
548+
549+
let ct = CompressedTensor {
550+
name: tensor.name.clone(),
551+
layer_type: layer_type.clone(),
552+
original_shape: tensor.dimensions.clone(),
553+
n_rows,
554+
n_cols,
555+
rows,
556+
};
557+
558+
let orig = ct.original_bytes() as u64;
559+
let comp = ct.compressed_bytes() as u64;
560+
stats.tensors_indexed += 1;
561+
stats.original_bytes += orig;
562+
stats.compressed_bytes += comp;
563+
564+
let lt_idx = match &ct.layer_type {
565+
LayerType::Attention => 0,
566+
LayerType::FeedForward => 1,
567+
LayerType::Conv2D => 2,
568+
LayerType::Norm => 3,
569+
LayerType::Embedding => 4,
570+
LayerType::Skip => 5,
571+
};
572+
stats.by_type[lt_idx].0 += 1;
573+
stats.by_type[lt_idx].1 += orig;
574+
stats.by_type[lt_idx].2 += comp;
575+
576+
if let Some(cb) = callback {
577+
cb(&ct.name, &ct.layer_type, ct.original_bytes(), ct.compressed_bytes());
578+
}
579+
580+
ct.write_to(writer)?;
581+
}
582+
}
583+
584+
Ok(stats)
585+
}
586+
345587
// ============================================================================
346588
// Tests
347589
// ============================================================================
@@ -762,7 +1004,7 @@ mod tests {
7621004
let out = std::fs::File::create(&out_path).expect("create output");
7631005
let mut writer = BufWriter::new(out);
7641006

765-
let stats = stream_index_gguf(
1007+
let stats = stream_index_gguf_large(
7661008
&mut reader,
7671009
&mut writer,
7681010
Some(&|name, layer_type, orig, comp| {

0 commit comments

Comments
 (0)