diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go new file mode 100644 index 00000000000..c36964a059c --- /dev/null +++ b/pkg/mcs/scheduling/server/cluster.go @@ -0,0 +1,701 @@ +package server + +import ( + "context" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/kvproto/pkg/schedulingpb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/cluster" + "github.com/tikv/pd/pkg/core" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/mcs/scheduling/server/config" + "github.com/tikv/pd/pkg/ratelimit" + "github.com/tikv/pd/pkg/schedule" + sc "github.com/tikv/pd/pkg/schedule/config" + "github.com/tikv/pd/pkg/schedule/hbstream" + "github.com/tikv/pd/pkg/schedule/labeler" + "github.com/tikv/pd/pkg/schedule/operator" + "github.com/tikv/pd/pkg/schedule/placement" + "github.com/tikv/pd/pkg/schedule/scatter" + "github.com/tikv/pd/pkg/schedule/schedulers" + "github.com/tikv/pd/pkg/schedule/splitter" + "github.com/tikv/pd/pkg/slice" + "github.com/tikv/pd/pkg/statistics" + "github.com/tikv/pd/pkg/statistics/buckets" + "github.com/tikv/pd/pkg/statistics/utils" + "github.com/tikv/pd/pkg/storage" + "github.com/tikv/pd/pkg/utils/logutil" + "go.uber.org/zap" +) + +// Cluster is used to manage all information for scheduling purpose. +type Cluster struct { + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + *core.BasicCluster + persistConfig *config.PersistConfig + ruleManager *placement.RuleManager + labelerManager *labeler.RegionLabeler + regionStats *statistics.RegionStatistics + labelStats *statistics.LabelStatistics + hotStat *statistics.HotStat + storage storage.Storage + coordinator *schedule.Coordinator + checkMembershipCh chan struct{} + apiServerLeader atomic.Value + clusterID uint64 + running atomic.Bool + + heartbeatRunnner ratelimit.Runner + logRunner ratelimit.Runner +} + +const ( + regionLabelGCInterval = time.Hour + requestTimeout = 3 * time.Second + collectWaitTime = time.Minute + + // heartbeat relative const + heartbeatTaskRunner = "heartbeat-task-runner" + logTaskRunner = "log-task-runner" +) + +var syncRunner = ratelimit.NewSyncRunner() + +// NewCluster creates a new cluster. +func NewCluster(parentCtx context.Context, persistConfig *config.PersistConfig, storage storage.Storage, basicCluster *core.BasicCluster, hbStreams *hbstream.HeartbeatStreams, clusterID uint64, checkMembershipCh chan struct{}) (*Cluster, error) { + ctx, cancel := context.WithCancel(parentCtx) + labelerManager, err := labeler.NewRegionLabeler(ctx, storage, regionLabelGCInterval) + if err != nil { + cancel() + return nil, err + } + ruleManager := placement.NewRuleManager(ctx, storage, basicCluster, persistConfig) + c := &Cluster{ + ctx: ctx, + cancel: cancel, + BasicCluster: basicCluster, + ruleManager: ruleManager, + labelerManager: labelerManager, + persistConfig: persistConfig, + hotStat: statistics.NewHotStat(ctx), + labelStats: statistics.NewLabelStatistics(), + regionStats: statistics.NewRegionStatistics(basicCluster, persistConfig, ruleManager), + storage: storage, + clusterID: clusterID, + checkMembershipCh: checkMembershipCh, + + heartbeatRunnner: ratelimit.NewConcurrentRunner(heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + logRunner: ratelimit.NewConcurrentRunner(logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + } + c.coordinator = schedule.NewCoordinator(ctx, c, hbStreams) + err = c.ruleManager.Initialize(persistConfig.GetMaxReplicas(), persistConfig.GetLocationLabels(), persistConfig.GetIsolationLevel()) + if err != nil { + cancel() + return nil, err + } + return c, nil +} + +// GetCoordinator returns the coordinator +func (c *Cluster) GetCoordinator() *schedule.Coordinator { + return c.coordinator +} + +// GetHotStat gets hot stat. +func (c *Cluster) GetHotStat() *statistics.HotStat { + return c.hotStat +} + +// GetStoresStats returns stores' statistics from cluster. +// And it will be unnecessary to filter unhealthy store, because it has been solved in process heartbeat +func (c *Cluster) GetStoresStats() *statistics.StoresStats { + return c.hotStat.StoresStats +} + +// GetRegionStats gets region statistics. +func (c *Cluster) GetRegionStats() *statistics.RegionStatistics { + return c.regionStats +} + +// GetLabelStats gets label statistics. +func (c *Cluster) GetLabelStats() *statistics.LabelStatistics { + return c.labelStats +} + +// GetBasicCluster returns the basic cluster. +func (c *Cluster) GetBasicCluster() *core.BasicCluster { + return c.BasicCluster +} + +// GetSharedConfig returns the shared config. +func (c *Cluster) GetSharedConfig() sc.SharedConfigProvider { + return c.persistConfig +} + +// GetRuleManager returns the rule manager. +func (c *Cluster) GetRuleManager() *placement.RuleManager { + return c.ruleManager +} + +// GetRegionLabeler returns the region labeler. +func (c *Cluster) GetRegionLabeler() *labeler.RegionLabeler { + return c.labelerManager +} + +// GetRegionSplitter returns the region splitter. +func (c *Cluster) GetRegionSplitter() *splitter.RegionSplitter { + return c.coordinator.GetRegionSplitter() +} + +// GetRegionScatterer returns the region scatter. +func (c *Cluster) GetRegionScatterer() *scatter.RegionScatterer { + return c.coordinator.GetRegionScatterer() +} + +// GetStoresLoads returns load stats of all stores. +func (c *Cluster) GetStoresLoads() map[uint64][]float64 { + return c.hotStat.GetStoresLoads() +} + +// IsRegionHot checks if a region is in hot state. +func (c *Cluster) IsRegionHot(region *core.RegionInfo) bool { + return c.hotStat.IsRegionHot(region, c.persistConfig.GetHotRegionCacheHitsThreshold()) +} + +// GetHotPeerStat returns hot peer stat with specified regionID and storeID. +func (c *Cluster) GetHotPeerStat(rw utils.RWType, regionID, storeID uint64) *statistics.HotPeerStat { + return c.hotStat.GetHotPeerStat(rw, regionID, storeID) +} + +// RegionReadStats returns hot region's read stats. +// The result only includes peers that are hot enough. +// RegionStats is a thread-safe method +func (c *Cluster) RegionReadStats() map[uint64][]*statistics.HotPeerStat { + // As read stats are reported by store heartbeat, the threshold needs to be adjusted. + threshold := c.persistConfig.GetHotRegionCacheHitsThreshold() * + (utils.RegionHeartBeatReportInterval / utils.StoreHeartBeatReportInterval) + return c.hotStat.RegionStats(utils.Read, threshold) +} + +// RegionWriteStats returns hot region's write stats. +// The result only includes peers that are hot enough. +func (c *Cluster) RegionWriteStats() map[uint64][]*statistics.HotPeerStat { + // RegionStats is a thread-safe method + return c.hotStat.RegionStats(utils.Write, c.persistConfig.GetHotRegionCacheHitsThreshold()) +} + +// BucketsStats returns hot region's buckets stats. +func (c *Cluster) BucketsStats(degree int, regionIDs ...uint64) map[uint64][]*buckets.BucketStat { + return c.hotStat.BucketsStats(degree, regionIDs...) +} + +// GetStorage returns the storage. +func (c *Cluster) GetStorage() storage.Storage { + return c.storage +} + +// GetCheckerConfig returns the checker config. +func (c *Cluster) GetCheckerConfig() sc.CheckerConfigProvider { return c.persistConfig } + +// GetSchedulerConfig returns the scheduler config. +func (c *Cluster) GetSchedulerConfig() sc.SchedulerConfigProvider { return c.persistConfig } + +// GetStoreConfig returns the store config. +func (c *Cluster) GetStoreConfig() sc.StoreConfigProvider { return c.persistConfig } + +// AllocID allocates a new ID. +func (c *Cluster) AllocID() (uint64, error) { + client, err := c.getAPIServerLeaderClient() + if err != nil { + return 0, err + } + ctx, cancel := context.WithTimeout(c.ctx, requestTimeout) + defer cancel() + resp, err := client.AllocID(ctx, &pdpb.AllocIDRequest{Header: &pdpb.RequestHeader{ClusterId: c.clusterID}}) + if err != nil { + c.triggerMembershipCheck() + return 0, err + } + return resp.GetId(), nil +} + +func (c *Cluster) getAPIServerLeaderClient() (pdpb.PDClient, error) { + cli := c.apiServerLeader.Load() + if cli == nil { + c.triggerMembershipCheck() + return nil, errors.New("API server leader is not found") + } + return cli.(pdpb.PDClient), nil +} + +func (c *Cluster) triggerMembershipCheck() { + select { + case c.checkMembershipCh <- struct{}{}: + default: // avoid blocking + } +} + +// SwitchAPIServerLeader switches the API server leader. +func (c *Cluster) SwitchAPIServerLeader(new pdpb.PDClient) bool { + old := c.apiServerLeader.Load() + return c.apiServerLeader.CompareAndSwap(old, new) +} + +func trySend(notifier chan struct{}) { + select { + case notifier <- struct{}{}: + // If the channel is not empty, it means the check is triggered. + default: + } +} + +// updateScheduler listens on the schedulers updating notifier and manage the scheduler creation and deletion. +func (c *Cluster) updateScheduler() { + defer logutil.LogPanic() + defer c.wg.Done() + + // Make sure the coordinator has initialized all the existing schedulers. + c.waitSchedulersInitialized() + // Establish a notifier to listen the schedulers updating. + notifier := make(chan struct{}, 1) + // Make sure the check will be triggered once later. + trySend(notifier) + c.persistConfig.SetSchedulersUpdatingNotifier(notifier) + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + log.Info("cluster is closing, stop listening the schedulers updating notifier") + return + case <-notifier: + // This is triggered by the watcher when the schedulers are updated. + } + + if !c.running.Load() { + select { + case <-c.ctx.Done(): + log.Info("cluster is closing, stop listening the schedulers updating notifier") + return + case <-ticker.C: + // retry + trySend(notifier) + continue + } + } + + log.Info("schedulers updating notifier is triggered, try to update the scheduler") + var ( + schedulersController = c.coordinator.GetSchedulersController() + latestSchedulersConfig = c.persistConfig.GetScheduleConfig().Schedulers + ) + // Create the newly added schedulers. + for _, scheduler := range latestSchedulersConfig { + s, err := schedulers.CreateScheduler( + scheduler.Type, + c.coordinator.GetOperatorController(), + c.storage, + schedulers.ConfigSliceDecoder(scheduler.Type, scheduler.Args), + schedulersController.RemoveScheduler, + ) + if err != nil { + log.Error("failed to create scheduler", + zap.String("scheduler-type", scheduler.Type), + zap.Strings("scheduler-args", scheduler.Args), + errs.ZapError(err)) + continue + } + name := s.GetName() + if existed, _ := schedulersController.IsSchedulerExisted(name); existed { + log.Info("scheduler has already existed, skip adding it", + zap.String("scheduler-name", name), + zap.Strings("scheduler-args", scheduler.Args)) + continue + } + if err := schedulersController.AddScheduler(s, scheduler.Args...); err != nil { + log.Error("failed to add scheduler", + zap.String("scheduler-name", name), + zap.Strings("scheduler-args", scheduler.Args), + errs.ZapError(err)) + continue + } + log.Info("add scheduler successfully", + zap.String("scheduler-name", name), + zap.Strings("scheduler-args", scheduler.Args)) + } + // Remove the deleted schedulers. + for _, name := range schedulersController.GetSchedulerNames() { + scheduler := schedulersController.GetScheduler(name) + if slice.AnyOf(latestSchedulersConfig, func(i int) bool { + return latestSchedulersConfig[i].Type == scheduler.GetType() + }) { + continue + } + if err := schedulersController.RemoveScheduler(name); err != nil { + log.Error("failed to remove scheduler", + zap.String("scheduler-name", name), + errs.ZapError(err)) + continue + } + log.Info("remove scheduler successfully", + zap.String("scheduler-name", name)) + } + } +} + +func (c *Cluster) waitSchedulersInitialized() { + ticker := time.NewTicker(time.Millisecond * 100) + defer ticker.Stop() + for { + if c.coordinator.AreSchedulersInitialized() { + return + } + select { + case <-c.ctx.Done(): + log.Info("cluster is closing, stop waiting the schedulers initialization") + return + case <-ticker.C: + } + } +} + +// TODO: implement the following methods + +// UpdateRegionsLabelLevelStats updates the status of the region label level by types. +func (c *Cluster) UpdateRegionsLabelLevelStats(regions []*core.RegionInfo) { + for _, region := range regions { + c.labelStats.Observe(region, c.getStoresWithoutLabelLocked(region, core.EngineKey, core.EngineTiFlash), c.persistConfig.GetLocationLabels()) + } +} + +func (c *Cluster) getStoresWithoutLabelLocked(region *core.RegionInfo, key, value string) []*core.StoreInfo { + stores := make([]*core.StoreInfo, 0, len(region.GetPeers())) + for _, p := range region.GetPeers() { + if store := c.GetStore(p.GetStoreId()); store != nil && !core.IsStoreContainLabel(store.GetMeta(), key, value) { + stores = append(stores, store) + } + } + return stores +} + +// HandleStoreHeartbeat updates the store status. +func (c *Cluster) HandleStoreHeartbeat(heartbeat *schedulingpb.StoreHeartbeatRequest) error { + stats := heartbeat.GetStats() + storeID := stats.GetStoreId() + store := c.GetStore(storeID) + if store == nil { + return errors.Errorf("store %v not found", storeID) + } + + nowTime := time.Now() + newStore := store.Clone(core.SetStoreStats(stats), core.SetLastHeartbeatTS(nowTime)) + + if store := c.GetStore(storeID); store != nil { + statistics.UpdateStoreHeartbeatMetrics(store) + } + c.PutStore(newStore) + c.hotStat.Observe(storeID, newStore.GetStoreStats()) + c.hotStat.FilterUnhealthyStore(c) + reportInterval := stats.GetInterval() + interval := reportInterval.GetEndTimestamp() - reportInterval.GetStartTimestamp() + + regions := make(map[uint64]*core.RegionInfo, len(stats.GetPeerStats())) + for _, peerStat := range stats.GetPeerStats() { + regionID := peerStat.GetRegionId() + region := c.GetRegion(regionID) + regions[regionID] = region + if region == nil { + log.Warn("discard hot peer stat for unknown region", + zap.Uint64("region-id", regionID), + zap.Uint64("store-id", storeID)) + continue + } + peer := region.GetStorePeer(storeID) + if peer == nil { + log.Warn("discard hot peer stat for unknown region peer", + zap.Uint64("region-id", regionID), + zap.Uint64("store-id", storeID)) + continue + } + readQueryNum := core.GetReadQueryNum(peerStat.GetQueryStats()) + loads := []float64{ + utils.RegionReadBytes: float64(peerStat.GetReadBytes()), + utils.RegionReadKeys: float64(peerStat.GetReadKeys()), + utils.RegionReadQueryNum: float64(readQueryNum), + utils.RegionWriteBytes: 0, + utils.RegionWriteKeys: 0, + utils.RegionWriteQueryNum: 0, + } + peerInfo := core.NewPeerInfo(peer, loads, interval) + c.hotStat.CheckReadAsync(statistics.NewCheckPeerTask(peerInfo, region)) + } + + // Here we will compare the reported regions with the previous hot peers to decide if it is still hot. + c.hotStat.CheckReadAsync(statistics.NewCollectUnReportedPeerTask(storeID, regions, interval)) + return nil +} + +// runUpdateStoreStats updates store stats periodically. +func (c *Cluster) runUpdateStoreStats() { + defer logutil.LogPanic() + defer c.wg.Done() + + ticker := time.NewTicker(9 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + log.Info("update store stats background jobs has been stopped") + return + case <-ticker.C: + c.UpdateAllStoreStatus() + } + } +} + +// runCoordinator runs the main scheduling loop. +func (c *Cluster) runCoordinator() { + defer logutil.LogPanic() + defer c.wg.Done() + // force wait for 1 minute to make prepare checker won't be directly skipped + runCollectWaitTime := collectWaitTime + failpoint.Inject("changeRunCollectWaitTime", func() { + runCollectWaitTime = 1 * time.Second + }) + c.coordinator.RunUntilStop(runCollectWaitTime) +} + +func (c *Cluster) runMetricsCollectionJob() { + defer logutil.LogPanic() + defer c.wg.Done() + + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + log.Info("metrics are reset") + resetMetrics() + log.Info("metrics collection job has been stopped") + return + case <-ticker.C: + c.collectMetrics() + } + } +} + +func (c *Cluster) collectMetrics() { + statsMap := statistics.NewStoreStatisticsMap(c.persistConfig) + stores := c.GetStores() + for _, s := range stores { + statsMap.Observe(s) + statistics.ObserveHotStat(s, c.hotStat.StoresStats) + } + statsMap.Collect() + + c.coordinator.GetSchedulersController().CollectSchedulerMetrics() + c.coordinator.CollectHotSpotMetrics() + if c.regionStats == nil { + return + } + c.regionStats.Collect() + c.labelStats.Collect() + // collect hot cache metrics + c.hotStat.CollectMetrics() + // collect the lock metrics + c.RegionsInfo.CollectWaitLockMetrics() +} + +func resetMetrics() { + statistics.Reset() + schedulers.ResetSchedulerMetrics() + schedule.ResetHotSpotMetrics() +} + +// StartBackgroundJobs starts background jobs. +func (c *Cluster) StartBackgroundJobs() { + c.wg.Add(4) + go c.updateScheduler() + go c.runUpdateStoreStats() + go c.runCoordinator() + go c.runMetricsCollectionJob() + c.heartbeatRunnner.Start() + c.logRunner.Start() + c.running.Store(true) +} + +// StopBackgroundJobs stops background jobs. +func (c *Cluster) StopBackgroundJobs() { + if !c.running.Load() { + return + } + c.running.Store(false) + c.coordinator.Stop() + c.heartbeatRunnner.Stop() + c.logRunner.Stop() + c.cancel() + c.wg.Wait() +} + +// IsBackgroundJobsRunning returns whether the background jobs are running. Only for test purpose. +func (c *Cluster) IsBackgroundJobsRunning() bool { + return c.running.Load() +} + +// HandleRegionHeartbeat processes RegionInfo reports from client. +func (c *Cluster) HandleRegionHeartbeat(region *core.RegionInfo) error { + tracer := core.NewNoopHeartbeatProcessTracer() + if c.persistConfig.GetScheduleConfig().EnableHeartbeatBreakdownMetrics { + tracer = core.NewHeartbeatProcessTracer() + } + var taskRunner, logRunner ratelimit.Runner + taskRunner, logRunner = syncRunner, syncRunner + if c.persistConfig.GetScheduleConfig().EnableHeartbeatConcurrentRunner { + taskRunner = c.heartbeatRunnner + logRunner = c.logRunner + } + ctx := &core.MetaProcessContext{ + Context: c.ctx, + Tracer: tracer, + TaskRunner: taskRunner, + LogRunner: logRunner, + } + tracer.Begin() + if err := c.processRegionHeartbeat(ctx, region); err != nil { + tracer.OnAllStageFinished() + return err + } + tracer.OnAllStageFinished() + c.coordinator.GetOperatorController().Dispatch(region, operator.DispatchFromHeartBeat, c.coordinator.RecordOpStepWithTTL) + return nil +} + +// processRegionHeartbeat updates the region information. +func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *core.RegionInfo) error { + tracer := ctx.Tracer + origin, _, err := c.PreCheckPutRegion(region) + tracer.OnPreCheckFinished() + if err != nil { + return err + } + region.Inherit(origin, c.GetStoreConfig().IsEnableRegionBucket()) + + ctx.TaskRunner.RunTask( + ctx, + func(_ context.Context) { + cluster.HandleStatsAsync(c, region) + }, + ratelimit.WithTaskName(ratelimit.HandleStatsAsync), + ) + tracer.OnAsyncHotStatsFinished() + hasRegionStats := c.regionStats != nil + // Save to storage if meta is updated, except for flashback. + // Save to cache if meta or leader is updated, or contains any down/pending peer. + _, saveCache, _ := core.GenerateRegionGuideFunc(true)(ctx, region, origin) + + if !saveCache { + // Due to some config changes need to update the region stats as well, + // so we do some extra checks here. + if hasRegionStats && c.regionStats.RegionStatsNeedUpdate(region) { + ctx.TaskRunner.RunTask( + ctx, + func(_ context.Context) { + if c.regionStats.RegionStatsNeedUpdate(region) { + cluster.Collect(c, region, hasRegionStats) + } + }, + ratelimit.WithTaskName(ratelimit.ObserveRegionStatsAsync), + ) + } + // region is not updated to the subtree. + if origin.GetRef() < 2 { + ctx.TaskRunner.RunTask( + ctx, + func(_ context.Context) { + c.CheckAndPutSubTree(region) + }, + ratelimit.WithTaskName(ratelimit.UpdateSubTree), + ) + } + return nil + } + tracer.OnSaveCacheBegin() + var overlaps []*core.RegionInfo + if saveCache { + // To prevent a concurrent heartbeat of another region from overriding the up-to-date region info by a stale one, + // check its validation again here. + // + // However, it can't solve the race condition of concurrent heartbeats from the same region. + + // Async task in next PR. + if overlaps, err = c.CheckAndPutRootTree(ctx, region); err != nil { + tracer.OnSaveCacheFinished() + return err + } + ctx.TaskRunner.RunTask( + ctx, + func(_ context.Context) { + c.CheckAndPutSubTree(region) + }, + ratelimit.WithTaskName(ratelimit.UpdateSubTree), + ) + tracer.OnUpdateSubTreeFinished() + ctx.TaskRunner.RunTask( + ctx, + func(_ context.Context) { + cluster.HandleOverlaps(c, overlaps) + }, + ratelimit.WithTaskName(ratelimit.HandleOverlaps), + ) + } + tracer.OnSaveCacheFinished() + // handle region stats + ctx.TaskRunner.RunTask( + ctx, + func(_ context.Context) { + cluster.Collect(c, region, hasRegionStats) + }, + ratelimit.WithTaskName(ratelimit.CollectRegionStatsAsync), + ) + tracer.OnCollectRegionStatsFinished() + return nil +} + +// IsPrepared return true if the prepare checker is ready. +func (c *Cluster) IsPrepared() bool { + return c.coordinator.GetPrepareChecker().IsPrepared() +} + +// SetPrepared set the prepare check to prepared. Only for test purpose. +func (c *Cluster) SetPrepared() { + c.coordinator.GetPrepareChecker().SetPrepared() +} + +// DropCacheAllRegion removes all cached regions. +func (c *Cluster) DropCacheAllRegion() { + c.ResetRegionCache() +} + +// DropCacheRegion removes a region from the cache. +func (c *Cluster) DropCacheRegion(id uint64) { + c.RemoveRegionIfExist(id) +} + +// IsSchedulingHalted returns whether the scheduling is halted. +// Currently, the microservice scheduling is halted when: +// - The `HaltScheduling` persist option is set to true. +func (c *Cluster) IsSchedulingHalted() bool { + return c.persistConfig.IsSchedulingHalted() +} diff --git a/pkg/mcs/scheduling/server/config/config.go b/pkg/mcs/scheduling/server/config/config.go new file mode 100644 index 00000000000..9dc6590a0b4 --- /dev/null +++ b/pkg/mcs/scheduling/server/config/config.go @@ -0,0 +1,829 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "strconv" + "strings" + "sync/atomic" + "time" + "unsafe" + + "github.com/BurntSushi/toml" + "github.com/coreos/go-semver/semver" + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" + "github.com/spf13/pflag" + "github.com/tikv/pd/pkg/cache" + "github.com/tikv/pd/pkg/core/constant" + "github.com/tikv/pd/pkg/core/storelimit" + "github.com/tikv/pd/pkg/mcs/utils" + sc "github.com/tikv/pd/pkg/schedule/config" + "github.com/tikv/pd/pkg/slice" + "github.com/tikv/pd/pkg/storage/endpoint" + "github.com/tikv/pd/pkg/utils/configutil" + "github.com/tikv/pd/pkg/utils/grpcutil" + "github.com/tikv/pd/pkg/utils/metricutil" + "github.com/tikv/pd/pkg/utils/typeutil" + "go.uber.org/zap" +) + +const ( + defaultName = "Scheduling" + defaultBackendEndpoints = "http://127.0.0.1:2379" + defaultListenAddr = "http://127.0.0.1:3379" +) + +// Config is the configuration for the scheduling. +type Config struct { + BackendEndpoints string `toml:"backend-endpoints" json:"backend-endpoints"` + ListenAddr string `toml:"listen-addr" json:"listen-addr"` + AdvertiseListenAddr string `toml:"advertise-listen-addr" json:"advertise-listen-addr"` + Name string `toml:"name" json:"name"` + DataDir string `toml:"data-dir" json:"data-dir"` // TODO: remove this after refactoring + EnableGRPCGateway bool `json:"enable-grpc-gateway"` // TODO: use it + + Metric metricutil.MetricConfig `toml:"metric" json:"metric"` + + // Log related config. + Log log.Config `toml:"log" json:"log"` + Logger *zap.Logger `json:"-"` + LogProps *log.ZapProperties `json:"-"` + + Security configutil.SecurityConfig `toml:"security" json:"security"` + + // WarningMsgs contains all warnings during parsing. + WarningMsgs []string + + // LeaderLease defines the time within which a Scheduling primary/leader must + // update its TTL in etcd, otherwise etcd will expire the leader key and other servers + // can campaign the primary/leader again. Etcd only supports seconds TTL, so here is + // second too. + LeaderLease int64 `toml:"lease" json:"lease"` + + ClusterVersion semver.Version `toml:"cluster-version" json:"cluster-version"` + + Schedule sc.ScheduleConfig `toml:"schedule" json:"schedule"` + Replication sc.ReplicationConfig `toml:"replication" json:"replication"` +} + +// NewConfig creates a new config. +func NewConfig() *Config { + return &Config{} +} + +// Parse parses flag definitions from the argument list. +func (c *Config) Parse(flagSet *pflag.FlagSet) error { + // Load config file if specified. + var ( + meta *toml.MetaData + err error + ) + if configFile, _ := flagSet.GetString("config"); configFile != "" { + meta, err = configutil.ConfigFromFile(c, configFile) + if err != nil { + return err + } + } + + // Ignore the error check here + configutil.AdjustCommandLineString(flagSet, &c.Log.Level, "log-level") + configutil.AdjustCommandLineString(flagSet, &c.Log.File.Filename, "log-file") + configutil.AdjustCommandLineString(flagSet, &c.Metric.PushAddress, "metrics-addr") + configutil.AdjustCommandLineString(flagSet, &c.Security.CAPath, "cacert") + configutil.AdjustCommandLineString(flagSet, &c.Security.CertPath, "cert") + configutil.AdjustCommandLineString(flagSet, &c.Security.KeyPath, "key") + configutil.AdjustCommandLineString(flagSet, &c.BackendEndpoints, "backend-endpoints") + configutil.AdjustCommandLineString(flagSet, &c.ListenAddr, "listen-addr") + configutil.AdjustCommandLineString(flagSet, &c.AdvertiseListenAddr, "advertise-listen-addr") + return c.adjust(meta) +} + +// adjust is used to adjust the scheduling configurations. +func (c *Config) adjust(meta *toml.MetaData) error { + configMetaData := configutil.NewConfigMetadata(meta) + if err := configMetaData.CheckUndecoded(); err != nil { + c.WarningMsgs = append(c.WarningMsgs, err.Error()) + } + + if c.Name == "" { + hostname, err := os.Hostname() + if err != nil { + return err + } + configutil.AdjustString(&c.Name, fmt.Sprintf("%s-%s", defaultName, hostname)) + } + configutil.AdjustString(&c.DataDir, fmt.Sprintf("default.%s", c.Name)) + configutil.AdjustPath(&c.DataDir) + + if err := c.validate(); err != nil { + return err + } + + configutil.AdjustString(&c.BackendEndpoints, defaultBackendEndpoints) + configutil.AdjustString(&c.ListenAddr, defaultListenAddr) + configutil.AdjustString(&c.AdvertiseListenAddr, c.ListenAddr) + + if !configMetaData.IsDefined("enable-grpc-gateway") { + c.EnableGRPCGateway = utils.DefaultEnableGRPCGateway + } + + c.adjustLog(configMetaData.Child("log")) + c.Security.Encryption.Adjust() + + configutil.AdjustInt64(&c.LeaderLease, utils.DefaultLeaderLease) + + if err := c.Schedule.Adjust(configMetaData.Child("schedule"), false); err != nil { + return err + } + return c.Replication.Adjust(configMetaData.Child("replication")) +} + +func (c *Config) adjustLog(meta *configutil.ConfigMetaData) { + if !meta.IsDefined("disable-error-verbose") { + c.Log.DisableErrorVerbose = utils.DefaultDisableErrorVerbose + } + configutil.AdjustString(&c.Log.Format, utils.DefaultLogFormat) + configutil.AdjustString(&c.Log.Level, utils.DefaultLogLevel) +} + +// GetName returns the Name +func (c *Config) GetName() string { + return c.Name +} + +// GeBackendEndpoints returns the BackendEndpoints +func (c *Config) GeBackendEndpoints() string { + return c.BackendEndpoints +} + +// GetListenAddr returns the ListenAddr +func (c *Config) GetListenAddr() string { + return c.ListenAddr +} + +// GetAdvertiseListenAddr returns the AdvertiseListenAddr +func (c *Config) GetAdvertiseListenAddr() string { + return c.AdvertiseListenAddr +} + +// GetTLSConfig returns the TLS config. +func (c *Config) GetTLSConfig() *grpcutil.TLSConfig { + return &c.Security.TLSConfig +} + +// validate is used to validate if some configurations are right. +func (c *Config) validate() error { + dataDir, err := filepath.Abs(c.DataDir) + if err != nil { + return errors.WithStack(err) + } + logFile, err := filepath.Abs(c.Log.File.Filename) + if err != nil { + return errors.WithStack(err) + } + rel, err := filepath.Rel(dataDir, filepath.Dir(logFile)) + if err != nil { + return errors.WithStack(err) + } + if !strings.HasPrefix(rel, "..") { + return errors.New("log directory shouldn't be the subdirectory of data directory") + } + + return nil +} + +// Clone creates a copy of current config. +func (c *Config) Clone() *Config { + cfg := &Config{} + *cfg = *c + return cfg +} + +// PersistConfig wraps all configurations that need to persist to storage and +// allows to access them safely. +type PersistConfig struct { + ttl *cache.TTLString + // Store the global configuration that is related to the scheduling. + clusterVersion unsafe.Pointer + schedule atomic.Value + replication atomic.Value + storeConfig atomic.Value + // schedulersUpdatingNotifier is used to notify that the schedulers have been updated. + // Store as `chan<- struct{}`. + schedulersUpdatingNotifier atomic.Value +} + +// NewPersistConfig creates a new PersistConfig instance. +func NewPersistConfig(cfg *Config, ttl *cache.TTLString) *PersistConfig { + o := &PersistConfig{} + o.SetClusterVersion(&cfg.ClusterVersion) + o.schedule.Store(&cfg.Schedule) + o.replication.Store(&cfg.Replication) + // storeConfig will be fetched from TiKV by PD API server, + // so we just set an empty value here first. + o.storeConfig.Store(&sc.StoreConfig{}) + o.ttl = ttl + return o +} + +// SetSchedulersUpdatingNotifier sets the schedulers updating notifier. +func (o *PersistConfig) SetSchedulersUpdatingNotifier(notifier chan<- struct{}) { + o.schedulersUpdatingNotifier.Store(notifier) +} + +func (o *PersistConfig) getSchedulersUpdatingNotifier() chan<- struct{} { + v := o.schedulersUpdatingNotifier.Load() + if v == nil { + return nil + } + return v.(chan<- struct{}) +} + +func (o *PersistConfig) tryNotifySchedulersUpdating() { + notifier := o.getSchedulersUpdatingNotifier() + if notifier == nil { + return + } + notifier <- struct{}{} +} + +// GetClusterVersion returns the cluster version. +func (o *PersistConfig) GetClusterVersion() *semver.Version { + return (*semver.Version)(atomic.LoadPointer(&o.clusterVersion)) +} + +// SetClusterVersion sets the cluster version. +func (o *PersistConfig) SetClusterVersion(v *semver.Version) { + atomic.StorePointer(&o.clusterVersion, unsafe.Pointer(v)) +} + +// GetScheduleConfig returns the scheduling configurations. +func (o *PersistConfig) GetScheduleConfig() *sc.ScheduleConfig { + return o.schedule.Load().(*sc.ScheduleConfig) +} + +// SetScheduleConfig sets the scheduling configuration dynamically. +func (o *PersistConfig) SetScheduleConfig(cfg *sc.ScheduleConfig) { + old := o.GetScheduleConfig() + o.schedule.Store(cfg) + // The coordinator is not aware of the underlying scheduler config changes, + // we should notify it to update the schedulers proactively. + if !reflect.DeepEqual(old.Schedulers, cfg.Schedulers) { + o.tryNotifySchedulersUpdating() + } +} + +// AdjustScheduleCfg adjusts the schedule config during the initialization. +func AdjustScheduleCfg(scheduleCfg *sc.ScheduleConfig) { + // In case we add new default schedulers. + for _, ps := range sc.DefaultSchedulers { + if slice.NoneOf(scheduleCfg.Schedulers, func(i int) bool { + return scheduleCfg.Schedulers[i].Type == ps.Type + }) { + scheduleCfg.Schedulers = append(scheduleCfg.Schedulers, ps) + } + } +} + +// GetReplicationConfig returns replication configurations. +func (o *PersistConfig) GetReplicationConfig() *sc.ReplicationConfig { + return o.replication.Load().(*sc.ReplicationConfig) +} + +// SetReplicationConfig sets the PD replication configuration. +func (o *PersistConfig) SetReplicationConfig(cfg *sc.ReplicationConfig) { + o.replication.Store(cfg) +} + +// SetStoreConfig sets the TiKV store configuration. +func (o *PersistConfig) SetStoreConfig(cfg *sc.StoreConfig) { + // Some of the fields won't be persisted and watched, + // so we need to adjust it here before storing it. + cfg.Adjust() + o.storeConfig.Store(cfg) +} + +// GetStoreConfig returns the TiKV store configuration. +func (o *PersistConfig) GetStoreConfig() *sc.StoreConfig { + return o.storeConfig.Load().(*sc.StoreConfig) +} + +// GetMaxReplicas returns the max replicas. +func (o *PersistConfig) GetMaxReplicas() int { + return int(o.GetReplicationConfig().MaxReplicas) +} + +// IsPlacementRulesEnabled returns if the placement rules is enabled. +func (o *PersistConfig) IsPlacementRulesEnabled() bool { + return o.GetReplicationConfig().EnablePlacementRules +} + +// GetLowSpaceRatio returns the low space ratio. +func (o *PersistConfig) GetLowSpaceRatio() float64 { + return o.GetScheduleConfig().LowSpaceRatio +} + +// GetHighSpaceRatio returns the high space ratio. +func (o *PersistConfig) GetHighSpaceRatio() float64 { + return o.GetScheduleConfig().HighSpaceRatio +} + +// GetLeaderSchedulePolicy is to get leader schedule policy. +func (o *PersistConfig) GetLeaderSchedulePolicy() constant.SchedulePolicy { + return constant.StringToSchedulePolicy(o.GetScheduleConfig().LeaderSchedulePolicy) +} + +// GetMaxStoreDownTime returns the max store downtime. +func (o *PersistConfig) GetMaxStoreDownTime() time.Duration { + return o.GetScheduleConfig().MaxStoreDownTime.Duration +} + +// GetIsolationLevel returns the isolation label for each region. +func (o *PersistConfig) GetIsolationLevel() string { + return o.GetReplicationConfig().IsolationLevel +} + +// GetLocationLabels returns the location labels. +func (o *PersistConfig) GetLocationLabels() []string { + return o.GetReplicationConfig().LocationLabels +} + +// IsUseJointConsensus returns if the joint consensus is enabled. +func (o *PersistConfig) IsUseJointConsensus() bool { + return o.GetScheduleConfig().EnableJointConsensus +} + +// GetKeyType returns the key type. +func (*PersistConfig) GetKeyType() constant.KeyType { + return constant.StringToKeyType("table") +} + +// IsCrossTableMergeEnabled returns if the cross table merge is enabled. +func (o *PersistConfig) IsCrossTableMergeEnabled() bool { + return o.GetScheduleConfig().EnableCrossTableMerge +} + +// IsOneWayMergeEnabled returns if the one way merge is enabled. +func (o *PersistConfig) IsOneWayMergeEnabled() bool { + return o.GetScheduleConfig().EnableOneWayMerge +} + +// GetRegionScoreFormulaVersion returns the region score formula version. +func (o *PersistConfig) GetRegionScoreFormulaVersion() string { + return o.GetScheduleConfig().RegionScoreFormulaVersion +} + +// GetHotRegionCacheHitsThreshold returns the hot region cache hits threshold. +func (o *PersistConfig) GetHotRegionCacheHitsThreshold() int { + return int(o.GetScheduleConfig().HotRegionCacheHitsThreshold) +} + +// GetMaxMovableHotPeerSize returns the max movable hot peer size. +func (o *PersistConfig) GetMaxMovableHotPeerSize() int64 { + return o.GetScheduleConfig().MaxMovableHotPeerSize +} + +// GetSwitchWitnessInterval returns the interval between promote to non-witness and starting to switch to witness. +func (o *PersistConfig) GetSwitchWitnessInterval() time.Duration { + return o.GetScheduleConfig().SwitchWitnessInterval.Duration +} + +// GetSplitMergeInterval returns the interval between finishing split and starting to merge. +func (o *PersistConfig) GetSplitMergeInterval() time.Duration { + return o.GetScheduleConfig().SplitMergeInterval.Duration +} + +// GetSlowStoreEvictingAffectedStoreRatioThreshold returns the affected ratio threshold when judging a store is slow. +func (o *PersistConfig) GetSlowStoreEvictingAffectedStoreRatioThreshold() float64 { + return o.GetScheduleConfig().SlowStoreEvictingAffectedStoreRatioThreshold +} + +// GetPatrolRegionInterval returns the interval of patrolling region. +func (o *PersistConfig) GetPatrolRegionInterval() time.Duration { + return o.GetScheduleConfig().PatrolRegionInterval.Duration +} + +// GetTolerantSizeRatio gets the tolerant size ratio. +func (o *PersistConfig) GetTolerantSizeRatio() float64 { + return o.GetScheduleConfig().TolerantSizeRatio +} + +// IsDebugMetricsEnabled returns if debug metrics is enabled. +func (o *PersistConfig) IsDebugMetricsEnabled() bool { + return o.GetScheduleConfig().EnableDebugMetrics +} + +// IsDiagnosticAllowed returns whether is enable to use diagnostic. +func (o *PersistConfig) IsDiagnosticAllowed() bool { + return o.GetScheduleConfig().EnableDiagnostic +} + +// IsRemoveDownReplicaEnabled returns if remove down replica is enabled. +func (o *PersistConfig) IsRemoveDownReplicaEnabled() bool { + return o.GetScheduleConfig().EnableRemoveDownReplica +} + +// IsReplaceOfflineReplicaEnabled returns if replace offline replica is enabled. +func (o *PersistConfig) IsReplaceOfflineReplicaEnabled() bool { + return o.GetScheduleConfig().EnableReplaceOfflineReplica +} + +// IsMakeUpReplicaEnabled returns if make up replica is enabled. +func (o *PersistConfig) IsMakeUpReplicaEnabled() bool { + return o.GetScheduleConfig().EnableMakeUpReplica +} + +// IsRemoveExtraReplicaEnabled returns if remove extra replica is enabled. +func (o *PersistConfig) IsRemoveExtraReplicaEnabled() bool { + return o.GetScheduleConfig().EnableRemoveExtraReplica +} + +// IsWitnessAllowed returns if the witness is allowed. +func (o *PersistConfig) IsWitnessAllowed() bool { + return o.GetScheduleConfig().EnableWitness +} + +// IsPlacementRulesCacheEnabled returns if the placement rules cache is enabled. +func (o *PersistConfig) IsPlacementRulesCacheEnabled() bool { + return o.GetReplicationConfig().EnablePlacementRulesCache +} + +// IsSchedulingHalted returns if PD scheduling is halted. +func (o *PersistConfig) IsSchedulingHalted() bool { + return o.GetScheduleConfig().HaltScheduling +} + +// GetStoresLimit gets the stores' limit. +func (o *PersistConfig) GetStoresLimit() map[uint64]sc.StoreLimitConfig { + return o.GetScheduleConfig().StoreLimit +} + +// TTL related methods. + +// GetLeaderScheduleLimit returns the limit for leader schedule. +func (o *PersistConfig) GetLeaderScheduleLimit() uint64 { + return o.getTTLUintOr(sc.LeaderScheduleLimitKey, o.GetScheduleConfig().LeaderScheduleLimit) +} + +// GetRegionScheduleLimit returns the limit for region schedule. +func (o *PersistConfig) GetRegionScheduleLimit() uint64 { + return o.getTTLUintOr(sc.RegionScheduleLimitKey, o.GetScheduleConfig().RegionScheduleLimit) +} + +// GetWitnessScheduleLimit returns the limit for region schedule. +func (o *PersistConfig) GetWitnessScheduleLimit() uint64 { + return o.getTTLUintOr(sc.WitnessScheduleLimitKey, o.GetScheduleConfig().WitnessScheduleLimit) +} + +// GetReplicaScheduleLimit returns the limit for replica schedule. +func (o *PersistConfig) GetReplicaScheduleLimit() uint64 { + return o.getTTLUintOr(sc.ReplicaRescheduleLimitKey, o.GetScheduleConfig().ReplicaScheduleLimit) +} + +// GetMergeScheduleLimit returns the limit for merge schedule. +func (o *PersistConfig) GetMergeScheduleLimit() uint64 { + return o.getTTLUintOr(sc.MergeScheduleLimitKey, o.GetScheduleConfig().MergeScheduleLimit) +} + +// GetHotRegionScheduleLimit returns the limit for hot region schedule. +func (o *PersistConfig) GetHotRegionScheduleLimit() uint64 { + return o.getTTLUintOr(sc.HotRegionScheduleLimitKey, o.GetScheduleConfig().HotRegionScheduleLimit) +} + +// GetStoreLimit returns the limit of a store. +func (o *PersistConfig) GetStoreLimit(storeID uint64) (returnSC sc.StoreLimitConfig) { + defer func() { + returnSC.RemovePeer = o.getTTLFloatOr(fmt.Sprintf("remove-peer-%v", storeID), returnSC.RemovePeer) + returnSC.AddPeer = o.getTTLFloatOr(fmt.Sprintf("add-peer-%v", storeID), returnSC.AddPeer) + }() + if limit, ok := o.GetScheduleConfig().StoreLimit[storeID]; ok { + return limit + } + cfg := o.GetScheduleConfig().Clone() + sc := sc.StoreLimitConfig{ + AddPeer: sc.DefaultStoreLimit.GetDefaultStoreLimit(storelimit.AddPeer), + RemovePeer: sc.DefaultStoreLimit.GetDefaultStoreLimit(storelimit.RemovePeer), + } + v, ok1, err := o.getTTLFloat("default-add-peer") + if err != nil { + log.Warn("failed to parse default-add-peer from PersistOptions's ttl storage", zap.Error(err)) + } + canSetAddPeer := ok1 && err == nil + if canSetAddPeer { + returnSC.AddPeer = v + } + + v, ok2, err := o.getTTLFloat("default-remove-peer") + if err != nil { + log.Warn("failed to parse default-remove-peer from PersistOptions's ttl storage", zap.Error(err)) + } + canSetRemovePeer := ok2 && err == nil + if canSetRemovePeer { + returnSC.RemovePeer = v + } + + if canSetAddPeer || canSetRemovePeer { + return returnSC + } + cfg.StoreLimit[storeID] = sc + o.SetScheduleConfig(cfg) + return o.GetScheduleConfig().StoreLimit[storeID] +} + +// GetStoreLimitByType returns the limit of a store with a given type. +func (o *PersistConfig) GetStoreLimitByType(storeID uint64, typ storelimit.Type) (returned float64) { + defer func() { + if typ == storelimit.RemovePeer { + returned = o.getTTLFloatOr(fmt.Sprintf("remove-peer-%v", storeID), returned) + } else if typ == storelimit.AddPeer { + returned = o.getTTLFloatOr(fmt.Sprintf("add-peer-%v", storeID), returned) + } + }() + limit := o.GetStoreLimit(storeID) + switch typ { + case storelimit.AddPeer: + return limit.AddPeer + case storelimit.RemovePeer: + return limit.RemovePeer + // todo: impl it in store limit v2. + case storelimit.SendSnapshot: + return 0.0 + default: + panic("no such limit type") + } +} + +// GetMaxSnapshotCount returns the number of the max snapshot which is allowed to send. +func (o *PersistConfig) GetMaxSnapshotCount() uint64 { + return o.getTTLUintOr(sc.MaxSnapshotCountKey, o.GetScheduleConfig().MaxSnapshotCount) +} + +// GetMaxPendingPeerCount returns the number of the max pending peers. +func (o *PersistConfig) GetMaxPendingPeerCount() uint64 { + return o.getTTLUintOr(sc.MaxPendingPeerCountKey, o.GetScheduleConfig().MaxPendingPeerCount) +} + +// GetMaxMergeRegionSize returns the max region size. +func (o *PersistConfig) GetMaxMergeRegionSize() uint64 { + return o.getTTLUintOr(sc.MaxMergeRegionSizeKey, o.GetScheduleConfig().MaxMergeRegionSize) +} + +// GetMaxMergeRegionKeys returns the max number of keys. +// It returns size * 10000 if the key of max-merge-region-Keys doesn't exist. +func (o *PersistConfig) GetMaxMergeRegionKeys() uint64 { + keys, exist, err := o.getTTLUint(sc.MaxMergeRegionKeysKey) + if exist && err == nil { + return keys + } + size, exist, err := o.getTTLUint(sc.MaxMergeRegionSizeKey) + if exist && err == nil { + return size * 10000 + } + return o.GetScheduleConfig().GetMaxMergeRegionKeys() +} + +// GetSchedulerMaxWaitingOperator returns the number of the max waiting operators. +func (o *PersistConfig) GetSchedulerMaxWaitingOperator() uint64 { + return o.getTTLUintOr(sc.SchedulerMaxWaitingOperatorKey, o.GetScheduleConfig().SchedulerMaxWaitingOperator) +} + +// IsLocationReplacementEnabled returns if location replace is enabled. +func (o *PersistConfig) IsLocationReplacementEnabled() bool { + return o.getTTLBoolOr(sc.EnableLocationReplacement, o.GetScheduleConfig().EnableLocationReplacement) +} + +// IsTikvRegionSplitEnabled returns whether tikv split region is enabled. +func (o *PersistConfig) IsTikvRegionSplitEnabled() bool { + return o.getTTLBoolOr(sc.EnableTiKVSplitRegion, o.GetScheduleConfig().EnableTiKVSplitRegion) +} + +// SetAllStoresLimit sets all store limit for a given type and rate. +func (o *PersistConfig) SetAllStoresLimit(typ storelimit.Type, ratePerMin float64) { + v := o.GetScheduleConfig().Clone() + switch typ { + case storelimit.AddPeer: + sc.DefaultStoreLimit.SetDefaultStoreLimit(storelimit.AddPeer, ratePerMin) + for storeID := range v.StoreLimit { + sc := sc.StoreLimitConfig{AddPeer: ratePerMin, RemovePeer: v.StoreLimit[storeID].RemovePeer} + v.StoreLimit[storeID] = sc + } + case storelimit.RemovePeer: + sc.DefaultStoreLimit.SetDefaultStoreLimit(storelimit.RemovePeer, ratePerMin) + for storeID := range v.StoreLimit { + sc := sc.StoreLimitConfig{AddPeer: v.StoreLimit[storeID].AddPeer, RemovePeer: ratePerMin} + v.StoreLimit[storeID] = sc + } + } + + o.SetScheduleConfig(v) +} + +// SetMaxReplicas sets the number of replicas for each region. +func (o *PersistConfig) SetMaxReplicas(replicas int) { + v := o.GetReplicationConfig().Clone() + v.MaxReplicas = uint64(replicas) + o.SetReplicationConfig(v) +} + +// IsSchedulerDisabled returns if the scheduler is disabled. +func (o *PersistConfig) IsSchedulerDisabled(t string) bool { + schedulers := o.GetScheduleConfig().Schedulers + for _, s := range schedulers { + if t == s.Type { + return s.Disable + } + } + return false +} + +// SetPlacementRulesCacheEnabled sets if the placement rules cache is enabled. +func (o *PersistConfig) SetPlacementRulesCacheEnabled(enabled bool) { + v := o.GetReplicationConfig().Clone() + v.EnablePlacementRulesCache = enabled + o.SetReplicationConfig(v) +} + +// SetEnableWitness sets if the witness is enabled. +func (o *PersistConfig) SetEnableWitness(enable bool) { + v := o.GetScheduleConfig().Clone() + v.EnableWitness = enable + o.SetScheduleConfig(v) +} + +// SetPlacementRuleEnabled set PlacementRuleEnabled +func (o *PersistConfig) SetPlacementRuleEnabled(enabled bool) { + v := o.GetReplicationConfig().Clone() + v.EnablePlacementRules = enabled + o.SetReplicationConfig(v) +} + +// SetSplitMergeInterval to set the interval between finishing split and starting to merge. It's only used to test. +func (o *PersistConfig) SetSplitMergeInterval(splitMergeInterval time.Duration) { + v := o.GetScheduleConfig().Clone() + v.SplitMergeInterval = typeutil.Duration{Duration: splitMergeInterval} + o.SetScheduleConfig(v) +} + +// SetSchedulingAllowanceStatus sets the scheduling allowance status to help distinguish the source of the halt. +// TODO: support this metrics for the scheduling service in the future. +func (*PersistConfig) SetSchedulingAllowanceStatus(bool, string) {} + +// SetHaltScheduling set HaltScheduling. +func (o *PersistConfig) SetHaltScheduling(halt bool, _ string) { + v := o.GetScheduleConfig().Clone() + v.HaltScheduling = halt + o.SetScheduleConfig(v) +} + +// CheckRegionKeys return error if the smallest region's keys is less than mergeKeys +func (o *PersistConfig) CheckRegionKeys(keys, mergeKeys uint64) error { + return o.GetStoreConfig().CheckRegionKeys(keys, mergeKeys) +} + +// CheckRegionSize return error if the smallest region's size is less than mergeSize +func (o *PersistConfig) CheckRegionSize(size, mergeSize uint64) error { + return o.GetStoreConfig().CheckRegionSize(size, mergeSize) +} + +// GetRegionMaxSize returns the max region size in MB +func (o *PersistConfig) GetRegionMaxSize() uint64 { + return o.GetStoreConfig().GetRegionMaxSize() +} + +// GetRegionMaxKeys returns the max region keys +func (o *PersistConfig) GetRegionMaxKeys() uint64 { + return o.GetStoreConfig().GetRegionMaxKeys() +} + +// GetRegionSplitSize returns the region split size in MB +func (o *PersistConfig) GetRegionSplitSize() uint64 { + return o.GetStoreConfig().GetRegionSplitSize() +} + +// GetRegionSplitKeys returns the region split keys +func (o *PersistConfig) GetRegionSplitKeys() uint64 { + return o.GetStoreConfig().GetRegionSplitKeys() +} + +// IsEnableRegionBucket return true if the region bucket is enabled. +func (o *PersistConfig) IsEnableRegionBucket() bool { + return o.GetStoreConfig().IsEnableRegionBucket() +} + +// IsRaftKV2 returns the whether the cluster use `raft-kv2` engine. +func (o *PersistConfig) IsRaftKV2() bool { + return o.GetStoreConfig().IsRaftKV2() +} + +// TODO: implement the following methods + +// AddSchedulerCfg adds the scheduler configurations. +// This method is a no-op since we only use configurations derived from one-way synchronization from API server now. +func (*PersistConfig) AddSchedulerCfg(string, []string) {} + +// RemoveSchedulerCfg removes the scheduler configurations. +// This method is a no-op since we only use configurations derived from one-way synchronization from API server now. +func (*PersistConfig) RemoveSchedulerCfg(string) {} + +// CheckLabelProperty checks if the label property is satisfied. +func (*PersistConfig) CheckLabelProperty(string, []*metapb.StoreLabel) bool { + return false +} + +// IsTraceRegionFlow returns if the region flow is tracing. +// If the accuracy cannot reach 0.1 MB, it is considered not. +func (*PersistConfig) IsTraceRegionFlow() bool { + return false +} + +// Persist saves the configuration to the storage. +func (*PersistConfig) Persist(endpoint.ConfigStorage) error { + return nil +} + +func (o *PersistConfig) getTTLUint(key string) (uint64, bool, error) { + stringForm, ok := o.GetTTLData(key) + if !ok { + return 0, false, nil + } + r, err := strconv.ParseUint(stringForm, 10, 64) + return r, true, err +} + +func (o *PersistConfig) getTTLUintOr(key string, defaultValue uint64) uint64 { + if v, ok, err := o.getTTLUint(key); ok { + if err == nil { + return v + } + log.Warn("failed to parse "+key+" from PersistOptions's ttl storage", zap.Error(err)) + } + return defaultValue +} + +func (o *PersistConfig) getTTLBool(key string) (result bool, contains bool, err error) { + stringForm, ok := o.GetTTLData(key) + if !ok { + return + } + result, err = strconv.ParseBool(stringForm) + contains = true + return +} + +func (o *PersistConfig) getTTLBoolOr(key string, defaultValue bool) bool { + if v, ok, err := o.getTTLBool(key); ok { + if err == nil { + return v + } + log.Warn("failed to parse "+key+" from PersistOptions's ttl storage", zap.Error(err)) + } + return defaultValue +} + +func (o *PersistConfig) getTTLFloat(key string) (float64, bool, error) { + stringForm, ok := o.GetTTLData(key) + if !ok { + return 0, false, nil + } + r, err := strconv.ParseFloat(stringForm, 64) + return r, true, err +} + +func (o *PersistConfig) getTTLFloatOr(key string, defaultValue float64) float64 { + if v, ok, err := o.getTTLFloat(key); ok { + if err == nil { + return v + } + log.Warn("failed to parse "+key+" from PersistOptions's ttl storage", zap.Error(err)) + } + return defaultValue +} + +// GetTTLData returns if there is a TTL data for a given key. +func (o *PersistConfig) GetTTLData(key string) (string, bool) { + if o.ttl == nil { + return "", false + } + if result, ok := o.ttl.Get(key); ok { + return result.(string), ok + } + return "", false +} diff --git a/pkg/mcs/scheduling/server/grpc_service.go b/pkg/mcs/scheduling/server/grpc_service.go new file mode 100644 index 00000000000..605ec73dad5 --- /dev/null +++ b/pkg/mcs/scheduling/server/grpc_service.go @@ -0,0 +1,367 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "io" + "net/http" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/kvproto/pkg/schedulingpb" + "github.com/pingcap/log" + bs "github.com/tikv/pd/pkg/basicserver" + "github.com/tikv/pd/pkg/core" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/mcs/registry" + "github.com/tikv/pd/pkg/utils/apiutil" + "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/pkg/versioninfo" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// gRPC errors +var ( + ErrNotStarted = status.Errorf(codes.Unavailable, "server not started") + ErrClusterMismatched = status.Errorf(codes.Unavailable, "cluster mismatched") +) + +// SetUpRestHandler is a hook to sets up the REST service. +var SetUpRestHandler = func(*Service) (http.Handler, apiutil.APIServiceGroup) { + return dummyRestService{}, apiutil.APIServiceGroup{} +} + +type dummyRestService struct{} + +func (dummyRestService) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotImplemented) + w.Write([]byte("not implemented")) +} + +// ConfigProvider is used to get scheduling config from the given +// `bs.server` without modifying its interface. +type ConfigProvider any + +// Service is the scheduling grpc service. +type Service struct { + *Server +} + +// NewService creates a new scheduling service. +func NewService[T ConfigProvider](svr bs.Server) registry.RegistrableService { + server, ok := svr.(*Server) + if !ok { + log.Fatal("create scheduling server failed") + } + return &Service{ + Server: server, + } +} + +// heartbeatServer wraps Scheduling_RegionHeartbeatServer to ensure when any error +// occurs on Send() or Recv(), both endpoints will be closed. +type heartbeatServer struct { + stream schedulingpb.Scheduling_RegionHeartbeatServer + closed int32 +} + +func (s *heartbeatServer) Send(m core.RegionHeartbeatResponse) error { + if atomic.LoadInt32(&s.closed) == 1 { + return io.EOF + } + done := make(chan error, 1) + go func() { + defer logutil.LogPanic() + done <- s.stream.Send(m.(*schedulingpb.RegionHeartbeatResponse)) + }() + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + select { + case err := <-done: + if err != nil { + atomic.StoreInt32(&s.closed, 1) + } + return errors.WithStack(err) + case <-timer.C: + atomic.StoreInt32(&s.closed, 1) + return status.Errorf(codes.DeadlineExceeded, "send heartbeat timeout") + } +} + +func (s *heartbeatServer) Recv() (*schedulingpb.RegionHeartbeatRequest, error) { + if atomic.LoadInt32(&s.closed) == 1 { + return nil, io.EOF + } + req, err := s.stream.Recv() + if err != nil { + atomic.StoreInt32(&s.closed, 1) + return nil, errors.WithStack(err) + } + return req, nil +} + +// RegionHeartbeat implements gRPC SchedulingServer. +func (s *Service) RegionHeartbeat(stream schedulingpb.Scheduling_RegionHeartbeatServer) error { + var ( + server = &heartbeatServer{stream: stream} + cancel context.CancelFunc + lastBind time.Time + ) + defer func() { + // cancel the forward stream + if cancel != nil { + cancel() + } + }() + + for { + request, err := server.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return errors.WithStack(err) + } + + c := s.GetCluster() + if c == nil { + resp := &schedulingpb.RegionHeartbeatResponse{Header: s.notBootstrappedHeader()} + err := server.Send(resp) + return errors.WithStack(err) + } + + storeID := request.GetLeader().GetStoreId() + store := c.GetStore(storeID) + if store == nil { + return errors.Errorf("invalid store ID %d, not found", storeID) + } + + if time.Since(lastBind) > time.Minute { + s.hbStreams.BindStream(storeID, server) + lastBind = time.Now() + } + region := core.RegionFromHeartbeat(request) + err = c.HandleRegionHeartbeat(region) + if err != nil { + // TODO: if we need to send the error back to API server. + log.Error("failed handle region heartbeat", zap.Error(err)) + continue + } + } +} + +// StoreHeartbeat implements gRPC SchedulingServer. +func (s *Service) StoreHeartbeat(_ context.Context, request *schedulingpb.StoreHeartbeatRequest) (*schedulingpb.StoreHeartbeatResponse, error) { + c := s.GetCluster() + if c == nil { + // TODO: add metrics + log.Info("cluster isn't initialized") + return &schedulingpb.StoreHeartbeatResponse{Header: s.notBootstrappedHeader()}, nil + } + + if c.GetStore(request.GetStats().GetStoreId()) == nil { + s.metaWatcher.GetStoreWatcher().ForceLoad() + } + + // TODO: add metrics + if err := c.HandleStoreHeartbeat(request); err != nil { + log.Error("handle store heartbeat failed", zap.Error(err)) + } + return &schedulingpb.StoreHeartbeatResponse{Header: &schedulingpb.ResponseHeader{ClusterId: s.clusterID}}, nil +} + +// SplitRegions split regions by the given split keys +func (s *Service) SplitRegions(ctx context.Context, request *schedulingpb.SplitRegionsRequest) (*schedulingpb.SplitRegionsResponse, error) { + c := s.GetCluster() + if c == nil { + return &schedulingpb.SplitRegionsResponse{Header: s.notBootstrappedHeader()}, nil + } + finishedPercentage, newRegionIDs := c.GetRegionSplitter().SplitRegions(ctx, request.GetSplitKeys(), int(request.GetRetryLimit())) + return &schedulingpb.SplitRegionsResponse{ + Header: s.header(), + RegionsId: newRegionIDs, + FinishedPercentage: uint64(finishedPercentage), + }, nil +} + +// ScatterRegions implements gRPC SchedulingServer. +func (s *Service) ScatterRegions(_ context.Context, request *schedulingpb.ScatterRegionsRequest) (*schedulingpb.ScatterRegionsResponse, error) { + c := s.GetCluster() + if c == nil { + return &schedulingpb.ScatterRegionsResponse{Header: s.notBootstrappedHeader()}, nil + } + + opsCount, failures, err := c.GetRegionScatterer().ScatterRegionsByID(request.GetRegionsId(), request.GetGroup(), int(request.GetRetryLimit()), request.GetSkipStoreLimit()) + if err != nil { + header := s.errorHeader(&schedulingpb.Error{ + Type: schedulingpb.ErrorType_UNKNOWN, + Message: err.Error(), + }) + return &schedulingpb.ScatterRegionsResponse{Header: header}, nil + } + percentage := 100 + if len(failures) > 0 { + percentage = 100 - 100*len(failures)/(opsCount+len(failures)) + log.Debug("scatter regions", zap.Errors("failures", func() []error { + r := make([]error, 0, len(failures)) + for _, err := range failures { + r = append(r, err) + } + return r + }())) + } + return &schedulingpb.ScatterRegionsResponse{ + Header: s.header(), + FinishedPercentage: uint64(percentage), + }, nil +} + +// GetOperator gets information about the operator belonging to the specify region. +func (s *Service) GetOperator(_ context.Context, request *schedulingpb.GetOperatorRequest) (*schedulingpb.GetOperatorResponse, error) { + c := s.GetCluster() + if c == nil { + return &schedulingpb.GetOperatorResponse{Header: s.notBootstrappedHeader()}, nil + } + + opController := c.GetCoordinator().GetOperatorController() + requestID := request.GetRegionId() + r := opController.GetOperatorStatus(requestID) + if r == nil { + header := s.errorHeader(&schedulingpb.Error{ + Type: schedulingpb.ErrorType_UNKNOWN, + Message: "region not found", + }) + return &schedulingpb.GetOperatorResponse{Header: header}, nil + } + + return &schedulingpb.GetOperatorResponse{ + Header: s.header(), + RegionId: requestID, + Desc: []byte(r.Desc()), + Kind: []byte(r.Kind().String()), + Status: r.Status, + }, nil +} + +// AskBatchSplit implements gRPC SchedulingServer. +func (s *Service) AskBatchSplit(_ context.Context, request *schedulingpb.AskBatchSplitRequest) (*schedulingpb.AskBatchSplitResponse, error) { + c := s.GetCluster() + if c == nil { + return &schedulingpb.AskBatchSplitResponse{Header: s.notBootstrappedHeader()}, nil + } + + if request.GetRegion() == nil { + return &schedulingpb.AskBatchSplitResponse{ + Header: s.wrapErrorToHeader(schedulingpb.ErrorType_UNKNOWN, + "missing region for split"), + }, nil + } + + if c.IsSchedulingHalted() { + return nil, errs.ErrSchedulingIsHalted.FastGenByArgs() + } + if !c.persistConfig.IsTikvRegionSplitEnabled() { + return nil, errs.ErrSchedulerTiKVSplitDisabled.FastGenByArgs() + } + reqRegion := request.GetRegion() + splitCount := request.GetSplitCount() + err := c.ValidRegion(reqRegion) + if err != nil { + return nil, err + } + splitIDs := make([]*pdpb.SplitID, 0, splitCount) + recordRegions := make([]uint64, 0, splitCount+1) + + for i := 0; i < int(splitCount); i++ { + newRegionID, err := c.AllocID() + if err != nil { + return nil, errs.ErrSchedulerNotFound.FastGenByArgs() + } + + peerIDs := make([]uint64, len(request.Region.Peers)) + for i := 0; i < len(peerIDs); i++ { + if peerIDs[i], err = c.AllocID(); err != nil { + return nil, err + } + } + + recordRegions = append(recordRegions, newRegionID) + splitIDs = append(splitIDs, &pdpb.SplitID{ + NewRegionId: newRegionID, + NewPeerIds: peerIDs, + }) + + log.Info("alloc ids for region split", zap.Uint64("region-id", newRegionID), zap.Uint64s("peer-ids", peerIDs)) + } + + recordRegions = append(recordRegions, reqRegion.GetId()) + if versioninfo.IsFeatureSupported(c.persistConfig.GetClusterVersion(), versioninfo.RegionMerge) { + // Disable merge the regions in a period of time. + c.GetCoordinator().GetMergeChecker().RecordRegionSplit(recordRegions) + } + + // If region splits during the scheduling process, regions with abnormal + // status may be left, and these regions need to be checked with higher + // priority. + c.GetCoordinator().GetCheckerController().AddSuspectRegions(recordRegions...) + + return &schedulingpb.AskBatchSplitResponse{ + Header: s.header(), + Ids: splitIDs, + }, nil +} + +// RegisterGRPCService registers the service to gRPC server. +func (s *Service) RegisterGRPCService(g *grpc.Server) { + schedulingpb.RegisterSchedulingServer(g, s) +} + +// RegisterRESTHandler registers the service to REST server. +func (s *Service) RegisterRESTHandler(userDefineHandlers map[string]http.Handler) { + handler, group := SetUpRestHandler(s) + apiutil.RegisterUserDefinedHandlers(userDefineHandlers, &group, handler) +} + +func (s *Service) errorHeader(err *schedulingpb.Error) *schedulingpb.ResponseHeader { + return &schedulingpb.ResponseHeader{ + ClusterId: s.clusterID, + Error: err, + } +} + +func (s *Service) notBootstrappedHeader() *schedulingpb.ResponseHeader { + return s.errorHeader(&schedulingpb.Error{ + Type: schedulingpb.ErrorType_NOT_BOOTSTRAPPED, + Message: "cluster is not initialized", + }) +} + +func (s *Service) header() *schedulingpb.ResponseHeader { + if s.clusterID == 0 { + return s.wrapErrorToHeader(schedulingpb.ErrorType_NOT_BOOTSTRAPPED, "cluster id is not ready") + } + return &schedulingpb.ResponseHeader{ClusterId: s.clusterID} +} + +func (s *Service) wrapErrorToHeader( + errorType schedulingpb.ErrorType, message string) *schedulingpb.ResponseHeader { + return s.errorHeader(&schedulingpb.Error{Type: errorType, Message: message}) +} diff --git a/pkg/schedule/config/config_provider.go b/pkg/schedule/config/config_provider.go new file mode 100644 index 00000000000..90e489f86f3 --- /dev/null +++ b/pkg/schedule/config/config_provider.go @@ -0,0 +1,149 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "sync" + "time" + + "github.com/coreos/go-semver/semver" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/tikv/pd/pkg/core/constant" + "github.com/tikv/pd/pkg/core/storelimit" + "github.com/tikv/pd/pkg/storage/endpoint" +) + +// RejectLeader is the label property type that suggests a store should not +// have any region leaders. +const RejectLeader = "reject-leader" + +var schedulerMap sync.Map + +// RegisterScheduler registers the scheduler type. +func RegisterScheduler(typ string) { + schedulerMap.Store(typ, struct{}{}) +} + +// IsSchedulerRegistered checks if the named scheduler type is registered. +func IsSchedulerRegistered(name string) bool { + _, ok := schedulerMap.Load(name) + return ok +} + +// SchedulerConfigProvider is the interface for scheduler configurations. +type SchedulerConfigProvider interface { + SharedConfigProvider + + SetSchedulingAllowanceStatus(bool, string) + GetStoresLimit() map[uint64]StoreLimitConfig + + IsSchedulerDisabled(string) bool + AddSchedulerCfg(string, []string) + RemoveSchedulerCfg(string) + Persist(endpoint.ConfigStorage) error + + GetRegionScheduleLimit() uint64 + GetLeaderScheduleLimit() uint64 + GetHotRegionScheduleLimit() uint64 + GetWitnessScheduleLimit() uint64 + + GetHotRegionCacheHitsThreshold() int + GetMaxMovableHotPeerSize() int64 + IsTraceRegionFlow() bool + + GetTolerantSizeRatio() float64 + GetLeaderSchedulePolicy() constant.SchedulePolicy + + IsDebugMetricsEnabled() bool + IsDiagnosticAllowed() bool + GetSlowStoreEvictingAffectedStoreRatioThreshold() float64 + + GetScheduleConfig() *ScheduleConfig + SetScheduleConfig(*ScheduleConfig) +} + +// CheckerConfigProvider is the interface for checker configurations. +type CheckerConfigProvider interface { + SharedConfigProvider + StoreConfigProvider + + GetSwitchWitnessInterval() time.Duration + IsRemoveExtraReplicaEnabled() bool + IsRemoveDownReplicaEnabled() bool + IsReplaceOfflineReplicaEnabled() bool + IsMakeUpReplicaEnabled() bool + IsLocationReplacementEnabled() bool + GetIsolationLevel() string + GetSplitMergeInterval() time.Duration + GetPatrolRegionInterval() time.Duration + GetMaxMergeRegionSize() uint64 + GetMaxMergeRegionKeys() uint64 + GetReplicaScheduleLimit() uint64 +} + +// SharedConfigProvider is the interface for shared configurations. +type SharedConfigProvider interface { + GetMaxReplicas() int + IsPlacementRulesEnabled() bool + GetMaxSnapshotCount() uint64 + GetMaxPendingPeerCount() uint64 + GetLowSpaceRatio() float64 + GetHighSpaceRatio() float64 + GetMaxStoreDownTime() time.Duration + GetLocationLabels() []string + CheckLabelProperty(string, []*metapb.StoreLabel) bool + GetClusterVersion() *semver.Version + IsUseJointConsensus() bool + GetKeyType() constant.KeyType + IsCrossTableMergeEnabled() bool + IsOneWayMergeEnabled() bool + GetMergeScheduleLimit() uint64 + GetRegionScoreFormulaVersion() string + GetSchedulerMaxWaitingOperator() uint64 + GetStoreLimitByType(uint64, storelimit.Type) float64 + IsWitnessAllowed() bool + IsPlacementRulesCacheEnabled() bool + SetHaltScheduling(bool, string) + GetHotRegionCacheHitsThreshold() int + + // for test purpose + SetPlacementRuleEnabled(bool) + SetPlacementRulesCacheEnabled(bool) + SetEnableWitness(bool) +} + +// ConfProvider is the interface that wraps the ConfProvider related methods. +type ConfProvider interface { + SchedulerConfigProvider + CheckerConfigProvider + StoreConfigProvider + // for test purpose + SetPlacementRuleEnabled(bool) + SetSplitMergeInterval(time.Duration) + SetMaxReplicas(int) + SetAllStoresLimit(storelimit.Type, float64) +} + +// StoreConfigProvider is the interface that wraps the StoreConfigProvider related methods. +type StoreConfigProvider interface { + GetRegionMaxSize() uint64 + GetRegionMaxKeys() uint64 + GetRegionSplitSize() uint64 + GetRegionSplitKeys() uint64 + CheckRegionSize(uint64, uint64) error + CheckRegionKeys(uint64, uint64) error + IsEnableRegionBucket() bool + IsRaftKV2() bool +} diff --git a/pkg/schedule/coordinator.go b/pkg/schedule/coordinator.go new file mode 100644 index 00000000000..5ab38aad81d --- /dev/null +++ b/pkg/schedule/coordinator.go @@ -0,0 +1,831 @@ +// Copyright 2016 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package schedule + +import ( + "bytes" + "context" + "strconv" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/cache" + "github.com/tikv/pd/pkg/core" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/schedule/checker" + sc "github.com/tikv/pd/pkg/schedule/config" + sche "github.com/tikv/pd/pkg/schedule/core" + "github.com/tikv/pd/pkg/schedule/diagnostic" + "github.com/tikv/pd/pkg/schedule/hbstream" + "github.com/tikv/pd/pkg/schedule/operator" + "github.com/tikv/pd/pkg/schedule/scatter" + "github.com/tikv/pd/pkg/schedule/schedulers" + "github.com/tikv/pd/pkg/schedule/splitter" + "github.com/tikv/pd/pkg/statistics" + "github.com/tikv/pd/pkg/statistics/utils" + "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/pkg/utils/syncutil" + "go.uber.org/zap" +) + +const ( + runSchedulerCheckInterval = 3 * time.Second + checkSuspectRangesInterval = 100 * time.Millisecond + collectFactor = 0.9 + collectTimeout = 5 * time.Minute + maxLoadConfigRetries = 10 + // pushOperatorTickInterval is the interval try to push the operator. + pushOperatorTickInterval = 500 * time.Millisecond + + patrolScanRegionLimit = 128 // It takes about 14 minutes to iterate 1 million regions. + // PluginLoad means action for load plugin + PluginLoad = "PluginLoad" + // PluginUnload means action for unload plugin + PluginUnload = "PluginUnload" +) + +var ( + // WithLabelValues is a heavy operation, define variable to avoid call it every time. + waitingListGauge = regionListGauge.WithLabelValues("waiting_list") + priorityListGauge = regionListGauge.WithLabelValues("priority_list") +) + +// Coordinator is used to manage all schedulers and checkers to decide if the region needs to be scheduled. +type Coordinator struct { + syncutil.RWMutex + + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + + schedulersInitialized bool + patrolRegionsDuration time.Duration + + cluster sche.ClusterInformer + prepareChecker *prepareChecker + checkers *checker.Controller + regionScatterer *scatter.RegionScatterer + regionSplitter *splitter.RegionSplitter + schedulers *schedulers.Controller + opController *operator.Controller + hbStreams *hbstream.HeartbeatStreams + pluginInterface *PluginInterface + diagnosticManager *diagnostic.Manager +} + +// NewCoordinator creates a new Coordinator. +func NewCoordinator(parentCtx context.Context, cluster sche.ClusterInformer, hbStreams *hbstream.HeartbeatStreams) *Coordinator { + ctx, cancel := context.WithCancel(parentCtx) + opController := operator.NewController(ctx, cluster.GetBasicCluster(), cluster.GetSharedConfig(), hbStreams) + schedulers := schedulers.NewController(ctx, cluster, cluster.GetStorage(), opController) + checkers := checker.NewController(ctx, cluster, cluster.GetCheckerConfig(), cluster.GetRuleManager(), cluster.GetRegionLabeler(), opController) + return &Coordinator{ + ctx: ctx, + cancel: cancel, + schedulersInitialized: false, + cluster: cluster, + prepareChecker: newPrepareChecker(), + checkers: checkers, + regionScatterer: scatter.NewRegionScatterer(ctx, cluster, opController, checkers.AddSuspectRegions), + regionSplitter: splitter.NewRegionSplitter(cluster, splitter.NewSplitRegionsHandler(cluster, opController), checkers.AddSuspectRegions), + schedulers: schedulers, + opController: opController, + hbStreams: hbStreams, + pluginInterface: NewPluginInterface(), + diagnosticManager: diagnostic.NewManager(schedulers, cluster.GetSchedulerConfig()), + } +} + +// GetPatrolRegionsDuration returns the duration of the last patrol region round. +func (c *Coordinator) GetPatrolRegionsDuration() time.Duration { + if c == nil { + return 0 + } + c.RLock() + defer c.RUnlock() + return c.patrolRegionsDuration +} + +func (c *Coordinator) setPatrolRegionsDuration(dur time.Duration) { + c.Lock() + defer c.Unlock() + c.patrolRegionsDuration = dur +} + +// markSchedulersInitialized marks the scheduler initialization is finished. +func (c *Coordinator) markSchedulersInitialized() { + c.Lock() + defer c.Unlock() + c.schedulersInitialized = true +} + +// AreSchedulersInitialized returns whether the schedulers have been initialized. +func (c *Coordinator) AreSchedulersInitialized() bool { + c.RLock() + defer c.RUnlock() + return c.schedulersInitialized +} + +// GetWaitingRegions returns the regions in the waiting list. +func (c *Coordinator) GetWaitingRegions() []*cache.Item { + return c.checkers.GetWaitingRegions() +} + +// IsPendingRegion returns if the region is in the pending list. +func (c *Coordinator) IsPendingRegion(region uint64) bool { + return c.checkers.IsPendingRegion(region) +} + +// PatrolRegions is used to scan regions. +// The checkers will check these regions to decide if they need to do some operations. +// The function is exposed for test purpose. +func (c *Coordinator) PatrolRegions() { + defer logutil.LogPanic() + + defer c.wg.Done() + ticker := time.NewTicker(c.cluster.GetCheckerConfig().GetPatrolRegionInterval()) + defer ticker.Stop() + + log.Info("coordinator starts patrol regions") + start := time.Now() + var ( + key []byte + regions []*core.RegionInfo + ) + for { + select { + case <-ticker.C: + // Note: we reset the ticker here to support updating configuration dynamically. + ticker.Reset(c.cluster.GetCheckerConfig().GetPatrolRegionInterval()) + case <-c.ctx.Done(): + patrolCheckRegionsGauge.Set(0) + c.setPatrolRegionsDuration(0) + log.Info("patrol regions has been stopped") + return + } + if c.cluster.IsSchedulingHalted() { + continue + } + + // Check priority regions first. + c.checkPriorityRegions() + // Check suspect regions first. + c.checkSuspectRegions() + // Check regions in the waiting list + c.checkWaitingRegions() + + key, regions = c.checkRegions(key) + if len(regions) == 0 { + continue + } + // Updates the label level isolation statistics. + c.cluster.UpdateRegionsLabelLevelStats(regions) + if len(key) == 0 { + dur := time.Since(start) + patrolCheckRegionsGauge.Set(dur.Seconds()) + c.setPatrolRegionsDuration(dur) + start = time.Now() + } + failpoint.Inject("break-patrol", func() { + failpoint.Break() + }) + } +} + +func (c *Coordinator) checkRegions(startKey []byte) (key []byte, regions []*core.RegionInfo) { + regions = c.cluster.ScanRegions(startKey, nil, patrolScanRegionLimit) + if len(regions) == 0 { + // Resets the scan key. + key = nil + return + } + + for _, region := range regions { + c.tryAddOperators(region) + key = region.GetEndKey() + } + return +} + +func (c *Coordinator) checkSuspectRegions() { + for _, id := range c.checkers.GetSuspectRegions() { + region := c.cluster.GetRegion(id) + c.tryAddOperators(region) + } +} + +func (c *Coordinator) checkWaitingRegions() { + items := c.checkers.GetWaitingRegions() + waitingListGauge.Set(float64(len(items))) + for _, item := range items { + region := c.cluster.GetRegion(item.Key) + c.tryAddOperators(region) + } +} + +// checkPriorityRegions checks priority regions +func (c *Coordinator) checkPriorityRegions() { + items := c.checkers.GetPriorityRegions() + removes := make([]uint64, 0) + priorityListGauge.Set(float64(len(items))) + for _, id := range items { + region := c.cluster.GetRegion(id) + if region == nil { + removes = append(removes, id) + continue + } + ops := c.checkers.CheckRegion(region) + // it should skip if region needs to merge + if len(ops) == 0 || ops[0].Kind()&operator.OpMerge != 0 { + continue + } + if !c.opController.ExceedStoreLimit(ops...) { + c.opController.AddWaitingOperator(ops...) + } + } + for _, v := range removes { + c.checkers.RemovePriorityRegions(v) + } +} + +// checkSuspectRanges would pop one suspect key range group +// The regions of new version key range and old version key range would be placed into +// the suspect regions map +func (c *Coordinator) checkSuspectRanges() { + defer logutil.LogPanic() + defer c.wg.Done() + log.Info("coordinator begins to check suspect key ranges") + ticker := time.NewTicker(checkSuspectRangesInterval) + defer ticker.Stop() + for { + select { + case <-c.ctx.Done(): + log.Info("check suspect key ranges has been stopped") + return + case <-ticker.C: + keyRange, success := c.checkers.PopOneSuspectKeyRange() + if !success { + continue + } + limit := 1024 + regions := c.cluster.ScanRegions(keyRange[0], keyRange[1], limit) + if len(regions) == 0 { + continue + } + regionIDList := make([]uint64, 0, len(regions)) + for _, region := range regions { + regionIDList = append(regionIDList, region.GetID()) + } + + // if the last region's end key is smaller the keyRange[1] which means there existed the remaining regions between + // keyRange[0] and keyRange[1] after scan regions, so we put the end key and keyRange[1] into Suspect KeyRanges + lastRegion := regions[len(regions)-1] + if lastRegion.GetEndKey() != nil && bytes.Compare(lastRegion.GetEndKey(), keyRange[1]) < 0 { + c.checkers.AddSuspectKeyRange(lastRegion.GetEndKey(), keyRange[1]) + } + c.checkers.AddSuspectRegions(regionIDList...) + } + } +} + +func (c *Coordinator) tryAddOperators(region *core.RegionInfo) { + if region == nil { + // the region could be recent split, continue to wait. + return + } + id := region.GetID() + if c.opController.GetOperator(id) != nil { + c.checkers.RemoveWaitingRegion(id) + c.checkers.RemoveSuspectRegion(id) + return + } + ops := c.checkers.CheckRegion(region) + if len(ops) == 0 { + return + } + + if !c.opController.ExceedStoreLimit(ops...) { + c.opController.AddWaitingOperator(ops...) + c.checkers.RemoveWaitingRegion(id) + c.checkers.RemoveSuspectRegion(id) + } else { + c.checkers.AddWaitingRegion(region) + } +} + +// drivePushOperator is used to push the unfinished operator to the executor. +func (c *Coordinator) drivePushOperator() { + defer logutil.LogPanic() + + defer c.wg.Done() + log.Info("coordinator begins to actively drive push operator") + ticker := time.NewTicker(pushOperatorTickInterval) + defer ticker.Stop() + for { + select { + case <-c.ctx.Done(): + log.Info("drive push operator has been stopped") + return + case <-ticker.C: + c.opController.PushOperators(c.RecordOpStepWithTTL) + } + } +} + +// driveSlowNodeScheduler is used to enable slow node scheduler when using `raft-kv2`. +func (c *Coordinator) driveSlowNodeScheduler() { + defer logutil.LogPanic() + defer c.wg.Done() + + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for { + select { + case <-c.ctx.Done(): + log.Info("drive slow node scheduler is stopped") + return + case <-ticker.C: + { + // If enabled, exit. + if exists, _ := c.schedulers.IsSchedulerExisted(schedulers.EvictSlowTrendName); exists { + return + } + // If the cluster was set up with `raft-kv2` engine, this cluster should + // enable `evict-slow-trend` scheduler as default. + if c.GetCluster().GetStoreConfig().IsRaftKV2() { + typ := schedulers.EvictSlowTrendType + args := []string{} + + s, err := schedulers.CreateScheduler(typ, c.opController, c.cluster.GetStorage(), schedulers.ConfigSliceDecoder(typ, args), c.schedulers.RemoveScheduler) + if err != nil { + log.Warn("initializing evict-slow-trend scheduler failed", errs.ZapError(err)) + } else if err = c.schedulers.AddScheduler(s, args...); err != nil { + log.Error("can not add scheduler", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", args), errs.ZapError(err)) + } + } + } + } + } +} + +// RunUntilStop runs the coordinator until receiving the stop signal. +func (c *Coordinator) RunUntilStop(collectWaitTime ...time.Duration) { + c.Run(collectWaitTime...) + <-c.ctx.Done() + log.Info("coordinator is stopping") + c.GetSchedulersController().Wait() + c.wg.Wait() + log.Info("coordinator has been stopped") +} + +// Run starts coordinator. +func (c *Coordinator) Run(collectWaitTime ...time.Duration) { + ticker := time.NewTicker(runSchedulerCheckInterval) + failpoint.Inject("changeCoordinatorTicker", func() { + ticker = time.NewTicker(100 * time.Millisecond) + }) + defer ticker.Stop() + log.Info("coordinator starts to collect cluster information") + for { + if c.ShouldRun(collectWaitTime...) { + log.Info("coordinator has finished cluster information preparation") + break + } + select { + case <-ticker.C: + case <-c.ctx.Done(): + log.Info("coordinator stops running") + return + } + } + log.Info("coordinator starts to run schedulers") + c.InitSchedulers(true) + + c.wg.Add(4) + // Starts to patrol regions. + go c.PatrolRegions() + // Checks suspect key ranges + go c.checkSuspectRanges() + go c.drivePushOperator() + // Checks whether to create evict-slow-trend scheduler. + go c.driveSlowNodeScheduler() +} + +// InitSchedulers initializes schedulers. +func (c *Coordinator) InitSchedulers(needRun bool) { + var ( + scheduleNames []string + configs []string + err error + ) + for i := 0; i < maxLoadConfigRetries; i++ { + scheduleNames, configs, err = c.cluster.GetStorage().LoadAllSchedulerConfigs() + select { + case <-c.ctx.Done(): + log.Info("init schedulers has been stopped") + return + default: + } + if err == nil { + break + } + log.Error("cannot load schedulers' config", zap.Int("retry-times", i), errs.ZapError(err)) + } + if err != nil { + log.Fatal("cannot load schedulers' config", errs.ZapError(err)) + } + scheduleCfg := c.cluster.GetSchedulerConfig().GetScheduleConfig().Clone() + // The new way to create scheduler with the independent configuration. + for i, name := range scheduleNames { + data := configs[i] + typ := schedulers.FindSchedulerTypeByName(name) + var cfg sc.SchedulerConfig + for _, c := range scheduleCfg.Schedulers { + if c.Type == typ { + cfg = c + break + } + } + if len(cfg.Type) == 0 { + log.Error("the scheduler type not found", zap.String("scheduler-name", name), errs.ZapError(errs.ErrSchedulerNotFound)) + continue + } + if cfg.Disable { + log.Info("skip create scheduler with independent configuration", zap.String("scheduler-name", name), zap.String("scheduler-type", cfg.Type), zap.Strings("scheduler-args", cfg.Args)) + continue + } + s, err := schedulers.CreateScheduler(cfg.Type, c.opController, c.cluster.GetStorage(), schedulers.ConfigJSONDecoder([]byte(data)), c.schedulers.RemoveScheduler) + if err != nil { + log.Error("can not create scheduler with independent configuration", zap.String("scheduler-name", name), zap.Strings("scheduler-args", cfg.Args), errs.ZapError(err)) + continue + } + if needRun { + log.Info("create scheduler with independent configuration", zap.String("scheduler-name", s.GetName())) + if err = c.schedulers.AddScheduler(s); err != nil { + log.Error("can not add scheduler with independent configuration", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", cfg.Args), errs.ZapError(err)) + } + } else { + log.Info("create scheduler handler with independent configuration", zap.String("scheduler-name", s.GetName())) + if err = c.schedulers.AddSchedulerHandler(s); err != nil { + log.Error("can not add scheduler handler with independent configuration", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", cfg.Args), errs.ZapError(err)) + } + } + } + + // The old way to create the scheduler. + k := 0 + for _, schedulerCfg := range scheduleCfg.Schedulers { + if schedulerCfg.Disable { + scheduleCfg.Schedulers[k] = schedulerCfg + k++ + log.Info("skip create scheduler", zap.String("scheduler-type", schedulerCfg.Type), zap.Strings("scheduler-args", schedulerCfg.Args)) + continue + } + + s, err := schedulers.CreateScheduler(schedulerCfg.Type, c.opController, c.cluster.GetStorage(), schedulers.ConfigSliceDecoder(schedulerCfg.Type, schedulerCfg.Args), c.schedulers.RemoveScheduler) + if err != nil { + log.Error("can not create scheduler", zap.String("scheduler-type", schedulerCfg.Type), zap.Strings("scheduler-args", schedulerCfg.Args), errs.ZapError(err)) + continue + } + + if needRun { + log.Info("create scheduler", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", schedulerCfg.Args)) + if err = c.schedulers.AddScheduler(s, schedulerCfg.Args...); err != nil && !errors.ErrorEqual(err, errs.ErrSchedulerExisted.FastGenByArgs()) { + log.Error("can not add scheduler", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", schedulerCfg.Args), errs.ZapError(err)) + } else { + // Only records the valid scheduler config. + scheduleCfg.Schedulers[k] = schedulerCfg + k++ + } + } else { + log.Info("create scheduler handler", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", schedulerCfg.Args)) + if err = c.schedulers.AddSchedulerHandler(s, schedulerCfg.Args...); err != nil && !errors.ErrorEqual(err, errs.ErrSchedulerExisted.FastGenByArgs()) { + log.Error("can not add scheduler handler", zap.String("scheduler-name", s.GetName()), zap.Strings("scheduler-args", schedulerCfg.Args), errs.ZapError(err)) + } else { + scheduleCfg.Schedulers[k] = schedulerCfg + k++ + } + } + } + + // Removes the invalid scheduler config and persist. + scheduleCfg.Schedulers = scheduleCfg.Schedulers[:k] + c.cluster.GetSchedulerConfig().SetScheduleConfig(scheduleCfg) + if err := c.cluster.GetSchedulerConfig().Persist(c.cluster.GetStorage()); err != nil { + log.Error("cannot persist schedule config", errs.ZapError(err)) + } + log.Info("scheduler config is updated", zap.Reflect("scheduler-config", scheduleCfg.Schedulers)) + + c.markSchedulersInitialized() +} + +// LoadPlugin load user plugin +func (c *Coordinator) LoadPlugin(pluginPath string, ch chan string) { + log.Info("load plugin", zap.String("plugin-path", pluginPath)) + // get func: SchedulerType from plugin + SchedulerType, err := c.pluginInterface.GetFunction(pluginPath, "SchedulerType") + if err != nil { + log.Error("GetFunction SchedulerType error", errs.ZapError(err)) + return + } + schedulerType := SchedulerType.(func() string) + // get func: SchedulerArgs from plugin + SchedulerArgs, err := c.pluginInterface.GetFunction(pluginPath, "SchedulerArgs") + if err != nil { + log.Error("GetFunction SchedulerArgs error", errs.ZapError(err)) + return + } + schedulerArgs := SchedulerArgs.(func() []string) + // create and add user scheduler + s, err := schedulers.CreateScheduler(schedulerType(), c.opController, c.cluster.GetStorage(), schedulers.ConfigSliceDecoder(schedulerType(), schedulerArgs()), c.schedulers.RemoveScheduler) + if err != nil { + log.Error("can not create scheduler", zap.String("scheduler-type", schedulerType()), errs.ZapError(err)) + return + } + log.Info("create scheduler", zap.String("scheduler-name", s.GetName())) + // TODO: handle the plugin in API service mode. + if err = c.schedulers.AddScheduler(s); err != nil { + log.Error("can't add scheduler", zap.String("scheduler-name", s.GetName()), errs.ZapError(err)) + return + } + + c.wg.Add(1) + go c.waitPluginUnload(pluginPath, s.GetName(), ch) +} + +func (c *Coordinator) waitPluginUnload(pluginPath, schedulerName string, ch chan string) { + defer logutil.LogPanic() + defer c.wg.Done() + // Get signal from channel which means user unload the plugin + for { + select { + case action := <-ch: + if action == PluginUnload { + err := c.schedulers.RemoveScheduler(schedulerName) + if err != nil { + log.Error("can not remove scheduler", zap.String("scheduler-name", schedulerName), errs.ZapError(err)) + } else { + log.Info("unload plugin", zap.String("plugin", pluginPath)) + return + } + } else { + log.Error("unknown action", zap.String("action", action)) + } + case <-c.ctx.Done(): + log.Info("unload plugin has been stopped") + return + } + } +} + +// Stop stops the coordinator. +func (c *Coordinator) Stop() { + c.cancel() +} + +// GetHotRegionsByType gets hot regions' statistics by RWType. +func (c *Coordinator) GetHotRegionsByType(typ utils.RWType) *statistics.StoreHotPeersInfos { + isTraceFlow := c.cluster.GetSchedulerConfig().IsTraceRegionFlow() + storeLoads := c.cluster.GetStoresLoads() + stores := c.cluster.GetStores() + var infos *statistics.StoreHotPeersInfos + switch typ { + case utils.Write: + regionStats := c.cluster.RegionWriteStats() + infos = statistics.GetHotStatus(stores, storeLoads, regionStats, utils.Write, isTraceFlow) + case utils.Read: + regionStats := c.cluster.RegionReadStats() + infos = statistics.GetHotStatus(stores, storeLoads, regionStats, utils.Read, isTraceFlow) + default: + } + // update params `IsLearner` and `LastUpdateTime` + s := []statistics.StoreHotPeersStat{infos.AsLeader, infos.AsPeer} + for i, stores := range s { + for j, store := range stores { + for k := range store.Stats { + h := &s[i][j].Stats[k] + region := c.cluster.GetRegion(h.RegionID) + if region != nil { + h.IsLearner = core.IsLearner(region.GetPeer(h.StoreID)) + } + switch typ { + case utils.Write: + if region != nil { + h.LastUpdateTime = time.Unix(int64(region.GetInterval().GetEndTimestamp()), 0) + } + case utils.Read: + store := c.cluster.GetStore(h.StoreID) + if store != nil { + ts := store.GetMeta().GetLastHeartbeat() + h.LastUpdateTime = time.Unix(ts/1e9, ts%1e9) + } + default: + } + } + } + } + return infos +} + +// GetHotRegions gets hot regions' statistics by RWType and storeIDs. +// If storeIDs is empty, it returns all hot regions' statistics by RWType. +func (c *Coordinator) GetHotRegions(typ utils.RWType, storeIDs ...uint64) *statistics.StoreHotPeersInfos { + hotRegions := c.GetHotRegionsByType(typ) + if len(storeIDs) > 0 && hotRegions != nil { + asLeader := statistics.StoreHotPeersStat{} + asPeer := statistics.StoreHotPeersStat{} + for _, storeID := range storeIDs { + asLeader[storeID] = hotRegions.AsLeader[storeID] + asPeer[storeID] = hotRegions.AsPeer[storeID] + } + return &statistics.StoreHotPeersInfos{ + AsLeader: asLeader, + AsPeer: asPeer, + } + } + return hotRegions +} + +// GetWaitGroup returns the wait group. Only for test purpose. +func (c *Coordinator) GetWaitGroup() *sync.WaitGroup { + return &c.wg +} + +// CollectHotSpotMetrics collects hot spot metrics. +func (c *Coordinator) CollectHotSpotMetrics() { + stores := c.cluster.GetStores() + // Collects hot write region metrics. + collectHotMetrics(c.cluster, stores, utils.Write) + // Collects hot read region metrics. + collectHotMetrics(c.cluster, stores, utils.Read) +} + +func collectHotMetrics(cluster sche.ClusterInformer, stores []*core.StoreInfo, typ utils.RWType) { + var ( + kind string + regionStats map[uint64][]*statistics.HotPeerStat + ) + + switch typ { + case utils.Read: + regionStats = cluster.RegionReadStats() + kind = utils.Read.String() + case utils.Write: + regionStats = cluster.RegionWriteStats() + kind = utils.Write.String() + } + status := statistics.CollectHotPeerInfos(stores, regionStats) // only returns TotalBytesRate,TotalKeysRate,TotalQueryRate,Count + + for _, s := range stores { + // TODO: pre-allocate gauge metrics + storeAddress := s.GetAddress() + storeID := s.GetID() + storeLabel := strconv.FormatUint(storeID, 10) + stat, hasHotLeader := status.AsLeader[storeID] + if hasHotLeader { + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_leader").Set(stat.TotalBytesRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_leader").Set(stat.TotalKeysRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_leader").Set(stat.TotalQueryRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_leader").Set(float64(stat.Count)) + } else { + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_leader") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_leader") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_leader") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_leader") + } + + stat, hasHotPeer := status.AsPeer[storeID] + if hasHotPeer { + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_peer").Set(stat.TotalBytesRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_peer").Set(stat.TotalKeysRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_peer").Set(stat.TotalQueryRate) + hotSpotStatusGauge.WithLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_peer").Set(float64(stat.Count)) + } else { + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_bytes_as_peer") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_keys_as_peer") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "total_"+kind+"_query_as_peer") + hotSpotStatusGauge.DeleteLabelValues(storeAddress, storeLabel, "hot_"+kind+"_region_as_peer") + } + + if !hasHotLeader && !hasHotPeer { + utils.ForeachRegionStats(func(rwTy utils.RWType, dim int, _ utils.RegionStatKind) { + schedulers.HotPendingSum.DeleteLabelValues(storeLabel, rwTy.String(), utils.DimToString(dim)) + }) + } + } +} + +// ResetHotSpotMetrics resets hot spot metrics. +func ResetHotSpotMetrics() { + hotSpotStatusGauge.Reset() + schedulers.HotPendingSum.Reset() +} + +// ShouldRun returns true if the coordinator should run. +func (c *Coordinator) ShouldRun(collectWaitTime ...time.Duration) bool { + return c.prepareChecker.check(c.cluster.GetBasicCluster(), collectWaitTime...) +} + +// GetSchedulersController returns the schedulers controller. +func (c *Coordinator) GetSchedulersController() *schedulers.Controller { + return c.schedulers +} + +// PauseOrResumeChecker pauses or resumes a checker by name. +func (c *Coordinator) PauseOrResumeChecker(name string, t int64) error { + c.Lock() + defer c.Unlock() + if c.cluster == nil { + return errs.ErrNotBootstrapped.FastGenByArgs() + } + p, err := c.checkers.GetPauseController(name) + if err != nil { + return err + } + p.PauseOrResume(t) + return nil +} + +// IsCheckerPaused returns whether a checker is paused. +func (c *Coordinator) IsCheckerPaused(name string) (bool, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return false, errs.ErrNotBootstrapped.FastGenByArgs() + } + p, err := c.checkers.GetPauseController(name) + if err != nil { + return false, err + } + return p.IsPaused(), nil +} + +// GetRegionScatterer returns the region scatterer. +func (c *Coordinator) GetRegionScatterer() *scatter.RegionScatterer { + return c.regionScatterer +} + +// GetRegionSplitter returns the region splitter. +func (c *Coordinator) GetRegionSplitter() *splitter.RegionSplitter { + return c.regionSplitter +} + +// GetOperatorController returns the operator controller. +func (c *Coordinator) GetOperatorController() *operator.Controller { + return c.opController +} + +// GetCheckerController returns the checker controller. +func (c *Coordinator) GetCheckerController() *checker.Controller { + return c.checkers +} + +// GetMergeChecker returns the merge checker. +func (c *Coordinator) GetMergeChecker() *checker.MergeChecker { + return c.checkers.GetMergeChecker() +} + +// GetRuleChecker returns the rule checker. +func (c *Coordinator) GetRuleChecker() *checker.RuleChecker { + return c.checkers.GetRuleChecker() +} + +// GetPrepareChecker returns the prepare checker. +func (c *Coordinator) GetPrepareChecker() *prepareChecker { + return c.prepareChecker +} + +// GetHeartbeatStreams returns the heartbeat streams. Only for test purpose. +func (c *Coordinator) GetHeartbeatStreams() *hbstream.HeartbeatStreams { + return c.hbStreams +} + +// GetCluster returns the cluster. Only for test purpose. +func (c *Coordinator) GetCluster() sche.ClusterInformer { + return c.cluster +} + +// GetDiagnosticResult returns the diagnostic result. +func (c *Coordinator) GetDiagnosticResult(name string) (*schedulers.DiagnosticResult, error) { + return c.diagnosticManager.GetDiagnosticResult(name) +} + +// RecordOpStepWithTTL records OpStep with TTL +func (c *Coordinator) RecordOpStepWithTTL(regionID uint64) { + c.GetRuleChecker().RecordRegionPromoteToNonWitness(regionID) +} diff --git a/pkg/schedule/core/cluster_informer.go b/pkg/schedule/core/cluster_informer.go new file mode 100644 index 00000000000..b97459d26ea --- /dev/null +++ b/pkg/schedule/core/cluster_informer.go @@ -0,0 +1,73 @@ +// Copyright 2017 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "github.com/tikv/pd/pkg/core" + sc "github.com/tikv/pd/pkg/schedule/config" + "github.com/tikv/pd/pkg/schedule/labeler" + "github.com/tikv/pd/pkg/schedule/placement" + "github.com/tikv/pd/pkg/statistics" + "github.com/tikv/pd/pkg/statistics/buckets" + "github.com/tikv/pd/pkg/storage" +) + +// ClusterInformer provides the necessary information of a cluster. +type ClusterInformer interface { + SchedulerCluster + CheckerCluster + + GetStorage() storage.Storage + UpdateRegionsLabelLevelStats(regions []*core.RegionInfo) +} + +// SchedulerCluster is an aggregate interface that wraps multiple interfaces +type SchedulerCluster interface { + SharedCluster + + statistics.StoreStatInformer + buckets.BucketStatInformer + + GetSchedulerConfig() sc.SchedulerConfigProvider + GetRegionLabeler() *labeler.RegionLabeler + GetStoreConfig() sc.StoreConfigProvider + IsSchedulingHalted() bool +} + +// CheckerCluster is an aggregate interface that wraps multiple interfaces +type CheckerCluster interface { + SharedCluster + + GetCheckerConfig() sc.CheckerConfigProvider + GetStoreConfig() sc.StoreConfigProvider +} + +// SharedCluster is an aggregate interface that wraps multiple interfaces +type SharedCluster interface { + BasicCluster + statistics.RegionStatInformer + + GetBasicCluster() *core.BasicCluster + GetSharedConfig() sc.SharedConfigProvider + GetRuleManager() *placement.RuleManager + AllocID() (uint64, error) +} + +// BasicCluster is an aggregate interface that wraps multiple interfaces +type BasicCluster interface { + core.StoreSetInformer + core.StoreSetController + core.RegionSetInformer +} diff --git a/pkg/schedule/schedulers/scheduler_controller.go b/pkg/schedule/schedulers/scheduler_controller.go new file mode 100644 index 00000000000..334a2f1199a --- /dev/null +++ b/pkg/schedule/schedulers/scheduler_controller.go @@ -0,0 +1,595 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package schedulers + +import ( + "context" + "fmt" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/core" + "github.com/tikv/pd/pkg/errs" + sche "github.com/tikv/pd/pkg/schedule/core" + "github.com/tikv/pd/pkg/schedule/labeler" + "github.com/tikv/pd/pkg/schedule/operator" + "github.com/tikv/pd/pkg/schedule/plan" + "github.com/tikv/pd/pkg/storage/endpoint" + "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/pkg/utils/syncutil" + "go.uber.org/zap" +) + +const maxScheduleRetries = 10 + +var ( + denySchedulersByLabelerCounter = labeler.LabelerEventCounter.WithLabelValues("schedulers", "deny") +) + +// Controller is used to manage all schedulers. +type Controller struct { + syncutil.RWMutex + wg sync.WaitGroup + ctx context.Context + cluster sche.SchedulerCluster + storage endpoint.ConfigStorage + // schedulers are used to manage all schedulers, which will only be initialized + // and used in the PD leader service mode now. + schedulers map[string]*ScheduleController + // schedulerHandlers is used to manage the HTTP handlers of schedulers, + // which will only be initialized and used in the API service mode now. + schedulerHandlers map[string]http.Handler + opController *operator.Controller +} + +// NewController creates a scheduler controller. +func NewController(ctx context.Context, cluster sche.SchedulerCluster, storage endpoint.ConfigStorage, opController *operator.Controller) *Controller { + return &Controller{ + ctx: ctx, + cluster: cluster, + storage: storage, + schedulers: make(map[string]*ScheduleController), + schedulerHandlers: make(map[string]http.Handler), + opController: opController, + } +} + +// Wait waits on all schedulers to exit. +func (c *Controller) Wait() { + c.Lock() + defer c.Unlock() + c.wg.Wait() +} + +// GetScheduler returns a schedule controller by name. +func (c *Controller) GetScheduler(name string) *ScheduleController { + c.RLock() + defer c.RUnlock() + return c.schedulers[name] +} + +// GetSchedulerNames returns all names of schedulers. +func (c *Controller) GetSchedulerNames() []string { + c.RLock() + defer c.RUnlock() + names := make([]string, 0, len(c.schedulers)) + for name := range c.schedulers { + names = append(names, name) + } + return names +} + +// GetSchedulerHandlers returns all handlers of schedulers. +func (c *Controller) GetSchedulerHandlers() map[string]http.Handler { + c.RLock() + defer c.RUnlock() + if len(c.schedulerHandlers) > 0 { + return c.schedulerHandlers + } + handlers := make(map[string]http.Handler, len(c.schedulers)) + for name, scheduler := range c.schedulers { + handlers[name] = scheduler.Scheduler + } + return handlers +} + +// CollectSchedulerMetrics collects metrics of all schedulers. +func (c *Controller) CollectSchedulerMetrics() { + c.RLock() + for _, s := range c.schedulers { + var allowScheduler float64 + // If the scheduler is not allowed to schedule, it will disappear in Grafana panel. + // See issue #1341. + if !s.IsPaused() && !c.cluster.IsSchedulingHalted() { + allowScheduler = 1 + } + schedulerStatusGauge.WithLabelValues(s.Scheduler.GetName(), "allow").Set(allowScheduler) + } + c.RUnlock() + ruleMgr := c.cluster.GetRuleManager() + if ruleMgr == nil { + return + } + ruleCnt := ruleMgr.GetRulesCount() + groupCnt := ruleMgr.GetGroupsCount() + ruleStatusGauge.WithLabelValues("rule_count").Set(float64(ruleCnt)) + ruleStatusGauge.WithLabelValues("group_count").Set(float64(groupCnt)) +} + +// ResetSchedulerMetrics resets metrics of all schedulers. +func ResetSchedulerMetrics() { + schedulerStatusGauge.Reset() + ruleStatusGauge.Reset() +} + +// AddSchedulerHandler adds the HTTP handler for a scheduler. +func (c *Controller) AddSchedulerHandler(scheduler Scheduler, args ...string) error { + c.Lock() + defer c.Unlock() + + name := scheduler.GetName() + if _, ok := c.schedulerHandlers[name]; ok { + return errs.ErrSchedulerExisted.FastGenByArgs() + } + + c.schedulerHandlers[name] = scheduler + if err := SaveSchedulerConfig(c.storage, scheduler); err != nil { + log.Error("can not save HTTP scheduler config", zap.String("scheduler-name", scheduler.GetName()), errs.ZapError(err)) + return err + } + c.cluster.GetSchedulerConfig().AddSchedulerCfg(scheduler.GetType(), args) + err := scheduler.PrepareConfig(c.cluster) + return err +} + +// RemoveSchedulerHandler removes the HTTP handler for a scheduler. +func (c *Controller) RemoveSchedulerHandler(name string) error { + c.Lock() + defer c.Unlock() + if c.cluster == nil { + return errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulerHandlers[name] + if !ok { + return errs.ErrSchedulerNotFound.FastGenByArgs() + } + + conf := c.cluster.GetSchedulerConfig() + conf.RemoveSchedulerCfg(s.(Scheduler).GetType()) + if err := conf.Persist(c.storage); err != nil { + log.Error("the option can not persist scheduler config", errs.ZapError(err)) + return err + } + + if err := c.storage.RemoveSchedulerConfig(name); err != nil { + log.Error("can not remove the scheduler config", errs.ZapError(err)) + return err + } + + s.(Scheduler).CleanConfig(c.cluster) + delete(c.schedulerHandlers, name) + + return nil +} + +// AddScheduler adds a scheduler. +func (c *Controller) AddScheduler(scheduler Scheduler, args ...string) error { + c.Lock() + defer c.Unlock() + + if _, ok := c.schedulers[scheduler.GetName()]; ok { + return errs.ErrSchedulerExisted.FastGenByArgs() + } + + s := NewScheduleController(c.ctx, c.cluster, c.opController, scheduler) + if err := s.Scheduler.PrepareConfig(c.cluster); err != nil { + return err + } + + c.wg.Add(1) + go c.runScheduler(s) + c.schedulers[s.Scheduler.GetName()] = s + if err := SaveSchedulerConfig(c.storage, scheduler); err != nil { + log.Error("can not save scheduler config", zap.String("scheduler-name", scheduler.GetName()), errs.ZapError(err)) + return err + } + c.cluster.GetSchedulerConfig().AddSchedulerCfg(s.Scheduler.GetType(), args) + return nil +} + +// RemoveScheduler removes a scheduler by name. +func (c *Controller) RemoveScheduler(name string) error { + c.Lock() + defer c.Unlock() + if c.cluster == nil { + return errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return errs.ErrSchedulerNotFound.FastGenByArgs() + } + + conf := c.cluster.GetSchedulerConfig() + conf.RemoveSchedulerCfg(s.Scheduler.GetType()) + if err := conf.Persist(c.storage); err != nil { + log.Error("the option can not persist scheduler config", errs.ZapError(err)) + return err + } + + if err := c.storage.RemoveSchedulerConfig(name); err != nil { + log.Error("can not remove the scheduler config", errs.ZapError(err)) + return err + } + + s.Stop() + schedulerStatusGauge.DeleteLabelValues(name, "allow") + delete(c.schedulers, name) + + return nil +} + +// PauseOrResumeScheduler pauses or resumes a scheduler by name. +func (c *Controller) PauseOrResumeScheduler(name string, t int64) error { + c.Lock() + defer c.Unlock() + if c.cluster == nil { + return errs.ErrNotBootstrapped.FastGenByArgs() + } + var s []*ScheduleController + if name != "all" { + sc, ok := c.schedulers[name] + if !ok { + return errs.ErrSchedulerNotFound.FastGenByArgs() + } + s = append(s, sc) + } else { + for _, sc := range c.schedulers { + s = append(s, sc) + } + } + var err error + for _, sc := range s { + var delayAt, delayUntil int64 + if t > 0 { + delayAt = time.Now().Unix() + delayUntil = delayAt + t + } + sc.SetDelay(delayAt, delayUntil) + } + return err +} + +// ReloadSchedulerConfig reloads a scheduler's config if it exists. +func (c *Controller) ReloadSchedulerConfig(name string) error { + if exist, _ := c.IsSchedulerExisted(name); !exist { + return fmt.Errorf("scheduler %s is not existed", name) + } + return c.GetScheduler(name).ReloadConfig() +} + +// IsSchedulerAllowed returns whether a scheduler is allowed to schedule, a scheduler is not allowed to schedule if it is paused or blocked by unsafe recovery. +func (c *Controller) IsSchedulerAllowed(name string) (bool, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return false, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return false, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.AllowSchedule(false), nil +} + +// IsSchedulerPaused returns whether a scheduler is paused. +func (c *Controller) IsSchedulerPaused(name string) (bool, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return false, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return false, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.IsPaused(), nil +} + +// IsSchedulerDisabled returns whether a scheduler is disabled. +func (c *Controller) IsSchedulerDisabled(name string) (bool, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return false, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return false, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return c.cluster.GetSchedulerConfig().IsSchedulerDisabled(s.Scheduler.GetType()), nil +} + +// IsSchedulerExisted returns whether a scheduler is existed. +func (c *Controller) IsSchedulerExisted(name string) (bool, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return false, errs.ErrNotBootstrapped.FastGenByArgs() + } + _, existScheduler := c.schedulers[name] + _, existHandler := c.schedulerHandlers[name] + if !existScheduler && !existHandler { + return false, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return true, nil +} + +func (c *Controller) runScheduler(s *ScheduleController) { + defer logutil.LogPanic() + defer c.wg.Done() + defer s.Scheduler.CleanConfig(c.cluster) + + ticker := time.NewTicker(s.GetInterval()) + defer ticker.Stop() + for { + select { + case <-ticker.C: + diagnosable := s.IsDiagnosticAllowed() + if !s.AllowSchedule(diagnosable) { + continue + } + if op := s.Schedule(diagnosable); len(op) > 0 { + added := c.opController.AddWaitingOperator(op...) + log.Debug("add operator", zap.Int("added", added), zap.Int("total", len(op)), zap.String("scheduler", s.Scheduler.GetName())) + } + // Note: we reset the ticker here to support updating configuration dynamically. + ticker.Reset(s.GetInterval()) + case <-s.Ctx().Done(): + log.Info("scheduler has been stopped", + zap.String("scheduler-name", s.Scheduler.GetName()), + errs.ZapError(s.Ctx().Err())) + return + } + } +} + +// GetPausedSchedulerDelayAt returns paused timestamp of a paused scheduler +func (c *Controller) GetPausedSchedulerDelayAt(name string) (int64, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return -1, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return -1, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.GetDelayAt(), nil +} + +// GetPausedSchedulerDelayUntil returns the delay time until the scheduler is paused. +func (c *Controller) GetPausedSchedulerDelayUntil(name string) (int64, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return -1, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return -1, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.GetDelayUntil(), nil +} + +// CheckTransferWitnessLeader determines if transfer leader is required, then sends to the scheduler if needed +func (c *Controller) CheckTransferWitnessLeader(region *core.RegionInfo) { + if core.NeedTransferWitnessLeader(region) { + c.RLock() + s, ok := c.schedulers[TransferWitnessLeaderName] + c.RUnlock() + if ok { + select { + case RecvRegionInfo(s.Scheduler) <- region: + default: + log.Warn("drop transfer witness leader due to recv region channel full", zap.Uint64("region-id", region.GetID())) + } + } + } +} + +// GetAllSchedulerConfigs returns all scheduler configs. +func (c *Controller) GetAllSchedulerConfigs() ([]string, []string, error) { + return c.storage.LoadAllSchedulerConfigs() +} + +// ScheduleController is used to manage a scheduler. +type ScheduleController struct { + Scheduler + cluster sche.SchedulerCluster + opController *operator.Controller + nextInterval time.Duration + ctx context.Context + cancel context.CancelFunc + delayAt int64 + delayUntil int64 + diagnosticRecorder *DiagnosticRecorder +} + +// NewScheduleController creates a new ScheduleController. +func NewScheduleController(ctx context.Context, cluster sche.SchedulerCluster, opController *operator.Controller, s Scheduler) *ScheduleController { + ctx, cancel := context.WithCancel(ctx) + return &ScheduleController{ + Scheduler: s, + cluster: cluster, + opController: opController, + nextInterval: s.GetMinInterval(), + ctx: ctx, + cancel: cancel, + diagnosticRecorder: NewDiagnosticRecorder(s.GetName(), cluster.GetSchedulerConfig()), + } +} + +// Ctx returns the context of ScheduleController +func (s *ScheduleController) Ctx() context.Context { + return s.ctx +} + +// Stop stops the ScheduleController +func (s *ScheduleController) Stop() { + s.cancel() +} + +// Schedule tries to create some operators. +func (s *ScheduleController) Schedule(diagnosable bool) []*operator.Operator { + for i := 0; i < maxScheduleRetries; i++ { + // no need to retry if schedule should stop to speed exit + select { + case <-s.ctx.Done(): + return nil + default: + } + cacheCluster := newCacheCluster(s.cluster) + // we need only process diagnostic once in the retry loop + diagnosable = diagnosable && i == 0 + ops, plans := s.Scheduler.Schedule(cacheCluster, diagnosable) + if diagnosable { + s.diagnosticRecorder.SetResultFromPlans(ops, plans) + } + foundDisabled := false + for _, op := range ops { + if labelMgr := s.cluster.GetRegionLabeler(); labelMgr != nil { + region := s.cluster.GetRegion(op.RegionID()) + if region == nil { + continue + } + if labelMgr.ScheduleDisabled(region) { + denySchedulersByLabelerCounter.Inc() + foundDisabled = true + break + } + } + } + if len(ops) > 0 { + // If we have schedule, reset interval to the minimal interval. + s.nextInterval = s.Scheduler.GetMinInterval() + // try regenerating operators + if foundDisabled { + continue + } + return ops + } + } + s.nextInterval = s.Scheduler.GetNextInterval(s.nextInterval) + return nil +} + +// DiagnoseDryRun returns the operators and plans of a scheduler. +func (s *ScheduleController) DiagnoseDryRun() ([]*operator.Operator, []plan.Plan) { + cacheCluster := newCacheCluster(s.cluster) + return s.Scheduler.Schedule(cacheCluster, true) +} + +// GetInterval returns the interval of scheduling for a scheduler. +func (s *ScheduleController) GetInterval() time.Duration { + return s.nextInterval +} + +// SetInterval sets the interval of scheduling for a scheduler. for test purpose. +func (s *ScheduleController) SetInterval(interval time.Duration) { + s.nextInterval = interval +} + +// AllowSchedule returns if a scheduler is allowed to +func (s *ScheduleController) AllowSchedule(diagnosable bool) bool { + if !s.Scheduler.IsScheduleAllowed(s.cluster) { + if diagnosable { + s.diagnosticRecorder.SetResultFromStatus(Pending) + } + return false + } + if s.cluster.IsSchedulingHalted() { + if diagnosable { + s.diagnosticRecorder.SetResultFromStatus(Halted) + } + return false + } + if s.IsPaused() { + if diagnosable { + s.diagnosticRecorder.SetResultFromStatus(Paused) + } + return false + } + return true +} + +// IsPaused returns if a scheduler is paused. +func (s *ScheduleController) IsPaused() bool { + delayUntil := atomic.LoadInt64(&s.delayUntil) + return time.Now().Unix() < delayUntil +} + +// GetDelayAt returns paused timestamp of a paused scheduler +func (s *ScheduleController) GetDelayAt() int64 { + if s.IsPaused() { + return atomic.LoadInt64(&s.delayAt) + } + return 0 +} + +// GetDelayUntil returns resume timestamp of a paused scheduler +func (s *ScheduleController) GetDelayUntil() int64 { + if s.IsPaused() { + return atomic.LoadInt64(&s.delayUntil) + } + return 0 +} + +// SetDelay sets the delay of a scheduler. +func (s *ScheduleController) SetDelay(delayAt, delayUntil int64) { + atomic.StoreInt64(&s.delayAt, delayAt) + atomic.StoreInt64(&s.delayUntil, delayUntil) +} + +// GetDiagnosticRecorder returns the diagnostic recorder of a scheduler. +func (s *ScheduleController) GetDiagnosticRecorder() *DiagnosticRecorder { + return s.diagnosticRecorder +} + +// IsDiagnosticAllowed returns if a scheduler is allowed to do diagnostic. +func (s *ScheduleController) IsDiagnosticAllowed() bool { + return s.diagnosticRecorder.IsAllowed() +} + +// cacheCluster include cache info to improve the performance. +type cacheCluster struct { + sche.SchedulerCluster + stores []*core.StoreInfo +} + +// GetStores returns store infos from cache +func (c *cacheCluster) GetStores() []*core.StoreInfo { + return c.stores +} + +// newCacheCluster constructor for cache +func newCacheCluster(c sche.SchedulerCluster) *cacheCluster { + return &cacheCluster{ + SchedulerCluster: c, + stores: c.GetStores(), + } +} diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 660e7347bd2..530731ba5e9 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -809,6 +809,7 @@ func (c *RaftCluster) SetPDServerConfig(cfg *config.PDServerConfig) { c.opt.SetPDServerConfig(cfg) } +<<<<<<< HEAD // AddSuspectRegions adds regions to suspect list. func (c *RaftCluster) AddSuspectRegions(regionIDs ...uint64) { c.coordinator.checkers.AddSuspectRegions(regionIDs...) @@ -827,6 +828,14 @@ func (c *RaftCluster) GetHotStat() *statistics.HotStat { // RemoveSuspectRegion removes region from suspect list. func (c *RaftCluster) RemoveSuspectRegion(id uint64) { c.coordinator.checkers.RemoveSuspectRegion(id) +======= +// IsSchedulingHalted returns whether the scheduling is halted. +// Currently, the PD scheduling is halted when: +// - The `HaltScheduling` persist option is set to true. +// - Online unsafe recovery is running. +func (c *RaftCluster) IsSchedulingHalted() bool { + return c.opt.IsSchedulingHalted() || c.unsafeRecoveryController.IsRunning() +>>>>>>> 740f15e65 (*: individually check the scheduling halt for online unsafe recovery (#8147)) } // GetUnsafeRecoveryController returns the unsafe recovery controller. diff --git a/server/cluster/cluster_worker.go b/server/cluster/cluster_worker.go index fd0acfe7466..8c0aba40bf8 100644 --- a/server/cluster/cluster_worker.go +++ b/server/cluster/cluster_worker.go @@ -43,8 +43,13 @@ func (c *RaftCluster) HandleRegionHeartbeat(region *core.RegionInfo) error { // HandleAskSplit handles the split request. func (c *RaftCluster) HandleAskSplit(request *pdpb.AskSplitRequest) (*pdpb.AskSplitResponse, error) { +<<<<<<< HEAD if c.GetUnsafeRecoveryController().IsRunning() { return nil, errs.ErrUnsafeRecoveryIsRunning.FastGenByArgs() +======= + if c.IsSchedulingHalted() { + return nil, errs.ErrSchedulingIsHalted.FastGenByArgs() +>>>>>>> 740f15e65 (*: individually check the scheduling halt for online unsafe recovery (#8147)) } if !c.opt.IsTikvRegionSplitEnabled() { return nil, errs.ErrSchedulerTiKVSplitDisabled.FastGenByArgs() @@ -86,6 +91,7 @@ func (c *RaftCluster) HandleAskSplit(request *pdpb.AskSplitRequest) (*pdpb.AskSp return split, nil } +<<<<<<< HEAD // ValidRequestRegion is used to decide if the region is valid. func (c *RaftCluster) ValidRequestRegion(reqRegion *metapb.Region) error { startKey := reqRegion.GetStartKey() @@ -107,6 +113,12 @@ func (c *RaftCluster) ValidRequestRegion(reqRegion *metapb.Region) error { func (c *RaftCluster) HandleAskBatchSplit(request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { if c.GetUnsafeRecoveryController().IsRunning() { return nil, errs.ErrUnsafeRecoveryIsRunning.FastGenByArgs() +======= +// HandleAskBatchSplit handles the batch split request. +func (c *RaftCluster) HandleAskBatchSplit(request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { + if c.IsSchedulingHalted() { + return nil, errs.ErrSchedulingIsHalted.FastGenByArgs() +>>>>>>> 740f15e65 (*: individually check the scheduling halt for online unsafe recovery (#8147)) } if !c.opt.IsTikvRegionSplitEnabled() { return nil, errs.ErrSchedulerTiKVSplitDisabled.FastGenByArgs() diff --git a/server/cluster/unsafe_recovery_controller.go b/server/cluster/unsafe_recovery_controller.go index 578ab0f6976..71d3388a905 100644 --- a/server/cluster/unsafe_recovery_controller.go +++ b/server/cluster/unsafe_recovery_controller.go @@ -474,7 +474,15 @@ func (u *unsafeRecoveryController) GetStage() unsafeRecoveryStage { return u.stage } +<<<<<<< HEAD:server/cluster/unsafe_recovery_controller.go func (u *unsafeRecoveryController) changeStage(stage unsafeRecoveryStage) { +======= +func (u *Controller) changeStage(stage stage) { + // If the running stage changes, update the scheduling allowance status to add or remove "online-unsafe-recovery" halt. + if running := isRunning(stage); running != isRunning(u.stage) { + u.cluster.GetSchedulerConfig().SetSchedulingAllowanceStatus(running, "online-unsafe-recovery") + } +>>>>>>> 740f15e65 (*: individually check the scheduling halt for online unsafe recovery (#8147)):pkg/unsaferecovery/unsafe_recovery_controller.go u.stage = stage var output StageOutput diff --git a/server/config/persist_options.go b/server/config/persist_options.go index ce4565b5502..078997b9ff5 100644 --- a/server/config/persist_options.go +++ b/server/config/persist_options.go @@ -919,3 +919,88 @@ func (o *PersistOptions) SetAllStoresLimitTTL(ctx context.Context, client *clien } return err } +<<<<<<< HEAD +======= + +var haltSchedulingStatus = schedulingAllowanceStatusGauge.WithLabelValues("halt-scheduling") + +// SetSchedulingAllowanceStatus sets the scheduling allowance status to help distinguish the source of the halt. +func (*PersistOptions) SetSchedulingAllowanceStatus(halt bool, source string) { + if halt { + haltSchedulingStatus.Set(1) + schedulingAllowanceStatusGauge.WithLabelValues(source).Set(1) + } else { + haltSchedulingStatus.Set(0) + schedulingAllowanceStatusGauge.WithLabelValues(source).Set(0) + } +} + +// SetHaltScheduling set HaltScheduling. +func (o *PersistOptions) SetHaltScheduling(halt bool, source string) { + v := o.GetScheduleConfig().Clone() + v.HaltScheduling = halt + o.SetScheduleConfig(v) + o.SetSchedulingAllowanceStatus(halt, source) +} + +// IsSchedulingHalted returns if PD scheduling is halted. +func (o *PersistOptions) IsSchedulingHalted() bool { + if o == nil { + return false + } + return o.GetScheduleConfig().HaltScheduling +} + +// GetRegionMaxSize returns the max region size in MB +func (o *PersistOptions) GetRegionMaxSize() uint64 { + return o.GetStoreConfig().GetRegionMaxSize() +} + +// GetRegionMaxKeys returns the max region keys +func (o *PersistOptions) GetRegionMaxKeys() uint64 { + return o.GetStoreConfig().GetRegionMaxKeys() +} + +// GetRegionSplitSize returns the region split size in MB +func (o *PersistOptions) GetRegionSplitSize() uint64 { + return o.GetStoreConfig().GetRegionSplitSize() +} + +// GetRegionSplitKeys returns the region split keys +func (o *PersistOptions) GetRegionSplitKeys() uint64 { + return o.GetStoreConfig().GetRegionSplitKeys() +} + +// CheckRegionSize return error if the smallest region's size is less than mergeSize +func (o *PersistOptions) CheckRegionSize(size, mergeSize uint64) error { + return o.GetStoreConfig().CheckRegionSize(size, mergeSize) +} + +// CheckRegionKeys return error if the smallest region's keys is less than mergeKeys +func (o *PersistOptions) CheckRegionKeys(keys, mergeKeys uint64) error { + return o.GetStoreConfig().CheckRegionKeys(keys, mergeKeys) +} + +// IsEnableRegionBucket return true if the region bucket is enabled. +func (o *PersistOptions) IsEnableRegionBucket() bool { + return o.GetStoreConfig().IsEnableRegionBucket() +} + +// IsRaftKV2 returns true if the raft kv is v2. +func (o *PersistOptions) IsRaftKV2() bool { + return o.GetStoreConfig().IsRaftKV2() +} + +// SetRegionBucketEnabled sets if the region bucket is enabled. +// only used for test. +func (o *PersistOptions) SetRegionBucketEnabled(enabled bool) { + cfg := o.GetStoreConfig().Clone() + cfg.SetRegionBucketEnabled(enabled) + o.SetStoreConfig(cfg) +} + +// GetRegionBucketSize returns the region bucket size. +func (o *PersistOptions) GetRegionBucketSize() uint64 { + return o.GetStoreConfig().GetRegionBucketSize() +} +>>>>>>> 740f15e65 (*: individually check the scheduling halt for online unsafe recovery (#8147)) diff --git a/server/forward.go b/server/forward.go new file mode 100644 index 00000000000..650833e1fc1 --- /dev/null +++ b/server/forward.go @@ -0,0 +1,505 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "io" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/kvproto/pkg/schedulingpb" + "github.com/pingcap/kvproto/pkg/tsopb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/mcs/utils" + "github.com/tikv/pd/pkg/tso" + "github.com/tikv/pd/pkg/utils/grpcutil" + "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/pkg/utils/tsoutil" + "github.com/tikv/pd/server/cluster" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func forwardTSORequest( + ctx context.Context, + request *pdpb.TsoRequest, + forwardStream tsopb.TSO_TsoClient) (*tsopb.TsoResponse, error) { + tsopbReq := &tsopb.TsoRequest{ + Header: &tsopb.RequestHeader{ + ClusterId: request.GetHeader().GetClusterId(), + SenderId: request.GetHeader().GetSenderId(), + KeyspaceId: utils.DefaultKeyspaceID, + KeyspaceGroupId: utils.DefaultKeyspaceGroupID, + }, + Count: request.GetCount(), + DcLocation: request.GetDcLocation(), + } + + failpoint.Inject("tsoProxySendToTSOTimeout", func() { + // block until watchDeadline routine cancels the context. + <-ctx.Done() + }) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if err := forwardStream.Send(tsopbReq); err != nil { + return nil, err + } + + failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() { + // block until watchDeadline routine cancels the context. + <-ctx.Done() + }) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + return forwardStream.Recv() +} + +// forwardTSO forward the TSO requests to the TSO service. +func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { + var ( + server = &tsoServer{stream: stream} + forwardStream tsopb.TSO_TsoClient + forwardCtx context.Context + cancelForward context.CancelFunc + tsoStreamErr error + lastForwardedHost string + ) + defer func() { + s.concurrentTSOProxyStreamings.Add(-1) + if cancelForward != nil { + cancelForward() + } + if grpcutil.NeedRebuildConnection(tsoStreamErr) { + s.closeDelegateClient(lastForwardedHost) + } + }() + + maxConcurrentTSOProxyStreamings := int32(s.GetMaxConcurrentTSOProxyStreamings()) + if maxConcurrentTSOProxyStreamings >= 0 { + if newCount := s.concurrentTSOProxyStreamings.Add(1); newCount > maxConcurrentTSOProxyStreamings { + return errors.WithStack(ErrMaxCountTSOProxyRoutinesExceeded) + } + } + + tsDeadlineCh := make(chan *tsoutil.TSDeadline, 1) + go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh) + + for { + select { + case <-s.ctx.Done(): + return errors.WithStack(s.ctx.Err()) + case <-stream.Context().Done(): + return stream.Context().Err() + default: + } + + request, err := server.Recv(s.GetTSOProxyRecvFromClientTimeout()) + if err == io.EOF { + return nil + } + if err != nil { + return errors.WithStack(err) + } + if request.GetCount() == 0 { + err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") + return status.Errorf(codes.Unknown, err.Error()) + } + + forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.TSOServiceName) + if !ok || len(forwardedHost) == 0 { + tsoStreamErr = errors.WithStack(ErrNotFoundTSOAddr) + return tsoStreamErr + } + if forwardStream == nil || lastForwardedHost != forwardedHost { + if cancelForward != nil { + cancelForward() + } + + clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) + if err != nil { + tsoStreamErr = errors.WithStack(err) + return tsoStreamErr + } + forwardStream, forwardCtx, cancelForward, err = createTSOForwardStream(stream.Context(), clientConn) + if err != nil { + tsoStreamErr = errors.WithStack(err) + return tsoStreamErr + } + lastForwardedHost = forwardedHost + } + + tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) + if err != nil { + tsoStreamErr = errors.WithStack(err) + return tsoStreamErr + } + + // The error types defined for tsopb and pdpb are different, so we need to convert them. + var pdpbErr *pdpb.Error + tsopbErr := tsopbResp.GetHeader().GetError() + if tsopbErr != nil { + if tsopbErr.Type == tsopb.ErrorType_OK { + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_OK, + Message: tsopbErr.GetMessage(), + } + } else { + // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_UNKNOWN, + Message: tsopbErr.GetMessage(), + } + } + } + + response := &pdpb.TsoResponse{ + Header: &pdpb.ResponseHeader{ + ClusterId: tsopbResp.GetHeader().GetClusterId(), + Error: pdpbErr, + }, + Count: tsopbResp.GetCount(), + Timestamp: tsopbResp.GetTimestamp(), + } + if err := server.Send(response); err != nil { + return errors.WithStack(err) + } + } +} + +func (s *GrpcServer) forwardTSORequestWithDeadLine( + forwardCtx context.Context, + cancelForward context.CancelFunc, + forwardStream tsopb.TSO_TsoClient, + request *pdpb.TsoRequest, + tsDeadlineCh chan<- *tsoutil.TSDeadline) (*tsopb.TsoResponse, error) { + done := make(chan struct{}) + dl := tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, cancelForward) + select { + case tsDeadlineCh <- dl: + case <-forwardCtx.Done(): + return nil, forwardCtx.Err() + } + + start := time.Now() + resp, err := forwardTSORequest(forwardCtx, request, forwardStream) + close(done) + if err != nil { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.tsoPrimaryWatcher.ForceLoad() + } + return nil, err + } + tsoProxyBatchSize.Observe(float64(request.GetCount())) + tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) + return resp, nil +} + +func createTSOForwardStream(ctx context.Context, client *grpc.ClientConn) (tsopb.TSO_TsoClient, context.Context, context.CancelFunc, error) { + done := make(chan struct{}) + forwardCtx, cancelForward := context.WithCancel(ctx) + go grpcutil.CheckStream(forwardCtx, cancelForward, done) + forwardStream, err := tsopb.NewTSOClient(client).Tso(forwardCtx) + done <- struct{}{} + return forwardStream, forwardCtx, cancelForward, err +} + +func (s *GrpcServer) createRegionHeartbeatForwardStream(client *grpc.ClientConn) (pdpb.PD_RegionHeartbeatClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err := pdpb.NewPDClient(client).RegionHeartbeat(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + +func createRegionHeartbeatSchedulingStream(ctx context.Context, client *grpc.ClientConn) (schedulingpb.Scheduling_RegionHeartbeatClient, context.Context, context.CancelFunc, error) { + done := make(chan struct{}) + forwardCtx, cancelForward := context.WithCancel(ctx) + go grpcutil.CheckStream(forwardCtx, cancelForward, done) + forwardStream, err := schedulingpb.NewSchedulingClient(client).RegionHeartbeat(forwardCtx) + done <- struct{}{} + return forwardStream, forwardCtx, cancelForward, err +} + +func forwardRegionHeartbeatToScheduling(rc *cluster.RaftCluster, forwardStream schedulingpb.Scheduling_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.Recv() + if err == io.EOF { + errCh <- errors.WithStack(err) + return + } + if err != nil { + errCh <- errors.WithStack(err) + return + } + // TODO: find a better way to halt scheduling immediately. + if rc.IsSchedulingHalted() { + continue + } + // The error types defined for schedulingpb and pdpb are different, so we need to convert them. + var pdpbErr *pdpb.Error + schedulingpbErr := resp.GetHeader().GetError() + if schedulingpbErr != nil { + if schedulingpbErr.Type == schedulingpb.ErrorType_OK { + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_OK, + Message: schedulingpbErr.GetMessage(), + } + } else { + // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_UNKNOWN, + Message: schedulingpbErr.GetMessage(), + } + } + } + response := &pdpb.RegionHeartbeatResponse{ + Header: &pdpb.ResponseHeader{ + ClusterId: resp.GetHeader().GetClusterId(), + Error: pdpbErr, + }, + ChangePeer: resp.GetChangePeer(), + TransferLeader: resp.GetTransferLeader(), + RegionId: resp.GetRegionId(), + RegionEpoch: resp.GetRegionEpoch(), + TargetPeer: resp.GetTargetPeer(), + Merge: resp.GetMerge(), + SplitRegion: resp.GetSplitRegion(), + ChangePeerV2: resp.GetChangePeerV2(), + SwitchWitnesses: resp.GetSwitchWitnesses(), + } + + if err := server.Send(response); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func forwardRegionHeartbeatClientToServer(forwardStream pdpb.PD_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.Recv() + if err != nil { + errCh <- errors.WithStack(err) + return + } + if err := server.Send(resp); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func forwardReportBucketClientToServer(forwardStream pdpb.PD_ReportBucketsClient, server *bucketHeartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.CloseAndRecv() + if err != nil { + errCh <- errors.WithStack(err) + return + } + if err := server.Send(resp); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func (s *GrpcServer) createReportBucketsForwardStream(client *grpc.ClientConn) (pdpb.PD_ReportBucketsClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err := pdpb.NewPDClient(client).ReportBuckets(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + +func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string) (*grpc.ClientConn, error) { + client, ok := s.clientConns.Load(forwardedHost) + if ok { + // Mostly, the connection is already established, and return it directly. + return client.(*grpc.ClientConn), nil + } + + tlsConfig, err := s.GetTLSConfig().ToTLSConfig() + if err != nil { + return nil, err + } + ctxTimeout, cancel := context.WithTimeout(ctx, defaultGRPCDialTimeout) + defer cancel() + newConn, err := grpcutil.GetClientConn(ctxTimeout, forwardedHost, tlsConfig) + if err != nil { + return nil, err + } + conn, loaded := s.clientConns.LoadOrStore(forwardedHost, newConn) + if !loaded { + // Successfully stored the connection we created. + return newConn, nil + } + // Loaded a connection created/stored by another goroutine, so close the one we created + // and return the one we loaded. + newConn.Close() + return conn.(*grpc.ClientConn), nil +} + +func (s *GrpcServer) closeDelegateClient(forwardedHost string) { + client, ok := s.clientConns.LoadAndDelete(forwardedHost) + if !ok { + return + } + client.(*grpc.ClientConn).Close() + log.Debug("close delegate client connection", zap.String("forwarded-host", forwardedHost)) +} + +func (s *GrpcServer) isLocalRequest(host string) bool { + failpoint.Inject("useForwardRequest", func() { + failpoint.Return(false) + }) + if host == "" { + return true + } + memberAddrs := s.GetMember().Member().GetClientUrls() + for _, addr := range memberAddrs { + if addr == host { + return true + } + } + return false +} + +func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { + if !s.IsAPIServiceMode() { + return s.tsoAllocatorManager.HandleRequest(ctx, tso.GlobalDCLocation, 1) + } + request := &tsopb.TsoRequest{ + Header: &tsopb.RequestHeader{ + ClusterId: s.ClusterID(), + KeyspaceId: utils.DefaultKeyspaceID, + KeyspaceGroupId: utils.DefaultKeyspaceGroupID, + }, + Count: 1, + } + var ( + forwardedHost string + forwardStream tsopb.TSO_TsoClient + ts *tsopb.TsoResponse + err error + ok bool + ) + handleStreamError := func(err error) (needRetry bool) { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.tsoPrimaryWatcher.ForceLoad() + log.Warn("force to load tso primary address due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + if grpcutil.NeedRebuildConnection(err) { + s.tsoClientPool.Lock() + delete(s.tsoClientPool.clients, forwardedHost) + s.tsoClientPool.Unlock() + log.Warn("client connection removed due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + return false + } + for i := 0; i < maxRetryTimesRequestTSOServer; i++ { + if i > 0 { + time.Sleep(retryIntervalRequestTSOServer) + } + forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) + if !ok || forwardedHost == "" { + return pdpb.Timestamp{}, ErrNotFoundTSOAddr + } + forwardStream, err = s.getTSOForwardStream(forwardedHost) + if err != nil { + return pdpb.Timestamp{}, err + } + err = forwardStream.Send(request) + if err != nil { + if needRetry := handleStreamError(err); needRetry { + continue + } + log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err + } + ts, err = forwardStream.Recv() + if err != nil { + if needRetry := handleStreamError(err); needRetry { + continue + } + log.Error("receive response from tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err + } + return *ts.GetTimestamp(), nil + } + log.Error("get global tso from tso primary server failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err +} + +func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoClient, error) { + s.tsoClientPool.RLock() + forwardStream, ok := s.tsoClientPool.clients[forwardedHost] + s.tsoClientPool.RUnlock() + if ok { + // This is the common case to return here + return forwardStream, nil + } + + s.tsoClientPool.Lock() + defer s.tsoClientPool.Unlock() + + // Double check after entering the critical section + forwardStream, ok = s.tsoClientPool.clients[forwardedHost] + if ok { + return forwardStream, nil + } + + // Now let's create the client connection and the forward stream + client, err := s.getDelegateClient(s.ctx, forwardedHost) + if err != nil { + return nil, err + } + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err = tsopb.NewTSOClient(client).Tso(ctx) + done <- struct{}{} + if err != nil { + return nil, err + } + s.tsoClientPool.clients[forwardedHost] = forwardStream + return forwardStream, nil +} diff --git a/server/server.go b/server/server.go index 5830f2e0a87..ba9b678fc9d 100644 --- a/server/server.go +++ b/server/server.go @@ -930,7 +930,12 @@ func (s *Server) GetScheduleConfig() *config.ScheduleConfig { } // SetScheduleConfig sets the balance config information. +<<<<<<< HEAD func (s *Server) SetScheduleConfig(cfg config.ScheduleConfig) error { +======= +// This function is exported to be used by the API. +func (s *Server) SetScheduleConfig(cfg sc.ScheduleConfig) error { +>>>>>>> 740f15e65 (*: individually check the scheduling halt for online unsafe recovery (#8147)) if err := cfg.Validate(); err != nil { return err } @@ -948,6 +953,8 @@ func (s *Server) SetScheduleConfig(cfg config.ScheduleConfig) error { errs.ZapError(err)) return err } + // Update the scheduling halt status at the same time. + s.persistOptions.SetSchedulingAllowanceStatus(cfg.HaltScheduling, "manually") log.Info("schedule config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) return nil }