From 79a7508d4329b532baf7549f8b3fd7e59f9e5eaa Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Wed, 25 Dec 2024 15:41:14 +0800 Subject: [PATCH] add 400 status code Signed-off-by: lhy1024 --- pkg/mcs/scheduling/server/apis/v1/api.go | 4 ++++ server/api/region.go | 4 ++++ server/api/region_test.go | 2 ++ tests/integrations/mcs/scheduling/api_test.go | 3 +++ 4 files changed, 13 insertions(+) diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index c374456a6eb..1792bb7835c 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -1474,6 +1474,10 @@ func getRegionByID(c *gin.Context) { c.String(http.StatusBadRequest, err.Error()) return } + if regionID == 0 { + c.String(http.StatusBadRequest, errs.ErrRegionInvalidID.FastGenByArgs().Error()) + return + } regionInfo := svr.GetBasicCluster().GetRegion(regionID) if regionInfo == nil { c.String(http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID).Error()) diff --git a/server/api/region.go b/server/api/region.go index f5e35a16ffa..8f4b9a49017 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -65,6 +65,10 @@ func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } + if regionID == 0 { + h.rd.JSON(w, http.StatusBadRequest, errs.ErrRegionInvalidID.FastGenByArgs()) + return + } regionInfo := rc.GetRegion(regionID) if regionInfo == nil { diff --git a/server/api/region_test.go b/server/api/region_test.go index fa5e98afc28..b3a1e2684c1 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -79,6 +79,8 @@ func (suite *regionTestSuite) TestRegion() { re := suite.Require() mustRegionHeartbeat(re, suite.svr, r) url := fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, 0) + re.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusBadRequest))) + url = fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, 2333) re.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusNotFound))) url = fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, r.GetID()) r1 := &response.RegionInfo{} diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index f3e7f235018..303c0d538a7 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -711,6 +711,9 @@ func (suite *apiTestSuite) checkRegions(cluster *tests.TestCluster) { err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3., resp["count"]) + urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/0", scheServerAddr) + testutil.CheckGetJSON(tests.TestDialClient, urlPrefix, nil, + testutil.Status(re, http.StatusBadRequest)) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/233", scheServerAddr) testutil.CheckGetJSON(tests.TestDialClient, urlPrefix, nil, testutil.Status(re, http.StatusNotFound), testutil.StringContain(re, "not found"))