diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index 1e502a542199..629989fef176 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -568,3 +568,28 @@ func (c *Cluster) processRegionHeartbeat(region *core.RegionInfo) error { func (c *Cluster) IsPrepared() bool { return c.coordinator.GetPrepareChecker().IsPrepared() } +<<<<<<< HEAD +======= + +// SetPrepared set the prepare check to prepared. Only for test purpose. +func (c *Cluster) SetPrepared() { + c.coordinator.GetPrepareChecker().SetPrepared() +} + +// DropCacheAllRegion removes all cached regions. +func (c *Cluster) DropCacheAllRegion() { + c.ResetRegionCache() +} + +// DropCacheRegion removes a region from the cache. +func (c *Cluster) DropCacheRegion(id uint64) { + c.RemoveRegionIfExist(id) +} + +// IsSchedulingHalted returns whether the scheduling is halted. +// Currently, the microservice scheduling is halted when: +// - The `HaltScheduling` persist option is set to true. +func (c *Cluster) IsSchedulingHalted() bool { + return c.persistConfig.IsSchedulingHalted() +} +>>>>>>> 740f15e65 (*: individually check the scheduling halt for online unsafe recovery (#8147)) diff --git a/pkg/mcs/scheduling/server/config/config.go b/pkg/mcs/scheduling/server/config/config.go index d462a9a58b53..72e46cb341fd 100644 --- a/pkg/mcs/scheduling/server/config/config.go +++ b/pkg/mcs/scheduling/server/config/config.go @@ -603,6 +603,10 @@ func (o *PersistConfig) SetSplitMergeInterval(splitMergeInterval time.Duration) o.SetScheduleConfig(v) } +// SetSchedulingAllowanceStatus sets the scheduling allowance status to help distinguish the source of the halt. +// TODO: support this metrics for the scheduling service in the future. +func (*PersistConfig) SetSchedulingAllowanceStatus(bool, string) {} + // SetHaltScheduling set HaltScheduling. func (o *PersistConfig) SetHaltScheduling(halt bool, source string) { v := o.GetScheduleConfig().Clone() diff --git a/pkg/mcs/scheduling/server/grpc_service.go b/pkg/mcs/scheduling/server/grpc_service.go index 204cd8ffe0c9..45500f113a61 100644 --- a/pkg/mcs/scheduling/server/grpc_service.go +++ b/pkg/mcs/scheduling/server/grpc_service.go @@ -191,6 +191,150 @@ func (s *Service) StoreHeartbeat(ctx context.Context, request *schedulingpb.Stor return &schedulingpb.StoreHeartbeatResponse{Header: &schedulingpb.ResponseHeader{ClusterId: s.clusterID}}, nil } +<<<<<<< HEAD +======= +// SplitRegions split regions by the given split keys +func (s *Service) SplitRegions(ctx context.Context, request *schedulingpb.SplitRegionsRequest) (*schedulingpb.SplitRegionsResponse, error) { + c := s.GetCluster() + if c == nil { + return &schedulingpb.SplitRegionsResponse{Header: s.notBootstrappedHeader()}, nil + } + finishedPercentage, newRegionIDs := c.GetRegionSplitter().SplitRegions(ctx, request.GetSplitKeys(), int(request.GetRetryLimit())) + return &schedulingpb.SplitRegionsResponse{ + Header: s.header(), + RegionsId: newRegionIDs, + FinishedPercentage: uint64(finishedPercentage), + }, nil +} + +// ScatterRegions implements gRPC SchedulingServer. +func (s *Service) ScatterRegions(_ context.Context, request *schedulingpb.ScatterRegionsRequest) (*schedulingpb.ScatterRegionsResponse, error) { + c := s.GetCluster() + if c == nil { + return &schedulingpb.ScatterRegionsResponse{Header: s.notBootstrappedHeader()}, nil + } + + opsCount, failures, err := c.GetRegionScatterer().ScatterRegionsByID(request.GetRegionsId(), request.GetGroup(), int(request.GetRetryLimit()), request.GetSkipStoreLimit()) + if err != nil { + header := s.errorHeader(&schedulingpb.Error{ + Type: schedulingpb.ErrorType_UNKNOWN, + Message: err.Error(), + }) + return &schedulingpb.ScatterRegionsResponse{Header: header}, nil + } + percentage := 100 + if len(failures) > 0 { + percentage = 100 - 100*len(failures)/(opsCount+len(failures)) + log.Debug("scatter regions", zap.Errors("failures", func() []error { + r := make([]error, 0, len(failures)) + for _, err := range failures { + r = append(r, err) + } + return r + }())) + } + return &schedulingpb.ScatterRegionsResponse{ + Header: s.header(), + FinishedPercentage: uint64(percentage), + }, nil +} + +// GetOperator gets information about the operator belonging to the specify region. +func (s *Service) GetOperator(_ context.Context, request *schedulingpb.GetOperatorRequest) (*schedulingpb.GetOperatorResponse, error) { + c := s.GetCluster() + if c == nil { + return &schedulingpb.GetOperatorResponse{Header: s.notBootstrappedHeader()}, nil + } + + opController := c.GetCoordinator().GetOperatorController() + requestID := request.GetRegionId() + r := opController.GetOperatorStatus(requestID) + if r == nil { + header := s.errorHeader(&schedulingpb.Error{ + Type: schedulingpb.ErrorType_UNKNOWN, + Message: "region not found", + }) + return &schedulingpb.GetOperatorResponse{Header: header}, nil + } + + return &schedulingpb.GetOperatorResponse{ + Header: s.header(), + RegionId: requestID, + Desc: []byte(r.Desc()), + Kind: []byte(r.Kind().String()), + Status: r.Status, + }, nil +} + +// AskBatchSplit implements gRPC SchedulingServer. +func (s *Service) AskBatchSplit(_ context.Context, request *schedulingpb.AskBatchSplitRequest) (*schedulingpb.AskBatchSplitResponse, error) { + c := s.GetCluster() + if c == nil { + return &schedulingpb.AskBatchSplitResponse{Header: s.notBootstrappedHeader()}, nil + } + + if request.GetRegion() == nil { + return &schedulingpb.AskBatchSplitResponse{ + Header: s.wrapErrorToHeader(schedulingpb.ErrorType_UNKNOWN, + "missing region for split"), + }, nil + } + + if c.IsSchedulingHalted() { + return nil, errs.ErrSchedulingIsHalted.FastGenByArgs() + } + if !c.persistConfig.IsTikvRegionSplitEnabled() { + return nil, errs.ErrSchedulerTiKVSplitDisabled.FastGenByArgs() + } + reqRegion := request.GetRegion() + splitCount := request.GetSplitCount() + err := c.ValidRegion(reqRegion) + if err != nil { + return nil, err + } + splitIDs := make([]*pdpb.SplitID, 0, splitCount) + recordRegions := make([]uint64, 0, splitCount+1) + + for i := 0; i < int(splitCount); i++ { + newRegionID, err := c.AllocID() + if err != nil { + return nil, errs.ErrSchedulerNotFound.FastGenByArgs() + } + + peerIDs := make([]uint64, len(request.Region.Peers)) + for i := 0; i < len(peerIDs); i++ { + if peerIDs[i], err = c.AllocID(); err != nil { + return nil, err + } + } + + recordRegions = append(recordRegions, newRegionID) + splitIDs = append(splitIDs, &pdpb.SplitID{ + NewRegionId: newRegionID, + NewPeerIds: peerIDs, + }) + + log.Info("alloc ids for region split", zap.Uint64("region-id", newRegionID), zap.Uint64s("peer-ids", peerIDs)) + } + + recordRegions = append(recordRegions, reqRegion.GetId()) + if versioninfo.IsFeatureSupported(c.persistConfig.GetClusterVersion(), versioninfo.RegionMerge) { + // Disable merge the regions in a period of time. + c.GetCoordinator().GetMergeChecker().RecordRegionSplit(recordRegions) + } + + // If region splits during the scheduling process, regions with abnormal + // status may be left, and these regions need to be checked with higher + // priority. + c.GetCoordinator().GetCheckerController().AddSuspectRegions(recordRegions...) + + return &schedulingpb.AskBatchSplitResponse{ + Header: s.header(), + Ids: splitIDs, + }, nil +} + +>>>>>>> 740f15e65 (*: individually check the scheduling halt for online unsafe recovery (#8147)) // RegisterGRPCService registers the service to gRPC server. func (s *Service) RegisterGRPCService(g *grpc.Server) { schedulingpb.RegisterSchedulingServer(g, s) diff --git a/pkg/schedule/config/config_provider.go b/pkg/schedule/config/config_provider.go index 20c7f0dc2cf1..90e489f86f3e 100644 --- a/pkg/schedule/config/config_provider.go +++ b/pkg/schedule/config/config_provider.go @@ -46,7 +46,7 @@ func IsSchedulerRegistered(name string) bool { type SchedulerConfigProvider interface { SharedConfigProvider - IsSchedulingHalted() bool + SetSchedulingAllowanceStatus(bool, string) GetStoresLimit() map[uint64]StoreLimitConfig IsSchedulerDisabled(string) bool diff --git a/pkg/schedule/coordinator.go b/pkg/schedule/coordinator.go index fcdd8c9a32c8..7d4b839385ab 100644 --- a/pkg/schedule/coordinator.go +++ b/pkg/schedule/coordinator.go @@ -178,7 +178,7 @@ func (c *Coordinator) PatrolRegions() { log.Info("patrol regions has been stopped") return } - if c.isSchedulingHalted() { + if c.cluster.IsSchedulingHalted() { continue } @@ -207,10 +207,6 @@ func (c *Coordinator) PatrolRegions() { } } -func (c *Coordinator) isSchedulingHalted() bool { - return c.cluster.GetSchedulerConfig().IsSchedulingHalted() -} - func (c *Coordinator) checkRegions(startKey []byte) (key []byte, regions []*core.RegionInfo) { regions = c.cluster.ScanRegions(startKey, nil, patrolScanRegionLimit) if len(regions) == 0 { diff --git a/pkg/schedule/core/cluster_informer.go b/pkg/schedule/core/cluster_informer.go index 63dacd0c30dd..b97459d26ea6 100644 --- a/pkg/schedule/core/cluster_informer.go +++ b/pkg/schedule/core/cluster_informer.go @@ -43,6 +43,7 @@ type SchedulerCluster interface { GetSchedulerConfig() sc.SchedulerConfigProvider GetRegionLabeler() *labeler.RegionLabeler GetStoreConfig() sc.StoreConfigProvider + IsSchedulingHalted() bool } // CheckerCluster is an aggregate interface that wraps multiple interfaces diff --git a/pkg/schedule/schedulers/scheduler_controller.go b/pkg/schedule/schedulers/scheduler_controller.go index 25a2c8b2afe4..28bb5c96c030 100644 --- a/pkg/schedule/schedulers/scheduler_controller.go +++ b/pkg/schedule/schedulers/scheduler_controller.go @@ -114,7 +114,7 @@ func (c *Controller) CollectSchedulerMetrics() { var allowScheduler float64 // If the scheduler is not allowed to schedule, it will disappear in Grafana panel. // See issue #1341. - if !s.IsPaused() && !c.isSchedulingHalted() { + if !s.IsPaused() && !c.cluster.IsSchedulingHalted() { allowScheduler = 1 } schedulerStatusGauge.WithLabelValues(s.Scheduler.GetName(), "allow").Set(allowScheduler) @@ -130,10 +130,6 @@ func (c *Controller) CollectSchedulerMetrics() { ruleStatusGauge.WithLabelValues("group_count").Set(float64(groupCnt)) } -func (c *Controller) isSchedulingHalted() bool { - return c.cluster.GetSchedulerConfig().IsSchedulingHalted() -} - // ResetSchedulerMetrics resets metrics of all schedulers. func ResetSchedulerMetrics() { schedulerStatusGauge.Reset() @@ -518,7 +514,7 @@ func (s *ScheduleController) AllowSchedule(diagnosable bool) bool { } return false } - if s.isSchedulingHalted() { + if s.cluster.IsSchedulingHalted() { if diagnosable { s.diagnosticRecorder.SetResultFromStatus(Halted) } @@ -533,10 +529,6 @@ func (s *ScheduleController) AllowSchedule(diagnosable bool) bool { return true } -func (s *ScheduleController) isSchedulingHalted() bool { - return s.cluster.GetSchedulerConfig().IsSchedulingHalted() -} - // IsPaused returns if a scheduler is paused. func (s *ScheduleController) IsPaused() bool { delayUntil := atomic.LoadInt64(&s.delayUntil) diff --git a/pkg/unsaferecovery/unsafe_recovery_controller.go b/pkg/unsaferecovery/unsafe_recovery_controller.go index aa45ba6a2bdc..3ce299ca6e01 100644 --- a/pkg/unsaferecovery/unsafe_recovery_controller.go +++ b/pkg/unsaferecovery/unsafe_recovery_controller.go @@ -492,12 +492,11 @@ func (u *Controller) GetStage() stage { } func (u *Controller) changeStage(stage stage) { - u.stage = stage - // Halt and resume the scheduling once the running state changed. - running := isRunning(stage) - if opt := u.cluster.GetSchedulerConfig(); opt.IsSchedulingHalted() != running { - opt.SetHaltScheduling(running, "online-unsafe-recovery") + // If the running stage changes, update the scheduling allowance status to add or remove "online-unsafe-recovery" halt. + if running := isRunning(stage); running != isRunning(u.stage) { + u.cluster.GetSchedulerConfig().SetSchedulingAllowanceStatus(running, "online-unsafe-recovery") } + u.stage = stage var output StageOutput output.Time = time.Now().Format("2006-01-02 15:04:05.000") diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index d6454f48ee36..5514be003ecc 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -890,6 +890,7 @@ func (c *RaftCluster) SetPDServerConfig(cfg *config.PDServerConfig) { c.opt.SetPDServerConfig(cfg) } +<<<<<<< HEAD // AddSuspectRegions adds regions to suspect list. func (c *RaftCluster) AddSuspectRegions(regionIDs ...uint64) { c.coordinator.GetCheckerController().AddSuspectRegions(regionIDs...) @@ -918,6 +919,14 @@ func (c *RaftCluster) GetLabelStats() *statistics.LabelStatistics { // RemoveSuspectRegion removes region from suspect list. func (c *RaftCluster) RemoveSuspectRegion(id uint64) { c.coordinator.GetCheckerController().RemoveSuspectRegion(id) +======= +// IsSchedulingHalted returns whether the scheduling is halted. +// Currently, the PD scheduling is halted when: +// - The `HaltScheduling` persist option is set to true. +// - Online unsafe recovery is running. +func (c *RaftCluster) IsSchedulingHalted() bool { + return c.opt.IsSchedulingHalted() || c.unsafeRecoveryController.IsRunning() +>>>>>>> 740f15e65 (*: individually check the scheduling halt for online unsafe recovery (#8147)) } // GetUnsafeRecoveryController returns the unsafe recovery controller. diff --git a/server/cluster/cluster_worker.go b/server/cluster/cluster_worker.go index c1da97363b53..fd1501aba725 100644 --- a/server/cluster/cluster_worker.go +++ b/server/cluster/cluster_worker.go @@ -43,7 +43,7 @@ func (c *RaftCluster) HandleRegionHeartbeat(region *core.RegionInfo) error { // HandleAskSplit handles the split request. func (c *RaftCluster) HandleAskSplit(request *pdpb.AskSplitRequest) (*pdpb.AskSplitResponse, error) { - if c.isSchedulingHalted() { + if c.IsSchedulingHalted() { return nil, errs.ErrSchedulingIsHalted.FastGenByArgs() } if !c.opt.IsTikvRegionSplitEnabled() { @@ -86,6 +86,7 @@ func (c *RaftCluster) HandleAskSplit(request *pdpb.AskSplitRequest) (*pdpb.AskSp return split, nil } +<<<<<<< HEAD func (c *RaftCluster) isSchedulingHalted() bool { return c.opt.IsSchedulingHalted() } @@ -107,9 +108,11 @@ func (c *RaftCluster) ValidRequestRegion(reqRegion *metapb.Region) error { return nil } +======= +>>>>>>> 740f15e65 (*: individually check the scheduling halt for online unsafe recovery (#8147)) // HandleAskBatchSplit handles the batch split request. func (c *RaftCluster) HandleAskBatchSplit(request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { - if c.isSchedulingHalted() { + if c.IsSchedulingHalted() { return nil, errs.ErrSchedulingIsHalted.FastGenByArgs() } if !c.opt.IsTikvRegionSplitEnabled() { diff --git a/server/config/persist_options.go b/server/config/persist_options.go index c0a0ebf5c47a..032e3736a98d 100644 --- a/server/config/persist_options.go +++ b/server/config/persist_options.go @@ -973,11 +973,8 @@ func (o *PersistOptions) SetAllStoresLimitTTL(ctx context.Context, client *clien var haltSchedulingStatus = schedulingAllowanceStatusGauge.WithLabelValues("halt-scheduling") -// SetHaltScheduling set HaltScheduling. -func (o *PersistOptions) SetHaltScheduling(halt bool, source string) { - v := o.GetScheduleConfig().Clone() - v.HaltScheduling = halt - o.SetScheduleConfig(v) +// SetSchedulingAllowanceStatus sets the scheduling allowance status to help distinguish the source of the halt. +func (*PersistOptions) SetSchedulingAllowanceStatus(halt bool, source string) { if halt { haltSchedulingStatus.Set(1) schedulingAllowanceStatusGauge.WithLabelValues(source).Set(1) @@ -987,6 +984,14 @@ func (o *PersistOptions) SetHaltScheduling(halt bool, source string) { } } +// SetHaltScheduling set HaltScheduling. +func (o *PersistOptions) SetHaltScheduling(halt bool, source string) { + v := o.GetScheduleConfig().Clone() + v.HaltScheduling = halt + o.SetScheduleConfig(v) + o.SetSchedulingAllowanceStatus(halt, source) +} + // IsSchedulingHalted returns if PD scheduling is halted. func (o *PersistOptions) IsSchedulingHalted() bool { if o == nil { diff --git a/server/forward.go b/server/forward.go new file mode 100644 index 000000000000..650833e1fc17 --- /dev/null +++ b/server/forward.go @@ -0,0 +1,505 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "io" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/kvproto/pkg/schedulingpb" + "github.com/pingcap/kvproto/pkg/tsopb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/mcs/utils" + "github.com/tikv/pd/pkg/tso" + "github.com/tikv/pd/pkg/utils/grpcutil" + "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/pkg/utils/tsoutil" + "github.com/tikv/pd/server/cluster" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func forwardTSORequest( + ctx context.Context, + request *pdpb.TsoRequest, + forwardStream tsopb.TSO_TsoClient) (*tsopb.TsoResponse, error) { + tsopbReq := &tsopb.TsoRequest{ + Header: &tsopb.RequestHeader{ + ClusterId: request.GetHeader().GetClusterId(), + SenderId: request.GetHeader().GetSenderId(), + KeyspaceId: utils.DefaultKeyspaceID, + KeyspaceGroupId: utils.DefaultKeyspaceGroupID, + }, + Count: request.GetCount(), + DcLocation: request.GetDcLocation(), + } + + failpoint.Inject("tsoProxySendToTSOTimeout", func() { + // block until watchDeadline routine cancels the context. + <-ctx.Done() + }) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if err := forwardStream.Send(tsopbReq); err != nil { + return nil, err + } + + failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() { + // block until watchDeadline routine cancels the context. + <-ctx.Done() + }) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + return forwardStream.Recv() +} + +// forwardTSO forward the TSO requests to the TSO service. +func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { + var ( + server = &tsoServer{stream: stream} + forwardStream tsopb.TSO_TsoClient + forwardCtx context.Context + cancelForward context.CancelFunc + tsoStreamErr error + lastForwardedHost string + ) + defer func() { + s.concurrentTSOProxyStreamings.Add(-1) + if cancelForward != nil { + cancelForward() + } + if grpcutil.NeedRebuildConnection(tsoStreamErr) { + s.closeDelegateClient(lastForwardedHost) + } + }() + + maxConcurrentTSOProxyStreamings := int32(s.GetMaxConcurrentTSOProxyStreamings()) + if maxConcurrentTSOProxyStreamings >= 0 { + if newCount := s.concurrentTSOProxyStreamings.Add(1); newCount > maxConcurrentTSOProxyStreamings { + return errors.WithStack(ErrMaxCountTSOProxyRoutinesExceeded) + } + } + + tsDeadlineCh := make(chan *tsoutil.TSDeadline, 1) + go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh) + + for { + select { + case <-s.ctx.Done(): + return errors.WithStack(s.ctx.Err()) + case <-stream.Context().Done(): + return stream.Context().Err() + default: + } + + request, err := server.Recv(s.GetTSOProxyRecvFromClientTimeout()) + if err == io.EOF { + return nil + } + if err != nil { + return errors.WithStack(err) + } + if request.GetCount() == 0 { + err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") + return status.Errorf(codes.Unknown, err.Error()) + } + + forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.TSOServiceName) + if !ok || len(forwardedHost) == 0 { + tsoStreamErr = errors.WithStack(ErrNotFoundTSOAddr) + return tsoStreamErr + } + if forwardStream == nil || lastForwardedHost != forwardedHost { + if cancelForward != nil { + cancelForward() + } + + clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) + if err != nil { + tsoStreamErr = errors.WithStack(err) + return tsoStreamErr + } + forwardStream, forwardCtx, cancelForward, err = createTSOForwardStream(stream.Context(), clientConn) + if err != nil { + tsoStreamErr = errors.WithStack(err) + return tsoStreamErr + } + lastForwardedHost = forwardedHost + } + + tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) + if err != nil { + tsoStreamErr = errors.WithStack(err) + return tsoStreamErr + } + + // The error types defined for tsopb and pdpb are different, so we need to convert them. + var pdpbErr *pdpb.Error + tsopbErr := tsopbResp.GetHeader().GetError() + if tsopbErr != nil { + if tsopbErr.Type == tsopb.ErrorType_OK { + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_OK, + Message: tsopbErr.GetMessage(), + } + } else { + // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_UNKNOWN, + Message: tsopbErr.GetMessage(), + } + } + } + + response := &pdpb.TsoResponse{ + Header: &pdpb.ResponseHeader{ + ClusterId: tsopbResp.GetHeader().GetClusterId(), + Error: pdpbErr, + }, + Count: tsopbResp.GetCount(), + Timestamp: tsopbResp.GetTimestamp(), + } + if err := server.Send(response); err != nil { + return errors.WithStack(err) + } + } +} + +func (s *GrpcServer) forwardTSORequestWithDeadLine( + forwardCtx context.Context, + cancelForward context.CancelFunc, + forwardStream tsopb.TSO_TsoClient, + request *pdpb.TsoRequest, + tsDeadlineCh chan<- *tsoutil.TSDeadline) (*tsopb.TsoResponse, error) { + done := make(chan struct{}) + dl := tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, cancelForward) + select { + case tsDeadlineCh <- dl: + case <-forwardCtx.Done(): + return nil, forwardCtx.Err() + } + + start := time.Now() + resp, err := forwardTSORequest(forwardCtx, request, forwardStream) + close(done) + if err != nil { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.tsoPrimaryWatcher.ForceLoad() + } + return nil, err + } + tsoProxyBatchSize.Observe(float64(request.GetCount())) + tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) + return resp, nil +} + +func createTSOForwardStream(ctx context.Context, client *grpc.ClientConn) (tsopb.TSO_TsoClient, context.Context, context.CancelFunc, error) { + done := make(chan struct{}) + forwardCtx, cancelForward := context.WithCancel(ctx) + go grpcutil.CheckStream(forwardCtx, cancelForward, done) + forwardStream, err := tsopb.NewTSOClient(client).Tso(forwardCtx) + done <- struct{}{} + return forwardStream, forwardCtx, cancelForward, err +} + +func (s *GrpcServer) createRegionHeartbeatForwardStream(client *grpc.ClientConn) (pdpb.PD_RegionHeartbeatClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err := pdpb.NewPDClient(client).RegionHeartbeat(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + +func createRegionHeartbeatSchedulingStream(ctx context.Context, client *grpc.ClientConn) (schedulingpb.Scheduling_RegionHeartbeatClient, context.Context, context.CancelFunc, error) { + done := make(chan struct{}) + forwardCtx, cancelForward := context.WithCancel(ctx) + go grpcutil.CheckStream(forwardCtx, cancelForward, done) + forwardStream, err := schedulingpb.NewSchedulingClient(client).RegionHeartbeat(forwardCtx) + done <- struct{}{} + return forwardStream, forwardCtx, cancelForward, err +} + +func forwardRegionHeartbeatToScheduling(rc *cluster.RaftCluster, forwardStream schedulingpb.Scheduling_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.Recv() + if err == io.EOF { + errCh <- errors.WithStack(err) + return + } + if err != nil { + errCh <- errors.WithStack(err) + return + } + // TODO: find a better way to halt scheduling immediately. + if rc.IsSchedulingHalted() { + continue + } + // The error types defined for schedulingpb and pdpb are different, so we need to convert them. + var pdpbErr *pdpb.Error + schedulingpbErr := resp.GetHeader().GetError() + if schedulingpbErr != nil { + if schedulingpbErr.Type == schedulingpb.ErrorType_OK { + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_OK, + Message: schedulingpbErr.GetMessage(), + } + } else { + // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_UNKNOWN, + Message: schedulingpbErr.GetMessage(), + } + } + } + response := &pdpb.RegionHeartbeatResponse{ + Header: &pdpb.ResponseHeader{ + ClusterId: resp.GetHeader().GetClusterId(), + Error: pdpbErr, + }, + ChangePeer: resp.GetChangePeer(), + TransferLeader: resp.GetTransferLeader(), + RegionId: resp.GetRegionId(), + RegionEpoch: resp.GetRegionEpoch(), + TargetPeer: resp.GetTargetPeer(), + Merge: resp.GetMerge(), + SplitRegion: resp.GetSplitRegion(), + ChangePeerV2: resp.GetChangePeerV2(), + SwitchWitnesses: resp.GetSwitchWitnesses(), + } + + if err := server.Send(response); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func forwardRegionHeartbeatClientToServer(forwardStream pdpb.PD_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.Recv() + if err != nil { + errCh <- errors.WithStack(err) + return + } + if err := server.Send(resp); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func forwardReportBucketClientToServer(forwardStream pdpb.PD_ReportBucketsClient, server *bucketHeartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.CloseAndRecv() + if err != nil { + errCh <- errors.WithStack(err) + return + } + if err := server.Send(resp); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func (s *GrpcServer) createReportBucketsForwardStream(client *grpc.ClientConn) (pdpb.PD_ReportBucketsClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err := pdpb.NewPDClient(client).ReportBuckets(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + +func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string) (*grpc.ClientConn, error) { + client, ok := s.clientConns.Load(forwardedHost) + if ok { + // Mostly, the connection is already established, and return it directly. + return client.(*grpc.ClientConn), nil + } + + tlsConfig, err := s.GetTLSConfig().ToTLSConfig() + if err != nil { + return nil, err + } + ctxTimeout, cancel := context.WithTimeout(ctx, defaultGRPCDialTimeout) + defer cancel() + newConn, err := grpcutil.GetClientConn(ctxTimeout, forwardedHost, tlsConfig) + if err != nil { + return nil, err + } + conn, loaded := s.clientConns.LoadOrStore(forwardedHost, newConn) + if !loaded { + // Successfully stored the connection we created. + return newConn, nil + } + // Loaded a connection created/stored by another goroutine, so close the one we created + // and return the one we loaded. + newConn.Close() + return conn.(*grpc.ClientConn), nil +} + +func (s *GrpcServer) closeDelegateClient(forwardedHost string) { + client, ok := s.clientConns.LoadAndDelete(forwardedHost) + if !ok { + return + } + client.(*grpc.ClientConn).Close() + log.Debug("close delegate client connection", zap.String("forwarded-host", forwardedHost)) +} + +func (s *GrpcServer) isLocalRequest(host string) bool { + failpoint.Inject("useForwardRequest", func() { + failpoint.Return(false) + }) + if host == "" { + return true + } + memberAddrs := s.GetMember().Member().GetClientUrls() + for _, addr := range memberAddrs { + if addr == host { + return true + } + } + return false +} + +func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { + if !s.IsAPIServiceMode() { + return s.tsoAllocatorManager.HandleRequest(ctx, tso.GlobalDCLocation, 1) + } + request := &tsopb.TsoRequest{ + Header: &tsopb.RequestHeader{ + ClusterId: s.ClusterID(), + KeyspaceId: utils.DefaultKeyspaceID, + KeyspaceGroupId: utils.DefaultKeyspaceGroupID, + }, + Count: 1, + } + var ( + forwardedHost string + forwardStream tsopb.TSO_TsoClient + ts *tsopb.TsoResponse + err error + ok bool + ) + handleStreamError := func(err error) (needRetry bool) { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.tsoPrimaryWatcher.ForceLoad() + log.Warn("force to load tso primary address due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + if grpcutil.NeedRebuildConnection(err) { + s.tsoClientPool.Lock() + delete(s.tsoClientPool.clients, forwardedHost) + s.tsoClientPool.Unlock() + log.Warn("client connection removed due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + return false + } + for i := 0; i < maxRetryTimesRequestTSOServer; i++ { + if i > 0 { + time.Sleep(retryIntervalRequestTSOServer) + } + forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) + if !ok || forwardedHost == "" { + return pdpb.Timestamp{}, ErrNotFoundTSOAddr + } + forwardStream, err = s.getTSOForwardStream(forwardedHost) + if err != nil { + return pdpb.Timestamp{}, err + } + err = forwardStream.Send(request) + if err != nil { + if needRetry := handleStreamError(err); needRetry { + continue + } + log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err + } + ts, err = forwardStream.Recv() + if err != nil { + if needRetry := handleStreamError(err); needRetry { + continue + } + log.Error("receive response from tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err + } + return *ts.GetTimestamp(), nil + } + log.Error("get global tso from tso primary server failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err +} + +func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoClient, error) { + s.tsoClientPool.RLock() + forwardStream, ok := s.tsoClientPool.clients[forwardedHost] + s.tsoClientPool.RUnlock() + if ok { + // This is the common case to return here + return forwardStream, nil + } + + s.tsoClientPool.Lock() + defer s.tsoClientPool.Unlock() + + // Double check after entering the critical section + forwardStream, ok = s.tsoClientPool.clients[forwardedHost] + if ok { + return forwardStream, nil + } + + // Now let's create the client connection and the forward stream + client, err := s.getDelegateClient(s.ctx, forwardedHost) + if err != nil { + return nil, err + } + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err = tsopb.NewTSOClient(client).Tso(ctx) + done <- struct{}{} + if err != nil { + return nil, err + } + s.tsoClientPool.clients[forwardedHost] = forwardStream + return forwardStream, nil +} diff --git a/server/server.go b/server/server.go index 1d2c4516abef..a4ff1405d608 100644 --- a/server/server.go +++ b/server/server.go @@ -998,6 +998,7 @@ func (s *Server) GetScheduleConfig() *sc.ScheduleConfig { } // SetScheduleConfig sets the balance config information. +// This function is exported to be used by the API. func (s *Server) SetScheduleConfig(cfg sc.ScheduleConfig) error { if err := cfg.Validate(); err != nil { return err @@ -1016,6 +1017,8 @@ func (s *Server) SetScheduleConfig(cfg sc.ScheduleConfig) error { errs.ZapError(err)) return err } + // Update the scheduling halt status at the same time. + s.persistOptions.SetSchedulingAllowanceStatus(cfg.HaltScheduling, "manually") log.Info("schedule config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) return nil }