diff --git a/internal/server/option.go b/internal/server/option.go index 35aa34dcec..b540f03ce4 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -99,6 +99,7 @@ type Options struct { Streaming stream.StreamingConfig RefuseTrafficWithoutServiceName bool + EnableContextTimeout bool } type Limit struct { diff --git a/server/middlewares.go b/server/middlewares.go new file mode 100644 index 0000000000..a8783b3030 --- /dev/null +++ b/server/middlewares.go @@ -0,0 +1,47 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "context" + + "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +func serverTimeoutMW(initCtx context.Context) endpoint.Middleware { + return func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request, response interface{}) (err error) { + // Regardless of the underlying protocol, only by checking the RPCTimeout + // For TTHeader, it will be set by transmeta.ServerTTHeaderHandler (not added by default though) + // For GRPC/HTTP2, the timeout deadline is already set in the context, so no need to check it + ri := rpcinfo.GetRPCInfo(ctx) + timeout := ri.Config().RPCTimeout() + if timeout <= 0 { + return next(ctx, request, response) + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer func() { + if err != nil { + cancel() + } + }() + return next(ctx, request, response) + } + } +} diff --git a/server/middlewares_test.go b/server/middlewares_test.go new file mode 100644 index 0000000000..5a8fd0aabf --- /dev/null +++ b/server/middlewares_test.go @@ -0,0 +1,197 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +var _ context.Context = (*mockCtx)(nil) + +type mockCtx struct { + err error + ddl time.Time + hasDDL bool + done chan struct{} + data map[interface{}]interface{} +} + +func (m *mockCtx) Deadline() (deadline time.Time, ok bool) { + return m.ddl, m.hasDDL +} + +func (m *mockCtx) Done() <-chan struct{} { + return m.done +} + +func (m *mockCtx) Err() error { + return m.err +} + +func (m *mockCtx) Value(key interface{}) interface{} { + return m.data[key] +} + +func Test_serverTimeoutMW(t *testing.T) { + addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080") + from := rpcinfo.NewEndpointInfo("from_service", "from_method", addr, nil) + to := rpcinfo.NewEndpointInfo("to_service", "to_method", nil, nil) + newCtxWithRPCInfo := func(timeout time.Duration) context.Context { + cfg := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(cfg).SetRPCTimeout(timeout) + ri := rpcinfo.NewRPCInfo(from, to, nil, cfg, nil) + return rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + } + timeoutMW := serverTimeoutMW(context.Background()) + + t.Run("no_timeout(fastPath)", func(t *testing.T) { + // prepare + ctx := newCtxWithRPCInfo(0) + + // test + err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) { + ddl, ok := ctx.Deadline() + test.Assert(t, !ok) + test.Assert(t, ddl.IsZero()) + return nil + })(ctx, nil, nil) + + // assert + test.Assert(t, err == nil, err) + }) + + t.Run("finish_before_timeout_without_error", func(t *testing.T) { + // prepare + ctx := newCtxWithRPCInfo(time.Millisecond * 50) + waitFinish := make(chan struct{}) + + // test + err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) { + go func() { + timer := time.NewTimer(time.Millisecond * 20) + select { + case <-ctx.Done(): + t.Errorf("ctx done, error: %v", ctx.Err()) + case <-timer.C: + t.Logf("(expected) ctx not done") + } + waitFinish <- struct{}{} + }() + return nil + })(ctx, nil, nil) + + // assert + test.Assert(t, err == nil, err) + <-waitFinish + }) + + t.Run("finish_before_timeout_with_error", func(t *testing.T) { + // prepare + ctx := newCtxWithRPCInfo(time.Millisecond * 50) + waitFinish := make(chan struct{}) + + // test + err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) { + go func() { + timer := time.NewTimer(time.Millisecond * 20) + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + t.Logf("(expected) cancel called") + } else { + t.Errorf("cancel not called, error: %v", ctx.Err()) + } + case <-timer.C: + t.Error("ctx not done") + } + waitFinish <- struct{}{} + }() + return errors.New("error") + })(ctx, nil, nil) + + // assert + test.Assert(t, err.Error() == "error", err) + <-waitFinish + }) + + t.Run("finish_after_timeout_without_error", func(t *testing.T) { + // prepare + ctx := newCtxWithRPCInfo(time.Millisecond * 20) + waitFinish := make(chan struct{}) + + // test + err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) { + go func() { + timer := time.NewTimer(time.Millisecond * 60) + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Logf("(expected) deadline exceeded") + } else { + t.Error("deadline not exceeded, error: ", ctx.Err()) + } + case <-timer.C: + t.Error("ctx not done") + } + waitFinish <- struct{}{} + }() + time.Sleep(time.Millisecond * 40) + return nil + })(ctx, nil, nil) + + // assert + test.Assert(t, err == nil, err) + <-waitFinish + }) + + t.Run("finish_after_timeout_with_error", func(t *testing.T) { + // prepare + ctx := newCtxWithRPCInfo(time.Millisecond * 20) + waitFinish := make(chan struct{}) + + // test + err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) { + go func() { + timer := time.NewTimer(time.Millisecond * 60) + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Logf("(expected) deadline exceeded") + } else { + t.Error("deadline not exceeded, error: ", ctx.Err()) + } + case <-timer.C: + t.Error("ctx not done") + } + waitFinish <- struct{}{} + }() + time.Sleep(time.Millisecond * 40) + return errors.New("error") + })(ctx, nil, nil) + + // assert + test.Assert(t, err.Error() == "error", err) + <-waitFinish + }) +} diff --git a/server/option.go b/server/option.go index 484fa64eb2..59e587d7cd 100644 --- a/server/option.go +++ b/server/option.go @@ -365,3 +365,27 @@ func WithRefuseTrafficWithoutServiceName() Option { o.RefuseTrafficWithoutServiceName = true }} } + +// WithEnableContextTimeout enables handler timeout. +// Available since Kitex >= v0.9.0 +// If enabled, a timeout middleware will be added to the beginning of the middleware chain. +// The timeout value will be read from RPCInfo.Config().RPCTimeout(), which can be set by a custom MetaHandler +// NOTE: +// If there's an error (excluding BizStatusError) returned by server handler or any middleware, cancel will be +// called automatically. +// +// For an opensource Kitex user, TTHeader has builtin support of timeout-passing (not enabled by default): +// - Client side: add the following NewClient options for enabling TTHeader and setting the timeout to TTHeader +// client.WithTransportProtocol(transport.TTHeader), +// client.WithMetaHandler(transmeta.ClientTTHeaderHandler), +// - Server side: add the following NewServer options for reading from TTHeader and enable timeout control +// server.WithMetaHandler(transmeta.ServerTTHeaderHandler) +// server.WithEnableContextTimeout(true) +// +// For requests on GRPC transport, a deadline will be added to the context if the header 'grpc-timeout' is positive, +// so there's no need to use this option. +func WithEnableContextTimeout(enable bool) Option { + return Option{F: func(o *internal_server.Options, di *utils.Slice) { + o.EnableContextTimeout = enable + }} +} diff --git a/server/server.go b/server/server.go index 382d7ef986..f2d6e135e2 100644 --- a/server/server.go +++ b/server/server.go @@ -83,6 +83,10 @@ func NewServer(ops ...Option) Server { func (s *server) init() { ctx := fillContext(s.opt) + if s.opt.EnableContextTimeout { + // prepend for adding timeout to the context for all middlewares and the handler + s.opt.MWBs = append([]endpoint.MiddlewareBuilder{serverTimeoutMW}, s.opt.MWBs...) + } s.mws = richMWsWithBuilder(ctx, s.opt.MWBs, s) s.mws = append(s.mws, acl.NewACLMiddleware(s.opt.ACLRules)) s.initStreamMiddlewares(ctx)