|
| 1 | +package middleware |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "net/http/httptest" |
| 6 | + "strings" |
| 7 | + |
| 8 | + . "github.com/onsi/ginkgo/v2" |
| 9 | + . "github.com/onsi/gomega" |
| 10 | +) |
| 11 | + |
| 12 | +// The trace middleware copies request and response bodies into an in-memory |
| 13 | +// buffer that backs the admin /api/traces endpoint. With no upper bound a |
| 14 | +// chatty workload (embeddings, large completions) trivially produces a |
| 15 | +// multi-MB response that locks the Traces UI in a loading state — fetching |
| 16 | +// and parsing the payload outruns the 5-second auto-refresh. These specs |
| 17 | +// pin the capping contract so future refactors keep both the cap and the |
| 18 | +// passthrough to the real client intact. |
| 19 | + |
| 20 | +var _ = Describe("bodyWriter capping", func() { |
| 21 | + It("captures the full body when maxBytes is 0 (unlimited)", func() { |
| 22 | + downstream := httptest.NewRecorder() |
| 23 | + buf := &bytes.Buffer{} |
| 24 | + bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 0} |
| 25 | + |
| 26 | + payload := []byte(strings.Repeat("x", 4096)) |
| 27 | + n, err := bw.Write(payload) |
| 28 | + |
| 29 | + Expect(err).ToNot(HaveOccurred()) |
| 30 | + Expect(n).To(Equal(len(payload))) |
| 31 | + Expect(buf.Len()).To(Equal(len(payload))) |
| 32 | + Expect(downstream.Body.Len()).To(Equal(len(payload))) |
| 33 | + Expect(bw.truncated).To(BeFalse()) |
| 34 | + }) |
| 35 | + |
| 36 | + It("stops appending to the trace buffer once maxBytes is reached but still forwards to the client", func() { |
| 37 | + downstream := httptest.NewRecorder() |
| 38 | + buf := &bytes.Buffer{} |
| 39 | + bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 100} |
| 40 | + |
| 41 | + payload := []byte(strings.Repeat("a", 250)) |
| 42 | + n, err := bw.Write(payload) |
| 43 | + |
| 44 | + Expect(err).ToNot(HaveOccurred()) |
| 45 | + Expect(n).To(Equal(len(payload)), "Write must return the full byte count so callers see no short write") |
| 46 | + Expect(buf.Len()).To(Equal(100), "trace buffer should hold exactly maxBytes") |
| 47 | + Expect(downstream.Body.Len()).To(Equal(len(payload)), "client must still receive every byte") |
| 48 | + Expect(bw.truncated).To(BeTrue()) |
| 49 | + }) |
| 50 | + |
| 51 | + It("handles a write that straddles the cap by keeping only the leading slice", func() { |
| 52 | + downstream := httptest.NewRecorder() |
| 53 | + buf := &bytes.Buffer{} |
| 54 | + bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 10} |
| 55 | + |
| 56 | + _, err := bw.Write([]byte("12345")) |
| 57 | + Expect(err).ToNot(HaveOccurred()) |
| 58 | + Expect(bw.truncated).To(BeFalse()) |
| 59 | + |
| 60 | + _, err = bw.Write([]byte("67890ABCDE")) |
| 61 | + Expect(err).ToNot(HaveOccurred()) |
| 62 | + |
| 63 | + Expect(buf.String()).To(Equal("1234567890")) |
| 64 | + Expect(downstream.Body.String()).To(Equal("1234567890ABCDE")) |
| 65 | + Expect(bw.truncated).To(BeTrue()) |
| 66 | + }) |
| 67 | + |
| 68 | + It("ignores further writes after the cap was already hit", func() { |
| 69 | + downstream := httptest.NewRecorder() |
| 70 | + buf := &bytes.Buffer{} |
| 71 | + bw := &bodyWriter{ResponseWriter: downstream, body: buf, maxBytes: 4} |
| 72 | + |
| 73 | + _, _ = bw.Write([]byte("AAAA")) |
| 74 | + _, _ = bw.Write([]byte("BBBB")) |
| 75 | + _, _ = bw.Write([]byte("CCCC")) |
| 76 | + |
| 77 | + Expect(buf.String()).To(Equal("AAAA")) |
| 78 | + Expect(downstream.Body.String()).To(Equal("AAAABBBBCCCC")) |
| 79 | + Expect(bw.truncated).To(BeTrue()) |
| 80 | + }) |
| 81 | +}) |
| 82 | + |
| 83 | +var _ = Describe("truncateForTrace", func() { |
| 84 | + It("returns the input unchanged when below the cap", func() { |
| 85 | + in := []byte("hello") |
| 86 | + out, truncated := truncateForTrace(in, 1024) |
| 87 | + Expect(truncated).To(BeFalse()) |
| 88 | + Expect(out).To(Equal(in)) |
| 89 | + }) |
| 90 | + |
| 91 | + It("truncates when the input exceeds the cap and signals truncation", func() { |
| 92 | + in := []byte(strings.Repeat("z", 200)) |
| 93 | + out, truncated := truncateForTrace(in, 64) |
| 94 | + Expect(truncated).To(BeTrue()) |
| 95 | + Expect(out).To(HaveLen(64)) |
| 96 | + Expect(string(out)).To(Equal(strings.Repeat("z", 64))) |
| 97 | + }) |
| 98 | + |
| 99 | + It("treats maxBytes <= 0 as unlimited (back-compat with current default)", func() { |
| 100 | + in := []byte(strings.Repeat("q", 10_000)) |
| 101 | + out, truncated := truncateForTrace(in, 0) |
| 102 | + Expect(truncated).To(BeFalse()) |
| 103 | + Expect(out).To(HaveLen(len(in))) |
| 104 | + }) |
| 105 | + |
| 106 | + It("does not retain the caller's backing array (defensive copy)", func() { |
| 107 | + in := []byte("abcdefghij") |
| 108 | + out, truncated := truncateForTrace(in, 4) |
| 109 | + Expect(truncated).To(BeTrue()) |
| 110 | + Expect(string(out)).To(Equal("abcd")) |
| 111 | + |
| 112 | + // Mutating the source must not corrupt the trace copy. |
| 113 | + in[0] = 'Z' |
| 114 | + Expect(string(out)).To(Equal("abcd")) |
| 115 | + }) |
| 116 | +}) |
0 commit comments