diff --git a/.github/workflows/pd-tests.yaml b/.github/workflows/pd-tests.yaml index 9084c7545a8..223187737e0 100644 --- a/.github/workflows/pd-tests.yaml +++ b/.github/workflows/pd-tests.yaml @@ -25,24 +25,33 @@ jobs: strategy: fail-fast: true matrix: - worker_id: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + include: + - worker_id: 1 + name: 'Unit Test(1)' + - worker_id: 2 + name: 'Unit Test(2)' + - worker_id: 3 + name: 'Tools Test' + - worker_id: 4 + name: 'Client Integration Test' + - worker_id: 5 + name: 'TSO Integration Test' + - worker_id: 6 + name: 'MicroService Integration Test' outputs: - job-total: 13 + job-total: 6 steps: - name: Checkout code uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version: '1.21' - - name: Make Test + - name: ${{ matrix.name }} env: WORKER_ID: ${{ matrix.worker_id }} - WORKER_COUNT: 13 - JOB_COUNT: 9 # 10 is tools test, 11, 12, 13 are for other integrations jobs run: | - make ci-test-job JOB_COUNT=$(($JOB_COUNT)) JOB_INDEX=$WORKER_ID + make ci-test-job JOB_INDEX=$WORKER_ID mv covprofile covprofile_$WORKER_ID - sed -i "/failpoint_binding/d" covprofile_$WORKER_ID - name: Upload coverage result ${{ matrix.worker_id }} uses: actions/upload-artifact@v4 with: @@ -62,7 +71,11 @@ jobs: - name: Merge env: TOTAL_JOBS: ${{needs.chunks.outputs.job-total}} - run: for i in $(seq 1 $TOTAL_JOBS); do cat covprofile_$i >> covprofile; done + run: | + for i in $(seq 1 $TOTAL_JOBS); do cat covprofile_$i >> covprofile; done + sed -i "/failpoint_binding/d" covprofile + # only keep the first line(`mode: aomic`) of the coverage profile + sed -i '2,${/mode: atomic/d;}' covprofile - name: Send coverage uses: codecov/codecov-action@v4.2.0 with: diff --git a/Makefile b/Makefile index 205896c377a..dca00012114 100644 --- a/Makefile +++ b/Makefile @@ -127,7 +127,7 @@ regions-dump: stores-dump: cd tools && CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/stores-dump stores-dump/main.go pd-ut: pd-xprog - cd tools && GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_TOOL_CGO_ENABLED) go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-ut pd-ut/ut.go + cd tools && GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_TOOL_CGO_ENABLED) go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-ut pd-ut/ut.go pd-ut/coverProfile.go pd-xprog: cd tools && GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_TOOL_CGO_ENABLED) go build -tags xprog -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/xprog pd-ut/xprog.go @@ -227,7 +227,8 @@ failpoint-disable: install-tools ut: pd-ut @$(FAILPOINT_ENABLE) - ./bin/pd-ut run --race + # only run unit tests + ./bin/pd-ut run --ignore tests --race @$(CLEAN_UT_BINARY) @$(FAILPOINT_DISABLE) @@ -251,7 +252,7 @@ basic-test: install-tools go test $(BASIC_TEST_PKGS) || { $(FAILPOINT_DISABLE); exit 1; } @$(FAILPOINT_DISABLE) -ci-test-job: install-tools dashboard-ui +ci-test-job: install-tools dashboard-ui pd-ut @$(FAILPOINT_ENABLE) ./scripts/ci-subtask.sh $(JOB_COUNT) $(JOB_INDEX) || { $(FAILPOINT_DISABLE); exit 1; } @$(FAILPOINT_DISABLE) diff --git a/client/resource_group/controller/controller.go b/client/resource_group/controller/controller.go index 79bd6a9c3a6..11ea3f7997d 100755 --- a/client/resource_group/controller/controller.go +++ b/client/resource_group/controller/controller.go @@ -117,6 +117,13 @@ func WithWaitRetryTimes(times int) ResourceControlCreateOption { } } +// WithDegradedModeWaitDuration is the option to set the wait duration for degraded mode. +func WithDegradedModeWaitDuration(d time.Duration) ResourceControlCreateOption { + return func(controller *ResourceGroupsController) { + controller.ruConfig.DegradedModeWaitDuration = d + } +} + var _ ResourceGroupKVInterceptor = (*ResourceGroupsController)(nil) // ResourceGroupsController implements ResourceGroupKVInterceptor. diff --git a/client/resource_group/controller/limiter.go b/client/resource_group/controller/limiter.go index 230ad46ecf1..a726b0e219a 100644 --- a/client/resource_group/controller/limiter.go +++ b/client/resource_group/controller/limiter.go @@ -330,6 +330,8 @@ func (lim *Limiter) AvailableTokens(now time.Time) float64 { return tokens } +const reserveWarnLogInterval = 10 * time.Millisecond + // reserveN is a helper method for Reserve. // maxFutureReserve specifies the maximum reservation wait duration allowed. // reserveN returns Reservation, not *Reservation. @@ -376,16 +378,19 @@ func (lim *Limiter) reserveN(now time.Time, n float64, maxFutureReserve time.Dur lim.tokens = tokens lim.maybeNotify() } else { - log.Warn("[resource group controller] cannot reserve enough tokens", - zap.Duration("need-wait-duration", waitDuration), - zap.Duration("max-wait-duration", maxFutureReserve), - zap.Float64("current-ltb-tokens", lim.tokens), - zap.Float64("current-ltb-rate", float64(lim.limit)), - zap.Float64("request-tokens", n), - zap.Float64("notify-threshold", lim.notifyThreshold), - zap.Bool("is-low-process", lim.isLowProcess), - zap.Int64("burst", lim.burst), - zap.Int("remaining-notify-times", lim.remainingNotifyTimes)) + // print log if the limiter cannot reserve for a while. + if time.Since(lim.last) > reserveWarnLogInterval { + log.Warn("[resource group controller] cannot reserve enough tokens", + zap.Duration("need-wait-duration", waitDuration), + zap.Duration("max-wait-duration", maxFutureReserve), + zap.Float64("current-ltb-tokens", lim.tokens), + zap.Float64("current-ltb-rate", float64(lim.limit)), + zap.Float64("request-tokens", n), + zap.Float64("notify-threshold", lim.notifyThreshold), + zap.Bool("is-low-process", lim.isLowProcess), + zap.Int64("burst", lim.burst), + zap.Int("remaining-notify-times", lim.remainingNotifyTimes)) + } lim.last = last if lim.limit == 0 { lim.notify() diff --git a/codecov.yml b/codecov.yml index bb439917e78..936eb3bbb11 100644 --- a/codecov.yml +++ b/codecov.yml @@ -24,9 +24,3 @@ flag_management: target: 74% # increase it if you want to enforce higher coverage for project, current setting as 74% is for do not let the error be reported and lose the meaning of warning. - type: patch target: 74% # increase it if you want to enforce higher coverage for project, current setting as 74% is for do not let the error be reported and lose the meaning of warning. - -ignore: - # Ignore the tool tests - - tests/dashboard - - tests/pdbackup - - tests/pdctl diff --git a/metrics/grafana/pd.json b/metrics/grafana/pd.json index 9941004e0c2..e6d314c2e00 100644 --- a/metrics/grafana/pd.json +++ b/metrics/grafana/pd.json @@ -1218,7 +1218,6 @@ }, "yaxes": [ { - "$$hashKey": "object:192", "format": "short", "label": null, "logBase": 1, @@ -1227,7 +1226,6 @@ "show": true }, { - "$$hashKey": "object:193", "format": "short", "label": null, "logBase": 1, @@ -7909,7 +7907,7 @@ "tableColumn": "", "targets": [ { - "expr": "pd_checker_patrol_regions_time{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\"} != 0", + "expr": "max(max(pd_checker_patrol_regions_time{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\"})by(instance))", "legendFormat": "{{instance}}", "format": "time_series", "intervalFactor": 1, @@ -11431,7 +11429,6 @@ "renderer": "flot", "seriesOverrides": [ { - "$$hashKey": "object:1147", "alias": "WaitRegionsLock", "bars": false, "lines": true, @@ -11439,7 +11436,6 @@ "stack": false }, { - "$$hashKey": "object:1251", "alias": "WaitSubRegionsLock", "bars": false, "lines": true, @@ -11486,14 +11482,12 @@ }, "yaxes": [ { - "$$hashKey": "object:322", "format": "s", "logBase": 1, "min": "0", "show": true }, { - "$$hashKey": "object:323", "format": "s", "logBase": 1, "show": true @@ -11606,10 +11600,15 @@ "dashLength": 10, "dashes": false, "datasource": "${DS_TEST-CLUSTER}", - "description": "The count of the corresponding schedule commands which PD sends to each TiKV instance", + "description": "The count of the heartbeats which pending in the task queue.", "editable": true, "error": false, + "fieldConfig": { + "defaults": {}, + "overrides": [] + }, "fill": 0, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 8, @@ -11617,6 +11616,236 @@ "x": 12, "y": 39 }, + "hiddenSeries": false, + "id": 1608, + "legend": { + "alignAsTable": true, + "avg": true, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null as zero", + "options": { + "alertThreshold": true + }, + "paceLength": 10, + "percentage": false, + "pluginVersion": "7.5.17", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "exemplar": true, + "expr": "pd_ratelimit_runner_task_pending_tasks{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\"}", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 2, + "legendFormat": "{{task_type}}_({{runner_name}})", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Heartbeat Runner Pending Task", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "opm", + "label": null, + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "format": "s", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The count of the heartbeats which faileds in the task queue.", + "editable": true, + "error": false, + "fieldConfig": { + "defaults": {}, + "overrides": [] + }, + "fill": 0, + "fillGradient": 0, + "grid": {}, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 47 + }, + "hiddenSeries": false, + "id": 1609, + "legend": { + "alignAsTable": true, + "avg": true, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null as zero", + "options": { + "alertThreshold": true + }, + "paceLength": 10, + "percentage": false, + "pluginVersion": "7.5.17", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [ + { + "alias": "/max-wait-duration.*/", + "bars": true, + "lines": false, + "transform": "negative-Y", + "yaxis": 2 + } + ], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "exemplar": true, + "expr": "rate(pd_ratelimit_runner_task_failed_tasks_total{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\"}[1m])*60", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 2, + "legendFormat": "failed-tasks-({{runner_name}})", + "refId": "A", + "step": 4 + }, + { + "exemplar": true, + "expr": "pd_ratelimit_runner_task_max_waiting_duration_seconds{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\"}", + "hide": false, + "interval": "", + "legendFormat": "max-wait-duration-({{runner_name}})", + "refId": "B" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Concurrent Runner Failed Task", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "decimals": null, + "format": "opm", + "label": "", + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "format": "s", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The count of the corresponding schedule commands which PD sends to each TiKV instance", + "editable": true, + "error": false, + "fill": 0, + "grid": {}, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 47 + }, "id": 1305, "legend": { "alignAsTable": true, @@ -11709,7 +11938,7 @@ "h": 8, "w": 12, "x": 0, - "y": 47 + "y": 55 }, "id": 1306, "legend": { @@ -11799,7 +12028,7 @@ "h": 8, "w": 12, "x": 12, - "y": 47 + "y": 55 }, "id": 1307, "legend": { @@ -11892,7 +12121,7 @@ "h": 8, "w": 12, "x": 0, - "y": 55 + "y": 63 }, "id": 1308, "legend": { @@ -11989,7 +12218,7 @@ "h": 8, "w": 12, "x": 12, - "y": 55 + "y": 63 }, "id": 1309, "legend": { @@ -12086,7 +12315,7 @@ "h": 8, "w": 12, "x": 0, - "y": 63 + "y": 71 }, "id": 1310, "legend": { @@ -12183,7 +12412,7 @@ "h": 8, "w": 12, "x": 12, - "y": 63 + "y": 71 }, "id": 1311, "legend": { @@ -12280,7 +12509,7 @@ "h": 8, "w": 12, "x": 0, - "y": 71 + "y": 79 }, "id": 1312, "legend": { diff --git a/pkg/core/metrics.go b/pkg/core/metrics.go index d23cf9dfcaa..7d2c904f319 100644 --- a/pkg/core/metrics.go +++ b/pkg/core/metrics.go @@ -15,6 +15,7 @@ package core import ( + "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -90,6 +91,12 @@ func init() { prometheus.MustRegister(AcquireRegionsLockWaitCount) } +var tracerPool = &sync.Pool{ + New: func() any { + return ®ionHeartbeatProcessTracer{} + }, +} + type saveCacheStats struct { startTime time.Time lastCheckTime time.Time @@ -114,6 +121,7 @@ type RegionHeartbeatProcessTracer interface { OnCollectRegionStatsFinished() OnAllStageFinished() LogFields() []zap.Field + Release() } type noopHeartbeatProcessTracer struct{} @@ -138,6 +146,7 @@ func (*noopHeartbeatProcessTracer) OnAllStageFinished() {} func (*noopHeartbeatProcessTracer) LogFields() []zap.Field { return nil } +func (*noopHeartbeatProcessTracer) Release() {} type regionHeartbeatProcessTracer struct { startTime time.Time @@ -151,7 +160,7 @@ type regionHeartbeatProcessTracer struct { // NewHeartbeatProcessTracer returns a heartbeat process tracer. func NewHeartbeatProcessTracer() RegionHeartbeatProcessTracer { - return ®ionHeartbeatProcessTracer{} + return tracerPool.Get().(*regionHeartbeatProcessTracer) } func (h *regionHeartbeatProcessTracer) Begin() { @@ -254,3 +263,10 @@ func (h *regionHeartbeatProcessTracer) LogFields() []zap.Field { zap.Duration("other-duration", h.OtherDuration), } } + +// Release puts the tracer back into the pool. +func (h *regionHeartbeatProcessTracer) Release() { + // Reset the fields of h to their zero values. + *h = regionHeartbeatProcessTracer{} + tracerPool.Put(h) +} diff --git a/pkg/core/region.go b/pkg/core/region.go index be8f392f05e..c9a8455d4de 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -36,7 +36,6 @@ import ( "github.com/pingcap/kvproto/pkg/replication_modepb" "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" - "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/utils/logutil" "github.com/tikv/pd/pkg/utils/syncutil" "github.com/tikv/pd/pkg/utils/typeutil" @@ -751,19 +750,19 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { debug = func(msg string, fields ...zap.Field) { logRunner.RunTask( ctx.Context, + "DebugLog", func(_ context.Context) { d(msg, fields...) }, - ratelimit.WithTaskName("DebugLog"), ) } info = func(msg string, fields ...zap.Field) { logRunner.RunTask( ctx.Context, + "InfoLog", func(_ context.Context) { i(msg, fields...) }, - ratelimit.WithTaskName("InfoLog"), ) } } @@ -915,6 +914,8 @@ type RegionsInfo struct { learners map[uint64]*regionTree // storeID -> sub regionTree witnesses map[uint64]*regionTree // storeID -> sub regionTree pendingPeers map[uint64]*regionTree // storeID -> sub regionTree + // This tree is used to check the overlaps among all the subtrees. + overlapTree *regionTree } // NewRegionsInfo creates RegionsInfo with tree, regions, leaders and followers @@ -928,6 +929,7 @@ func NewRegionsInfo() *RegionsInfo { learners: make(map[uint64]*regionTree), witnesses: make(map[uint64]*regionTree), pendingPeers: make(map[uint64]*regionTree), + overlapTree: newRegionTreeWithCountRef(), } } @@ -1042,10 +1044,10 @@ func (r *RegionsInfo) CheckAndPutRootTree(ctx *MetaProcessContext, region *Regio // Usually used with CheckAndPutRootTree together. func (r *RegionsInfo) CheckAndPutSubTree(region *RegionInfo) { // new region get from root tree again - var newRegion *RegionInfo - newRegion = r.GetRegion(region.GetID()) + newRegion := r.GetRegion(region.GetID()) if newRegion == nil { - newRegion = region + // Make sure there is this region in the root tree, so as to ensure the correctness of reference count + return } r.UpdateSubTreeOrderInsensitive(newRegion) } @@ -1067,110 +1069,98 @@ func (r *RegionsInfo) UpdateSubTreeOrderInsensitive(region *RegionInfo) { origin = originItem.RegionInfo } rangeChanged := true - if origin != nil { + rangeChanged = !origin.rangeEqualsTo(region) + if r.preUpdateSubTreeLocked(rangeChanged, !origin.peersEqualTo(region), true, origin, region) { + return + } + } + r.updateSubTreeLocked(rangeChanged, nil, region) +} + +func (r *RegionsInfo) preUpdateSubTreeLocked( + rangeChanged, peerChanged, orderInsensitive bool, + origin, region *RegionInfo, +) (done bool) { + if orderInsensitive { re := region.GetRegionEpoch() oe := origin.GetRegionEpoch() isTermBehind := region.GetTerm() > 0 && region.GetTerm() < origin.GetTerm() if (isTermBehind || re.GetVersion() < oe.GetVersion() || re.GetConfVer() < oe.GetConfVer()) && !region.isRegionRecreated() { // Region meta is stale, skip. - return + return true } - rangeChanged = !origin.rangeEqualsTo(region) - - if rangeChanged || !origin.peersEqualTo(region) { - // If the range or peers have changed, the sub regionTree needs to be cleaned up. - // TODO: Improve performance by deleting only the different peers. - r.removeRegionFromSubTreeLocked(origin) - } else { - // The region tree and the subtree update is not atomic and the region tree is updated first. - // If there are two thread needs to update region tree, - // t1: thread-A update region tree - // t2: thread-B: update region tree again - // t3: thread-B: update subtree - // t4: thread-A: update region subtree - // to keep region tree consistent with subtree, we need to drop this update. - if tree, ok := r.subRegions[region.GetID()]; ok { - r.updateSubTreeStat(origin, region) - tree.RegionInfo = region - } - return + } + if rangeChanged || peerChanged { + // If the range or peers have changed, clean up the subtrees before updating them. + // TODO: improve performance by deleting only the different peers. + r.removeRegionFromSubTreeLocked(origin) + } else { + // The region tree and the subtree update is not atomic and the region tree is updated first. + // If there are two thread needs to update region tree, + // t1: thread-A update region tree + // t2: thread-B: update region tree again + // t3: thread-B: update subtree + // t4: thread-A: update region subtree + // to keep region tree consistent with subtree, we need to drop this update. + if tree, ok := r.subRegions[region.GetID()]; ok { + r.updateSubTreeStat(origin, region) + tree.RegionInfo = region } + return true } + return false +} +func (r *RegionsInfo) updateSubTreeLocked(rangeChanged bool, overlaps []*RegionInfo, region *RegionInfo) { if rangeChanged { - overlaps := r.getOverlapRegionFromSubTreeLocked(region) - for _, re := range overlaps { - r.removeRegionFromSubTreeLocked(re) + // TODO: only perform the remove operation on the overlapped peer. + if len(overlaps) == 0 { + // If the range has changed but the overlapped regions are not provided, collect them by `[]*regionItem`. + for _, item := range r.getOverlapRegionFromOverlapTreeLocked(region) { + r.removeRegionFromSubTreeLocked(item.RegionInfo) + } + } else { + // Remove all provided overlapped regions from the subtrees. + for _, overlap := range overlaps { + r.removeRegionFromSubTreeLocked(overlap) + } } } - + // Reinsert the region into all subtrees. item := ®ionItem{region} r.subRegions[region.GetID()] = item - // It has been removed and all information needs to be updated again. - // Set peers then. - setPeer := func(peersMap map[uint64]*regionTree, storeID uint64, item *regionItem, countRef bool) { + r.overlapTree.update(item, false) + // Add leaders and followers. + setPeer := func(peersMap map[uint64]*regionTree, storeID uint64) { store, ok := peersMap[storeID] if !ok { - if !countRef { - store = newRegionTree() - } else { - store = newRegionTreeWithCountRef() - } + store = newRegionTree() peersMap[storeID] = store } store.update(item, false) } - - // Add to leaders and followers. for _, peer := range region.GetVoters() { storeID := peer.GetStoreId() if peer.GetId() == region.leader.GetId() { - // Add leader peer to leaders. - setPeer(r.leaders, storeID, item, true) + setPeer(r.leaders, storeID) } else { - // Add follower peer to followers. - setPeer(r.followers, storeID, item, false) + setPeer(r.followers, storeID) } } - + // Add other peers. setPeers := func(peersMap map[uint64]*regionTree, peers []*metapb.Peer) { for _, peer := range peers { - storeID := peer.GetStoreId() - setPeer(peersMap, storeID, item, false) + setPeer(peersMap, peer.GetStoreId()) } } - // Add to learners. setPeers(r.learners, region.GetLearners()) - // Add to witnesses. setPeers(r.witnesses, region.GetWitnesses()) - // Add to PendingPeers setPeers(r.pendingPeers, region.GetPendingPeers()) } -func (r *RegionsInfo) getOverlapRegionFromSubTreeLocked(region *RegionInfo) []*RegionInfo { - it := ®ionItem{RegionInfo: region} - overlaps := make([]*RegionInfo, 0) - overlapsMap := make(map[uint64]struct{}) - collectFromItemSlice := func(peersMap map[uint64]*regionTree, storeID uint64) { - if tree, ok := peersMap[storeID]; ok { - items := tree.overlaps(it) - for _, item := range items { - if _, ok := overlapsMap[item.GetID()]; !ok { - overlapsMap[item.GetID()] = struct{}{} - overlaps = append(overlaps, item.RegionInfo) - } - } - } - } - for _, peer := range region.GetMeta().GetPeers() { - storeID := peer.GetStoreId() - collectFromItemSlice(r.leaders, storeID) - collectFromItemSlice(r.followers, storeID) - collectFromItemSlice(r.learners, storeID) - collectFromItemSlice(r.witnesses, storeID) - } - return overlaps +func (r *RegionsInfo) getOverlapRegionFromOverlapTreeLocked(region *RegionInfo) []*regionItem { + return r.overlapTree.overlaps(®ionItem{RegionInfo: region}) } // GetRelevantRegions returns the relevant regions for a given region. @@ -1276,72 +1266,11 @@ func (r *RegionsInfo) UpdateSubTree(region, origin *RegionInfo, overlaps []*Regi r.st.Lock() defer r.st.Unlock() if origin != nil { - if rangeChanged || !origin.peersEqualTo(region) { - // If the range or peers have changed, the sub regionTree needs to be cleaned up. - // TODO: Improve performance by deleting only the different peers. - r.removeRegionFromSubTreeLocked(origin) - } else { - // The region tree and the subtree update is not atomic and the region tree is updated first. - // If there are two thread needs to update region tree, - // t1: thread-A update region tree - // t2: thread-B: update region tree again - // t3: thread-B: update subtree - // t4: thread-A: update region subtree - // to keep region tree consistent with subtree, we need to drop this update. - if tree, ok := r.subRegions[region.GetID()]; ok { - r.updateSubTreeStat(origin, region) - tree.RegionInfo = region - } + if r.preUpdateSubTreeLocked(rangeChanged, !origin.peersEqualTo(region), false, origin, region) { return } } - if rangeChanged { - for _, re := range overlaps { - r.removeRegionFromSubTreeLocked(re) - } - } - - item := ®ionItem{region} - r.subRegions[region.GetID()] = item - // It has been removed and all information needs to be updated again. - // Set peers then. - setPeer := func(peersMap map[uint64]*regionTree, storeID uint64, item *regionItem, countRef bool) { - store, ok := peersMap[storeID] - if !ok { - if !countRef { - store = newRegionTree() - } else { - store = newRegionTreeWithCountRef() - } - peersMap[storeID] = store - } - store.update(item, false) - } - - // Add to leaders and followers. - for _, peer := range region.GetVoters() { - storeID := peer.GetStoreId() - if peer.GetId() == region.leader.GetId() { - // Add leader peer to leaders. - setPeer(r.leaders, storeID, item, true) - } else { - // Add follower peer to followers. - setPeer(r.followers, storeID, item, false) - } - } - - setPeers := func(peersMap map[uint64]*regionTree, peers []*metapb.Peer) { - for _, peer := range peers { - storeID := peer.GetStoreId() - setPeer(peersMap, storeID, item, false) - } - } - // Add to learners. - setPeers(r.learners, region.GetLearners()) - // Add to witnesses. - setPeers(r.witnesses, region.GetWitnesses()) - // Add to PendingPeers - setPeers(r.pendingPeers, region.GetPendingPeers()) + r.updateSubTreeLocked(rangeChanged, overlaps, region) } func (r *RegionsInfo) updateSubTreeStat(origin *RegionInfo, region *RegionInfo) { @@ -1395,7 +1324,7 @@ func (r *RegionsInfo) RemoveRegion(region *RegionInfo) { // ResetRegionCache resets the regions info. func (r *RegionsInfo) ResetRegionCache() { r.t.Lock() - r.tree = newRegionTree() + r.tree = newRegionTreeWithCountRef() r.regions = make(map[uint64]*regionItem) r.t.Unlock() r.st.Lock() @@ -1405,6 +1334,7 @@ func (r *RegionsInfo) ResetRegionCache() { r.learners = make(map[uint64]*regionTree) r.witnesses = make(map[uint64]*regionTree) r.pendingPeers = make(map[uint64]*regionTree) + r.overlapTree = newRegionTreeWithCountRef() } // RemoveRegionFromSubTree removes RegionInfo from regionSubTrees @@ -1417,7 +1347,6 @@ func (r *RegionsInfo) RemoveRegionFromSubTree(region *RegionInfo) { // removeRegionFromSubTreeLocked removes RegionInfo from regionSubTrees func (r *RegionsInfo) removeRegionFromSubTreeLocked(region *RegionInfo) { - // Remove from leaders and followers. for _, peer := range region.GetMeta().GetPeers() { storeID := peer.GetStoreId() r.leaders[storeID].remove(region) @@ -1426,6 +1355,7 @@ func (r *RegionsInfo) removeRegionFromSubTreeLocked(region *RegionInfo) { r.witnesses[storeID].remove(region) r.pendingPeers[storeID].remove(region) } + r.overlapTree.remove(region) delete(r.subRegions, region.GetMeta().GetId()) } diff --git a/pkg/core/region_test.go b/pkg/core/region_test.go index 43629fccda0..1b8f20cf9b2 100644 --- a/pkg/core/region_test.go +++ b/pkg/core/region_test.go @@ -778,27 +778,24 @@ func BenchmarkRandomSetRegionWithGetRegionSizeByRangeParallel(b *testing.B) { ) } -const keyLength = 100 - -func randomBytes(n int) []byte { - bytes := make([]byte, n) - _, err := rand.Read(bytes) - if err != nil { - panic(err) - } - return bytes -} +const ( + peerNum = 3 + storeNum = 10 + keyLength = 100 +) func newRegionInfoIDRandom(idAllocator id.Allocator) *RegionInfo { var ( peers []*metapb.Peer leader *metapb.Peer ) - storeNum := 10 - for i := 0; i < 3; i++ { + // Randomly select a peer as the leader. + leaderIdx := mrand.Intn(peerNum) + for i := 0; i < peerNum; i++ { id, _ := idAllocator.Alloc() - p := &metapb.Peer{Id: id, StoreId: uint64(i%storeNum + 1)} - if i == 0 { + // Randomly distribute the peers to different stores. + p := &metapb.Peer{Id: id, StoreId: uint64(mrand.Intn(storeNum) + 1)} + if i == leaderIdx { leader = p } peers = append(peers, p) @@ -817,13 +814,19 @@ func newRegionInfoIDRandom(idAllocator id.Allocator) *RegionInfo { ) } +func randomBytes(n int) []byte { + bytes := make([]byte, n) + _, err := rand.Read(bytes) + if err != nil { + panic(err) + } + return bytes +} + func BenchmarkAddRegion(b *testing.B) { regions := NewRegionsInfo() idAllocator := mockid.NewIDAllocator() - var items []*RegionInfo - for i := 0; i < 10000000; i++ { - items = append(items, newRegionInfoIDRandom(idAllocator)) - } + items := generateRegionItems(idAllocator, 10000000) b.ResetTimer() for i := 0; i < b.N; i++ { origin, overlaps, rangeChanged := regions.SetRegion(items[i]) @@ -831,6 +834,54 @@ func BenchmarkAddRegion(b *testing.B) { } } +func BenchmarkUpdateSubTreeOrderInsensitive(b *testing.B) { + idAllocator := mockid.NewIDAllocator() + for _, size := range []int{10, 100, 1000, 10000, 100000, 1000000, 10000000} { + regions := NewRegionsInfo() + items := generateRegionItems(idAllocator, size) + // Update the subtrees from an empty `*RegionsInfo`. + b.Run(fmt.Sprintf("from empty with size %d", size), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + for idx := range items { + regions.UpdateSubTreeOrderInsensitive(items[idx]) + } + } + }) + + // Update the subtrees from a non-empty `*RegionsInfo` with the same regions, + // which means the regions are completely non-overlapped. + b.Run(fmt.Sprintf("from non-overlapped regions with size %d", size), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + for idx := range items { + regions.UpdateSubTreeOrderInsensitive(items[idx]) + } + } + }) + + // Update the subtrees from a non-empty `*RegionsInfo` with different regions, + // which means the regions are most likely overlapped. + b.Run(fmt.Sprintf("from overlapped regions with size %d", size), func(b *testing.B) { + items = generateRegionItems(idAllocator, size) + b.ResetTimer() + for i := 0; i < b.N; i++ { + for idx := range items { + regions.UpdateSubTreeOrderInsensitive(items[idx]) + } + } + }) + } +} + +func generateRegionItems(idAllocator *mockid.IDAllocator, size int) []*RegionInfo { + items := make([]*RegionInfo, size) + for i := 0; i < size; i++ { + items[i] = newRegionInfoIDRandom(idAllocator) + } + return items +} + func BenchmarkRegionFromHeartbeat(b *testing.B) { peers := make([]*metapb.Peer, 0, 3) for i := uint64(1); i <= 3; i++ { @@ -1021,3 +1072,27 @@ func TestUpdateRegionEventualConsistency(t *testing.T) { re.Equal(int32(2), item.GetRef()) } } + +func TestCheckAndPutSubTree(t *testing.T) { + re := require.New(t) + regions := NewRegionsInfo() + region := NewTestRegionInfo(1, 1, []byte("a"), []byte("b")) + regions.CheckAndPutSubTree(region) + // should failed to put because the root tree is missing + re.Equal(0, regions.tree.length()) +} + +func TestCntRefAfterResetRegionCache(t *testing.T) { + re := require.New(t) + regions := NewRegionsInfo() + // Put the region first. + region := NewTestRegionInfo(1, 1, []byte("a"), []byte("b")) + regions.CheckAndPutRegion(region) + re.Equal(int32(2), region.GetRef()) + regions.ResetRegionCache() + // Put the region after reset. + region = NewTestRegionInfo(1, 1, []byte("a"), []byte("b")) + re.Zero(region.GetRef()) + regions.CheckAndPutRegion(region) + re.Equal(int32(2), region.GetRef()) +} diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index 2a5302b34dc..d3691516868 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -594,10 +594,10 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c ctx.TaskRunner.RunTask( ctx, + ratelimit.HandleStatsAsync, func(_ context.Context) { cluster.HandleStatsAsync(c, region) }, - ratelimit.WithTaskName(ratelimit.HandleStatsAsync), ) tracer.OnAsyncHotStatsFinished() hasRegionStats := c.regionStats != nil @@ -611,22 +611,22 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c if hasRegionStats && c.regionStats.RegionStatsNeedUpdate(region) { ctx.TaskRunner.RunTask( ctx, + ratelimit.ObserveRegionStatsAsync, func(_ context.Context) { if c.regionStats.RegionStatsNeedUpdate(region) { cluster.Collect(c, region, hasRegionStats) } }, - ratelimit.WithTaskName(ratelimit.ObserveRegionStatsAsync), ) } // region is not updated to the subtree. if origin.GetRef() < 2 { ctx.TaskRunner.RunTask( ctx, + ratelimit.UpdateSubTree, func(_ context.Context) { c.CheckAndPutSubTree(region) }, - ratelimit.WithTaskName(ratelimit.UpdateSubTree), ) } return nil @@ -646,28 +646,28 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c } ctx.TaskRunner.RunTask( ctx, + ratelimit.UpdateSubTree, func(_ context.Context) { c.CheckAndPutSubTree(region) }, - ratelimit.WithTaskName(ratelimit.UpdateSubTree), ) tracer.OnUpdateSubTreeFinished() ctx.TaskRunner.RunTask( ctx, + ratelimit.HandleOverlaps, func(_ context.Context) { cluster.HandleOverlaps(c, overlaps) }, - ratelimit.WithTaskName(ratelimit.HandleOverlaps), ) } tracer.OnSaveCacheFinished() // handle region stats ctx.TaskRunner.RunTask( ctx, + ratelimit.CollectRegionStatsAsync, func(_ context.Context) { cluster.Collect(c, region, hasRegionStats) }, - ratelimit.WithTaskName(ratelimit.CollectRegionStatsAsync), ) tracer.OnCollectRegionStatsFinished() return nil @@ -692,3 +692,10 @@ func (c *Cluster) DropCacheAllRegion() { func (c *Cluster) DropCacheRegion(id uint64) { c.RemoveRegionIfExist(id) } + +// IsSchedulingHalted returns whether the scheduling is halted. +// Currently, the microservice scheduling is halted when: +// - The `HaltScheduling` persist option is set to true. +func (c *Cluster) IsSchedulingHalted() bool { + return c.persistConfig.IsSchedulingHalted() +} diff --git a/pkg/mcs/scheduling/server/config/config.go b/pkg/mcs/scheduling/server/config/config.go index 07bb12049c0..6465c58ff87 100644 --- a/pkg/mcs/scheduling/server/config/config.go +++ b/pkg/mcs/scheduling/server/config/config.go @@ -683,6 +683,10 @@ func (o *PersistConfig) SetSplitMergeInterval(splitMergeInterval time.Duration) o.SetScheduleConfig(v) } +// SetSchedulingAllowanceStatus sets the scheduling allowance status to help distinguish the source of the halt. +// TODO: support this metrics for the scheduling service in the future. +func (*PersistConfig) SetSchedulingAllowanceStatus(bool, string) {} + // SetHaltScheduling set HaltScheduling. func (o *PersistConfig) SetHaltScheduling(halt bool, _ string) { v := o.GetScheduleConfig().Clone() diff --git a/pkg/mcs/scheduling/server/grpc_service.go b/pkg/mcs/scheduling/server/grpc_service.go index 62ec1c1118f..605ec73dad5 100644 --- a/pkg/mcs/scheduling/server/grpc_service.go +++ b/pkg/mcs/scheduling/server/grpc_service.go @@ -275,7 +275,7 @@ func (s *Service) AskBatchSplit(_ context.Context, request *schedulingpb.AskBatc }, nil } - if c.persistConfig.IsSchedulingHalted() { + if c.IsSchedulingHalted() { return nil, errs.ErrSchedulingIsHalted.FastGenByArgs() } if !c.persistConfig.IsTikvRegionSplitEnabled() { diff --git a/pkg/ratelimit/concurrency_limiter.go b/pkg/ratelimit/concurrency_limiter.go index af768461478..e5379bc48cc 100644 --- a/pkg/ratelimit/concurrency_limiter.go +++ b/pkg/ratelimit/concurrency_limiter.go @@ -106,8 +106,8 @@ func (l *ConcurrencyLimiter) GetWaitingTasksNum() uint64 { return l.waiting } -// Acquire acquires a token from the limiter. which will block until a token is available or ctx is done, like Timeout. -func (l *ConcurrencyLimiter) Acquire(ctx context.Context) (*TaskToken, error) { +// AcquireToken acquires a token from the limiter. which will block until a token is available or ctx is done, like Timeout. +func (l *ConcurrencyLimiter) AcquireToken(ctx context.Context) (*TaskToken, error) { l.mu.Lock() if l.current >= l.limit { l.waiting++ @@ -129,27 +129,26 @@ func (l *ConcurrencyLimiter) Acquire(ctx context.Context) (*TaskToken, error) { } } l.current++ - token := &TaskToken{limiter: l} + token := &TaskToken{} l.mu.Unlock() return token, nil } -// TaskToken is a token that must be released after the task is done. -type TaskToken struct { - released bool - limiter *ConcurrencyLimiter -} - -// Release releases the token. -func (tt *TaskToken) Release() { - tt.limiter.mu.Lock() - defer tt.limiter.mu.Unlock() - if tt.released { +// ReleaseToken releases the token. +func (l *ConcurrencyLimiter) ReleaseToken(token *TaskToken) { + l.mu.Lock() + defer l.mu.Unlock() + if token.released { return } - tt.released = true - tt.limiter.current-- - if len(tt.limiter.queue) < int(tt.limiter.limit) { - tt.limiter.queue <- tt + token.released = true + l.current-- + if len(l.queue) < int(l.limit) { + l.queue <- token } } + +// TaskToken is a token that must be released after the task is done. +type TaskToken struct { + released bool +} diff --git a/pkg/ratelimit/concurrency_limiter_test.go b/pkg/ratelimit/concurrency_limiter_test.go index a397b6ac50f..f0af1125d21 100644 --- a/pkg/ratelimit/concurrency_limiter_test.go +++ b/pkg/ratelimit/concurrency_limiter_test.go @@ -68,17 +68,17 @@ func TestConcurrencyLimiter2(t *testing.T) { defer cancel() // Acquire two tokens - token1, err := limiter.Acquire(ctx) + token1, err := limiter.AcquireToken(ctx) require.NoError(t, err, "Failed to acquire token") - token2, err := limiter.Acquire(ctx) + token2, err := limiter.AcquireToken(ctx) require.NoError(t, err, "Failed to acquire token") require.Equal(t, limit, limiter.GetRunningTasksNum(), "Expected running tasks to be 2") // Try to acquire third token, it should not be able to acquire immediately due to limit go func() { - _, err := limiter.Acquire(ctx) + _, err := limiter.AcquireToken(ctx) require.NoError(t, err, "Failed to acquire token") }() @@ -86,13 +86,13 @@ func TestConcurrencyLimiter2(t *testing.T) { require.Equal(t, uint64(1), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 1") // Release a token - token1.Release() + limiter.ReleaseToken(token1) time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run require.Equal(t, uint64(2), limiter.GetRunningTasksNum(), "Expected running tasks to be 2") require.Equal(t, uint64(0), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 0") // Release the second token - token2.Release() + limiter.ReleaseToken(token2) time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run require.Equal(t, uint64(1), limiter.GetRunningTasksNum(), "Expected running tasks to be 1") } @@ -109,12 +109,12 @@ func TestConcurrencyLimiterAcquire(t *testing.T) { for i := 0; i < 100; i++ { go func(i int) { defer wg.Done() - token, err := limiter.Acquire(ctx) + token, err := limiter.AcquireToken(ctx) if err != nil { fmt.Printf("Task %d failed to acquire: %v\n", i, err) return } - defer token.Release() + defer limiter.ReleaseToken(token) // simulate takes some time time.Sleep(10 * time.Millisecond) atomic.AddInt64(&sum, 1) @@ -122,6 +122,6 @@ func TestConcurrencyLimiterAcquire(t *testing.T) { } wg.Wait() // We should have 20 tasks running concurrently, so it should take at least 50ms to complete - require.Greater(t, time.Since(start).Milliseconds(), int64(50)) + require.GreaterOrEqual(t, time.Since(start).Milliseconds(), int64(50)) require.Equal(t, int64(100), sum) } diff --git a/pkg/ratelimit/metrics.go b/pkg/ratelimit/metrics.go index 3c5020554a8..5d4443a1cc4 100644 --- a/pkg/ratelimit/metrics.go +++ b/pkg/ratelimit/metrics.go @@ -18,7 +18,10 @@ import ( "github.com/prometheus/client_golang/prometheus" ) -const nameStr = "runner_name" +const ( + nameStr = "runner_name" + taskStr = "task_type" +) var ( RunnerTaskMaxWaitingDuration = prometheus.NewGaugeVec( @@ -35,7 +38,7 @@ var ( Subsystem: "ratelimit", Name: "runner_task_pending_tasks", Help: "The number of pending tasks in the runner.", - }, []string{nameStr}) + }, []string{nameStr, taskStr}) RunnerTaskFailedTasks = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: "pd", diff --git a/pkg/ratelimit/runner.go b/pkg/ratelimit/runner.go index 44ee54971f5..07233af238b 100644 --- a/pkg/ratelimit/runner.go +++ b/pkg/ratelimit/runner.go @@ -39,17 +39,18 @@ const initialCapacity = 100 // Runner is the interface for running tasks. type Runner interface { - RunTask(ctx context.Context, f func(context.Context), opts ...TaskOption) error + RunTask(ctx context.Context, name string, f func(context.Context), opts ...TaskOption) error Start() Stop() } // Task is a task to be run. type Task struct { - Ctx context.Context - Opts *TaskOpts - f func(context.Context) + ctx context.Context submittedAt time.Time + opts *TaskOpts + f func(context.Context) + name string } // ErrMaxWaitingTasksExceeded is returned when the number of waiting tasks exceeds the maximum. @@ -65,6 +66,7 @@ type ConcurrentRunner struct { pendingMu sync.Mutex stopChan chan struct{} wg sync.WaitGroup + pendingTaskCount map[string]int64 failedTaskCount prometheus.Counter maxWaitingDuration prometheus.Gauge } @@ -78,25 +80,18 @@ func NewConcurrentRunner(name string, limiter *ConcurrencyLimiter, maxPendingDur taskChan: make(chan *Task), pendingTasks: make([]*Task, 0, initialCapacity), failedTaskCount: RunnerTaskFailedTasks.WithLabelValues(name), + pendingTaskCount: make(map[string]int64), maxWaitingDuration: RunnerTaskMaxWaitingDuration.WithLabelValues(name), } return s } // TaskOpts is the options for RunTask. -type TaskOpts struct { - // TaskName is a human-readable name for the operation. TODO: metrics by name. - TaskName string -} +type TaskOpts struct{} // TaskOption configures TaskOp type TaskOption func(opts *TaskOpts) -// WithTaskName specify the task name. -func WithTaskName(name string) TaskOption { - return func(opts *TaskOpts) { opts.TaskName = name } -} - // Start starts the runner. func (cr *ConcurrentRunner) Start() { cr.stopChan = make(chan struct{}) @@ -108,13 +103,13 @@ func (cr *ConcurrentRunner) Start() { select { case task := <-cr.taskChan: if cr.limiter != nil { - token, err := cr.limiter.Acquire(context.Background()) + token, err := cr.limiter.AcquireToken(context.Background()) if err != nil { continue } - go cr.run(task.Ctx, task.f, token) + go cr.run(task, token) } else { - go cr.run(task.Ctx, task.f, nil) + go cr.run(task, nil) } case <-cr.stopChan: cr.pendingMu.Lock() @@ -128,6 +123,9 @@ func (cr *ConcurrentRunner) Start() { if len(cr.pendingTasks) > 0 { maxDuration = time.Since(cr.pendingTasks[0].submittedAt) } + for name, cnt := range cr.pendingTaskCount { + RunnerTaskPendingTasks.WithLabelValues(cr.name, name).Set(float64(cnt)) + } cr.pendingMu.Unlock() cr.maxWaitingDuration.Set(maxDuration.Seconds()) } @@ -135,10 +133,10 @@ func (cr *ConcurrentRunner) Start() { }() } -func (cr *ConcurrentRunner) run(ctx context.Context, task func(context.Context), token *TaskToken) { - task(ctx) +func (cr *ConcurrentRunner) run(task *Task, token *TaskToken) { + task.f(task.ctx) if token != nil { - token.Release() + cr.limiter.ReleaseToken(token) cr.processPendingTasks() } } @@ -151,6 +149,7 @@ func (cr *ConcurrentRunner) processPendingTasks() { select { case cr.taskChan <- task: cr.pendingTasks = cr.pendingTasks[1:] + cr.pendingTaskCount[task.name]-- return default: return @@ -165,15 +164,16 @@ func (cr *ConcurrentRunner) Stop() { } // RunTask runs the task asynchronously. -func (cr *ConcurrentRunner) RunTask(ctx context.Context, f func(context.Context), opts ...TaskOption) error { +func (cr *ConcurrentRunner) RunTask(ctx context.Context, name string, f func(context.Context), opts ...TaskOption) error { taskOpts := &TaskOpts{} for _, opt := range opts { opt(taskOpts) } task := &Task{ - Ctx: ctx, + ctx: ctx, + name: name, f: f, - Opts: taskOpts, + opts: taskOpts, } cr.processPendingTasks() @@ -191,6 +191,7 @@ func (cr *ConcurrentRunner) RunTask(ctx context.Context, f func(context.Context) } task.submittedAt = time.Now() cr.pendingTasks = append(cr.pendingTasks, task) + cr.pendingTaskCount[task.name]++ } return nil } @@ -204,7 +205,7 @@ func NewSyncRunner() *SyncRunner { } // RunTask runs the task synchronously. -func (*SyncRunner) RunTask(ctx context.Context, f func(context.Context), _ ...TaskOption) error { +func (*SyncRunner) RunTask(ctx context.Context, _ string, f func(context.Context), _ ...TaskOption) error { f(ctx) return nil } diff --git a/pkg/ratelimit/runner_test.go b/pkg/ratelimit/runner_test.go index ccbf6ed59ed..0241536686b 100644 --- a/pkg/ratelimit/runner_test.go +++ b/pkg/ratelimit/runner_test.go @@ -35,11 +35,11 @@ func TestConcurrentRunner(t *testing.T) { wg.Add(1) err := runner.RunTask( context.Background(), + "test1", func(context.Context) { defer wg.Done() time.Sleep(100 * time.Millisecond) }, - WithTaskName("test1"), ) require.NoError(t, err) } @@ -55,11 +55,11 @@ func TestConcurrentRunner(t *testing.T) { wg.Add(1) err := runner.RunTask( context.Background(), + "test2", func(context.Context) { defer wg.Done() time.Sleep(100 * time.Millisecond) }, - WithTaskName("test2"), ) if err != nil { wg.Done() diff --git a/pkg/schedule/config/config_provider.go b/pkg/schedule/config/config_provider.go index 20c7f0dc2cf..90e489f86f3 100644 --- a/pkg/schedule/config/config_provider.go +++ b/pkg/schedule/config/config_provider.go @@ -46,7 +46,7 @@ func IsSchedulerRegistered(name string) bool { type SchedulerConfigProvider interface { SharedConfigProvider - IsSchedulingHalted() bool + SetSchedulingAllowanceStatus(bool, string) GetStoresLimit() map[uint64]StoreLimitConfig IsSchedulerDisabled(string) bool diff --git a/pkg/schedule/coordinator.go b/pkg/schedule/coordinator.go index 35d9c2029a1..fb22303f0b7 100644 --- a/pkg/schedule/coordinator.go +++ b/pkg/schedule/coordinator.go @@ -52,7 +52,8 @@ const ( // pushOperatorTickInterval is the interval try to push the operator. pushOperatorTickInterval = 500 * time.Millisecond - patrolScanRegionLimit = 128 // It takes about 14 minutes to iterate 1 million regions. + // It takes about 1.3 minutes(1000000/128*10/60/1000) to iterate 1 million regions(with DefaultPatrolRegionInterval=10ms). + patrolScanRegionLimit = 128 // PluginLoad means action for load plugin PluginLoad = "PluginLoad" // PluginUnload means action for unload plugin @@ -178,7 +179,7 @@ func (c *Coordinator) PatrolRegions() { log.Info("patrol regions has been stopped") return } - if c.isSchedulingHalted() { + if c.cluster.IsSchedulingHalted() { continue } @@ -207,10 +208,6 @@ func (c *Coordinator) PatrolRegions() { } } -func (c *Coordinator) isSchedulingHalted() bool { - return c.cluster.GetSchedulerConfig().IsSchedulingHalted() -} - func (c *Coordinator) checkRegions(startKey []byte) (key []byte, regions []*core.RegionInfo) { regions = c.cluster.ScanRegions(startKey, nil, patrolScanRegionLimit) if len(regions) == 0 { diff --git a/pkg/schedule/core/cluster_informer.go b/pkg/schedule/core/cluster_informer.go index 63dacd0c30d..b97459d26ea 100644 --- a/pkg/schedule/core/cluster_informer.go +++ b/pkg/schedule/core/cluster_informer.go @@ -43,6 +43,7 @@ type SchedulerCluster interface { GetSchedulerConfig() sc.SchedulerConfigProvider GetRegionLabeler() *labeler.RegionLabeler GetStoreConfig() sc.StoreConfigProvider + IsSchedulingHalted() bool } // CheckerCluster is an aggregate interface that wraps multiple interfaces diff --git a/pkg/schedule/schedulers/scheduler_controller.go b/pkg/schedule/schedulers/scheduler_controller.go index 5953ecac5e3..334a2f1199a 100644 --- a/pkg/schedule/schedulers/scheduler_controller.go +++ b/pkg/schedule/schedulers/scheduler_controller.go @@ -115,7 +115,7 @@ func (c *Controller) CollectSchedulerMetrics() { var allowScheduler float64 // If the scheduler is not allowed to schedule, it will disappear in Grafana panel. // See issue #1341. - if !s.IsPaused() && !c.isSchedulingHalted() { + if !s.IsPaused() && !c.cluster.IsSchedulingHalted() { allowScheduler = 1 } schedulerStatusGauge.WithLabelValues(s.Scheduler.GetName(), "allow").Set(allowScheduler) @@ -131,10 +131,6 @@ func (c *Controller) CollectSchedulerMetrics() { ruleStatusGauge.WithLabelValues("group_count").Set(float64(groupCnt)) } -func (c *Controller) isSchedulingHalted() bool { - return c.cluster.GetSchedulerConfig().IsSchedulingHalted() -} - // ResetSchedulerMetrics resets metrics of all schedulers. func ResetSchedulerMetrics() { schedulerStatusGauge.Reset() @@ -526,7 +522,7 @@ func (s *ScheduleController) AllowSchedule(diagnosable bool) bool { } return false } - if s.isSchedulingHalted() { + if s.cluster.IsSchedulingHalted() { if diagnosable { s.diagnosticRecorder.SetResultFromStatus(Halted) } @@ -541,10 +537,6 @@ func (s *ScheduleController) AllowSchedule(diagnosable bool) bool { return true } -func (s *ScheduleController) isSchedulingHalted() bool { - return s.cluster.GetSchedulerConfig().IsSchedulingHalted() -} - // IsPaused returns if a scheduler is paused. func (s *ScheduleController) IsPaused() bool { delayUntil := atomic.LoadInt64(&s.delayUntil) diff --git a/pkg/statistics/region_collection.go b/pkg/statistics/region_collection.go index 565597b4efb..e4c159cf22d 100644 --- a/pkg/statistics/region_collection.go +++ b/pkg/statistics/region_collection.go @@ -32,7 +32,7 @@ type RegionInfoProvider interface { } // RegionStatisticType represents the type of the region's status. -type RegionStatisticType uint32 +type RegionStatisticType uint16 const emptyStatistic = RegionStatisticType(0) @@ -81,7 +81,6 @@ var ( // RegionInfoWithTS is used to record the extra timestamp status of a region. type RegionInfoWithTS struct { - id uint64 startMissVoterPeerTS int64 startDownPeerTS int64 } @@ -91,7 +90,7 @@ type RegionStatistics struct { syncutil.RWMutex rip RegionInfoProvider conf sc.CheckerConfigProvider - stats map[RegionStatisticType]map[uint64]*RegionInfoWithTS + stats map[RegionStatisticType]map[uint64]any index map[uint64]RegionStatisticType ruleManager *placement.RuleManager } @@ -106,11 +105,11 @@ func NewRegionStatistics( rip: rip, conf: conf, ruleManager: ruleManager, - stats: make(map[RegionStatisticType]map[uint64]*RegionInfoWithTS), + stats: make(map[RegionStatisticType]map[uint64]any), index: make(map[uint64]RegionStatisticType), } for _, typ := range regionStatisticTypes { - r.stats[typ] = make(map[uint64]*RegionInfoWithTS) + r.stats[typ] = make(map[uint64]any) } return r } @@ -207,14 +206,27 @@ func (r *RegionStatistics) Observe(region *core.RegionInfo, stores []*core.Store } } } + + peers := region.GetPeers() + downPeers := region.GetDownPeers() + pendingPeers := region.GetPendingPeers() + learners := region.GetLearners() + voters := region.GetVoters() + regionSize := region.GetApproximateSize() + regionMaxSize := int64(r.conf.GetRegionMaxSize()) + regionMaxKeys := int64(r.conf.GetRegionMaxKeys()) + maxMergeRegionSize := int64(r.conf.GetMaxMergeRegionSize()) + maxMergeRegionKeys := int64(r.conf.GetMaxMergeRegionKeys()) + leaderIsWitness := region.GetLeader().GetIsWitness() + // Better to make sure once any of these conditions changes, it will trigger the heartbeat `save_cache`. // Otherwise, the state may be out-of-date for a long time, which needs another way to apply the change ASAP. // For example, see `RegionStatsNeedUpdate` above to know how `OversizedRegion` and `UndersizedRegion` are updated. conditions := map[RegionStatisticType]bool{ - MissPeer: len(region.GetPeers()) < desiredReplicas, - ExtraPeer: len(region.GetPeers()) > desiredReplicas, - DownPeer: len(region.GetDownPeers()) > 0, - PendingPeer: len(region.GetPendingPeers()) > 0, + MissPeer: len(peers) < desiredReplicas, + ExtraPeer: len(peers) > desiredReplicas, + DownPeer: len(downPeers) > 0, + PendingPeer: len(pendingPeers) > 0, OfflinePeer: func() bool { for _, store := range stores { if store.IsRemoving() { @@ -226,39 +238,40 @@ func (r *RegionStatistics) Observe(region *core.RegionInfo, stores []*core.Store } return false }(), - LearnerPeer: len(region.GetLearners()) > 0, - EmptyRegion: region.GetApproximateSize() <= core.EmptyRegionApproximateSize, - OversizedRegion: region.IsOversized( - int64(r.conf.GetRegionMaxSize()), - int64(r.conf.GetRegionMaxKeys()), - ), - UndersizedRegion: region.NeedMerge( - int64(r.conf.GetMaxMergeRegionSize()), - int64(r.conf.GetMaxMergeRegionKeys()), - ), - WitnessLeader: region.GetLeader().GetIsWitness(), + LearnerPeer: len(learners) > 0, + EmptyRegion: regionSize <= core.EmptyRegionApproximateSize, + OversizedRegion: region.IsOversized(regionMaxSize, regionMaxKeys), + UndersizedRegion: region.NeedMerge(maxMergeRegionSize, maxMergeRegionKeys), + WitnessLeader: leaderIsWitness, } // Check if the region meets any of the conditions and update the corresponding info. regionID := region.GetID() for typ, c := range conditions { if c { info := r.stats[typ][regionID] - if info == nil { - info = &RegionInfoWithTS{id: regionID} - } if typ == DownPeer { - if info.startDownPeerTS != 0 { - regionDownPeerDuration.Observe(float64(time.Now().Unix() - info.startDownPeerTS)) + if info == nil { + info = &RegionInfoWithTS{} + } + if info.(*RegionInfoWithTS).startDownPeerTS != 0 { + regionDownPeerDuration.Observe(float64(time.Now().Unix() - info.(*RegionInfoWithTS).startDownPeerTS)) } else { - info.startDownPeerTS = time.Now().Unix() + info.(*RegionInfoWithTS).startDownPeerTS = time.Now().Unix() logDownPeerWithNoDisconnectedStore(region, stores) } - } else if typ == MissPeer && len(region.GetVoters()) < desiredVoters { - if info.startMissVoterPeerTS != 0 { - regionMissVoterPeerDuration.Observe(float64(time.Now().Unix() - info.startMissVoterPeerTS)) - } else { - info.startMissVoterPeerTS = time.Now().Unix() + } else if typ == MissPeer { + if info == nil { + info = &RegionInfoWithTS{} + } + if len(voters) < desiredVoters { + if info.(*RegionInfoWithTS).startMissVoterPeerTS != 0 { + regionMissVoterPeerDuration.Observe(float64(time.Now().Unix() - info.(*RegionInfoWithTS).startMissVoterPeerTS)) + } else { + info.(*RegionInfoWithTS).startMissVoterPeerTS = time.Now().Unix() + } } + } else { + info = struct{}{} } r.stats[typ][regionID] = info diff --git a/pkg/statistics/region_collection_test.go b/pkg/statistics/region_collection_test.go index cbbf7672bee..64a625a04e2 100644 --- a/pkg/statistics/region_collection_test.go +++ b/pkg/statistics/region_collection_test.go @@ -269,3 +269,43 @@ func TestRegionLabelIsolationLevel(t *testing.T) { re.Equal(res, labelLevelStats.labelCounter[i]) } } + +func BenchmarkObserve(b *testing.B) { + // Setup + store := storage.NewStorageWithMemoryBackend() + manager := placement.NewRuleManager(context.Background(), store, nil, nil) + manager.Initialize(3, []string{"zone", "rack", "host"}, "") + opt := mockconfig.NewTestOptions() + opt.SetPlacementRuleEnabled(false) + peers := []*metapb.Peer{ + {Id: 4, StoreId: 1}, + {Id: 5, StoreId: 2}, + {Id: 6, StoreId: 3}, + } + + metaStores := []*metapb.Store{ + {Id: 1, Address: "mock://tikv-1"}, + {Id: 2, Address: "mock://tikv-2"}, + {Id: 3, Address: "mock://tikv-3"}, + } + + stores := make([]*core.StoreInfo, 0, len(metaStores)) + for _, m := range metaStores { + s := core.NewStoreInfo(m) + stores = append(stores, s) + } + + regionNum := uint64(1000000) + regions := make([]*core.RegionInfo, 0, regionNum) + for i := uint64(1); i <= regionNum; i++ { + r := &metapb.Region{Id: i, Peers: peers, StartKey: []byte{byte(i)}, EndKey: []byte{byte(i + 1)}} + regions = append(regions, core.NewRegionInfo(r, peers[0])) + } + regionStats := NewRegionStatistics(nil, opt, manager) + + b.ResetTimer() + // Run the Observe function b.N times + for i := 0; i < b.N; i++ { + regionStats.Observe(regions[i%int(regionNum)], stores) + } +} diff --git a/pkg/storage/endpoint/key_path.go b/pkg/storage/endpoint/key_path.go index 69b8d0f2f8e..dbcd9690419 100644 --- a/pkg/storage/endpoint/key_path.go +++ b/pkg/storage/endpoint/key_path.go @@ -149,24 +149,23 @@ func storeRegionWeightPath(storeID uint64) string { // RegionPath returns the region meta info key path with the given region ID. func RegionPath(regionID uint64) string { var buf strings.Builder + buf.Grow(len(regionPathPrefix) + 1 + keyLen) // Preallocate memory + buf.WriteString(regionPathPrefix) buf.WriteString("/") s := strconv.FormatUint(regionID, 10) - if len(s) > keyLen { - s = s[len(s)-keyLen:] - } else { - b := make([]byte, keyLen) + b := make([]byte, keyLen) + copy(b, s) + if len(s) < keyLen { diff := keyLen - len(s) - for i := 0; i < keyLen; i++ { - if i < diff { - b[i] = 48 - } else { - b[i] = s[i-diff] - } + copy(b[diff:], s) + for i := 0; i < diff; i++ { + b[i] = '0' } - s = string(b) + } else if len(s) > keyLen { + copy(b, s[len(s)-keyLen:]) } - buf.WriteString(s) + buf.Write(b) return buf.String() } diff --git a/pkg/unsaferecovery/unsafe_recovery_controller.go b/pkg/unsaferecovery/unsafe_recovery_controller.go index 044dbd182e2..d2f6125c3f3 100644 --- a/pkg/unsaferecovery/unsafe_recovery_controller.go +++ b/pkg/unsaferecovery/unsafe_recovery_controller.go @@ -493,12 +493,11 @@ func (u *Controller) GetStage() stage { } func (u *Controller) changeStage(stage stage) { - u.stage = stage - // Halt and resume the scheduling once the running state changed. - running := isRunning(stage) - if opt := u.cluster.GetSchedulerConfig(); opt.IsSchedulingHalted() != running { - opt.SetHaltScheduling(running, "online-unsafe-recovery") + // If the running stage changes, update the scheduling allowance status to add or remove "online-unsafe-recovery" halt. + if running := isRunning(stage); running != isRunning(u.stage) { + u.cluster.GetSchedulerConfig().SetSchedulingAllowanceStatus(running, "online-unsafe-recovery") } + u.stage = stage var output StageOutput output.Time = time.Now().Format("2006-01-02 15:04:05.000") diff --git a/pkg/utils/grpcutil/grpcutil.go b/pkg/utils/grpcutil/grpcutil.go index 9b8cc2feb49..5633533ae4a 100644 --- a/pkg/utils/grpcutil/grpcutil.go +++ b/pkg/utils/grpcutil/grpcutil.go @@ -186,13 +186,9 @@ func ResetForwardContext(ctx context.Context) context.Context { // GetForwardedHost returns the forwarded host in metadata. func GetForwardedHost(ctx context.Context) string { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - log.Debug("failed to get gRPC incoming metadata when getting forwarded host") - return "" - } - if t, ok := md[ForwardMetadataKey]; ok { - return t[0] + s := metadata.ValueFromIncomingContext(ctx, ForwardMetadataKey) + if len(s) > 0 { + return s[0] } return "" } diff --git a/pkg/utils/grpcutil/grpcutil_test.go b/pkg/utils/grpcutil/grpcutil_test.go index 2cbff4f3ebc..99cbeae6cde 100644 --- a/pkg/utils/grpcutil/grpcutil_test.go +++ b/pkg/utils/grpcutil/grpcutil_test.go @@ -1,6 +1,7 @@ package grpcutil import ( + "context" "os" "os/exec" "path" @@ -9,6 +10,7 @@ import ( "github.com/pingcap/errors" "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/errs" + "google.golang.org/grpc/metadata" ) var ( @@ -66,3 +68,14 @@ func TestToTLSConfig(t *testing.T) { _, err = tlsConfig.ToTLSConfig() re.True(errors.ErrorEqual(err, errs.ErrCryptoAppendCertsFromPEM)) } + +func BenchmarkGetForwardedHost(b *testing.B) { + // Without forwarded host key + md := metadata.Pairs("test", "example.com") + ctx := metadata.NewIncomingContext(context.Background(), md) + + // Run the GetForwardedHost function b.N times + for i := 0; i < b.N; i++ { + GetForwardedHost(ctx) + } +} diff --git a/scripts/ci-subtask.sh b/scripts/ci-subtask.sh index c00cba9c0a4..9bdce420d4a 100755 --- a/scripts/ci-subtask.sh +++ b/scripts/ci-subtask.sh @@ -3,63 +3,34 @@ # ./ci-subtask.sh ROOT_PATH_COV=$(pwd)/covprofile - -if [[ $2 -gt 9 ]]; then - # run tools tests - if [[ $2 -eq 10 ]]; then +# Currently, we only have 3 integration tests, so we can hardcode the task index. +integrations_dir=$(pwd)/tests/integrations + +case $1 in + 1) + # unit tests ignore `tests` + ./bin/pd-ut run --race --ignore tests --coverprofile $ROOT_PATH_COV || exit 1 + ;; + 2) + # unit tests only in `tests` + ./bin/pd-ut run tests --race --coverprofile $ROOT_PATH_COV || exit 1 + ;; + 3) + # tools tests cd ./tools && make ci-test-job && cat covprofile >> $ROOT_PATH_COV || exit 1 - exit - fi - - # Currently, we only have 3 integration tests, so we can hardcode the task index. - integrations_dir=$(pwd)/tests/integrations - integrations_tasks=($(find "$integrations_dir" -mindepth 1 -maxdepth 1 -type d)) - for t in "${integrations_tasks[@]}"; do - if [[ "$t" = "$integrations_dir/client" && $2 -eq 11 ]]; then - cd ./client && make ci-test-job && cat covprofile >> $ROOT_PATH_COV || exit 1 - cd $integrations_dir && make ci-test-job test_name=client && cat ./client/covprofile >> $ROOT_PATH_COV || exit 1 - elif [[ "$t" = "$integrations_dir/tso" && $2 -eq 12 ]]; then - cd $integrations_dir && make ci-test-job test_name=tso && cat ./tso/covprofile >> $ROOT_PATH_COV || exit 1 - elif [[ "$t" = "$integrations_dir/mcs" && $2 -eq 13 ]]; then - cd $integrations_dir && make ci-test-job test_name=mcs && cat ./mcs/covprofile >> $ROOT_PATH_COV || exit 1 - fi - done -else - # Get package test list. - packages=($(go list ./...)) - dirs=($(find . -iname "*_test.go" -exec dirname {} \; | sort -u | sed -e "s/^\./github.com\/tikv\/pd/")) - tasks=($(comm -12 <(printf "%s\n" "${packages[@]}") <(printf "%s\n" "${dirs[@]}"))) - - weight() { - [[ $1 == "github.com/tikv/pd/server/api" ]] && return 30 - [[ $1 == "github.com/tikv/pd/pkg/schedule" ]] && return 30 - [[ $1 == "github.com/tikv/pd/pkg/core" ]] && return 30 - [[ $1 == "github.com/tikv/pd/tests/server/api" ]] && return 30 - [[ $1 =~ "pd/tests" ]] && return 5 - return 1 - } - - # Create an associative array to store the weight of each task. - declare -A task_weights - for t in ${tasks[@]}; do - weight $t - task_weights[$t]=$? - done - - # Sort tasks by weight in descending order. - tasks=($(printf "%s\n" "${tasks[@]}" | sort -rn)) - - scores=($(seq "$1" | xargs -I{} echo 0)) - - res=() - for t in ${tasks[@]}; do - min_i=0 - for i in ${!scores[@]}; do - [[ ${scores[i]} -lt ${scores[$min_i]} ]] && min_i=$i - done - scores[$min_i]=$((${scores[$min_i]} + ${task_weights[$t]})) - [[ $(($min_i + 1)) -eq $2 ]] && res+=($t) - done - - CGO_ENABLED=1 go test -timeout=15m -tags deadlock -race -cover -covermode=atomic -coverprofile=$ROOT_PATH_COV -coverpkg=./... ${res[@]} -fi + ;; + 4) + # integration test client + ./bin/pd-ut it run client --race --coverprofile $ROOT_PATH_COV || exit 1 + # client tests + cd ./client && make ci-test-job && cat covprofile >> $ROOT_PATH_COV || exit 1 + ;; + 5) + # integration test tso + ./bin/pd-ut it run tso --race --coverprofile $ROOT_PATH_COV || exit 1 + ;; + 6) + # integration test mcs + ./bin/pd-ut it run mcs --race --coverprofile $ROOT_PATH_COV || exit 1 + ;; +esac diff --git a/server/api/diagnostic_test.go b/server/api/diagnostic_test.go index c98717902c5..8c4089a8710 100644 --- a/server/api/diagnostic_test.go +++ b/server/api/diagnostic_test.go @@ -36,7 +36,7 @@ type diagnosticTestSuite struct { cleanup tu.CleanupFunc urlPrefix string configPrefix string - schedulerPrifex string + schedulerPrefix string } func TestDiagnosticTestSuite(t *testing.T) { @@ -50,7 +50,7 @@ func (suite *diagnosticTestSuite) SetupSuite() { addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/schedulers/diagnostic", addr, apiPrefix) - suite.schedulerPrifex = fmt.Sprintf("%s%s/api/v1/schedulers", addr, apiPrefix) + suite.schedulerPrefix = fmt.Sprintf("%s%s/api/v1/schedulers", addr, apiPrefix) suite.configPrefix = fmt.Sprintf("%s%s/api/v1/config", addr, apiPrefix) mustBootstrapCluster(re, suite.svr) @@ -108,7 +108,7 @@ func (suite *diagnosticTestSuite) TestSchedulerDiagnosticAPI() { input["name"] = schedulers.BalanceRegionName body, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.schedulerPrifex, body, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, suite.schedulerPrefix, body, tu.StatusOK(re)) re.NoError(err) suite.checkStatus("pending", balanceRegionURL) @@ -116,21 +116,23 @@ func (suite *diagnosticTestSuite) TestSchedulerDiagnosticAPI() { input["delay"] = 30 pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.schedulerPrifex+"/"+schedulers.BalanceRegionName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, suite.schedulerPrefix+"/"+schedulers.BalanceRegionName, pauseArgs, tu.StatusOK(re)) re.NoError(err) suite.checkStatus("paused", balanceRegionURL) input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.schedulerPrifex+"/"+schedulers.BalanceRegionName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, suite.schedulerPrefix+"/"+schedulers.BalanceRegionName, pauseArgs, tu.StatusOK(re)) re.NoError(err) suite.checkStatus("pending", balanceRegionURL) + fmt.Println("before put region") mustPutRegion(re, suite.svr, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(60)) + fmt.Println("after put region") suite.checkStatus("normal", balanceRegionURL) - deleteURL := fmt.Sprintf("%s/%s", suite.schedulerPrifex, schedulers.BalanceRegionName) + deleteURL := fmt.Sprintf("%s/%s", suite.schedulerPrefix, schedulers.BalanceRegionName) err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) suite.checkStatus("disabled", balanceRegionURL) diff --git a/server/apiv2/handlers/tso_keyspace_group.go b/server/apiv2/handlers/tso_keyspace_group.go index a9f042687f6..835bda9d7bb 100644 --- a/server/apiv2/handlers/tso_keyspace_group.go +++ b/server/apiv2/handlers/tso_keyspace_group.go @@ -453,11 +453,6 @@ func SetNodesForKeyspaceGroup(c *gin.Context) { c.AbortWithStatusJSON(http.StatusBadRequest, "keyspace group does not exist") return } - // check if nodes is less than default replica count - if len(setParams.Nodes) < utils.DefaultKeyspaceGroupReplicaCount { - c.AbortWithStatusJSON(http.StatusBadRequest, "invalid num of nodes") - return - } // check if node exists for _, node := range setParams.Nodes { if !manager.IsExistNode(node) { diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 8889fdf87b6..a8558051dfa 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -107,8 +107,8 @@ const ( minSnapshotDurationSec = 5 // heartbeat relative const - heartbeatTaskRunner = "heartbeat-async-task-runner" - logTaskRunner = "log-async-task-runner" + heartbeatTaskRunner = "heartbeat-async" + logTaskRunner = "log-async" ) // Server is the interface for cluster. @@ -843,6 +843,14 @@ func (c *RaftCluster) SetPDServerConfig(cfg *config.PDServerConfig) { c.opt.SetPDServerConfig(cfg) } +// IsSchedulingHalted returns whether the scheduling is halted. +// Currently, the PD scheduling is halted when: +// - The `HaltScheduling` persist option is set to true. +// - Online unsafe recovery is running. +func (c *RaftCluster) IsSchedulingHalted() bool { + return c.opt.IsSchedulingHalted() || c.unsafeRecoveryController.IsRunning() +} + // GetUnsafeRecoveryController returns the unsafe recovery controller. func (c *RaftCluster) GetUnsafeRecoveryController() *unsaferecovery.Controller { return c.unsafeRecoveryController @@ -1018,10 +1026,10 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { ctx.TaskRunner.RunTask( ctx.Context, + ratelimit.HandleStatsAsync, func(_ context.Context) { cluster.HandleStatsAsync(c, region) }, - ratelimit.WithTaskName(ratelimit.HandleStatsAsync), ) } tracer.OnAsyncHotStatsFinished() @@ -1039,22 +1047,22 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio if hasRegionStats && c.regionStats.RegionStatsNeedUpdate(region) { ctx.TaskRunner.RunTask( ctx.Context, + ratelimit.ObserveRegionStatsAsync, func(_ context.Context) { if c.regionStats.RegionStatsNeedUpdate(region) { cluster.Collect(c, region, hasRegionStats) } }, - ratelimit.WithTaskName(ratelimit.ObserveRegionStatsAsync), ) } // region is not updated to the subtree. if origin.GetRef() < 2 { ctx.TaskRunner.RunTask( ctx, + ratelimit.UpdateSubTree, func(_ context.Context) { c.CheckAndPutSubTree(region) }, - ratelimit.WithTaskName(ratelimit.UpdateSubTree), ) } return nil @@ -1078,20 +1086,20 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio } ctx.TaskRunner.RunTask( ctx, + ratelimit.UpdateSubTree, func(_ context.Context) { c.CheckAndPutSubTree(region) }, - ratelimit.WithTaskName(ratelimit.UpdateSubTree), ) tracer.OnUpdateSubTreeFinished() if !c.IsServiceIndependent(mcsutils.SchedulingServiceName) { ctx.TaskRunner.RunTask( ctx.Context, + ratelimit.HandleOverlaps, func(_ context.Context) { cluster.HandleOverlaps(c, overlaps) }, - ratelimit.WithTaskName(ratelimit.HandleOverlaps), ) } regionUpdateCacheEventCounter.Inc() @@ -1101,13 +1109,13 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio // handle region stats ctx.TaskRunner.RunTask( ctx.Context, + ratelimit.CollectRegionStatsAsync, func(_ context.Context) { // TODO: Due to the accuracy requirements of the API "/regions/check/xxx", // region stats needs to be collected in API mode. // We need to think of a better way to reduce this part of the cost in the future. cluster.Collect(c, region, hasRegionStats) }, - ratelimit.WithTaskName(ratelimit.CollectRegionStatsAsync), ) tracer.OnCollectRegionStatsFinished() @@ -1115,6 +1123,7 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio if saveKV { ctx.TaskRunner.RunTask( ctx.Context, + ratelimit.SaveRegionToKV, func(_ context.Context) { // If there are concurrent heartbeats from the same region, the last write will win even if // writes to storage in the critical area. So don't use mutex to protect it. @@ -1136,7 +1145,6 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio } regionUpdateKVEventCounter.Inc() }, - ratelimit.WithTaskName(ratelimit.SaveRegionToKV), ) } } diff --git a/server/cluster/cluster_worker.go b/server/cluster/cluster_worker.go index 14a4d0c71a1..43602dbb68d 100644 --- a/server/cluster/cluster_worker.go +++ b/server/cluster/cluster_worker.go @@ -39,7 +39,7 @@ func (c *RaftCluster) HandleRegionHeartbeat(region *core.RegionInfo) error { if c.GetScheduleConfig().EnableHeartbeatBreakdownMetrics { tracer = core.NewHeartbeatProcessTracer() } - + defer tracer.Release() var taskRunner, logRunner ratelimit.Runner taskRunner, logRunner = syncRunner, syncRunner if c.GetScheduleConfig().EnableHeartbeatConcurrentRunner { @@ -69,7 +69,7 @@ func (c *RaftCluster) HandleRegionHeartbeat(region *core.RegionInfo) error { // HandleAskSplit handles the split request. func (c *RaftCluster) HandleAskSplit(request *pdpb.AskSplitRequest) (*pdpb.AskSplitResponse, error) { - if c.isSchedulingHalted() { + if c.IsSchedulingHalted() { return nil, errs.ErrSchedulingIsHalted.FastGenByArgs() } if !c.opt.IsTikvRegionSplitEnabled() { @@ -112,13 +112,9 @@ func (c *RaftCluster) HandleAskSplit(request *pdpb.AskSplitRequest) (*pdpb.AskSp return split, nil } -func (c *RaftCluster) isSchedulingHalted() bool { - return c.opt.IsSchedulingHalted() -} - // HandleAskBatchSplit handles the batch split request. func (c *RaftCluster) HandleAskBatchSplit(request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { - if c.isSchedulingHalted() { + if c.IsSchedulingHalted() { return nil, errs.ErrSchedulingIsHalted.FastGenByArgs() } if !c.opt.IsTikvRegionSplitEnabled() { diff --git a/server/config/persist_options.go b/server/config/persist_options.go index 62118dde593..6f5dc50f205 100644 --- a/server/config/persist_options.go +++ b/server/config/persist_options.go @@ -987,11 +987,8 @@ func (o *PersistOptions) SetAllStoresLimitTTL(ctx context.Context, client *clien var haltSchedulingStatus = schedulingAllowanceStatusGauge.WithLabelValues("halt-scheduling") -// SetHaltScheduling set HaltScheduling. -func (o *PersistOptions) SetHaltScheduling(halt bool, source string) { - v := o.GetScheduleConfig().Clone() - v.HaltScheduling = halt - o.SetScheduleConfig(v) +// SetSchedulingAllowanceStatus sets the scheduling allowance status to help distinguish the source of the halt. +func (*PersistOptions) SetSchedulingAllowanceStatus(halt bool, source string) { if halt { haltSchedulingStatus.Set(1) schedulingAllowanceStatusGauge.WithLabelValues(source).Set(1) @@ -1001,6 +998,14 @@ func (o *PersistOptions) SetHaltScheduling(halt bool, source string) { } } +// SetHaltScheduling set HaltScheduling. +func (o *PersistOptions) SetHaltScheduling(halt bool, source string) { + v := o.GetScheduleConfig().Clone() + v.HaltScheduling = halt + o.SetScheduleConfig(v) + o.SetSchedulingAllowanceStatus(halt, source) +} + // IsSchedulingHalted returns if PD scheduling is halted. func (o *PersistOptions) IsSchedulingHalted() bool { if o == nil { diff --git a/server/forward.go b/server/forward.go index 13bad4c7600..650833e1fc1 100644 --- a/server/forward.go +++ b/server/forward.go @@ -264,7 +264,7 @@ func forwardRegionHeartbeatToScheduling(rc *cluster.RaftCluster, forwardStream s return } // TODO: find a better way to halt scheduling immediately. - if rc.GetOpts().IsSchedulingHalted() { + if rc.IsSchedulingHalted() { continue } // The error types defined for schedulingpb and pdpb are different, so we need to convert them. diff --git a/server/server.go b/server/server.go index 8d7b83cfe4a..af9f48f8c9b 100644 --- a/server/server.go +++ b/server/server.go @@ -1042,6 +1042,7 @@ func (s *Server) GetScheduleConfig() *sc.ScheduleConfig { } // SetScheduleConfig sets the balance config information. +// This function is exported to be used by the API. func (s *Server) SetScheduleConfig(cfg sc.ScheduleConfig) error { if err := cfg.Validate(); err != nil { return err @@ -1060,6 +1061,8 @@ func (s *Server) SetScheduleConfig(cfg sc.ScheduleConfig) error { errs.ZapError(err)) return err } + // Update the scheduling halt status at the same time. + s.persistOptions.SetSchedulingAllowanceStatus(cfg.HaltScheduling, "manually") log.Info("schedule config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) return nil } diff --git a/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go b/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go index 160eea167d6..0c7683b569c 100644 --- a/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go +++ b/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go @@ -47,7 +47,6 @@ type keyspaceGroupTestSuite struct { cluster *tests.TestCluster server *tests.TestServer backendEndpoints string - dialClient *http.Client } func TestKeyspaceGroupTestSuite(t *testing.T) { @@ -67,11 +66,6 @@ func (suite *keyspaceGroupTestSuite) SetupTest() { suite.server = cluster.GetLeaderServer() re.NoError(suite.server.BootstrapCluster()) suite.backendEndpoints = suite.server.GetAddr() - suite.dialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, - } suite.cleanupFunc = func() { cancel() } @@ -81,7 +75,6 @@ func (suite *keyspaceGroupTestSuite) TearDownTest() { re := suite.Require() suite.cleanupFunc() suite.cluster.Destroy() - suite.dialClient.CloseIdleConnections() re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/keyspace/acceleratedAllocNodes")) } @@ -298,7 +291,7 @@ func (suite *keyspaceGroupTestSuite) TestSetNodes() { Nodes: []string{nodesList[0]}, } _, code = suite.trySetNodesForKeyspaceGroup(re, id, params) - re.Equal(http.StatusBadRequest, code) + re.Equal(http.StatusOK, code) // the keyspace group is not exist. id = 2 @@ -347,7 +340,7 @@ func (suite *keyspaceGroupTestSuite) tryAllocNodesForKeyspaceGroup(re *require.A re.NoError(err) httpReq, err := http.NewRequest(http.MethodPost, suite.server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d/alloc", id), bytes.NewBuffer(data)) re.NoError(err) - resp, err := suite.dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() nodes := make([]endpoint.KeyspaceGroupMember, 0) @@ -364,7 +357,7 @@ func (suite *keyspaceGroupTestSuite) tryCreateKeyspaceGroup(re *require.Assertio re.NoError(err) httpReq, err := http.NewRequest(http.MethodPost, suite.server.GetAddr()+keyspaceGroupsPrefix, bytes.NewBuffer(data)) re.NoError(err) - resp, err := suite.dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() return resp.StatusCode @@ -373,7 +366,7 @@ func (suite *keyspaceGroupTestSuite) tryCreateKeyspaceGroup(re *require.Assertio func (suite *keyspaceGroupTestSuite) tryGetKeyspaceGroup(re *require.Assertions, id uint32) (*endpoint.KeyspaceGroup, int) { httpReq, err := http.NewRequest(http.MethodGet, suite.server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d", id), http.NoBody) re.NoError(err) - resp, err := suite.dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() kg := &endpoint.KeyspaceGroup{} @@ -390,7 +383,7 @@ func (suite *keyspaceGroupTestSuite) trySetNodesForKeyspaceGroup(re *require.Ass re.NoError(err) httpReq, err := http.NewRequest(http.MethodPatch, suite.server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d", id), bytes.NewBuffer(data)) re.NoError(err) - resp, err := suite.dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { diff --git a/tests/integrations/mcs/members/member_test.go b/tests/integrations/mcs/members/member_test.go index 69c848937b5..ba8f6efd7fd 100644 --- a/tests/integrations/mcs/members/member_test.go +++ b/tests/integrations/mcs/members/member_test.go @@ -36,7 +36,7 @@ type memberTestSuite struct { cluster *tests.TestCluster server *tests.TestServer backendEndpoints string - dialClient pdClient.Client + pdClient pdClient.Client tsoNodes map[string]bs.Server schedulingNodes map[string]bs.Server @@ -58,7 +58,7 @@ func (suite *memberTestSuite) SetupTest() { suite.server = cluster.GetLeaderServer() re.NoError(suite.server.BootstrapCluster()) suite.backendEndpoints = suite.server.GetAddr() - suite.dialClient = pdClient.NewClient("mcs-member-test", []string{suite.server.GetAddr()}) + suite.pdClient = pdClient.NewClient("mcs-member-test", []string{suite.server.GetAddr()}) // TSO nodes := make(map[string]bs.Server) @@ -93,37 +93,37 @@ func (suite *memberTestSuite) TearDownTest() { for _, cleanup := range suite.cleanupFunc { cleanup() } - if suite.dialClient != nil { - suite.dialClient.Close() + if suite.pdClient != nil { + suite.pdClient.Close() } suite.cluster.Destroy() } func (suite *memberTestSuite) TestMembers() { re := suite.Require() - members, err := suite.dialClient.GetMicroServiceMembers(suite.ctx, "tso") + members, err := suite.pdClient.GetMicroServiceMembers(suite.ctx, "tso") re.NoError(err) re.Len(members, 3) - members, err = suite.dialClient.GetMicroServiceMembers(suite.ctx, "scheduling") + members, err = suite.pdClient.GetMicroServiceMembers(suite.ctx, "scheduling") re.NoError(err) re.Len(members, 3) } func (suite *memberTestSuite) TestPrimary() { re := suite.Require() - primary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, "tso") + primary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, "tso") re.NoError(err) re.NotEmpty(primary) - primary, err = suite.dialClient.GetMicroServicePrimary(suite.ctx, "scheduling") + primary, err = suite.pdClient.GetMicroServicePrimary(suite.ctx, "scheduling") re.NoError(err) re.NotEmpty(primary) } func (suite *memberTestSuite) TestCampaignPrimaryWhileServerClose() { re := suite.Require() - primary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, "tso") + primary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, "tso") re.NoError(err) re.NotEmpty(primary) @@ -137,7 +137,7 @@ func (suite *memberTestSuite) TestCampaignPrimaryWhileServerClose() { nodes = suite.schedulingNodes } - primary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, service) + primary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, service) re.NoError(err) // Close old and new primary to mock campaign primary @@ -151,7 +151,7 @@ func (suite *memberTestSuite) TestCampaignPrimaryWhileServerClose() { tests.WaitForPrimaryServing(re, nodes) // primary should be different with before - onlyPrimary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, service) + onlyPrimary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, service) re.NoError(err) re.NotEqual(primary, onlyPrimary) } @@ -159,7 +159,7 @@ func (suite *memberTestSuite) TestCampaignPrimaryWhileServerClose() { func (suite *memberTestSuite) TestTransferPrimary() { re := suite.Require() - primary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, "tso") + primary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, "tso") re.NoError(err) re.NotEmpty(primary) @@ -174,9 +174,9 @@ func (suite *memberTestSuite) TestTransferPrimary() { } // Test resign primary by random - primary, err = suite.dialClient.GetMicroServicePrimary(suite.ctx, service) + primary, err = suite.pdClient.GetMicroServicePrimary(suite.ctx, service) re.NoError(err) - err = suite.dialClient.TransferMicroServicePrimary(suite.ctx, service, "") + err = suite.pdClient.TransferMicroServicePrimary(suite.ctx, service, "") re.NoError(err) testutil.Eventually(re, func() bool { @@ -188,7 +188,7 @@ func (suite *memberTestSuite) TestTransferPrimary() { return false }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) - primary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, service) + primary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, service) re.NoError(err) // Test transfer primary to a specific node @@ -199,27 +199,27 @@ func (suite *memberTestSuite) TestTransferPrimary() { break } } - err = suite.dialClient.TransferMicroServicePrimary(suite.ctx, service, newPrimary) + err = suite.pdClient.TransferMicroServicePrimary(suite.ctx, service, newPrimary) re.NoError(err) testutil.Eventually(re, func() bool { return nodes[newPrimary].IsServing() }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) - primary, err = suite.dialClient.GetMicroServicePrimary(suite.ctx, service) + primary, err = suite.pdClient.GetMicroServicePrimary(suite.ctx, service) re.NoError(err) re.Equal(primary, newPrimary) // Test transfer primary to a non-exist node newPrimary = "http://" - err = suite.dialClient.TransferMicroServicePrimary(suite.ctx, service, newPrimary) + err = suite.pdClient.TransferMicroServicePrimary(suite.ctx, service, newPrimary) re.Error(err) } } func (suite *memberTestSuite) TestCampaignPrimaryAfterTransfer() { re := suite.Require() - primary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, "tso") + primary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, "tso") re.NoError(err) re.NotEmpty(primary) @@ -233,7 +233,7 @@ func (suite *memberTestSuite) TestCampaignPrimaryAfterTransfer() { nodes = suite.schedulingNodes } - primary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, service) + primary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, service) re.NoError(err) // Test transfer primary to a specific node @@ -244,11 +244,11 @@ func (suite *memberTestSuite) TestCampaignPrimaryAfterTransfer() { break } } - err = suite.dialClient.TransferMicroServicePrimary(suite.ctx, service, newPrimary) + err = suite.pdClient.TransferMicroServicePrimary(suite.ctx, service, newPrimary) re.NoError(err) tests.WaitForPrimaryServing(re, nodes) - newPrimary, err = suite.dialClient.GetMicroServicePrimary(suite.ctx, service) + newPrimary, err = suite.pdClient.GetMicroServicePrimary(suite.ctx, service) re.NoError(err) re.NotEqual(primary, newPrimary) @@ -257,7 +257,7 @@ func (suite *memberTestSuite) TestCampaignPrimaryAfterTransfer() { nodes[newPrimary].Close() tests.WaitForPrimaryServing(re, nodes) // Primary should be different with before - onlyPrimary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, service) + onlyPrimary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, service) re.NoError(err) re.NotEqual(primary, onlyPrimary) re.NotEqual(newPrimary, onlyPrimary) @@ -266,7 +266,7 @@ func (suite *memberTestSuite) TestCampaignPrimaryAfterTransfer() { func (suite *memberTestSuite) TestTransferPrimaryWhileLeaseExpired() { re := suite.Require() - primary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, "tso") + primary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, "tso") re.NoError(err) re.NotEmpty(primary) @@ -280,7 +280,7 @@ func (suite *memberTestSuite) TestTransferPrimaryWhileLeaseExpired() { nodes = suite.schedulingNodes } - primary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, service) + primary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, service) re.NoError(err) // Test transfer primary to a specific node @@ -293,7 +293,7 @@ func (suite *memberTestSuite) TestTransferPrimaryWhileLeaseExpired() { } // Mock the new primary can not grant leader which means the lease will expire re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/election/skipGrantLeader", fmt.Sprintf("return(\"%s\")", newPrimary))) - err = suite.dialClient.TransferMicroServicePrimary(suite.ctx, service, newPrimary) + err = suite.pdClient.TransferMicroServicePrimary(suite.ctx, service, newPrimary) re.NoError(err) // Wait for the old primary exit and new primary campaign @@ -310,7 +310,7 @@ func (suite *memberTestSuite) TestTransferPrimaryWhileLeaseExpired() { tests.WaitForPrimaryServing(re, nodes) // Primary should be different with before - onlyPrimary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, service) + onlyPrimary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, service) re.NoError(err) re.NotEqual(newPrimary, onlyPrimary) } diff --git a/tests/integrations/mcs/resourcemanager/resource_manager_test.go b/tests/integrations/mcs/resourcemanager/resource_manager_test.go index 17673213a97..ab7cd5321ad 100644 --- a/tests/integrations/mcs/resourcemanager/resource_manager_test.go +++ b/tests/integrations/mcs/resourcemanager/resource_manager_test.go @@ -957,7 +957,7 @@ func (suite *resourceManagerClientTestSuite) TestBasicResourceGroupCURD() { } createJSON, err := json.Marshal(group) re.NoError(err) - resp, err := http.Post(getAddr(i)+"/resource-manager/api/v1/config/group", "application/json", strings.NewReader(string(createJSON))) + resp, err := tests.TestDialClient.Post(getAddr(i)+"/resource-manager/api/v1/config/group", "application/json", strings.NewReader(string(createJSON))) re.NoError(err) resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -982,7 +982,7 @@ func (suite *resourceManagerClientTestSuite) TestBasicResourceGroupCURD() { } // Get Resource Group - resp, err = http.Get(getAddr(i) + "/resource-manager/api/v1/config/group/" + tcase.name) + resp, err = tests.TestDialClient.Get(getAddr(i) + "/resource-manager/api/v1/config/group/" + tcase.name) re.NoError(err) re.Equal(http.StatusOK, resp.StatusCode) respString, err := io.ReadAll(resp.Body) @@ -995,7 +995,7 @@ func (suite *resourceManagerClientTestSuite) TestBasicResourceGroupCURD() { // Last one, Check list and delete all resource groups if i == len(testCasesSet1)-1 { - resp, err := http.Get(getAddr(i) + "/resource-manager/api/v1/config/groups") + resp, err := tests.TestDialClient.Get(getAddr(i) + "/resource-manager/api/v1/config/groups") re.NoError(err) re.Equal(http.StatusOK, resp.StatusCode) respString, err := io.ReadAll(resp.Body) @@ -1023,7 +1023,7 @@ func (suite *resourceManagerClientTestSuite) TestBasicResourceGroupCURD() { } // verify again - resp1, err := http.Get(getAddr(i) + "/resource-manager/api/v1/config/groups") + resp1, err := tests.TestDialClient.Get(getAddr(i) + "/resource-manager/api/v1/config/groups") re.NoError(err) re.Equal(http.StatusOK, resp1.StatusCode) respString1, err := io.ReadAll(resp1.Body) diff --git a/tests/integrations/mcs/resourcemanager/server_test.go b/tests/integrations/mcs/resourcemanager/server_test.go index 4e1fb018d56..24de29db3a6 100644 --- a/tests/integrations/mcs/resourcemanager/server_test.go +++ b/tests/integrations/mcs/resourcemanager/server_test.go @@ -63,7 +63,7 @@ func TestResourceManagerServer(t *testing.T) { // Test registered REST HTTP Handler url := addr + "/resource-manager/api/v1/config" { - resp, err := http.Get(url + "/groups") + resp, err := tests.TestDialClient.Get(url + "/groups") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -78,13 +78,13 @@ func TestResourceManagerServer(t *testing.T) { } createJSON, err := json.Marshal(group) re.NoError(err) - resp, err := http.Post(url+"/group", "application/json", strings.NewReader(string(createJSON))) + resp, err := tests.TestDialClient.Post(url+"/group", "application/json", strings.NewReader(string(createJSON))) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) } { - resp, err := http.Get(url + "/group/pingcap") + resp, err := tests.TestDialClient.Get(url + "/group/pingcap") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -95,7 +95,7 @@ func TestResourceManagerServer(t *testing.T) { // Test metrics handler { - resp, err := http.Get(addr + "/metrics") + resp, err := tests.TestDialClient.Get(addr + "/metrics") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -106,7 +106,7 @@ func TestResourceManagerServer(t *testing.T) { // Test status handler { - resp, err := http.Get(addr + "/status") + resp, err := tests.TestDialClient.Get(addr + "/status") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index e9033e5016a..cf2c6dd2508 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -29,12 +29,6 @@ import ( "github.com/tikv/pd/tests" ) -var testDialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, -} - type apiTestSuite struct { suite.Suite env *tests.SchedulingTestEnvironment @@ -56,7 +50,6 @@ func (suite *apiTestSuite) TearDownSuite() { re := suite.Require() re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/schedule/changeCoordinatorTicker")) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/mcs/scheduling/server/changeRunCollectWaitTime")) - testDialClient.CloseIdleConnections() } func (suite *apiTestSuite) TestGetCheckerByName() { @@ -84,14 +77,14 @@ func (suite *apiTestSuite) checkGetCheckerByName(cluster *tests.TestCluster) { name := testCase.name // normal run resp := make(map[string]any) - err := testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err := testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) // paused err = co.PauseOrResumeChecker(name, 30) re.NoError(err) resp = make(map[string]any) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.True(resp["paused"].(bool)) // resumed @@ -99,7 +92,7 @@ func (suite *apiTestSuite) checkGetCheckerByName(cluster *tests.TestCluster) { re.NoError(err) time.Sleep(time.Second) resp = make(map[string]any) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) } @@ -121,29 +114,29 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { }) // Test operators - err := testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &respSlice, + err := testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &respSlice, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) re.Empty(respSlice) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), []byte(``), + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), []byte(``), testutil.StatusNotOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), nil, testutil.StatusNotOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), + err = testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), testutil.StatusNotOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/records"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/records"), nil, testutil.StatusNotOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // Test checker - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) re.False(resp["paused"].(bool)) @@ -154,7 +147,7 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { input["delay"] = delay pauseArgs, err := json.Marshal(input) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), pauseArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), pauseArgs, testutil.StatusOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) } @@ -173,7 +166,7 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { // "/schedulers", http.MethodPost // "/schedulers/{name}", http.MethodDelete testutil.Eventually(re, func() bool { - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &respSlice, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &respSlice, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) return slice.Contains(respSlice, "balance-leader-scheduler") @@ -184,18 +177,18 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { input["delay"] = delay pauseArgs, err := json.Marshal(input) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/balance-leader-scheduler"), pauseArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/balance-leader-scheduler"), pauseArgs, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) } postScheduler(30) postScheduler(0) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/diagnostic/balance-leader-scheduler"), &resp, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/diagnostic/balance-leader-scheduler"), &resp, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "scheduler-config"), &resp, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "scheduler-config"), &resp, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) re.Contains(resp, "balance-leader-scheduler") @@ -206,16 +199,16 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { "balance-hot-region-scheduler", } for _, schedulerName := range schedulers { - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s/%s/%s", urlPrefix, "scheduler-config", schedulerName, "list"), &resp, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s/%s/%s", urlPrefix, "scheduler-config", schedulerName, "list"), &resp, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) } - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), nil, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), nil, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/balance-leader-scheduler"), + err = testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/balance-leader-scheduler"), testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) @@ -223,74 +216,74 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { input["name"] = "balance-leader-scheduler" b, err := json.Marshal(input) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), b, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), b, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) // Test hotspot var hotRegions statistics.StoreHotPeersInfos - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/write"), &hotRegions, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/write"), &hotRegions, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/read"), &hotRegions, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/read"), &hotRegions, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) var stores handler.HotStoreStats - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/stores"), &stores, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/stores"), &stores, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) var buckets handler.HotBucketsResponse - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/buckets"), &buckets, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/buckets"), &buckets, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) var history storage.HistoryHotRegions - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/history"), &history, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/history"), &history, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // Test region label var labelRules []*labeler.LabelRule - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rules"), &labelRules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rules"), &labelRules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.ReadGetJSONWithBody(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rules/ids"), []byte(`["rule1", "rule3"]`), + err = testutil.ReadGetJSONWithBody(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rules/ids"), []byte(`["rule1", "rule3"]`), &labelRules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rule/rule1"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rule/rule1"), nil, testutil.StatusNotOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1"), nil, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1/label/key"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1/label/key"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1/labels"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1/labels"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // Test Region body := fmt.Sprintf(`{"start_key":"%s", "end_key": "%s"}`, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a3"))) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/accelerate-schedule"), []byte(body), + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/accelerate-schedule"), []byte(body), testutil.StatusOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) body = fmt.Sprintf(`[{"start_key":"%s", "end_key": "%s"}, {"start_key":"%s", "end_key": "%s"}]`, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a3")), hex.EncodeToString([]byte("a4")), hex.EncodeToString([]byte("a6"))) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/accelerate-schedule/batch"), []byte(body), + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/accelerate-schedule/batch"), []byte(body), testutil.StatusOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) body = fmt.Sprintf(`{"start_key":"%s", "end_key": "%s"}`, hex.EncodeToString([]byte("b1")), hex.EncodeToString([]byte("b3"))) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/scatter"), []byte(body), + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/scatter"), []byte(body), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) body = fmt.Sprintf(`{"retry_limit":%v, "split_keys": ["%s","%s","%s"]}`, 3, hex.EncodeToString([]byte("bbb")), hex.EncodeToString([]byte("ccc")), hex.EncodeToString([]byte("ddd"))) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/split"), []byte(body), + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/split"), []byte(body), testutil.StatusOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, urlPrefix, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a2"))), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, urlPrefix, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a2"))), nil, testutil.StatusOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // Test rules: only forward `GET` request @@ -308,73 +301,73 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { rulesArgs, err := json.Marshal(rules) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/batch"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/batch"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/group/pd"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/group/pd"), &rules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2"), &rules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) var fit placement.RegionFit - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2/detail"), &fit, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2/detail"), &fit, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/key/0000000000000001"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/key/0000000000000001"), &rules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), + err = testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), + err = testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_groups"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_groups"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), + err = testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) // test redirect is disabled - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), http.NoBody) re.NoError(err) req.Header.Set(apiutil.XForbiddenForwardToMicroServiceHeader, "true") - httpResp, err := testDialClient.Do(req) + httpResp, err := tests.TestDialClient.Do(req) re.NoError(err) re.Equal(http.StatusOK, httpResp.StatusCode) defer httpResp.Body.Close() @@ -395,7 +388,7 @@ func (suite *apiTestSuite) checkConfig(cluster *tests.TestCluster) { urlPrefix := fmt.Sprintf("%s/scheduling/api/v1/config", addr) var cfg config.Config - testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &cfg) re.Equal(cfg.GetListenAddr(), s.GetConfig().GetListenAddr()) re.Equal(cfg.Schedule.LeaderScheduleLimit, s.GetConfig().Schedule.LeaderScheduleLimit) re.Equal(cfg.Schedule.EnableCrossTableMerge, s.GetConfig().Schedule.EnableCrossTableMerge) @@ -427,7 +420,7 @@ func (suite *apiTestSuite) checkConfigForward(cluster *tests.TestCluster) { // Test config forward // Expect to get same config in scheduling server and api server testutil.Eventually(re, func() bool { - testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &cfg) re.Equal(cfg["schedule"].(map[string]any)["leader-schedule-limit"], float64(opts.GetLeaderScheduleLimit())) re.Equal(cfg["replication"].(map[string]any)["max-replicas"], @@ -442,10 +435,10 @@ func (suite *apiTestSuite) checkConfigForward(cluster *tests.TestCluster) { "max-replicas": 4, }) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, urlPrefix, reqData, testutil.StatusOK(re)) + err = testutil.CheckPostJSON(tests.TestDialClient, urlPrefix, reqData, testutil.StatusOK(re)) re.NoError(err) testutil.Eventually(re, func() bool { - testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &cfg) return cfg["replication"].(map[string]any)["max-replicas"] == 4. && opts.GetReplicationConfig().MaxReplicas == 4. }) @@ -454,11 +447,11 @@ func (suite *apiTestSuite) checkConfigForward(cluster *tests.TestCluster) { // Expect to get new config in scheduling server but not old config in api server opts.GetScheduleConfig().LeaderScheduleLimit = 100 re.Equal(100, int(opts.GetLeaderScheduleLimit())) - testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &cfg) re.Equal(100., cfg["schedule"].(map[string]any)["leader-schedule-limit"]) opts.GetReplicationConfig().MaxReplicas = 5 re.Equal(5, int(opts.GetReplicationConfig().MaxReplicas)) - testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &cfg) re.Equal(5., cfg["replication"].(map[string]any)["max-replicas"]) } @@ -480,11 +473,11 @@ func (suite *apiTestSuite) checkAdminRegionCache(cluster *tests.TestCluster) { addr := schedulingServer.GetAddr() urlPrefix := fmt.Sprintf("%s/scheduling/api/v1/admin/cache/regions", addr) - err := testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "30"), testutil.StatusOK(re)) + err := testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "30"), testutil.StatusOK(re)) re.NoError(err) re.Equal(2, schedulingServer.GetCluster().GetRegionCount([]byte{}, []byte{})) - err = testutil.CheckDelete(testDialClient, urlPrefix, testutil.StatusOK(re)) + err = testutil.CheckDelete(tests.TestDialClient, urlPrefix, testutil.StatusOK(re)) re.NoError(err) re.Equal(0, schedulingServer.GetCluster().GetRegionCount([]byte{}, []byte{})) } @@ -509,12 +502,12 @@ func (suite *apiTestSuite) checkAdminRegionCacheForward(cluster *tests.TestClust addr := cluster.GetLeaderServer().GetAddr() urlPrefix := fmt.Sprintf("%s/pd/api/v1/admin/cache/region", addr) - err := testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "30"), testutil.StatusOK(re)) + err := testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "30"), testutil.StatusOK(re)) re.NoError(err) re.Equal(2, schedulingServer.GetCluster().GetRegionCount([]byte{}, []byte{})) re.Equal(2, apiServer.GetRaftCluster().GetRegionCount([]byte{}, []byte{}).Count) - err = testutil.CheckDelete(testDialClient, urlPrefix+"s", testutil.StatusOK(re)) + err = testutil.CheckDelete(tests.TestDialClient, urlPrefix+"s", testutil.StatusOK(re)) re.NoError(err) re.Equal(0, schedulingServer.GetCluster().GetRegionCount([]byte{}, []byte{})) re.Equal(0, apiServer.GetRaftCluster().GetRegionCount([]byte{}, []byte{}).Count) @@ -544,14 +537,14 @@ func (suite *apiTestSuite) checkFollowerForward(cluster *tests.TestCluster) { if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { // follower will forward to scheduling server directly re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true"), ) re.NoError(err) } else { // follower will forward to leader server re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader), ) re.NoError(err) @@ -560,7 +553,7 @@ func (suite *apiTestSuite) checkFollowerForward(cluster *tests.TestCluster) { // follower will forward to leader server re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) results := make(map[string]any) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config"), &results, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config"), &results, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader), ) re.NoError(err) @@ -576,7 +569,7 @@ func (suite *apiTestSuite) checkMetrics(cluster *tests.TestCluster) { testutil.Eventually(re, func() bool { return s.IsServing() }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) - resp, err := http.Get(s.GetConfig().GetAdvertiseListenAddr() + "/metrics") + resp, err := tests.TestDialClient.Get(s.GetConfig().GetAdvertiseListenAddr() + "/metrics") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -595,7 +588,7 @@ func (suite *apiTestSuite) checkStatus(cluster *tests.TestCluster) { testutil.Eventually(re, func() bool { return s.IsServing() }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) - resp, err := http.Get(s.GetConfig().GetAdvertiseListenAddr() + "/status") + resp, err := tests.TestDialClient.Get(s.GetConfig().GetAdvertiseListenAddr() + "/status") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -659,34 +652,34 @@ func (suite *apiTestSuite) checkStores(cluster *tests.TestCluster) { apiServerAddr := cluster.GetLeaderServer().GetAddr() urlPrefix := fmt.Sprintf("%s/pd/api/v1/stores", apiServerAddr) var resp map[string]any - err := testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err := testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3, int(resp["count"].(float64))) re.Len(resp["stores"].([]any), 3) scheServerAddr := cluster.GetSchedulingPrimaryServer().GetAddr() urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/stores", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3, int(resp["count"].(float64))) re.Len(resp["stores"].([]any), 3) // Test /stores/{id} urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/stores/1", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal("tikv1", resp["store"].(map[string]any)["address"]) re.Equal("Up", resp["store"].(map[string]any)["state_name"]) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/stores/6", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal("tikv6", resp["store"].(map[string]any)["address"]) re.Equal("Offline", resp["store"].(map[string]any)["state_name"]) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/stores/7", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal("tikv7", resp["store"].(map[string]any)["address"]) re.Equal("Tombstone", resp["store"].(map[string]any)["state_name"]) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/stores/233", scheServerAddr) - testutil.CheckGetJSON(testDialClient, urlPrefix, nil, + testutil.CheckGetJSON(tests.TestDialClient, urlPrefix, nil, testutil.Status(re, http.StatusNotFound), testutil.StringContain(re, "not found")) } @@ -703,27 +696,27 @@ func (suite *apiTestSuite) checkRegions(cluster *tests.TestCluster) { apiServerAddr := cluster.GetLeaderServer().GetAddr() urlPrefix := fmt.Sprintf("%s/pd/api/v1/regions", apiServerAddr) var resp map[string]any - err := testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err := testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3, int(resp["count"].(float64))) re.Len(resp["regions"].([]any), 3) scheServerAddr := cluster.GetSchedulingPrimaryServer().GetAddr() urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3, int(resp["count"].(float64))) re.Len(resp["regions"].([]any), 3) // Test /regions/{id} and /regions/count urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/1", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) key := fmt.Sprintf("%x", "a") re.Equal(key, resp["start_key"]) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/count", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3., resp["count"]) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/233", scheServerAddr) - testutil.CheckGetJSON(testDialClient, urlPrefix, nil, + testutil.CheckGetJSON(tests.TestDialClient, urlPrefix, nil, testutil.Status(re, http.StatusNotFound), testutil.StringContain(re, "not found")) } diff --git a/tests/integrations/mcs/tso/api_test.go b/tests/integrations/mcs/tso/api_test.go index dc9bfa1e291..4d6f9b33e3b 100644 --- a/tests/integrations/mcs/tso/api_test.go +++ b/tests/integrations/mcs/tso/api_test.go @@ -42,13 +42,6 @@ const ( tsoKeyspaceGroupsPrefix = "/tso/api/v1/keyspace-groups" ) -// dialClient used to dial http request. -var dialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, -} - type tsoAPITestSuite struct { suite.Suite ctx context.Context @@ -110,13 +103,13 @@ func (suite *tsoAPITestSuite) TestForwardResetTS() { // Test reset ts input := []byte(`{"tso":"121312", "force-use-larger":true}`) - err := testutil.CheckPostJSON(dialClient, url, input, + err := testutil.CheckPostJSON(tests.TestDialClient, url, input, testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully"), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // Test reset ts with invalid tso input = []byte(`{}`) - err = testutil.CheckPostJSON(dialClient, url, input, + err = testutil.CheckPostJSON(tests.TestDialClient, url, input, testutil.StatusNotOK(re), testutil.StringContain(re, "invalid tso value"), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) } @@ -124,7 +117,7 @@ func (suite *tsoAPITestSuite) TestForwardResetTS() { func mustGetKeyspaceGroupMembers(re *require.Assertions, server *tso.Server) map[uint32]*apis.KeyspaceGroupMember { httpReq, err := http.NewRequest(http.MethodGet, server.GetAddr()+tsoKeyspaceGroupsPrefix+"/members", http.NoBody) re.NoError(err) - httpResp, err := dialClient.Do(httpReq) + httpResp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer httpResp.Body.Close() data, err := io.ReadAll(httpResp.Body) @@ -177,14 +170,14 @@ func TestTSOServerStartFirst(t *testing.T) { re.NoError(err) httpReq, err := http.NewRequest(http.MethodPost, addr+"/pd/api/v2/tso/keyspace-groups/0/split", bytes.NewBuffer(jsonBody)) re.NoError(err) - httpResp, err := dialClient.Do(httpReq) + httpResp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer httpResp.Body.Close() re.Equal(http.StatusOK, httpResp.StatusCode) httpReq, err = http.NewRequest(http.MethodGet, addr+"/pd/api/v2/tso/keyspace-groups/0", http.NoBody) re.NoError(err) - httpResp, err = dialClient.Do(httpReq) + httpResp, err = tests.TestDialClient.Do(httpReq) re.NoError(err) data, err := io.ReadAll(httpResp.Body) re.NoError(err) @@ -219,20 +212,20 @@ func TestForwardOnlyTSONoScheduling(t *testing.T) { // Test /operators, it should not forward when there is no scheduling server. var slice []string - err = testutil.ReadGetJSON(re, dialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) re.Empty(slice) // Test admin/reset-ts, it should forward to tso server. input := []byte(`{"tso":"121312", "force-use-larger":true}`) - err = testutil.CheckPostJSON(dialClient, fmt.Sprintf("%s/%s", urlPrefix, "admin/reset-ts"), input, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "admin/reset-ts"), input, testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully"), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // If close tso server, it should try forward to tso server, but return error in api mode. ttc.Destroy() - err = testutil.CheckPostJSON(dialClient, fmt.Sprintf("%s/%s", urlPrefix, "admin/reset-ts"), input, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "admin/reset-ts"), input, testutil.Status(re, http.StatusInternalServerError), testutil.StringContain(re, "[PD:apiutil:ErrRedirect]redirect failed")) re.NoError(err) } @@ -241,7 +234,7 @@ func (suite *tsoAPITestSuite) TestMetrics() { re := suite.Require() primary := suite.tsoCluster.WaitForDefaultPrimaryServing(re) - resp, err := http.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/metrics") + resp, err := tests.TestDialClient.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/metrics") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -254,7 +247,7 @@ func (suite *tsoAPITestSuite) TestStatus() { re := suite.Require() primary := suite.tsoCluster.WaitForDefaultPrimaryServing(re) - resp, err := http.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/status") + resp, err := tests.TestDialClient.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/status") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -271,7 +264,7 @@ func (suite *tsoAPITestSuite) TestConfig() { re := suite.Require() primary := suite.tsoCluster.WaitForDefaultPrimaryServing(re) - resp, err := http.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/tso/api/v1/config") + resp, err := tests.TestDialClient.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/tso/api/v1/config") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) diff --git a/tests/integrations/mcs/tso/server_test.go b/tests/integrations/mcs/tso/server_test.go index 108740e46f9..260395e4209 100644 --- a/tests/integrations/mcs/tso/server_test.go +++ b/tests/integrations/mcs/tso/server_test.go @@ -111,13 +111,13 @@ func (suite *tsoServerTestSuite) TestTSOServerStartAndStopNormally() { url := s.GetAddr() + tsoapi.APIPathPrefix + "/admin/reset-ts" // Test reset ts input := []byte(`{"tso":"121312", "force-use-larger":true}`) - err = testutil.CheckPostJSON(dialClient, url, input, + err = testutil.CheckPostJSON(tests.TestDialClient, url, input, testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully")) re.NoError(err) // Test reset ts with invalid tso input = []byte(`{}`) - err = testutil.CheckPostJSON(dialClient, suite.backendEndpoints+"/pd/api/v1/admin/reset-ts", input, + err = testutil.CheckPostJSON(tests.TestDialClient, suite.backendEndpoints+"/pd/api/v1/admin/reset-ts", input, testutil.StatusNotOK(re), testutil.StringContain(re, "invalid tso value")) re.NoError(err) } @@ -583,7 +583,7 @@ func (suite *CommonTestSuite) TestBootstrapDefaultKeyspaceGroup() { // check the default keyspace group check := func() { - resp, err := http.Get(suite.pdLeader.GetServer().GetConfig().AdvertiseClientUrls + "/pd/api/v2/tso/keyspace-groups") + resp, err := tests.TestDialClient.Get(suite.pdLeader.GetServer().GetConfig().AdvertiseClientUrls + "/pd/api/v2/tso/keyspace-groups") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) diff --git a/tests/scheduling_cluster.go b/tests/scheduling_cluster.go index 1768c4128cc..434a6bd9a48 100644 --- a/tests/scheduling_cluster.go +++ b/tests/scheduling_cluster.go @@ -113,7 +113,7 @@ func (tc *TestSchedulingCluster) WaitForPrimaryServing(re *require.Assertions) * } } return false - }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) + }, testutil.WithWaitFor(10*time.Second), testutil.WithTickInterval(50*time.Millisecond)) return primary } diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 9e1636df045..091d1488177 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -66,7 +66,7 @@ func TestReconnect(t *testing.T) { re.NotEmpty(leader) for name, s := range cluster.GetServers() { if name != leader { - res, err := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") + res, err := tests.TestDialClient.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") re.NoError(err) res.Body.Close() re.Equal(http.StatusOK, res.StatusCode) @@ -83,7 +83,7 @@ func TestReconnect(t *testing.T) { for name, s := range cluster.GetServers() { if name != leader { testutil.Eventually(re, func() bool { - res, err := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") + res, err := tests.TestDialClient.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") re.NoError(err) defer res.Body.Close() return res.StatusCode == http.StatusOK @@ -98,7 +98,7 @@ func TestReconnect(t *testing.T) { for name, s := range cluster.GetServers() { if name != leader && name != newLeader { testutil.Eventually(re, func() bool { - res, err := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") + res, err := tests.TestDialClient.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") re.NoError(err) defer res.Body.Close() return res.StatusCode == http.StatusServiceUnavailable @@ -148,7 +148,7 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { data, err := json.Marshal(input) re.NoError(err) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) @@ -156,7 +156,7 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { labels := make(map[string]any) labels["testkey"] = "testvalue" data, _ = json.Marshal(labels) - resp, err = dialClient.Post(leader.GetAddr()+"/pd/api/v1/debug/pprof/profile?force=true", "application/json", bytes.NewBuffer(data)) + resp, err = tests.TestDialClient.Post(leader.GetAddr()+"/pd/api/v1/debug/pprof/profile?seconds=1", "application/json", bytes.NewBuffer(data)) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() @@ -164,7 +164,7 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { re.Equal(http.StatusOK, resp.StatusCode) re.Equal("Profile", resp.Header.Get("service-label")) - re.Equal("{\"force\":[\"true\"]}", resp.Header.Get("url-param")) + re.Equal("{\"seconds\":[\"1\"]}", resp.Header.Get("url-param")) re.Equal("{\"testkey\":\"testvalue\"}", resp.Header.Get("body-param")) re.Equal("HTTP/1.1/POST:/pd/api/v1/debug/pprof/profile", resp.Header.Get("method")) re.Equal("anonymous", resp.Header.Get("caller-id")) @@ -176,13 +176,13 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { data, err = json.Marshal(input) re.NoError(err) req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.False(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) header := mustRequestSuccess(re, leader.GetServer()) - re.Equal("", header.Get("service-label")) + re.Equal("GetVersion", header.Get("service-label")) re.NoError(failpoint.Disable("github.com/tikv/pd/server/api/addRequestInfoMiddleware")) } @@ -199,7 +199,7 @@ func BenchmarkDoRequestWithServiceMiddleware(b *testing.B) { } data, _ := json.Marshal(input) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() b.StartTimer() for i := 0; i < b.N; i++ { @@ -219,14 +219,14 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { data, err := json.Marshal(input) re.NoError(err) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) // returns StatusOK when no rate-limit config req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() @@ -240,7 +240,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { jsonBody, err := json.Marshal(input) re.NoError(err) req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config/rate-limit", bytes.NewBuffer(jsonBody)) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() @@ -249,7 +249,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { for i := 0; i < 3; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -266,7 +266,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { time.Sleep(time.Second * 2) for i := 0; i < 2; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -283,7 +283,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { time.Sleep(time.Second) for i := 0; i < 2; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -310,7 +310,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { for i := 0; i < 3; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -327,7 +327,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { time.Sleep(time.Second * 2) for i := 0; i < 2; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -344,7 +344,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { time.Sleep(time.Second) for i := 0; i < 2; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -359,20 +359,32 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { data, err = json.Marshal(input) re.NoError(err) req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.False(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) for i := 0; i < 3; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() re.NoError(err) re.Equal(http.StatusOK, resp.StatusCode) } + + // reset rate limit + input = map[string]any{ + "enable-rate-limit": "true", + } + data, err = json.Marshal(input) + re.NoError(err) + req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) + resp, err = tests.TestDialClient.Do(req) + re.NoError(err) + resp.Body.Close() + re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) } func (suite *middlewareTestSuite) TestSwaggerUrl() { @@ -380,7 +392,7 @@ func (suite *middlewareTestSuite) TestSwaggerUrl() { leader := suite.cluster.GetLeaderServer() re.NotNil(leader) req, _ := http.NewRequest(http.MethodGet, leader.GetAddr()+"/swagger/ui/index", http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) re.Equal(http.StatusNotFound, resp.StatusCode) resp.Body.Close() @@ -396,20 +408,20 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { data, err := json.Marshal(input) re.NoError(err) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) timeUnix := time.Now().Unix() - 20 req, _ = http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/trend?from=%d", leader.GetAddr(), timeUnix), http.NoBody) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() re.NoError(err) req, _ = http.NewRequest(http.MethodGet, leader.GetAddr()+"/metrics", http.NoBody) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) @@ -428,14 +440,14 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { timeUnix = time.Now().Unix() - 20 req, _ = http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/trend?from=%d", leader.GetAddr(), timeUnix), http.NoBody) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() re.NoError(err) req, _ = http.NewRequest(http.MethodGet, leader.GetAddr()+"/metrics", http.NoBody) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() content, _ = io.ReadAll(resp.Body) @@ -448,7 +460,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { data, err = json.Marshal(input) re.NoError(err) req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.False(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) @@ -466,13 +478,13 @@ func (suite *middlewareTestSuite) TestAuditLocalLogBackend() { data, err := json.Marshal(input) re.NoError(err) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() @@ -494,7 +506,7 @@ func BenchmarkDoRequestWithLocalLogAudit(b *testing.B) { } data, _ := json.Marshal(input) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() b.StartTimer() for i := 0; i < b.N; i++ { @@ -516,7 +528,7 @@ func BenchmarkDoRequestWithPrometheusAudit(b *testing.B) { } data, _ := json.Marshal(input) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() b.StartTimer() for i := 0; i < b.N; i++ { @@ -538,7 +550,7 @@ func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) { } data, _ := json.Marshal(input) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() b.StartTimer() for i := 0; i < b.N; i++ { @@ -551,7 +563,7 @@ func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) { func doTestRequestWithLogAudit(srv *tests.TestServer) { req, _ := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s/pd/api/v1/admin/cache/regions", srv.GetAddr()), http.NoBody) req.Header.Set(apiutil.XCallerIDHeader, "test") - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() } @@ -559,7 +571,7 @@ func doTestRequestWithPrometheus(srv *tests.TestServer) { timeUnix := time.Now().Unix() - 20 req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/trend?from=%d", srv.GetAddr(), timeUnix), http.NoBody) req.Header.Set(apiutil.XCallerIDHeader, "test") - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() } @@ -623,7 +635,7 @@ func (suite *redirectorTestSuite) TestAllowFollowerHandle() { request, err := http.NewRequest(http.MethodGet, addr, http.NoBody) re.NoError(err) request.Header.Add(apiutil.PDAllowFollowerHandleHeader, "true") - resp, err := dialClient.Do(request) + resp, err := tests.TestDialClient.Do(request) re.NoError(err) re.Equal("", resp.Header.Get(apiutil.PDRedirectorHeader)) defer resp.Body.Close() @@ -648,7 +660,7 @@ func (suite *redirectorTestSuite) TestNotLeader() { // Request to follower without redirectorHeader is OK. request, err := http.NewRequest(http.MethodGet, addr, http.NoBody) re.NoError(err) - resp, err := dialClient.Do(request) + resp, err := tests.TestDialClient.Do(request) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -658,7 +670,7 @@ func (suite *redirectorTestSuite) TestNotLeader() { // Request to follower with redirectorHeader will fail. request.RequestURI = "" request.Header.Set(apiutil.PDRedirectorHeader, "pd") - resp1, err := dialClient.Do(request) + resp1, err := tests.TestDialClient.Do(request) re.NoError(err) defer resp1.Body.Close() re.NotEqual(http.StatusOK, resp1.StatusCode) @@ -677,7 +689,7 @@ func (suite *redirectorTestSuite) TestXForwardedFor() { addr := follower.GetAddr() + "/pd/api/v1/regions" request, err := http.NewRequest(http.MethodGet, addr, http.NoBody) re.NoError(err) - resp, err := dialClient.Do(request) + resp, err := tests.TestDialClient.Do(request) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -689,7 +701,7 @@ func (suite *redirectorTestSuite) TestXForwardedFor() { } func mustRequestSuccess(re *require.Assertions, s *server.Server) http.Header { - resp, err := dialClient.Get(s.GetAddr() + "/pd/api/v1/version") + resp, err := tests.TestDialClient.Get(s.GetAddr() + "/pd/api/v1/version") re.NoError(err) defer resp.Body.Close() _, err = io.ReadAll(resp.Body) @@ -783,7 +795,7 @@ func TestRemovingProgress(t *testing.T) { } url := leader.GetAddr() + "/pd/api/v1/stores/progress?action=removing" req, _ := http.NewRequest(http.MethodGet, url, http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -807,7 +819,7 @@ func TestRemovingProgress(t *testing.T) { } url := leader.GetAddr() + "/pd/api/v1/stores/progress?action=removing" req, _ := http.NewRequest(http.MethodGet, url, http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -957,13 +969,20 @@ func TestPreparingProgress(t *testing.T) { StartTimestamp: time.Now().UnixNano() - 100, }, } - - for _, store := range stores { + // store 4 and store 5 are preparing state while store 1, store 2 and store 3 are state serving state + for _, store := range stores[:2] { tests.MustPutStore(re, cluster, store) } - for i := 0; i < 100; i++ { + for i := 0; i < core.InitClusterRegionThreshold; i++ { tests.MustPutRegion(re, cluster, uint64(i+1), uint64(i)%3+1, []byte(fmt.Sprintf("%20d", i)), []byte(fmt.Sprintf("%20d", i+1)), core.SetApproximateSize(10)) } + testutil.Eventually(re, func() bool { + return leader.GetRaftCluster().GetTotalRegionCount() == core.InitClusterRegionThreshold + }) + // to avoid forcing the store to the `serving` state with too few regions + for _, store := range stores[2:] { + tests.MustPutStore(re, cluster, store) + } // no store preparing output := sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusNotFound) re.Contains(string(output), "no progress found for the action") @@ -977,7 +996,7 @@ func TestPreparingProgress(t *testing.T) { } url := leader.GetAddr() + "/pd/api/v1/stores/progress?action=preparing" req, _ := http.NewRequest(http.MethodGet, url, http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() if resp.StatusCode != http.StatusNotFound { @@ -1002,7 +1021,7 @@ func TestPreparingProgress(t *testing.T) { } url := leader.GetAddr() + "/pd/api/v1/stores/progress?action=preparing" req, _ := http.NewRequest(http.MethodGet, url, http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() output, err := io.ReadAll(resp.Body) @@ -1059,7 +1078,7 @@ func TestPreparingProgress(t *testing.T) { func sendRequest(re *require.Assertions, url string, method string, statusCode int) []byte { req, _ := http.NewRequest(method, url, http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) re.Equal(statusCode, resp.StatusCode) output, err := io.ReadAll(resp.Body) diff --git a/tests/server/api/checker_test.go b/tests/server/api/checker_test.go index 0304d7fd369..54298b405f1 100644 --- a/tests/server/api/checker_test.go +++ b/tests/server/api/checker_test.go @@ -73,14 +73,14 @@ func testErrCases(re *require.Assertions, cluster *tests.TestCluster) { input := make(map[string]any) pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(re)) re.NoError(err) // negative delay input["delay"] = -10 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(re)) re.NoError(err) // wrong name @@ -88,12 +88,12 @@ func testErrCases(re *require.Assertions, cluster *tests.TestCluster) { input["delay"] = 30 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(re)) re.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(re)) re.NoError(err) } @@ -102,28 +102,28 @@ func testGetStatus(re *require.Assertions, cluster *tests.TestCluster, name stri urlPrefix := fmt.Sprintf("%s/pd/api/v1/checker", cluster.GetLeaderServer().GetAddr()) // normal run resp := make(map[string]any) - err := tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err := tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) // paused input["delay"] = 30 pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) resp = make(map[string]any) - err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.True(resp["paused"].(bool)) // resumed input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) time.Sleep(time.Second) resp = make(map[string]any) - err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) } @@ -137,18 +137,18 @@ func testPauseOrResume(re *require.Assertions, cluster *tests.TestCluster, name input["delay"] = 30 pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.True(resp["paused"].(bool)) input["delay"] = 1 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) time.Sleep(time.Second) - err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) @@ -157,14 +157,14 @@ func testPauseOrResume(re *require.Assertions, cluster *tests.TestCluster, name input["delay"] = 30 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) } diff --git a/tests/server/api/operator_test.go b/tests/server/api/operator_test.go index a5cd865b454..c3b86f9fde0 100644 --- a/tests/server/api/operator_test.go +++ b/tests/server/api/operator_test.go @@ -18,7 +18,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "sort" "strconv" "strings" @@ -35,15 +34,6 @@ import ( "github.com/tikv/pd/tests" ) -var ( - // testDialClient used to dial http request. only used for test. - testDialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, - } -) - type operatorTestSuite struct { suite.Suite env *tests.SchedulingTestEnvironment @@ -70,7 +60,7 @@ func (suite *operatorTestSuite) TestAddRemovePeer() { func (suite *operatorTestSuite) checkAddRemovePeer(cluster *tests.TestCluster) { re := suite.Require() - pauseRuleChecker(re, cluster) + pauseAllCheckers(re, cluster) stores := []*metapb.Store{ { Id: 1, @@ -112,35 +102,35 @@ func (suite *operatorTestSuite) checkAddRemovePeer(cluster *tests.TestCluster) { urlPrefix := fmt.Sprintf("%s/pd/api/v1", cluster.GetLeaderServer().GetAddr()) regionURL := fmt.Sprintf("%s/operators/%d", urlPrefix, region.GetId()) - err := tu.CheckGetJSON(testDialClient, regionURL, nil, + err := tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusNotOK(re), tu.StringContain(re, "operator not found")) re.NoError(err) recordURL := fmt.Sprintf("%s/operators/records?from=%s", urlPrefix, strconv.FormatInt(time.Now().Unix(), 10)) - err = tu.CheckGetJSON(testDialClient, recordURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, recordURL, nil, tu.StatusNotOK(re), tu.StringContain(re, "operator not found")) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, regionURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusOK(re), tu.StringContain(re, "add learner peer 1 on store 3"), tu.StringContain(re, "RUNNING")) re.NoError(err) - err = tu.CheckDelete(testDialClient, regionURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, regionURL, tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, recordURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, recordURL, nil, tu.StatusOK(re), tu.StringContain(re, "admin-add-peer {add peer: store [3]}")) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"remove-peer", "region_id": 1, "store_id": 2}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"remove-peer", "region_id": 1, "store_id": 2}`), tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, regionURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusOK(re), tu.StringContain(re, "remove peer on store 2"), tu.StringContain(re, "RUNNING")) re.NoError(err) - err = tu.CheckDelete(testDialClient, regionURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, regionURL, tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, recordURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, recordURL, nil, tu.StatusOK(re), tu.StringContain(re, "admin-remove-peer {rm peer: store [2]}")) re.NoError(err) @@ -150,26 +140,26 @@ func (suite *operatorTestSuite) checkAddRemovePeer(cluster *tests.TestCluster) { NodeState: metapb.NodeState_Serving, LastHeartbeat: time.Now().UnixNano(), }) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-learner", "region_id": 1, "store_id": 4}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-learner", "region_id": 1, "store_id": 4}`), tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, regionURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusOK(re), tu.StringContain(re, "add learner peer 2 on store 4")) re.NoError(err) // Fail to add peer to tombstone store. err = cluster.GetLeaderServer().GetRaftCluster().RemoveStore(3, true) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusNotOK(re)) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"transfer-peer", "region_id": 1, "from_store_id": 1, "to_store_id": 3}`), tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"transfer-peer", "region_id": 1, "from_store_id": 1, "to_store_id": 3}`), tu.StatusNotOK(re)) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"transfer-region", "region_id": 1, "to_store_ids": [1, 2, 3]}`), tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"transfer-region", "region_id": 1, "to_store_ids": [1, 2, 3]}`), tu.StatusNotOK(re)) re.NoError(err) // Fail to get operator if from is latest. time.Sleep(time.Second) url := fmt.Sprintf("%s/operators/records?from=%s", urlPrefix, strconv.FormatInt(time.Now().Unix(), 10)) - err = tu.CheckGetJSON(testDialClient, url, nil, + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.StatusNotOK(re), tu.StringContain(re, "operator not found")) re.NoError(err) } @@ -205,7 +195,7 @@ func (suite *operatorTestSuite) checkMergeRegionOperator(cluster *tests.TestClus tests.MustPutStore(re, cluster, store) } - pauseRuleChecker(re, cluster) + pauseAllCheckers(re, cluster) r1 := core.NewTestRegionInfo(10, 1, []byte(""), []byte("b"), core.SetWrittenBytes(1000), core.SetReadBytes(1000), core.SetRegionConfVer(1), core.SetRegionVersion(1)) tests.MustPutRegionInfo(re, cluster, r1) r2 := core.NewTestRegionInfo(20, 1, []byte("b"), []byte("c"), core.SetWrittenBytes(2000), core.SetReadBytes(0), core.SetRegionConfVer(2), core.SetRegionVersion(3)) @@ -214,17 +204,17 @@ func (suite *operatorTestSuite) checkMergeRegionOperator(cluster *tests.TestClus tests.MustPutRegionInfo(re, cluster, r3) urlPrefix := fmt.Sprintf("%s/pd/api/v1", cluster.GetLeaderServer().GetAddr()) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) re.NoError(err) - tu.CheckDelete(testDialClient, fmt.Sprintf("%s/operators/%d", urlPrefix, 10), tu.StatusOK(re)) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 20, "target_region_id": 10}`), tu.StatusOK(re)) + tu.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/operators/%d", urlPrefix, 10), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 20, "target_region_id": 10}`), tu.StatusOK(re)) re.NoError(err) - tu.CheckDelete(testDialClient, fmt.Sprintf("%s/operators/%d", urlPrefix, 10), tu.StatusOK(re)) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 30}`), + tu.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/operators/%d", urlPrefix, 10), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 30}`), tu.StatusNotOK(re), tu.StringContain(re, "not adjacent")) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 30, "target_region_id": 10}`), + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 30, "target_region_id": 10}`), tu.StatusNotOK(re), tu.StringContain(re, "not adjacent")) re.NoError(err) } @@ -241,7 +231,7 @@ func (suite *operatorTestSuite) TestTransferRegionWithPlacementRule() { func (suite *operatorTestSuite) checkTransferRegionWithPlacementRule(cluster *tests.TestCluster) { re := suite.Require() - pauseRuleChecker(re, cluster) + pauseAllCheckers(re, cluster) stores := []*metapb.Store{ { Id: 1, @@ -287,7 +277,7 @@ func (suite *operatorTestSuite) checkTransferRegionWithPlacementRule(cluster *te urlPrefix := fmt.Sprintf("%s/pd/api/v1", cluster.GetLeaderServer().GetAddr()) regionURL := fmt.Sprintf("%s/operators/%d", urlPrefix, region.GetId()) - err := tu.CheckGetJSON(testDialClient, regionURL, nil, + err := tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusNotOK(re), tu.StringContain(re, "operator not found")) re.NoError(err) convertStepsToStr := func(steps []string) string { @@ -462,7 +452,7 @@ func (suite *operatorTestSuite) checkTransferRegionWithPlacementRule(cluster *te } reqData, e := json.Marshal(data) re.NoError(e) - err := tu.CheckPostJSON(testDialClient, url, reqData, tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, url, reqData, tu.StatusOK(re)) re.NoError(err) if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { // wait for the scheduling server to update the config @@ -491,19 +481,19 @@ func (suite *operatorTestSuite) checkTransferRegionWithPlacementRule(cluster *te re.NoError(err) } if testCase.expectedError == nil { - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), testCase.input, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), testCase.input, tu.StatusOK(re)) } else { - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), testCase.input, + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), testCase.input, tu.StatusNotOK(re), tu.StringContain(re, testCase.expectedError.Error())) } re.NoError(err) if len(testCase.expectSteps) > 0 { - err = tu.CheckGetJSON(testDialClient, regionURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusOK(re), tu.StringContain(re, testCase.expectSteps)) re.NoError(err) - err = tu.CheckDelete(testDialClient, regionURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, regionURL, tu.StatusOK(re)) } else { - err = tu.CheckDelete(testDialClient, regionURL, tu.StatusNotOK(re)) + err = tu.CheckDelete(tests.TestDialClient, regionURL, tu.StatusNotOK(re)) } re.NoError(err) } @@ -521,7 +511,7 @@ func (suite *operatorTestSuite) TestGetOperatorsAsObject() { func (suite *operatorTestSuite) checkGetOperatorsAsObject(cluster *tests.TestCluster) { re := suite.Require() - pauseRuleChecker(re, cluster) + pauseAllCheckers(re, cluster) stores := []*metapb.Store{ { Id: 1, @@ -552,7 +542,7 @@ func (suite *operatorTestSuite) checkGetOperatorsAsObject(cluster *tests.TestClu resp := make([]operator.OpObject, 0) // No operator. - err := tu.ReadGetJSON(re, testDialClient, objURL, &resp) + err := tu.ReadGetJSON(re, tests.TestDialClient, objURL, &resp) re.NoError(err) re.Empty(resp) @@ -564,9 +554,9 @@ func (suite *operatorTestSuite) checkGetOperatorsAsObject(cluster *tests.TestClu r3 := core.NewTestRegionInfo(30, 1, []byte("c"), []byte("d"), core.SetWrittenBytes(500), core.SetReadBytes(800), core.SetRegionConfVer(3), core.SetRegionVersion(2)) tests.MustPutRegionInfo(re, cluster, r3) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) re.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, objURL, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, objURL, &resp) re.NoError(err) re.Len(resp, 2) less := func(i, j int) bool { @@ -601,9 +591,9 @@ func (suite *operatorTestSuite) checkGetOperatorsAsObject(cluster *tests.TestClu } regionInfo := core.NewRegionInfo(region, peer1) tests.MustPutRegionInfo(re, cluster, regionInfo) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 40, "store_id": 3}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 40, "store_id": 3}`), tu.StatusOK(re)) re.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, objURL, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, objURL, &resp) re.NoError(err) re.Len(resp, 3) sort.Slice(resp, less) @@ -642,7 +632,7 @@ func (suite *operatorTestSuite) checkRemoveOperators(cluster *tests.TestCluster) tests.MustPutStore(re, cluster, store) } - pauseRuleChecker(re, cluster) + pauseAllCheckers(re, cluster) r1 := core.NewTestRegionInfo(10, 1, []byte(""), []byte("b"), core.SetWrittenBytes(1000), core.SetReadBytes(1000), core.SetRegionConfVer(1), core.SetRegionVersion(1)) tests.MustPutRegionInfo(re, cluster, r1) r2 := core.NewTestRegionInfo(20, 1, []byte("b"), []byte("c"), core.SetWrittenBytes(2000), core.SetReadBytes(0), core.SetRegionConfVer(2), core.SetRegionVersion(3)) @@ -651,15 +641,15 @@ func (suite *operatorTestSuite) checkRemoveOperators(cluster *tests.TestCluster) tests.MustPutRegionInfo(re, cluster, r3) urlPrefix := fmt.Sprintf("%s/pd/api/v1", cluster.GetLeaderServer().GetAddr()) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 30, "store_id": 4}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 30, "store_id": 4}`), tu.StatusOK(re)) re.NoError(err) url := fmt.Sprintf("%s/operators", urlPrefix) - err = tu.CheckGetJSON(testDialClient, url, nil, tu.StatusOK(re), tu.StringContain(re, "merge: region 10 to 20"), tu.StringContain(re, "add peer: store [4]")) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.StatusOK(re), tu.StringContain(re, "merge: region 10 to 20"), tu.StringContain(re, "add peer: store [4]")) re.NoError(err) - err = tu.CheckDelete(testDialClient, url, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, url, tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, url, nil, tu.StatusOK(re), tu.StringNotContain(re, "merge: region 10 to 20"), tu.StringNotContain(re, "add peer: store [4]")) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.StatusOK(re), tu.StringNotContain(re, "merge: region 10 to 20"), tu.StringNotContain(re, "add peer: store [4]")) re.NoError(err) } diff --git a/tests/server/api/region_test.go b/tests/server/api/region_test.go index b233ce94a99..2ff0b5d4b86 100644 --- a/tests/server/api/region_test.go +++ b/tests/server/api/region_test.go @@ -57,7 +57,7 @@ func (suite *regionTestSuite) TearDownTest() { pdAddr := cluster.GetConfig().GetClientURL() for _, region := range leader.GetRegions() { url := fmt.Sprintf("%s/pd/api/v1/admin/cache/region/%d", pdAddr, region.GetID()) - err := tu.CheckDelete(testDialClient, url, tu.StatusOK(re)) + err := tu.CheckDelete(tests.TestDialClient, url, tu.StatusOK(re)) re.NoError(err) } re.Empty(leader.GetRegions()) @@ -71,7 +71,7 @@ func (suite *regionTestSuite) TearDownTest() { data, err := json.Marshal([]placement.GroupBundle{def}) re.NoError(err) urlPrefix := cluster.GetLeaderServer().GetAddr() - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/pd/api/v1/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/pd/api/v1/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) // clean stores for _, store := range leader.GetStores() { @@ -132,7 +132,7 @@ func (suite *regionTestSuite) checkSplitRegions(cluster *tests.TestCluster) { re.Equal([]uint64{newRegionID}, s.NewRegionsID) } re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/schedule/handler/splitResponses", fmt.Sprintf("return(%v)", newRegionID))) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/split", urlPrefix), []byte(body), checkOpt) + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/regions/split", urlPrefix), []byte(body), checkOpt) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/schedule/handler/splitResponses")) re.NoError(err) } @@ -162,7 +162,7 @@ func (suite *regionTestSuite) checkAccelerateRegionsScheduleInRange(cluster *tes checkRegionCount(re, cluster, regionCount) body := fmt.Sprintf(`{"start_key":"%s", "end_key": "%s"}`, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a3"))) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/accelerate-schedule", urlPrefix), []byte(body), + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/regions/accelerate-schedule", urlPrefix), []byte(body), tu.StatusOK(re)) re.NoError(err) idList := leader.GetRaftCluster().GetSuspectRegions() @@ -198,7 +198,7 @@ func (suite *regionTestSuite) checkAccelerateRegionsScheduleInRanges(cluster *te body := fmt.Sprintf(`[{"start_key":"%s", "end_key": "%s"}, {"start_key":"%s", "end_key": "%s"}]`, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a3")), hex.EncodeToString([]byte("a4")), hex.EncodeToString([]byte("a6"))) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/accelerate-schedule/batch", urlPrefix), []byte(body), + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/regions/accelerate-schedule/batch", urlPrefix), []byte(body), tu.StatusOK(re)) re.NoError(err) idList := leader.GetRaftCluster().GetSuspectRegions() @@ -239,7 +239,7 @@ func (suite *regionTestSuite) checkScatterRegions(cluster *tests.TestCluster) { checkRegionCount(re, cluster, 3) body := fmt.Sprintf(`{"start_key":"%s", "end_key": "%s"}`, hex.EncodeToString([]byte("b1")), hex.EncodeToString([]byte("b3"))) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/scatter", urlPrefix), []byte(body), tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/regions/scatter", urlPrefix), []byte(body), tu.StatusOK(re)) re.NoError(err) oc := leader.GetRaftCluster().GetOperatorController() if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { @@ -253,7 +253,7 @@ func (suite *regionTestSuite) checkScatterRegions(cluster *tests.TestCluster) { re.True(op1 != nil || op2 != nil || op3 != nil) body = `{"regions_id": [701, 702, 703]}` - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/scatter", urlPrefix), []byte(body), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/regions/scatter", urlPrefix), []byte(body), tu.StatusOK(re)) re.NoError(err) } @@ -263,7 +263,7 @@ func (suite *regionTestSuite) TestCheckRegionsReplicated() { func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) { re := suite.Require() - pauseRuleChecker(re, cluster) + pauseAllCheckers(re, cluster) leader := cluster.GetLeaderServer() urlPrefix := leader.GetAddr() + "/pd/api/v1" @@ -295,40 +295,40 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) // invalid url url := fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, urlPrefix, "_", "t") - err := tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusBadRequest)) + err := tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.Status(re, http.StatusBadRequest)) re.NoError(err) url = fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, urlPrefix, hex.EncodeToString(r1.GetStartKey()), "_") - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusBadRequest)) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.Status(re, http.StatusBadRequest)) re.NoError(err) // correct test url = fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, urlPrefix, hex.EncodeToString(r1.GetStartKey()), hex.EncodeToString(r1.GetEndKey())) - err = tu.CheckGetJSON(testDialClient, url, nil, tu.StatusOK(re)) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.StatusOK(re)) re.NoError(err) // test one rule data, err := json.Marshal(bundle) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err = tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) return len(respBundle) == 1 && respBundle[0].ID == "5" }) tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &status) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &status) re.NoError(err) return status == "REPLICATED" }) re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/schedule/handler/mockPending", "return(true)")) - err = tu.ReadGetJSON(re, testDialClient, url, &status) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &status) re.NoError(err) re.Equal("PENDING", status) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/schedule/handler/mockPending")) @@ -342,19 +342,19 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) }) data, err = json.Marshal(bundle) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err = tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) return len(respBundle) == 1 && len(respBundle[0].Rules) == 2 }) tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &status) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &status) re.NoError(err) return status == "REPLICATED" }) @@ -371,12 +371,12 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) }) data, err = json.Marshal(bundle) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err = tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) if len(respBundle) != 2 { @@ -388,7 +388,7 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) }) tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &status) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &status) re.NoError(err) return status == "INPROGRESS" }) @@ -398,7 +398,7 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) tests.MustPutRegionInfo(re, cluster, r1) tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &status) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &status) re.NoError(err) return status == "REPLICATED" }) @@ -416,15 +416,16 @@ func checkRegionCount(re *require.Assertions, cluster *tests.TestCluster, count } } -// pauseRuleChecker will pause rule checker to avoid unexpected operator. -func pauseRuleChecker(re *require.Assertions, cluster *tests.TestCluster) { - checkerName := "rule" +func pauseAllCheckers(re *require.Assertions, cluster *tests.TestCluster) { + checkerNames := []string{"learner", "replica", "rule", "split", "merge", "joint-state"} addr := cluster.GetLeaderServer().GetAddr() - resp := make(map[string]any) - url := fmt.Sprintf("%s/pd/api/v1/checker/%s", addr, checkerName) - err := tu.CheckPostJSON(testDialClient, url, []byte(`{"delay":1000}`), tu.StatusOK(re)) - re.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, url, &resp) - re.NoError(err) - re.True(resp["paused"].(bool)) + for _, checkerName := range checkerNames { + resp := make(map[string]any) + url := fmt.Sprintf("%s/pd/api/v1/checker/%s", addr, checkerName) + err := tu.CheckPostJSON(tests.TestDialClient, url, []byte(`{"delay":1000}`), tu.StatusOK(re)) + re.NoError(err) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &resp) + re.NoError(err) + re.True(resp["paused"].(bool)) + } } diff --git a/tests/server/api/rule_test.go b/tests/server/api/rule_test.go index 4f60b5cfb28..16077a308f6 100644 --- a/tests/server/api/rule_test.go +++ b/tests/server/api/rule_test.go @@ -71,7 +71,7 @@ func (suite *ruleTestSuite) TearDownTest() { data, err := json.Marshal([]placement.GroupBundle{def}) re.NoError(err) urlPrefix := cluster.GetLeaderServer().GetAddr() - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/pd/api/v1/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/pd/api/v1/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) } suite.env.RunFuncInTwoModes(cleanFunc) @@ -171,7 +171,7 @@ func (suite *ruleTestSuite) checkSet(cluster *tests.TestCluster) { // clear suspect keyRanges to prevent test case from others leaderServer.GetRaftCluster().ClearSuspectKeyRanges() if testCase.success { - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", testCase.rawData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", testCase.rawData, tu.StatusOK(re)) popKeyRangeMap := map[string]struct{}{} for i := 0; i < len(testCase.popKeyRange)/2; i++ { v, got := leaderServer.GetRaftCluster().PopOneSuspectKeyRange() @@ -185,7 +185,7 @@ func (suite *ruleTestSuite) checkSet(cluster *tests.TestCluster) { re.True(ok) } } else { - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", testCase.rawData, + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", testCase.rawData, tu.StatusNotOK(re), tu.StringEqual(re, testCase.response)) } @@ -206,7 +206,7 @@ func (suite *ruleTestSuite) checkGet(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "a", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) testCases := []struct { @@ -234,11 +234,11 @@ func (suite *ruleTestSuite) checkGet(cluster *tests.TestCluster) { url := fmt.Sprintf("%s/rule/%s/%s", urlPrefix, testCase.rule.GroupID, testCase.rule.ID) if testCase.found { tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &resp) return compareRule(&resp, &testCase.rule) }) } else { - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, testCase.code)) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.Status(re, testCase.code)) } re.NoError(err) } @@ -257,11 +257,11 @@ func (suite *ruleTestSuite) checkGetAll(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "b", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) var resp2 []*placement.Rule - err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/rules", &resp2) + err = tu.ReadGetJSON(re, tests.TestDialClient, urlPrefix+"/rules", &resp2) re.NoError(err) re.NotEmpty(resp2) } @@ -369,13 +369,13 @@ func (suite *ruleTestSuite) checkSetAll(cluster *tests.TestCluster) { for _, testCase := range testCases { suite.T().Log(testCase.name) if testCase.success { - err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules", testCase.rawData, tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rules", testCase.rawData, tu.StatusOK(re)) re.NoError(err) if testCase.isDefaultRule { re.Equal(int(leaderServer.GetPersistOptions().GetReplicationConfig().MaxReplicas), testCase.count) } } else { - err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules", testCase.rawData, + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rules", testCase.rawData, tu.StringEqual(re, testCase.response)) re.NoError(err) } @@ -395,13 +395,13 @@ func (suite *ruleTestSuite) checkGetAllByGroup(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "c", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) rule1 := placement.Rule{GroupID: "c", ID: "30", StartKeyHex: "1111", EndKeyHex: "3333", Role: placement.Voter, Count: 1} data, err = json.Marshal(rule1) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) testCases := []struct { @@ -426,7 +426,7 @@ func (suite *ruleTestSuite) checkGetAllByGroup(cluster *tests.TestCluster) { var resp []*placement.Rule url := fmt.Sprintf("%s/rules/group/%s", urlPrefix, testCase.groupID) tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &resp) re.NoError(err) if len(resp) != testCase.count { return false @@ -452,7 +452,7 @@ func (suite *ruleTestSuite) checkGetAllByRegion(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "e", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) r := core.NewTestRegionInfo(4, 1, []byte{0x22, 0x22}, []byte{0x33, 0x33}) @@ -489,7 +489,7 @@ func (suite *ruleTestSuite) checkGetAllByRegion(cluster *tests.TestCluster) { if testCase.success { tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &resp) for _, r := range resp { if r.GroupID == "e" { return compareRule(r, &rule) @@ -498,7 +498,7 @@ func (suite *ruleTestSuite) checkGetAllByRegion(cluster *tests.TestCluster) { return true }) } else { - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, testCase.code)) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.Status(re, testCase.code)) } re.NoError(err) } @@ -517,7 +517,7 @@ func (suite *ruleTestSuite) checkGetAllByKey(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "f", ID: "40", StartKeyHex: "8888", EndKeyHex: "9111", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) testCases := []struct { @@ -553,11 +553,11 @@ func (suite *ruleTestSuite) checkGetAllByKey(cluster *tests.TestCluster) { url := fmt.Sprintf("%s/rules/key/%s", urlPrefix, testCase.key) if testCase.success { tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &resp) return len(resp) == testCase.respSize }) } else { - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, testCase.code)) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.Status(re, testCase.code)) } re.NoError(err) } @@ -576,7 +576,7 @@ func (suite *ruleTestSuite) checkDelete(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "g", ID: "10", StartKeyHex: "8888", EndKeyHex: "9111", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) oldStartKey, err := hex.DecodeString(rule.StartKeyHex) re.NoError(err) @@ -610,7 +610,7 @@ func (suite *ruleTestSuite) checkDelete(cluster *tests.TestCluster) { url := fmt.Sprintf("%s/rule/%s/%s", urlPrefix, testCase.groupID, testCase.id) // clear suspect keyRanges to prevent test case from others leaderServer.GetRaftCluster().ClearSuspectKeyRanges() - err = tu.CheckDelete(testDialClient, url, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, url, tu.StatusOK(re)) re.NoError(err) if len(testCase.popKeyRange) > 0 { popKeyRangeMap := map[string]struct{}{} @@ -747,10 +747,10 @@ func (suite *ruleTestSuite) checkBatch(cluster *tests.TestCluster) { for _, testCase := range testCases { suite.T().Log(testCase.name) if testCase.success { - err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules/batch", testCase.rawData, tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rules/batch", testCase.rawData, tu.StatusOK(re)) re.NoError(err) } else { - err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules/batch", testCase.rawData, + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rules/batch", testCase.rawData, tu.StatusNotOK(re), tu.StringEqual(re, testCase.response)) re.NoError(err) @@ -793,7 +793,7 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { } data, err := json.Marshal(b2) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule/foo", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/placement-rule/foo", data, tu.StatusOK(re)) re.NoError(err) // Get @@ -803,7 +803,7 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { assertBundlesEqual(re, urlPrefix+"/placement-rule", []placement.GroupBundle{b1, b2}, 2) // Delete - err = tu.CheckDelete(testDialClient, urlPrefix+"/placement-rule/pd", tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, urlPrefix+"/placement-rule/pd", tu.StatusOK(re)) re.NoError(err) // GetAll again @@ -815,14 +815,14 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { b3 := placement.GroupBundle{ID: "foobar", Index: 100} data, err = json.Marshal([]placement.GroupBundle{b1, b2, b3}) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) re.NoError(err) // GetAll again assertBundlesEqual(re, urlPrefix+"/placement-rule", []placement.GroupBundle{b1, b2, b3}, 3) // Delete using regexp - err = tu.CheckDelete(testDialClient, urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp", tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp", tu.StatusOK(re)) re.NoError(err) // GetAll again @@ -838,7 +838,7 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { } data, err = json.Marshal(b4) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule/"+id, data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/placement-rule/"+id, data, tu.StatusOK(re)) re.NoError(err) b4.ID = id @@ -859,7 +859,7 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { } data, err = json.Marshal([]placement.GroupBundle{b1, b4, b5}) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) re.NoError(err) b5.Rules[0].GroupID = b5.ID @@ -891,7 +891,7 @@ func (suite *ruleTestSuite) checkBundleBadRequest(cluster *tests.TestCluster) { {"/placement-rule", `[{"group_id":"foo", "rules": [{"group_id":"bar", "id":"baz", "role":"voter", "count":1}]}]`, false}, } for _, testCase := range testCases { - err := tu.CheckPostJSON(testDialClient, urlPrefix+testCase.uri, []byte(testCase.data), + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix+testCase.uri, []byte(testCase.data), func(_ []byte, code int, _ http.Header) { re.Equal(testCase.ok, code == http.StatusOK) }) @@ -976,12 +976,12 @@ func (suite *ruleTestSuite) checkLeaderAndVoter(cluster *tests.TestCluster) { for _, bundle := range bundles { data, err := json.Marshal(bundle) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err := tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err := tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) re.Len(respBundle, 1) @@ -1144,7 +1144,7 @@ func (suite *ruleTestSuite) checkConcurrencyWith(cluster *tests.TestCluster, re.NoError(err) for j := 0; j < 10; j++ { expectResult.Lock() - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) expectResult.val = i expectResult.Unlock() @@ -1158,7 +1158,7 @@ func (suite *ruleTestSuite) checkConcurrencyWith(cluster *tests.TestCluster, re.NotZero(expectResult.val) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err := tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err := tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) re.Len(respBundle, 1) @@ -1197,7 +1197,7 @@ func (suite *ruleTestSuite) checkLargeRules(cluster *tests.TestCluster) { func assertBundleEqual(re *require.Assertions, url string, expectedBundle placement.GroupBundle) { var bundle placement.GroupBundle tu.Eventually(re, func() bool { - err := tu.ReadGetJSON(re, testDialClient, url, &bundle) + err := tu.ReadGetJSON(re, tests.TestDialClient, url, &bundle) if err != nil { return false } @@ -1208,7 +1208,7 @@ func assertBundleEqual(re *require.Assertions, url string, expectedBundle placem func assertBundlesEqual(re *require.Assertions, url string, expectedBundles []placement.GroupBundle, expectedLen int) { var bundles []placement.GroupBundle tu.Eventually(re, func() bool { - err := tu.ReadGetJSON(re, testDialClient, url, &bundles) + err := tu.ReadGetJSON(re, tests.TestDialClient, url, &bundles) if err != nil { return false } @@ -1253,12 +1253,12 @@ func (suite *ruleTestSuite) postAndCheckRuleBundle(urlPrefix string, bundle []pl re := suite.Require() data, err := json.Marshal(bundle) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err = tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) if len(respBundle) != len(bundle) { @@ -1364,19 +1364,19 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl fit := &placement.RegionFit{} u := fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 1) - err := tu.ReadGetJSON(re, testDialClient, u, fit) + err := tu.ReadGetJSON(re, tests.TestDialClient, u, fit) re.NoError(err) re.Len(fit.RuleFits, 1) re.Len(fit.OrphanPeers, 1) u = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 2) fit = &placement.RegionFit{} - err = tu.ReadGetJSON(re, testDialClient, u, fit) + err = tu.ReadGetJSON(re, tests.TestDialClient, u, fit) re.NoError(err) re.Len(fit.RuleFits, 2) re.Empty(fit.OrphanPeers) u = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 3) fit = &placement.RegionFit{} - err = tu.ReadGetJSON(re, testDialClient, u, fit) + err = tu.ReadGetJSON(re, tests.TestDialClient, u, fit) re.NoError(err) re.Empty(fit.RuleFits) re.Len(fit.OrphanPeers, 2) @@ -1384,26 +1384,26 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl var label labeler.LabelRule escapedID := url.PathEscape("keyspaces/0") u = fmt.Sprintf("%s/config/region-label/rule/%s", urlPrefix, escapedID) - err = tu.ReadGetJSON(re, testDialClient, u, &label) + err = tu.ReadGetJSON(re, tests.TestDialClient, u, &label) re.NoError(err) re.Equal("keyspaces/0", label.ID) var labels []labeler.LabelRule u = fmt.Sprintf("%s/config/region-label/rules", urlPrefix) - err = tu.ReadGetJSON(re, testDialClient, u, &labels) + err = tu.ReadGetJSON(re, tests.TestDialClient, u, &labels) re.NoError(err) re.Len(labels, 1) re.Equal("keyspaces/0", labels[0].ID) u = fmt.Sprintf("%s/config/region-label/rules/ids", urlPrefix) - err = tu.CheckGetJSON(testDialClient, u, []byte(`["rule1", "rule3"]`), func(resp []byte, _ int, _ http.Header) { + err = tu.CheckGetJSON(tests.TestDialClient, u, []byte(`["rule1", "rule3"]`), func(resp []byte, _ int, _ http.Header) { err := json.Unmarshal(resp, &labels) re.NoError(err) re.Empty(labels) }) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, u, []byte(`["keyspaces/0"]`), func(resp []byte, _ int, _ http.Header) { + err = tu.CheckGetJSON(tests.TestDialClient, u, []byte(`["keyspaces/0"]`), func(resp []byte, _ int, _ http.Header) { err := json.Unmarshal(resp, &labels) re.NoError(err) re.Len(labels, 1) @@ -1412,12 +1412,12 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl re.NoError(err) u = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 4) - err = tu.CheckGetJSON(testDialClient, u, nil, tu.Status(re, http.StatusNotFound), tu.StringContain( + err = tu.CheckGetJSON(tests.TestDialClient, u, nil, tu.Status(re, http.StatusNotFound), tu.StringContain( re, "region 4 not found")) re.NoError(err) u = fmt.Sprintf("%s/config/rules/region/%s/detail", urlPrefix, "id") - err = tu.CheckGetJSON(testDialClient, u, nil, tu.Status(re, http.StatusBadRequest), tu.StringContain( + err = tu.CheckGetJSON(tests.TestDialClient, u, nil, tu.Status(re, http.StatusBadRequest), tu.StringContain( re, errs.ErrRegionInvalidID.Error())) re.NoError(err) @@ -1426,7 +1426,7 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl reqData, e := json.Marshal(data) re.NoError(e) u = fmt.Sprintf("%s/config", urlPrefix) - err = tu.CheckPostJSON(testDialClient, u, reqData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, u, reqData, tu.StatusOK(re)) re.NoError(err) if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { // wait for the scheduling server to update the config @@ -1435,7 +1435,7 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl }) } u = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 1) - err = tu.CheckGetJSON(testDialClient, u, nil, tu.Status(re, http.StatusPreconditionFailed), tu.StringContain( + err = tu.CheckGetJSON(tests.TestDialClient, u, nil, tu.Status(re, http.StatusPreconditionFailed), tu.StringContain( re, "placement rules feature is disabled")) re.NoError(err) } diff --git a/tests/server/api/scheduler_test.go b/tests/server/api/scheduler_test.go index 4f71315803a..10631dab158 100644 --- a/tests/server/api/scheduler_test.go +++ b/tests/server/api/scheduler_test.go @@ -84,12 +84,12 @@ func (suite *scheduleTestSuite) checkOriginAPI(cluster *tests.TestCluster) { input["store_id"] = 1 body, err := json.Marshal(input) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, urlPrefix, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, urlPrefix, body, tu.StatusOK(re))) suite.assertSchedulerExists(urlPrefix, "evict-leader-scheduler") resp := make(map[string]any) listURL := fmt.Sprintf("%s%s%s/%s/list", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, "evict-leader-scheduler") - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) re.Len(resp["store-id-ranges"], 1) input1 := make(map[string]any) input1["name"] = "evict-leader-scheduler" @@ -97,35 +97,35 @@ func (suite *scheduleTestSuite) checkOriginAPI(cluster *tests.TestCluster) { body, err = json.Marshal(input1) re.NoError(err) re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/schedule/schedulers/persistFail", "return(true)")) - re.NoError(tu.CheckPostJSON(testDialClient, urlPrefix, body, tu.StatusNotOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, urlPrefix, body, tu.StatusNotOK(re))) suite.assertSchedulerExists(urlPrefix, "evict-leader-scheduler") resp = make(map[string]any) - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) re.Len(resp["store-id-ranges"], 1) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/schedule/schedulers/persistFail")) - re.NoError(tu.CheckPostJSON(testDialClient, urlPrefix, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, urlPrefix, body, tu.StatusOK(re))) suite.assertSchedulerExists(urlPrefix, "evict-leader-scheduler") resp = make(map[string]any) - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) re.Len(resp["store-id-ranges"], 2) deleteURL := fmt.Sprintf("%s/%s", urlPrefix, "evict-leader-scheduler-1") - err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) suite.assertSchedulerExists(urlPrefix, "evict-leader-scheduler") resp1 := make(map[string]any) - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp1)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp1)) re.Len(resp1["store-id-ranges"], 1) deleteURL = fmt.Sprintf("%s/%s", urlPrefix, "evict-leader-scheduler-2") re.NoError(failpoint.Enable("github.com/tikv/pd/server/config/persistFail", "return(true)")) - err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusInternalServerError)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.Status(re, http.StatusInternalServerError)) re.NoError(err) suite.assertSchedulerExists(urlPrefix, "evict-leader-scheduler") re.NoError(failpoint.Disable("github.com/tikv/pd/server/config/persistFail")) - err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) assertNoScheduler(re, urlPrefix, "evict-leader-scheduler") - re.NoError(tu.CheckGetJSON(testDialClient, listURL, nil, tu.Status(re, http.StatusNotFound))) - err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) + re.NoError(tu.CheckGetJSON(tests.TestDialClient, listURL, nil, tu.Status(re, http.StatusNotFound))) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) re.NoError(err) } @@ -164,7 +164,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { listURL := fmt.Sprintf("%s%s%s/%s/list", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) resp := make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 4.0 }) dataMap := make(map[string]any) @@ -172,15 +172,15 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(dataMap) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]any) tu.Eventually(re, func() bool { // wait for scheduling server to be synced. - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 3.0 }) // update again - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re), tu.StringEqual(re, "\"Config is the same with origin, so do nothing.\"\n")) re.NoError(err) @@ -189,17 +189,17 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["batch"] = 100 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"invalid batch size which should be an integer between 1 and 10\"\n")) re.NoError(err) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 3.0 }) // empty body - err = tu.CheckPostJSON(testDialClient, updateURL, nil, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, nil, tu.Status(re, http.StatusInternalServerError), tu.StringEqual(re, "\"unexpected end of JSON input\"\n")) re.NoError(err) @@ -208,7 +208,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["error"] = 3 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"Config item is not found.\"\n")) re.NoError(err) @@ -245,7 +245,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { "history-sample-interval": "30s", } tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) re.Equal(len(expectMap), len(resp), "expect %v, got %v", expectMap, resp) for key := range expectMap { if !reflect.DeepEqual(resp[key], expectMap[key]) { @@ -260,10 +260,10 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(dataMap) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) for key := range expectMap { if !reflect.DeepEqual(resp[key], expectMap[key]) { return false @@ -273,7 +273,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { }) // update again - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re), tu.StringEqual(re, "Config is the same with origin, so do nothing.")) re.NoError(err) @@ -282,7 +282,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["error"] = 3 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "Config item is not found.")) re.NoError(err) @@ -295,7 +295,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { listURL := fmt.Sprintf("%s%s%s/%s/list", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) resp := make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["degree"] == 3.0 && resp["split-limit"] == 0.0 }) dataMap := make(map[string]any) @@ -303,19 +303,19 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(dataMap) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["degree"] == 4.0 }) // update again - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re), tu.StringEqual(re, "Config is the same with origin, so do nothing.")) re.NoError(err) // empty body - err = tu.CheckPostJSON(testDialClient, updateURL, nil, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, nil, tu.Status(re, http.StatusInternalServerError), tu.StringEqual(re, "\"unexpected end of JSON input\"\n")) re.NoError(err) @@ -324,7 +324,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["error"] = 3 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "Config item is not found.")) re.NoError(err) @@ -353,7 +353,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { resp := make(map[string]any) listURL := fmt.Sprintf("%s%s%s/%s/list", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 4.0 }) dataMap := make(map[string]any) @@ -361,14 +361,14 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(dataMap) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 3.0 }) // update again - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re), tu.StringEqual(re, "\"Config is the same with origin, so do nothing.\"\n")) re.NoError(err) @@ -377,17 +377,17 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["batch"] = 100 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"invalid batch size which should be an integer between 1 and 10\"\n")) re.NoError(err) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 3.0 }) // empty body - err = tu.CheckPostJSON(testDialClient, updateURL, nil, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, nil, tu.Status(re, http.StatusInternalServerError), tu.StringEqual(re, "\"unexpected end of JSON input\"\n")) re.NoError(err) @@ -396,7 +396,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["error"] = 3 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"Config item is not found.\"\n")) re.NoError(err) @@ -412,7 +412,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { expectedMap := make(map[string]any) expectedMap["1"] = []any{map[string]any{"end-key": "", "start-key": ""}} tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) @@ -423,25 +423,25 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(input) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) expectedMap["2"] = []any{map[string]any{"end-key": "", "start-key": ""}} resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) // using /pd/v1/schedule-config/grant-leader-scheduler/config to delete exists store from grant-leader-scheduler deleteURL := fmt.Sprintf("%s%s%s/%s/delete/%s", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name, "2") - err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) delete(expectedMap, "2") resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) - err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) re.NoError(err) }, }, @@ -454,7 +454,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { resp := make(map[string]any) listURL := fmt.Sprintf("%s%s%s/%s/list", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["start-key"] == "" && resp["end-key"] == "" && resp["range-name"] == "test" }) resp["start-key"] = "a_00" @@ -462,10 +462,10 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(resp) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["start-key"] == "a_00" && resp["end-key"] == "a_99" && resp["range-name"] == "test" }) }, @@ -481,7 +481,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { expectedMap := make(map[string]any) expectedMap["3"] = []any{map[string]any{"end-key": "", "start-key": ""}} tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) @@ -492,25 +492,25 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(input) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) expectedMap["4"] = []any{map[string]any{"end-key": "", "start-key": ""}} resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) // using /pd/v1/schedule-config/evict-leader-scheduler/config to delete exist store from evict-leader-scheduler deleteURL := fmt.Sprintf("%s%s%s/%s/delete/%s", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name, "4") - err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) delete(expectedMap, "4") resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) - err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) re.NoError(err) }, }, @@ -558,7 +558,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { input["delay"] = 30 pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) re.NoError(err) for _, testCase := range testCases { @@ -572,7 +572,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { input["delay"] = 1 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) re.NoError(err) time.Sleep(time.Second) for _, testCase := range testCases { @@ -588,12 +588,12 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { input["delay"] = 30 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) re.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) re.NoError(err) for _, testCase := range testCases { createdName := testCase.createdName @@ -642,14 +642,14 @@ func (suite *scheduleTestSuite) checkDisable(cluster *tests.TestCluster) { u := fmt.Sprintf("%s%s/api/v1/config/schedule", leaderAddr, apiPrefix) var scheduleConfig sc.ScheduleConfig - err = tu.ReadGetJSON(re, testDialClient, u, &scheduleConfig) + err = tu.ReadGetJSON(re, tests.TestDialClient, u, &scheduleConfig) re.NoError(err) originSchedulers := scheduleConfig.Schedulers scheduleConfig.Schedulers = sc.SchedulerConfigs{sc.SchedulerConfig{Type: "shuffle-leader", Disable: true}} body, err = json.Marshal(scheduleConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, u, body, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, u, body, tu.StatusOK(re)) re.NoError(err) assertNoScheduler(re, urlPrefix, name) @@ -659,7 +659,7 @@ func (suite *scheduleTestSuite) checkDisable(cluster *tests.TestCluster) { scheduleConfig.Schedulers = originSchedulers body, err = json.Marshal(scheduleConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, u, body, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, u, body, tu.StatusOK(re)) re.NoError(err) deleteScheduler(re, urlPrefix, name) @@ -667,13 +667,13 @@ func (suite *scheduleTestSuite) checkDisable(cluster *tests.TestCluster) { } func addScheduler(re *require.Assertions, urlPrefix string, body []byte) { - err := tu.CheckPostJSON(testDialClient, urlPrefix, body, tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix, body, tu.StatusOK(re)) re.NoError(err) } func deleteScheduler(re *require.Assertions, urlPrefix string, createdName string) { deleteURL := fmt.Sprintf("%s/%s", urlPrefix, createdName) - err := tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) + err := tu.CheckDelete(tests.TestDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) } @@ -682,9 +682,9 @@ func (suite *scheduleTestSuite) testPauseOrResume(re *require.Assertions, urlPre createdName = name } var schedulers []string - tu.ReadGetJSON(re, testDialClient, urlPrefix, &schedulers) + tu.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &schedulers) if !slice.Contains(schedulers, createdName) { - err := tu.CheckPostJSON(testDialClient, urlPrefix, body, tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix, body, tu.StatusOK(re)) re.NoError(err) } suite.assertSchedulerExists(urlPrefix, createdName) // wait for scheduler to be synced. @@ -694,14 +694,14 @@ func (suite *scheduleTestSuite) testPauseOrResume(re *require.Assertions, urlPre input["delay"] = 30 pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) re.NoError(err) isPaused := isSchedulerPaused(re, urlPrefix, createdName) re.True(isPaused) input["delay"] = 1 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) re.NoError(err) time.Sleep(time.Second * 2) isPaused = isSchedulerPaused(re, urlPrefix, createdName) @@ -712,12 +712,12 @@ func (suite *scheduleTestSuite) testPauseOrResume(re *require.Assertions, urlPre input["delay"] = 30 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) re.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) re.NoError(err) isPaused = isSchedulerPaused(re, urlPrefix, createdName) re.False(isPaused) @@ -742,7 +742,7 @@ func (suite *scheduleTestSuite) checkEmptySchedulers(cluster *tests.TestCluster) } for _, query := range []string{"", "?status=paused", "?status=disabled"} { schedulers := make([]string, 0) - re.NoError(tu.ReadGetJSON(re, testDialClient, urlPrefix+query, &schedulers)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, urlPrefix+query, &schedulers)) for _, scheduler := range schedulers { if strings.Contains(query, "disable") { input := make(map[string]any) @@ -755,7 +755,7 @@ func (suite *scheduleTestSuite) checkEmptySchedulers(cluster *tests.TestCluster) } } tu.Eventually(re, func() bool { - resp, err := apiutil.GetJSON(testDialClient, urlPrefix+query, nil) + resp, err := apiutil.GetJSON(tests.TestDialClient, urlPrefix+query, nil) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -770,7 +770,7 @@ func (suite *scheduleTestSuite) assertSchedulerExists(urlPrefix string, schedule var schedulers []string re := suite.Require() tu.Eventually(re, func() bool { - err := tu.ReadGetJSON(re, testDialClient, urlPrefix, &schedulers, + err := tu.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &schedulers, tu.StatusOK(re)) re.NoError(err) return slice.Contains(schedulers, scheduler) @@ -780,7 +780,7 @@ func (suite *scheduleTestSuite) assertSchedulerExists(urlPrefix string, schedule func assertNoScheduler(re *require.Assertions, urlPrefix string, scheduler string) { var schedulers []string tu.Eventually(re, func() bool { - err := tu.ReadGetJSON(re, testDialClient, urlPrefix, &schedulers, + err := tu.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &schedulers, tu.StatusOK(re)) re.NoError(err) return !slice.Contains(schedulers, scheduler) @@ -789,7 +789,7 @@ func assertNoScheduler(re *require.Assertions, urlPrefix string, scheduler strin func isSchedulerPaused(re *require.Assertions, urlPrefix, name string) bool { var schedulers []string - err := tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s?status=paused", urlPrefix), &schedulers, + err := tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s?status=paused", urlPrefix), &schedulers, tu.StatusOK(re)) re.NoError(err) for _, scheduler := range schedulers { diff --git a/tests/server/api/testutil.go b/tests/server/api/testutil.go index 1b2f3d09e3d..163a25c9bbb 100644 --- a/tests/server/api/testutil.go +++ b/tests/server/api/testutil.go @@ -23,6 +23,7 @@ import ( "path" "github.com/stretchr/testify/require" + "github.com/tikv/pd/tests" ) const ( @@ -30,13 +31,6 @@ const ( schedulerConfigPrefix = "/pd/api/v1/scheduler-config" ) -// dialClient used to dial http request. -var dialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, -} - // MustAddScheduler adds a scheduler with HTTP API. func MustAddScheduler( re *require.Assertions, serverAddr string, @@ -53,7 +47,7 @@ func MustAddScheduler( httpReq, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s", serverAddr, schedulersPrefix), bytes.NewBuffer(data)) re.NoError(err) // Send request. - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err = io.ReadAll(resp.Body) @@ -65,7 +59,7 @@ func MustAddScheduler( func MustDeleteScheduler(re *require.Assertions, serverAddr, schedulerName string) { httpReq, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s%s/%s", serverAddr, schedulersPrefix, schedulerName), http.NoBody) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err := io.ReadAll(resp.Body) @@ -84,7 +78,7 @@ func MustCallSchedulerConfigAPI( args = append([]string{schedulerConfigPrefix, schedulerName}, args...) httpReq, err := http.NewRequest(method, fmt.Sprintf("%s%s", serverAddr, path.Join(args...)), bytes.NewBuffer(data)) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err = io.ReadAll(resp.Body) diff --git a/tests/server/apiv2/handlers/testutil.go b/tests/server/apiv2/handlers/testutil.go index c5682aafbce..1a40e8d1ac7 100644 --- a/tests/server/apiv2/handlers/testutil.go +++ b/tests/server/apiv2/handlers/testutil.go @@ -34,13 +34,6 @@ const ( keyspaceGroupsPrefix = "/pd/api/v2/tso/keyspace-groups" ) -// dialClient used to dial http request. -var dialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, -} - func sendLoadRangeRequest(re *require.Assertions, server *tests.TestServer, token, limit string) *handlers.LoadAllKeyspacesResponse { // Construct load range request. httpReq, err := http.NewRequest(http.MethodGet, server.GetAddr()+keyspacesPrefix, http.NoBody) @@ -50,7 +43,7 @@ func sendLoadRangeRequest(re *require.Assertions, server *tests.TestServer, toke query.Add("limit", limit) httpReq.URL.RawQuery = query.Encode() // Send request. - httpResp, err := dialClient.Do(httpReq) + httpResp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer httpResp.Body.Close() re.Equal(http.StatusOK, httpResp.StatusCode) @@ -67,7 +60,7 @@ func sendUpdateStateRequest(re *require.Assertions, server *tests.TestServer, na re.NoError(err) httpReq, err := http.NewRequest(http.MethodPut, server.GetAddr()+keyspacesPrefix+"/"+name+"/state", bytes.NewBuffer(data)) re.NoError(err) - httpResp, err := dialClient.Do(httpReq) + httpResp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer httpResp.Body.Close() if httpResp.StatusCode != http.StatusOK { @@ -86,7 +79,7 @@ func MustCreateKeyspace(re *require.Assertions, server *tests.TestServer, reques re.NoError(err) httpReq, err := http.NewRequest(http.MethodPost, server.GetAddr()+keyspacesPrefix, bytes.NewBuffer(data)) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -110,7 +103,7 @@ func mustUpdateKeyspaceConfig(re *require.Assertions, server *tests.TestServer, re.NoError(err) httpReq, err := http.NewRequest(http.MethodPatch, server.GetAddr()+keyspacesPrefix+"/"+name+"/config", bytes.NewBuffer(data)) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -122,7 +115,7 @@ func mustUpdateKeyspaceConfig(re *require.Assertions, server *tests.TestServer, } func mustLoadKeyspaces(re *require.Assertions, server *tests.TestServer, name string) *keyspacepb.KeyspaceMeta { - resp, err := dialClient.Get(server.GetAddr() + keyspacesPrefix + "/" + name) + resp, err := tests.TestDialClient.Get(server.GetAddr() + keyspacesPrefix + "/" + name) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -143,7 +136,7 @@ func MustLoadKeyspaceGroups(re *require.Assertions, server *tests.TestServer, to query.Add("limit", limit) httpReq.URL.RawQuery = query.Encode() // Send request. - httpResp, err := dialClient.Do(httpReq) + httpResp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer httpResp.Body.Close() data, err := io.ReadAll(httpResp.Body) @@ -159,7 +152,7 @@ func tryCreateKeyspaceGroup(re *require.Assertions, server *tests.TestServer, re re.NoError(err) httpReq, err := http.NewRequest(http.MethodPost, server.GetAddr()+keyspaceGroupsPrefix, bytes.NewBuffer(data)) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err = io.ReadAll(resp.Body) @@ -184,7 +177,7 @@ func MustLoadKeyspaceGroupByID(re *require.Assertions, server *tests.TestServer, func TryLoadKeyspaceGroupByID(re *require.Assertions, server *tests.TestServer, id uint32) (*endpoint.KeyspaceGroup, int) { httpReq, err := http.NewRequest(http.MethodGet, server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d", id), http.NoBody) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err := io.ReadAll(resp.Body) @@ -214,7 +207,7 @@ func FailCreateKeyspaceGroupWithCode(re *require.Assertions, server *tests.TestS func MustDeleteKeyspaceGroup(re *require.Assertions, server *tests.TestServer, id uint32) { httpReq, err := http.NewRequest(http.MethodDelete, server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d", id), http.NoBody) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err := io.ReadAll(resp.Body) @@ -229,7 +222,7 @@ func MustSplitKeyspaceGroup(re *require.Assertions, server *tests.TestServer, id httpReq, err := http.NewRequest(http.MethodPost, server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d/split", id), bytes.NewBuffer(data)) re.NoError(err) // Send request. - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err = io.ReadAll(resp.Body) @@ -245,7 +238,7 @@ func MustFinishSplitKeyspaceGroup(re *require.Assertions, server *tests.TestServ return false } // Send request. - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) if err != nil { return false } @@ -270,7 +263,7 @@ func MustMergeKeyspaceGroup(re *require.Assertions, server *tests.TestServer, id httpReq, err := http.NewRequest(http.MethodPost, server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d/merge", id), bytes.NewBuffer(data)) re.NoError(err) // Send request. - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err = io.ReadAll(resp.Body) diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index aea5ff73968..61a4561c55a 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -753,20 +753,19 @@ func TestConcurrentHandleRegion(t *testing.T) { re.NoError(err) peerID, err := id.Alloc() re.NoError(err) - regionID, err := id.Alloc() - re.NoError(err) peer := &metapb.Peer{Id: peerID, StoreId: store.GetId()} regionReq := &pdpb.RegionHeartbeatRequest{ Header: testutil.NewRequestHeader(clusterID), Region: &metapb.Region{ - Id: regionID, + // mock error msg to trigger stream.Recv() + Id: 0, Peers: []*metapb.Peer{peer}, }, Leader: peer, } err = stream.Send(regionReq) re.NoError(err) - // make sure the first store can receive one response + // make sure the first store can receive one response(error msg) if i == 0 { wg.Add(1) } diff --git a/tests/server/config/config_test.go b/tests/server/config/config_test.go index 57e4272f7ea..67d7478caa0 100644 --- a/tests/server/config/config_test.go +++ b/tests/server/config/config_test.go @@ -36,13 +36,6 @@ import ( "github.com/tikv/pd/tests" ) -// testDialClient used to dial http request. -var testDialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, -} - func TestRateLimitConfigReload(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) @@ -65,7 +58,7 @@ func TestRateLimitConfigReload(t *testing.T) { data, err := json.Marshal(input) re.NoError(err) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err := testDialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) @@ -109,7 +102,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { addr := fmt.Sprintf("%s/pd/api/v1/config", urlPrefix) cfg := &config.Config{} tu.Eventually(re, func() bool { - err := tu.ReadGetJSON(re, testDialClient, addr, cfg) + err := tu.ReadGetJSON(re, tests.TestDialClient, addr, cfg) re.NoError(err) return cfg.PDServerCfg.DashboardAddress != "auto" }) @@ -118,7 +111,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { r := map[string]int{"max-replicas": 5} postData, err := json.Marshal(r) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) l := map[string]any{ "location-labels": "zone,rack", @@ -126,7 +119,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) l = map[string]any{ @@ -134,7 +127,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) cfg.Replication.MaxReplicas = 5 cfg.Replication.LocationLabels = []string{"zone", "rack"} @@ -143,7 +136,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { tu.Eventually(re, func() bool { newCfg := &config.Config{} - err = tu.ReadGetJSON(re, testDialClient, addr, newCfg) + err = tu.ReadGetJSON(re, tests.TestDialClient, addr, newCfg) re.NoError(err) return suite.Equal(newCfg, cfg) }) @@ -160,7 +153,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) cfg.Schedule.EnableTiKVSplitRegion = false cfg.Schedule.TolerantSizeRatio = 2.5 @@ -174,7 +167,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { cfg.ClusterVersion = *v tu.Eventually(re, func() bool { newCfg1 := &config.Config{} - err = tu.ReadGetJSON(re, testDialClient, addr, newCfg1) + err = tu.ReadGetJSON(re, tests.TestDialClient, addr, newCfg1) re.NoError(err) return suite.Equal(cfg, newCfg1) }) @@ -183,7 +176,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { l["schedule.enable-tikv-split-region"] = "true" postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) // illegal prefix @@ -192,7 +185,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusNotOK(re), tu.StringContain(re, "not found")) re.NoError(err) @@ -203,7 +196,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusNotOK(re), tu.StringContain(re, "cannot update config prefix")) re.NoError(err) @@ -214,7 +207,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringContain(re, "not found")) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusNotOK(re), tu.StringContain(re, "not found")) re.NoError(err) } @@ -230,16 +223,16 @@ func (suite *configTestSuite) checkConfigSchedule(cluster *tests.TestCluster) { addr := fmt.Sprintf("%s/pd/api/v1/config/schedule", urlPrefix) scheduleConfig := &sc.ScheduleConfig{} - re.NoError(tu.ReadGetJSON(re, testDialClient, addr, scheduleConfig)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, addr, scheduleConfig)) scheduleConfig.MaxStoreDownTime.Duration = time.Second postData, err := json.Marshal(scheduleConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { scheduleConfig1 := &sc.ScheduleConfig{} - re.NoError(tu.ReadGetJSON(re, testDialClient, addr, scheduleConfig1)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, addr, scheduleConfig1)) return reflect.DeepEqual(*scheduleConfig1, *scheduleConfig) }) } @@ -255,33 +248,33 @@ func (suite *configTestSuite) checkConfigReplication(cluster *tests.TestCluster) addr := fmt.Sprintf("%s/pd/api/v1/config/replicate", urlPrefix) rc := &sc.ReplicationConfig{} - err := tu.ReadGetJSON(re, testDialClient, addr, rc) + err := tu.ReadGetJSON(re, tests.TestDialClient, addr, rc) re.NoError(err) rc.MaxReplicas = 5 rc1 := map[string]int{"max-replicas": 5} postData, err := json.Marshal(rc1) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) rc.LocationLabels = []string{"zone", "rack"} rc2 := map[string]string{"location-labels": "zone,rack"} postData, err = json.Marshal(rc2) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) rc.IsolationLevel = "zone" rc3 := map[string]string{"isolation-level": "zone"} postData, err = json.Marshal(rc3) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) rc4 := &sc.ReplicationConfig{} tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, addr, rc4) + err = tu.ReadGetJSON(re, tests.TestDialClient, addr, rc4) re.NoError(err) return reflect.DeepEqual(*rc4, *rc) }) @@ -299,7 +292,7 @@ func (suite *configTestSuite) checkConfigLabelProperty(cluster *tests.TestCluste addr := urlPrefix + "/pd/api/v1/config/label-property" loadProperties := func() config.LabelPropertyConfig { var cfg config.LabelPropertyConfig - err := tu.ReadGetJSON(re, testDialClient, addr, &cfg) + err := tu.ReadGetJSON(re, tests.TestDialClient, addr, &cfg) re.NoError(err) return cfg } @@ -313,7 +306,7 @@ func (suite *configTestSuite) checkConfigLabelProperty(cluster *tests.TestCluste `{"type": "bar", "action": "set", "label-key": "host", "label-value": "h1"}`, } for _, cmd := range cmds { - err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, addr, []byte(cmd), tu.StatusOK(re)) re.NoError(err) } @@ -330,7 +323,7 @@ func (suite *configTestSuite) checkConfigLabelProperty(cluster *tests.TestCluste `{"type": "bar", "action": "delete", "label-key": "host", "label-value": "h1"}`, } for _, cmd := range cmds { - err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, addr, []byte(cmd), tu.StatusOK(re)) re.NoError(err) } @@ -353,7 +346,7 @@ func (suite *configTestSuite) checkConfigDefault(cluster *tests.TestCluster) { r := map[string]int{"max-replicas": 5} postData, err := json.Marshal(r) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) l := map[string]any{ "location-labels": "zone,rack", @@ -361,7 +354,7 @@ func (suite *configTestSuite) checkConfigDefault(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) l = map[string]any{ @@ -369,12 +362,12 @@ func (suite *configTestSuite) checkConfigDefault(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) addr = fmt.Sprintf("%s/pd/api/v1/config/default", urlPrefix) defaultCfg := &config.Config{} - err = tu.ReadGetJSON(re, testDialClient, addr, defaultCfg) + err = tu.ReadGetJSON(re, tests.TestDialClient, addr, defaultCfg) re.NoError(err) re.Equal(uint64(3), defaultCfg.Replication.MaxReplicas) @@ -398,10 +391,10 @@ func (suite *configTestSuite) checkConfigPDServer(cluster *tests.TestCluster) { } postData, err := json.Marshal(ms) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, addrPost, postData, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, addrPost, postData, tu.StatusOK(re))) addrGet := fmt.Sprintf("%s/pd/api/v1/config/pd-server", urlPrefix) sc := &config.PDServerConfig{} - re.NoError(tu.ReadGetJSON(re, testDialClient, addrGet, sc)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, addrGet, sc)) re.Equal(bool(true), sc.UseRegionStorage) re.Equal("table", sc.KeyType) re.Equal(typeutil.StringSlice([]string{}), sc.RuntimeServices) @@ -525,28 +518,28 @@ func (suite *configTestSuite) checkConfigTTL(cluster *tests.TestCluster) { re.NoError(err) // test no config and cleaning up - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfig(re, cluster, false) // test time goes by - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 5), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 5), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfig(re, cluster, true) time.Sleep(5 * time.Second) assertTTLConfig(re, cluster, false) // test cleaning up - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 5), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 5), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfig(re, cluster, true) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfig(re, cluster, false) postData, err = json.Marshal(invalidTTLConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 1), postData, + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 1), postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"unsupported ttl config schedule.invalid-ttl-config\"\n")) re.NoError(err) @@ -557,7 +550,7 @@ func (suite *configTestSuite) checkConfigTTL(cluster *tests.TestCluster) { postData, err = json.Marshal(mergeConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 1), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 1), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfigItemEqual(re, cluster, "max-merge-region-size", uint64(999)) // max-merge-region-keys should keep consistence with max-merge-region-size. @@ -569,7 +562,7 @@ func (suite *configTestSuite) checkConfigTTL(cluster *tests.TestCluster) { } postData, err = json.Marshal(mergeConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 10), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 10), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfigItemEqual(re, cluster, "enable-tikv-split-region", true) } @@ -585,7 +578,7 @@ func (suite *configTestSuite) checkTTLConflict(cluster *tests.TestCluster) { addr := createTTLUrl(urlPrefix, 1) postData, err := json.Marshal(ttlConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfig(re, cluster, true) @@ -593,16 +586,16 @@ func (suite *configTestSuite) checkTTLConflict(cluster *tests.TestCluster) { postData, err = json.Marshal(cfg) re.NoError(err) addr = fmt.Sprintf("%s/pd/api/v1/config", urlPrefix) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) re.NoError(err) addr = fmt.Sprintf("%s/pd/api/v1/config/schedule", urlPrefix) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) re.NoError(err) cfg = map[string]any{"schedule.max-snapshot-count": 30} postData, err = json.Marshal(cfg) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) } diff --git a/tests/server/member/member_test.go b/tests/server/member/member_test.go index 92ed11a75ce..c581eb39390 100644 --- a/tests/server/member/member_test.go +++ b/tests/server/member/member_test.go @@ -84,15 +84,13 @@ func TestMemberDelete(t *testing.T) { {path: fmt.Sprintf("id/%d", members[1].GetServerID()), members: []*config.Config{leader.GetConfig()}}, } - httpClient := &http.Client{Timeout: 15 * time.Second, Transport: &http.Transport{DisableKeepAlives: true}} - defer httpClient.CloseIdleConnections() for _, table := range tables { t.Log(time.Now(), "try to delete:", table.path) testutil.Eventually(re, func() bool { addr := leader.GetConfig().ClientUrls + "/pd/api/v1/members/" + table.path req, err := http.NewRequest(http.MethodDelete, addr, http.NoBody) re.NoError(err) - res, err := httpClient.Do(req) + res, err := tests.TestDialClient.Do(req) re.NoError(err) defer res.Body.Close() // Check by status. @@ -105,7 +103,7 @@ func TestMemberDelete(t *testing.T) { } // Check by member list. cluster.WaitLeader() - if err = checkMemberList(re, *httpClient, leader.GetConfig().ClientUrls, table.members); err != nil { + if err = checkMemberList(re, leader.GetConfig().ClientUrls, table.members); err != nil { t.Logf("check member fail: %v", err) time.Sleep(time.Second) return false @@ -122,9 +120,9 @@ func TestMemberDelete(t *testing.T) { } } -func checkMemberList(re *require.Assertions, httpClient http.Client, clientURL string, configs []*config.Config) error { +func checkMemberList(re *require.Assertions, clientURL string, configs []*config.Config) error { addr := clientURL + "/pd/api/v1/members" - res, err := httpClient.Get(addr) + res, err := tests.TestDialClient.Get(addr) re.NoError(err) defer res.Body.Close() buf, err := io.ReadAll(res.Body) @@ -183,7 +181,7 @@ func TestLeaderPriority(t *testing.T) { func post(t *testing.T, re *require.Assertions, url string, body string) { testutil.Eventually(re, func() bool { - res, err := http.Post(url, "", bytes.NewBufferString(body)) // #nosec + res, err := tests.TestDialClient.Post(url, "", bytes.NewBufferString(body)) // #nosec re.NoError(err) b, err := io.ReadAll(res.Body) res.Body.Close() diff --git a/tests/testutil.go b/tests/testutil.go index 495dd547c4f..79917bf9961 100644 --- a/tests/testutil.go +++ b/tests/testutil.go @@ -17,8 +17,12 @@ package tests import ( "context" "fmt" + "math/rand" + "net" + "net/http" "os" "runtime" + "strconv" "strings" "sync" "testing" @@ -45,6 +49,45 @@ import ( "go.uber.org/zap" ) +var ( + TestDialClient = &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + }, + } + + testPortMutex sync.Mutex + testPortMap = make(map[string]struct{}) +) + +// SetRangePort sets the range of ports for test. +func SetRangePort(start, end int) { + portRange := []int{start, end} + dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &net.Dialer{} + randomPort := strconv.Itoa(rand.Intn(portRange[1]-portRange[0]) + portRange[0]) + testPortMutex.Lock() + for i := 0; i < 10; i++ { + if _, ok := testPortMap[randomPort]; !ok { + break + } + randomPort = strconv.Itoa(rand.Intn(portRange[1]-portRange[0]) + portRange[0]) + } + testPortMutex.Unlock() + localAddr, err := net.ResolveTCPAddr(network, "0.0.0.0:"+randomPort) + if err != nil { + return nil, err + } + dialer.LocalAddr = localAddr + return dialer.DialContext(ctx, network, addr) + } + + TestDialClient.Transport = &http.Transport{ + DisableKeepAlives: true, + DialContext: dialContext, + } +} + var once sync.Once // InitLogger initializes the logger for test. diff --git a/tools/go.mod b/tools/go.mod index 8d0f0d4ec35..2febbe1ad68 100644 --- a/tools/go.mod +++ b/tools/go.mod @@ -35,6 +35,7 @@ require ( go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 golang.org/x/text v0.14.0 + golang.org/x/tools v0.14.0 google.golang.org/grpc v1.62.1 ) @@ -172,7 +173,6 @@ require ( golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.14.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda // indirect diff --git a/tools/pd-ctl/tests/keyspace/keyspace_group_test.go b/tools/pd-ctl/tests/keyspace/keyspace_group_test.go index 87fd17a97d4..1e3763d5d6e 100644 --- a/tools/pd-ctl/tests/keyspace/keyspace_group_test.go +++ b/tools/pd-ctl/tests/keyspace/keyspace_group_test.go @@ -263,7 +263,7 @@ func TestSetNodeAndPriorityKeyspaceGroup(t *testing.T) { args := []string{"-u", pdAddr, "keyspace-group", "set-node", defaultKeyspaceGroupID, tsoAddrs[0]} output, err := tests.ExecuteCommand(cmd, args...) re.NoError(err) - re.Contains(string(output), "invalid num of nodes") + re.Contains(string(output), "Success!") args = []string{"-u", pdAddr, "keyspace-group", "set-node", defaultKeyspaceGroupID, "", ""} output, err = tests.ExecuteCommand(cmd, args...) re.NoError(err) diff --git a/tools/pd-ut/README.md b/tools/pd-ut/README.md index 77b59bea4f7..805ee5cf322 100644 --- a/tools/pd-ut/README.md +++ b/tools/pd-ut/README.md @@ -63,4 +63,8 @@ pd-ut run --junitfile xxx // test with race flag pd-ut run --race + +// test with coverprofile +pd-ut run --coverprofile xxx +go tool cover --func=xxx ``` diff --git a/tools/pd-ut/coverProfile.go b/tools/pd-ut/coverProfile.go new file mode 100644 index 00000000000..0ed1c3f3c61 --- /dev/null +++ b/tools/pd-ut/coverProfile.go @@ -0,0 +1,176 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bufio" + "fmt" + "os" + "path" + "sort" + + "golang.org/x/tools/cover" +) + +func collectCoverProfileFile() { + // Combine all the cover file of single test function into a whole. + files, err := os.ReadDir(coverFileTempDir) + if err != nil { + fmt.Println("collect cover file error:", err) + os.Exit(-1) + } + + w, err := os.Create(coverProfile) + if err != nil { + fmt.Println("create cover file error:", err) + os.Exit(-1) + } + //nolint: errcheck + defer w.Close() + w.WriteString("mode: atomic\n") + + result := make(map[string]*cover.Profile) + for _, file := range files { + if file.IsDir() { + continue + } + collectOneCoverProfileFile(result, file) + } + + w1 := bufio.NewWriter(w) + for _, prof := range result { + for _, block := range prof.Blocks { + fmt.Fprintf(w1, "%s:%d.%d,%d.%d %d %d\n", + prof.FileName, + block.StartLine, + block.StartCol, + block.EndLine, + block.EndCol, + block.NumStmt, + block.Count, + ) + } + if err := w1.Flush(); err != nil { + fmt.Println("flush data to cover profile file error:", err) + os.Exit(-1) + } + } +} + +func collectOneCoverProfileFile(result map[string]*cover.Profile, file os.DirEntry) { + f, err := os.Open(path.Join(coverFileTempDir, file.Name())) + if err != nil { + fmt.Println("open temp cover file error:", err) + os.Exit(-1) + } + //nolint: errcheck + defer f.Close() + + profs, err := cover.ParseProfilesFromReader(f) + if err != nil { + fmt.Println("parse cover profile file error:", err) + os.Exit(-1) + } + mergeProfile(result, profs) +} + +func mergeProfile(m map[string]*cover.Profile, profs []*cover.Profile) { + for _, prof := range profs { + sort.Sort(blocksByStart(prof.Blocks)) + old, ok := m[prof.FileName] + if !ok { + m[prof.FileName] = prof + continue + } + + // Merge samples from the same location. + // The data has already been sorted. + tmp := old.Blocks[:0] + var i, j int + for i < len(old.Blocks) && j < len(prof.Blocks) { + v1 := old.Blocks[i] + v2 := prof.Blocks[j] + + switch compareProfileBlock(v1, v2) { + case -1: + tmp = appendWithReduce(tmp, v1) + i++ + case 1: + tmp = appendWithReduce(tmp, v2) + j++ + default: + tmp = appendWithReduce(tmp, v1) + tmp = appendWithReduce(tmp, v2) + i++ + j++ + } + } + for ; i < len(old.Blocks); i++ { + tmp = appendWithReduce(tmp, old.Blocks[i]) + } + for ; j < len(prof.Blocks); j++ { + tmp = appendWithReduce(tmp, prof.Blocks[j]) + } + + m[prof.FileName] = old + } +} + +// appendWithReduce works like append(), but it merge the duplicated values. +func appendWithReduce(input []cover.ProfileBlock, b cover.ProfileBlock) []cover.ProfileBlock { + if len(input) >= 1 { + last := &input[len(input)-1] + if b.StartLine == last.StartLine && + b.StartCol == last.StartCol && + b.EndLine == last.EndLine && + b.EndCol == last.EndCol { + if b.NumStmt != last.NumStmt { + panic(fmt.Errorf("inconsistent NumStmt: changed from %d to %d", last.NumStmt, b.NumStmt)) + } + // Merge the data with the last one of the slice. + last.Count |= b.Count + return input + } + } + return append(input, b) +} + +type blocksByStart []cover.ProfileBlock + +func compareProfileBlock(x, y cover.ProfileBlock) int { + if x.StartLine < y.StartLine { + return -1 + } + if x.StartLine > y.StartLine { + return 1 + } + + // Now x.StartLine == y.StartLine + if x.StartCol < y.StartCol { + return -1 + } + if x.StartCol > y.StartCol { + return 1 + } + + return 0 +} + +func (b blocksByStart) Len() int { return len(b) } +func (b blocksByStart) Swap(i, j int) { b[i], b[j] = b[j], b[i] } +func (b blocksByStart) Less(i, j int) bool { + bi, bj := b[i], b[j] + return bi.StartLine < bj.StartLine || bi.StartLine == bj.StartLine && bi.StartCol < bj.StartCol +} diff --git a/tools/pd-ut/ut.go b/tools/pd-ut/ut.go index 7fc96ee11cf..9419363c152 100644 --- a/tools/pd-ut/ut.go +++ b/tools/pd-ut/ut.go @@ -74,27 +74,49 @@ pd-ut build xxx pd-ut run --junitfile xxx // test with race flag -pd-ut run --race` +pd-ut run --race + +// test with coverprofile +pd-ut run --coverprofile xxx +go tool cover --func=xxx` fmt.Println(msg) return true } -const modulePath = "github.com/tikv/pd" +var ( + modulePath = "github.com/tikv/pd" + integrationsTestPath = "tests/integrations" +) var ( // runtime - p int - buildParallel int - workDir string + p int + buildParallel int + workDir string + coverFileTempDir string // arguments - race bool - junitFile string + race bool + junitFile string + coverProfile string + ignoreDir string ) func main() { race = handleFlag("--race") junitFile = stripFlag("--junitfile") + coverProfile = stripFlag("--coverprofile") + ignoreDir = stripFlag("--ignore") + + if coverProfile != "" { + var err error + coverFileTempDir, err = os.MkdirTemp(os.TempDir(), "cov") + if err != nil { + fmt.Println("create temp dir fail", coverFileTempDir) + os.Exit(1) + } + defer os.RemoveAll(coverFileTempDir) + } // Get the correct count of CPU if it's in docker. p = runtime.GOMAXPROCS(0) @@ -120,6 +142,18 @@ func main() { isSucceed = cmdBuild(os.Args[2:]...) case "run": isSucceed = cmdRun(os.Args[2:]...) + case "it": + // run integration tests + if len(os.Args) >= 3 { + modulePath = path.Join(modulePath, integrationsTestPath) + workDir = path.Join(workDir, integrationsTestPath) + switch os.Args[2] { + case "run": + isSucceed = cmdRun(os.Args[3:]...) + default: + isSucceed = usage() + } + } default: isSucceed = usage() } @@ -204,10 +238,16 @@ func cmdBuild(args ...string) bool { // build test binary of a single package if len(args) >= 1 { - pkg := args[0] - err := buildTestBinary(pkg) + var dirPkgs []string + for _, pkg := range pkgs { + if strings.Contains(pkg, args[0]) { + dirPkgs = append(dirPkgs, pkg) + } + } + + err := buildTestBinaryMulti(dirPkgs) if err != nil { - log.Println("build package error", pkg, err) + log.Println("build package error", dirPkgs, err) return false } } @@ -248,23 +288,32 @@ func cmdRun(args ...string) bool { // run tests for a single package if len(args) == 1 { - pkg := args[0] - err := buildTestBinary(pkg) - if err != nil { - log.Println("build package error", pkg, err) - return false + var dirPkgs []string + for _, pkg := range pkgs { + if strings.Contains(pkg, args[0]) { + dirPkgs = append(dirPkgs, pkg) + } } - exist, err := testBinaryExist(pkg) + + err := buildTestBinaryMulti(dirPkgs) if err != nil { - log.Println("check test binary existence error", err) + log.Println("build package error", dirPkgs, err) return false } - if !exist { - fmt.Println("no test case in ", pkg) - return false + for _, pkg := range dirPkgs { + exist, err := testBinaryExist(pkg) + if err != nil { + fmt.Println("check test binary existence error", err) + return false + } + if !exist { + fmt.Println("no test case in ", pkg) + continue + } + + tasks = listTestCases(pkg, tasks) } - tasks = listTestCases(pkg, tasks) } // run a single test @@ -326,6 +375,10 @@ func cmdRun(args ...string) bool { } } + if coverProfile != "" { + collectCoverProfileFile() + } + for _, work := range works { if work.Fail { return false @@ -336,7 +389,7 @@ func cmdRun(args ...string) bool { // stripFlag strip the '--flag xxx' from the command line os.Args // Example of the os.Args changes -// Before: ut run pkg TestXXX --junitfile yyy +// Before: ut run pkg TestXXX --coverprofile xxx --junitfile yyy // After: ut run pkg TestXXX // The value of the flag is returned. func stripFlag(flag string) string { @@ -421,6 +474,7 @@ func filterTestCases(tasks []task, arg1 string) ([]task, error) { func listPackages() ([]string, error) { cmd := exec.Command("go", "list", "./...") + cmd.Dir = workDir ss, err := cmdToLines(cmd) if err != nil { return nil, withTrace(err) @@ -565,7 +619,16 @@ func failureCases(input []JUnitTestCase) int { func (*numa) testCommand(pkg string, fn string) *exec.Cmd { args := make([]string, 0, 10) exe := "./" + testFileName(pkg) - args = append(args, "-test.cpu", "1") + if coverProfile != "" { + fileName := strings.ReplaceAll(pkg, "/", "_") + "." + fn + tmpFile := path.Join(coverFileTempDir, fileName) + args = append(args, "-test.coverprofile", tmpFile) + } + if strings.Contains(fn, "Suite") { + args = append(args, "-test.cpu", fmt.Sprint(p/2)) + } else { + args = append(args, "-test.cpu", "1") + } if !race { args = append(args, []string{"-test.timeout", "2m"}...) } else { @@ -580,7 +643,10 @@ func (*numa) testCommand(pkg string, fn string) *exec.Cmd { } func skipDIR(pkg string) bool { - skipDir := []string{"tests", "bin", "cmd", "tools"} + skipDir := []string{"bin", "cmd", "realcluster"} + if ignoreDir != "" { + skipDir = append(skipDir, ignoreDir) + } for _, ignore := range skipDir { if strings.HasPrefix(pkg, ignore) { return true @@ -593,8 +659,14 @@ func generateBuildCache() error { // cd cmd/pd-server && go test -tags=tso_function_test,deadlock -exec-=true -vet=off -toolexec=go-compile-without-link cmd := exec.Command("go", "test", "-exec=true", "-vet", "off", "--tags=tso_function_test,deadlock") goCompileWithoutLink := fmt.Sprintf("-toolexec=%s/tools/pd-ut/go-compile-without-link.sh", workDir) - cmd.Args = append(cmd.Args, goCompileWithoutLink) cmd.Dir = fmt.Sprintf("%s/cmd/pd-server", workDir) + if strings.Contains(workDir, integrationsTestPath) { + cmd.Dir = fmt.Sprintf("%s/cmd/pd-server", workDir[:strings.LastIndex(workDir, integrationsTestPath)]) + goCompileWithoutLink = fmt.Sprintf("-toolexec=%s/tools/pd-ut/go-compile-without-link.sh", + workDir[:strings.LastIndex(workDir, integrationsTestPath)]) + } + cmd.Args = append(cmd.Args, goCompileWithoutLink) + cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { @@ -612,7 +684,11 @@ func buildTestBinaryMulti(pkgs []string) error { } // go test --exec=xprog --tags=tso_function_test,deadlock -vet=off --count=0 $(pkgs) + // workPath just like `/data/nvme0n1/husharp/proj/pd/tests/integrations` xprogPath := path.Join(workDir, "bin/xprog") + if strings.Contains(workDir, integrationsTestPath) { + xprogPath = path.Join(workDir[:strings.LastIndex(workDir, integrationsTestPath)], "bin/xprog") + } packages := make([]string, 0, len(pkgs)) for _, pkg := range pkgs { packages = append(packages, path.Join(modulePath, pkg)) @@ -620,6 +696,13 @@ func buildTestBinaryMulti(pkgs []string) error { p := strconv.Itoa(buildParallel) cmd := exec.Command("go", "test", "-p", p, "--exec", xprogPath, "-vet", "off", "--tags=tso_function_test,deadlock") + if coverProfile != "" { + coverpkg := "./..." + if strings.Contains(workDir, integrationsTestPath) { + coverpkg = "../../..." + } + cmd.Args = append(cmd.Args, "-cover", fmt.Sprintf("-coverpkg=%s", coverpkg)) + } cmd.Args = append(cmd.Args, packages...) cmd.Dir = workDir cmd.Stdout = os.Stdout @@ -633,6 +716,9 @@ func buildTestBinaryMulti(pkgs []string) error { func buildTestBinary(pkg string) error { //nolint:gosec cmd := exec.Command("go", "test", "-c", "-vet", "off", "--tags=tso_function_test,deadlock", "-o", testFileName(pkg), "-v") + if coverProfile != "" { + cmd.Args = append(cmd.Args, "-cover", "-coverpkg=./...") + } if race { cmd.Args = append(cmd.Args, "-race") }