diff --git a/cdc/sink/dmlsink/txn/mysql/mysql.go b/cdc/sink/dmlsink/txn/mysql/mysql.go index 10e88ab792c..29897f88614 100644 --- a/cdc/sink/dmlsink/txn/mysql/mysql.go +++ b/cdc/sink/dmlsink/txn/mysql/mysql.go @@ -762,6 +762,24 @@ func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *prepare start, s.changefeed, "BEGIN", dmls.rowCount, dmls.startTs) } + // Set session variables first and then execute the transaction. + // we try to set write source for each txn, + // so we can use it to trace the data source + if err = s.setWriteSource(pctx, tx); err != nil { + err := logDMLTxnErr( + cerror.WrapError(cerror.ErrMySQLTxnError, err), + start, s.changefeed, + fmt.Sprintf("SET SESSION %s = %d", "tidb_cdc_write_source", + s.cfg.SourceID), + dmls.rowCount, dmls.startTs) + if rbErr := tx.Rollback(); rbErr != nil { + if errors.Cause(rbErr) != context.Canceled { + log.Warn("failed to rollback txn", zap.String("changefeed", s.changefeed), zap.Error(rbErr)) + } + } + return 0, 0, err + } + // If interplated SQL size exceeds maxAllowedPacket, mysql driver will // fall back to the sequantial way. // error can be ErrPrepareMulti, ErrBadConn etc. @@ -780,23 +798,6 @@ func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *prepare } } - // we try to set write source for each txn, - // so we can use it to trace the data source - if err = s.setWriteSource(pctx, tx); err != nil { - err := logDMLTxnErr( - cerror.WrapError(cerror.ErrMySQLTxnError, err), - start, s.changefeed, - fmt.Sprintf("SET SESSION %s = %d", "tidb_cdc_write_source", - s.cfg.SourceID), - dmls.rowCount, dmls.startTs) - if rbErr := tx.Rollback(); rbErr != nil { - if errors.Cause(rbErr) != context.Canceled { - log.Warn("failed to rollback txn", zap.String("changefeed", s.changefeed), zap.Error(rbErr)) - } - } - return 0, 0, err - } - if err = tx.Commit(); err != nil { return 0, 0, logDMLTxnErr( cerror.WrapError(cerror.ErrMySQLTxnError, err),