Skip to content

Commit ae32ade

Browse files
haasonsaasclaude
andcommitted
Harden server: atomic writes, input validation, CORS restriction, review pruning
- Serialize reviews under read lock for consistent snapshots - Atomic file writes (tmp + rename) for reviews and config persistence - Validate diff_source, branch names, and feedback actions - Restrict CORS to localhost origins instead of wildcard - Return 404 for unmatched /api/ routes in SPA fallback - Add review pruning (MAX_REVIEWS=200) to prevent unbounded growth - Check git diff exit status and report stderr on failure - Saturating arithmetic for pagination to prevent overflow - Add Ollama sidecar security context in Helm chart Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 993d088 commit ae32ade

File tree

4 files changed

+158
-92
lines changed

4 files changed

+158
-92
lines changed

charts/diffscope/templates/deployment.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ spec:
8787
{{- range .Values.diffscope.extraArgs }}
8888
- {{ . | quote }}
8989
{{- end }}
90+
{{- if .Values.gitRepo.enabled }}
9091
workingDir: {{ .Values.gitRepo.mountPath }}
92+
{{- end }}
9193
ports:
9294
- name: http
9395
containerPort: {{ .Values.diffscope.port }}
@@ -153,6 +155,8 @@ spec:
153155
{{- end }}
154156
{{- if and .Values.ollama.enabled (eq .Values.ollama.mode "sidecar") }}
155157
- name: ollama
158+
securityContext:
159+
allowPrivilegeEscalation: false
156160
image: "{{ .Values.ollama.image.repository }}:{{ .Values.ollama.image.tag }}"
157161
imagePullPolicy: {{ .Values.ollama.image.pullPolicy }}
158162
ports:

src/server/api.rs

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::core::comment::CommentSynthesizer;
1414

1515
#[derive(Deserialize)]
1616
pub struct StartReviewRequest {
17-
pub diff_source: String, // "head", "staged", "branch"
17+
pub diff_source: String,
1818
pub base_branch: Option<String>,
1919
}
2020

