Skip to content

Commit b81eeff

Browse files
committed
review: TLS test case, comments, span start extraction
1 parent fa31e80 commit b81eeff

2 files changed

Lines changed: 33 additions & 19 deletions

File tree

passthrough.go

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package aibridge
22

33
import (
4+
"context"
45
"net/http"
56
"net/http/httputil"
67
"net/url"
@@ -58,10 +59,7 @@ func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics
5859
m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1)
5960
}
6061

61-
ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes(
62-
attribute.String(tracing.PassthroughURL, r.URL.String()),
63-
attribute.String(tracing.PassthroughMethod, r.Method),
64-
))
62+
ctx, span := startSpan(r, tracer)
6563
defer span.End()
6664

6765
proxy.ServeHTTP(w, r.WithContext(ctx))
@@ -73,6 +71,10 @@ func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics
7371
func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL, prov provider.Provider) {
7472
pr.SetURL(provBaseURL)
7573

74+
// Rewrite sets "X-Forwarded-For" to just last hop (clients IP address).
75+
// To preserve old Director behavior pr.In "X-Forwarded-For" header
76+
// values need to be copied manually.
77+
// https://pkg.go.dev/net/http/httputil#ProxyRequest.SetXForwarded
7678
if prior, ok := pr.In.Header["X-Forwarded-For"]; ok {
7779
pr.Out.Header["X-Forwarded-For"] = append([]string(nil), prior...)
7880
}
@@ -90,22 +92,26 @@ func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL,
9092
prov.InjectAuthHeader(&pr.Out.Header)
9193
}
9294

93-
// newInvalidBaseURLHandler returns a handler that always returns 502 because
94-
// the provider's base URL is invalid.
95+
// newInvalidBaseURLHandler returns a handler that always returns 502
96+
// when the provider's base URL is invalid.
9597
func newInvalidBaseURLHandler(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer, baseURLErr error) http.HandlerFunc {
9698
return func(w http.ResponseWriter, r *http.Request) {
99+
ctx, span := startSpan(r, tracer)
100+
defer span.End()
101+
97102
if m != nil {
98103
m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1)
99104
}
100105

101-
ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes(
102-
attribute.String(tracing.PassthroughURL, r.URL.String()),
103-
attribute.String(tracing.PassthroughMethod, r.Method),
104-
))
105-
defer span.End()
106-
107106
logger.Warn(ctx, "invalid provider base URL", slog.Error(baseURLErr))
108107
http.Error(w, "invalid provider base URL", http.StatusBadGateway)
109108
span.SetStatus(codes.Error, "invalid provider base URL: "+baseURLErr.Error())
110109
}
111110
}
111+
112+
func startSpan(r *http.Request, tracer trace.Tracer) (context.Context, trace.Span) {
113+
return tracer.Start(r.Context(), "Passthrough", trace.WithAttributes(
114+
attribute.String(tracing.PassthroughURL, r.URL.String()),
115+
attribute.String(tracing.PassthroughMethod, r.Method),
116+
))
117+
}

passthrough_test.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ func TestRewritePassthroughRequest(t *testing.T) {
140140
reqRemoteAddr string
141141
reqHeaders http.Header
142142
reqTLS bool
143-
baseURL string
144143
provider *testutil.MockProvider
145144
expectURL string
146145
expectHeaders http.Header
@@ -149,7 +148,6 @@ func TestRewritePassthroughRequest(t *testing.T) {
149148
name: "sets_upstream_url_and_forwarded_headers_from_client_peer",
150149
reqPath: "http://client-host/chat?stream=true",
151150
reqRemoteAddr: "1.1.1.1:1111",
152-
baseURL: "https://upstream-host/base",
153151
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
154152
expectURL: "https://upstream-host/base/chat?stream=true",
155153
expectHeaders: http.Header{
@@ -164,7 +162,6 @@ func TestRewritePassthroughRequest(t *testing.T) {
164162
reqPath: "http://client-host/chat",
165163
reqRemoteAddr: "1.1.1.1:1111",
166164
reqHeaders: http.Header{"User-Agent": {"custom-agent/1.0"}},
167-
baseURL: "https://upstream-host/base",
168165
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
169166
expectURL: "https://upstream-host/base/chat",
170167
expectHeaders: http.Header{
@@ -178,7 +175,6 @@ func TestRewritePassthroughRequest(t *testing.T) {
178175
name: "injects_auth_header",
179176
reqPath: "http://client-host/chat",
180177
reqRemoteAddr: "1.1.1.1:1111",
181-
baseURL: "https://upstream-host/base",
182178
provider: &testutil.MockProvider{
183179
URL: "https://upstream-host/base",
184180
InjectAuthHeaderFunc: func(h *http.Header) {
@@ -201,7 +197,6 @@ func TestRewritePassthroughRequest(t *testing.T) {
201197
reqHeaders: http.Header{
202198
"X-Forwarded-For": {"2.2.2.2, 3.3.3.3"},
203199
},
204-
baseURL: "https://upstream-host/base",
205200
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
206201
expectURL: "https://upstream-host/base/chat",
207202
expectHeaders: http.Header{
@@ -211,14 +206,27 @@ func TestRewritePassthroughRequest(t *testing.T) {
211206
"User-Agent": {"aibridge"},
212207
},
213208
},
209+
{
210+
name: "tls_request_sets_forwarded_proto_to_https",
211+
reqPath: "http://client-host/chat",
212+
reqRemoteAddr: "1.1.1.1:1111",
213+
reqTLS: true,
214+
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
215+
expectURL: "https://upstream-host/base/chat",
216+
expectHeaders: http.Header{
217+
"X-Forwarded-Host": {"client-host"},
218+
"X-Forwarded-Proto": {"https"},
219+
"X-Forwarded-For": {"1.1.1.1"},
220+
"User-Agent": {"aibridge"},
221+
},
222+
},
214223
{
215224
name: "omits_forwarded_for_when_remote_addr_is_not_parseable",
216225
reqPath: "http://client-host/chat",
217226
reqRemoteAddr: "not-a-socket-address",
218227
reqHeaders: http.Header{
219228
"X-Forwarded-For": {"1.1.1.1"},
220229
},
221-
baseURL: "https://upstream-host/base",
222230
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
223231
expectURL: "https://upstream-host/base/chat",
224232
expectHeaders: http.Header{
@@ -239,7 +247,7 @@ func TestRewritePassthroughRequest(t *testing.T) {
239247
if tc.reqTLS {
240248
r.TLS = &tls.ConnectionState{}
241249
}
242-
provBaseURL, err := url.Parse(tc.baseURL)
250+
provBaseURL, err := url.Parse(tc.provider.URL)
243251
assert.NoError(t, err)
244252

245253
pr := &httputil.ProxyRequest{

0 commit comments

Comments
 (0)