Skip to content

Commit

Permalink
feat: add processBeforeValidate and rpc event
Browse files Browse the repository at this point in the history
  • Loading branch information
ppzqh committed Aug 12, 2024
1 parent ef621e0 commit 39ad24c
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 36 deletions.
1 change: 1 addition & 0 deletions pkg/consts/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package consts
// Method key used in context.
const (
CtxKeyMethod = "K_METHOD"
CtxKeyLogID = "K_LOGID"
)

const (
Expand Down
20 changes: 4 additions & 16 deletions pkg/remote/codec/default_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,8 @@ func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, i
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("ttheader read payload first 8 byte failed: %s", err.Error()))
}
if c.payloadValidator != nil {
fillRPCInfoBeforeValidate(message)
if vErr := validate(ctx, message, in, c.payloadValidator); vErr != nil {
return vErr
if pErr := payloadValidate(ctx, c.payloadValidator, in, message); pErr != nil {
return pErr
}
}
} else if isMeshHeader(flagBuf) {
Expand Down Expand Up @@ -301,19 +300,8 @@ func (c *defaultCodec) encodeMetaAndPayloadWithPayloadValidator(ctx context.Cont
return err
}
if c.payloadValidator != nil {
need, value, pErr := c.payloadValidator.Generate(ctx, flatten2DSlice(payload, payloadLen))
if pErr != nil {
return kerrors.ErrPayloadValidation.WithCause(fmt.Errorf("generate failed, err=%v", pErr))
}
if need {
if len(value) > maxPayloadChecksumLength {
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("payload validator value exceeds the limit, actual length=%d, limit=%d", len(value), maxPayloadChecksumLength))
}
key := getValidatorKey(ctx, c.payloadValidator)
strInfo := message.TransInfo().TransStrInfo()
if strInfo != nil {
strInfo[key] = value
}
if err = payloadChecksumGenerate(ctx, c.payloadValidator, flatten2DSlice(payload, payloadLen), message); err != nil {
return err
}
}
// set payload length before encode TTHeader
Expand Down
70 changes: 62 additions & 8 deletions pkg/remote/codec/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"encoding/binary"
"encoding/hex"
"fmt"
"github.com/cloudwego/kitex/pkg/consts"
"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/remote/transmeta"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/stats"
"hash/crc32"
"sync"
)
Expand All @@ -27,8 +29,46 @@ func getValidatorKey(ctx context.Context, p PayloadValidator) string {
return PayloadValidatorPrefix + key
}

