diff --git a/apps/api-go/cmd/platform-api/main.go b/apps/api-go/cmd/platform-api/main.go index c8086e5f..00f3ebed 100644 --- a/apps/api-go/cmd/platform-api/main.go +++ b/apps/api-go/cmd/platform-api/main.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "time" "diffaudit/platform-api-go/internal/proxy" ) @@ -25,6 +26,12 @@ type runtimeConfig struct { const ( defaultHost = "127.0.0.1" defaultPort = "8780" + + defaultReadHeaderTimeout = 5 * time.Second + defaultReadTimeout = 15 * time.Second + defaultWriteTimeout = 30 * time.Second + defaultIdleTimeout = 60 * time.Second + defaultMaxHeaderBytes = 1 << 20 ) func parseConfig(args []string) (runtimeConfig, error) { @@ -113,7 +120,7 @@ func main() { } } - server := proxy.NewServer(proxy.Config{ + gateway := proxy.NewServer(proxy.Config{ PublicDataDir: config.PublicDataDir, RuntimeBaseURL: config.RuntimeBaseURL, BuildRevision: config.BuildRevision, @@ -126,17 +133,30 @@ func main() { }, }) - handler := server.Handler() - handler = proxy.CORSMiddleware(server.GetConfig().CORS)(handler) + handler := gateway.Handler() + handler = proxy.CORSMiddleware(gateway.GetConfig().CORS)(handler) handler = proxy.NewStructuredLogger()(handler) address := fmt.Sprintf("%s:%s", config.Host, config.Port) - if err := http.ListenAndServe(address, handler); err != nil { + server := newHTTPServer(address, handler) + if err := server.ListenAndServe(); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } } +func newHTTPServer(address string, handler http.Handler) *http.Server { + return &http.Server{ + Addr: address, + Handler: handler, + ReadHeaderTimeout: defaultReadHeaderTimeout, + ReadTimeout: defaultReadTimeout, + WriteTimeout: defaultWriteTimeout, + IdleTimeout: defaultIdleTimeout, + MaxHeaderBytes: defaultMaxHeaderBytes, + } +} + func envOrDefault(fallback string, names ...string) string { for _, name := range names { if value := os.Getenv(name); value != "" { diff --git a/apps/api-go/cmd/platform-api/main_test.go b/apps/api-go/cmd/platform-api/main_test.go index 5af193e7..b8de6a87 100644 --- a/apps/api-go/cmd/platform-api/main_test.go +++ b/apps/api-go/cmd/platform-api/main_test.go @@ -1,6 +1,9 @@ package main -import "testing" +import ( + "net/http" + "testing" +) func TestParseConfigUsesDefaults(t *testing.T) { config, err := parseConfig([]string{}) @@ -82,3 +85,30 @@ func TestParseConfigAcceptsLegacyResearchAPIFlag(t *testing.T) { t.Fatalf("expected legacy alias to override runtime upstream, got %s", config.RuntimeBaseURL) } } + +func TestNewHTTPServerUsesExplicitResourceLimits(t *testing.T) { + handler := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}) + server := newHTTPServer("127.0.0.1:8780", handler) + + if server.Addr != "127.0.0.1:8780" { + t.Fatalf("unexpected address: %s", server.Addr) + } + if server.Handler == nil { + t.Fatal("expected handler to be configured") + } + if server.ReadHeaderTimeout != defaultReadHeaderTimeout { + t.Fatalf("unexpected ReadHeaderTimeout: %v", server.ReadHeaderTimeout) + } + if server.ReadTimeout != defaultReadTimeout { + t.Fatalf("unexpected ReadTimeout: %v", server.ReadTimeout) + } + if server.WriteTimeout != defaultWriteTimeout { + t.Fatalf("unexpected WriteTimeout: %v", server.WriteTimeout) + } + if server.IdleTimeout != defaultIdleTimeout { + t.Fatalf("unexpected IdleTimeout: %v", server.IdleTimeout) + } + if server.MaxHeaderBytes != defaultMaxHeaderBytes { + t.Fatalf("unexpected MaxHeaderBytes: %d", server.MaxHeaderBytes) + } +} diff --git a/apps/api-go/internal/proxy/server.go b/apps/api-go/internal/proxy/server.go index 784b0bcc..9bef229c 100644 --- a/apps/api-go/internal/proxy/server.go +++ b/apps/api-go/internal/proxy/server.go @@ -1,6 +1,7 @@ package proxy import ( + "bytes" "encoding/json" "errors" "io" @@ -17,6 +18,9 @@ const ( defaultRuntimeTimeout = 15000 * time.Millisecond maxRetries = 3 retryDelay = 1 * time.Second + + maxAuditControlRequestBodyBytes = 1 << 20 + maxRuntimeResponseBodyBytes = 8 << 20 ) type Config struct { @@ -47,9 +51,9 @@ type Server struct { func NewServer(config Config) *Server { mux := http.NewServeMux() server := &Server{ - config: config, - mux: mux, - client: &http.Client{ + config: config, + mux: mux, + client: &http.Client{ Timeout: config.timeout(), }, cacheDir: config.PublicDataDir, @@ -212,8 +216,12 @@ func (s *Server) handleControlGet(writer http.ResponseWriter, request *http.Requ } func (s *Server) handleControlPost(writer http.ResponseWriter, request *http.Request) { - body, err := io.ReadAll(request.Body) + body, err := readBoundedRequestBody(writer, request, maxAuditControlRequestBodyBytes) if err != nil { + if errors.Is(err, errRequestBodyTooLarge) { + writeJSON(writer, http.StatusRequestEntityTooLarge, map[string]any{"detail": "request body too large"}) + return + } writePublicGatewayError(writer, "request body unavailable") return } @@ -348,7 +356,7 @@ func (s *Server) forwardControl(writer http.ResponseWriter, request *http.Reques if query := request.URL.RawQuery; query != "" { upstreamURL = upstreamURL + "?" + query } - upstreamRequest, err := http.NewRequest(request.Method, upstreamURL, strings.NewReader(string(body))) + upstreamRequest, err := http.NewRequest(request.Method, upstreamURL, bytes.NewReader(body)) if err != nil { writePublicGatewayError(writer, "runtime proxy request is misconfigured") return @@ -366,8 +374,12 @@ func (s *Server) forwardControl(writer http.ResponseWriter, request *http.Reques return } defer response.Body.Close() - responseBody, err := io.ReadAll(response.Body) + responseBody, err := readBoundedRuntimeResponseBody(response.Body) if err != nil { + if errors.Is(err, errRuntimeResponseTooLarge) { + writePublicGatewayError(writer, "runtime response too large") + return + } writePublicGatewayError(writer, "runtime response unavailable") return } @@ -405,8 +417,12 @@ func (s *Server) forwardControlWithMethod(writer http.ResponseWriter, request *h return } defer response.Body.Close() - responseBody, err := io.ReadAll(response.Body) + responseBody, err := readBoundedRuntimeResponseBody(response.Body) if err != nil { + if errors.Is(err, errRuntimeResponseTooLarge) { + writePublicGatewayError(writer, "runtime response too large") + return + } writePublicGatewayError(writer, "runtime response unavailable") return } @@ -425,7 +441,39 @@ func writePublicGatewayError(writer http.ResponseWriter, detail string) { writeJSON(writer, http.StatusBadGateway, map[string]any{"detail": detail}) } -var errSnapshotUnavailable = errors.New("snapshot unavailable") +var ( + errRequestBodyTooLarge = errors.New("request body too large") + errRuntimeResponseTooLarge = errors.New("runtime response too large") + errSnapshotUnavailable = errors.New("snapshot unavailable") +) + +func readBoundedRequestBody(writer http.ResponseWriter, request *http.Request, maxBytes int64) ([]byte, error) { + if request.ContentLength > maxBytes { + return nil, errRequestBodyTooLarge + } + + request.Body = http.MaxBytesReader(writer, request.Body, maxBytes) + body, err := io.ReadAll(request.Body) + if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + return nil, errRequestBodyTooLarge + } + return nil, err + } + return body, nil +} + +func readBoundedRuntimeResponseBody(reader io.Reader) ([]byte, error) { + body, err := io.ReadAll(io.LimitReader(reader, maxRuntimeResponseBodyBytes+1)) + if err != nil { + return nil, err + } + if int64(len(body)) > maxRuntimeResponseBodyBytes { + return nil, errRuntimeResponseTooLarge + } + return body, nil +} func (s *Server) doWithRetry(request *http.Request, maxAttempts int) (*http.Response, error) { // Only retry safe, idempotent methods (GET, HEAD). diff --git a/apps/api-go/internal/proxy/server_test.go b/apps/api-go/internal/proxy/server_test.go index 31230bae..a0106e20 100644 --- a/apps/api-go/internal/proxy/server_test.go +++ b/apps/api-go/internal/proxy/server_test.go @@ -539,6 +539,33 @@ func TestCreateJobEndpointIsProxied(t *testing.T) { } } +func TestCreateJobRejectsOversizedBodyBeforeProxy(t *testing.T) { + upstreamCalled := false + upstream := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + upstreamCalled = true + writeJSON(writer, http.StatusAccepted, map[string]any{"job_id": "should-not-run"}) + })) + defer upstream.Close() + + server := NewServer(Config{RuntimeBaseURL: upstream.URL}) + body := bytes.Repeat([]byte("a"), maxAuditControlRequestBodyBytes+1) + request := httptest.NewRequest(http.MethodPost, "/api/v1/audit/jobs", bytes.NewReader(body)) + request.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + + server.Handler().ServeHTTP(recorder, request) + + if recorder.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("expected 413, got %d: %s", recorder.Code, recorder.Body.String()) + } + if upstreamCalled { + t.Fatal("oversized request body was forwarded to runtime") + } + if !strings.Contains(recorder.Body.String(), "request body too large") { + t.Fatalf("expected generic oversized-body detail, got %s", recorder.Body.String()) + } +} + func TestCreateGrayBoxJobEndpointIsProxied(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { if request.Method != http.MethodPost { @@ -943,6 +970,29 @@ func TestBadGatewayResponseIsSafe(t *testing.T) { } } +func TestRuntimeResponseTooLargeIsRejected(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusOK) + _, _ = writer.Write(bytes.Repeat([]byte("a"), maxRuntimeResponseBodyBytes+1)) + })) + defer upstream.Close() + + server := NewServer(Config{RuntimeBaseURL: upstream.URL}) + request := httptest.NewRequest(http.MethodGet, "/api/v1/audit/jobs", nil) + recorder := httptest.NewRecorder() + + server.Handler().ServeHTTP(recorder, request) + + if recorder.Code != http.StatusBadGateway { + t.Fatalf("expected 502, got %d", recorder.Code) + } + raw := recorder.Body.String() + if !strings.Contains(raw, "runtime response too large") { + t.Fatalf("expected bounded runtime response detail, got %s", raw) + } +} + // ── Retry and error handling ─────────────────────────────────────────────────── func TestRuntimeErrorHint(t *testing.T) { diff --git a/apps/web/src/app/api/v1/audit/jobs/public-facade.test.ts b/apps/web/src/app/api/v1/audit/jobs/public-facade.test.ts index 45cd52e4..817aa90f 100644 --- a/apps/web/src/app/api/v1/audit/jobs/public-facade.test.ts +++ b/apps/web/src/app/api/v1/audit/jobs/public-facade.test.ts @@ -96,6 +96,50 @@ describe("audit job public facade routes", () => { }); }); + it("rejects live audit job creation when content-length exceeds the facade body limit", async () => { + const fetchMock = vi.fn(); + vi.stubGlobal("fetch", fetchMock); + + const route = await import("./route"); + const response = await route.POST(new Request("http://localhost/api/v1/audit/jobs", { + method: "POST", + headers: { + cookie: "platform-demo-mode=0", + "content-length": String((1 << 20) + 1), + }, + body: "{}", + })); + const payload = await response.json(); + + expect(response.status).toBe(413); + expect(payload).toEqual({ detail: "request body too large" }); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it("rejects live audit job creation when streamed body exceeds the facade body limit", async () => { + const fetchMock = vi.fn(); + vi.stubGlobal("fetch", fetchMock); + + const route = await import("./route"); + const response = await route.POST(new Request("http://localhost/api/v1/audit/jobs", { + method: "POST", + headers: { cookie: "platform-demo-mode=0" }, + body: new ReadableStream({ + start(controller) { + controller.enqueue(new Uint8Array(1 << 20)); + controller.enqueue(new Uint8Array(1)); + controller.close(); + }, + }), + duplex: "half", + } as RequestInit)); + const payload = await response.json(); + + expect(response.status).toBe(413); + expect(payload).toEqual({ detail: "request body too large" }); + expect(fetchMock).not.toHaveBeenCalled(); + }); + it("sanitizes live audit job cancellation responses before returning them", async () => { vi.stubGlobal("fetch", vi.fn().mockResolvedValue(Response.json({ ok: true, diff --git a/apps/web/src/app/api/v1/audit/jobs/route.ts b/apps/web/src/app/api/v1/audit/jobs/route.ts index 274ffab3..8f919d52 100644 --- a/apps/web/src/app/api/v1/audit/jobs/route.ts +++ b/apps/web/src/app/api/v1/audit/jobs/route.ts @@ -3,6 +3,8 @@ import { sanitizeAuditJobPayload } from "@/lib/audit-job-payload"; import { isDemoModeEnabledServer } from "@/lib/demo-mode"; import { createDemoJob, listDemoJobs } from "@/lib/demo-jobs-store"; +const MAX_AUDIT_CONTROL_REQUEST_BODY_BYTES = 1 << 20; + export async function GET(request: Request) { if (await isDemoModeEnabledServer(request)) { return Response.json(sanitizeAuditJobPayload({ jobs: listDemoJobs() })); @@ -15,8 +17,13 @@ export async function GET(request: Request) { } export async function POST(request: Request) { + const bodyResult = await readAuditControlBody(request); + if (!bodyResult.ok) { + return bodyResult.response; + } + if (await isDemoModeEnabledServer(request)) { - const payload = (await request.json().catch(() => null)) as Record | null; + const payload = parseJsonObject(bodyResult.body); const job = createDemoJob({ contract_key: typeof payload?.contract_key === "string" ? payload.contract_key : undefined, workspace_name: typeof payload?.workspace_name === "string" ? payload.workspace_name : undefined, @@ -35,8 +42,83 @@ export async function POST(request: Request) { "/api/v1/audit/jobs", { method: "POST", - body: await request.text(), + body: bodyResult.body, }, sanitizeAuditJobPayload, ); } + +type BodyReadResult = + | { ok: true; body: string } + | { ok: false; response: Response }; + +async function readAuditControlBody(request: Request): Promise { + if (isContentLengthTooLarge(request.headers.get("content-length"))) { + return oversizedBodyResponse(); + } + + if (!request.body) { + return { ok: true, body: "" }; + } + + const reader = request.body.getReader(); + const chunks: Uint8Array[] = []; + let totalBytes = 0; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + if (!value) continue; + + totalBytes += value.byteLength; + if (totalBytes > MAX_AUDIT_CONTROL_REQUEST_BODY_BYTES) { + await reader.cancel().catch(() => {}); + return oversizedBodyResponse(); + } + + chunks.push(value); + } + } catch { + return { + ok: false, + response: Response.json({ detail: "request body unavailable" }, { status: 400 }), + }; + } + + return { ok: true, body: new TextDecoder().decode(joinChunks(chunks, totalBytes)) }; +} + +function isContentLengthTooLarge(value: string | null) { + if (!value) return false; + const contentLength = Number(value); + return Number.isFinite(contentLength) && contentLength > MAX_AUDIT_CONTROL_REQUEST_BODY_BYTES; +} + +function oversizedBodyResponse(): BodyReadResult { + return { + ok: false, + response: Response.json({ detail: "request body too large" }, { status: 413 }), + }; +} + +function joinChunks(chunks: Uint8Array[], totalBytes: number) { + const result = new Uint8Array(totalBytes); + let offset = 0; + for (const chunk of chunks) { + result.set(chunk, offset); + offset += chunk.byteLength; + } + return result; +} + +function parseJsonObject(body: string): Record | null { + try { + const payload = JSON.parse(body); + return payload && typeof payload === "object" && !Array.isArray(payload) + ? (payload as Record) + : null; + } catch { + return null; + } +}