Skip to content

Commit 7ff1a81

Browse files
authored
aitools: extract pollStatement helper and pin OnWaitTimeout (#5092)
## Stack This PR is part of a 4-PR stack making `aitools` data exploration faster for ai-dev-kit. Each PR is independently reviewable; merge in order. 1. **#5092 — aitools: extract pollStatement helper and pin OnWaitTimeout** *(base: `main`)* — **this PR** 2. #5093 — aitools: run multiple SQL queries in parallel from one query invocation *(base: #5092)* 3. #5095 — aitools: add 'tools statement' lifecycle commands *(base: #5093)* 4. #5097 — aitools: parallelize discover-schema across tables and probes *(base: #5095)* Use `git diff <base>...HEAD` or set the comparison base in the GitHub UI to see only this PR's changes; the default "Files changed" diff against `main` includes ancestor PRs. --- ## Why The query command in `experimental/aitools/cmd/query.go` works today, but two things make it fragile and hard to reuse: 1. The polling loop, signal handling, spinner, and server-side cancellation are entangled in one ~100-line function. Upcoming features (parallel batch queries, a statement lifecycle command tree) need pure polling without the signal-handler side effects, so the helper has to come out cleanly. 2. The `ExecuteStatement` request sets `WaitTimeout: 0s` but does not set `OnWaitTimeout`. That relies on the SDK's default being `CONTINUE`. It is today, but a flip would silently break the command: the statement would be cancelled before our first GET and we'd never see the result. This PR is a pure refactor + one explicit-default fix. No user-visible behavior change. ## Changes - Extract `pollStatement(ctx, api, resp)` from `executeAndPoll`. The helper polls until the statement reaches a terminal state and returns the response. It does not call `CancelExecution` on context cancellation, that's the caller's job (and a deliberate design choice for the upcoming `statement get` command, where Ctrl+C should stop polling without killing the server-side statement). - Pin `OnWaitTimeout: CONTINUE` explicitly on the `ExecuteStatement` call. - Update `executeAndPoll` to delegate to `pollStatement` and keep the existing signal-handling, spinner, and server-side cancel-on-Ctrl+C semantics intact. - Add five unit tests covering the new helper: - Immediate terminal short-circuit (no Get calls) - Failed terminal returned without error (caller decides) - Eventual success across multiple polls - Context cancellation returns ctx error and does NOT call CancelExecution - GetStatement transport error is wrapped and propagated - Update the existing `TestExecuteAndPollImmediateSuccess` matcher to assert `OnWaitTimeout == CONTINUE` so a future SDK default flip cannot regress us. ## Test plan - [x] `go test ./experimental/aitools/...` passes (10 polling-related cases including the 5 new ones). - [x] `make checks` clean (tidy, whitespace, dead code). - [x] `make fmt` no drift. - [x] `make lint` 0 issues. - [x] Existing `executeAndPoll` tests (immediate success, immediate failure, polling, fail-during-poll, ctx-cancellation-calls-cancel-execution) all still pass without modification beyond the matcher tweak.
1 parent 30e21b5 commit 7ff1a81

2 files changed

Lines changed: 146 additions & 22 deletions

File tree

experimental/aitools/cmd/query.go

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -262,21 +262,17 @@ func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, e
262262
func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) {
263263
// Submit asynchronously to get the statement ID immediately for cancellation.
264264
resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{
265-
WarehouseId: warehouseID,
266-
Statement: statement,
267-
WaitTimeout: "0s",
265+
WarehouseId: warehouseID,
266+
Statement: statement,
267+
WaitTimeout: "0s",
268+
OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue,
268269
})
269270
if err != nil {
270271
return nil, fmt.Errorf("execute statement: %w", err)
271272
}
272273

273274
statementID := resp.StatementId
274275

275-
// Check if it completed immediately.
276-
if isTerminalState(resp.Status) {
277-
return resp, checkFailedState(resp.Status)
278-
}
279-
280276
// Set up Ctrl+C: signal cancels the poll context, cleanup is unified below.
281277
pollCtx, pollCancel := context.WithCancel(ctx)
282278
defer pollCancel()
@@ -327,34 +323,59 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa
327323
}
328324
}()
329325

