Skip to content

Commit cdc768e

Browse files
authored
fix(arrow/flight): deliver response headers eagerly in streaming client middleware (#801)
### Rationale for this change Fixes #755. The cookie middleware (`NewClientCookieMiddleware`) does not capture `Set-Cookie` headers returned in response to a streaming RPC like `Handshake` when the server also sends back a response payload. `ClientHeadersMiddleware.HeadersReceived` was only invoked from `finishFn`, which fires when `Recv()` returns `io.EOF` or a non-`io.EOF` error. `AuthenticateBasicToken` calls `Recv()` exactly once; if the server sends a `HandshakeResponse` payload (common when the Handshake carries auth data or a session cookie), `Recv()` returns that message rather than `io.EOF` and `finishFn` never fires. The cookie middleware never sees the response headers, so the session cookie is dropped and subsequent RPCs go out without it, even though the user reports cookies ARE delivered on other endpoints like `GetFlightInfo` (unary RPCs capture headers synchronously via `grpc.Header(&md)`). ### What changes are included in this PR? - `clientStream.Header()` now delivers response metadata to `ClientHeadersMiddleware` at-most-once (guarded by `atomic.Bool.CompareAndSwap`) the first time headers are successfully retrieved for a streaming RPC. - The existing `finishFn` path is unchanged so: - trailers are still captured when the stream completes, and - callers that never explicitly invoke `Header()` get the exact same behavior as before. - Added four regression tests in `arrow/flight/handshake_cookie_test.go` covering: 1. `Set-Cookie` in Handshake response **headers** (via `AuthenticateBasicToken`) 2. `Set-Cookie` in Handshake response **trailers** 3. `Set-Cookie` + server-sent `HandshakeResponse` payload (the precise scenario reported in #755 — fails without this fix) 4. Eager capture when `stream.Header()` is inspected before draining the stream (also fails without this fix) ### Are these changes tested? Yes. The four new tests in `arrow/flight/handshake_cookie_test.go` reproduce the regression. Tests 3 and 4 fail without the fix and pass with it. The existing middleware/cookie tests continue to pass, including with `-race`. ### Are there any user-facing changes? Minor behavioral refinement of `ClientHeadersMiddleware` for streaming RPCs: `HeadersReceived` may now be invoked up to twice on a streaming RPC whose caller explicitly calls `stream.Header()` — once with just the response headers (from `Header()`), and again with headers+trailers joined (from the existing `finishFn` path at stream completion). This is backward compatible for the in-tree `clientCookieMiddleware` (cookie updates are keyed by `name+path` and idempotent). Callers that never explicitly call `stream.Header()` see no change in behavior.
1 parent 23c1ed3 commit cdc768e

3 files changed

Lines changed: 354 additions & 0 deletions

File tree

arrow/flight/client.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,19 @@ func CreateClientMiddleware(middleware CustomClientMiddleware) ClientMiddleware
175175
desc: desc,
176176
finishFn: finishFunc,
177177
}
178+
if isHdrs {
179+
// Deliver response headers to the middleware as soon as they
180+
// are first retrieved via Header(), rather than waiting for
181+
// the stream to finish. This is necessary for streaming RPCs
182+
// like Handshake where the caller may inspect headers (e.g.
183+
// Set-Cookie) and issue subsequent RPCs before the stream
184+
// reaches io.EOF (e.g. when the server sends a response
185+
// payload that causes Recv to return a message instead of
186+
// EOF). See GH-755.
187+
newCS.onHeaders = func(md metadata.MD) {
188+
hdrs.HeadersReceived(csCtx, md)
189+
}
190+
}
178191
// The `ClientStream` interface allows one to omit calling `Recv` if it's
179192
// known that the result will be `io.EOF`. See
180193
// http://stackoverflow.com/q/42915337
@@ -193,12 +206,24 @@ type clientStream struct {
193206
grpc.ClientStream
194207
desc *grpc.StreamDesc
195208
finishFn func(error)
209+
210+
// onHeaders, when non-nil, is invoked at most once with the response
211+
// metadata the first time Header() returns successfully. It allows
212+
// middleware (e.g. cookie middleware) to observe server headers as
213+
// soon as they arrive on streaming RPCs, rather than waiting for the
214+
// stream to finish via finishFn. See GH-755.
215+
onHeaders func(md metadata.MD)
216+
headersObserved atomic.Bool
196217
}
197218