@@ -37,7 +37,7 @@ pub struct StatusResponse {
3737
#[derive(Deserialize)]
3838
pub struct FeedbackRequest {
3939
pub comment_id: String,
40-
pub action: String, // "accept" or "reject"
40+
pub action: String,
4141
}
4242

4343
#[derive(Serialize)]
@@ -57,7 +57,6 @@ pub async fn get_status(State(state): State<Arc<AppState>>) -> Json<StatusRespon
5757
let config = state.config.read().await;
5858
let reviews = state.reviews.read().await;
5959

60-
// Try to get current branch via git2
6160
let branch = git2::Repository::discover(&state.repo_path)
6261
.ok()
6362
.and_then(|repo| {
@@ -83,18 +82,28 @@ pub async fn start_review(
8382
State(state): State<Arc<AppState>>,
8483
Json(request): Json<StartReviewRequest>,
8584
) -> Result<Json<StartReviewResponse>, (StatusCode, String)> {
86-
let id = Uuid::new_v4().to_string();
85+
// Validate diff_source
86+
let diff_source = match request.diff_source.as_str() {
87+
"head" | "staged" | "branch" => request.diff_source.clone(),
88+
_ => return Err((StatusCode::BAD_REQUEST, "Invalid diff_source: must be head, staged, or branch".to_string())),
89+
};
8790

88-
let now = std::time::SystemTime::now()
89-
.duration_since(std::time::UNIX_EPOCH)
90-
.unwrap_or_default()
91-
.as_secs() as i64;
91+
// Validate branch name if provided
92+
if let Some(ref branch) = request.base_branch {
93+
if branch.is_empty() || branch.len() > 200
94+
|| !branch.chars().all(|c| c.is_alphanumeric() || matches!(c, '/' | '-' | '_' | '.'))
95+
{
96+
return Err((StatusCode::BAD_REQUEST, "Invalid branch name".to_string()));
97+
}
98+
}
99+
100+
let id = Uuid::new_v4().to_string();
92101

93102
let session = ReviewSession {
94103
id: id.clone(),
95104
status: ReviewStatus::Pending,
96-
diff_source: request.diff_source.clone(),
97-
started_at: now,
105+
diff_source: diff_source.clone(),
106+
started_at: current_timestamp(),
98107
completed_at: None,
99108
comments: Vec::new(),
100109
summary: None,
@@ -105,10 +114,8 @@ pub async fn start_review(
105114

106115
state.reviews.write().await.insert(id.clone(), session);
107116

108-
// Spawn the review task
109117
let state_clone = state.clone();
110118
let review_id = id.clone();
111-
let diff_source = request.diff_source.clone();
112119
let base_branch = request.base_branch.clone();
113120

114121
tokio::spawn(async move {
@@ -145,30 +152,14 @@ async fn run_review_task(
145152
let config = state.config.read().await.clone();
146153
let repo_path = state.repo_path.clone();
147154

148-
// Validate branch name to prevent injection
149-
if let Some(ref branch) = base_branch {
150-
if !branch.chars().all(|c| c.is_alphanumeric() || matches!(c, '/' | '-' | '_' | '.')) {
151-
let mut reviews = state.reviews.write().await;
152-
if let Some(session) = reviews.get_mut(&review_id) {
153-
session.status = ReviewStatus::Failed;
154-
session.error = Some("Invalid branch name".to_string());
155-
session.completed_at = Some(current_timestamp());
156-
}
157-
return;
158-
}
159-
}
160-
161155
// Get the diff content based on source
162156
let diff_result = match diff_source.as_str() {
163157
"staged" => get_diff_from_git(&repo_path, "staged", None),
164158
"branch" => {
165159
let base = base_branch.as_deref().unwrap_or("main");
166160
get_diff_from_git(&repo_path, "branch", Some(base))
167161
}
168-
_ => {
169-
// "head" or default
170-
get_diff_from_git(&repo_path, "head", None)
171-
}
162+
_ => get_diff_from_git(&repo_path, "head", None),
172163
};
173164

174165
let diff_content = match diff_result {
@@ -242,7 +233,6 @@ async fn run_review_task(
242233
}
243234
}
244235
Err(_) => {
245-
// Timeout
246236
let mut reviews = state.reviews.write().await;
247237
if let Some(session) = reviews.get_mut(&review_id) {
248238
session.status = ReviewStatus::Failed;
@@ -253,6 +243,7 @@ async fn run_review_task(
253243
}
254244

255245
AppState::save_reviews_async(&state);
246+
AppState::prune_old_reviews(&state).await;
256247
}
257248

258249
fn get_diff_from_git(
@@ -274,15 +265,17 @@ fn get_diff_from_git(
274265
.current_dir(repo_path)
275266
.output()?
276267
}
277-
_ => {
278-
// head
279-
Command::new("git")
280-
.args(["diff", "HEAD~1"])
281-
.current_dir(repo_path)
282-
.output()?
283-
}
268+
_ => Command::new("git")
269+
.args(["diff", "HEAD~1"])
270+
.current_dir(repo_path)
271+
.output()?,
284272
};
285273

274+
if !output.status.success() {
275+
let stderr = String::from_utf8_lossy(&output.stderr);
276+
anyhow::bail!("git diff failed: {}", stderr.trim());
277+
}
278+
286279
Ok(String::from_utf8_lossy(&output.stdout).to_string())
287280
}
288281

@@ -307,13 +300,12 @@ pub async fn list_reviews(
307300
.values()
308301
.map(|r| {
309302
let mut r = r.clone();
310-
r.diff_content = None; // strip bulk data from list
303+
r.diff_content = None;
311304
r
312305
})
313306
.collect();
314307
list.sort_by(|a, b| b.started_at.cmp(&a.started_at));
315308

316-
// Apply pagination
317309
let page = params.page.unwrap_or(1).max(1).min(10_000);
318310
let per_page = params.per_page.unwrap_or(20).max(1).min(100);
319311
let start = (page - 1).saturating_mul(per_page);
@@ -332,10 +324,14 @@ pub async fn submit_feedback(
332324
Path(id): Path<String>,
333325
Json(request): Json<FeedbackRequest>,
334326
) -> Result<Json<FeedbackResponse>, StatusCode> {
327+
// Validate action
328+
if request.action != "accept" && request.action != "reject" {
329+
return Err(StatusCode::BAD_REQUEST);
330+
}
331+
335332
let mut reviews = state.reviews.write().await;
336333
let session = reviews.get_mut(&id).ok_or(StatusCode::NOT_FOUND)?;
337334

338-
// Find the comment and store the feedback action
339335
let comment = session
340336
.comments
341337
.iter_mut()
@@ -432,7 +428,6 @@ pub async fn get_doctor(State(state): State<Arc<AppState>>) -> Json<serde_json::
432428
pub async fn get_config(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
433429
let config = state.config.read().await;
434430
let mut value = serde_json::to_value(&*config).unwrap_or_default();
435-
// Redact API key
436431
if let Some(obj) = value.as_object_mut() {
437432
if obj.contains_key("api_key") {
438433
obj.insert("api_key".to_string(), serde_json::json!("***"));
@@ -447,12 +442,11 @@ pub async fn update_config(
447442
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
448443
let mut config = state.config.write().await;
449444

450-
// Merge updates into current config
451445
let mut current = serde_json::to_value(&*config).unwrap_or_default();
452446
if let (Some(current_obj), Some(updates_obj)) = (current.as_object_mut(), updates.as_object()) {
453447
for (key, value) in updates_obj {
454448
if key == "api_key" && value.as_str() == Some("***") {
455-
continue; // Don't overwrite with redacted value
449+
continue;
456450
}
457451
current_obj.insert(key.clone(), value.clone());
458452
}
@@ -463,8 +457,13 @@ pub async fn update_config(
463457

464458
*config = new_config;
465459
config.normalize();
460+
drop(config);
461+
462+
// Persist config to disk
463+
AppState::save_config_async(&state);
466464

467465
// Return updated config (redacted)
466+
let config = state.config.read().await;
468467
let mut result = serde_json::to_value(&*config).unwrap_or_default();
469468
if let Some(obj) = result.as_object_mut() {
470469
if obj.contains_key("api_key") {

src/server/mod.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use axum::{
77
response::{IntoResponse, Response},
88
http::{StatusCode, header},
99
};
10-
use tower_http::cors::{CorsLayer, Any};
10+
use tower_http::cors::CorsLayer;
1111
use rust_embed::Embed;
1212
use std::net::SocketAddr;
1313
use std::sync::Arc;
@@ -61,10 +61,19 @@ async fn serve_embedded(uri: axum::http::Uri) -> Response {
6161
pub async fn start_server(config: Config, host: &str, port: u16) -> anyhow::Result<()> {
6262
let state = Arc::new(state::AppState::new(config)?);
6363

64+
let allowed_origins: Vec<axum::http::HeaderValue> = [
65+
format!("http://localhost:{}", port),
66+
format!("http://127.0.0.1:{}", port),
67+
"http://localhost:5173".to_string(),
68+
]
69+
.iter()
70+
.filter_map(|s| s.parse().ok())
71+
.collect();
72+
6473
let cors = CorsLayer::new()
65-
.allow_origin(Any)
66-
.allow_methods(Any)
67-
.allow_headers(Any);
74+
.allow_origin(allowed_origins)
75+
.allow_methods(tower_http::cors::Any)
76+
.allow_headers(tower_http::cors::Any);
6877

6978
let api_routes = Router::new()
7079
.route("/status", get(api::get_status))

0 commit comments

Comments
 (0)