Skip to content

Commit

Permalink
Merge pull request cloudwego#1244 from felix021/feat/server-handler-t…
Browse files Browse the repository at this point in the history
…imeout

[WIP] feat: server handler timeout
  • Loading branch information
felix021 authored Feb 19, 2024
2 parents 34b88b2 + d4c7726 commit 320839b
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 0 deletions.
1 change: 1 addition & 0 deletions internal/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ type Options struct {
Streaming stream.StreamingConfig

RefuseTrafficWithoutServiceName bool
EnableContextTimeout bool
}

type Limit struct {
Expand Down
47 changes: 47 additions & 0 deletions server/middlewares.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package server

import (
"context"

"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/rpcinfo"
)

func serverTimeoutMW(initCtx context.Context) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request, response interface{}) (err error) {
// Regardless of the underlying protocol, only by checking the RPCTimeout
// For TTHeader, it will be set by transmeta.ServerTTHeaderHandler (not added by default though)
// For GRPC/HTTP2, the timeout deadline is already set in the context, so no need to check it
ri := rpcinfo.GetRPCInfo(ctx)
timeout := ri.Config().RPCTimeout()
if timeout <= 0 {
return next(ctx, request, response)
}

ctx, cancel := context.WithTimeout(ctx, timeout)
defer func() {
if err != nil {
cancel()
}
}()
return next(ctx, request, response)
}
}
}
197 changes: 197 additions & 0 deletions server/middlewares_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package server

import (
"context"
"errors"
"net"
"testing"
"time"

"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/rpcinfo"
)

var _ context.Context = (*mockCtx)(nil)

type mockCtx struct {
err error
ddl time.Time
hasDDL bool
done chan struct{}
data map[interface{}]interface{}
}

func (m *mockCtx) Deadline() (deadline time.Time, ok bool) {
return m.ddl, m.hasDDL
}

func (m *mockCtx) Done() <-chan struct{} {
return m.done
}

func (m *mockCtx) Err() error {
return m.err
}

func (m *mockCtx) Value(key interface{}) interface{} {
return m.data[key]
}

func Test_serverTimeoutMW(t *testing.T) {
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
from := rpcinfo.NewEndpointInfo("from_service", "from_method", addr, nil)
to := rpcinfo.NewEndpointInfo("to_service", "to_method", nil, nil)
newCtxWithRPCInfo := func(timeout time.Duration) context.Context {
cfg := rpcinfo.NewRPCConfig()
_ = rpcinfo.AsMutableRPCConfig(cfg).SetRPCTimeout(timeout)
ri := rpcinfo.NewRPCInfo(from, to, nil, cfg, nil)
return rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
}
timeoutMW := serverTimeoutMW(context.Background())

t.Run("no_timeout(fastPath)", func(t *testing.T) {
// prepare
ctx := newCtxWithRPCInfo(0)

// test
err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
ddl, ok := ctx.Deadline()
test.Assert(t, !ok)
test.Assert(t, ddl.IsZero())
return nil
})(ctx, nil, nil)

// assert
test.Assert(t, err == nil, err)
})

t.Run("finish_before_timeout_without_error", func(t *testing.T) {
// prepare
ctx := newCtxWithRPCInfo(time.Millisecond * 50)
waitFinish := make(chan struct{})

// test
err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
go func() {
timer := time.NewTimer(time.Millisecond * 20)
select {
case <-ctx.Done():
t.Errorf("ctx done, error: %v", ctx.Err())
case <-timer.C:
t.Logf("(expected) ctx not done")
}
waitFinish <- struct{}{}
}()
return nil
})(ctx, nil, nil)

// assert
test.Assert(t, err == nil, err)
<-waitFinish
})

