diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index 1201705f51f1..cb05deca9d45 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -876,6 +876,7 @@ go_test( "//pkg/sql/sessiondatapb", "//pkg/sql/sessionphase", "//pkg/sql/sqlclustersettings", + "//pkg/sql/sqlerrors", "//pkg/sql/sqlinstance", "//pkg/sql/sqlliveness", "//pkg/sql/sqlliveness/sqllivenesstestutils", diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index a57100f25ff5..a031c17fabf2 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -4029,9 +4029,7 @@ func (ex *connExecutor) txnStateTransitionsApplyWrapper( return advanceInfo{}, err } ex.statsCollector.PhaseTimes().SetSessionPhaseTime(sessionphase.SessionStartPostCommitJob, timeutil.Now()) - if err := ex.server.cfg.JobRegistry.Run( - ex.ctxHolder.connCtx, ex.extraTxnState.jobs.created, - ); err != nil { + if err := ex.waitForTxnJobs(); err != nil { handleErr(err) } ex.statsCollector.PhaseTimes().SetSessionPhaseTime(sessionphase.SessionEndPostCommitJob, timeutil.Now()) @@ -4062,6 +4060,65 @@ func (ex *connExecutor) txnStateTransitionsApplyWrapper( return advInfo, nil } +// waitForTxnJobs waits for any jobs created inside this txn +// and respects the statement timeout for implicit transactions. +func (ex *connExecutor) waitForTxnJobs() error { + var retErr error + if len(ex.extraTxnState.jobs.created) == 0 { + return nil + } + ex.server.cfg.JobRegistry.NotifyToResume( + ex.ctxHolder.connCtx, ex.extraTxnState.jobs.created..., + ) + // Set up a context for waiting for the jobs, which can be cancelled if + // a statement timeout exists. + jobWaitCtx := ex.ctxHolder.ctx() + var queryTimedout atomic.Bool + if ex.sessionData().StmtTimeout > 0 { + timePassed := timeutil.Since(ex.phaseTimes.GetSessionPhaseTime(sessionphase.SessionQueryReceived)) + if timePassed > ex.sessionData().StmtTimeout { + queryTimedout.Store(true) + } else { + var cancelFn context.CancelFunc + jobWaitCtx, cancelFn = context.WithCancel(jobWaitCtx) + queryTimeTicker := time.AfterFunc(ex.sessionData().StmtTimeout-timePassed, func() { + cancelFn() + queryTimedout.Store(true) + }) + defer cancelFn() + defer queryTimeTicker.Stop() + } + } + if !queryTimedout.Load() && len(ex.extraTxnState.jobs.created) > 0 { + if err := ex.server.cfg.JobRegistry.WaitForJobs(jobWaitCtx, + ex.extraTxnState.jobs.created); err != nil { + if errors.Is(err, context.Canceled) && queryTimedout.Load() { + retErr = sqlerrors.QueryTimeoutError + err = nil + } else { + return err + } + } + } + // If the query timed out indicate that there are jobs left behind. + if queryTimedout.Load() { + jobList := strings.Builder{} + for i, j := range ex.extraTxnState.jobs.created { + if i > 0 { + jobList.WriteString(",") + } + jobList.WriteString(j.String()) + } + if err := ex.planner.noticeSender.SendNotice(ex.ctxHolder.connCtx, + pgnotice.Newf("The statement has timed out, but the following "+ + "background jobs have been created and will continue running: %s.", + jobList.String())); err != nil { + return err + } + } + return retErr +} + func (ex *connExecutor) maybeSetSQLLivenessSession() error { if !ex.server.cfg.Codec.ForSystemTenant() || ex.server.cfg.TestingKnobs.ForceSQLLivenessSession { diff --git a/pkg/sql/run_control_test.go b/pkg/sql/run_control_test.go index 3b2fbb6c620c..235cc1c59813 100644 --- a/pkg/sql/run_control_test.go +++ b/pkg/sql/run_control_test.go @@ -29,6 +29,10 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descs" + "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scexec" + "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scop" + "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scplan" + "github.com/cockroachdb/cockroach/pkg/sql/sqlerrors" "github.com/cockroachdb/cockroach/pkg/sql/sqltestutils" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -40,6 +44,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" + "github.com/lib/pq" "github.com/petermattis/goid" "github.com/stretchr/testify/require" ) @@ -968,3 +973,69 @@ func TestTenantStatementTimeoutAdmissionQueueCancellation(t *testing.T) { wg.Wait() require.ErrorIs(t, ctx.Err(), context.Canceled) } + +// TestStatementTimeoutForSchemaChangeCommit confirms that waiting for the job +// phase of the schema change respects statement timeout. +func TestStatementTimeoutForSchemaChangeCommit(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + + for _, implicitTxn := range []bool{true, false} { + t.Run(fmt.Sprintf("implicitTxn=%t", implicitTxn), + func(t *testing.T) { + numNodes := 1 + var blockSchemaChange atomic.Bool + waitForTimeout := make(chan struct{}) + tc := serverutils.StartCluster(t, numNodes, + base.TestClusterArgs{ + ServerArgs: base.TestServerArgs{ + Knobs: base.TestingKnobs{ + SQLDeclarativeSchemaChanger: &scexec.TestingKnobs{ + AfterStage: func(p scplan.Plan, stageIdx int) error { + if blockSchemaChange.Load() && p.Params.ExecutionPhase == scop.PostCommitPhase { + <-waitForTimeout + } + return nil + }, + }, + }, + }, + }) + defer tc.Stopper().Stop(ctx) + + url, cleanup := tc.ApplicationLayer(0).PGUrl(t) + defer cleanup() + baseConn, err := pq.NewConnector(url.String()) + require.NoError(t, err) + actualNotices := make([]string, 0) + connector := pq.ConnectorWithNoticeHandler(baseConn, func(n *pq.Error) { + actualNotices = append(actualNotices, n.Message) + }) + dbWithHandler := gosql.OpenDB(connector) + defer dbWithHandler.Close() + conn := sqlutils.MakeSQLRunner(dbWithHandler) + conn.Exec(t, "CREATE TABLE t1 (n int primary key)") + conn.Exec(t, `SET statement_timeout = '30s'`) + require.NoError(t, err) + // Test implicit transactions first. + blockSchemaChange.Swap(true) + if implicitTxn { + _, err := conn.DB.ExecContext(ctx, "ALTER TABLE t1 ADD COLUMN j INT DEFAULT 32") + require.Errorf(t, err, sqlerrors.QueryTimeoutError.Error()) + require.Equal(t, 1, len(actualNotices)) + require.Regexp(t, + "The statement has timed out, but the following background jobs have been created and will continue running: \\d+", + actualNotices[0]) + } else { + txn := conn.Begin(t) + _, err := txn.Exec("ALTER TABLE t1 ADD COLUMN j INT DEFAULT 32") + require.NoError(t, err) + err = txn.Commit() + require.NoError(t, err) + } + close(waitForTimeout) + blockSchemaChange.Swap(false) + }) + } +}