Skip to content

Commit

Permalink
chore: no customized process function before validate
Browse files Browse the repository at this point in the history
  • Loading branch information
ppzqh committed Aug 19, 2024
1 parent b1f40bc commit a7d2e89
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 103 deletions.
6 changes: 0 additions & 6 deletions pkg/remote/bytebuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ type NocopyWrite interface {
MallocAck(n int) error
}

// NocopyRead is to read [][]byte without copying.
// It is used with linked buffer.
type NocopyRead interface {
GetBytesNoCopy() ([][]byte, int, error)
}

// FrameWrite is to write header and data buffer separately to avoid memory copy
type FrameWrite interface {
// WriteHeader set header buffer without copy
Expand Down
24 changes: 0 additions & 24 deletions pkg/remote/codec/default_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"encoding/binary"
"errors"
"fmt"
"math/rand"
"strings"
"testing"

Expand Down Expand Up @@ -571,26 +570,3 @@ func TestCornerCase(t *testing.T) {
err := (&defaultCodec{}).EncodePayload(context.Background(), sendMsg, buffer)
test.Assert(t, err.Error() == "error malloc")
}

func TestFlatten2DSlice(t *testing.T) {
var (
b2 [][]byte
expectedB1 []byte
)
row, column := 10, 10
for i := 0; i < row; i++ {
var b []byte
for j := 0; j < column; j++ {
curr := rand.Int()
b = append(b, byte(curr))
expectedB1 = append(expectedB1, byte(curr))
}
b2 = append(b2, b)
}
length := row * column
actualB1 := flatten2DSlice(b2, length)
test.Assert(t, len(actualB1) == length)
for i := 0; i < length; i++ {
test.Assert(t, actualB1[i] == expectedB1[i])
}
}
38 changes: 7 additions & 31 deletions pkg/remote/codec/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ import (
"encoding/binary"
"encoding/hex"
"fmt"
"hash/crc32"
"sync"

"github.com/cloudwego/kitex/pkg/consts"
"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/remote/transmeta"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/stats"
"hash/crc32"
"sync"
)

const (
Expand All @@ -34,10 +35,6 @@ type PayloadValidator interface {
// Validate validates the input payload with the attached checksum.
// Return pass if validation succeed, or return error.
Validate(ctx context.Context, expectedValue string, inboundPayload []byte) (pass bool, err error)

// ProcessBeforeValidate is used to do some preprocess before validate.
// For example, you can extract some value from ttheader and set to the rpcinfo, which may be useful for validation.
ProcessBeforeValidate(ctx context.Context, message remote.Message) (context.Context, error)
}

func getValidatorKey(ctx context.Context, p PayloadValidator) string {
Expand Down Expand Up @@ -79,12 +76,8 @@ func payloadChecksumValidate(ctx context.Context, pv PayloadValidator, in remote
rpcinfo.Record(ctx, message.RPCInfo(), stats.ChecksumValidateFinish, err)
}()

// this return ctx can only be used in Validate part since Decode has no return argument for context
ctx = fillRPCInfoBeforeValidate(ctx, message)
// before validate
ctx, err = pv.ProcessBeforeValidate(ctx, message)
if err != nil {
return err
}

// get key and value
key := getValidatorKey(ctx, pv)
Expand Down Expand Up @@ -172,7 +165,7 @@ func (p *crcPayloadValidator) Key(ctx context.Context) string {
}

func (p *crcPayloadValidator) Generate(ctx context.Context, outPayload []byte) (need bool, value string, err error) {
return true, getCRC32C([][]byte{outPayload}), nil
return true, getCRC32C(outPayload), nil
}

func (p *crcPayloadValidator) Validate(ctx context.Context, expectedValue string, inputPayload []byte) (pass bool, err error) {
Expand All @@ -186,10 +179,6 @@ func (p *crcPayloadValidator) Validate(ctx context.Context, expectedValue string
return true, nil
}

func (p *crcPayloadValidator) ProcessBeforeValidate(ctx context.Context, message remote.Message) (context.Context, error) {
return ctx, nil
}

// crc32cTable is used for crc32c check
var (
crc32cTable *crc32.Table
Expand All @@ -198,26 +187,13 @@ var (

// getCRC32C calculates the crc32c checksum of the input bytes.
// the checksum will be converted into big-endian format and encoded into hex string.
func getCRC32C(payload [][]byte) string {
func getCRC32C(payload []byte) string {
if crc32cTable == nil {
return ""
}
csb := make([]byte, Size32)
var checksum uint32
for i := 0; i < len(payload); i++ {
checksum = crc32.Update(checksum, crc32cTable, payload[i])
}
checksum = crc32.Update(checksum, crc32cTable, payload)
binary.BigEndian.PutUint32(csb, checksum)
return hex.EncodeToString(csb)
}

// flatten2DSlice converts 2d slice to 1d.
// total length should be provided.
func flatten2DSlice(b2 [][]byte, length int) []byte {
b1 := make([]byte, length)
off := 0
for i := 0; i < len(b2); i++ {
off += copy(b1[off:], b2[i])
}
return b1
}
51 changes: 9 additions & 42 deletions pkg/remote/codec/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,25 @@ package codec
import (
"context"
"errors"
"strconv"
"testing"

"github.com/bytedance/gopkg/util/xxhash3"
"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/transport"
"strconv"
"testing"
)

var _ PayloadValidator = &mockPayloadValidator{}

type mockPayloadValidator struct {
}
type mockPayloadValidator struct{}

const (
mockGenerateErrorKey = "mockGenerateError"
mockExceedLimitKey = "mockExceedLimit"
mockPreprocessKey = "preprocessSucceedKey"
mockPreprocessFailedKey = "preprocessFailed"
mockGenerateErrorKey = "mockGenerateError"
mockExceedLimitKey = "mockExceedLimit"
)

type preprocessStruct struct {
value string
}

func (m *mockPayloadValidator) Key(ctx context.Context) string {
return "mockValidator"
}
Expand All @@ -52,17 +46,6 @@ func (m *mockPayloadValidator) Validate(ctx context.Context, expectedValue strin
return value == expectedValue, nil
}

func (m *mockPayloadValidator) ProcessBeforeValidate(ctx context.Context, message remote.Message) (context.Context, error) {
if l := ctx.Value(mockPreprocessFailedKey); l != nil {
return ctx, errors.New("mockPreprocessFailed")
}
s := ctx.Value(mockPreprocessKey)
if s != nil {
s.(*preprocessStruct).value = "123"
}
return ctx, nil
}

func TestPayloadValidator(t *testing.T) {
p := &mockPayloadValidator{}
payload := make([]byte, 0)
Expand All @@ -71,14 +54,7 @@ func TestPayloadValidator(t *testing.T) {
test.Assert(t, err == nil, err)
test.Assert(t, need)

ctx := context.Background()
s := &preprocessStruct{value: "1"}
ctx = context.WithValue(ctx, mockPreprocessKey, s)
ctx, err = p.ProcessBeforeValidate(ctx, nil)
test.Assert(t, err == nil, err)
test.Assert(t, s.value == "123")

pass, err := p.Validate(ctx, value, payload)
pass, err := p.Validate(context.Background(), value, payload)
test.Assert(t, err == nil, err)
test.Assert(t, pass, true)
}
Expand All @@ -98,6 +74,7 @@ func TestPayloadChecksumGenerate(t *testing.T) {
test.Assert(t, err == nil, err)
test.Assert(t, len(strInfo) != 0)
test.Assert(t, strInfo[getValidatorKey(ctx, pv)] != "")
message.Recycle()

// failed, generate error
message = initClientSendMsg(transport.TTHeader)
Expand All @@ -106,6 +83,7 @@ func TestPayloadChecksumGenerate(t *testing.T) {
err = payloadChecksumGenerate(ctx, pv, payload, message)
test.Assert(t, err != nil, err)
test.Assert(t, errors.Is(err, kerrors.ErrPayloadValidation))
message.Recycle()

// failed, exceed limit
message = initClientSendMsg(transport.TTHeader)
Expand Down Expand Up @@ -133,11 +111,8 @@ func TestPayloadChecksumValidate(t *testing.T) {
message := initClientRecvMsg()
message.TransInfo().PutTransStrInfo(sendMsg.TransInfo().TransStrInfo()) // put header strinfo
message.SetPayloadLen(len(payload))
s := &preprocessStruct{value: "1"}
ctx = context.WithValue(context.Background(), mockPreprocessKey, s)
err = payloadChecksumValidate(ctx, pv, in, message)
test.Assert(t, err == nil, err)
test.Assert(t, s.value == "123")

// validate failed, checksum validation error
in = remote.NewReaderBuffer(payload)
Expand All @@ -146,12 +121,4 @@ func TestPayloadChecksumValidate(t *testing.T) {
message.SetPayloadLen(len(payload))
err = payloadChecksumValidate(context.Background(), pv, in, message)
test.Assert(t, err != nil)

// validate failed, preprocess error
ctx = context.WithValue(context.Background(), mockPreprocessFailedKey, true)
s = &preprocessStruct{value: "1"}
ctx = context.WithValue(ctx, mockPreprocessKey, s)
err = payloadChecksumValidate(ctx, pv, in, message)
test.Assert(t, err != nil)
test.Assert(t, s.value == "1") // will not be modified
}

0 comments on commit a7d2e89

Please sign in to comment.