Skip to content

Commit 1b3fd30

Browse files
committed
feat: unify review flows with schema-first verification
1 parent 923c3aa commit 1b3fd30

28 files changed

Lines changed: 1253 additions & 666 deletions

src/adapters/anthropic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ mod tests {
354354
user_prompt: "user".to_string(),
355355
temperature: None,
356356
max_tokens: None,
357+
response_schema: None,
357358
}
358359
}
359360

@@ -593,6 +594,7 @@ mod tests {
593594
user_prompt: "u".to_string(),
594595
temperature: Some(0.9),
595596
max_tokens: Some(200),
597+
response_schema: None,
596598
})
597599
.await;
598600

src/adapters/llm.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,30 @@ pub struct LLMRequest {
4747
pub user_prompt: String,
4848
pub temperature: Option<f32>,
4949
pub max_tokens: Option<usize>,
50+
#[serde(default, skip_serializing_if = "Option::is_none")]
51+
pub response_schema: Option<StructuredOutputSchema>,
52+
}
53+
54+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
55+
pub struct StructuredOutputSchema {
56+
pub name: String,
57+
pub schema: serde_json::Value,
58+
#[serde(default = "default_true")]
59+
pub strict: bool,
60+
}
61+
62+
impl StructuredOutputSchema {
63+
pub fn json_schema(name: impl Into<String>, schema: serde_json::Value) -> Self {
64+
Self {
65+
name: name.into(),
66+
schema,
67+
strict: true,
68+
}
69+
}
70+
}
71+
72+
fn default_true() -> bool {
73+
true
5074
}
5175

5276
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -198,6 +222,7 @@ pub trait LLMAdapter: Send + Sync {
198222
user_prompt,
199223
temperature: request.temperature,
200224
max_tokens: request.max_tokens,
225+
response_schema: None,
201226
};
202227

