11package aibridge
22
33import (
4- "net"
54 "net/http"
65 "net/http/httputil"
76 "net/url"
@@ -21,7 +20,39 @@ import (
2120
2221// newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically
2322// by a [intercept.Provider].
23+ // A single reverse proxy is created per provider and reused across all requests.
2424func newPassthroughRouter (prov provider.Provider , logger slog.Logger , m * metrics.Metrics , tracer trace.Tracer ) http.HandlerFunc {
25+ provBaseURL , err := url .Parse (prov .BaseURL ())
26+ if err != nil {
27+ return newInvalidBaseURLHandler (prov , logger , m , tracer , err )
28+ }
29+ if _ , err := url .JoinPath (provBaseURL .Path , "/" ); err != nil {
30+ return newInvalidBaseURLHandler (prov , logger , m , tracer , err )
31+ }
32+
33+ // Transport tuned for streaming (no response header timeout).
34+ t := & http.Transport {
35+ Proxy : http .ProxyFromEnvironment ,
36+ ForceAttemptHTTP2 : true ,
37+ MaxIdleConns : 100 ,
38+ IdleConnTimeout : 90 * time .Second ,
39+ TLSHandshakeTimeout : 10 * time .Second ,
40+ ExpectContinueTimeout : 1 * time .Second ,
41+ }
42+
43+ // Build a reverse proxy to the upstream, reused across all requests for this provider.
44+ // All request modifications happen in Rewrite.
45+ proxy := & httputil.ReverseProxy {
46+ Rewrite : func (pr * httputil.ProxyRequest ) {
47+ rewritePassthroughRequest (pr , provBaseURL , prov )
48+ },
49+ Transport : apidump .NewPassthroughMiddleware (t , prov .APIDumpDir (), prov .Name (), logger , quartz .NewReal ()),
50+ ErrorHandler : func (rw http.ResponseWriter , req * http.Request , e error ) {
51+ logger .Warn (req .Context (), "reverse proxy error" , slog .Error (e ), slog .F ("path" , req .URL .Path ))
52+ http .Error (rw , "upstream proxy error" , http .StatusBadGateway )
53+ },
54+ }
55+
2556 return func (w http.ResponseWriter , r * http.Request ) {
2657 if m != nil {
2758 m .PassthroughCount .WithLabelValues (prov .Name (), r .URL .Path , r .Method ).Add (1 )
@@ -33,88 +64,48 @@ func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics
3364 ))
3465 defer span .End ()
3566
36- upURL , err := url .Parse (prov .BaseURL ())
37- if err != nil {
38- logger .Warn (ctx , "failed to parse provider base URL" , slog .Error (err ))
39- http .Error (w , "request error" , http .StatusBadGateway )
40- span .SetStatus (codes .Error , "failed to parse provider base URL: " + err .Error ())
41- return
42- }
67+ proxy .ServeHTTP (w , r .WithContext (ctx ))
68+ }
69+ }
4370
44- // Append the request path to the upstream base path.
45- reqPath , err := url .JoinPath (upURL .Path , r .URL .Path )
46- if err != nil {
47- logger .Warn (ctx , "failed to join upstream path" , slog .Error (err ), slog .F ("upstream_path" , upURL .Path ), slog .F ("request_path" , r .URL .Path ))
48- http .Error (w , "failed to join upstream path" , http .StatusInternalServerError )
49- span .SetStatus (codes .Error , "failed to join upstream path: " + err .Error ())
50- return
51- }
52- // Ensure leading slash, proxied requests should have absolute paths.
53- // JoinPath can return relative paths, eg. when upURL path is empty.
54- if len (reqPath ) == 0 || reqPath [0 ] != '/' {
55- reqPath = "/" + reqPath
56- }
71+ // rewritePassthroughRequest configures the outbound request for the upstream and
72+ // applies proxy headers and provider auth.
73+ func rewritePassthroughRequest (pr * httputil.ProxyRequest , provBaseURL * url.URL , prov provider.Provider ) {
74+ pr .SetURL (provBaseURL )
5775
58- // Build a reverse proxy to the upstream.
59- proxy := & httputil.ReverseProxy {
60- Director : func (req * http.Request ) {
61- // Set scheme/host to upstream.
62- req .URL .Scheme = upURL .Scheme
63- req .URL .Host = upURL .Host
64- req .URL .Path = reqPath
65- req .URL .RawPath = ""
66-
67- // Preserve query string.
68- req .URL .RawQuery = r .URL .RawQuery
69-
70- // Set Host header for upstream.
71- req .Host = upURL .Host
72- span .SetAttributes (attribute .String (tracing .PassthroughUpstreamURL , req .URL .String ()))
73-
74- // Copy headers from client.
75- req .Header = r .Header .Clone ()
76-
77- // Standard proxy headers.
78- host , _ , herr := net .SplitHostPort (r .RemoteAddr )
79- if herr != nil {
80- host = r .RemoteAddr
81- }
82- if prior := req .Header .Get ("X-Forwarded-For" ); prior != "" {
83- req .Header .Set ("X-Forwarded-For" , prior + ", " + host )
84- } else {
85- req .Header .Set ("X-Forwarded-For" , host )
86- }
87- req .Header .Set ("X-Forwarded-Host" , r .Host )
88- if r .TLS != nil {
89- req .Header .Set ("X-Forwarded-Proto" , "https" )
90- } else {
91- req .Header .Set ("X-Forwarded-Proto" , "http" )
92- }
93- // Avoid default Go user-agent if none provided.
94- if _ , ok := req .Header ["User-Agent" ]; ! ok {
95- req .Header .Set ("User-Agent" , "aibridge" ) // TODO: use build tag.
96- }
97-
98- // Inject provider auth.
99- prov .InjectAuthHeader (& req .Header )
100- },
101- ErrorHandler : func (rw http.ResponseWriter , req * http.Request , e error ) {
102- logger .Warn (req .Context (), "reverse proxy error" , slog .Error (e ), slog .F ("path" , req .URL .Path ))
103- http .Error (rw , "upstream proxy error" , http .StatusBadGateway )
104- },
105- }
76+ if prior , ok := pr .In .Header ["X-Forwarded-For" ]; ok {
77+ pr .Out .Header ["X-Forwarded-For" ] = append ([]string (nil ), prior ... )
78+ }
79+ pr .SetXForwarded ()
10680
107- // Transport tuned for streaming (no response header timeout).
108- t := & http.Transport {
109- Proxy : http .ProxyFromEnvironment ,
110- ForceAttemptHTTP2 : true ,
111- MaxIdleConns : 100 ,
112- IdleConnTimeout : 90 * time .Second ,
113- TLSHandshakeTimeout : 10 * time .Second ,
114- ExpectContinueTimeout : 1 * time .Second ,
81+ span := trace .SpanFromContext (pr .Out .Context ())
82+ span .SetAttributes (attribute .String (tracing .PassthroughUpstreamURL , pr .Out .URL .String ()))
83+
84+ // Avoid default Go user-agent if none provided.
85+ if _ , ok := pr .Out .Header ["User-Agent" ]; ! ok {
86+ pr .Out .Header .Set ("User-Agent" , "aibridge" ) // TODO: use build tag.
87+ }
88+
89+ // Inject provider auth.
90+ prov .InjectAuthHeader (& pr .Out .Header )
91+ }
92+
93+ // newInvalidBaseURLHandler returns a handler that always returns 502 because
94+ // the provider's base URL is invalid.
95+ func newInvalidBaseURLHandler (prov provider.Provider , logger slog.Logger , m * metrics.Metrics , tracer trace.Tracer , baseURLErr error ) http.HandlerFunc {
96+ return func (w http.ResponseWriter , r * http.Request ) {
97+ if m != nil {
98+ m .PassthroughCount .WithLabelValues (prov .Name (), r .URL .Path , r .Method ).Add (1 )
11599 }
116- proxy .Transport = apidump .NewPassthroughMiddleware (t , prov .APIDumpDir (), prov .Name (), logger , quartz .NewReal ())
117100
118- proxy .ServeHTTP (w , r )
101+ ctx , span := tracer .Start (r .Context (), "Passthrough" , trace .WithAttributes (
102+ attribute .String (tracing .PassthroughURL , r .URL .String ()),
103+ attribute .String (tracing .PassthroughMethod , r .Method ),
104+ ))
105+ defer span .End ()
106+
107+ logger .Warn (ctx , "invalid provider base URL" , slog .Error (baseURLErr ))
108+ http .Error (w , "invalid provider base URL" , http .StatusBadGateway )
109+ span .SetStatus (codes .Error , "invalid provider base URL: " + baseURLErr .Error ())
119110 }
120111}
0 commit comments