From 988c9a3b181f97304a0962cd5aa08c4e05ee8898 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Fri, 25 Oct 2024 11:51:53 +0800 Subject: [PATCH] server: refactor the independent service check (#8738) ref tikv/pd#8477 Signed-off-by: Ryan Leung Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/utils/apiutil/serverapi/middleware.go | 2 +- server/api/config.go | 6 +- server/cluster/cluster.go | 9 -- server/forward.go | 112 ++++++++++-------- server/grpc_service.go | 8 +- server/server.go | 12 ++ tests/integrations/mcs/scheduling/api_test.go | 2 +- tests/testutil.go | 2 +- 8 files changed, 86 insertions(+), 67 deletions(-) diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 23723d9b254..d6fc98082d6 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -129,7 +129,7 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri for _, rule := range h.microserviceRedirectRules { // Now we only support checking the scheduling service whether it is independent if rule.targetServiceName == constant.SchedulingServiceName { - if !h.s.GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName) { + if !h.s.IsServiceIndependent(constant.SchedulingServiceName) { continue } } diff --git a/server/api/config.go b/server/api/config.go index 7b011957d22..511f47284a9 100644 --- a/server/api/config.go +++ b/server/api/config.go @@ -62,7 +62,7 @@ func newConfHandler(svr *server.Server, rd *render.Render) *confHandler { // @Router /config [get] func (h *confHandler) GetConfig(w http.ResponseWriter, r *http.Request) { cfg := h.svr.GetConfig() - if h.svr.GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName) && + if h.svr.IsServiceIndependent(constant.SchedulingServiceName) && r.Header.Get(apiutil.XForbiddenForwardToMicroServiceHeader) != "true" { schedulingServerConfig, err := h.getSchedulingServerConfig() if err != nil { @@ -336,7 +336,7 @@ func getConfigMap(cfg map[string]any, key []string, value any) map[string]any { // @Success 200 {object} sc.ScheduleConfig // @Router /config/schedule [get] func (h *confHandler) GetScheduleConfig(w http.ResponseWriter, r *http.Request) { - if h.svr.GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName) && + if h.svr.IsServiceIndependent(constant.SchedulingServiceName) && r.Header.Get(apiutil.XForbiddenForwardToMicroServiceHeader) != "true" { cfg, err := h.getSchedulingServerConfig() if err != nil { @@ -409,7 +409,7 @@ func (h *confHandler) SetScheduleConfig(w http.ResponseWriter, r *http.Request) // @Success 200 {object} sc.ReplicationConfig // @Router /config/replicate [get] func (h *confHandler) GetReplicationConfig(w http.ResponseWriter, r *http.Request) { - if h.svr.GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName) && + if h.svr.IsServiceIndependent(constant.SchedulingServiceName) && r.Header.Get(apiutil.XForbiddenForwardToMicroServiceHeader) != "true" { cfg, err := h.getSchedulingServerConfig() if err != nil { diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 3869308d9dc..69b815e6b95 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -2510,25 +2510,16 @@ func IsClientURL(addr string, etcdClient *clientv3.Client) bool { // IsServiceIndependent returns whether the service is independent. func (c *RaftCluster) IsServiceIndependent(name string) bool { - if c == nil { - return false - } _, exist := c.independentServices.Load(name) return exist } // SetServiceIndependent sets the service to be independent. func (c *RaftCluster) SetServiceIndependent(name string) { - if c == nil { - return - } c.independentServices.Store(name, struct{}{}) } // UnsetServiceIndependent unsets the service to be independent. func (c *RaftCluster) UnsetServiceIndependent(name string) { - if c == nil { - return - } c.independentServices.Delete(name) } diff --git a/server/forward.go b/server/forward.go index 79aea2da119..7fbbb8e04f8 100644 --- a/server/forward.go +++ b/server/forward.go @@ -133,66 +133,82 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") return status.Error(codes.Unknown, err.Error()) } - - forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), constant.TSOServiceName) - if !ok || len(forwardedHost) == 0 { - tsoStreamErr = errors.WithStack(ErrNotFoundTSOAddr) + forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, server, request, tsDeadlineCh, lastForwardedHost, cancelForward) + if tsoStreamErr != nil { return tsoStreamErr } - if forwardStream == nil || lastForwardedHost != forwardedHost { - if cancelForward != nil { - cancelForward() - } + if err != nil { + return err + } + } +} - clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) - if err != nil { - tsoStreamErr = errors.WithStack(err) - return tsoStreamErr - } - forwardStream, forwardCtx, cancelForward, err = createTSOForwardStream(stream.Context(), clientConn) - if err != nil { - tsoStreamErr = errors.WithStack(err) - return tsoStreamErr - } - lastForwardedHost = forwardedHost +func (s *GrpcServer) handleTSOForwarding(forwardCtx context.Context, forwardStream tsopb.TSO_TsoClient, stream pdpb.PD_TsoServer, server *tsoServer, + request *pdpb.TsoRequest, tsDeadlineCh chan<- *tsoutil.TSDeadline, lastForwardedHost string, cancelForward context.CancelFunc) ( + context.Context, + context.CancelFunc, + tsopb.TSO_TsoClient, + string, + error, // tso stream error + error, // send error +) { + forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), constant.TSOServiceName) + if !ok || len(forwardedHost) == 0 { + return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(ErrNotFoundTSOAddr), nil + } + if forwardStream == nil || lastForwardedHost != forwardedHost { + if cancelForward != nil { + cancelForward() } - tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) + clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) if err != nil { - tsoStreamErr = errors.WithStack(err) - return tsoStreamErr + return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil } + forwardStream, forwardCtx, cancelForward, err = createTSOForwardStream(stream.Context(), clientConn) + if err != nil { + return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil + } + lastForwardedHost = forwardedHost + } - // The error types defined for tsopb and pdpb are different, so we need to convert them. - var pdpbErr *pdpb.Error - tsopbErr := tsopbResp.GetHeader().GetError() - if tsopbErr != nil { - if tsopbErr.Type == tsopb.ErrorType_OK { - pdpbErr = &pdpb.Error{ - Type: pdpb.ErrorType_OK, - Message: tsopbErr.GetMessage(), - } - } else { - // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. - pdpbErr = &pdpb.Error{ - Type: pdpb.ErrorType_UNKNOWN, - Message: tsopbErr.GetMessage(), - } + tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) + if err != nil { + return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil + } + + // The error types defined for tsopb and pdpb are different, so we need to convert them. + var pdpbErr *pdpb.Error + tsopbErr := tsopbResp.GetHeader().GetError() + if tsopbErr != nil { + if tsopbErr.Type == tsopb.ErrorType_OK { + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_OK, + Message: tsopbErr.GetMessage(), + } + } else { + // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_UNKNOWN, + Message: tsopbErr.GetMessage(), } } + } - response := &pdpb.TsoResponse{ - Header: &pdpb.ResponseHeader{ - ClusterId: tsopbResp.GetHeader().GetClusterId(), - Error: pdpbErr, - }, - Count: tsopbResp.GetCount(), - Timestamp: tsopbResp.GetTimestamp(), - } - if err := server.send(response); err != nil { - return errors.WithStack(err) - } + response := &pdpb.TsoResponse{ + Header: &pdpb.ResponseHeader{ + ClusterId: tsopbResp.GetHeader().GetClusterId(), + Error: pdpbErr, + }, + Count: tsopbResp.GetCount(), + Timestamp: tsopbResp.GetTimestamp(), + } + if server != nil { + err = server.send(response) + } else { + err = stream.Send(response) } + return forwardCtx, cancelForward, forwardStream, lastForwardedHost, nil, errors.WithStack(err) } func (s *GrpcServer) forwardTSORequestWithDeadLine( diff --git a/server/grpc_service.go b/server/grpc_service.go index 9e892dda161..ec03819ccaf 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -274,7 +274,7 @@ func (s *GrpcServer) GetClusterInfo(context.Context, *pdpb.GetClusterInfoRequest var tsoServiceAddrs []string svcModes := make([]pdpb.ServiceMode, 0) - if s.IsAPIServiceMode() { + if s.IsServiceIndependent(constant.TSOServiceName) { svcModes = append(svcModes, pdpb.ServiceMode_API_SVC_MODE) tsoServiceAddrs = s.keyspaceGroupManager.GetTSOServiceAddrs() } else { @@ -318,7 +318,7 @@ func (s *GrpcServer) GetMinTS( minTS *pdpb.Timestamp err error ) - if s.IsAPIServiceMode() { + if s.IsServiceIndependent(constant.TSOServiceName) { minTS, err = s.GetMinTSFromTSOService(tso.GlobalDCLocation) } else { start := time.Now() @@ -486,7 +486,7 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb } tsoAllocatorLeaders := make(map[string]*pdpb.Member) - if !s.IsAPIServiceMode() { + if !s.IsServiceIndependent(constant.TSOServiceName) { tsoAllocatorManager := s.GetTSOAllocatorManager() tsoAllocatorLeaders, err = tsoAllocatorManager.GetLocalAllocatorLeaders() } @@ -524,7 +524,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { return err } } - if s.IsAPIServiceMode() { + if s.IsServiceIndependent(constant.TSOServiceName) { return s.forwardTSO(stream) } diff --git a/server/server.go b/server/server.go index 26f8ebb614c..96e359e40d8 100644 --- a/server/server.go +++ b/server/server.go @@ -1417,6 +1417,18 @@ func (s *Server) GetRaftCluster() *cluster.RaftCluster { return s.cluster } +// IsServiceIndependent returns whether the service is independent. +func (s *Server) IsServiceIndependent(name string) bool { + if s.mode == APIServiceMode && !s.IsClosed() { + // TODO: remove it after we support tso discovery + if name == constant.TSOServiceName { + return true + } + return s.cluster.IsServiceIndependent(name) + } + return false +} + // DirectlyGetRaftCluster returns raft cluster directly. // Only used for test. func (s *Server) DirectlyGetRaftCluster() *cluster.RaftCluster { diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 326068c29b5..443bee2cd6a 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -110,7 +110,7 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { var respSlice []string var resp map[string]any testutil.Eventually(re, func() bool { - return leader.GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName) + return leader.IsServiceIndependent(constant.SchedulingServiceName) }) // Test operators diff --git a/tests/testutil.go b/tests/testutil.go index 22a5ab40a7e..98a64c4686c 100644 --- a/tests/testutil.go +++ b/tests/testutil.go @@ -412,7 +412,7 @@ func (s *SchedulingTestEnvironment) startCluster(m SchedulerMode) { cluster.SetSchedulingCluster(tc) time.Sleep(200 * time.Millisecond) // wait for scheduling cluster to update member testutil.Eventually(re, func() bool { - return cluster.GetLeaderServer().GetServer().GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName) + return cluster.GetLeaderServer().GetServer().IsServiceIndependent(constant.SchedulingServiceName) }) s.clusters[APIMode] = cluster }