326+
pollResp, err := pollStatement(pollCtx, api, resp)
327+
if err != nil {
328+
if pollCtx.Err() != nil {
329+
cancelStatement()
330+
cmdio.LogString(ctx, "Query cancelled.")
331+
return nil, root.ErrAlreadyPrinted
332+
}
333+
return nil, err
334+
}
335+
336+
sp.Close()
337+
if err := checkFailedState(pollResp.Status); err != nil {
338+
return nil, err
339+
}
340+
return pollResp, nil
341+
}
342+
343+
// pollStatement polls until the statement reaches a terminal state.
344+
//
345+
// On context cancellation it returns the context error WITHOUT cancelling the
346+
// server-side statement. Callers that want server-side cancellation should
347+
// invoke CancelExecution explicitly.
348+
//
349+
// If the input response is already in a terminal state, it is returned without
350+
// further polling.
351+
func pollStatement(ctx context.Context, api sql.StatementExecutionInterface, resp *sql.StatementResponse) (*sql.StatementResponse, error) {
352+
if isTerminalState(resp.Status) {
353+
return resp, nil
354+
}
355+
356+
statementID := resp.StatementId
357+
start := time.Now()
358+
330359
// Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped).
331360
interval := pollIntervalInitial
332361
for {
333362
select {
334-
case <-pollCtx.Done():
335-
cancelStatement()
336-
cmdio.LogString(ctx, "Query cancelled.")
337-
return nil, root.ErrAlreadyPrinted
363+
case <-ctx.Done():
364+
return nil, ctx.Err()
338365
case <-time.After(interval):
339366
}
340367

341368
log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, time.Since(start).Truncate(time.Second))
342369

343-
pollResp, err := api.GetStatementByStatementId(pollCtx, statementID)
370+
pollResp, err := api.GetStatementByStatementId(ctx, statementID)
344371
if err != nil {
345-
if pollCtx.Err() != nil {
346-
cancelStatement()
347-
cmdio.LogString(ctx, "Query cancelled.")
348-
return nil, root.ErrAlreadyPrinted
372+
if ctx.Err() != nil {
373+
return nil, ctx.Err()
349374
}
350375
return nil, fmt.Errorf("poll statement status: %w", err)
351376
}
352377

353378
if isTerminalState(pollResp.Status) {
354-
sp.Close()
355-
if err := checkFailedState(pollResp.Status); err != nil {
356-
return nil, err
357-
}
358379
return &sql.StatementResponse{
359380
StatementId: pollResp.StatementId,
360381
Status: pollResp.Status,

experimental/aitools/cmd/query_test.go

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package aitools
22

33
import (
44
"context"
5+
"errors"
56
"os"
67
"path/filepath"
78
"strings"
@@ -48,7 +49,9 @@ func TestExecuteAndPollImmediateSuccess(t *testing.T) {
4849
mockAPI := mocksql.NewMockStatementExecutionInterface(t)
4950

5051
mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool {
51-
return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && req.WaitTimeout == "0s"
52+
return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" &&
53+
req.WaitTimeout == "0s" &&
54+
req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue
5255
})).Return(&sql.StatementResponse{
5356
StatementId: "stmt-1",
5457
Status: &sql.StatementStatus{State: sql.StatementStateSucceeded},
@@ -154,6 +157,106 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) {
154157
require.ErrorIs(t, err, root.ErrAlreadyPrinted)
155158
}
156159

160+
func TestPollStatementImmediateTerminal(t *testing.T) {
161+
ctx := cmdio.MockDiscard(t.Context())
162+
mockAPI := mocksql.NewMockStatementExecutionInterface(t)
163+
164+
resp := &sql.StatementResponse{
165+
StatementId: "stmt-1",
166+
Status: &sql.StatementStatus{State: sql.StatementStateSucceeded},
167+
Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "1"}}}},
168+
Result: &sql.ResultData{DataArray: [][]string{{"1"}}},
169+
}
170+
171+
pollResp, err := pollStatement(ctx, mockAPI, resp)
172+
require.NoError(t, err)
173+
assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State)
174+
assert.Equal(t, "stmt-1", pollResp.StatementId)
175+
}
176+
177+
func TestPollStatementTerminalFailureNotErrored(t *testing.T) {
178+
// pollStatement returns the response without erroring on failed terminal
179+
// states; callers (e.g. executeAndPoll) decide what to do via checkFailedState.
180+
ctx := cmdio.MockDiscard(t.Context())
181+
mockAPI := mocksql.NewMockStatementExecutionInterface(t)
182+
183+
resp := &sql.StatementResponse{
184+
StatementId: "stmt-1",
185+
Status: &sql.StatementStatus{
186+
State: sql.StatementStateFailed,
187+
Error: &sql.ServiceError{ErrorCode: "ERR", Message: "boom"},
188+
},
189+
}
190+
191+
pollResp, err := pollStatement(ctx, mockAPI, resp)
192+
require.NoError(t, err)
193+
assert.Equal(t, sql.StatementStateFailed, pollResp.Status.State)
194+
}
195+
196+
func TestPollStatementEventualSuccess(t *testing.T) {
197+
ctx := cmdio.MockDiscard(t.Context())
198+
mockAPI := mocksql.NewMockStatementExecutionInterface(t)
199+
200+
initial := &sql.StatementResponse{
201+
StatementId: "stmt-1",
202+
Status: &sql.StatementStatus{State: sql.StatementStatePending},
203+
}
204+
205+
mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{
206+
StatementId: "stmt-1",
207+
Status: &sql.StatementStatus{State: sql.StatementStateRunning},
208+
}, nil).Once()
209+
210+
mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{
211+
StatementId: "stmt-1",
212+
Status: &sql.StatementStatus{State: sql.StatementStateSucceeded},
213+
Result: &sql.ResultData{DataArray: [][]string{{"42"}}},
214+
}, nil).Once()
215+
216+
pollResp, err := pollStatement(ctx, mockAPI, initial)
217+
require.NoError(t, err)
218+
assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State)
219+
assert.Equal(t, [][]string{{"42"}}, pollResp.Result.DataArray)
220+
}
221+
222+
func TestPollStatementContextCancellationDoesNotCancelServerSide(t *testing.T) {
223+
// The mock asserts (via t.Cleanup) that no unexpected calls are made.
224+
// Specifically, pollStatement must NOT call CancelExecution on context
225+
// cancellation; that is the caller's responsibility.
226+
ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context()))
227+
mockAPI := mocksql.NewMockStatementExecutionInterface(t)
228+
229+
initial := &sql.StatementResponse{
230+
StatementId: "stmt-1",
231+
Status: &sql.StatementStatus{State: sql.StatementStatePending},
232+
}
233+
234+
cancel()
235+
236+
pollResp, err := pollStatement(ctx, mockAPI, initial)
237+
require.ErrorIs(t, err, context.Canceled)
238+
assert.Nil(t, pollResp)
239+
}
240+
241+
func TestPollStatementGetErrorPropagated(t *testing.T) {
242+
ctx := cmdio.MockDiscard(t.Context())
243+
mockAPI := mocksql.NewMockStatementExecutionInterface(t)
244+
245+
initial := &sql.StatementResponse{
246+
StatementId: "stmt-1",
247+
Status: &sql.StatementStatus{State: sql.StatementStatePending},
248+
}
249+
250+
mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").
251+
Return(nil, errors.New("network unreachable")).Once()
252+
253+
pollResp, err := pollStatement(ctx, mockAPI, initial)
254+
require.Error(t, err)
255+
assert.Contains(t, err.Error(), "poll statement status")
256+
assert.Contains(t, err.Error(), "network unreachable")
257+
assert.Nil(t, pollResp)
258+
}
259+
157260
func TestResolveWarehouseIDWithFlag(t *testing.T) {
158261
ctx := t.Context()
159262
id, err := resolveWarehouseID(ctx, nil, "explicit-id")

0 commit comments

Comments
 (0)