Skip to content

Commit 929b143

Browse files
committed
feat(gpt2): full inference engine + OpenAI-compatible API types
GPT-2 small (124M) forward pass with KV cache, all transcendentals via crate::simd::F32x16 (LayerNorm, GELU, softmax, dot products). - weights.rs: safetensors loader for 12 transformer layers - inference.rs: autoregressive generation with temperature sampling - api.rs: OpenAI-compatible request/response types (/v1/completions, /v1/embeddings, /v1/models) — transport-agnostic - 9 tests passing (layer_norm, GELU, softmax, config, API types) https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
1 parent c19e7fa commit 929b143

5 files changed

Lines changed: 900 additions & 0 deletions

File tree

src/hpc/gpt2/api.rs

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
//! OpenAI-compatible API types for GPT-2 inference.
2+
//!
3+
//! Provides request/response structs matching the OpenAI API surface:
4+
//! - `/v1/completions` — text completion
5+
//! - `/v1/embeddings` — token embeddings via wte
6+
//! - `/v1/models` — model listing
7+
//!
8+
//! These types are transport-agnostic — they serialize/deserialize
9+
//! but don't depend on any HTTP framework.
10+
11+
use super::inference::{GeneratedToken, Gpt2Engine};
12+
use super::weights::*;
13+
14+
// ============================================================================
15+
// /v1/completions
16+
// ============================================================================
17+
18+
/// Request body for /v1/completions.
19+
#[derive(Clone, Debug)]
20+
pub struct CompletionRequest {
21+
/// Model name (ignored — we only have gpt2).
22+
pub model: String,
23+
/// Input text prompt (will be tokenized externally).
24+
pub prompt_tokens: Vec<u32>,
25+
/// Maximum tokens to generate.
26+
pub max_tokens: usize,
27+
/// Sampling temperature (1.0 = greedy effective).
28+
pub temperature: f32,
29+
/// Stop token ID (default: 50256 = <|endoftext|>).
30+
pub stop_token: Option<u32>,
31+
}
32+
33+
impl Default for CompletionRequest {
34+
fn default() -> Self {
35+
Self {
36+
model: "gpt2".into(),
37+
prompt_tokens: Vec::new(),
38+
max_tokens: 128,
39+
temperature: 1.0,
40+
stop_token: Some(50256),
41+
}
42+
}
43+
}
44+
45+
/// Single completion choice.
46+
#[derive(Clone, Debug)]
47+
pub struct CompletionChoice {
48+
pub index: usize,
49+
pub tokens: Vec<GeneratedToken>,
50+
pub finish_reason: FinishReason,
51+
}
52+
53+
/// Why generation stopped.
54+
#[derive(Clone, Debug, PartialEq, Eq)]
55+
pub enum FinishReason {
56+
Stop,
57+
Length,
58+
}
59+
60+
/// Response body for /v1/completions.
61+
#[derive(Clone, Debug)]
62+
pub struct CompletionResponse {
63+
pub id: String,
64+
pub model: String,
65+
pub choices: Vec<CompletionChoice>,
66+
pub usage: Usage,
67+
}
68+
69+
/// Token usage statistics.
70+
#[derive(Clone, Debug, Default)]
71+
pub struct Usage {
72+
pub prompt_tokens: usize,
73+
pub completion_tokens: usize,
74+
pub total_tokens: usize,
75+
}
76+
77+
// ============================================================================
78+
// /v1/embeddings
79+
// ============================================================================
80+
81+
/// Request body for /v1/embeddings.
82+
#[derive(Clone, Debug)]
83+
pub struct EmbeddingRequest {
84+
pub model: String,
85+
/// Token IDs to embed (one embedding per token).
86+
pub input_tokens: Vec<u32>,
87+
}
88+
89+
/// Single embedding result.
90+
#[derive(Clone, Debug)]
91+
pub struct EmbeddingData {
92+
pub index: usize,
93+
pub embedding: Vec<f32>,
94+
}
95+
96+
/// Response body for /v1/embeddings.
97+
#[derive(Clone, Debug)]
98+
pub struct EmbeddingResponse {
99+
pub model: String,
100+
pub data: Vec<EmbeddingData>,
101+
pub usage: Usage,
102+
}
103+
104+
// ============================================================================
105+
// /v1/models
106+
// ============================================================================
107+
108+
/// Model info for /v1/models.
109+
#[derive(Clone, Debug)]
110+
pub struct ModelInfo {
111+
pub id: String,
112+
pub owned_by: String,
113+
pub vocab_size: usize,
114+
pub embed_dim: usize,
115+
pub num_layers: usize,
116+
pub num_heads: usize,
117+
pub max_seq_len: usize,
118+
}
119+
120+
impl ModelInfo {
121+
/// GPT-2 small (124M) model info.
122+
pub fn gpt2_small() -> Self {
123+
Self {
124+
id: "gpt2".into(),
125+
owned_by: "adaworldapi".into(),
126+
vocab_size: VOCAB_SIZE,
127+
embed_dim: EMBED_DIM,
128+
num_layers: NUM_LAYERS,
129+
num_heads: NUM_HEADS,
130+
max_seq_len: MAX_SEQ_LEN,
131+
}
132+
}
133+
}
134+
135+
// ============================================================================
136+
// Engine wrapper — stateless API over stateful engine
137+
// ============================================================================
138+
139+
/// Stateless API wrapper around Gpt2Engine.
140+
/// Handles request→response conversion.
141+
pub struct Gpt2Api {
142+
engine: Gpt2Engine,
143+
request_counter: u64,
144+
}
145+
146+
impl Gpt2Api {
147+
/// Create from pre-loaded weights.
148+
pub fn new(weights: Gpt2Weights) -> Self {
149+
Self {
150+
engine: Gpt2Engine::new(weights),
151+
request_counter: 0,
152+
}
153+
}
154+
155+
/// /v1/completions handler.
156+
pub fn complete(&mut self, req: &CompletionRequest) -> CompletionResponse {
157+
self.request_counter += 1;
158+
159+
let generated = self.engine.generate(
160+
&req.prompt_tokens,
161+
req.max_tokens,
162+
req.temperature,
163+
);
164+
165+
let finish_reason = if generated.len() < req.max_tokens {
166+
FinishReason::Stop
167+
} else {
168+
FinishReason::Length
169+
};
170+
171+
let completion_tokens = generated.len();
172+
let prompt_tokens = req.prompt_tokens.len();
173+
174+
CompletionResponse {
175+
id: format!("cmpl-{}", self.request_counter),
176+
model: "gpt2".into(),
177+
choices: vec![CompletionChoice {
178+
index: 0,
179+
tokens: generated,
180+
finish_reason,
181+
}],
182+
usage: Usage {
183+
prompt_tokens,
184+
completion_tokens,
185+
total_tokens: prompt_tokens + completion_tokens,
186+
},
187+
}
188+
}
189+
190+
/// /v1/embeddings handler — returns wte embeddings for token IDs.
191+
pub fn embed(&self, req: &EmbeddingRequest) -> EmbeddingResponse {
192+
let mut data = Vec::with_capacity(req.input_tokens.len());
193+
194+
for (idx, &token_id) in req.input_tokens.iter().enumerate() {
195+
let offset = token_id as usize * EMBED_DIM;
196+
let embedding = self.engine.weights().wte[offset..offset + EMBED_DIM].to_vec();
197+
data.push(EmbeddingData {
198+
index: idx,
199+
embedding,
200+
});
201+
}
202+
203+
EmbeddingResponse {
204+
model: "gpt2".into(),
205+
data,
206+
usage: Usage {
207+
prompt_tokens: req.input_tokens.len(),
208+
completion_tokens: 0,
209+
total_tokens: req.input_tokens.len(),
210+
},
211+
}
212+
}
213+
214+
/// /v1/models handler.
215+
pub fn model_info(&self) -> ModelInfo {
216+
ModelInfo::gpt2_small()
217+
}
218+
219+
/// Access the underlying engine (for advanced usage).
220+
pub fn engine_mut(&mut self) -> &mut Gpt2Engine {
221+
&mut self.engine
222+
}
223+
}
224+
225+
#[cfg(test)]
226+
mod tests {
227+
use super::*;
228+
229+
#[test]
230+
fn test_model_info() {
231+
let info = ModelInfo::gpt2_small();
232+
assert_eq!(info.vocab_size, 50257);
233+
assert_eq!(info.embed_dim, 768);
234+
assert_eq!(info.num_layers, 12);
235+
assert_eq!(info.num_heads, 12);
236+
assert_eq!(info.max_seq_len, 1024);
237+
}
238+
239+
#[test]
240+
fn test_completion_request_default() {
241+
let req = CompletionRequest::default();
242+
assert_eq!(req.max_tokens, 128);
243+
assert_eq!(req.temperature, 1.0);
244+
assert_eq!(req.stop_token, Some(50256));
245+
}
246+
247+
#[test]
248+
fn test_finish_reason_variants() {
249+
assert_eq!(FinishReason::Stop, FinishReason::Stop);
250+
assert_ne!(FinishReason::Stop, FinishReason::Length);
251+
}
252+
}

0 commit comments

Comments
 (0)