Skip to content

Commit

Permalink
Revert "feat: support payload check with crc32c when using ttheader" (c…
Browse files Browse the repository at this point in the history
  • Loading branch information
ppzqh authored Feb 23, 2024
1 parent ad5a278 commit 73984a4
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 291 deletions.
6 changes: 0 additions & 6 deletions pkg/remote/bytebuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ type NocopyWrite interface {
WriteDirect(buf []byte, remainCap int) error
}

// NocopyWrittenBytesGetter is used to get the written bytes from the buffer without copy.
type NocopyWrittenBytesGetter interface {
// BytesNocopy is used to get the bytes written with nocopy.
BytesNocopy() (buf []byte, err error)
}

// FrameWrite is to write header and data buffer separately to avoid memory copy
type FrameWrite interface {
// WriteHeader set header buffer without copy
Expand Down
132 changes: 2 additions & 130 deletions pkg/remote/codec/default_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,12 @@ package codec
import (
"context"
"encoding/binary"
"encoding/hex"
"fmt"
"hash/crc32"
"sync"
"sync/atomic"

"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/remote/transmeta"
"github.com/cloudwego/kitex/pkg/retry"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
Expand All @@ -50,10 +45,6 @@ const (

// MagicMask is bit mask for checking version.
MagicMask = 0xffff0000

// PayloadBufferSize is the default buffer size for encoding payload.
// Only used in encode using remote.defaultByteBuffer
PayloadBufferSize = 1024 * 4
)

var (
Expand All @@ -64,12 +55,6 @@ var (
_ remote.MetaDecoder = (*defaultCodec)(nil)
)

// crc32cTable is used for crc32c check
var (
crc32cTable *crc32.Table
crc32TableOnce sync.Once
)

// NewDefaultCodec creates the default protocol sniffing codec supporting thrift and protobuf.
func NewDefaultCodec() remote.Codec {
// No size limit by default
Expand All @@ -86,32 +71,9 @@ func NewDefaultCodecWithSizeLimit(maxSize int) remote.Codec {
}
}

// NewDefaultCodecWithConfig creates the default protocol sniffing codec supporting thrift and protobuf with the input config.
func NewDefaultCodecWithConfig(cfg CodecConfig) remote.Codec {
if cfg.CRC32Check {
crc32TableOnce.Do(func() {
crc32cTable = crc32.MakeTable(crc32.Castagnoli)
})
}
return &defaultCodec{
maxSize: cfg.MaxSize,
crc32Check: cfg.CRC32Check,
}
}

// CodecConfig is the config of defaultCodec
type CodecConfig struct {
MaxSize int
CRC32Check bool
}

type defaultCodec struct {
// maxSize limits the max size of the payload
maxSize int
// If crc32Check is true, the codec will validate the payload using crc32c.
// Only effective when transport is TTHeader.
// Payload is all the data after TTHeader.
crc32Check bool
}

// EncodePayload encode payload
Expand Down Expand Up @@ -162,13 +124,10 @@ func (c *defaultCodec) EncodePayload(ctx context.Context, message remote.Message

// EncodeMetaAndPayload encode meta and payload
func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error {
tp := message.ProtocolInfo().TransProto
if c.crc32Check && tp&transport.TTHeader == transport.TTHeader {
return c.encodeMetaAndPayloadWithCRC32C(ctx, message, out, c)
}

var err error
var totalLenField []byte
tp := message.ProtocolInfo().TransProto

// 1. encode header and return totalLenField if needed
// totalLenField will be filled after payload encoded
if tp&transport.TTHeader == transport.TTHeader {
Expand All @@ -191,50 +150,6 @@ func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.
return nil
}

// encodeMetaAndPayloadWithCRC32C encodes payload and meta with crc32c checksum of the payload.
func (c *defaultCodec) encodeMetaAndPayloadWithCRC32C(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error {
var err error

// 1. encode payload and calculate crc32c checksum
newPayloadOut := remote.NewWriterBuffer(PayloadBufferSize)

if err = me.EncodePayload(ctx, message, newPayloadOut); err != nil {
return err
}
// get the payload from buffer
var payload []byte
if nc, ok := newPayloadOut.(remote.NocopyWrittenBytesGetter); ok {
payload, err = nc.BytesNocopy()
} else {
payload, err = newPayloadOut.Bytes()
}
newPayloadOut.Release(err)
if err != nil {
return err
}
crc32c := getCRC32C(payload)
strInfo := message.TransInfo().TransStrInfo()
if crc32c != "" && strInfo != nil {
strInfo[transmeta.HeaderCRC32C] = crc32c
}
// set payload length before encode TTHeader.
message.SetPayloadLen(len(payload))

// 2. encode header and return totalLenField if needed
// In this case, set total length during TTHeader encode
if _, err = ttHeaderCodec.encode(ctx, message, out); err != nil {
return err
}

// 3. write payload to the buffer after TTHeader
if ncWriter, ok := out.(remote.NocopyWrite); ok {
err = ncWriter.WriteDirect(payload, 0)
} else {
_, err = out.WriteBinary(payload)
}
return err
}

// Encode implements the remote.Codec interface, it does complete message encode include header and payload.
func (c *defaultCodec) Encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (err error) {
return c.EncodeMetaAndPayload(ctx, message, out, c)
Expand All @@ -261,11 +176,6 @@ func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, i
if flagBuf, err = in.Peek(2 * Size32); err != nil {
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("ttheader read payload first 8 byte failed: %s", err.Error()))
}
if c.crc32Check && crc32cTable != nil {
if err = checkCRC32C(message, in); err != nil {
return err
}
}
} else if isMeshHeader(flagBuf) {
message.Tags()[remote.MeshHeader] = true
// MeshHeader
Expand Down Expand Up @@ -463,41 +373,3 @@ func checkPayloadSize(payloadLen, maxSize int) error {
}
return nil
}

// getCRC32C calculates the crc32c checksum of the input bytes.
// the checksum will be converted into big-endian format and encoded into hex string.
func getCRC32C(payload []byte) string {
if crc32cTable == nil {
return ""
}
csb := make([]byte, Size32)
binary.BigEndian.PutUint32(csb, crc32.Checksum(payload, crc32cTable))
return hex.EncodeToString(csb)
}

// checkCRC32C validates the crc32c checksum in the header.
func checkCRC32C(message remote.Message, in remote.ByteBuffer) error {
strInfo := message.TransInfo().TransStrInfo()
if strInfo == nil {
return nil
}
crc32HexString := strInfo[transmeta.HeaderCRC32C]
if len(crc32HexString) != 0 {
crc32Byte, err := hex.DecodeString(crc32HexString)
if err != nil {
klog.Warnf("KITEX: crc32c key found in TTHeader, value is not a valid hex string")
return nil
}
expectedChecksum := binary.BigEndian.Uint32(crc32Byte)
payloadLen := message.PayloadLen() // total length
payload, err := in.Peek(payloadLen)
if err != nil {
return err
}
realChecksum := crc32.Checksum(payload, crc32cTable)
if realChecksum != expectedChecksum {
return perrors.NewProtocolErrorWithType(perrors.InvalidData, fmt.Sprintf("crc32c payload check failed, expected=%d, actual=%d", expectedChecksum, realChecksum))
}
}
return nil
}
114 changes: 0 additions & 114 deletions pkg/remote/codec/default_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"
"encoding/binary"
"errors"
"fmt"
"testing"

"github.com/bytedance/mockey"
Expand Down Expand Up @@ -222,54 +221,6 @@ func TestDefaultSizedCodec_Encode_Decode(t *testing.T) {
test.Assert(t, err == nil, err)
}

func TestDefaultCodecWithCRC32_Encode_Decode(t *testing.T) {
remote.PutPayloadCode(serviceinfo.Thrift, mpc)

dc := NewDefaultCodecWithConfig(CodecConfig{CRC32Check: true})
ctx := context.Background()
intKVInfo := prepareIntKVInfo()
strKVInfo := prepareStrKVInfo()
sendMsg := initClientSendMsg(transport.TTHeader, 3*1024)
sendMsg.TransInfo().PutTransIntInfo(intKVInfo)
sendMsg.TransInfo().PutTransStrInfo(strKVInfo)

// test encode err
out := remote.NewReaderBuffer([]byte{})
err := dc.Encode(ctx, sendMsg, out)
test.Assert(t, err != nil)

// encode
out = remote.NewWriterBuffer(256)
err = dc.Encode(ctx, sendMsg, out)
test.Assert(t, err == nil, err)

// decode, succeed
recvMsg := initServerRecvMsg()
buf, err := out.Bytes()
test.Assert(t, err == nil, err)
in := remote.NewReaderBuffer(buf)
err = dc.Decode(ctx, recvMsg, in)
test.Assert(t, err == nil, err)
intKVInfoRecv := recvMsg.TransInfo().TransIntInfo()
strKVInfoRecv := recvMsg.TransInfo().TransStrInfo()
test.DeepEqual(t, intKVInfoRecv, intKVInfo)
test.DeepEqual(t, strKVInfoRecv, strKVInfo)
test.Assert(t, sendMsg.RPCInfo().Invocation().SeqID() == recvMsg.RPCInfo().Invocation().SeqID())

// decode, crc32c check failed
buf, err = out.Bytes()
test.Assert(t, err == nil, err)
bufLen := len(buf)
modifiedBuf := make([]byte, bufLen)
copy(modifiedBuf, buf)
for i := bufLen - 1; i > bufLen-10; i-- {
modifiedBuf[i] = 123
}
in = remote.NewReaderBuffer(modifiedBuf)
err = dc.Decode(ctx, recvMsg, in)
test.Assert(t, err != nil, err)
}

func TestCodecTypeNotMatchWithServiceInfoPayloadCodec(t *testing.T) {
var req interface{}
remote.PutPayloadCode(serviceinfo.Thrift, mpc)
Expand Down Expand Up @@ -299,42 +250,6 @@ func TestCodecTypeNotMatchWithServiceInfoPayloadCodec(t *testing.T) {
test.Assert(t, err == nil)
}

func BenchmarkDefaultEncodeDecode(b *testing.B) {
ctx := context.Background()
remote.PutPayloadCode(serviceinfo.Thrift, mpc)
type factory func() remote.Codec
testCases := map[string]factory{"normal": NewDefaultCodec, "crc32c": func() remote.Codec { return NewDefaultCodecWithConfig(CodecConfig{CRC32Check: true}) }}

for name, f := range testCases {
b.Run(name, func(b *testing.B) {
msgLen := 1
for i := 0; i < 6; i++ {
b.ReportAllocs()
b.ResetTimer()
b.Run(fmt.Sprintf("payload-%d", msgLen), func(b *testing.B) {
for j := 0; j < b.N; j++ {
codec := f()
sendMsg := initClientSendMsg(transport.TTHeader, msgLen)
// encode
out := remote.NewWriterBuffer(1024)
err := codec.Encode(ctx, sendMsg, out)
test.Assert(b, err == nil, err)

// decode
recvMsg := initServerRecvMsgWithMockMsg()
buf, err := out.Bytes()
test.Assert(b, err == nil, err)
in := remote.NewReaderBuffer(buf)
err = codec.Decode(ctx, recvMsg, in)
test.Assert(b, err == nil, err)
}
})
msgLen *= 10
}
})
}
}

var mpc remote.PayloadCodec = mockPayloadCodec{}

type mockPayloadCodec struct{}
Expand All @@ -343,23 +258,6 @@ func (m mockPayloadCodec) Marshal(ctx context.Context, message remote.Message, o
WriteUint32(ThriftV1Magic+uint32(message.MessageType()), out)
WriteString(message.RPCInfo().Invocation().MethodName(), out)
WriteUint32(uint32(message.RPCInfo().Invocation().SeqID()), out)
var (
dataLen uint32
dataStr string
)
// write data
if data := message.Data(); data != nil {
if mm, ok := data.(*mockMsg); ok {
if len(mm.msg) != 0 {
dataStr = mm.msg
dataLen = uint32(len(mm.msg))
}
}
}
WriteUint32(dataLen, out)
if dataLen > 0 {
WriteString(dataStr, out)
}
return nil
}

Expand Down Expand Up @@ -390,18 +288,6 @@ func (m mockPayloadCodec) Unmarshal(ctx context.Context, message remote.Message,
if err = SetOrCheckSeqID(int32(seqID), message); err != nil && msgType != uint32(remote.Exception) {
return err
}
// read data
dataLen, err := PeekUint32(in)
if err != nil {
return err
}
if dataLen == 0 {
// no data
return nil
}
if _, _, err = ReadString(in); err != nil {
return err
}
return nil
}

Expand Down
8 changes: 0 additions & 8 deletions pkg/remote/codec/header_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ const (
type ttHeader struct{}

func (t ttHeader) encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (totalLenField []byte, err error) {
mallocLenBefore := out.MallocLen()

// 1. header meta
var headerMeta []byte
headerMeta, err = out.Malloc(TTHeaderMetaSize)
Expand Down Expand Up @@ -155,12 +153,6 @@ func (t ttHeader) encode(ctx context.Context, message remote.Message, out remote
return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("invalid header length[%d]", headerInfoSize))
}
binary.BigEndian.PutUint16(headerInfoSizeField, uint16(headerInfoSize/4))
if message.PayloadLen() != 0 {
// payload encoded before. set total length here.
headerLen := out.MallocLen() - mallocLenBefore
totalLen := message.PayloadLen() + headerLen - Size32
binary.BigEndian.PutUint32(totalLenField, uint32(totalLen))
}
return totalLenField, err
}

Expand Down
Loading

0 comments on commit 73984a4

Please sign in to comment.