From acdcba0eaab1f4d7e6f947e4186f7fdce578e50c Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 27 Feb 2026 11:35:30 +0100 Subject: [PATCH 1/6] Add --warehouse flag, spinner, async polling, and Ctrl+C cancellation to query command Replace ExecuteAndWait with a custom poll loop that submits asynchronously (WaitTimeout=0s) to get the statement ID immediately. This enables: - Spinner with elapsed time in interactive mode - Server-side cancellation on Ctrl+C via CancelExecution API - Exponential backoff polling (1s -> 5s cap) Add --warehouse/-w flag to override auto-detection. Co-Authored-By: Claude Opus 4.6 (1M context) --- experimental/aitools/cmd/query.go | 199 ++++++++++++++++++--- experimental/aitools/cmd/query_test.go | 229 +++++++++++++++++++++++++ 2 files changed, 402 insertions(+), 26 deletions(-) create mode 100644 experimental/aitools/cmd/query_test.go diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 23fc3c5f2d..98afe8e438 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -1,30 +1,51 @@ package mcp import ( + "context" "encoding/json" "errors" "fmt" + "os" + "os/signal" "strings" + "syscall" + "time" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/experimental/aitools/lib/middlewares" "github.com/databricks/cli/experimental/aitools/lib/session" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" ) +const ( + // pollIntervalInitial is the starting interval between status polls. + pollIntervalInitial = 1 * time.Second + + // pollIntervalMax is the maximum interval between status polls. + pollIntervalMax = 5 * time.Second + + // cancelTimeout is how long to wait for server-side cancellation. + cancelTimeout = 10 * time.Second +) + func newQueryCmd() *cobra.Command { + var warehouseID string + cmd := &cobra.Command{ Use: "query SQL", Short: "Execute SQL against a Databricks warehouse", Long: `Execute a SQL statement against a Databricks SQL warehouse and return results. -The command auto-detects an available warehouse unless DATABRICKS_WAREHOUSE_ID is set. +The command auto-detects an available warehouse unless --warehouse is set +or the DATABRICKS_WAREHOUSE_ID environment variable is configured. Output includes the query results as JSON and row count.`, - Example: ` databricks experimental aitools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5"`, + Example: ` databricks experimental aitools tools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5" + databricks experimental aitools tools query --warehouse abc123 "SELECT 1"`, Args: cobra.ExactArgs(1), PreRunE: root.MustWorkspaceClient, RunE: func(cmd *cobra.Command, args []string) error { @@ -36,31 +57,14 @@ Output includes the query results as JSON and row count.`, return errors.New("SQL statement is required") } - // set up session with client for middleware compatibility - sess := session.NewSession() - sess.Set(middlewares.DatabricksClientKey, w) - ctx = session.WithSession(ctx, sess) - - warehouseID, err := middlewares.GetWarehouseID(ctx, true) + wID, err := resolveWarehouseID(ctx, w, warehouseID) if err != nil { return err } - resp, err := w.StatementExecution.ExecuteAndWait(ctx, sql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: sqlStatement, - WaitTimeout: "50s", - }) + resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqlStatement) if err != nil { - return fmt.Errorf("execute statement: %w", err) - } - - if resp.Status != nil && resp.Status.State == sql.StatementStateFailed { - errMsg := "query failed" - if resp.Status.Error != nil { - errMsg = resp.Status.Error.Message - } - return errors.New(errMsg) + return err } output, err := formatQueryResult(resp) @@ -73,13 +77,159 @@ Output includes the query results as JSON and row count.`, }, } + cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution") + return cmd } +// resolveWarehouseID returns the warehouse ID to use for query execution. +// Priority: explicit flag > middleware auto-detection (env var > server default > first running). +func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, error) { + if flagValue != "" { + return flagValue, nil + } + + sess := session.NewSession() + sess.Set(middlewares.DatabricksClientKey, w) + ctx = session.WithSession(ctx, sess) + + return middlewares.GetWarehouseID(ctx, true) +} + +// executeAndPoll submits a SQL statement asynchronously and polls until completion. +// It shows a spinner in interactive mode and supports Ctrl+C cancellation. +func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) { + // Submit asynchronously to get the statement ID immediately for cancellation. + resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + }) + if err != nil { + return nil, fmt.Errorf("execute statement: %w", err) + } + + statementID := resp.StatementId + + // Check if it completed immediately. + if isTerminalState(resp.Status) { + return resp, checkFailedState(resp.Status) + } + + // Set up Ctrl+C handling: cancel context + server-side cancellation. + pollCtx, pollCancel := context.WithCancel(ctx) + defer pollCancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + + cancelDone := make(chan struct{}) + go func() { + select { + case <-sigCh: + log.Infof(ctx, "Received interrupt, cancelling query %s", statementID) + pollCancel() + + // Best-effort server-side cancellation with independent context. + cancelCtx, cancel := context.WithTimeout(context.Background(), cancelTimeout) + defer cancel() + if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{ + StatementId: statementID, + }); err != nil { + log.Warnf(ctx, "Failed to cancel statement %s: %v", statementID, err) + } + close(cancelDone) + case <-pollCtx.Done(): + } + }() + + // Spinner for interactive feedback. + sp := cmdio.NewSpinner(pollCtx) + defer sp.Close() + start := time.Now() + sp.Update("Executing query...") + + interval := pollIntervalInitial + for { + select { + case <-pollCtx.Done(): + select { + case <-cancelDone: + case <-time.After(cancelTimeout): + } + return nil, errors.New("query cancelled") + case <-time.After(interval): + } + + elapsed := time.Since(start).Truncate(time.Second) + sp.Update(fmt.Sprintf("Executing query... (%s elapsed)", elapsed)) + + pollResp, err := api.GetStatementByStatementId(pollCtx, statementID) + if err != nil { + if pollCtx.Err() != nil { + select { + case <-cancelDone: + case <-time.After(cancelTimeout): + } + return nil, errors.New("query cancelled") + } + return nil, fmt.Errorf("poll statement status: %w", err) + } + + if isTerminalState(pollResp.Status) { + sp.Close() + if err := checkFailedState(pollResp.Status); err != nil { + return nil, err + } + return &sql.StatementResponse{ + StatementId: pollResp.StatementId, + Status: pollResp.Status, + Manifest: pollResp.Manifest, + Result: pollResp.Result, + }, nil + } + + interval = min(interval*2, pollIntervalMax) + } +} + +// isTerminalState returns true if the statement has reached a final state. +func isTerminalState(status *sql.StatementStatus) bool { + if status == nil { + return false + } + switch status.State { + case sql.StatementStateSucceeded, sql.StatementStateFailed, + sql.StatementStateCanceled, sql.StatementStateClosed: + return true + } + return false +} + +// checkFailedState returns an error if the statement is in a non-success terminal state. +func checkFailedState(status *sql.StatementStatus) error { + if status == nil { + return nil + } + switch status.State { + case sql.StatementStateFailed: + msg := "query failed" + if status.Error != nil { + msg = fmt.Sprintf("query failed: %s %s", status.Error.ErrorCode, status.Error.Message) + } + return errors.New(msg) + case sql.StatementStateCanceled: + return errors.New("query was cancelled") + case sql.StatementStateClosed: + return errors.New("query was closed before results could be fetched") + } + return nil +} + // cleanSQL removes surrounding quotes, empty lines, and SQL comments. func cleanSQL(s string) string { s = strings.TrimSpace(s) - // remove surrounding quotes if present if (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) || (strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) { s = s[1 : len(s)-1] @@ -88,7 +238,6 @@ func cleanSQL(s string) string { var lines []string for _, line := range strings.Split(s, "\n") { line = strings.TrimSpace(line) - // skip empty lines and single-line comments if line == "" || strings.HasPrefix(line, "--") { continue } @@ -105,7 +254,6 @@ func formatQueryResult(resp *sql.StatementResponse) (string, error) { return sb.String(), nil } - // get column names var columns []string if resp.Manifest.Schema != nil { for _, col := range resp.Manifest.Schema.Columns { @@ -113,7 +261,6 @@ func formatQueryResult(resp *sql.StatementResponse) (string, error) { } } - // format as JSON array for consistency with Neon API var rows []map[string]any if resp.Result.DataArray != nil { for _, row := range resp.Result.DataArray { diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go new file mode 100644 index 0000000000..9d8d7b8b8c --- /dev/null +++ b/experimental/aitools/cmd/query_test.go @@ -0,0 +1,229 @@ +package mcp + +import ( + "context" + "testing" + "time" + + mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestCleanSQL(t *testing.T) { + tests := []struct { + name string + in string + out string + }{ + {"plain", "SELECT 1", "SELECT 1"}, + {"double quoted", `"SELECT 1"`, "SELECT 1"}, + {"single quoted", `'SELECT 1'`, "SELECT 1"}, + {"strips comments", "-- comment\nSELECT 1", "SELECT 1"}, + {"strips empty lines", "\n\nSELECT 1\n\n", "SELECT 1"}, + {"multiline", "SELECT\n 1\nFROM\n dual", "SELECT\n1\nFROM\ndual"}, + {"empty", "", ""}, + {"only comments", "-- comment\n-- another", ""}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.out, cleanSQL(tc.in)) + }) + } +} + +func TestExecuteAndPollImmediateSuccess(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && req.WaitTimeout == "0s" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "1"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + }, nil) + + resp, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, resp.Status.State) + assert.Equal(t, "stmt-1", resp.StatementId) +} + +func TestExecuteAndPollImmediateFailure(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything).Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "SYNTAX_ERROR", + Message: "near 'SELCT': syntax error", + }, + }, + }, nil) + + _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELCT 1") + require.Error(t, err) + assert.Contains(t, err.Error(), "SYNTAX_ERROR") + assert.Contains(t, err.Error(), "syntax error") +} + +func TestExecuteAndPollWithPolling(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything).Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil) + + // First poll: still RUNNING. + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + // Second poll: SUCCEEDED. + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "result"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, + }, nil).Once() + + resp, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 42") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, resp.Status.State) + assert.Equal(t, [][]string{{"42"}}, resp.Result.DataArray) +} + +func TestExecuteAndPollFailsDuringPolling(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything).Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ErrorCode: "RESOURCE_EXHAUSTED", Message: "warehouse unavailable"}, + }, + }, nil).Once() + + _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1") + require.Error(t, err) + assert.Contains(t, err.Error(), "RESOURCE_EXHAUSTED") +} + +func TestExecuteAndPollCancelledContext(t *testing.T) { + ctx, cancel := context.WithCancel(cmdio.MockDiscard(context.Background())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything).Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil) + + cancel() + + _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1") + require.Error(t, err) + assert.Contains(t, err.Error(), "cancelled") +} + +func TestResolveWarehouseIDWithFlag(t *testing.T) { + ctx := context.Background() + id, err := resolveWarehouseID(ctx, nil, "explicit-id") + require.NoError(t, err) + assert.Equal(t, "explicit-id", id) +} + +func TestFormatQueryResultNoResults(t *testing.T) { + resp := &sql.StatementResponse{ + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + } + output, err := formatQueryResult(resp) + require.NoError(t, err) + assert.Contains(t, output, "no results") +} + +func TestFormatQueryResultWithData(t *testing.T) { + resp := &sql.StatementResponse{ + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{ + Schema: &sql.ResultSchema{ + Columns: []sql.ColumnInfo{{Name: "id"}, {Name: "name"}}, + }, + }, + Result: &sql.ResultData{ + DataArray: [][]string{{"1", "alice"}, {"2", "bob"}}, + }, + } + output, err := formatQueryResult(resp) + require.NoError(t, err) + assert.Contains(t, output, "alice") + assert.Contains(t, output, "bob") + assert.Contains(t, output, "Row count: 2") +} + +func TestIsTerminalState(t *testing.T) { + tests := []struct { + state sql.StatementState + terminal bool + }{ + {sql.StatementStateSucceeded, true}, + {sql.StatementStateFailed, true}, + {sql.StatementStateCanceled, true}, + {sql.StatementStateClosed, true}, + {sql.StatementStatePending, false}, + {sql.StatementStateRunning, false}, + } + + for _, tc := range tests { + t.Run(string(tc.state), func(t *testing.T) { + status := &sql.StatementStatus{State: tc.state} + assert.Equal(t, tc.terminal, isTerminalState(status)) + }) + } + + assert.False(t, isTerminalState(nil)) +} + +func TestCheckFailedState(t *testing.T) { + assert.NoError(t, checkFailedState(nil)) + assert.NoError(t, checkFailedState(&sql.StatementStatus{State: sql.StatementStateSucceeded})) + + err := checkFailedState(&sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ErrorCode: "ERR", Message: "bad"}, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "ERR") + assert.Contains(t, err.Error(), "bad") + + err = checkFailedState(&sql.StatementStatus{State: sql.StatementStateCanceled}) + require.Error(t, err) + assert.Contains(t, err.Error(), "cancelled") + + err = checkFailedState(&sql.StatementStatus{State: sql.StatementStateClosed}) + require.Error(t, err) + assert.Contains(t, err.Error(), "closed") +} + +func TestPollingConstants(t *testing.T) { + assert.Equal(t, 1*time.Second, pollIntervalInitial) + assert.Equal(t, 5*time.Second, pollIntervalMax) + assert.Equal(t, 10*time.Second, cancelTimeout) +} From 6d731b19a95c567997b0d582fb27b90da33e25f3 Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 27 Feb 2026 11:45:19 +0100 Subject: [PATCH 2/6] Fix cancellation: always call CancelExecution on context cancel, fix poll backoff Three fixes from code review: 1. Unify cancellation path: CancelExecution is now called for any context cancellation (signal or parent), not just Ctrl+C. Previously parent-context cancellation would block for 10s without server-side cleanup. 2. Fix poll backoff: change from exponential (1,2,4,5s) to additive (1,2,3,4,5s) to match the plan specification. 3. Test now verifies CancelExecution is called on context cancel and runs in <1s instead of 10s. Co-Authored-By: Claude Opus 4.6 (1M context) --- experimental/aitools/cmd/query.go | 38 ++++++++++++-------------- experimental/aitools/cmd/query_test.go | 7 ++++- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 98afe8e438..ba1215ee0e 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -116,7 +116,7 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa return resp, checkFailedState(resp.Status) } - // Set up Ctrl+C handling: cancel context + server-side cancellation. + // Set up Ctrl+C: signal cancels the poll context, cleanup is unified below. pollCtx, pollCancel := context.WithCancel(ctx) defer pollCancel() @@ -124,40 +124,39 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) defer signal.Stop(sigCh) - cancelDone := make(chan struct{}) go func() { select { case <-sigCh: log.Infof(ctx, "Received interrupt, cancelling query %s", statementID) pollCancel() - - // Best-effort server-side cancellation with independent context. - cancelCtx, cancel := context.WithTimeout(context.Background(), cancelTimeout) - defer cancel() - if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{ - StatementId: statementID, - }); err != nil { - log.Warnf(ctx, "Failed to cancel statement %s: %v", statementID, err) - } - close(cancelDone) case <-pollCtx.Done(): } }() + // cancelStatement performs best-effort server-side cancellation. + // Called on any poll exit due to context cancellation (signal or parent). + cancelStatement := func() { + cancelCtx, cancel := context.WithTimeout(context.Background(), cancelTimeout) + defer cancel() + if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{ + StatementId: statementID, + }); err != nil { + log.Warnf(ctx, "Failed to cancel statement %s: %v", statementID, err) + } + } + // Spinner for interactive feedback. sp := cmdio.NewSpinner(pollCtx) defer sp.Close() start := time.Now() sp.Update("Executing query...") + // Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped). interval := pollIntervalInitial for { select { case <-pollCtx.Done(): - select { - case <-cancelDone: - case <-time.After(cancelTimeout): - } + cancelStatement() return nil, errors.New("query cancelled") case <-time.After(interval): } @@ -168,10 +167,7 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa pollResp, err := api.GetStatementByStatementId(pollCtx, statementID) if err != nil { if pollCtx.Err() != nil { - select { - case <-cancelDone: - case <-time.After(cancelTimeout): - } + cancelStatement() return nil, errors.New("query cancelled") } return nil, fmt.Errorf("poll statement status: %w", err) @@ -190,7 +186,7 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa }, nil } - interval = min(interval*2, pollIntervalMax) + interval = min(interval+time.Second, pollIntervalMax) } } diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index 9d8d7b8b8c..22c6daacb9 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -127,7 +127,7 @@ func TestExecuteAndPollFailsDuringPolling(t *testing.T) { assert.Contains(t, err.Error(), "RESOURCE_EXHAUSTED") } -func TestExecuteAndPollCancelledContext(t *testing.T) { +func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { ctx, cancel := context.WithCancel(cmdio.MockDiscard(context.Background())) mockAPI := mocksql.NewMockStatementExecutionInterface(t) @@ -136,6 +136,11 @@ func TestExecuteAndPollCancelledContext(t *testing.T) { Status: &sql.StatementStatus{State: sql.StatementStatePending}, }, nil) + // CancelExecution must be called when context is cancelled (not just on signal). + mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + StatementId: "stmt-1", + }).Return(nil).Once() + cancel() _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1") From 1065b2175c4b117705b1360aff332fa977c95d58 Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 27 Feb 2026 11:51:16 +0100 Subject: [PATCH 3/6] Print "Query cancelled." to stderr, add debug logging for non-interactive polling - On cancellation, print "Query cancelled." via cmdio.LogString (stderr) and return root.ErrAlreadyPrinted so the root handler doesn't double-print with "Error:" prefix. - Add log.Debugf on each poll iteration so non-interactive mode has a status trace (spinner is silent when stderr is not a TTY). Co-Authored-By: Claude Opus 4.6 (1M context) --- experimental/aitools/cmd/query.go | 7 +++++-- experimental/aitools/cmd/query_test.go | 6 +++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index ba1215ee0e..7dbda528ad 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -157,18 +157,21 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa select { case <-pollCtx.Done(): cancelStatement() - return nil, errors.New("query cancelled") + cmdio.LogString(ctx, "Query cancelled.") + return nil, root.ErrAlreadyPrinted case <-time.After(interval): } elapsed := time.Since(start).Truncate(time.Second) sp.Update(fmt.Sprintf("Executing query... (%s elapsed)", elapsed)) + log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, elapsed) pollResp, err := api.GetStatementByStatementId(pollCtx, statementID) if err != nil { if pollCtx.Err() != nil { cancelStatement() - return nil, errors.New("query cancelled") + cmdio.LogString(ctx, "Query cancelled.") + return nil, root.ErrAlreadyPrinted } return nil, fmt.Errorf("poll statement status: %w", err) } diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index 22c6daacb9..265b244a7c 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -5,9 +5,10 @@ import ( "testing" "time" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" "github.com/databricks/databricks-sdk-go/service/sql" - "github.com/databricks/cli/libs/cmdio" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -144,8 +145,7 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { cancel() _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1") - require.Error(t, err) - assert.Contains(t, err.Error(), "cancelled") + require.ErrorIs(t, err, root.ErrAlreadyPrinted) } func TestResolveWarehouseIDWithFlag(t *testing.T) { From df30e88be0832e3149a457e237b4520bb0e23111 Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 27 Feb 2026 12:03:43 +0100 Subject: [PATCH 4/6] Update elapsed time in spinner every second via ticker The elapsed time text was only updating on each poll cycle, which with additive backoff + API latency (up to 5s per call) meant gaps of 6-13s between updates. Now a background ticker updates the spinner text every second, independent of the poll interval. Co-Authored-By: Claude Opus 4.6 (1M context) --- experimental/aitools/cmd/query.go | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 7dbda528ad..ef0c4bd679 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -145,12 +145,26 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa } } - // Spinner for interactive feedback. + // Spinner for interactive feedback, updated every second via ticker. sp := cmdio.NewSpinner(pollCtx) defer sp.Close() start := time.Now() sp.Update("Executing query...") + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + go func() { + for { + select { + case <-pollCtx.Done(): + return + case <-ticker.C: + elapsed := time.Since(start).Truncate(time.Second) + sp.Update(fmt.Sprintf("Executing query... (%s elapsed)", elapsed)) + } + } + }() + // Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped). interval := pollIntervalInitial for { @@ -162,9 +176,7 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa case <-time.After(interval): } - elapsed := time.Since(start).Truncate(time.Second) - sp.Update(fmt.Sprintf("Executing query... (%s elapsed)", elapsed)) - log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, elapsed) + log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, time.Since(start).Truncate(time.Second)) pollResp, err := api.GetStatementByStatementId(pollCtx, statementID) if err != nil { From db26b61065cbe61ec37e50d1ddb1960a196574a3 Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 27 Feb 2026 12:27:50 +0100 Subject: [PATCH 5/6] Add hint for UNRESOLVED_MAP_KEY errors about shell quote stripping When Databricks returns an UNRESOLVED_MAP_KEY error, append a hint suggesting the user switch to single quotes for map keys or use --file to avoid shell quote stripping. This is a common issue when passing SQL with map access like info["key"] inside a shell double-quoted string. Co-Authored-By: Claude Opus 4.6 (1M context) --- experimental/aitools/cmd/query.go | 5 +++++ experimental/aitools/cmd/query_test.go | 14 ++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index ef0c4bd679..27b4b3df81 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -228,6 +228,10 @@ func checkFailedState(status *sql.StatementStatus) error { msg := "query failed" if status.Error != nil { msg = fmt.Sprintf("query failed: %s %s", status.Error.ErrorCode, status.Error.Message) + if strings.Contains(status.Error.Message, "UNRESOLVED_MAP_KEY") { + msg += "\n\nHint: your shell may have stripped quotes from the SQL string. " + + "Use single quotes for map keys (e.g. info['key']) or pass the query via --file." + } } return errors.New(msg) case sql.StatementStateCanceled: @@ -254,6 +258,7 @@ func cleanSQL(s string) string { } lines = append(lines, line) } + return strings.Join(lines, "\n") } diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index 265b244a7c..b945e88606 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -227,6 +227,20 @@ func TestCheckFailedState(t *testing.T) { assert.Contains(t, err.Error(), "closed") } +func TestCheckFailedStateMapKeyHint(t *testing.T) { + err := checkFailedState(&sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "BAD_REQUEST", + Message: "[UNRESOLVED_MAP_KEY.WITH_SUGGESTION] Cannot resolve column", + }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "Hint:") + assert.Contains(t, err.Error(), "single quotes") + assert.Contains(t, err.Error(), "--file") +} + func TestPollingConstants(t *testing.T) { assert.Equal(t, 1*time.Second, pollIntervalInitial) assert.Equal(t, 5*time.Second, pollIntervalMax) From 03eedc38ff23d1c599ac61387a069a8604f4a330 Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 27 Feb 2026 12:59:58 +0100 Subject: [PATCH 6/6] Add exhaustive switch cases for StatementState to satisfy linter Co-Authored-By: Claude Opus 4.6 (1M context) --- experimental/aitools/cmd/query.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 27b4b3df81..f57d1fd6f4 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -214,6 +214,8 @@ func isTerminalState(status *sql.StatementStatus) bool { case sql.StatementStateSucceeded, sql.StatementStateFailed, sql.StatementStateCanceled, sql.StatementStateClosed: return true + case sql.StatementStatePending, sql.StatementStateRunning: + return false } return false } @@ -238,6 +240,8 @@ func checkFailedState(status *sql.StatementStatus) error { return errors.New("query was cancelled") case sql.StatementStateClosed: return errors.New("query was closed before results could be fetched") + case sql.StatementStatePending, sql.StatementStateRunning, sql.StatementStateSucceeded: + return nil } return nil }