Skip to content

Commit 1d53ad0

Browse files
authored
chore: create proxy instance once per provider (#267)
1 parent 8bdd397 commit 1d53ad0

3 files changed

Lines changed: 324 additions & 106 deletions

File tree

internal/testutil/mockprovider.go

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,26 @@ import (
1111
)
1212

1313
type MockProvider struct {
14-
NameStr string
15-
URL string
16-
Bridged []string
17-
Passthrough []string
18-
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
14+
NameStr string
15+
URL string
16+
Bridged []string
17+
Passthrough []string
18+
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
19+
InjectAuthHeaderFunc func(h *http.Header)
1920
}
2021

21-
func (m *MockProvider) Type() string { return m.NameStr }
22-
func (m *MockProvider) Name() string { return m.NameStr }
23-
func (m *MockProvider) BaseURL() string { return m.URL }
24-
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) }
25-
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
26-
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
27-
func (*MockProvider) AuthHeader() string { return "Authorization" }
28-
func (*MockProvider) InjectAuthHeader(_ *http.Header) {}
22+
func (m *MockProvider) Type() string { return m.NameStr }
23+
func (m *MockProvider) Name() string { return m.NameStr }
24+
func (m *MockProvider) BaseURL() string { return m.URL }
25+
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) }
26+
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
27+
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
28+
func (*MockProvider) AuthHeader() string { return "Authorization" }
29+
func (m *MockProvider) InjectAuthHeader(h *http.Header) {
30+
if m.InjectAuthHeaderFunc != nil {
31+
m.InjectAuthHeaderFunc(h)
32+
}
33+
}
2934
func (*MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil }
3035
func (*MockProvider) APIDumpDir() string { return "" }
3136
func (m *MockProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) {

passthrough.go

Lines changed: 80 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package aibridge
22

33
import (
4-
"net"
4+
"context"
55
"net/http"
66
"net/http/httputil"
77
"net/url"
@@ -21,100 +21,97 @@ import (
2121

2222
// newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically
2323
// by a [intercept.Provider].
24+
// A single reverse proxy is created per provider and reused across all requests.
2425
func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc {
26+
provBaseURL, err := url.Parse(prov.BaseURL())
27+
if err != nil {
28+
return newInvalidBaseURLHandler(prov, logger, m, tracer, err)
29+
}
30+
if _, err := url.JoinPath(provBaseURL.Path, "/"); err != nil {
31+
return newInvalidBaseURLHandler(prov, logger, m, tracer, err)
32+
}
33+
34+
// Transport tuned for streaming (no response header timeout).
35+
t := &http.Transport{
36+
Proxy: http.ProxyFromEnvironment,
37+
ForceAttemptHTTP2: true,
38+
MaxIdleConns: 100,
39+
IdleConnTimeout: 90 * time.Second,
40+
TLSHandshakeTimeout: 10 * time.Second,
41+
ExpectContinueTimeout: 1 * time.Second,
42+
}
43+
44+
// Build a reverse proxy to the upstream, reused across all requests for this provider.
45+
// All request modifications happen in Rewrite.
46+
proxy := &httputil.ReverseProxy{
47+
Rewrite: func(pr *httputil.ProxyRequest) {
48+
rewritePassthroughRequest(pr, provBaseURL, prov)
49+
},
50+
Transport: apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal()),
51+
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) {
52+
logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path))
53+
http.Error(rw, "upstream proxy error", http.StatusBadGateway)
54+
},
55+
}
56+
2557
return func(w http.ResponseWriter, r *http.Request) {
2658
if m != nil {
2759
m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1)
2860
}
2961

30-
ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes(
31-
attribute.String(tracing.PassthroughURL, r.URL.String()),
32-
attribute.String(tracing.PassthroughMethod, r.Method),
33-
))
62+
ctx, span := startSpan(r, tracer)
3463
defer span.End()
3564

36-
upURL, err := url.Parse(prov.BaseURL())
37-
if err != nil {
38-
logger.Warn(ctx, "failed to parse provider base URL", slog.Error(err))
39-
http.Error(w, "request error", http.StatusBadGateway)
40-
span.SetStatus(codes.Error, "failed to parse provider base URL: "+err.Error())
41-
return
42-
}
65+
proxy.ServeHTTP(w, r.WithContext(ctx))
66+
}
67+
}
4368

