11package aibridge
22
33import (
4- "net "
4+ "context "
55 "net/http"
66 "net/http/httputil"
77 "net/url"
@@ -21,100 +21,97 @@ import (
2121
2222// newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically
2323// by a [intercept.Provider].
24+ // A single reverse proxy is created per provider and reused across all requests.
2425func newPassthroughRouter (prov provider.Provider , logger slog.Logger , m * metrics.Metrics , tracer trace.Tracer ) http.HandlerFunc {
26+ provBaseURL , err := url .Parse (prov .BaseURL ())
27+ if err != nil {
28+ return newInvalidBaseURLHandler (prov , logger , m , tracer , err )
29+ }
30+ if _ , err := url .JoinPath (provBaseURL .Path , "/" ); err != nil {
31+ return newInvalidBaseURLHandler (prov , logger , m , tracer , err )
32+ }
33+
34+ // Transport tuned for streaming (no response header timeout).
35+ t := & http.Transport {
36+ Proxy : http .ProxyFromEnvironment ,
37+ ForceAttemptHTTP2 : true ,
38+ MaxIdleConns : 100 ,
39+ IdleConnTimeout : 90 * time .Second ,
40+ TLSHandshakeTimeout : 10 * time .Second ,
41+ ExpectContinueTimeout : 1 * time .Second ,
42+ }
43+
44+ // Build a reverse proxy to the upstream, reused across all requests for this provider.
45+ // All request modifications happen in Rewrite.
46+ proxy := & httputil.ReverseProxy {
47+ Rewrite : func (pr * httputil.ProxyRequest ) {
48+ rewritePassthroughRequest (pr , provBaseURL , prov )
49+ },
50+ Transport : apidump .NewPassthroughMiddleware (t , prov .APIDumpDir (), prov .Name (), logger , quartz .NewReal ()),
51+ ErrorHandler : func (rw http.ResponseWriter , req * http.Request , e error ) {
52+ logger .Warn (req .Context (), "reverse proxy error" , slog .Error (e ), slog .F ("path" , req .URL .Path ))
53+ http .Error (rw , "upstream proxy error" , http .StatusBadGateway )
54+ },
55+ }
56+
2557 return func (w http.ResponseWriter , r * http.Request ) {
2658 if m != nil {
2759 m .PassthroughCount .WithLabelValues (prov .Name (), r .URL .Path , r .Method ).Add (1 )
2860 }
2961
30- ctx , span := tracer .Start (r .Context (), "Passthrough" , trace .WithAttributes (
31- attribute .String (tracing .PassthroughURL , r .URL .String ()),
32- attribute .String (tracing .PassthroughMethod , r .Method ),
33- ))
62+ ctx , span := startSpan (r , tracer )
3463 defer span .End ()
3564
36- upURL , err := url .Parse (prov .BaseURL ())
37- if err != nil {
38- logger .Warn (ctx , "failed to parse provider base URL" , slog .Error (err ))
39- http .Error (w , "request error" , http .StatusBadGateway )
40- span .SetStatus (codes .Error , "failed to parse provider base URL: " + err .Error ())
41- return
42- }
65+ proxy .ServeHTTP (w , r .WithContext (ctx ))
66+ }
67+ }
4368
44- // Append the request path to the upstream base path.
45- reqPath , err := url .JoinPath (upURL .Path , r .URL .Path )
46- if err != nil {
47- logger .Warn (ctx , "failed to join upstream path" , slog .Error (err ), slog .F ("upstream_path" , upURL .Path ), slog .F ("request_path" , r .URL .Path ))
48- http .Error (w , "failed to join upstream path" , http .StatusInternalServerError )
49- span .SetStatus (codes .Error , "failed to join upstream path: " + err .Error ())
50- return
51- }
52- // Ensure leading slash, proxied requests should have absolute paths.
53- // JoinPath can return relative paths, eg. when upURL path is empty.
54- if len (reqPath ) == 0 || reqPath [0 ] != '/' {
55- reqPath = "/" + reqPath
56- }
69+ // rewritePassthroughRequest configures the outbound request for the upstream and
70+ // applies proxy headers and provider auth.
71+ func rewritePassthroughRequest (pr * httputil.ProxyRequest , provBaseURL * url.URL , prov provider.Provider ) {
72+ pr .SetURL (provBaseURL )
5773
58- // Build a reverse proxy to the upstream.
59- proxy := & httputil.ReverseProxy {
60- Director : func (req * http.Request ) {
61- // Set scheme/host to upstream.
62- req .URL .Scheme = upURL .Scheme
63- req .URL .Host = upURL .Host
64- req .URL .Path = reqPath
65- req .URL .RawPath = ""
66-
67- // Preserve query string.
68- req .URL .RawQuery = r .URL .RawQuery
69-
70- // Set Host header for upstream.
71- req .Host = upURL .Host
72- span .SetAttributes (attribute .String (tracing .PassthroughUpstreamURL , req .URL .String ()))
73-
74- // Copy headers from client.
75- req .Header = r .Header .Clone ()
76-
77- // Standard proxy headers.
78- host , _ , herr := net .SplitHostPort (r .RemoteAddr )
79- if herr != nil {
80- host = r .RemoteAddr
81- }
82- if prior := req .Header .Get ("X-Forwarded-For" ); prior != "" {
83- req .Header .Set ("X-Forwarded-For" , prior + ", " + host )
84- } else {
85- req .Header .Set ("X-Forwarded-For" , host )
86- }
87- req .Header .Set ("X-Forwarded-Host" , r .Host )
88- if r .TLS != nil {
89- req .Header .Set ("X-Forwarded-Proto" , "https" )
90- } else {
91- req .Header .Set ("X-Forwarded-Proto" , "http" )
92- }
93- // Avoid default Go user-agent if none provided.
94- if _ , ok := req .Header ["User-Agent" ]; ! ok {
95- req .Header .Set ("User-Agent" , "aibridge" ) // TODO: use build tag.
96- }
97-
98- // Inject provider auth.
99- prov .InjectAuthHeader (& req .Header )
100- },
101- ErrorHandler : func (rw http.ResponseWriter , req * http.Request , e error ) {
102- logger .Warn (req .Context (), "reverse proxy error" , slog .Error (e ), slog .F ("path" , req .URL .Path ))
103- http .Error (rw , "upstream proxy error" , http .StatusBadGateway )
104- },
105- }
74+ // Rewrite sets "X-Forwarded-For" to just last hop (clients IP address).
75+ // To preserve old Director behavior pr.In "X-Forwarded-For" header
76+ // values need to be copied manually.
77+ // https://pkg.go.dev/net/http/httputil#ProxyRequest.SetXForwarded
78+ if prior , ok := pr .In .Header ["X-Forwarded-For" ]; ok {
79+ pr .Out .Header ["X-Forwarded-For" ] = append ([]string (nil ), prior ... )
80+ }
81+ pr .SetXForwarded ()
82+
83+ span := trace .SpanFromContext (pr .Out .Context ())
84+ span .SetAttributes (attribute .String (tracing .PassthroughUpstreamURL , pr .Out .URL .String ()))
85+
86+ // Avoid default Go user-agent if none provided.
87+ if _ , ok := pr .Out .Header ["User-Agent" ]; ! ok {
88+ pr .Out .Header .Set ("User-Agent" , "aibridge" ) // TODO: use build tag.
89+ }
10690
107- // Transport tuned for streaming (no response header timeout).
108- t := & http.Transport {
109- Proxy : http .ProxyFromEnvironment ,
110- ForceAttemptHTTP2 : true ,
111- MaxIdleConns : 100 ,
112- IdleConnTimeout : 90 * time .Second ,
113- TLSHandshakeTimeout : 10 * time .Second ,
114- ExpectContinueTimeout : 1 * time .Second ,
91+ // Inject provider auth.
92+ prov .InjectAuthHeader (& pr .Out .Header )
93+ }
94+
95+ // newInvalidBaseURLHandler returns a handler that always returns 502
96+ // when the provider's base URL is invalid.
97+ func newInvalidBaseURLHandler (prov provider.Provider , logger slog.Logger , m * metrics.Metrics , tracer trace.Tracer , baseURLErr error ) http.HandlerFunc {
98+ return func (w http.ResponseWriter , r * http.Request ) {
99+ ctx , span := startSpan (r , tracer )
100+ defer span .End ()
101+
102+ if m != nil {
103+ m .PassthroughCount .WithLabelValues (prov .Name (), r .URL .Path , r .Method ).Add (1 )
115104 }
116- proxy .Transport = apidump .NewPassthroughMiddleware (t , prov .APIDumpDir (), prov .Name (), logger , quartz .NewReal ())
117105
118- proxy .ServeHTTP (w , r )
106+ logger .Warn (ctx , "invalid provider base URL" , slog .Error (baseURLErr ))
107+ http .Error (w , "invalid provider base URL" , http .StatusBadGateway )
108+ span .SetStatus (codes .Error , "invalid provider base URL: " + baseURLErr .Error ())
119109 }
120110}
111+
112+ func startSpan (r * http.Request , tracer trace.Tracer ) (context.Context , trace.Span ) {
113+ return tracer .Start (r .Context (), "Passthrough" , trace .WithAttributes (
114+ attribute .String (tracing .PassthroughURL , r .URL .String ()),
115+ attribute .String (tracing .PassthroughMethod , r .Method ),
116+ ))
117+ }
0 commit comments