t.Run("finish_before_timeout_with_error", func(t *testing.T) {
// prepare
ctx := newCtxWithRPCInfo(time.Millisecond * 50)
waitFinish := make(chan struct{})

// test
err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
go func() {
timer := time.NewTimer(time.Millisecond * 20)
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
t.Logf("(expected) cancel called")
} else {
t.Errorf("cancel not called, error: %v", ctx.Err())
}
case <-timer.C:
t.Error("ctx not done")
}
waitFinish <- struct{}{}
}()
return errors.New("error")
})(ctx, nil, nil)

// assert
test.Assert(t, err.Error() == "error", err)
<-waitFinish
})

t.Run("finish_after_timeout_without_error", func(t *testing.T) {
// prepare
ctx := newCtxWithRPCInfo(time.Millisecond * 20)
waitFinish := make(chan struct{})

// test
err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
go func() {
timer := time.NewTimer(time.Millisecond * 60)
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Logf("(expected) deadline exceeded")
} else {
t.Error("deadline not exceeded, error: ", ctx.Err())
}
case <-timer.C:
t.Error("ctx not done")
}
waitFinish <- struct{}{}
}()
time.Sleep(time.Millisecond * 40)
return nil
})(ctx, nil, nil)

// assert
test.Assert(t, err == nil, err)
<-waitFinish
})

t.Run("finish_after_timeout_with_error", func(t *testing.T) {
// prepare
ctx := newCtxWithRPCInfo(time.Millisecond * 20)
waitFinish := make(chan struct{})

// test
err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
go func() {
timer := time.NewTimer(time.Millisecond * 60)
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Logf("(expected) deadline exceeded")
} else {
t.Error("deadline not exceeded, error: ", ctx.Err())
}
case <-timer.C:
t.Error("ctx not done")
}
waitFinish <- struct{}{}
}()
time.Sleep(time.Millisecond * 40)
return errors.New("error")
})(ctx, nil, nil)

// assert
test.Assert(t, err.Error() == "error", err)
<-waitFinish
})
}
24 changes: 24 additions & 0 deletions server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,27 @@ func WithRefuseTrafficWithoutServiceName() Option {
o.RefuseTrafficWithoutServiceName = true
}}
}

// WithEnableContextTimeout enables handler timeout.
// Available since Kitex >= v0.9.0
// If enabled, a timeout middleware will be added to the beginning of the middleware chain.
// The timeout value will be read from RPCInfo.Config().RPCTimeout(), which can be set by a custom MetaHandler
// NOTE:
// If there's an error (excluding BizStatusError) returned by server handler or any middleware, cancel will be
// called automatically.
//
// For an opensource Kitex user, TTHeader has builtin support of timeout-passing (not enabled by default):
// - Client side: add the following NewClient options for enabling TTHeader and setting the timeout to TTHeader
// client.WithTransportProtocol(transport.TTHeader),
// client.WithMetaHandler(transmeta.ClientTTHeaderHandler),
// - Server side: add the following NewServer options for reading from TTHeader and enable timeout control
// server.WithMetaHandler(transmeta.ServerTTHeaderHandler)
// server.WithEnableContextTimeout(true)
//
// For requests on GRPC transport, a deadline will be added to the context if the header 'grpc-timeout' is positive,
// so there's no need to use this option.
func WithEnableContextTimeout(enable bool) Option {
return Option{F: func(o *internal_server.Options, di *utils.Slice) {
o.EnableContextTimeout = enable
}}
}
4 changes: 4 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ func NewServer(ops ...Option) Server {

func (s *server) init() {
ctx := fillContext(s.opt)
if s.opt.EnableContextTimeout {
// prepend for adding timeout to the context for all middlewares and the handler
s.opt.MWBs = append([]endpoint.MiddlewareBuilder{serverTimeoutMW}, s.opt.MWBs...)
}
s.mws = richMWsWithBuilder(ctx, s.opt.MWBs, s)
s.mws = append(s.mws, acl.NewACLMiddleware(s.opt.ACLRules))
s.initStreamMiddlewares(ctx)
Expand Down

0 comments on commit 320839b

Please sign in to comment.