Skip to content

Commit 5f2c4ad

Browse files
committed
auth: support model alias claims with body rewrite
The platform's MODEL_HF_PATH_OVERRIDES alias was previously one-way: the JWT model claim and the orchestrator's TOML model.name both held the canonical HF path, but envs that hard-code the user-facing name (e.g. emilschmitz/curriculum-oversight's LLM judge sending sprints/Llama-3.2-1B-Instruct) hit 403 because the canonical-only claim required exact-string match. RftClaims gains an optional model_aliases list. allows_model accepts alias hits, and pin_and_check_model variants (typed, string, JSON) plus the transparent-proxy fallback rewrite the request body's model field to self.model after an alias-only match, so vLLM only ever sees the canonical served name. Base-model and LoRA matches pass through unchanged.
1 parent 3b76934 commit 5f2c4ad

2 files changed

Lines changed: 117 additions & 2 deletions

File tree

src/auth.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ pub struct RftClaims {
1717
/// Allowed LoRA adapter name
1818
#[serde(default)]
1919
pub lora: Option<String>,
20+
/// Alternate names that should resolve to the base `model`. Used when
21+
/// the platform exposes a user-facing model identifier (e.g.
22+
/// `sprints/Llama-3.2-1B-Instruct`) that differs from the canonical
23+
/// HF path vLLM serves (`meta-llama/Llama-3.2-1B-Instruct`). A
24+
/// request hitting an alias is authorized and the request body's
25+
/// `model` field is rewritten to `self.model` before dispatch.
26+
#[serde(default)]
27+
pub model_aliases: Vec<String>,
2028
}
2129

2230
/// Verifier for RS256 JWTs signed by the platform.
@@ -73,6 +81,10 @@ impl RftClaims {
7381
}
7482
}
7583

84+
if self.is_model_alias(requested) {
85+
return true;
86+
}
87+
7688
if let Some(lora) = self.lora.as_deref() {
7789
// Empty lora claim must never authorize anything; an empty
7890
// string is a prefix of every other string, which would let
@@ -96,19 +108,52 @@ impl RftClaims {
96108

97109
false
98110
}
111+
112+
/// Whether `requested` is one of the alternate names declared in
113+
/// `model_aliases`. Empty entries are ignored so a misconfigured
114+
/// claim never authorizes the empty model.
115+
fn is_model_alias(&self, requested: &str) -> bool {
116+
if requested.is_empty() {
117+
return false;
118+
}
119+
self.model_aliases
120+
.iter()
121+
.any(|a| !a.is_empty() && a == requested)
122+
}
123+
124+
/// If `requested` matched the JWT only via `model_aliases`, return
125+
/// the canonical model name (`self.model`) so the request body can
126+
/// be rewritten before forwarding to vLLM. Returns `None` if the
127+
/// request already targets the base model or a LoRA — those need to
128+
/// pass through unchanged.
129+
pub fn canonical_for_alias(&self, requested: &str) -> Option<String> {
130+
if !self.is_model_alias(requested) {
131+
return None;
132+
}
133+
let base = self.model.as_deref()?;
134+
if base.is_empty() || base == requested {
135+
return None;
136+
}
137+
Some(base.to_string())
138+
}
99139
}
100140