func validate(ctx context.Context, message remote.Message, in remote.ByteBuffer, p PayloadValidator) error {
key := getValidatorKey(ctx, p)
func payloadChecksumGenerate(ctx context.Context, pv PayloadValidator, outboundPayload []byte, message remote.Message) (err error) {
rpcinfo.Record(ctx, message.RPCInfo(), stats.ChecksumGenerateStart, nil)
defer func() {
rpcinfo.Record(ctx, message.RPCInfo(), stats.ChecksumGenerateFinish, err)
}()

need, value, pErr := pv.Generate(ctx, outboundPayload)
if pErr != nil {
err = kerrors.ErrPayloadValidation.WithCause(fmt.Errorf("generate failed, err=%v", pErr))
return err
}
if need {
if len(value) > maxPayloadChecksumLength {
err = perrors.NewProtocolErrorWithMsg(fmt.Sprintf("payload validator value exceeds the limit, actual length=%d, limit=%d", len(value), maxPayloadChecksumLength))
return err
}
key := getValidatorKey(ctx, pv)
strInfo := message.TransInfo().TransStrInfo()
if strInfo != nil {
strInfo[key] = value
}
}
return nil
}

func payloadValidate(ctx context.Context, pv PayloadValidator, in remote.ByteBuffer, message remote.Message) (err error) {
rpcinfo.Record(ctx, message.RPCInfo(), stats.ChecksumValidateStart, nil)
defer func() {
rpcinfo.Record(ctx, message.RPCInfo(), stats.ChecksumValidateFinish, err)
}()

ctx = fillRPCInfoBeforeValidate(ctx, message)
// before validate
ctx, err = pv.ProcessBeforeValidate(ctx, message)
if err != nil {
return err
}

// get key and value
key := getValidatorKey(ctx, pv)
strInfo := message.TransInfo().TransStrInfo()
if strInfo == nil {
return nil
Expand All @@ -39,7 +79,9 @@ func validate(ctx context.Context, message remote.Message, in remote.ByteBuffer,
if err != nil {
return err
}
pass, err := p.Validate(ctx, expectedValue, payload)

// validate
pass, err := pv.Validate(ctx, expectedValue, payload)
if err != nil {
return kerrors.ErrPayloadValidation.WithCause(fmt.Errorf("validation failed, err=%v", err))
}
Expand All @@ -50,23 +92,23 @@ func validate(ctx context.Context, message remote.Message, in remote.ByteBuffer,
}

// fillRPCInfoBeforeValidate reads header and set into the RPCInfo, which allows Validate() to use RPCInfo.
func fillRPCInfoBeforeValidate(message remote.Message) {
func fillRPCInfoBeforeValidate(ctx context.Context, message remote.Message) context.Context {
if message.RPCRole() != remote.Server {
// only fill when server-side reading the request header
// TODO: client-side can read from the response header
return
return ctx
}
ri := message.RPCInfo()
if ri == nil {
return
return ctx
}
transInfo := message.TransInfo()
if transInfo == nil {
return
return ctx
}
intInfo := transInfo.TransIntInfo()
if intInfo == nil {
return
return ctx
}
from := rpcinfo.AsMutableEndpointInfo(ri.From())
if from != nil {
Expand All @@ -87,6 +129,10 @@ func fillRPCInfoBeforeValidate(message remote.Message) {
to.SetServiceName(v)
}
}
if logid := intInfo[transmeta.LogID]; logid != "" {
ctx = context.WithValue(ctx, consts.CtxKeyLogID, logid)
}
return ctx
}

// PayloadValidator is the interface for validating the payload of RPC requests, which allows customized Checksum function.
Expand All @@ -102,6 +148,10 @@ type PayloadValidator interface {
// Validate validates the input payload with the attached checksum.
// Return pass if validation succeed, or return error.
Validate(ctx context.Context, expectedValue string, inboundPayload []byte) (pass bool, err error)

// ProcessBeforeValidate is used to do some preprocess before validate.
// For example, you can extract some value from ttheader and set to the context, which may be useful for validation.
ProcessBeforeValidate(ctx context.Context, message remote.Message) (context.Context, error)
}

// NewCRC32PayloadValidator returns a new crcPayloadValidator
Expand Down Expand Up @@ -136,6 +186,10 @@ func (p *crcPayloadValidator) Validate(ctx context.Context, expectedValue string
return true, nil
}

func (p *crcPayloadValidator) ProcessBeforeValidate(ctx context.Context, message remote.Message) (context.Context, error) {
return ctx, nil
}

// crc32cTable is used for crc32c check
var (
crc32cTable *crc32.Table
Expand Down
15 changes: 13 additions & 2 deletions pkg/remote/codec/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"github.com/bytedance/gopkg/util/xxhash3"
"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/remote"
"strconv"
"testing"
)
Expand All @@ -15,6 +16,7 @@ type mockPayloadValidator struct {

const (
mockExceedLimitKey = "mockExceedLimit"
mockPreprocessKey = "key"
)

func (m *mockPayloadValidator) Key(ctx context.Context) string {
Expand All @@ -38,15 +40,24 @@ func (m *mockPayloadValidator) Validate(ctx context.Context, expectedValue strin
return value == expectedValue, nil
}

func (m *mockPayloadValidator) ProcessBeforeValidate(ctx context.Context, message remote.Message) (context.Context, error) {
ctx = context.WithValue(ctx, mockPreprocessKey, true)
return ctx, nil
}

func TestPayloadValidator(t *testing.T) {
p := &crcPayloadValidator{}
p := &mockPayloadValidator{}
payload := make([]byte, 0)

need, value, err := p.Generate(context.Background(), payload)
test.Assert(t, err == nil, err)
test.Assert(t, need)

pass, err := p.Validate(context.Background(), value, payload)
ctx, err := p.ProcessBeforeValidate(context.Background(), nil)
test.Assert(t, err == nil, err)
test.Assert(t, ctx.Value(mockPreprocessKey) == true)

pass, err := p.Validate(ctx, value, payload)
test.Assert(t, err == nil, err)
test.Assert(t, pass, true)
}
28 changes: 18 additions & 10 deletions pkg/stats/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ const (
writeFinish
streamRecv
streamSend
checksumGenerateStart
checksumGenerateFinish
checksumValidateStart
checksumValidateFinish

// NOTE: add new events before this line
predefinedEventNum
Expand All @@ -82,16 +86,20 @@ var (
RPCStart = newEvent(rpcStart, LevelBase)
RPCFinish = newEvent(rpcFinish, LevelBase)

ServerHandleStart = newEvent(serverHandleStart, LevelDetailed)
ServerHandleFinish = newEvent(serverHandleFinish, LevelDetailed)
ClientConnStart = newEvent(clientConnStart, LevelDetailed)
ClientConnFinish = newEvent(clientConnFinish, LevelDetailed)
ReadStart = newEvent(readStart, LevelDetailed)
ReadFinish = newEvent(readFinish, LevelDetailed)
WaitReadStart = newEvent(waitReadStart, LevelDetailed)
WaitReadFinish = newEvent(waitReadFinish, LevelDetailed)
WriteStart = newEvent(writeStart, LevelDetailed)
WriteFinish = newEvent(writeFinish, LevelDetailed)
ServerHandleStart = newEvent(serverHandleStart, LevelDetailed)
ServerHandleFinish = newEvent(serverHandleFinish, LevelDetailed)
ClientConnStart = newEvent(clientConnStart, LevelDetailed)
ClientConnFinish = newEvent(clientConnFinish, LevelDetailed)
ReadStart = newEvent(readStart, LevelDetailed)
ReadFinish = newEvent(readFinish, LevelDetailed)
WaitReadStart = newEvent(waitReadStart, LevelDetailed)
WaitReadFinish = newEvent(waitReadFinish, LevelDetailed)
WriteStart = newEvent(writeStart, LevelDetailed)
WriteFinish = newEvent(writeFinish, LevelDetailed)
ChecksumValidateStart = newEvent(checksumValidateStart, LevelDetailed)
ChecksumValidateFinish = newEvent(checksumValidateFinish, LevelDetailed)
ChecksumGenerateStart = newEvent(checksumGenerateStart, LevelDetailed)
ChecksumGenerateFinish = newEvent(checksumGenerateFinish, LevelDetailed)

// Streaming Events
StreamRecv = newEvent(streamRecv, LevelDetailed)
Expand Down

0 comments on commit 39ad24c

Please sign in to comment.