|
20 | 20 |
|
21 | 21 | use super::bgz17_bridge::Base17; |
22 | 22 | use super::gguf::{self, GgufFile, TensorInfo, GgmlType}; |
23 | | -use std::io::{Read, Seek, Write}; |
| 23 | +use std::io::{Read, Seek, SeekFrom, Write}; |
24 | 24 |
|
25 | 25 | // ============================================================================ |
26 | 26 | // Layer classification |
@@ -342,6 +342,248 @@ pub fn stream_index_gguf<R: Read + Seek, W: Write>( |
342 | 342 | Ok(stats) |
343 | 343 | } |
344 | 344 |
|
| 345 | +/// Maximum f32 elements before switching to row-wise streaming (512 M elements = 2 GB f32). |
| 346 | +const LARGE_TENSOR_THRESHOLD: usize = 512 * 1024 * 1024; |
| 347 | + |
| 348 | +/// Read one row of a BF16 tensor directly, dequantizing in-place. |
| 349 | +/// `abs_offset` is the file offset of this row's BF16 data. |
| 350 | +fn read_bf16_row_f32<R: Read + Seek>( |
| 351 | + reader: &mut R, |
| 352 | + abs_offset: u64, |
| 353 | + n_cols: usize, |
| 354 | + buf: &mut Vec<u8>, |
| 355 | + row_f32: &mut Vec<f32>, |
| 356 | +) -> Result<(), String> { |
| 357 | + let row_bytes = n_cols * 2; |
| 358 | + buf.resize(row_bytes, 0); |
| 359 | + row_f32.resize(n_cols, 0.0); |
| 360 | + |
| 361 | + reader.seek(SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?; |
| 362 | + reader.read_exact(&mut buf[..row_bytes]).map_err(|e| e.to_string())?; |
| 363 | + |
| 364 | + // SAFETY: BF16 is #[repr(transparent)] over u16, same layout as [u8; 2] LE pairs. |
| 365 | + let bf16_slice: &[super::quantized::BF16] = unsafe { |
| 366 | + std::slice::from_raw_parts(buf.as_ptr() as *const super::quantized::BF16, n_cols) |
| 367 | + }; |
| 368 | + super::quantized::bf16_to_f32_slice(bf16_slice, &mut row_f32[..n_cols]); |
| 369 | + Ok(()) |
| 370 | +} |
| 371 | + |
| 372 | +/// Read one row of an F16 tensor directly, dequantizing in-place. |
| 373 | +fn read_f16_row_f32<R: Read + Seek>( |
| 374 | + reader: &mut R, |
| 375 | + abs_offset: u64, |
| 376 | + n_cols: usize, |
| 377 | + buf: &mut Vec<u8>, |
| 378 | + row_f32: &mut Vec<f32>, |
| 379 | +) -> Result<(), String> { |
| 380 | + let row_bytes = n_cols * 2; |
| 381 | + buf.resize(row_bytes, 0); |
| 382 | + row_f32.resize(n_cols, 0.0); |
| 383 | + |
| 384 | + reader.seek(SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?; |
| 385 | + reader.read_exact(&mut buf[..row_bytes]).map_err(|e| e.to_string())?; |
| 386 | + |
| 387 | + for (i, c) in buf[..row_bytes].chunks_exact(2).enumerate() { |
| 388 | + let bits = u16::from_le_bytes([c[0], c[1]]); |
| 389 | + row_f32[i] = gguf::f16_to_f32(bits); |
| 390 | + } |
| 391 | + Ok(()) |
| 392 | +} |
| 393 | + |
| 394 | +/// Read one row of an F32 tensor directly. |
| 395 | +fn read_f32_row<R: Read + Seek>( |
| 396 | + reader: &mut R, |
| 397 | + abs_offset: u64, |
| 398 | + n_cols: usize, |
| 399 | + buf: &mut Vec<u8>, |
| 400 | + row_f32: &mut Vec<f32>, |
| 401 | +) -> Result<(), String> { |
| 402 | + let row_bytes = n_cols * 4; |
| 403 | + buf.resize(row_bytes, 0); |
| 404 | + row_f32.resize(n_cols, 0.0); |
| 405 | + |
| 406 | + reader.seek(SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?; |
| 407 | + reader.read_exact(&mut buf[..row_bytes]).map_err(|e| e.to_string())?; |
| 408 | + |
| 409 | + for (i, c) in buf[..row_bytes].chunks_exact(4).enumerate() { |
| 410 | + row_f32[i] = f32::from_le_bytes([c[0], c[1], c[2], c[3]]); |
| 411 | + } |
| 412 | + Ok(()) |
| 413 | +} |
| 414 | + |
| 415 | +/// Stream-index a GGUF file with row-wise streaming for large tensors. |
| 416 | +/// |
| 417 | +/// Identical to `stream_index_gguf` for tensors under `LARGE_TENSOR_THRESHOLD`, |
| 418 | +/// but processes oversized tensors (e.g. Maverick's 20 GB embeddings) one row |
| 419 | +/// at a time — peak RAM per large tensor = one row (~20 KB–55 KB) instead of |
| 420 | +/// the full tensor. |
| 421 | +/// |
| 422 | +/// Supports row-wise streaming for F32, F16, and BF16 dtypes. |
| 423 | +/// Quantized large tensors are skipped (rare — quantized blocks don't align to rows). |
| 424 | +pub fn stream_index_gguf_large<R: Read + Seek, W: Write>( |
| 425 | + reader: &mut R, |
| 426 | + writer: &mut W, |
| 427 | + callback: Option<&dyn Fn(&str, &LayerType, usize, usize)>, |
| 428 | +) -> Result<IndexStats, String> { |
| 429 | + let gguf = gguf::read_gguf_header(reader)?; |
| 430 | + let mut stats = IndexStats::default(); |
| 431 | + stats.tensors_total = gguf.tensors.len(); |
| 432 | + |
| 433 | + // Write file header: magic + tensor count |
| 434 | + writer.write_all(b"BGZ7").map_err(|e| e.to_string())?; |
| 435 | + writer.write_all(&(gguf.tensors.len() as u32).to_le_bytes()).map_err(|e| e.to_string())?; |
| 436 | + |
| 437 | + // Reusable row buffers for large-tensor streaming |
| 438 | + let mut row_buf: Vec<u8> = Vec::new(); |
| 439 | + let mut row_f32: Vec<f32> = Vec::new(); |
| 440 | + |
| 441 | + for tensor in &gguf.tensors { |
| 442 | + let layer_type = classify_tensor(&tensor.name, &tensor.dimensions); |
| 443 | + |
| 444 | + // Skip norms and tiny tensors |
| 445 | + if matches!(layer_type, LayerType::Skip | LayerType::Norm) { |
| 446 | + stats.tensors_skipped += 1; |
| 447 | + continue; |
| 448 | + } |
| 449 | + |
| 450 | + let n_elements = tensor.element_count() as usize; |
| 451 | + let is_large = n_elements > LARGE_TENSOR_THRESHOLD; |
| 452 | + |
| 453 | + if is_large { |
| 454 | + // ── Row-wise streaming path for large tensors ── |
| 455 | + // Only supported for unquantized types where rows align to file offsets. |
| 456 | + let elem_size = match tensor.dtype { |
| 457 | + GgmlType::BF16 => 2usize, |
| 458 | + GgmlType::F16 => 2, |
| 459 | + GgmlType::F32 => 4, |
| 460 | + _ => { |
| 461 | + // Quantized large tensors: skip (block structure doesn't align to rows) |
| 462 | + eprintln!(" SKIP large quantized tensor: {} ({:?}, {} elements)", |
| 463 | + tensor.name, tensor.dtype, n_elements); |
| 464 | + stats.tensors_skipped += 1; |
| 465 | + continue; |
| 466 | + } |
| 467 | + }; |
| 468 | + |
| 469 | + // Determine rows × cols |
| 470 | + let (n_rows, n_cols) = if tensor.dimensions.len() >= 2 { |
| 471 | + let rows = tensor.dimensions[0] as usize; |
| 472 | + let cols: usize = tensor.dimensions[1..].iter().map(|&d| d as usize).product(); |
| 473 | + (rows, cols) |
| 474 | + } else { |
| 475 | + (1, n_elements) |
| 476 | + }; |
| 477 | + |
| 478 | + let tensor_f32_bytes = (n_rows as u64) * (n_cols as u64) * 4; |
| 479 | + if tensor_f32_bytes > stats.peak_tensor_bytes { |
| 480 | + // Record the logical size, even though we never allocate it all |
| 481 | + stats.peak_tensor_bytes = tensor_f32_bytes; |
| 482 | + } |
| 483 | + |
| 484 | + let abs_base = gguf.tensor_data_offset + tensor.offset; |
| 485 | + |
| 486 | + // Project each row one at a time |
| 487 | + let mut rows = Vec::with_capacity(n_rows); |
| 488 | + for r in 0..n_rows { |
| 489 | + let row_offset = abs_base + (r as u64) * (n_cols as u64) * (elem_size as u64); |
| 490 | + match tensor.dtype { |
| 491 | + GgmlType::BF16 => read_bf16_row_f32(reader, row_offset, n_cols, &mut row_buf, &mut row_f32)?, |
| 492 | + GgmlType::F16 => read_f16_row_f32(reader, row_offset, n_cols, &mut row_buf, &mut row_f32)?, |
| 493 | + GgmlType::F32 => read_f32_row(reader, row_offset, n_cols, &mut row_buf, &mut row_f32)?, |
| 494 | + _ => unreachable!(), // guarded above |
| 495 | + }; |
| 496 | + rows.push(project_row_to_base17(&row_f32[..n_cols])); |
| 497 | + } |
| 498 | + |
| 499 | + let ct = CompressedTensor { |
| 500 | + name: tensor.name.clone(), |
| 501 | + layer_type: layer_type.clone(), |
| 502 | + original_shape: tensor.dimensions.clone(), |
| 503 | + n_rows, |
| 504 | + n_cols, |
| 505 | + rows, |
| 506 | + }; |
| 507 | + |
| 508 | + let orig = ct.original_bytes() as u64; |
| 509 | + let comp = ct.compressed_bytes() as u64; |
| 510 | + stats.tensors_indexed += 1; |
| 511 | + stats.original_bytes += orig; |
| 512 | + stats.compressed_bytes += comp; |
| 513 | + |
| 514 | + let lt_idx = match &ct.layer_type { |
| 515 | + LayerType::Attention => 0, |
| 516 | + LayerType::FeedForward => 1, |
| 517 | + LayerType::Conv2D => 2, |
| 518 | + LayerType::Norm => 3, |
| 519 | + LayerType::Embedding => 4, |
| 520 | + LayerType::Skip => 5, |
| 521 | + }; |
| 522 | + stats.by_type[lt_idx].0 += 1; |
| 523 | + stats.by_type[lt_idx].1 += orig; |
| 524 | + stats.by_type[lt_idx].2 += comp; |
| 525 | + |
| 526 | + if let Some(cb) = callback { |
| 527 | + cb(&ct.name, &ct.layer_type, ct.original_bytes(), ct.compressed_bytes()); |
| 528 | + } |
| 529 | + |
| 530 | + ct.write_to(writer)?; |
| 531 | + } else { |
| 532 | + // ── Standard path: load full tensor (same as stream_index_gguf) ── |
| 533 | + let data = gguf::read_tensor_f32(reader, &gguf, tensor)?; |
| 534 | + |
| 535 | + let tensor_bytes = data.len() as u64 * 4; |
| 536 | + if tensor_bytes > stats.peak_tensor_bytes { |
| 537 | + stats.peak_tensor_bytes = tensor_bytes; |
| 538 | + } |
| 539 | + |
| 540 | + let (n_rows, n_cols) = tensor_to_rows(&data, &tensor.dimensions, &layer_type); |
| 541 | + |
| 542 | + let mut rows = Vec::with_capacity(n_rows); |
| 543 | + for r in 0..n_rows { |
| 544 | + let start = r * n_cols; |
| 545 | + let end = (start + n_cols).min(data.len()); |
| 546 | + rows.push(project_row_to_base17(&data[start..end])); |
| 547 | + } |
| 548 | + |
| 549 | + let ct = CompressedTensor { |
| 550 | + name: tensor.name.clone(), |
| 551 | + layer_type: layer_type.clone(), |
| 552 | + original_shape: tensor.dimensions.clone(), |
| 553 | + n_rows, |
| 554 | + n_cols, |
| 555 | + rows, |
| 556 | + }; |
| 557 | + |
| 558 | + let orig = ct.original_bytes() as u64; |
| 559 | + let comp = ct.compressed_bytes() as u64; |
| 560 | + stats.tensors_indexed += 1; |
| 561 | + stats.original_bytes += orig; |
| 562 | + stats.compressed_bytes += comp; |
| 563 | + |
| 564 | + let lt_idx = match &ct.layer_type { |
| 565 | + LayerType::Attention => 0, |
| 566 | + LayerType::FeedForward => 1, |
| 567 | + LayerType::Conv2D => 2, |
| 568 | + LayerType::Norm => 3, |
| 569 | + LayerType::Embedding => 4, |
| 570 | + LayerType::Skip => 5, |
| 571 | + }; |
| 572 | + stats.by_type[lt_idx].0 += 1; |
| 573 | + stats.by_type[lt_idx].1 += orig; |
| 574 | + stats.by_type[lt_idx].2 += comp; |
| 575 | + |
| 576 | + if let Some(cb) = callback { |
| 577 | + cb(&ct.name, &ct.layer_type, ct.original_bytes(), ct.compressed_bytes()); |
| 578 | + } |
| 579 | + |
| 580 | + ct.write_to(writer)?; |
| 581 | + } |
| 582 | + } |
| 583 | + |
| 584 | + Ok(stats) |
| 585 | +} |
| 586 | + |
345 | 587 | // ============================================================================ |
346 | 588 | // Tests |
347 | 589 | // ============================================================================ |
@@ -762,7 +1004,7 @@ mod tests { |
762 | 1004 | let out = std::fs::File::create(&out_path).expect("create output"); |
763 | 1005 | let mut writer = BufWriter::new(out); |
764 | 1006 |
|
765 | | - let stats = stream_index_gguf( |
| 1007 | + let stats = stream_index_gguf_large( |
766 | 1008 | &mut reader, |
767 | 1009 | &mut writer, |
768 | 1010 | Some(&|name, layer_type, orig, comp| { |
|
0 commit comments