44-
// Append the request path to the upstream base path.
45-
reqPath, err := url.JoinPath(upURL.Path, r.URL.Path)
46-
if err != nil {
47-
logger.Warn(ctx, "failed to join upstream path", slog.Error(err), slog.F("upstream_path", upURL.Path), slog.F("request_path", r.URL.Path))
48-
http.Error(w, "failed to join upstream path", http.StatusInternalServerError)
49-
span.SetStatus(codes.Error, "failed to join upstream path: "+err.Error())
50-
return
51-
}
52-
// Ensure leading slash, proxied requests should have absolute paths.
53-
// JoinPath can return relative paths, eg. when upURL path is empty.
54-
if len(reqPath) == 0 || reqPath[0] != '/' {
55-
reqPath = "/" + reqPath
56-
}
69+
// rewritePassthroughRequest configures the outbound request for the upstream and
70+
// applies proxy headers and provider auth.
71+
func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL, prov provider.Provider) {
72+
pr.SetURL(provBaseURL)
5773

58-
// Build a reverse proxy to the upstream.
59-
proxy := &httputil.ReverseProxy{
60-
Director: func(req *http.Request) {
61-
// Set scheme/host to upstream.
62-
req.URL.Scheme = upURL.Scheme
63-
req.URL.Host = upURL.Host
64-
req.URL.Path = reqPath
65-
req.URL.RawPath = ""
66-
67-
// Preserve query string.
68-
req.URL.RawQuery = r.URL.RawQuery
69-
70-
// Set Host header for upstream.
71-
req.Host = upURL.Host
72-
span.SetAttributes(attribute.String(tracing.PassthroughUpstreamURL, req.URL.String()))
73-
74-
// Copy headers from client.
75-
req.Header = r.Header.Clone()
76-
77-
// Standard proxy headers.
78-
host, _, herr := net.SplitHostPort(r.RemoteAddr)
79-
if herr != nil {
80-
host = r.RemoteAddr
81-
}
82-
if prior := req.Header.Get("X-Forwarded-For"); prior != "" {
83-
req.Header.Set("X-Forwarded-For", prior+", "+host)
84-
} else {
85-
req.Header.Set("X-Forwarded-For", host)
86-
}
87-
req.Header.Set("X-Forwarded-Host", r.Host)
88-
if r.TLS != nil {
89-
req.Header.Set("X-Forwarded-Proto", "https")
90-
} else {
91-
req.Header.Set("X-Forwarded-Proto", "http")
92-
}
93-
// Avoid default Go user-agent if none provided.
94-
if _, ok := req.Header["User-Agent"]; !ok {
95-
req.Header.Set("User-Agent", "aibridge") // TODO: use build tag.
96-
}
97-
98-
// Inject provider auth.
99-
prov.InjectAuthHeader(&req.Header)
100-
},
101-
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) {
102-
logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path))
103-
http.Error(rw, "upstream proxy error", http.StatusBadGateway)
104-
},
105-
}
74+
// Rewrite sets "X-Forwarded-For" to just last hop (clients IP address).
75+
// To preserve old Director behavior pr.In "X-Forwarded-For" header
76+
// values need to be copied manually.
77+
// https://pkg.go.dev/net/http/httputil#ProxyRequest.SetXForwarded
78+
if prior, ok := pr.In.Header["X-Forwarded-For"]; ok {
79+
pr.Out.Header["X-Forwarded-For"] = append([]string(nil), prior...)
80+
}
81+
pr.SetXForwarded()
82+
83+
span := trace.SpanFromContext(pr.Out.Context())
84+
span.SetAttributes(attribute.String(tracing.PassthroughUpstreamURL, pr.Out.URL.String()))
85+
86+
// Avoid default Go user-agent if none provided.
87+
if _, ok := pr.Out.Header["User-Agent"]; !ok {
88+
pr.Out.Header.Set("User-Agent", "aibridge") // TODO: use build tag.
89+
}
10690

107-
// Transport tuned for streaming (no response header timeout).
108-
t := &http.Transport{
109-
Proxy: http.ProxyFromEnvironment,
110-
ForceAttemptHTTP2: true,
111-
MaxIdleConns: 100,
112-
IdleConnTimeout: 90 * time.Second,
113-
TLSHandshakeTimeout: 10 * time.Second,
114-
ExpectContinueTimeout: 1 * time.Second,
91+
// Inject provider auth.
92+
prov.InjectAuthHeader(&pr.Out.Header)
93+
}
94+
95+
// newInvalidBaseURLHandler returns a handler that always returns 502
96+
// when the provider's base URL is invalid.
97+
func newInvalidBaseURLHandler(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer, baseURLErr error) http.HandlerFunc {
98+
return func(w http.ResponseWriter, r *http.Request) {
99+
ctx, span := startSpan(r, tracer)
100+
defer span.End()
101+
102+
if m != nil {
103+
m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1)
115104
}
116-
proxy.Transport = apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal())
117105

118-
proxy.ServeHTTP(w, r)
106+
logger.Warn(ctx, "invalid provider base URL", slog.Error(baseURLErr))
107+
http.Error(w, "invalid provider base URL", http.StatusBadGateway)
108+
span.SetStatus(codes.Error, "invalid provider base URL: "+baseURLErr.Error())
119109
}
120110
}
111+
112+
func startSpan(r *http.Request, tracer trace.Tracer) (context.Context, trace.Span) {
113+
return tracer.Start(r.Context(), "Passthrough", trace.WithAttributes(
114+
attribute.String(tracing.PassthroughURL, r.URL.String()),
115+
attribute.String(tracing.PassthroughMethod, r.Method),
116+
))
117+
}

0 commit comments

Comments
 (0)