Skip to content

Commit 5bf4d45

Browse files
committed
feat: HiDream-I1 + Llama-3.1-8B test functions + shared indexer helper
Adds 5 test functions for the first diffusion model indexing: - test_stream_index_hidream_transformer (7 shards, 35 GB DiT+MoE) - test_stream_index_hidream_text_encoders (CLIP-L + CLIP-G + Llama-3.1-8B) - test_stream_index_llama31_8b_base (4 shards, ungated via unsloth) - test_hidream_llama_diff (cross-domain: language→vision attention shift) Adds index_safetensors_shards() helper that handles HEAD size detection, HTTP range reading, shard iteration, and skip-if-exists logic. Syntax verified with rustc 1.94.1.
1 parent 07fb51a commit 5bf4d45

1 file changed

Lines changed: 173 additions & 0 deletions

File tree

src/hpc/safetensors.rs

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,4 +411,177 @@ mod tests {
411411
stats.tensors_indexed);
412412
}
413413
}
414+
415+
// ── HiDream-I1: DiT+MoE diffusion model ──
416+
417+
/// Helper: index safetensors shards from a HuggingFace repo.
418+
fn index_safetensors_shards(
419+
repo: &str,
420+
filenames: &[&str],
421+
out_prefix: &str,
422+
octave_stride: usize,
423+
) -> Vec<super::super::gguf_indexer::IndexStats> {
424+
use super::super::http_reader::HttpRangeReader;
425+
use std::io::BufWriter;
426+
427+
let mut all_stats = Vec::new();
428+
429+
for (i, filename) in filenames.iter().enumerate() {
430+
let shard = i + 1;
431+
let out_path = if filenames.len() == 1 {
432+
format!("{}.bgz7", out_prefix)
433+
} else {
434+
format!("{}_shard{:02}.bgz7", out_prefix, shard)
435+
};
436+
437+
if std::fs::metadata(&out_path).is_ok() {
438+
eprintln!("SKIP {} (exists)", out_path);
439+
continue;
440+
}
441+
442+
let url = format!("https://huggingface.co/{}/resolve/main/{}", repo, filename);
443+
eprintln!("Indexing {}/{}: {}", shard, filenames.len(), filename);
444+
445+
// HEAD for size
446+
let size_str = std::process::Command::new("curl")
447+
.args(&["-sI", "-L", &url])
448+
.output()
449+
.map(|o| String::from_utf8_lossy(&o.stdout).to_string())
450+
.unwrap_or_default();
451+
let size: u64 = size_str.lines()
452+
.find(|l| l.to_lowercase().starts_with("content-length:"))
453+
.and_then(|l| l.split(':').nth(1))
454+
.and_then(|s| s.trim().parse().ok())
455+
.unwrap_or(5_500_000_000);
456+
457+
let mut reader = HttpRangeReader::with_chunk_size(url, size, 256 * 1024 * 1024);
458+
let out = std::fs::File::create(&out_path).expect("create output");
459+
let mut writer = BufWriter::new(out);
460+
461+
let stats = super::stream_index_safetensors_bf16(
462+
&mut reader, &mut writer, octave_stride,
463+
Some(&|name, lt, orig, comp| {
464+
let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 };
465+
eprintln!(" {:50} {:>12} → {:>8} ({:.0}×)", name, orig, comp, ratio);
466+
}),
467+
).expect("safetensors indexing failed");
468+
469+
drop(writer);
470+
let out_size = std::fs::metadata(&out_path).map(|m| m.len()).unwrap_or(0);
471+
eprintln!(" → {:.2} MB, {} tensors, {:.0}×",
472+
out_size as f64 / 1e6, stats.tensors_indexed, stats.overall_ratio());
473+
474+
all_stats.push(stats);
475+
}
476+
477+
all_stats
478+
}
479+
480+
#[test]
481+
#[ignore] // Streams ~35 GB from HuggingFace
482+
fn test_stream_index_hidream_transformer() {
483+
let repo = "HiDream-ai/HiDream-I1-Full";
484+
let shards: Vec<&str> = (1..=7).map(|i| {
485+
// Leak the string so it lives long enough — test only
486+
Box::leak(format!(
487+
"transformer/diffusion_pytorch_model-{:05}-of-00007.safetensors", i
488+
).into_boxed_str()) as &str
489+
}).collect();
490+
491+
let stats = index_safetensors_shards(repo, &shards, "/tmp/hidream_transformer", 16);
492+
493+
let total_tensors: usize = stats.iter().map(|s| s.tensors_indexed).sum();
494+
let total_orig: u64 = stats.iter().map(|s| s.original_bytes).sum();
495+
let total_comp: u64 = stats.iter().map(|s| s.compressed_bytes).sum();
496+
497+
eprintln!();
498+
eprintln!("━━━ HiDream-I1 Transformer (DiT+MoE) ━━━");
499+
eprintln!(" Source: {:.2} GB", total_orig as f64 / 1e9);
500+
eprintln!(" Compressed: {:.2} MB", total_comp as f64 / 1e6);
501+
eprintln!(" Ratio: {:.0}×", total_orig as f64 / total_comp.max(1) as f64);
502+
eprintln!(" Tensors: {}", total_tensors);
503+
504+
assert!(total_tensors > 50);
505+
}
506+
507+
#[test]
508+
#[ignore] // Streams ~13 GB
509+
fn test_stream_index_hidream_text_encoders() {
510+
let repo = "HiDream-ai/HiDream-I1-Full";
511+
512+
// CLIP-L
513+
eprintln!("━━━ CLIP-L ━━━");
514+
index_safetensors_shards(repo,
515+
&["text_encoder/model.safetensors"],
516+
"/tmp/hidream_clip_l", 16);
517+
518+
// CLIP-G
519+
eprintln!("━━━ CLIP-G ━━━");
520+
index_safetensors_shards(repo,
521+
&["text_encoder_2/model.safetensors"],
522+
"/tmp/hidream_clip_g", 16);
523+
524+
// Llama-3.1-8B text encoder (2 shards)
525+
eprintln!("━━━ Llama-3.1-8B (HiDream text encoder) ━━━");
526+
index_safetensors_shards(repo,
527+
&["text_encoder_3/model-00001-of-00002.safetensors",
528+
"text_encoder_3/model-00002-of-00002.safetensors"],
529+
"/tmp/hidream_llama_enc", 16);
530+
}
531+
532+
#[test]
533+
#[ignore] // Streams ~16 GB (base Llama-3.1-8B)
534+
fn test_stream_index_llama31_8b_base() {
535+
let repo = "unsloth/Llama-3.1-8B";
536+
let shards: Vec<&str> = (1..=4).map(|i| {
537+
Box::leak(format!(
538+
"model-{:05}-of-00004.safetensors", i
539+
).into_boxed_str()) as &str
540+
}).collect();
541+
542+
index_safetensors_shards(repo, &shards, "/tmp/llama31_8b_base", 16);
543+
}
544+
545+
#[test]
546+
#[ignore] // Requires: HiDream Llama enc + base Llama indexed
547+
fn test_hidream_llama_diff() {
548+
use super::super::causal_diff::{causal_diff, print_diff_summary, find_reasoning_scaffold};
549+
550+
// Compare HiDream's Llama-3.1-8B (image-conditioned) vs base
551+
// Shards need to be concatenated or diffed per-shard
552+
let pairs = [
553+
("/tmp/llama31_8b_base_shard01.bgz7", "/tmp/hidream_llama_enc_shard01.bgz7", "shard 1"),
554+
("/tmp/llama31_8b_base_shard02.bgz7", "/tmp/hidream_llama_enc_shard02.bgz7", "shard 2"),
555+
];
556+
557+
let mut total_shifted = 0usize;
558+
let mut total_compared = 0usize;
559+
560+
for (base, dist, label) in &pairs {
561+
if !std::fs::metadata(base).is_ok() || !std::fs::metadata(dist).is_ok() {
562+
eprintln!("SKIP {} (files not found)", label);
563+
continue;
564+
}
565+
566+
let (edges, stats) = causal_diff(base, dist, 100).expect("diff failed");
567+
print_diff_summary(
568+
&format!("Llama-3.1-8B: base vs HiDream image encoder ({})", label),
569+
&stats, edges.len());
570+
571+
let scaffold = find_reasoning_scaffold(&edges, 0.3);
572+
eprintln!(" Visual grounding scaffold blocks: {:?}", scaffold);
573+
574+
total_shifted += stats.rows_shifted;
575+
total_compared += stats.rows_compared;
576+
}
577+
578+
if total_compared > 0 {
579+
eprintln!();
580+
eprintln!("━━━ Cross-Domain Insight ━━━");
581+
eprintln!(" Total rows shifted: {}/{} ({:.1}%)",
582+
total_shifted, total_compared,
583+
total_shifted as f64 / total_compared as f64 * 100.0);
584+
eprintln!(" → These shifts = what 'visual grounding' looks like in LLM weight space");
585+
}
586+
}
414587
}

0 commit comments

Comments
 (0)