Skip to content

Commit f89e79c

Browse files
authored
Merge pull request #60 from RustedBytes/copilot/add-schema-json-support
2 parents 5f27406 + 6fecf84 commit f89e79c

4 files changed

Lines changed: 295 additions & 7 deletions

File tree

README.md

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ with a prompt and input file, and writes the response to an output file. The
3333
basic usage is as follows:
3434

3535
```bash
36-
invoke-llm --endpoint <endpoint> --model <model> --tokens <tokens> --prompt <prompt_file> --input <input_file> [--output <output_file>]
36+
invoke-llm --endpoint <endpoint> --model <model> --tokens <tokens> --prompt <prompt_file> --input <input_file> [--output <output_file>] [--schema <schema_file>]
3737
```
3838

3939
### Command-Line Arguments
@@ -50,6 +50,8 @@ The following command-line arguments are supported:
5050
provided).
5151
* `--reasoning` (optional): Whether to use reasoning models instead of regular
5252
ones.
53+
* `--schema` (optional): Path to a JSON schema file for structured output (using
54+
OpenAI's structured output format).
5355

5456
### Environment Variables
5557

@@ -70,6 +72,47 @@ The following endpoints are currently supported:
7072
* "hf": Hugging Face API endpoint
7173
* Custom endpoints: Any custom URL can be used as an endpoint
7274

75+
### Structured Output with JSON Schema
76+
77+
The `--schema` option enables you to specify a JSON schema file that defines the
78+
expected structure of the response. This is useful when you need the model to
79+
return data in a specific format.
80+
81+
Example schema file (`schema.json`):
82+
83+
```json
84+
{
85+
"name": "translation_response",
86+
"description": "Response format for translation",
87+
"strict": true,
88+
"schema": {
89+
"type": "object",
90+
"properties": {
91+
"translated_text": {
92+
"type": "string",
93+
"description": "The translated text"
94+
},
95+
"source_language": {
96+
"type": "string",
97+
"description": "The detected source language"
98+
},
99+
"target_language": {
100+
"type": "string",
101+
"description": "The target language"
102+
}
103+
},
104+
"required": ["translated_text", "source_language", "target_language"],
105+
"additionalProperties": false
106+
}
107+
}
108+
```
109+
110+
Usage with schema:
111+
112+
```bash
113+
invoke-llm --endpoint openai --model gpt-4o --tokens 500 --prompt prompt.txt --input input.txt --schema schema.json --output response.json
114+
```
115+
73116
### Examples
74117

75118
To run `invoke-llm` with some pre-defined prompts, use the following commands:

examples/schema.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"name": "translation_response",
3+
"description": "Response format for translation",
4+
"strict": true,
5+
"schema": {
6+
"type": "object",
7+
"properties": {
8+
"translated_text": {
9+
"type": "string",
10+
"description": "The translated text"
11+
},
12+
"source_language": {
13+
"type": "string",
14+
"description": "The detected source language"
15+
},
16+
"target_language": {
17+
"type": "string",
18+
"description": "The target language"
19+
}
20+
},
21+
"required": ["translated_text", "source_language", "target_language"],
22+
"additionalProperties": false
23+
}
24+
}

src/main.rs

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ struct Args {
8686
/// Optional `API_TOKEN` to use to access the API endpoint
8787
#[arg(short, long, required = false)]
8888
api_token: Option<String>,
89+
90+
/// Optional path to a JSON schema file for structured output
91+
#[arg(long, value_parser, required = false)]
92+
schema: Option<PathBuf>,
8993
}
9094

9195
/// Represents a single message in the chat completion request.
@@ -98,6 +102,15 @@ struct RequestMessage {
98102
content: String,
99103
}
100104

105+
/// Represents the response format configuration for structured output.
106+
///
107+
/// This structure is used to specify the JSON schema for the expected response format.
108+
#[derive(Serialize, Debug, Clone)]
109+
struct ResponseFormat {
110+
r#type: String,
111+
json_schema: serde_json::Value,
112+
}
113+
101114
/// The complete request payload sent to the chat completion API.
102115
///
103116
/// This structure contains all the necessary parameters for making a chat
@@ -111,6 +124,7 @@ struct RequestMessage {
111124
/// models)
112125
/// * `max_completion_tokens` - Maximum number of completion tokens (used for
113126
/// reasoning models)
127+
/// * `response_format` - Optional structured output schema
114128
#[derive(Serialize, Debug)]
115129
struct RequestPayload<'a> {
116130
messages: Vec<RequestMessage>,
@@ -119,6 +133,8 @@ struct RequestPayload<'a> {
119133
max_tokens: Option<u32>,
120134
#[serde(skip_serializing_if = "Option::is_none")]
121135
max_completion_tokens: Option<u32>,
136+
#[serde(skip_serializing_if = "Option::is_none")]
137+
response_format: Option<ResponseFormat>,
122138
}
123139

124140
/// Reads the entire contents of a file into a string.
@@ -142,6 +158,26 @@ fn read_file_content(file_path: impl AsRef<Path>) -> Result<String> {
142158
fs::read_to_string(file_path).context("Failed to read file content")
143159
}
144160

161+
/// Reads and parses a JSON schema file.
162+
///
163+
/// This function reads a JSON file containing a schema definition and parses it
164+
/// into a serde_json::Value. The schema is validated to ensure it's valid JSON.
165+
///
166+
/// # Arguments
167+
/// * `schema_path` - Path to the JSON schema file
168+
///
169+
/// # Returns
170+
/// * `Ok(serde_json::Value)` - The parsed JSON schema
171+
/// * `Err` - An error if the file could not be read or parsed as JSON
172+
fn read_schema_file(schema_path: impl AsRef<Path>) -> Result<serde_json::Value> {
173+
let schema_content = fs::read_to_string(schema_path).context("Failed to read schema file")?;
174+
175+
let schema: serde_json::Value =
176+
serde_json::from_str(&schema_content).context("Failed to parse JSON schema file")?;
177+
178+
Ok(schema)
179+
}
180+
145181
/// Maps known endpoint names to their corresponding API URLs.
146182
///
147183
/// This function takes a string identifier for a known service and returns
@@ -202,11 +238,14 @@ fn init_sentry() -> Option<sentry::ClientInitGuard> {
202238
};
203239

204240
info!("Sentry DSN detected. Enabling error monitoring.");
205-
let guard = sentry::init((dsn, sentry::ClientOptions {
206-
release: sentry::release_name!(),
207-
attach_stacktrace: true,
208-
..Default::default()
209-
}));
241+
let guard = sentry::init((
242+
dsn,
243+
sentry::ClientOptions {
244+
release: sentry::release_name!(),
245+
attach_stacktrace: true,
246+
..Default::default()
247+
},
248+
));
210249

211250
Some(guard)
212251
}
@@ -305,11 +344,23 @@ async fn run() -> Result<()> {
305344
content: input_content,
306345
});
307346

347+
// Read and parse the schema file if provided
348+
let response_format = if let Some(schema_path) = args.schema {
349+
let schema_json = read_schema_file(&schema_path)?;
350+
Some(ResponseFormat {
351+
r#type: "json_schema".to_owned(),
352+
json_schema: schema_json,
353+
})
354+
} else {
355+
None
356+
};
357+
308358
let payload = RequestPayload {
309359
messages,
310360
model: &args.model,
311361
max_tokens: if args.reasoning { None } else { Some(args.tokens) },
312362
max_completion_tokens: if args.reasoning { Some(args.tokens) } else { None },
363+
response_format,
313364
};
314365

315366
let client = Client::builder()

0 commit comments

Comments
 (0)