203228
let response = self.complete(llm_request).await?;
@@ -590,6 +615,10 @@ mod tests {
590615
user_prompt: "Review this diff.".to_string(),
591616
temperature: Some(0.3),
592617
max_tokens: Some(2000),
618+
response_schema: Some(StructuredOutputSchema::json_schema(
619+
"review_comments",
620+
serde_json::json!({"type": "array"}),
621+
)),
593622
};
594623
let tools = vec![ToolDefinition {
595624
name: "read_file".to_string(),
@@ -602,6 +631,7 @@ mod tests {
602631
assert_eq!(chat_req.messages[0].role, ChatRole::User);
603632
assert_eq!(chat_req.tools.len(), 1);
604633
assert_eq!(chat_req.temperature, Some(0.3));
634+
assert_eq!(chat_req.max_tokens, Some(2000));
605635
}
606636

607637
#[test]

src/adapters/ollama.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ mod tests {
203203
user_prompt: "user".to_string(),
204204
temperature: None,
205205
max_tokens: None,
206+
response_schema: None,
206207
}
207208
}
208209

@@ -446,6 +447,7 @@ mod tests {
446447
user_prompt: "review this".to_string(),
447448
temperature: None,
448449
max_tokens: None,
450+
response_schema: None,
449451
};
450452
let result = adapter.complete(request).await;
451453

@@ -473,6 +475,7 @@ mod tests {
473475
user_prompt: "user".to_string(),
474476
temperature: Some(0.5),
475477
max_tokens: Some(200),
478+
response_schema: None,
476479
};
477480
let result = adapter.complete(request).await;
478481

src/adapters/openai.rs

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::adapters::common;
22
use crate::adapters::llm::{
33
ChatRequest, ChatResponse, ChatRole, ContentBlock, LLMAdapter, LLMRequest, LLMResponse,
4-
ModelConfig, StopReason, Usage,
4+
ModelConfig, StopReason, StructuredOutputSchema, Usage,
55
};
66
use anyhow::{Context, Result};
77
use async_trait::async_trait;
@@ -22,6 +22,8 @@ struct OpenAIRequest {
2222
messages: Vec<Message>,
2323
temperature: f32,
2424
max_tokens: usize,
25+
#[serde(skip_serializing_if = "Option::is_none")]
26+
response_format: Option<OpenAIResponseFormat>,
2527
}
2628

2729
#[derive(Serialize)]
@@ -45,6 +47,20 @@ struct Message {
4547
content: String,
4648
}
4749

50+
#[derive(Serialize)]
51+
struct OpenAIResponseFormat {
52+
#[serde(rename = "type")]
53+
format_type: String,
54+
json_schema: OpenAIJsonSchemaFormat,
55+
}
56+
57+
#[derive(Serialize)]
58+
struct OpenAIJsonSchemaFormat {
59+
name: String,
60+
schema: serde_json::Value,
61+
strict: bool,
62+
}
63+
4864
#[derive(Deserialize)]
4965
struct OpenAIResponse {
5066
choices: Vec<Choice>,
@@ -221,7 +237,15 @@ impl OpenAIAdapter {
221237

222238
#[async_trait]
223239
impl LLMAdapter for OpenAIAdapter {
224-
async fn complete(&self, request: LLMRequest) -> Result<LLMResponse> {
240+
async fn complete(&self, mut request: LLMRequest) -> Result<LLMResponse> {
241+
if request.response_schema.is_some() {
242+
if self.supports_native_response_schema() {
243+
return self.complete_chat_completions(request).await;
244+
}
245+
246+
request.response_schema = None;
247+
}
248+
225249
if should_use_responses_api(&self.config) {
226250
return self.complete_responses(request).await;
227251
}
@@ -482,6 +506,12 @@ fn should_use_responses_api(config: &ModelConfig) -> bool {
482506
}
483507

484508
impl OpenAIAdapter {
509+
fn supports_native_response_schema(&self) -> bool {
510+
self.base_url.contains("api.openai.com")
511+
|| self.base_url.contains("127.0.0.1")
512+
|| self.base_url.contains("localhost")
513+
}
514+
485515
async fn complete_chat_completions(&self, request: LLMRequest) -> Result<LLMResponse> {
486516
let messages = vec![
487517
Message {
@@ -499,6 +529,10 @@ impl OpenAIAdapter {
499529
messages,
500530
temperature: request.temperature.unwrap_or(self.config.temperature),
501531
max_tokens: request.max_tokens.unwrap_or(self.config.max_tokens),
532+
response_format: request
533+
.response_schema
534+
.as_ref()
535+
.map(to_openai_response_format),
502536
};
503537

504538
let url = format!("{}/chat/completions", self.base_url);
@@ -576,6 +610,17 @@ impl OpenAIAdapter {
576610
}
577611
}
578612

613+
fn to_openai_response_format(schema: &StructuredOutputSchema) -> OpenAIResponseFormat {
614+
OpenAIResponseFormat {
615+
format_type: "json_schema".to_string(),
616+
json_schema: OpenAIJsonSchemaFormat {
617+
name: schema.name.clone(),
618+
schema: schema.schema.clone(),
619+
strict: schema.strict,
620+
},
621+
}
622+
}
623+
579624
fn extract_response_text(response: &OpenAIResponsesResponse) -> String {
580625
let mut combined = String::new();
581626

@@ -603,8 +648,9 @@ mod tests {
603648
use super::*;
604649
use crate::adapters::llm::{
605650
ChatMessage, ChatRequest, ContentBlock as CB, LLMAdapter, LLMRequest, ModelConfig,
606-
StopReason, ToolDefinition,
651+
StopReason, StructuredOutputSchema, ToolDefinition,
607652
};
653+
use mockito::Matcher;
608654

609655
fn test_config(base_url: &str) -> ModelConfig {
610656
ModelConfig {
@@ -625,9 +671,57 @@ mod tests {
625671
user_prompt: "user".to_string(),
626672
temperature: None,
627673
max_tokens: None,
674+
response_schema: None,
628675
}
629676
}
630677

678+
#[tokio::test]
679+
async fn test_structured_output_schema_uses_chat_response_format() {
680+
let mut server = mockito::Server::new_async().await;
681+
let mock = server
682+
.mock("POST", "/chat/completions")
683+
.match_body(Matcher::PartialJsonString(
684+
serde_json::json!({
685+
"response_format": {
686+
"type": "json_schema",
687+
"json_schema": {
688+
"name": "review_findings",
689+
"strict": true
690+
}
691+
}
692+
})
693+
.to_string(),
694+
))
695+
.with_status(200)
696+
.with_header("content-type", "application/json")
697+
.with_body(
698+
r#"{
699+
"choices": [{"message": {"role": "assistant", "content": "[]"}}],
700+
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
701+
"model": "gpt-4o"
702+
}"#,
703+
)
704+
.create_async()
705+
.await;
706+
707+
let adapter = OpenAIAdapter::new(test_config(&server.url())).unwrap();
708+
let result = adapter
709+
.complete(LLMRequest {
710+
system_prompt: "system".to_string(),
711+
user_prompt: "user".to_string(),
712+
temperature: None,
713+
max_tokens: None,
714+
response_schema: Some(StructuredOutputSchema::json_schema(
715+
"review_findings",
716+
serde_json::json!({"type": "array"}),
717+
)),
718+
})
719+
.await;
720+
721+
assert!(result.is_ok());
722+
mock.assert_async().await;
723+
}
724+
631725
#[tokio::test]
632726
async fn test_successful_completion() {
633727
let mut server = mockito::Server::new_async().await;
@@ -1103,6 +1197,7 @@ mod tests {
11031197
user_prompt: "u".to_string(),
11041198
temperature: Some(0.8),
11051199
max_tokens: Some(500),
1200+
response_schema: None,
11061201
})
11071202
.await;
11081203

src/commands/git.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ async fn suggest_commit_message(config: config::Config) -> Result<()> {
8585
user_prompt,
8686
temperature: Some(0.3),
8787
max_tokens: Some(500),
88+
response_schema: None,
8889
};
8990

9091
let response = adapter.complete(request).await?;
@@ -128,6 +129,7 @@ async fn suggest_pr_title(config: config::Config) -> Result<()> {
128129
user_prompt,
129130
temperature: Some(0.3),
130131
max_tokens: Some(200),
132+
response_schema: None,
131133
};
132134

133135
let response = adapter.complete(request).await?;

src/commands/misc.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,7 @@ pub async fn feedback_command(
211211
);
212212

213213
let is_accepted = action == "accept";
214-
for comment in &comments {
215-
let _ = review::record_semantic_feedback_example(&config, comment, is_accepted).await;
216-
}
214+
let _ = review::record_semantic_feedback_examples(&config, &comments, is_accepted).await;
217215

218216
// Also record in the convention store for learned suppression/boost patterns
219217
let convention_path = resolve_convention_store_path_for_feedback(&config);
@@ -497,6 +495,7 @@ async fn answer_discussion_question(
497495
user_prompt: prompt,
498496
temperature: Some(0.2),
499497
max_tokens: Some(1200),
498+
response_schema: None,
500499
};
501500

502501
let response = adapter.complete(request).await?;

0 commit comments

Comments
 (0)