Skip to content

Commit 93e86a4

Browse files
committed
feat: shared model primitives + Stable Diffusion scaffold
Extract shared code into hpc/models/: - safetensors.rs: generic file loader (used by GPT-2, SD, BERT) - layers.rs: SIMD ops (layer_norm, gelu, silu, group_norm, softmax, matmul_vec, dot_product) — all via crate::simd::F32x16 - api_types.rs: OpenAI-compatible envelope (Usage, FinishReason, etc.) Add hpc/stable_diffusion/ scaffold (code only, no weights): - clip.rs: CLIP text encoder (same transformer as GPT-2, shared layers) - unet.rs: UNet denoiser with Conv2D, GroupNorm, SiLU, timestep embedding - vae.rs: VAE decoder (latent→RGB) - scheduler.rs: DDIM noise scheduler with precomputed alpha schedule - weights.rs: safetensors loader for SD CLIP weights - api.rs: /v1/images/generations with full pipeline 52 tests passing. Zero weight files — disk space conscious. https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
1 parent a4d0b56 commit 93e86a4

12 files changed

Lines changed: 1562 additions & 0 deletions

File tree

src/hpc/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,18 @@ pub mod gguf;
170170
#[allow(missing_docs)]
171171
pub mod jina;
172172

173+
/// Shared model primitives — safetensors, SIMD layers, API types.
174+
#[allow(missing_docs)]
175+
pub mod models;
176+
173177
/// GPT-2 inference engine — full forward pass + OpenAI-compatible API types.
174178
#[allow(missing_docs)]
175179
pub mod gpt2;
176180

181+
/// Stable Diffusion inference — CLIP + UNet + VAE + DDIM scheduler.
182+
#[allow(missing_docs)]
183+
pub mod stable_diffusion;
184+
177185
// jitson: JSON config → scan pipeline (parser, validator, template, precompile, packed)
178186
// Always available — no Cranelift dependency.
179187
#[allow(missing_docs)]

src/hpc/models/api_types.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
//! OpenAI-compatible API types shared across all model endpoints.
2+
//!
3+
//! Transport-agnostic — no HTTP framework dependency.
4+
//! Used by GPT-2 (/v1/completions), Stable Diffusion (/v1/images/generations),
5+
//! BERT/Jina (/v1/embeddings).
6+
7+
/// Token usage statistics (shared by all endpoints).
8+
#[derive(Clone, Debug, Default)]
9+
pub struct Usage {
10+
pub prompt_tokens: usize,
11+
pub completion_tokens: usize,
12+
pub total_tokens: usize,
13+
}
14+
15+
/// Why generation stopped.
16+
#[derive(Clone, Debug, PartialEq, Eq)]
17+
pub enum FinishReason {
18+
/// Hit stop token or stop sequence.
19+
Stop,
20+
/// Hit max_tokens limit.
21+
Length,
22+
/// Content filter triggered.
23+
ContentFilter,
24+
}
25+
26+
/// Error response envelope.
27+
#[derive(Clone, Debug)]
28+
pub struct ApiError {
29+
pub message: String,
30+
pub error_type: String,
31+
pub code: Option<String>,
32+
}
33+
34+
impl ApiError {
35+
pub fn invalid_request(msg: impl Into<String>) -> Self {
36+
Self {
37+
message: msg.into(),
38+
error_type: "invalid_request_error".into(),
39+
code: None,
40+
}
41+
}
42+
43+
pub fn model_not_found(model: &str) -> Self {
44+
Self {
45+
message: format!("model '{}' not found", model),
46+
error_type: "invalid_request_error".into(),
47+
code: Some("model_not_found".into()),
48+
}
49+
}
50+
}
51+
52+
/// Model info for /v1/models listing.
53+
#[derive(Clone, Debug)]
54+
pub struct ModelCard {
55+
pub id: String,
56+
pub owned_by: String,
57+
pub created: u64,
58+
}
59+
60+
/// Embedding data for /v1/embeddings response.
61+
#[derive(Clone, Debug)]
62+
pub struct EmbeddingData {
63+
pub index: usize,
64+
pub embedding: Vec<f32>,
65+
}
66+
67+
/// /v1/embeddings response (shared by BERT, Jina, GPT-2 wte).
68+
#[derive(Clone, Debug)]
69+
pub struct EmbeddingResponse {
70+
pub model: String,
71+
pub data: Vec<EmbeddingData>,
72+
pub usage: Usage,
73+
}
74+
75+
/// Image data for /v1/images/generations response.
76+
#[derive(Clone, Debug)]
77+
pub struct ImageData {
78+
/// Base64-encoded PNG, or URL if hosted.
79+
pub b64_json: Option<String>,
80+
pub url: Option<String>,
81+
pub revised_prompt: Option<String>,
82+
}
83+
84+
/// /v1/images/generations response (Stable Diffusion).
85+
#[derive(Clone, Debug)]
86+
pub struct ImageResponse {
87+
pub created: u64,
88+
pub data: Vec<ImageData>,
89+
}
90+
91+
#[cfg(test)]
92+
mod tests {
93+
use super::*;
94+
95+
#[test]
96+
fn test_usage_default() {
97+
let u = Usage::default();
98+
assert_eq!(u.prompt_tokens, 0);
99+
assert_eq!(u.total_tokens, 0);
100+
}
101+
102+
#[test]
103+
fn test_api_error_invalid_request() {
104+
let e = ApiError::invalid_request("bad input");
105+
assert_eq!(e.error_type, "invalid_request_error");
106+
assert!(e.code.is_none());
107+
}
108+
109+
#[test]
110+
fn test_api_error_model_not_found() {
111+
let e = ApiError::model_not_found("gpt-5");
112+
assert!(e.message.contains("gpt-5"));
113+
assert_eq!(e.code.as_deref(), Some("model_not_found"));
114+
}
115+
116+
#[test]
117+
fn test_finish_reason_eq() {
118+
assert_eq!(FinishReason::Stop, FinishReason::Stop);
119+
assert_ne!(FinishReason::Stop, FinishReason::Length);
120+
assert_ne!(FinishReason::Length, FinishReason::ContentFilter);
121+
}
122+
}

0 commit comments

Comments
 (0)