|
| 1 | +// <complete_code> |
| 2 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +// Licensed under the MIT License. |
| 4 | + |
| 5 | +//! Responses API web-service sample. |
| 6 | +//! |
| 7 | +//! Demonstrates how to use the Rust SDK for Foundry Local setup, model |
| 8 | +//! lifecycle, and local web-service lifecycle, then call `/v1/responses` with a |
| 9 | +//! standard HTTP client. |
| 10 | +
|
| 11 | +// <imports> |
| 12 | +use std::error::Error; |
| 13 | +use std::io::{self, Write}; |
| 14 | + |
| 15 | +use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager}; |
| 16 | +use serde_json::{json, Value}; |
| 17 | +// </imports> |
| 18 | + |
| 19 | +type SampleResult<T> = Result<T, Box<dyn Error + Send + Sync>>; |
| 20 | +const MODEL_ALIAS: &str = "qwen2.5-0.5b"; |
| 21 | + |
| 22 | +#[tokio::main] |
| 23 | +async fn main() -> SampleResult<()> { |
| 24 | + println!("Responses Web Service"); |
| 25 | + println!("=====================\n"); |
| 26 | + |
| 27 | + // ── 1. Initialise the SDK ──────────────────────────────────────────── |
| 28 | + // <init> |
| 29 | + println!("Initializing Foundry Local SDK..."); |
| 30 | + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; |
| 31 | + println!("SDK initialized successfully"); |
| 32 | + |
| 33 | + manager |
| 34 | + .download_and_register_eps_with_progress(None, { |
| 35 | + let mut current_ep = String::new(); |
| 36 | + move |ep_name: &str, percent: f64| { |
| 37 | + if ep_name != current_ep { |
| 38 | + if !current_ep.is_empty() { |
| 39 | + println!(); |
| 40 | + } |
| 41 | + current_ep = ep_name.to_string(); |
| 42 | + } |
| 43 | + print!("\r {:<30} {:5.1}%", ep_name, percent); |
| 44 | + io::stdout().flush().ok(); |
| 45 | + } |
| 46 | + }) |
| 47 | + .await?; |
| 48 | + println!(); |
| 49 | + // </init> |
| 50 | + |
| 51 | + // ── 2. Download and load a model ───────────────────────────────────── |
| 52 | + // <model_setup> |
| 53 | + let model = manager.catalog().get_model(MODEL_ALIAS).await?; |
| 54 | + |
| 55 | + if !model.is_cached().await? { |
| 56 | + println!("Downloading model {MODEL_ALIAS}..."); |
| 57 | + model |
| 58 | + .download(Some(|progress: f64| { |
| 59 | + print!("\rDownloading model... {progress:.1}%"); |
| 60 | + io::stdout().flush().ok(); |
| 61 | + })) |
| 62 | + .await?; |
| 63 | + println!(); |
| 64 | + } |
| 65 | + |
| 66 | + println!("Loading model {MODEL_ALIAS}..."); |
| 67 | + model.load().await?; |
| 68 | + println!("Model loaded"); |
| 69 | + // </model_setup> |
| 70 | + |
| 71 | + // ── 3. Start the OpenAI-compatible web service ─────────────────────── |
| 72 | + // <server_setup> |
| 73 | + println!("Starting web service..."); |
| 74 | + manager.start_web_service().await?; |
| 75 | + println!("Web service started"); |
| 76 | + |
| 77 | + let endpoint = manager |
| 78 | + .urls()? |
| 79 | + .first() |
| 80 | + .expect("Web service did not return an endpoint") |
| 81 | + .trim_end_matches('/') |
| 82 | + .to_string(); |
| 83 | + let base_url = format!("{endpoint}/v1"); |
| 84 | + println!("Using base URL: {base_url}"); |
| 85 | + // </server_setup> |
| 86 | + |
| 87 | + let result = run_responses_flow(&base_url, model.id()).await; |
| 88 | + |
| 89 | + // ── 4. Clean up ────────────────────────────────────────────────────── |
| 90 | + manager.stop_web_service().await.ok(); |
| 91 | + model.unload().await.ok(); |
| 92 | + |
| 93 | + result |
| 94 | +} |
| 95 | + |
| 96 | +async fn run_responses_flow(base_url: &str, model_id: &str) -> SampleResult<()> { |
| 97 | + let http = reqwest::Client::new(); |
| 98 | + |
| 99 | + println!("\nTesting a non-streaming Responses call..."); |
| 100 | + let response = post_response_json( |
| 101 | + &http, |
| 102 | + base_url, |
| 103 | + json!({ |
| 104 | + "model": model_id, |
| 105 | + "input": "Reply with one short sentence about local AI.", |
| 106 | + "temperature": 0.0, |
| 107 | + "max_output_tokens": 64, |
| 108 | + "store": false |
| 109 | + }), |
| 110 | + ) |
| 111 | + .await?; |
| 112 | + println!("[ASSISTANT]: {}", output_text(&response)); |
| 113 | + |
| 114 | + println!("\nTesting a streaming Responses call..."); |
| 115 | + print!("[ASSISTANT STREAM]: "); |
| 116 | + io::stdout().flush().ok(); |
| 117 | + let streaming_response = http |
| 118 | + .post(format!("{base_url}/responses")) |
| 119 | + .header(reqwest::header::ACCEPT, "text/event-stream") |
| 120 | + .json(&json!({ |
| 121 | + "model": model_id, |
| 122 | + "input": "Count from one to three.", |
| 123 | + "temperature": 0.0, |
| 124 | + "max_output_tokens": 64, |
| 125 | + "store": false, |
| 126 | + "stream": true |
| 127 | + })) |
| 128 | + .send() |
| 129 | + .await?; |
| 130 | + let streamed = read_responses_sse(streaming_response).await?; |
| 131 | + println!(); |
| 132 | + if !streamed.created || streamed.delta_count == 0 || !streamed.completed { |
| 133 | + return Err( |
| 134 | + "stream did not include response.created, text delta, and completion events".into(), |
| 135 | + ); |
| 136 | + } |
| 137 | + |
| 138 | + println!("\nTesting Responses tool calling..."); |
| 139 | + let tools = [get_weather_tool()]; |
| 140 | + let tool_response = post_response_json( |
| 141 | + &http, |
| 142 | + base_url, |
| 143 | + json!({ |
| 144 | + "model": model_id, |
| 145 | + "input": "Use the get_weather tool and then answer with the weather.", |
| 146 | + "tools": tools, |
| 147 | + "tool_choice": "required", |
| 148 | + "temperature": 0.0, |
| 149 | + "max_output_tokens": 64, |
| 150 | + "store": true |
| 151 | + }), |
| 152 | + ) |
| 153 | + .await?; |
| 154 | + |
| 155 | + let (call_id, name) = |
| 156 | + find_function_call(&tool_response).ok_or("expected a function_call item")?; |
| 157 | + println!("[TOOL CALL]: {name} ({call_id})"); |
| 158 | + |
| 159 | + let final_response = post_response_json( |
| 160 | + &http, |
| 161 | + base_url, |
| 162 | + json!({ |
| 163 | + "model": model_id, |
| 164 | + "previous_response_id": tool_response["id"].clone(), |
| 165 | + "input": [{ |
| 166 | + "type": "function_call_output", |
| 167 | + "call_id": call_id, |
| 168 | + "output": "{\"location\":\"Seattle\",\"weather\":\"72 degrees F and sunny\"}" |
| 169 | + }], |
| 170 | + "tools": [get_weather_tool()], |
| 171 | + "temperature": 0.0, |
| 172 | + "max_output_tokens": 64, |
| 173 | + "store": false |
| 174 | + }), |
| 175 | + ) |
| 176 | + .await?; |
| 177 | + println!("[ASSISTANT FINAL]: {}", output_text(&final_response)); |
| 178 | + |
| 179 | + Ok(()) |
| 180 | +} |
| 181 | + |
| 182 | +async fn post_response_json( |
| 183 | + http: &reqwest::Client, |
| 184 | + base_url: &str, |
| 185 | + body: Value, |
| 186 | +) -> SampleResult<Value> { |
| 187 | + let response = http |
| 188 | + .post(format!("{base_url}/responses")) |
| 189 | + .json(&body) |
| 190 | + .send() |
| 191 | + .await?; |
| 192 | + let status = response.status(); |
| 193 | + let text = response.text().await?; |
| 194 | + if !status.is_success() { |
| 195 | + return Err(format!("Responses API returned {status}: {text}").into()); |
| 196 | + } |
| 197 | + Ok(serde_json::from_str(&text)?) |
| 198 | +} |
| 199 | + |
| 200 | +fn output_text(response: &Value) -> String { |
| 201 | + if let Some(text) = response.get("output_text").and_then(Value::as_str) { |
| 202 | + return text.to_string(); |
| 203 | + } |
| 204 | + |
| 205 | + response |
| 206 | + .get("output") |
| 207 | + .and_then(Value::as_array) |
| 208 | + .into_iter() |
| 209 | + .flatten() |
| 210 | + .find_map(|item| { |
| 211 | + if item.get("type").and_then(Value::as_str) != Some("message") { |
| 212 | + return None; |
| 213 | + } |
| 214 | + match item.get("content") { |
| 215 | + Some(Value::String(text)) => Some(text.clone()), |
| 216 | + Some(Value::Array(parts)) => Some( |
| 217 | + parts |
| 218 | + .iter() |
| 219 | + .filter_map(|part| { |
| 220 | + (part.get("type").and_then(Value::as_str) == Some("output_text")) |
| 221 | + .then(|| part.get("text").and_then(Value::as_str)) |
| 222 | + .flatten() |
| 223 | + }) |
| 224 | + .collect::<String>(), |
| 225 | + ), |
| 226 | + _ => None, |
| 227 | + } |
| 228 | + }) |
| 229 | + .unwrap_or_default() |
| 230 | +} |
| 231 | + |
| 232 | +fn find_function_call(response: &Value) -> Option<(String, String)> { |
| 233 | + response.get("output")?.as_array()?.iter().find_map(|item| { |
| 234 | + if item.get("type").and_then(Value::as_str) != Some("function_call") { |
| 235 | + return None; |
| 236 | + } |
| 237 | + let call_id = item.get("call_id")?.as_str()?.to_string(); |
| 238 | + let name = item.get("name")?.as_str()?.to_string(); |
| 239 | + Some((call_id, name)) |
| 240 | + }) |
| 241 | +} |
| 242 | + |
| 243 | +fn get_weather_tool() -> Value { |
| 244 | + json!({ |
| 245 | + "type": "function", |
| 246 | + "name": "get_weather", |
| 247 | + "description": "Get the current weather. This sample always returns Seattle weather.", |
| 248 | + "parameters": { |
| 249 | + "type": "object", |
| 250 | + "properties": {}, |
| 251 | + "additionalProperties": false |
| 252 | + } |
| 253 | + }) |
| 254 | +} |
| 255 | + |
| 256 | +#[derive(Default)] |
| 257 | +struct StreamSummary { |
| 258 | + created: bool, |
| 259 | + delta_count: usize, |
| 260 | + completed: bool, |
| 261 | +} |
| 262 | + |
| 263 | +async fn read_responses_sse(mut response: reqwest::Response) -> SampleResult<StreamSummary> { |
| 264 | + let status = response.status(); |
| 265 | + if !status.is_success() { |
| 266 | + let text = response.text().await?; |
| 267 | + return Err(format!("Responses API returned {status}: {text}").into()); |
| 268 | + } |
| 269 | + |
| 270 | + let mut buffer = String::new(); |
| 271 | + let mut summary = StreamSummary::default(); |
| 272 | + |
| 273 | + while let Some(chunk) = response.chunk().await? { |
| 274 | + buffer.push_str(&String::from_utf8_lossy(&chunk).replace("\r\n", "\n")); |
| 275 | + while let Some(block_end) = buffer.find("\n\n") { |
| 276 | + let block = buffer[..block_end].to_string(); |
| 277 | + buffer = buffer[block_end + 2..].to_string(); |
| 278 | + if handle_sse_block(&block, &mut summary) { |
| 279 | + return Ok(summary); |
| 280 | + } |
| 281 | + } |
| 282 | + } |
| 283 | + |
| 284 | + if !buffer.trim().is_empty() { |
| 285 | + handle_sse_block(&buffer, &mut summary); |
| 286 | + } |
| 287 | + |
| 288 | + Ok(summary) |
| 289 | +} |
| 290 | + |
| 291 | +fn handle_sse_block(block: &str, summary: &mut StreamSummary) -> bool { |
| 292 | + let data = block |
| 293 | + .lines() |
| 294 | + .filter_map(|line| line.trim().strip_prefix("data:").map(str::trim_start)) |
| 295 | + .collect::<Vec<_>>() |
| 296 | + .join("\n"); |
| 297 | + |
| 298 | + if data.is_empty() { |
| 299 | + return false; |
| 300 | + } |
| 301 | + if data == "[DONE]" { |
| 302 | + return true; |
| 303 | + } |
| 304 | + |
| 305 | + if let Ok(event) = serde_json::from_str::<Value>(&data) { |
| 306 | + match event.get("type").and_then(Value::as_str) { |
| 307 | + Some("response.created") => summary.created = true, |
| 308 | + Some("response.output_text.delta") => { |
| 309 | + summary.delta_count += 1; |
| 310 | + if let Some(delta) = event.get("delta").and_then(Value::as_str) { |
| 311 | + print!("{delta}"); |
| 312 | + io::stdout().flush().ok(); |
| 313 | + } |
| 314 | + } |
| 315 | + Some("response.completed") => summary.completed = true, |
| 316 | + _ => {} |
| 317 | + } |
| 318 | + } |
| 319 | + |
| 320 | + false |
| 321 | +} |
| 322 | +// </complete_code> |
0 commit comments