198219
func (cs *clientStream) Header() (metadata.MD, error) {
199220
md, err := cs.ClientStream.Header()
200221
if err != nil {
201222
cs.finishFn(err)
223+
return md, err
224+
}
225+
if cs.onHeaders != nil && cs.headersObserved.CompareAndSwap(false, true) {
226+
cs.onHeaders(md)
202227
}
203228
return md, err
204229
}

arrow/flight/flightsql/driver/driver_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,6 +1819,11 @@ func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flight
18191819
if !s.ExpectedPreparedStatementSchema.Equal(r.Schema()) {
18201820
return nil, errors.New("parameter schema: unexpected")
18211821
}
1822+
// See GH-35328: drain remaining batches before returning to avoid
1823+
// the io.EOF race between server close and client Write. The other
1824+
// success path below already does this; this branch must too.
1825+
for r.Next() {
1826+
}
18221827
return qry.GetPreparedStatementHandle(), nil
18231828
}
18241829

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
package flight_test
18+
19+
import (
20+
"context"
21+
"encoding/base64"
22+
"errors"
23+
"io"
24+
"strings"
25+
"sync"
26+
"testing"
27+
28+
"github.com/apache/arrow-go/v18/arrow/flight"
29+
"github.com/stretchr/testify/assert"
30+
"github.com/stretchr/testify/require"
31+
"google.golang.org/grpc"
32+
"google.golang.org/grpc/credentials/insecure"
33+
"google.golang.org/grpc/metadata"
34+
)
35+
36+
// handshakeCookieFlightServer is a flight server that emits Set-Cookie
37+
// response headers (and trailers) during Handshake, simulating a server
38+
// that creates a session during the authentication flow (see GH-755).
39+
type handshakeCookieFlightServer struct {
40+
flight.BaseFlightServer
41+
42+
headerCookie string // cookie attached via SendHeader during Handshake
43+
trailerCookie string // cookie attached via SetTrailer during Handshake
44+
bearerToken string // authorization header returned during Handshake
45+
sendPayload bool // if true, server sends a HandshakeResponse payload before closing
46+
mu sync.Mutex
47+
lastIncomingCook []string // incoming Cookie header values observed on ListFlights
48+
}
49+
50+
func (h *handshakeCookieFlightServer) Handshake(stream flight.FlightService_HandshakeServer) error {
51+
md := metadata.MD{}
52+
if h.headerCookie != "" {
53+
md.Append("set-cookie", h.headerCookie)
54+
}
55+
if h.bearerToken != "" {
56+
md.Append("authorization", "Bearer "+h.bearerToken)
57+
}
58+
if len(md) > 0 {
59+
if err := stream.SendHeader(md); err != nil {
60+
return err
61+
}
62+
}
63+
64+
if h.trailerCookie != "" {
65+
stream.SetTrailer(metadata.Pairs("set-cookie", h.trailerCookie))
66+
}
67+
68+
if h.sendPayload {
69+
if err := stream.Send(&flight.HandshakeResponse{Payload: []byte("handshake-ok")}); err != nil {
70+
return err
71+
}
72+
}
73+
74+
// Drain the client stream until it closes.
75+
for {
76+
if _, err := stream.Recv(); err != nil {
77+
if errors.Is(err, io.EOF) {
78+
return nil
79+
}
80+
return err
81+
}
82+
}
83+
}
84+
85+
func (h *handshakeCookieFlightServer) ListFlights(c *flight.Criteria, fs flight.FlightService_ListFlightsServer) error {
86+
h.mu.Lock()
87+
if md, ok := metadata.FromIncomingContext(fs.Context()); ok {
88+
h.lastIncomingCook = append([]string(nil), md.Get("cookie")...)
89+
} else {
90+
h.lastIncomingCook = nil
91+
}
92+
h.mu.Unlock()
93+
return nil
94+
}
95+
96+
func (h *handshakeCookieFlightServer) observedCookies() []string {
97+
h.mu.Lock()
98+
defer h.mu.Unlock()
99+
return append([]string(nil), h.lastIncomingCook...)
100+
}
101+
102+
// TestHandshakeCookiePropagationViaAuthenticateBasicToken is a regression
103+
// test for GH-755. It asserts that Set-Cookie headers returned by a
104+
// Handshake/DoHandshake response are captured by the cookie middleware
105+
// and attached to subsequent requests.
106+
func TestHandshakeCookiePropagationViaAuthenticateBasicToken(t *testing.T) {
107+
srv := &handshakeCookieFlightServer{
108+
headerCookie: "session_id=sess_header_abc",
109+
bearerToken: "my-bearer-token",
110+
}
111+
112+
s := flight.NewServerWithMiddleware(nil)
113+
s.Init("localhost:0")
114+
s.RegisterFlightService(srv)
115+
116+
go s.Serve()
117+
defer s.Shutdown()
118+
119+
creds := grpc.WithTransportCredentials(insecure.NewCredentials())
120+
client, err := flight.NewClientWithMiddleware(
121+
s.Addr().String(),
122+
nil,
123+
[]flight.ClientMiddleware{flight.NewClientCookieMiddleware()},
124+
creds,
125+
)
126+
require.NoError(t, err)
127+
defer client.Close()
128+
129+
ctx, err := client.AuthenticateBasicToken(context.Background(), "user", "pass")
130+
require.NoError(t, err)
131+
132+
// Make a follow-up RPC. The cookie middleware must have captured
133+
// Set-Cookie from the Handshake response, and StartCall should
134+
// attach it as a Cookie header on this call.
135+
stream, err := client.ListFlights(ctx, &flight.Criteria{})
136+
require.NoError(t, err)
137+
for {
138+
if _, err := stream.Recv(); err != nil {
139+
if errors.Is(err, io.EOF) {
140+
break
141+
}
142+
require.NoError(t, err)
143+
}
144+
}
145+
146+
cookies := srv.observedCookies()
147+
require.Len(t, cookies, 1, "expected exactly one Cookie header, got %v", cookies)
148+
assert.Contains(t, cookies[0], "session_id=sess_header_abc",
149+
"cookie middleware should propagate Set-Cookie from Handshake response headers")
150+
}
151+
152+
// TestHandshakeCookiePropagationFromTrailers ensures cookies delivered as
153+
// gRPC trailers (instead of initial metadata headers) are also captured
154+
// by the cookie middleware during Handshake.
155+
func TestHandshakeCookiePropagationFromTrailers(t *testing.T) {
156+
srv := &handshakeCookieFlightServer{
157+
trailerCookie: "session_id=sess_trailer_xyz",
158+
bearerToken: "my-bearer-token",
159+
}
160+
161+
s := flight.NewServerWithMiddleware(nil)
162+
s.Init("localhost:0")
163+
s.RegisterFlightService(srv)
164+
165+
go s.Serve()
166+
defer s.Shutdown()
167+
168+
creds := grpc.WithTransportCredentials(insecure.NewCredentials())
169+
client, err := flight.NewClientWithMiddleware(
170+
s.Addr().String(),
171+
nil,
172+
[]flight.ClientMiddleware{flight.NewClientCookieMiddleware()},
173+
creds,
174+
)
175+
require.NoError(t, err)
176+
defer client.Close()
177+
178+
ctx, err := client.AuthenticateBasicToken(context.Background(), "user", "pass")
179+
require.NoError(t, err)
180+
181+
stream, err := client.ListFlights(ctx, &flight.Criteria{})
182+
require.NoError(t, err)
183+
for {
184+
if _, err := stream.Recv(); err != nil {
185+
if errors.Is(err, io.EOF) {
186+
break
187+
}
188+
require.NoError(t, err)
189+
}
190+
}
191+
192+
cookies := srv.observedCookies()
193+
require.Len(t, cookies, 1, "expected exactly one Cookie header, got %v", cookies)
194+
assert.Contains(t, cookies[0], "session_id=sess_trailer_xyz",
195+
"cookie middleware should propagate Set-Cookie from Handshake response trailers")
196+
}
197+
198+
// TestHandshakeCookiePropagationWithServerPayload is the precise scenario
199+
// reported in GH-755. The server attaches a Set-Cookie header AND sends
200+
// back a HandshakeResponse payload. AuthenticateBasicToken only calls
201+
// stream.Recv() once, which returns the payload (not io.EOF), so the
202+
// streaming finishFn that would normally invoke HeadersReceived never
203+
// fires. The cookie middleware must still capture the header cookie.
204+
func TestHandshakeCookiePropagationWithServerPayload(t *testing.T) {
205+
srv := &handshakeCookieFlightServer{
206+
headerCookie: "session_id=sess_with_payload",
207+
bearerToken: "my-bearer-token",
208+
sendPayload: true,
209+
}
210+
211+
s := flight.NewServerWithMiddleware(nil)
212+
s.Init("localhost:0")
213+
s.RegisterFlightService(srv)
214+
215+
go s.Serve()
216+
defer s.Shutdown()
217+
218+
creds := grpc.WithTransportCredentials(insecure.NewCredentials())
219+
client, err := flight.NewClientWithMiddleware(
220+
s.Addr().String(),
221+
nil,
222+
[]flight.ClientMiddleware{flight.NewClientCookieMiddleware()},
223+
creds,
224+
)
225+
require.NoError(t, err)
226+
defer client.Close()
227+
228+
ctx, err := client.AuthenticateBasicToken(context.Background(), "user", "pass")
229+
require.NoError(t, err)
230+
231+
stream, err := client.ListFlights(ctx, &flight.Criteria{})
232+
require.NoError(t, err)
233+
for {
234+
if _, err := stream.Recv(); err != nil {
235+
if errors.Is(err, io.EOF) {
236+
break
237+
}
238+
require.NoError(t, err)
239+
}
240+
}
241+
242+
cookies := srv.observedCookies()
243+
require.Len(t, cookies, 1,
244+
"expected exactly one Cookie header, got %v (GH-755: cookie lost when Handshake returns a payload)", cookies)
245+
assert.Contains(t, cookies[0], "session_id=sess_with_payload")
246+
}
247+
248+
// TestHandshakeCookieProcessedBeforeRecv verifies cookies are captured
249+
// eagerly once stream.Header() returns successfully. This models the
250+
// scenario where an application-level Handshake flow inspects response
251+
// headers and makes further RPCs before draining the stream.
252+
func TestHandshakeCookieProcessedBeforeRecv(t *testing.T) {
253+
srv := &handshakeCookieFlightServer{
254+
headerCookie: "session_id=eager_capture",
255+
}
256+
257+
s := flight.NewServerWithMiddleware(nil)
258+
s.Init("localhost:0")
259+
s.RegisterFlightService(srv)
260+
261+
go s.Serve()
262+
defer s.Shutdown()
263+
264+
cookies := flight.NewCookieMiddleware()
265+
creds := grpc.WithTransportCredentials(insecure.NewCredentials())
266+
client, err := flight.NewClientWithMiddleware(
267+
s.Addr().String(),
268+
nil,
269+
[]flight.ClientMiddleware{flight.CreateClientMiddleware(cookies)},
270+
creds,
271+
)
272+
require.NoError(t, err)
273+
defer client.Close()
274+
275+
// Drive the Handshake manually; inspect headers before calling Recv().
276+
authCtx := metadata.AppendToOutgoingContext(context.Background(),
277+
"Authorization", "Basic "+base64.RawStdEncoding.EncodeToString([]byte("user:pass")))
278+
279+
stream, err := client.Handshake(authCtx)
280+
require.NoError(t, err)
281+
require.NoError(t, stream.CloseSend())
282+
283+
hdr, err := stream.Header()
284+
require.NoError(t, err)
285+
require.Contains(t, strings.Join(hdr.Get("set-cookie"), ","), "eager_capture")
286+
287+
// Clone the middleware while the original Handshake stream is still
288+
// open. If cookies were processed eagerly from the header, the clone
289+
// should already contain the session cookie.
290+
cloned := cookies.Clone()
291+
292+
// Using the clone, make a unary-ish request against a second client
293+
// to observe the outgoing Cookie header.
294+
clientB, err := flight.NewClientWithMiddleware(
295+
s.Addr().String(),
296+
nil,
297+
[]flight.ClientMiddleware{flight.CreateClientMiddleware(cloned)},
298+
creds,
299+
)
300+
require.NoError(t, err)
301+
defer clientB.Close()
302+
303+
ls, err := clientB.ListFlights(context.Background(), &flight.Criteria{})
304+
require.NoError(t, err)
305+
for {
306+
if _, err := ls.Recv(); err != nil {
307+
if errors.Is(err, io.EOF) {
308+
break
309+
}
310+
require.NoError(t, err)
311+
}
312+
}
313+
314+
got := srv.observedCookies()
315+
require.Len(t, got, 1, "expected cloned middleware to send cookie from eagerly captured Handshake header, got %v", got)
316+
assert.Contains(t, got[0], "session_id=eager_capture")
317+
318+
// Clean up original stream.
319+
for {
320+
if _, err := stream.Recv(); err != nil {
321+
break
322+
}
323+
}
324+
}

0 commit comments

Comments
 (0)