From 34b88b29ae18b61bcb82ace8638676e2c50f7bd8 Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Thu, 1 Feb 2024 16:25:33 +0800 Subject: [PATCH] feat: support thrift and pb multi service (#1217) --- client/service_inline.go | 3 +- internal/mocks/serviceinfo.go | 25 ++++- internal/server/option.go | 2 + internal/server/register_option.go | 41 +++++++ pkg/diagnosis/interface.go | 7 +- pkg/generic/binarythrift_codec_test.go | 8 ++ pkg/generic/json_test/generic_init.go | 4 + pkg/remote/codec/header_codec.go | 5 + pkg/remote/codec/header_codec_test.go | 18 ++- pkg/remote/codec/util.go | 8 +- pkg/remote/codec/util_test.go | 20 +++- pkg/remote/message.go | 76 ++++++++++--- pkg/remote/option.go | 7 +- pkg/remote/trans/default_server_handler.go | 55 +++++---- .../trans/default_server_handler_test.go | 50 ++++++--- .../trans/detection/server_handler_test.go | 21 +++- pkg/remote/trans/gonet/server_handler_test.go | 2 +- pkg/remote/trans/gonet/trans_server_test.go | 21 +++- pkg/remote/trans/mocks_test.go | 8 ++ .../trans/netpoll/http_client_handler_test.go | 14 ++- .../trans/netpoll/server_handler_test.go | 2 +- pkg/remote/trans/netpoll/trans_server_test.go | 22 +++- pkg/remote/trans/netpollmux/mocks_test.go | 8 ++ pkg/remote/trans/netpollmux/server_handler.go | 76 ++++++++----- .../trans/netpollmux/server_handler_test.go | 38 +++++-- pkg/remote/trans/nphttp2/mocks_test.go | 10 +- pkg/remote/trans/nphttp2/server_handler.go | 16 +-- pkg/remote/trans_errors.go | 1 + pkg/remote/transmeta/metakey.go | 1 + pkg/transmeta/ttheader.go | 1 + pkg/transmeta/ttheader_test.go | 5 + server/invoke.go | 2 +- server/option.go | 9 ++ server/option_advanced_test.go | 2 +- server/option_test.go | 28 ++++- server/register_option.go | 33 ++++++ server/register_option_test.go | 30 +++++ server/server.go | 66 ++++++++--- server/server_test.go | 106 ++++++++++++++++-- server/service.go | 96 ++++++++++++++-- server/service_test.go | 88 +++++++++++++++ tool/internal_pkg/generator/type.go | 1 + tool/internal_pkg/tpl/client.go | 1 + tool/internal_pkg/tpl/server.go | 4 + 44 files changed, 873 insertions(+), 168 deletions(-) create mode 100644 internal/server/register_option.go create mode 100644 server/register_option.go create mode 100644 server/register_option_test.go create mode 100644 server/service_test.go diff --git a/client/service_inline.go b/client/service_inline.go index f4930057f1..eda548c478 100644 --- a/client/service_inline.go +++ b/client/service_inline.go @@ -78,7 +78,8 @@ func NewServiceInlineClient(svcInfo *serviceinfo.ServiceInfo, s ServerInitialInf kc.opt = client.NewOptions(opts) kc.serverEps = s.Endpoints() kc.serverOpt = s.Option() - kc.serverOpt.RemoteOpt.SvcMap = s.GetServiceInfos() + kc.serverOpt.RemoteOpt.TargetSvcInfo = svcInfo + kc.serverOpt.RemoteOpt.SvcSearchMap = s.GetServiceInfos() if err := kc.init(); err != nil { _ = kc.Close() return nil, err diff --git a/internal/mocks/serviceinfo.go b/internal/mocks/serviceinfo.go index ec60603171..7d6f0020b5 100644 --- a/internal/mocks/serviceinfo.go +++ b/internal/mocks/serviceinfo.go @@ -30,6 +30,7 @@ import ( const ( MockServiceName = "MockService" MockService2Name = "MockService2" + MockService3Name = "MockService3" MockMethod string = "mock" Mock2Method string = "mock2" MockExceptionMethod string = "mockException" @@ -66,7 +67,7 @@ func newServiceInfo() *serviceinfo.ServiceInfo { return svcInfo } -// ServiceInfo return mock serviceInfo +// Service2Info return mock serviceInfo func Service2Info() *serviceinfo.ServiceInfo { return myServiceService2Info } @@ -88,6 +89,28 @@ func newService2Info() *serviceinfo.ServiceInfo { return svcInfo } +// Service3Info return mock serviceInfo +func Service3Info() *serviceinfo.ServiceInfo { + return myServiceService3Info +} + +var myServiceService3Info = newService3Info() + +func newService3Info() *serviceinfo.ServiceInfo { + methods := map[string]serviceinfo.MethodInfo{ + "mock": serviceinfo.NewMethodInfo(mockHandler, NewMockArgs, NewMockResult, false), + } + + svcInfo := &serviceinfo.ServiceInfo{ + ServiceName: MockService3Name, + Methods: methods, + Extra: map[string]interface{}{ + "PackageName": "mock", + }, + } + return svcInfo +} + func mockHandler(ctx context.Context, handler, args, result interface{}) error { a := args.(*myServiceMockArgs) r := result.(*myServiceMockResult) diff --git a/internal/server/option.go b/internal/server/option.go index 18051120b1..35aa34dcec 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -97,6 +97,8 @@ type Options struct { BackupOpt backup.Options Streaming stream.StreamingConfig + + RefuseTrafficWithoutServiceName bool } type Limit struct { diff --git a/internal/server/register_option.go b/internal/server/register_option.go new file mode 100644 index 0000000000..8f582c6a56 --- /dev/null +++ b/internal/server/register_option.go @@ -0,0 +1,41 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +// RegisterOption is the only way to config service registration. +type RegisterOption struct { + F func(o *RegisterOptions) +} + +// RegisterOptions is used to config service registration. +type RegisterOptions struct { + IsFallbackService bool +} + +// NewRegisterOptions creates a register options. +func NewRegisterOptions(opts []RegisterOption) *RegisterOptions { + o := &RegisterOptions{} + ApplyRegisterOptions(opts, o) + return o +} + +// ApplyRegisterOptions applies the given register options. +func ApplyRegisterOptions(opts []RegisterOption, o *RegisterOptions) { + for _, op := range opts { + op.F(o) + } +} diff --git a/pkg/diagnosis/interface.go b/pkg/diagnosis/interface.go index b6391024c4..26f7c7bed8 100644 --- a/pkg/diagnosis/interface.go +++ b/pkg/diagnosis/interface.go @@ -42,9 +42,10 @@ func RegisterProbeFunc(svc Service, name ProbeName, pf ProbeFunc) { // If you want to register other info, please use RegisterProbeFunc(ProbeName, ProbeFunc) to do that. const ( // Common - ChangeEventsKey ProbeName = "events" - ServiceInfosKey ProbeName = "service_infos" - OptionsKey ProbeName = "options" + ChangeEventsKey ProbeName = "events" + ServiceInfosKey ProbeName = "service_infos" + FallbackServiceKey ProbeName = "fallback_service" + OptionsKey ProbeName = "options" // Client DestServiceKey ProbeName = "dest_service" diff --git a/pkg/generic/binarythrift_codec_test.go b/pkg/generic/binarythrift_codec_test.go index 4b6403ab08..f67d82cec0 100644 --- a/pkg/generic/binarythrift_codec_test.go +++ b/pkg/generic/binarythrift_codec_test.go @@ -168,6 +168,7 @@ var _ remote.Message = &mockMessage{} type mockMessage struct { RPCInfoFunc func() rpcinfo.RPCInfo ServiceInfoFunc func() *serviceinfo.ServiceInfo + SetServiceInfoFunc func(svcName, methodName string) (*serviceinfo.ServiceInfo, error) DataFunc func() interface{} NewDataFunc func(method string) (ok bool) MessageTypeFunc func() remote.MessageType @@ -198,6 +199,13 @@ func (m *mockMessage) ServiceInfo() (si *serviceinfo.ServiceInfo) { return } +func (m *mockMessage) SpecifyServiceInfo(svcName, methodName string) (si *serviceinfo.ServiceInfo, err error) { + if m.SetServiceInfoFunc != nil { + return m.SetServiceInfoFunc(svcName, methodName) + } + return nil, nil +} + func (m *mockMessage) Data() interface{} { if m.DataFunc != nil { return m.DataFunc() diff --git a/pkg/generic/json_test/generic_init.go b/pkg/generic/json_test/generic_init.go index 4adc15b476..20ec556aa2 100644 --- a/pkg/generic/json_test/generic_init.go +++ b/pkg/generic/json_test/generic_init.go @@ -34,6 +34,7 @@ import ( "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/genericclient" + "github.com/cloudwego/kitex/internal/mocks" kt "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/generic" @@ -269,6 +270,9 @@ func newMockServer(handler kt.Mock, addr net.Addr, opts ...server.Option) server if err := svr.RegisterService(serviceInfo(), handler); err != nil { panic(err) } + if err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()); err != nil { + panic(err) + } go func() { err := svr.Run() if err != nil { diff --git a/pkg/remote/codec/header_codec.go b/pkg/remote/codec/header_codec.go index a5ad3563b0..524dd6a903 100644 --- a/pkg/remote/codec/header_codec.go +++ b/pkg/remote/codec/header_codec.go @@ -492,6 +492,11 @@ func fillBasicInfoOfTTHeader(msg remote.Message) { fi.SetServiceName(v) } } + if ink, ok := msg.RPCInfo().Invocation().(rpcinfo.InvocationSetter); ok { + if svcName, ok := msg.TransInfo().TransStrInfo()[transmeta.HeaderIDLServiceName]; ok { + ink.SetServiceName(svcName) + } + } } else { ti := remoteinfo.AsRemoteInfo(msg.RPCInfo().To()) if ti != nil { diff --git a/pkg/remote/codec/header_codec_test.go b/pkg/remote/codec/header_codec_test.go index 0c52ca7c5f..f23c936669 100644 --- a/pkg/remote/codec/header_codec_test.go +++ b/pkg/remote/codec/header_codec_test.go @@ -32,6 +32,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote/transmeta" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" tm "github.com/cloudwego/kitex/pkg/transmeta" "github.com/cloudwego/kitex/transport" ) @@ -306,8 +307,18 @@ var ( ) func initServerRecvMsg() remote.Message { - var req interface{} - msg := remote.NewMessage(req, mocks.ServiceInfo(), mockSvrRPCInfo, remote.Call, remote.Server) + svcInfo := mocks.ServiceInfo() + svcSearchMap := map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, + mocks.MockMethod: svcInfo, + mocks.MockExceptionMethod: svcInfo, + mocks.MockErrorMethod: svcInfo, + mocks.MockOnewayMethod: svcInfo, + } + msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, mockSvrRPCInfo, remote.Call, remote.Server, false) return msg } @@ -367,7 +378,7 @@ func prepareIntKVInfo() map[uint16]string { } func prepareStrKVInfo() map[string]string { - kvInfo := map[string]string{} + kvInfo := map[string]string{transmeta.HeaderIDLServiceName: mocks.MockServiceName} return kvInfo } @@ -375,6 +386,7 @@ func prepareStrKVInfoWithGDPRToken() map[string]string { kvInfo := map[string]string{ transmeta.GDPRToken: "mockToken", transmeta.HeaderTransRemoteAddr: "mockRemoteAddr", + transmeta.HeaderIDLServiceName: mocks.MockServiceName, } return kvInfo } diff --git a/pkg/remote/codec/util.go b/pkg/remote/codec/util.go index 7cf354da57..78b13347f2 100644 --- a/pkg/remote/codec/util.go +++ b/pkg/remote/codec/util.go @@ -45,7 +45,10 @@ func SetOrCheckMethodName(methodName string, message remote.Message) error { if message.RPCRole() == remote.Client { return fmt.Errorf("wrong method name, expect=%s, actual=%s", callMethodName, methodName) } - svcInfo := message.ServiceInfo() + svcInfo, err := message.SpecifyServiceInfo(ink.ServiceName(), methodName) + if err != nil { + return err + } if ink, ok := ink.(rpcinfo.InvocationSetter); ok { ink.SetMethodName(methodName) ink.SetPackageName(svcInfo.GetPackageName()) @@ -53,9 +56,6 @@ func SetOrCheckMethodName(methodName string, message remote.Message) error { } else { return errors.New("the interface Invocation doesn't implement InvocationSetter") } - if mt := svcInfo.MethodInfo(methodName); mt == nil { - return remote.NewTransErrorWithMsg(remote.UnknownMethod, fmt.Sprintf("unknown method %s", methodName)) - } // unknown method doesn't set methodName for RPCInfo.To(), or lead inconsistent with old version rpcinfo.AsMutableEndpointInfo(ri.To()).SetMethod(methodName) diff --git a/pkg/remote/codec/util_test.go b/pkg/remote/codec/util_test.go index dadc7abcb3..b177fe97b6 100644 --- a/pkg/remote/codec/util_test.go +++ b/pkg/remote/codec/util_test.go @@ -23,13 +23,24 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) func TestSetOrCheckMethodName(t *testing.T) { - var req interface{} ri := rpcinfo.NewRPCInfo(nil, rpcinfo.NewEndpointInfo("", "mock", nil, nil), rpcinfo.NewServerInvocation(), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) - msg := remote.NewMessage(req, mocks.ServiceInfo(), ri, remote.Call, remote.Server) + svcInfo := mocks.ServiceInfo() + svcSearchMap := map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, + mocks.MockMethod: svcInfo, + mocks.MockExceptionMethod: svcInfo, + mocks.MockErrorMethod: svcInfo, + mocks.MockOnewayMethod: svcInfo, + } + msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) err := SetOrCheckMethodName("mock", msg) test.Assert(t, err == nil) ri = msg.RPCInfo() @@ -37,4 +48,9 @@ func TestSetOrCheckMethodName(t *testing.T) { test.Assert(t, ri.Invocation().PackageName() == "mock") test.Assert(t, ri.Invocation().MethodName() == "mock") test.Assert(t, ri.To().Method() == "mock") + + msg = remote.NewMessageWithNewer(svcInfo, map[string]*serviceinfo.ServiceInfo{}, ri, remote.Call, remote.Server, false) + err = SetOrCheckMethodName("dummy", msg) + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "unknown method dummy") } diff --git a/pkg/remote/message.go b/pkg/remote/message.go index 80238ad2fd..57749e385a 100644 --- a/pkg/remote/message.go +++ b/pkg/remote/message.go @@ -17,10 +17,12 @@ package remote import ( + "fmt" "sync" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/utils" "github.com/cloudwego/kitex/transport" ) @@ -83,6 +85,7 @@ func NewProtocolInfo(tp transport.Protocol, ct serviceinfo.PayloadCodec) Protoco type Message interface { RPCInfo() rpcinfo.RPCInfo ServiceInfo() *serviceinfo.ServiceInfo + SpecifyServiceInfo(svcName, methodName string) (*serviceinfo.ServiceInfo, error) Data() interface{} NewData(method string) (ok bool) MessageType() MessageType @@ -104,7 +107,7 @@ func NewMessage(data interface{}, svcInfo *serviceinfo.ServiceInfo, ri rpcinfo.R msg := messagePool.Get().(*message) msg.data = data msg.rpcInfo = ri - msg.svcInfo = svcInfo + msg.targetSvcInfo = svcInfo msg.msgType = msgType msg.rpcRole = rpcRole msg.transInfo = transInfoPool.Get().(*transInfo) @@ -112,13 +115,15 @@ func NewMessage(data interface{}, svcInfo *serviceinfo.ServiceInfo, ri rpcinfo.R } // NewMessageWithNewer creates a new Message and set data later. -func NewMessageWithNewer(svcInfo *serviceinfo.ServiceInfo, ri rpcinfo.RPCInfo, msgType MessageType, rpcRole RPCRole) Message { +func NewMessageWithNewer(targetSvcInfo *serviceinfo.ServiceInfo, svcSearchMap map[string]*serviceinfo.ServiceInfo, ri rpcinfo.RPCInfo, msgType MessageType, rpcRole RPCRole, refuseTrafficWithoutServiceName bool) Message { msg := messagePool.Get().(*message) msg.rpcInfo = ri - msg.svcInfo = svcInfo + msg.targetSvcInfo = targetSvcInfo + msg.svcSearchMap = svcSearchMap msg.msgType = msgType msg.rpcRole = rpcRole msg.transInfo = transInfoPool.Get().(*transInfo) + msg.refuseTrafficWithoutServiceName = refuseTrafficWithoutServiceName return msg } @@ -134,24 +139,26 @@ func newMessage() interface{} { } type message struct { - msgType MessageType - data interface{} - rpcInfo rpcinfo.RPCInfo - svcInfo *serviceinfo.ServiceInfo - rpcRole RPCRole - compressType CompressType - payloadSize int - transInfo TransInfo - tags map[string]interface{} - protocol ProtocolInfo - payloadCodec PayloadCodec + msgType MessageType + data interface{} + rpcInfo rpcinfo.RPCInfo + targetSvcInfo *serviceinfo.ServiceInfo + svcSearchMap map[string]*serviceinfo.ServiceInfo + rpcRole RPCRole + compressType CompressType + payloadSize int + transInfo TransInfo + tags map[string]interface{} + protocol ProtocolInfo + payloadCodec PayloadCodec + refuseTrafficWithoutServiceName bool } func (m *message) zero() { m.msgType = InvalidMessageType m.data = nil m.rpcInfo = nil - m.svcInfo = &emptyServiceInfo + m.targetSvcInfo = &emptyServiceInfo m.rpcRole = -1 m.compressType = NoCompress m.payloadSize = 0 @@ -172,7 +179,32 @@ func (m *message) RPCInfo() rpcinfo.RPCInfo { // ServiceInfo implements the Message interface. func (m *message) ServiceInfo() *serviceinfo.ServiceInfo { - return m.svcInfo + return m.targetSvcInfo +} + +func (m *message) SpecifyServiceInfo(svcName, methodName string) (*serviceinfo.ServiceInfo, error) { + // for non-multi-service including generic server scenario + if m.targetSvcInfo != nil { + if mt := m.targetSvcInfo.MethodInfo(methodName); mt == nil { + return nil, NewTransErrorWithMsg(UnknownMethod, fmt.Sprintf("unknown method %s", methodName)) + } + return m.targetSvcInfo, nil + } + if svcName == "" && m.refuseTrafficWithoutServiceName { + return nil, NewTransErrorWithMsg(NoServiceName, "no service name while the server has WithRefuseTrafficWithoutServiceName option enabled") + } + var key string + if svcName == "" { + key = methodName + } else { + key = BuildMultiServiceKey(svcName, methodName) + } + svcInfo := m.svcSearchMap[key] + if svcInfo == nil { + return nil, NewTransErrorWithMsg(UnknownMethod, fmt.Sprintf("unknown method %s", methodName)) + } + m.targetSvcInfo = svcInfo + return svcInfo, nil } // Data implements the Message interface. @@ -185,7 +217,7 @@ func (m *message) NewData(method string) (ok bool) { if m.data != nil { return false } - if mt := m.svcInfo.MethodInfo(method); mt != nil { + if mt := m.targetSvcInfo.MethodInfo(method); mt != nil { m.data = mt.NewArgs() } if m.data == nil { @@ -334,3 +366,13 @@ func FillSendMsgFromRecvMsg(recvMsg, sendMsg Message) { sendMsg.SetProtocolInfo(recvMsg.ProtocolInfo()) sendMsg.SetPayloadCodec(recvMsg.PayloadCodec()) } + +// BuildMultiServiceKey is used to create a key to search svcInfo from svcSearchMap. +func BuildMultiServiceKey(serviceName, methodName string) string { + var builder utils.StringBuilder + builder.Grow(len(serviceName) + len(methodName) + 1) + builder.WriteString(serviceName) + builder.WriteString(".") + builder.WriteString(methodName) + return builder.String() +} diff --git a/pkg/remote/option.go b/pkg/remote/option.go index 3b13f55ec2..afe0f567b9 100644 --- a/pkg/remote/option.go +++ b/pkg/remote/option.go @@ -70,7 +70,9 @@ func (o *Option) AppendBoundHandler(h BoundHandler) { // ServerOption contains option that is used to init the remote server. type ServerOption struct { - SvcMap map[string]*serviceinfo.ServiceInfo + TargetSvcInfo *serviceinfo.ServiceInfo + + SvcSearchMap map[string]*serviceinfo.ServiceInfo TransServerFactory TransServerFactory @@ -111,6 +113,9 @@ type ServerOption struct { GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error + // RefuseTrafficWithoutServiceName is used for a server with multi services + RefuseTrafficWithoutServiceName bool + Option // invoking chain with recv/send middlewares for streaming APIs diff --git a/pkg/remote/trans/default_server_handler.go b/pkg/remote/trans/default_server_handler.go index 15b25d3c89..ebeb42729f 100644 --- a/pkg/remote/trans/default_server_handler.go +++ b/pkg/remote/trans/default_server_handler.go @@ -34,12 +34,12 @@ import ( // NewDefaultSvrTransHandler to provide default impl of svrTransHandler, it can be reused in netpoll, shm-ipc, framework-sdk extensions func NewDefaultSvrTransHandler(opt *remote.ServerOption, ext Extension) (remote.ServerTransHandler, error) { - svcInfo := GetDefaultSvcInfo(opt.SvcMap) svrHdlr := &svrTransHandler{ - opt: opt, - codec: opt.Codec, - svcInfo: svcInfo, - ext: ext, + opt: opt, + codec: opt.Codec, + svcSearchMap: opt.SvcSearchMap, + targetSvcInfo: opt.TargetSvcInfo, + ext: ext, } if svrHdlr.opt.TracerCtl == nil { // init TraceCtl when it is nil, or it will lead some unit tests panic @@ -49,12 +49,13 @@ func NewDefaultSvrTransHandler(opt *remote.ServerOption, ext Extension) (remote. } type svrTransHandler struct { - opt *remote.ServerOption - svcInfo *serviceinfo.ServiceInfo - inkHdlFunc endpoint.Endpoint - codec remote.Codec - transPipe *remote.TransPipeline - ext Extension + opt *remote.ServerOption + svcSearchMap map[string]*serviceinfo.ServiceInfo + targetSvcInfo *serviceinfo.ServiceInfo + inkHdlFunc endpoint.Endpoint + codec remote.Codec + transPipe *remote.TransPipeline + ext Extension } // Write implements the remote.ServerTransHandler interface. @@ -67,9 +68,12 @@ func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, sendMsg remo rpcinfo.Record(ctx, ri, stats.WriteFinish, err) }() - if methodInfo, _ := GetMethodInfo(ri, t.svcInfo); methodInfo != nil { - if methodInfo.OneWay() { - return ctx, nil + svcInfo := sendMsg.ServiceInfo() + if svcInfo != nil { + if methodInfo, _ := GetMethodInfo(ri, svcInfo); methodInfo != nil { + if methodInfo.OneWay() { + return ctx, nil + } } } @@ -163,7 +167,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) }() ctx = t.startTracer(ctx, ri) ctx = t.startProfiler(ctx) - recvMsg = remote.NewMessageWithNewer(t.svcInfo, ri, remote.Call, remote.Server) + recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, ri, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName) recvMsg.SetPayloadCodec(t.opt.PayloadCodec) ctx, err = t.transPipe.Read(ctx, conn, recvMsg) if err != nil { @@ -172,15 +176,16 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) return err } + svcInfo := recvMsg.ServiceInfo() // heartbeat processing // recvMsg.MessageType would be set to remote.Heartbeat in previous Read procedure // if specified codec support heartbeat if recvMsg.MessageType() == remote.Heartbeat { - sendMsg = remote.NewMessage(nil, t.svcInfo, ri, remote.Heartbeat, remote.Server) + sendMsg = remote.NewMessage(nil, svcInfo, ri, remote.Heartbeat, remote.Server) } else { // reply processing var methodInfo serviceinfo.MethodInfo - if methodInfo, err = GetMethodInfo(ri, t.svcInfo); err != nil { + if methodInfo, err = GetMethodInfo(ri, svcInfo); err != nil { // it won't be err, because the method has been checked in decode, err check here just do defensive inspection t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, true) // for proxy case, need read actual remoteAddr, error print must exec after writeErrorReplyIfNeeded @@ -188,9 +193,9 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) return err } if methodInfo.OneWay() { - sendMsg = remote.NewMessage(nil, t.svcInfo, ri, remote.Reply, remote.Server) + sendMsg = remote.NewMessage(nil, svcInfo, ri, remote.Reply, remote.Server) } else { - sendMsg = remote.NewMessage(methodInfo.NewResult(), t.svcInfo, ri, remote.Reply, remote.Server) + sendMsg = remote.NewMessage(methodInfo.NewResult(), svcInfo, ri, remote.Reply, remote.Server) } ctx, err = t.transPipe.OnMessage(ctx, recvMsg, sendMsg) @@ -273,16 +278,20 @@ func (t *svrTransHandler) writeErrorReplyIfNeeded( // conn is closed, no need reply return } - if methodInfo, _ := GetMethodInfo(ri, t.svcInfo); methodInfo != nil { - if methodInfo.OneWay() { - return + svcInfo := recvMsg.ServiceInfo() + if svcInfo != nil { + if methodInfo, _ := GetMethodInfo(ri, svcInfo); methodInfo != nil { + if methodInfo.OneWay() { + return + } } } + transErr, isTransErr := err.(*remote.TransError) if !isTransErr { return } - errMsg := remote.NewMessage(transErr, t.svcInfo, ri, remote.Exception, remote.Server) + errMsg := remote.NewMessage(transErr, svcInfo, ri, remote.Exception, remote.Server) remote.FillSendMsgFromRecvMsg(recvMsg, errMsg) if doOnMessage { // if error happen before normal OnMessage, exec it to transfer header trans info into rpcinfo diff --git a/pkg/remote/trans/default_server_handler_test.go b/pkg/remote/trans/default_server_handler_test.go index 450b9e9e83..c8ed141988 100644 --- a/pkg/remote/trans/default_server_handler_test.go +++ b/pkg/remote/trans/default_server_handler_test.go @@ -33,6 +33,20 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) +var ( + svcInfo = mocks.ServiceInfo() + svcSearchMap = map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, + mocks.MockMethod: svcInfo, + mocks.MockExceptionMethod: svcInfo, + mocks.MockErrorMethod: svcInfo, + mocks.MockOnewayMethod: svcInfo, + } +) + func TestDefaultSvrTransHandler(t *testing.T) { buf := remote.NewReaderWriterBuffer(1024) ext := &MockExtension{ @@ -45,8 +59,6 @@ func TestDefaultSvrTransHandler(t *testing.T) { } tagEncode, tagDecode := 0, 0 - svcMap := map[string]*serviceinfo.ServiceInfo{} - svcMap[mocks.MockServiceName] = mocks.ServiceInfo() opt := &remote.ServerOption{ Codec: &MockCodec{ EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { @@ -60,7 +72,8 @@ func TestDefaultSvrTransHandler(t *testing.T) { return nil }, }, - SvcMap: svcMap, + SvcSearchMap: svcSearchMap, + TargetSvcInfo: svcInfo, } handler, err := NewDefaultSvrTransHandler(opt, ext) @@ -72,6 +85,13 @@ func TestDefaultSvrTransHandler(t *testing.T) { RPCInfoFunc: func() rpcinfo.RPCInfo { return newMockRPCInfo() }, + ServiceInfoFunc: func() *serviceinfo.ServiceInfo { + return &serviceinfo.ServiceInfo{ + Methods: map[string]serviceinfo.MethodInfo{ + "method": serviceinfo.NewMethodInfo(nil, nil, nil, false), + }, + } + }, } ctx, err = handler.Write(ctx, conn, msg) test.Assert(t, ctx != nil, ctx) @@ -109,19 +129,19 @@ func TestSvrTransHandlerBizError(t *testing.T) { tracerCtl := &rpcinfo.TraceController{} tracerCtl.Append(mockTracer) - svcMap := map[string]*serviceinfo.ServiceInfo{} - svcMap[mocks.MockServiceName] = mocks.ServiceInfo() opt := &remote.ServerOption{ Codec: &MockCodec{ EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { return nil }, DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { + msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) return nil }, }, - SvcMap: svcMap, - TracerCtl: tracerCtl, + SvcSearchMap: svcSearchMap, + TargetSvcInfo: svcInfo, + TracerCtl: tracerCtl, InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr) return ri @@ -168,19 +188,19 @@ func TestSvrTransHandlerReadErr(t *testing.T) { mockErr := errors.New("mock") tracerCtl := &rpcinfo.TraceController{} tracerCtl.Append(mockTracer) - svcMap := map[string]*serviceinfo.ServiceInfo{} - svcMap[mocks.MockServiceName] = mocks.ServiceInfo() opt := &remote.ServerOption{ Codec: &MockCodec{ EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { return nil }, DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { + msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) return mockErr }, }, - SvcMap: svcMap, - TracerCtl: tracerCtl, + SvcSearchMap: svcSearchMap, + TargetSvcInfo: svcInfo, + TracerCtl: tracerCtl, InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr) return ri @@ -222,8 +242,6 @@ func TestSvrTransHandlerOnReadHeartbeat(t *testing.T) { tracerCtl := &rpcinfo.TraceController{} tracerCtl.Append(mockTracer) - svcMap := map[string]*serviceinfo.ServiceInfo{} - svcMap[mocks.MockServiceName] = mocks.ServiceInfo() opt := &remote.ServerOption{ Codec: &MockCodec{ EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { @@ -234,11 +252,13 @@ func TestSvrTransHandlerOnReadHeartbeat(t *testing.T) { }, DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { msg.SetMessageType(remote.Heartbeat) + msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) return nil }, }, - SvcMap: svcMap, - TracerCtl: tracerCtl, + SvcSearchMap: svcSearchMap, + TargetSvcInfo: svcInfo, + TracerCtl: tracerCtl, InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr) return ri diff --git a/pkg/remote/trans/detection/server_handler_test.go b/pkg/remote/trans/detection/server_handler_test.go index 0c6680db04..04343edc18 100644 --- a/pkg/remote/trans/detection/server_handler_test.go +++ b/pkg/remote/trans/detection/server_handler_test.go @@ -49,12 +49,23 @@ var ( } return grpc.ClientPrefaceLen }() - svcMap = map[string]*serviceinfo.ServiceInfo{mocks.MockServiceName: mocks.ServiceInfo()} + svcInfo = mocks.ServiceInfo() + svcSearchMap = map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, + mocks.MockMethod: svcInfo, + mocks.MockExceptionMethod: svcInfo, + mocks.MockErrorMethod: svcInfo, + mocks.MockOnewayMethod: svcInfo, + } ) func TestServerHandlerCall(t *testing.T) { transHdler, _ := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ - SvcMap: svcMap, + SvcSearchMap: svcSearchMap, + TargetSvcInfo: svcInfo, }) ctrl := gomock.NewController(t) @@ -123,7 +134,8 @@ func TestOnError(t *testing.T) { ctrl.Finish() }() transHdler, err := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ - SvcMap: svcMap, + SvcSearchMap: svcSearchMap, + TargetSvcInfo: svcInfo, }) test.Assert(t, err == nil) @@ -152,7 +164,8 @@ func TestOnError(t *testing.T) { // TestOnInactive covers onInactive() codes to check panic func TestOnInactive(t *testing.T) { transHdler, err := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ - SvcMap: svcMap, + SvcSearchMap: svcSearchMap, + TargetSvcInfo: svcInfo, }) test.Assert(t, err == nil) diff --git a/pkg/remote/trans/gonet/server_handler_test.go b/pkg/remote/trans/gonet/server_handler_test.go index 14c4ff5dfd..a34f1c4c69 100644 --- a/pkg/remote/trans/gonet/server_handler_test.go +++ b/pkg/remote/trans/gonet/server_handler_test.go @@ -171,7 +171,7 @@ func TestNoMethodInfo(t *testing.T) { } remote.NewTransPipeline(svrTransHdlr) - svcInfo := svrOpt.SvcMap[mocks.MockServiceName] + svcInfo := svrOpt.TargetSvcInfo delete(svcInfo.Methods, method) // 2. test diff --git a/pkg/remote/trans/gonet/trans_server_test.go b/pkg/remote/trans/gonet/trans_server_test.go index 5a94d2398f..fda743e03f 100644 --- a/pkg/remote/trans/gonet/trans_server_test.go +++ b/pkg/remote/trans/gonet/trans_server_test.go @@ -17,6 +17,7 @@ package gonet import ( + "context" "net" "os" "testing" @@ -41,6 +42,7 @@ var ( ) func TestMain(m *testing.M) { + svcInfo := mocks.ServiceInfo() svrOpt = &remote.ServerOption{ InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { fromInfo := rpcinfo.EmptyEndpointInfo() @@ -55,10 +57,23 @@ func TestMain(m *testing.M) { }, Codec: &MockCodec{ EncodeFunc: nil, - DecodeFunc: nil, + DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { + msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) + return nil + }, }, - SvcMap: map[string]*serviceinfo.ServiceInfo{mocks.MockServiceName: mocks.ServiceInfo()}, - TracerCtl: &rpcinfo.TraceController{}, + SvcSearchMap: map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, + mocks.MockMethod: svcInfo, + mocks.MockExceptionMethod: svcInfo, + mocks.MockErrorMethod: svcInfo, + mocks.MockOnewayMethod: svcInfo, + }, + TargetSvcInfo: svcInfo, + TracerCtl: &rpcinfo.TraceController{}, } svrTransHdlr, _ = newSvrTransHandler(svrOpt) transSvr = NewTransServerFactory().NewTransServer(svrOpt, svrTransHdlr).(*transServer) diff --git a/pkg/remote/trans/mocks_test.go b/pkg/remote/trans/mocks_test.go index a3b60920b8..d83a26fed2 100644 --- a/pkg/remote/trans/mocks_test.go +++ b/pkg/remote/trans/mocks_test.go @@ -120,6 +120,7 @@ var _ remote.Message = &MockMessage{} type MockMessage struct { RPCInfoFunc func() rpcinfo.RPCInfo ServiceInfoFunc func() *serviceinfo.ServiceInfo + SetServiceInfoFunc func(svcName, methodName string) (*serviceinfo.ServiceInfo, error) DataFunc func() interface{} NewDataFunc func(method string) (ok bool) MessageTypeFunc func() remote.MessageType @@ -150,6 +151,13 @@ func (m *MockMessage) ServiceInfo() (si *serviceinfo.ServiceInfo) { return } +func (m *MockMessage) SpecifyServiceInfo(svcName, methodName string) (si *serviceinfo.ServiceInfo, err error) { + if m.SetServiceInfoFunc != nil { + return m.SetServiceInfoFunc(svcName, methodName) + } + return nil, nil +} + func (m *MockMessage) Data() interface{} { if m.DataFunc != nil { return m.DataFunc() diff --git a/pkg/remote/trans/netpoll/http_client_handler_test.go b/pkg/remote/trans/netpoll/http_client_handler_test.go index 30a763af6d..fed8d05a30 100644 --- a/pkg/remote/trans/netpoll/http_client_handler_test.go +++ b/pkg/remote/trans/netpoll/http_client_handler_test.go @@ -30,6 +30,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/trans" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) var ( @@ -118,10 +119,19 @@ func TestHTTPRead(t *testing.T) { func TestHTTPOnMessage(t *testing.T) { // 1. prepare mock data svcInfo := mocks.ServiceInfo() - + svcSearchMap := map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, + mocks.MockMethod: svcInfo, + mocks.MockExceptionMethod: svcInfo, + mocks.MockErrorMethod: svcInfo, + mocks.MockOnewayMethod: svcInfo, + } ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, method), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, ri, remote.Call, remote.Server) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) sendMsg := remote.NewMessage(svcInfo.MethodInfo(method).NewResult(), svcInfo, ri, remote.Reply, remote.Server) // 2. test diff --git a/pkg/remote/trans/netpoll/server_handler_test.go b/pkg/remote/trans/netpoll/server_handler_test.go index 82130a8809..fed95dd5a5 100644 --- a/pkg/remote/trans/netpoll/server_handler_test.go +++ b/pkg/remote/trans/netpoll/server_handler_test.go @@ -266,7 +266,7 @@ func TestNoMethodInfo(t *testing.T) { }, } remote.NewTransPipeline(svrTransHdlr) - svcInfo := svrOpt.SvcMap[mocks.MockServiceName] + svcInfo := svrOpt.TargetSvcInfo delete(svcInfo.Methods, method) // 2. test diff --git a/pkg/remote/trans/netpoll/trans_server_test.go b/pkg/remote/trans/netpoll/trans_server_test.go index 952d5efe7f..f8c9ef1502 100644 --- a/pkg/remote/trans/netpoll/trans_server_test.go +++ b/pkg/remote/trans/netpoll/trans_server_test.go @@ -47,8 +47,7 @@ var ( ) func TestMain(m *testing.M) { - svcMap := map[string]*serviceinfo.ServiceInfo{} - svcMap[mocks.MockServiceName] = mocks.ServiceInfo() + svcInfo := mocks.ServiceInfo() svrOpt = &remote.ServerOption{ InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { fromInfo := rpcinfo.EmptyEndpointInfo() @@ -63,10 +62,23 @@ func TestMain(m *testing.M) { }, Codec: &MockCodec{ EncodeFunc: nil, - DecodeFunc: nil, + DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { + msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) + return nil + }, + }, + SvcSearchMap: map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, + mocks.MockMethod: svcInfo, + mocks.MockExceptionMethod: svcInfo, + mocks.MockErrorMethod: svcInfo, + mocks.MockOnewayMethod: svcInfo, }, - SvcMap: svcMap, - TracerCtl: &rpcinfo.TraceController{}, + TargetSvcInfo: svcInfo, + TracerCtl: &rpcinfo.TraceController{}, } svrTransHdlr, _ = newSvrTransHandler(svrOpt) transSvr = NewTransServerFactory().NewTransServer(svrOpt, svrTransHdlr).(*transServer) diff --git a/pkg/remote/trans/netpollmux/mocks_test.go b/pkg/remote/trans/netpollmux/mocks_test.go index fd0be796b3..1bd39da128 100644 --- a/pkg/remote/trans/netpollmux/mocks_test.go +++ b/pkg/remote/trans/netpollmux/mocks_test.go @@ -72,6 +72,7 @@ var _ remote.Message = &MockMessage{} type MockMessage struct { RPCInfoFunc func() rpcinfo.RPCInfo ServiceInfoFunc func() *serviceinfo.ServiceInfo + SetServiceInfoFunc func(svcName, methodName string) (*serviceinfo.ServiceInfo, error) DataFunc func() interface{} NewDataFunc func(method string) (ok bool) MessageTypeFunc func() remote.MessageType @@ -102,6 +103,13 @@ func (m *MockMessage) ServiceInfo() (si *serviceinfo.ServiceInfo) { return } +func (m *MockMessage) SpecifyServiceInfo(svcName, methodName string) (si *serviceinfo.ServiceInfo, err error) { + if m.SetServiceInfoFunc != nil { + return m.SetServiceInfoFunc(svcName, methodName) + } + return nil, nil +} + func (m *MockMessage) Data() interface{} { if m.DataFunc != nil { return m.DataFunc() diff --git a/pkg/remote/trans/netpollmux/server_handler.go b/pkg/remote/trans/netpollmux/server_handler.go index 2ddfcc4fcb..028954fc74 100644 --- a/pkg/remote/trans/netpollmux/server_handler.go +++ b/pkg/remote/trans/netpollmux/server_handler.go @@ -62,12 +62,12 @@ func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remo } func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) { - svcInfo := trans.GetDefaultSvcInfo(opt.SvcMap) svrHdlr := &svrTransHandler{ - opt: opt, - codec: opt.Codec, - svcInfo: svcInfo, - ext: np.NewNetpollConnExtension(), + opt: opt, + codec: opt.Codec, + svcSearchMap: opt.SvcSearchMap, + targetSvcInfo: opt.TargetSvcInfo, + ext: np.NewNetpollConnExtension(), } if svrHdlr.opt.TracerCtl == nil { // init TraceCtl when it is nil, or it will lead some unit tests panic @@ -83,15 +83,16 @@ func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) { var _ remote.ServerTransHandler = &svrTransHandler{} type svrTransHandler struct { - opt *remote.ServerOption - svcInfo *serviceinfo.ServiceInfo - inkHdlFunc endpoint.Endpoint - codec remote.Codec - transPipe *remote.TransPipeline - ext trans.Extension - funcPool sync.Pool - conns sync.Map - tasks sync.WaitGroup + opt *remote.ServerOption + svcSearchMap map[string]*serviceinfo.ServiceInfo + targetSvcInfo *serviceinfo.ServiceInfo + inkHdlFunc endpoint.Endpoint + codec remote.Codec + transPipe *remote.TransPipeline + ext trans.Extension + funcPool sync.Pool + conns sync.Map + tasks sync.WaitGroup } // Write implements the remote.ServerTransHandler interface. @@ -102,11 +103,15 @@ func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, sendMsg remo rpcinfo.Record(ctx, ri, stats.WriteFinish, nil) }() - if methodInfo, _ := trans.GetMethodInfo(ri, t.svcInfo); methodInfo != nil { - if methodInfo.OneWay() { - return ctx, nil + svcInfo := sendMsg.ServiceInfo() + if svcInfo != nil { + if methodInfo, _ := trans.GetMethodInfo(ri, svcInfo); methodInfo != nil { + if methodInfo.OneWay() { + return ctx, nil + } } } + wbuf := netpoll.NewLinkBuffer() bufWriter := np.NewWriterByteBuffer(wbuf) err = t.codec.Encode(ctx, sendMsg, bufWriter) @@ -234,7 +239,7 @@ func (t *svrTransHandler) task(muxSvrConnCtx context.Context, conn net.Conn, rea }() // read - recvMsg = remote.NewMessageWithNewer(t.svcInfo, rpcInfo, remote.Call, remote.Server) + recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, rpcInfo, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName) bufReader := np.NewReaderByteBuffer(reader) err = t.readWithByteBuffer(ctx, bufReader, recvMsg) if err != nil { @@ -246,19 +251,21 @@ func (t *svrTransHandler) task(muxSvrConnCtx context.Context, conn net.Conn, rea return } + svcInfo := recvMsg.ServiceInfo() + t.targetSvcInfo = svcInfo if recvMsg.MessageType() == remote.Heartbeat { - sendMsg = remote.NewMessage(nil, t.svcInfo, rpcInfo, remote.Heartbeat, remote.Server) + sendMsg = remote.NewMessage(nil, svcInfo, rpcInfo, remote.Heartbeat, remote.Server) } else { var methodInfo serviceinfo.MethodInfo - if methodInfo, err = trans.GetMethodInfo(rpcInfo, t.svcInfo); err != nil { + if methodInfo, err = trans.GetMethodInfo(rpcInfo, svcInfo); err != nil { closeConn = t.writeErrorReplyIfNeeded(ctx, recvMsg, muxSvrConn, rpcInfo, err, true) t.OnError(ctx, err, muxSvrConn) return } if methodInfo.OneWay() { - sendMsg = remote.NewMessage(nil, t.svcInfo, rpcInfo, remote.Reply, remote.Server) + sendMsg = remote.NewMessage(nil, svcInfo, rpcInfo, remote.Reply, remote.Server) } else { - sendMsg = remote.NewMessage(methodInfo.NewResult(), t.svcInfo, rpcInfo, remote.Reply, remote.Server) + sendMsg = remote.NewMessage(methodInfo.NewResult(), svcInfo, rpcInfo, remote.Reply, remote.Server) } ctx, err = t.transPipe.OnMessage(ctx, recvMsg, sendMsg) @@ -322,7 +329,8 @@ func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error { iv.SetSeqID(0) ri := rpcinfo.NewRPCInfo(nil, nil, iv, nil, nil) data := NewControlFrame() - msg := remote.NewMessage(data, t.svcInfo, ri, remote.Reply, remote.Server) + svcInfo := t.getSvcInfo() + msg := remote.NewMessage(data, svcInfo, ri, remote.Reply, remote.Server) msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift)) msg.TransInfo().TransStrInfo()[transmeta.HeaderConnectionReadyToReset] = "1" @@ -405,16 +413,19 @@ func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) { func (t *svrTransHandler) writeErrorReplyIfNeeded( ctx context.Context, recvMsg remote.Message, conn net.Conn, ri rpcinfo.RPCInfo, err error, doOnMessage bool, ) (shouldCloseConn bool) { - if methodInfo, _ := trans.GetMethodInfo(ri, t.svcInfo); methodInfo != nil { - if methodInfo.OneWay() { - return + svcInfo := recvMsg.ServiceInfo() + if svcInfo != nil { + if methodInfo, _ := trans.GetMethodInfo(ri, svcInfo); methodInfo != nil { + if methodInfo.OneWay() { + return + } } } transErr, isTransErr := err.(*remote.TransError) if !isTransErr { return } - errMsg := remote.NewMessage(transErr, t.svcInfo, ri, remote.Exception, remote.Server) + errMsg := remote.NewMessage(transErr, svcInfo, ri, remote.Exception, remote.Server) remote.FillSendMsgFromRecvMsg(recvMsg, errMsg) if doOnMessage { // if error happen before normal OnMessage, exec it to transfer header trans info into rpcinfo @@ -468,6 +479,17 @@ func (t *svrTransHandler) finishTracer(ctx context.Context, ri rpcinfo.RPCInfo, rpcStats.SetLevel(sl) } +// getSvcInfo is used to get one ServiceInfo +func (t *svrTransHandler) getSvcInfo() *serviceinfo.ServiceInfo { + if t.targetSvcInfo != nil { + return t.targetSvcInfo + } + for _, svcInfo := range t.svcSearchMap { + return svcInfo + } + return nil +} + func getRemoteInfo(ri rpcinfo.RPCInfo, conn net.Conn) (string, net.Addr) { rAddr := conn.RemoteAddr() if ri == nil { diff --git a/pkg/remote/trans/netpollmux/server_handler_test.go b/pkg/remote/trans/netpollmux/server_handler_test.go index c4d6893194..812b1747ad 100644 --- a/pkg/remote/trans/netpollmux/server_handler_test.go +++ b/pkg/remote/trans/netpollmux/server_handler_test.go @@ -42,6 +42,18 @@ var ( addrStr = "test addr" addr = utils.NewNetAddr("tcp", addrStr) method = "mock" + + svcInfo = mocks.ServiceInfo() + svcSearchMap = map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, + mocks.MockMethod: svcInfo, + mocks.MockExceptionMethod: svcInfo, + mocks.MockErrorMethod: svcInfo, + mocks.MockOnewayMethod: svcInfo, + } ) func newTestRpcInfo() rpcinfo.RPCInfo { @@ -61,8 +73,6 @@ func newTestRpcInfo() rpcinfo.RPCInfo { func init() { body := "hello world" rpcInfo := newTestRpcInfo() - svcMap := map[string]*serviceinfo.ServiceInfo{} - svcMap[mocks.MockServiceName] = mocks.ServiceInfo() opt = &remote.ServerOption{ InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { @@ -77,10 +87,12 @@ func init() { DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { in.Skip(3 * codec.Size32) _, err := in.ReadString(len(body)) + msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) return err }, }, - SvcMap: svcMap, + SvcSearchMap: svcSearchMap, + TargetSvcInfo: svcInfo, TracerCtl: &rpcinfo.TraceController{}, ReadWriteTimeout: rwTimeout, } @@ -146,6 +158,13 @@ func TestMuxSvrWrite(t *testing.T) { RPCInfoFunc: func() rpcinfo.RPCInfo { return rpcInfo }, + ServiceInfoFunc: func() *serviceinfo.ServiceInfo { + return &serviceinfo.ServiceInfo{ + Methods: map[string]serviceinfo.MethodInfo{ + "method": serviceinfo.NewMethodInfo(nil, nil, nil, false), + }, + } + }, } // 2. test @@ -444,8 +463,6 @@ func TestInvokeError(t *testing.T) { } body := "hello world" - svcMap := map[string]*serviceinfo.ServiceInfo{} - svcMap[mocks.MockServiceName] = mocks.ServiceInfo() opt := &remote.ServerOption{ InitOrResetRPCInfoFunc: func(rpcInfo rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { fromInfo := rpcinfo.EmptyEndpointInfo() @@ -467,10 +484,12 @@ func TestInvokeError(t *testing.T) { DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { in.Skip(3 * codec.Size32) _, err := in.ReadString(len(body)) + msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) return err }, }, - SvcMap: svcMap, + SvcSearchMap: svcSearchMap, + TargetSvcInfo: svcInfo, TracerCtl: &rpcinfo.TraceController{}, ReadWriteTimeout: rwTimeout, } @@ -628,7 +647,7 @@ func TestInvokeNoMethod(t *testing.T) { pl := remote.NewTransPipeline(svrTransHdlr) svrTransHdlr.SetPipeline(pl) - svcInfo := opt.SvcMap[mocks.MockServiceName] + svcInfo = opt.TargetSvcInfo delete(svcInfo.Methods, method) if setter, ok := svrTransHdlr.(remote.InvokeHandleFuncSetter); ok { @@ -676,8 +695,6 @@ func TestMuxSvrOnReadHeartbeat(t *testing.T) { ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo) // use newOpt cause we need to add heartbeat logic to EncodeFunc and DecodeFunc - svcMap := map[string]*serviceinfo.ServiceInfo{} - svcMap[mocks.MockServiceName] = mocks.ServiceInfo() newOpt := &remote.ServerOption{ InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { return rpcInfo @@ -704,7 +721,8 @@ func TestMuxSvrOnReadHeartbeat(t *testing.T) { return err }, }, - SvcMap: svcMap, + SvcSearchMap: svcSearchMap, + TargetSvcInfo: svcInfo, TracerCtl: &rpcinfo.TraceController{}, ReadWriteTimeout: rwTimeout, } diff --git a/pkg/remote/trans/nphttp2/mocks_test.go b/pkg/remote/trans/nphttp2/mocks_test.go index 9c231fbbab..f0aeaf3bf2 100644 --- a/pkg/remote/trans/nphttp2/mocks_test.go +++ b/pkg/remote/trans/nphttp2/mocks_test.go @@ -296,7 +296,7 @@ func newMockConnOption() remote.ConnOption { func newMockServerOption() *remote.ServerOption { return &remote.ServerOption{ - SvcMap: nil, + SvcSearchMap: nil, TransServerFactory: nil, SvrHandlerFactory: nil, Codec: nil, @@ -409,6 +409,7 @@ var _ remote.Message = &mockMessage{} type mockMessage struct { RPCInfoFunc func() rpcinfo.RPCInfo ServiceInfoFunc func() *serviceinfo.ServiceInfo + SetServiceInfoFunc func(svcName, methodName string) (*serviceinfo.ServiceInfo, error) DataFunc func() interface{} NewDataFunc func(method string) (ok bool) MessageTypeFunc func() remote.MessageType @@ -439,6 +440,13 @@ func (m *mockMessage) ServiceInfo() (si *serviceinfo.ServiceInfo) { return } +func (m *mockMessage) SpecifyServiceInfo(svcName, methodName string) (si *serviceinfo.ServiceInfo, err error) { + if m.SetServiceInfoFunc != nil { + return m.SetServiceInfoFunc(svcName, methodName) + } + return nil, nil +} + func (m *mockMessage) Data() interface{} { if m.DataFunc != nil { return m.DataFunc() diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index 1d70ad3a11..982f7f26ef 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -59,19 +59,19 @@ func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remo func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) { return &svrTransHandler{ - opt: opt, - svcMap: opt.SvcMap, - codec: grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)), + opt: opt, + svcSearchMap: opt.SvcSearchMap, + codec: grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)), }, nil } var _ remote.ServerTransHandler = &svrTransHandler{} type svrTransHandler struct { - opt *remote.ServerOption - svcMap map[string]*serviceinfo.ServiceInfo - inkHdlFunc endpoint.Endpoint - codec remote.Codec + opt *remote.ServerOption + svcSearchMap map[string]*serviceinfo.ServiceInfo + inkHdlFunc endpoint.Endpoint + codec remote.Codec } var prefaceReadAtMost = func() int { @@ -193,7 +193,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { // set send grpc compressor at server to encode reply pack remote.SetSendCompressor(ri, s.SendCompress()) - svcInfo := t.svcMap[serviceName] + svcInfo := t.svcSearchMap[remote.BuildMultiServiceKey(serviceName, methodName)] var methodInfo serviceinfo.MethodInfo if svcInfo != nil { methodInfo = svcInfo.MethodInfo(methodName) diff --git a/pkg/remote/trans_errors.go b/pkg/remote/trans_errors.go index 2b226a49f9..6adcf992ce 100644 --- a/pkg/remote/trans_errors.go +++ b/pkg/remote/trans_errors.go @@ -38,6 +38,7 @@ const ( UnsupportedClientType = 10 // kitex's own type id from number 20 UnknownService = 20 + NoServiceName = 21 ) var defaultTransErrorMessage = map[int32]string{ diff --git a/pkg/remote/transmeta/metakey.go b/pkg/remote/transmeta/metakey.go index 24799d0465..a2ad43ce3d 100644 --- a/pkg/remote/transmeta/metakey.go +++ b/pkg/remote/transmeta/metakey.go @@ -51,6 +51,7 @@ const ( // key of header transport const ( + HeaderIDLServiceName = "isn" HeaderTransRemoteAddr = "rip" HeaderTransToCluster = "tc" HeaderTransToIDC = "ti" diff --git a/pkg/transmeta/ttheader.go b/pkg/transmeta/ttheader.go index 832b7b2e01..801ba3f025 100644 --- a/pkg/transmeta/ttheader.go +++ b/pkg/transmeta/ttheader.go @@ -76,6 +76,7 @@ func (ch *clientTTHeaderHandler) WriteMeta(ctx context.Context, msg remote.Messa } transInfo.PutTransIntInfo(hd) + transInfo.PutTransStrInfo(map[string]string{transmeta.HeaderIDLServiceName: ri.Invocation().ServiceName()}) return ctx, nil } diff --git a/pkg/transmeta/ttheader_test.go b/pkg/transmeta/ttheader_test.go index 24b26c5782..c01963e4dd 100644 --- a/pkg/transmeta/ttheader_test.go +++ b/pkg/transmeta/ttheader_test.go @@ -65,6 +65,8 @@ func TestTTHeaderClientWriteMetainfo(t *testing.T) { kvs := msg.TransInfo().TransIntInfo() test.Assert(t, err == nil) test.Assert(t, len(kvs) == 0) + strKvs := msg.TransInfo().TransStrInfo() + test.Assert(t, len(strKvs) == 0) // ttheader msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift)) @@ -79,6 +81,9 @@ func TestTTHeaderClientWriteMetainfo(t *testing.T) { test.Assert(t, kvs[transmeta.MsgType] == strconv.Itoa(int(remote.Call))) test.Assert(t, kvs[transmeta.TransportType] == unframedTransportType) test.Assert(t, kvs[transmeta.RPCTimeout] == "100") + strKvs = msg.TransInfo().TransStrInfo() + test.Assert(t, len(strKvs) == 1) + test.Assert(t, strKvs[transmeta.HeaderIDLServiceName] == "") } func TestTTHeaderServerReadMetainfo(t *testing.T) { diff --git a/server/invoke.go b/server/invoke.go index ea8563b0e6..8fada9d0c9 100644 --- a/server/invoke.go +++ b/server/invoke.go @@ -35,7 +35,7 @@ type InvokeCaller interface { // Invoker is the abstraction for invoker. type Invoker interface { - RegisterService(svcInfo *serviceinfo.ServiceInfo, handler interface{}) error + RegisterService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, opts ...RegisterOption) error Init() (err error) InvokeCaller } diff --git a/server/option.go b/server/option.go index a6341a6b22..484fa64eb2 100644 --- a/server/option.go +++ b/server/option.go @@ -356,3 +356,12 @@ func WithContextBackup(enable, async bool) Option { o.BackupOpt.EnableImplicitlyTransmitAsync = async }} } + +// WithRefuseTrafficWithoutServiceName returns an Option that only accepts traffics with service name. +// This is used for a server with multi services and is one of the options to avoid a server startup error +// when having conflicting method names between services without specifying a fallback service for the method. +func WithRefuseTrafficWithoutServiceName() Option { + return Option{F: func(o *internal_server.Options, di *utils.Slice) { + o.RefuseTrafficWithoutServiceName = true + }} +} diff --git a/server/option_advanced_test.go b/server/option_advanced_test.go index c18d25ec43..565e21352a 100644 --- a/server/option_advanced_test.go +++ b/server/option_advanced_test.go @@ -245,7 +245,7 @@ func TestWithSupportedTransportsFunc(t *testing.T) { svcInfo := mocks.ServiceInfo() svr.RegisterService(svcInfo, new(mockImpl)) svr.(*server).fillMoreServiceInfo(nil) - test.Assert(t, reflect.DeepEqual(svr.GetServiceInfos()[svcInfo.ServiceName].Extra["transports"], tcase.wantTransports)) + test.Assert(t, reflect.DeepEqual(svr.GetServiceInfos()[remote.BuildMultiServiceKey(svcInfo.ServiceName, mocks.MockMethod)].Extra["transports"], tcase.wantTransports)) } } diff --git a/server/option_test.go b/server/option_test.go index 946e6f6ec1..5ccb28b16c 100644 --- a/server/option_test.go +++ b/server/option_test.go @@ -35,6 +35,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" "github.com/cloudwego/kitex/pkg/utils" ) @@ -440,7 +441,18 @@ func TestWithProfilerMessageTagging(t *testing.T) { to := rpcinfo.NewEndpointInfo("callee", "method", nil, nil) ri := rpcinfo.NewRPCInfo(from, to, nil, nil, nil) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - msg := remote.NewMessageWithNewer(mocks.ServiceInfo(), ri, remote.Call, remote.Server) + svcInfo := mocks.ServiceInfo() + svcSearchMap := map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, + mocks.MockMethod: svcInfo, + mocks.MockExceptionMethod: svcInfo, + mocks.MockErrorMethod: svcInfo, + mocks.MockOnewayMethod: svcInfo, + } + msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) newCtx, tags := iSvr.opt.RemoteOpt.ProfilerMessageTagging(ctx, msg) test.Assert(t, len(tags) == 8) @@ -448,3 +460,17 @@ func TestWithProfilerMessageTagging(t *testing.T) { test.Assert(t, newCtx.Value("ctx1").(int) == 1) test.Assert(t, newCtx.Value("ctx2").(int) == 2) } + +func TestRefuseTrafficWithoutServiceNamOption(t *testing.T) { + svr := NewServer(WithRefuseTrafficWithoutServiceName()) + err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + test.Assert(t, err == nil, err) + time.AfterFunc(100*time.Millisecond, func() { + err := svr.Stop() + test.Assert(t, err == nil, err) + }) + err = svr.Run() + test.Assert(t, err == nil, err) + iSvr := svr.(*server) + test.Assert(t, iSvr.opt.RefuseTrafficWithoutServiceName) +} diff --git a/server/register_option.go b/server/register_option.go new file mode 100644 index 0000000000..92e327e461 --- /dev/null +++ b/server/register_option.go @@ -0,0 +1,33 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + internal_server "github.com/cloudwego/kitex/internal/server" +) + +// RegisterOption is the only way to config service registration. +type RegisterOption = internal_server.RegisterOption + +// RegisterOptions is used to config service registration. +type RegisterOptions = internal_server.RegisterOptions + +func WithFallbackService() RegisterOption { + return RegisterOption{F: func(o *internal_server.RegisterOptions) { + o.IsFallbackService = true + }} +} diff --git a/server/register_option_test.go b/server/register_option_test.go new file mode 100644 index 0000000000..0e6e4d9587 --- /dev/null +++ b/server/register_option_test.go @@ -0,0 +1,30 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "testing" + + internal_server "github.com/cloudwego/kitex/internal/server" + "github.com/cloudwego/kitex/internal/test" +) + +func TestWithFallbackService(t *testing.T) { + opts := []RegisterOption{WithFallbackService()} + registerOpts := internal_server.NewRegisterOptions(opts) + test.Assert(t, registerOpts.IsFallbackService) +} diff --git a/server/server.go b/server/server.go index fd7a9c0b04..382d7ef986 100644 --- a/server/server.go +++ b/server/server.go @@ -47,18 +47,19 @@ import ( "github.com/cloudwego/kitex/pkg/stats" ) -// Server is a abstraction of a RPC server. It accepts connections and dispatches them to the service +// Server is an abstraction of an RPC server. It accepts connections and dispatches them to the service // registered to it. type Server interface { - RegisterService(svcInfo *serviceinfo.ServiceInfo, handler interface{}) error + RegisterService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, opts ...RegisterOption) error GetServiceInfos() map[string]*serviceinfo.ServiceInfo Run() error Stop() error } type server struct { - opt *internal_server.Options - svcs *services + opt *internal_server.Options + svcs *services + targetSvcInfo *serviceinfo.ServiceInfo // actual rpc service implement of biz eps endpoint.Endpoint @@ -183,7 +184,7 @@ func (s *server) buildInvokeChain() { } // RegisterService should not be called by users directly. -func (s *server) RegisterService(svcInfo *serviceinfo.ServiceInfo, handler interface{}) error { +func (s *server) RegisterService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, opts ...RegisterOption) error { s.Lock() defer s.Unlock() if s.isRun { @@ -195,16 +196,19 @@ func (s *server) RegisterService(svcInfo *serviceinfo.ServiceInfo, handler inter if handler == nil || reflect.ValueOf(handler).IsNil() { panic("handler is nil. please specify non-nil handler") } - if s.svcs.getService(svcInfo.ServiceName) != nil { + if s.svcs.svcMap[svcInfo.ServiceName] != nil { panic(fmt.Sprintf("Service[%s] is already defined", svcInfo.ServiceName)) } - s.svcs.addService(svcInfo, handler) + registerOpts := internal_server.NewRegisterOptions(opts) + if err := s.svcs.addService(svcInfo, handler, registerOpts); err != nil { + panic(err.Error()) + } return nil } func (s *server) GetServiceInfos() map[string]*serviceinfo.ServiceInfo { - return s.svcs.getSvcInfoMap() + return s.svcs.getSvcInfoSearchMap() } // Run runs the server. @@ -215,7 +219,11 @@ func (s *server) Run() (err error) { if err = s.check(); err != nil { return err } + s.findAndSetDefaultService() diagnosis.RegisterProbeFunc(s.opt.DebugService, diagnosis.ServiceInfosKey, diagnosis.WrapAsProbeFunc(s.svcs.getSvcInfoMap())) + if s.svcs.fallbackSvc != nil { + diagnosis.RegisterProbeFunc(s.opt.DebugService, diagnosis.FallbackServiceKey, diagnosis.WrapAsProbeFunc(s.svcs.fallbackSvc.svcInfo.ServiceName)) + } svrCfg := s.opt.RemoteOpt addr := svrCfg.Address // should not be nil if s.opt.Proxy != nil { @@ -306,10 +314,7 @@ func (s *server) invokeHandleEndpoint() endpoint.Endpoint { ri := rpcinfo.GetRPCInfo(ctx) methodName := ri.Invocation().MethodName() serviceName := ri.Invocation().ServiceName() - svc := s.svcs.getService(serviceName) - if svc == nil { - svc = s.svcs.defaultSvc - } + svc := s.svcs.svcMap[serviceName] svcInfo := svc.svcInfo if methodName == "" && svcInfo.ServiceName != serviceinfo.GenericService { return errors.New("method name is empty in rpcinfo, should not happen") @@ -348,7 +353,9 @@ func (s *server) invokeHandleEndpoint() endpoint.Endpoint { func (s *server) initBasicRemoteOption() { remoteOpt := s.opt.RemoteOpt - remoteOpt.SvcMap = s.svcs.getSvcInfoMap() + remoteOpt.TargetSvcInfo = s.targetSvcInfo + remoteOpt.SvcSearchMap = s.svcs.getSvcInfoSearchMap() + remoteOpt.RefuseTrafficWithoutServiceName = s.opt.RefuseTrafficWithoutServiceName remoteOpt.InitOrResetRPCInfoFunc = s.initOrResetRPCInfoFunc() remoteOpt.TracerCtl = s.opt.TracerCtl remoteOpt.ReadWriteTimeout = s.opt.Configs.ReadWriteTimeout() @@ -436,7 +443,7 @@ func (s *server) check() error { if len(s.svcs.svcMap) == 0 { return errors.New("run: no service. Use RegisterService to set one") } - return nil + return checkFallbackServiceForConflictingMethods(s.svcs.conflictingMethodHasFallbackSvcMap, s.opt.RefuseTrafficWithoutServiceName) } func doAddBoundHandlerToHead(h remote.BoundHandler, opt *remote.ServerOption) { @@ -502,7 +509,7 @@ func (s *server) buildRegistryInfo(lAddr net.Addr) { info.ServiceName = s.opt.Svr.ServiceName } if info.PayloadCodec == "" { - info.PayloadCodec = s.svcs.defaultSvc.svcInfo.PayloadCodec.String() + info.PayloadCodec = getDefaultSvcInfo(s.svcs).PayloadCodec.String() } if info.Weight == 0 { info.Weight = discovery.DefaultWeight @@ -548,3 +555,32 @@ func (s *server) waitExit(errCh chan error) error { } } } + +func (s *server) findAndSetDefaultService() { + if len(s.svcs.svcMap) == 1 { + s.targetSvcInfo = getDefaultSvcInfo(s.svcs) + } +} + +// getDefaultSvc is used to get one ServiceInfo from map +func getDefaultSvcInfo(svcs *services) *serviceinfo.ServiceInfo { + if len(svcs.svcMap) > 1 && svcs.fallbackSvc != nil { + return svcs.fallbackSvc.svcInfo + } + for _, svc := range svcs.svcMap { + return svc.svcInfo + } + return nil +} + +func checkFallbackServiceForConflictingMethods(conflictingMethodHasFallbackSvcMap map[string]bool, refuseTrafficWithoutServiceName bool) error { + if refuseTrafficWithoutServiceName { + return nil + } + for name, hasFallbackSvc := range conflictingMethodHasFallbackSvcMap { + if !hasFallbackSvc { + return fmt.Errorf("method name [%s] is conflicted between services but no fallback service is specified", name) + } + } + return nil +} diff --git a/server/server_test.go b/server/server_test.go index eadd9a5951..f932054cbc 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -53,6 +53,20 @@ import ( "github.com/cloudwego/kitex/transport" ) +var ( + svcInfo = mocks.ServiceInfo() + svcSearchMap = map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, + remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, + mocks.MockMethod: svcInfo, + mocks.MockExceptionMethod: svcInfo, + mocks.MockErrorMethod: svcInfo, + mocks.MockOnewayMethod: svcInfo, + } +) + func TestServerRun(t *testing.T) { var opts []Option opts = append(opts, WithMetaHandler(noopMetahandler{})) @@ -430,8 +444,8 @@ func TestGRPCServerMultipleServices(t *testing.T) { test.Assert(t, err == nil) err = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler()) test.Assert(t, err == nil) - test.DeepEqual(t, svr.GetServiceInfos()[mocks.ServiceInfo().ServiceName], mocks.ServiceInfo()) - test.DeepEqual(t, svr.GetServiceInfos()[mocks.Service2Info().ServiceName], mocks.Service2Info()) + test.DeepEqual(t, svr.GetServiceInfos()[mocks.MockMethod], mocks.ServiceInfo()) + test.DeepEqual(t, svr.GetServiceInfos()[mocks.Mock2Method], mocks.Service2Info()) time.AfterFunc(1000*time.Millisecond, func() { err := svr.Stop() test.Assert(t, err == nil, err) @@ -635,7 +649,6 @@ func testInvokeHandlerWithSession(t *testing.T, fail bool, ad string) { opts = append(opts, WithCodec(&mockCodec{})) transHdlrFact := &mockSvrTransHandlerFactory{} - svcInfo := mocks.ServiceInfo() exitCh := make(chan bool) var ln net.Listener transSvr := &mocks.MockTransServer{ @@ -643,7 +656,7 @@ func testInvokeHandlerWithSession(t *testing.T, fail bool, ad string) { { // mock server call ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, ri, remote.Call, remote.Server) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) recvMsg.NewData(callMethod) sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) @@ -723,7 +736,6 @@ func TestInvokeHandlerExec(t *testing.T) { })) opts = append(opts, WithCodec(&mockCodec{})) transHdlrFact := &mockSvrTransHandlerFactory{} - svcInfo := mocks.ServiceInfo() exitCh := make(chan bool) var ln net.Listener transSvr := &mocks.MockTransServer{ @@ -731,7 +743,7 @@ func TestInvokeHandlerExec(t *testing.T) { { // mock server call ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, ri, remote.Call, remote.Server) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) recvMsg.NewData(callMethod) sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) @@ -786,7 +798,6 @@ func TestInvokeHandlerPanic(t *testing.T) { })) opts = append(opts, WithCodec(&mockCodec{})) transHdlrFact := &mockSvrTransHandlerFactory{} - svcInfo := mocks.ServiceInfo() exitCh := make(chan bool) var ln net.Listener transSvr := &mocks.MockTransServer{ @@ -795,7 +806,7 @@ func TestInvokeHandlerPanic(t *testing.T) { // mock server call ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, ri, remote.Call, remote.Server) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) recvMsg.NewData(callMethod) sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) @@ -836,6 +847,83 @@ func TestInvokeHandlerPanic(t *testing.T) { test.Assert(t, serviceHandler) } +func TestRegisterService(t *testing.T) { + svr := NewServer() + time.AfterFunc(time.Second, func() { + err := svr.Stop() + test.Assert(t, err == nil, err) + }) + + svr.Run() + + test.PanicAt(t, func() { + _ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + }, func(err interface{}) bool { + if errMsg, ok := err.(string); ok { + return strings.Contains(errMsg, "server is running") + } + return true + }) + svr.Stop() + + svr = NewServer() + time.AfterFunc(time.Second, func() { + err := svr.Stop() + test.Assert(t, err == nil, err) + }) + + test.PanicAt(t, func() { + _ = svr.RegisterService(nil, mocks.MyServiceHandler()) + }, func(err interface{}) bool { + if errMsg, ok := err.(string); ok { + return strings.Contains(errMsg, "svcInfo is nil") + } + return true + }) + + test.PanicAt(t, func() { + _ = svr.RegisterService(mocks.ServiceInfo(), nil) + }, func(err interface{}) bool { + if errMsg, ok := err.(string); ok { + return strings.Contains(errMsg, "handler is nil") + } + return true + }) + + test.PanicAt(t, func() { + _ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler(), WithFallbackService()) + _ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + }, func(err interface{}) bool { + if errMsg, ok := err.(string); ok { + return strings.Contains(errMsg, "Service[MockService] is already defined") + } + return true + }) + + test.PanicAt(t, func() { + _ = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler(), WithFallbackService()) + }, func(err interface{}) bool { + if errMsg, ok := err.(string); ok { + return strings.Contains(errMsg, "multiple fallback services cannot be registered") + } + return true + }) + svr.Stop() + + svr = NewServer() + time.AfterFunc(time.Second, func() { + err := svr.Stop() + test.Assert(t, err == nil, err) + }) + + _ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + _ = svr.RegisterService(mocks.Service3Info(), mocks.MyServiceHandler()) + err := svr.Run() + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "method name [mock] is conflicted between services but no fallback service is specified") + svr.Stop() +} + type noopMetahandler struct{} func (noopMetahandler) WriteMeta(ctx context.Context, msg remote.Message) (context.Context, error) { @@ -907,7 +995,7 @@ func (m *mockCodec) Decode(ctx context.Context, msg remote.Message, in remote.By func TestDuplicatedRegisterInfoPanic(t *testing.T) { svcs := newServices() - svcs.addService(mocks.ServiceInfo(), nil) + svcs.addService(mocks.ServiceInfo(), nil, &RegisterOptions{}) s := &server{ opt: internal_server.NewOptions(nil), svcs: svcs, diff --git a/server/service.go b/server/service.go index 0624477d8d..4d679917a2 100644 --- a/server/service.go +++ b/server/service.go @@ -16,7 +16,13 @@ package server -import "github.com/cloudwego/kitex/pkg/serviceinfo" +import ( + "errors" + "fmt" + + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/serviceinfo" +) type service struct { svcInfo *serviceinfo.ServiceInfo @@ -28,24 +34,88 @@ func newService(svcInfo *serviceinfo.ServiceInfo, handler interface{}) *service } type services struct { - svcMap map[string]*service - defaultSvc *service + svcSearchMap map[string]*service // key: "svcName.methodName" and "methodName", value: svcInfo + svcMap map[string]*service // key: service name, value: svcInfo + conflictingMethodHasFallbackSvcMap map[string]bool + fallbackSvc *service } func newServices() *services { - return &services{svcMap: map[string]*service{}} + return &services{ + svcSearchMap: map[string]*service{}, + svcMap: map[string]*service{}, + conflictingMethodHasFallbackSvcMap: map[string]bool{}, + } } -func (s *services) addService(svcInfo *serviceinfo.ServiceInfo, handler interface{}) { +func (s *services) addService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, registerOpts *RegisterOptions) error { svc := newService(svcInfo, handler) - if s.defaultSvc == nil { - s.defaultSvc = svc + + if err := s.checkCombineServiceWithOtherService(svcInfo); err != nil { + return err } + + if err := s.checkMultipleFallbackService(registerOpts, svc); err != nil { + return err + } + s.svcMap[svcInfo.ServiceName] = svc + s.createSearchMap(svcInfo, svc, registerOpts) + return nil +} + +// when registering combine service, it does not allow the registration of other services +func (s *services) checkCombineServiceWithOtherService(svcInfo *serviceinfo.ServiceInfo) error { + if len(s.svcMap) > 0 { + if _, ok := s.svcMap["CombineService"]; ok || svcInfo.ServiceName == "CombineService" { + return errors.New("only one service can be registered when registering combine service") + } + } + return nil +} + +func (s *services) checkMultipleFallbackService(registerOpts *RegisterOptions, svc *service) error { + if registerOpts.IsFallbackService { + if s.fallbackSvc != nil { + return fmt.Errorf("multiple fallback services cannot be registered. [%s] is already registered as a fallback service", s.fallbackSvc.svcInfo.ServiceName) + } + s.fallbackSvc = svc + } + return nil } -func (s *services) getService(svcName string) *service { - return s.svcMap[svcName] +func (s *services) createSearchMap(svcInfo *serviceinfo.ServiceInfo, svc *service, registerOpts *RegisterOptions) { + for methodName := range svcInfo.Methods { + s.svcSearchMap[remote.BuildMultiServiceKey(svcInfo.ServiceName, methodName)] = svc + if svcFromMap, ok := s.svcSearchMap[methodName]; ok { + s.handleConflictingMethod(svcFromMap, svc, methodName, registerOpts) + } else { + s.svcSearchMap[methodName] = svc + } + } +} + +func (s *services) handleConflictingMethod(svcFromMap, svc *service, methodName string, registerOpts *RegisterOptions) { + s.registerConflictingMethodHasFallbackSvcMap(svcFromMap, methodName) + s.updateWithFallbackSvc(registerOpts, svc, methodName) +} + +func (s *services) registerConflictingMethodHasFallbackSvcMap(svcFromMap *service, methodName string) { + if _, ok := s.conflictingMethodHasFallbackSvcMap[methodName]; !ok { + if s.fallbackSvc != nil && svcFromMap.svcInfo.ServiceName == s.fallbackSvc.svcInfo.ServiceName { + // svc which is already registered is a fallback service + s.conflictingMethodHasFallbackSvcMap[methodName] = true + } else { + s.conflictingMethodHasFallbackSvcMap[methodName] = false + } + } +} + +func (s *services) updateWithFallbackSvc(registerOpts *RegisterOptions, svc *service, methodName string) { + if registerOpts.IsFallbackService { + s.svcSearchMap[methodName] = svc + s.conflictingMethodHasFallbackSvcMap[methodName] = true + } } func (s *services) getSvcInfoMap() map[string]*serviceinfo.ServiceInfo { @@ -55,3 +125,11 @@ func (s *services) getSvcInfoMap() map[string]*serviceinfo.ServiceInfo { } return svcInfoMap } + +func (s *services) getSvcInfoSearchMap() map[string]*serviceinfo.ServiceInfo { + svcInfoSearchMap := map[string]*serviceinfo.ServiceInfo{} + for name, svc := range s.svcSearchMap { + svcInfoSearchMap[name] = svc.svcInfo + } + return svcInfoSearchMap +} diff --git a/server/service_test.go b/server/service_test.go new file mode 100644 index 0000000000..5d800c499c --- /dev/null +++ b/server/service_test.go @@ -0,0 +1,88 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "fmt" + "testing" + + "github.com/cloudwego/kitex/internal/mocks" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +func TestAddService(t *testing.T) { + svcs := newServices() + err := svcs.addService(mocks.ServiceInfo(), mocks.MyServiceHandler(), &RegisterOptions{}) + test.Assert(t, err == nil) + test.Assert(t, len(svcs.svcMap) == 1) + fmt.Println(svcs.svcSearchMap) + test.Assert(t, len(svcs.svcSearchMap) == 10) + test.Assert(t, len(svcs.conflictingMethodHasFallbackSvcMap) == 0) + test.Assert(t, svcs.fallbackSvc == nil) + + err = svcs.addService(mocks.Service3Info(), mocks.MyServiceHandler(), &RegisterOptions{IsFallbackService: true}) + test.Assert(t, err == nil) + test.Assert(t, len(svcs.svcMap) == 2) + test.Assert(t, len(svcs.svcSearchMap) == 11) + test.Assert(t, len(svcs.conflictingMethodHasFallbackSvcMap) == 1) + test.Assert(t, svcs.conflictingMethodHasFallbackSvcMap["mock"]) + + err = svcs.addService(mocks.Service2Info(), mocks.MyServiceHandler(), &RegisterOptions{IsFallbackService: true}) + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "multiple fallback services cannot be registered. [MockService3] is already registered as a fallback service") +} + +func TestCheckCombineServiceWithOtherService(t *testing.T) { + svcs := newServices() + combineSvcInfo := &serviceinfo.ServiceInfo{ServiceName: "CombineService"} + svcs.svcMap[combineSvcInfo.ServiceName] = newService(combineSvcInfo, nil) + err := svcs.checkCombineServiceWithOtherService(mocks.ServiceInfo()) + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "only one service can be registered when registering combine service") + + svcs = newServices() + svcs.svcMap[mocks.MockServiceName] = newService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + err = svcs.checkCombineServiceWithOtherService(combineSvcInfo) + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "only one service can be registered when registering combine service") +} + +func TestCheckMultipleFallbackService(t *testing.T) { + svcs := newServices() + svc := newService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + registerOpts := &RegisterOptions{IsFallbackService: true} + err := svcs.checkMultipleFallbackService(registerOpts, svc) + test.Assert(t, err == nil) + test.Assert(t, svcs.fallbackSvc == svc) + + err = svcs.checkMultipleFallbackService(registerOpts, newService(mocks.Service2Info(), nil)) + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "multiple fallback services cannot be registered. [MockService] is already registered as a fallback service", err) +} + +func TestRegisterConflictingMethodHasFallbackSvcMap(t *testing.T) { + svcs := newServices() + svcFromMap := newService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + svcs.registerConflictingMethodHasFallbackSvcMap(svcFromMap, mocks.MockMethod) + test.Assert(t, !svcs.conflictingMethodHasFallbackSvcMap[mocks.MockMethod]) + + svcs = newServices() + svcs.fallbackSvc = svcFromMap + svcs.registerConflictingMethodHasFallbackSvcMap(svcFromMap, mocks.MockMethod) + test.Assert(t, svcs.conflictingMethodHasFallbackSvcMap[mocks.MockMethod]) +} diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index 43f88dcac8..367ac997e4 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -50,6 +50,7 @@ type PackageInfo struct { Module string Protocol transport.Protocol IDLName string + ServerPkg string } // AddImport . diff --git a/tool/internal_pkg/tpl/client.go b/tool/internal_pkg/tpl/client.go index 96d4674554..666e966898 100644 --- a/tool/internal_pkg/tpl/client.go +++ b/tool/internal_pkg/tpl/client.go @@ -80,6 +80,7 @@ type {{.ServiceName}}_{{.RawName}}Client interface { func NewClient(destService string, opts ...client.Option) (Client, error) { var options []client.Option options = append(options, client.WithDestService(destService)) + {{template "@client.go-NewClient-option" .}} {{if and (eq $.Codec "protobuf") .HasStreaming}}{{/* Thrift Streaming only in StreamClient */}} options = append(options, client.WithTransportProtocol(transport.GRPC)) diff --git a/tool/internal_pkg/tpl/server.go b/tool/internal_pkg/tpl/server.go index 73d0ef3449..d292a97932 100644 --- a/tool/internal_pkg/tpl/server.go +++ b/tool/internal_pkg/tpl/server.go @@ -49,4 +49,8 @@ func NewServer(handler {{call .ServiceTypeName}}, opts ...server.Option) server. return svr } {{template "@server.go-EOF" .}} + +func RegisterService(svr server.Server, handler {{call .ServiceTypeName}}, opts ...server.RegisterOption) error { + return svr.RegisterService(serviceInfo(), handler, opts...) +} `