101141
#[cfg(test)]
102142
mod tests {
103143
use super::*;
104144

105145
fn claims(model: Option<&str>, lora: Option<&str>) -> RftClaims {
146+
claims_with_aliases(model, lora, &[])
147+
}
148+
149+
fn claims_with_aliases(model: Option<&str>, lora: Option<&str>, aliases: &[&str]) -> RftClaims {
106150
RftClaims {
107151
sub: "user".into(),
108152
run_id: "abc".into(),
109153
team_id: String::new(),
110154
model: model.map(String::from),
111155
lora: lora.map(String::from),
156+
model_aliases: aliases.iter().map(|s| (*s).to_string()).collect(),
112157
}
113158
}
114159

@@ -165,4 +210,51 @@ mod tests {
165210
let c = claims(Some(""), Some("rft-abc"));
166211
assert!(!c.allows_model(""));
167212
}
213+
214+
#[test]
215+
fn allows_model_alias() {
216+
let c = claims_with_aliases(
217+
Some("meta-llama/Llama-3.2-1B-Instruct"),
218+
Some("rft-abc"),
219+
&["sprints/Llama-3.2-1B-Instruct"],
220+
);
221+
assert!(c.allows_model("sprints/Llama-3.2-1B-Instruct"));
222+
assert!(c.allows_model("meta-llama/Llama-3.2-1B-Instruct"));
223+
assert!(!c.allows_model("other/model"));
224+
}
225+
226+
#[test]
227+
fn canonical_for_alias_rewrites_alias_only() {
228+
let c = claims_with_aliases(
229+
Some("meta-llama/Llama-3.2-1B-Instruct"),
230+
Some("rft-abc"),
231+
&["sprints/Llama-3.2-1B-Instruct"],
232+
);
233+
assert_eq!(
234+
c.canonical_for_alias("sprints/Llama-3.2-1B-Instruct").as_deref(),
235+
Some("meta-llama/Llama-3.2-1B-Instruct"),
236+
);
237+
// Base model and lora must NOT be rewritten — lora dispatch
238+
// depends on the original name reaching vLLM.
239+
assert_eq!(c.canonical_for_alias("meta-llama/Llama-3.2-1B-Instruct"), None);
240+
assert_eq!(c.canonical_for_alias("rft-abc"), None);
241+
assert_eq!(c.canonical_for_alias("rft-abc-step-42"), None);
242+
assert_eq!(c.canonical_for_alias("unrelated"), None);
243+
}
244+
245+
#[test]
246+
fn empty_alias_entry_authorizes_nothing() {
247+
let c = claims_with_aliases(Some("meta-llama/Llama-3.2-1B-Instruct"), Some("rft-abc"), &[""]);
248+
assert!(!c.allows_model(""));
249+
assert_eq!(c.canonical_for_alias(""), None);
250+
}
251+
252+
#[test]
253+
fn alias_matching_base_does_not_rewrite() {
254+
// Pathological config: alias equals the base. Should still
255+
// authorize, but not produce a rewrite (would be a no-op anyway).
256+
let c = claims_with_aliases(Some("meta-llama/Llama-3.2-1B-Instruct"), None, &["meta-llama/Llama-3.2-1B-Instruct"]);
257+
assert!(c.allows_model("meta-llama/Llama-3.2-1B-Instruct"));
258+
assert_eq!(c.canonical_for_alias("meta-llama/Llama-3.2-1B-Instruct"), None);
259+
}
168260
}

src/server.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ async fn transparent_proxy_handler(State(state): State<Arc<AppState>>, req: Requ
170170
};
171171

172172
// Parse body as JSON
173-
let body_json: serde_json::Value = if body_bytes.is_empty() {
173+
let mut body_json: serde_json::Value = if body_bytes.is_empty() {
174174
serde_json::Value::Null
175175
} else {
176176
match serde_json::from_slice(&body_bytes) {
@@ -193,8 +193,9 @@ async fn transparent_proxy_handler(State(state): State<Arc<AppState>>, req: Requ
193193
.get("model")
194194
.and_then(|v| v.as_str())
195195
.filter(|s| !s.is_empty())
196+
.map(str::to_string)
196197
{
197-
if !claims_ref.allows_model(model) {
198+
if !claims_ref.allows_model(&model) {
198199
warn!(
199200
run_id = %claims_ref.run_id,
200201
requested_model = %model,
@@ -209,6 +210,14 @@ async fn transparent_proxy_handler(State(state): State<Arc<AppState>>, req: Requ
209210
)
210211
.into_response();
211212
}
213+
if let Some(canonical) = claims_ref.canonical_for_alias(&model) {
214+
if let Some(obj) = body_json.as_object_mut() {
215+
obj.insert(
216+
"model".to_string(),
217+
serde_json::Value::String(canonical),
218+
);
219+
}
220+
}
212221
}
213222
if let Err(response) = enforce_no_lora_path_override_json(&claims, &body_json) {
214223
return response;
@@ -445,6 +454,9 @@ fn pin_and_check_model(
445454

446455
let resolved = model.as_deref().unwrap_or("");
447456
if claims.allows_model(resolved) {
457+
if let Some(canonical) = claims.canonical_for_alias(resolved) {
458+
*model = Some(canonical);
459+
}
448460
return Ok(());
449461
}
450462

@@ -500,6 +512,9 @@ fn pin_and_check_model_string(
500512
}
501513

502514
if claims.allows_model(model) {
515+
if let Some(canonical) = claims.canonical_for_alias(model) {
516+
*model = canonical;
517+
}
503518
return Ok(());
504519
}
505520

@@ -558,6 +573,14 @@ fn pin_and_check_model_json(
558573
};
559574

560575
if claims.allows_model(&resolved) {
576+
if let Some(canonical) = claims.canonical_for_alias(&resolved) {
577+
if let Some(obj) = body.as_object_mut() {
578+
obj.insert(
579+
"model".to_string(),
580+
serde_json::Value::String(canonical),
581+
);
582+
}
583+
}
561584
return Ok(());
562585
}
563586

0 commit comments

Comments
 (0)