From 843102b9bed256efd2a3ede9f5c568363b2bf390 Mon Sep 17 00:00:00 2001 From: QihengZhou Date: Mon, 19 Feb 2024 17:04:44 +0800 Subject: [PATCH] feat: support payload check with crc32c when using ttheader (#1239) --- pkg/remote/bytebuf.go | 6 ++ pkg/remote/codec/default_codec.go | 132 ++++++++++++++++++++++++- pkg/remote/codec/default_codec_test.go | 114 +++++++++++++++++++++ pkg/remote/codec/header_codec.go | 8 ++ pkg/remote/codec/header_codec_test.go | 28 +++++- pkg/remote/default_bytebuf.go | 9 +- pkg/remote/transmeta/metakey.go | 2 + 7 files changed, 291 insertions(+), 8 deletions(-) diff --git a/pkg/remote/bytebuf.go b/pkg/remote/bytebuf.go index 9f1dbe4645..df74aff917 100644 --- a/pkg/remote/bytebuf.go +++ b/pkg/remote/bytebuf.go @@ -34,6 +34,12 @@ type NocopyWrite interface { WriteDirect(buf []byte, remainCap int) error } +// NocopyWrittenBytesGetter is used to get the written bytes from the buffer without copy. +type NocopyWrittenBytesGetter interface { + // BytesNocopy is used to get the bytes written with nocopy. + BytesNocopy() (buf []byte, err error) +} + // FrameWrite is to write header and data buffer separately to avoid memory copy type FrameWrite interface { // WriteHeader set header buffer without copy diff --git a/pkg/remote/codec/default_codec.go b/pkg/remote/codec/default_codec.go index a4d29c4523..8cdc43e47a 100644 --- a/pkg/remote/codec/default_codec.go +++ b/pkg/remote/codec/default_codec.go @@ -19,12 +19,17 @@ package codec import ( "context" "encoding/binary" + "encoding/hex" "fmt" + "hash/crc32" + "sync" "sync/atomic" "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" + "github.com/cloudwego/kitex/pkg/remote/transmeta" "github.com/cloudwego/kitex/pkg/retry" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" @@ -45,6 +50,10 @@ const ( // MagicMask is bit mask for checking version. MagicMask = 0xffff0000 + + // PayloadBufferSize is the default buffer size for encoding payload. + // Only used in encode using remote.defaultByteBuffer + PayloadBufferSize = 1024 * 4 ) var ( @@ -55,6 +64,12 @@ var ( _ remote.MetaDecoder = (*defaultCodec)(nil) ) +// crc32cTable is used for crc32c check +var ( + crc32cTable *crc32.Table + crc32TableOnce sync.Once +) + // NewDefaultCodec creates the default protocol sniffing codec supporting thrift and protobuf. func NewDefaultCodec() remote.Codec { // No size limit by default @@ -71,9 +86,32 @@ func NewDefaultCodecWithSizeLimit(maxSize int) remote.Codec { } } +// NewDefaultCodecWithConfig creates the default protocol sniffing codec supporting thrift and protobuf with the input config. +func NewDefaultCodecWithConfig(cfg CodecConfig) remote.Codec { + if cfg.CRC32Check { + crc32TableOnce.Do(func() { + crc32cTable = crc32.MakeTable(crc32.Castagnoli) + }) + } + return &defaultCodec{ + maxSize: cfg.MaxSize, + crc32Check: cfg.CRC32Check, + } +} + +// CodecConfig is the config of defaultCodec +type CodecConfig struct { + MaxSize int + CRC32Check bool +} + type defaultCodec struct { // maxSize limits the max size of the payload maxSize int + // If crc32Check is true, the codec will validate the payload using crc32c. + // Only effective when transport is TTHeader. + // Payload is all the data after TTHeader. + crc32Check bool } // EncodePayload encode payload @@ -124,10 +162,13 @@ func (c *defaultCodec) EncodePayload(ctx context.Context, message remote.Message // EncodeMetaAndPayload encode meta and payload func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error { - var err error - var totalLenField []byte tp := message.ProtocolInfo().TransProto + if c.crc32Check && tp&transport.TTHeader == transport.TTHeader { + return c.encodeMetaAndPayloadWithCRC32C(ctx, message, out, c) + } + var err error + var totalLenField []byte // 1. encode header and return totalLenField if needed // totalLenField will be filled after payload encoded if tp&transport.TTHeader == transport.TTHeader { @@ -150,6 +191,50 @@ func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote. return nil } +// encodeMetaAndPayloadWithCRC32C encodes payload and meta with crc32c checksum of the payload. +func (c *defaultCodec) encodeMetaAndPayloadWithCRC32C(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error { + var err error + + // 1. encode payload and calculate crc32c checksum + newPayloadOut := remote.NewWriterBuffer(PayloadBufferSize) + + if err = me.EncodePayload(ctx, message, newPayloadOut); err != nil { + return err + } + // get the payload from buffer + var payload []byte + if nc, ok := newPayloadOut.(remote.NocopyWrittenBytesGetter); ok { + payload, err = nc.BytesNocopy() + } else { + payload, err = newPayloadOut.Bytes() + } + newPayloadOut.Release(err) + if err != nil { + return err + } + crc32c := getCRC32C(payload) + strInfo := message.TransInfo().TransStrInfo() + if crc32c != "" && strInfo != nil { + strInfo[transmeta.HeaderCRC32C] = crc32c + } + // set payload length before encode TTHeader. + message.SetPayloadLen(len(payload)) + + // 2. encode header and return totalLenField if needed + // In this case, set total length during TTHeader encode + if _, err = ttHeaderCodec.encode(ctx, message, out); err != nil { + return err + } + + // 3. write payload to the buffer after TTHeader + if ncWriter, ok := out.(remote.NocopyWrite); ok { + err = ncWriter.WriteDirect(payload, 0) + } else { + _, err = out.WriteBinary(payload) + } + return err +} + // Encode implements the remote.Codec interface, it does complete message encode include header and payload. func (c *defaultCodec) Encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (err error) { return c.EncodeMetaAndPayload(ctx, message, out, c) @@ -176,6 +261,11 @@ func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, i if flagBuf, err = in.Peek(2 * Size32); err != nil { return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("ttheader read payload first 8 byte failed: %s", err.Error())) } + if c.crc32Check && crc32cTable != nil { + if err = checkCRC32C(message, in); err != nil { + return err + } + } } else if isMeshHeader(flagBuf) { message.Tags()[remote.MeshHeader] = true // MeshHeader @@ -373,3 +463,41 @@ func checkPayloadSize(payloadLen, maxSize int) error { } return nil } + +// getCRC32C calculates the crc32c checksum of the input bytes. +// the checksum will be converted into big-endian format and encoded into hex string. +func getCRC32C(payload []byte) string { + if crc32cTable == nil { + return "" + } + csb := make([]byte, Size32) + binary.BigEndian.PutUint32(csb, crc32.Checksum(payload, crc32cTable)) + return hex.EncodeToString(csb) +} + +// checkCRC32C validates the crc32c checksum in the header. +func checkCRC32C(message remote.Message, in remote.ByteBuffer) error { + strInfo := message.TransInfo().TransStrInfo() + if strInfo == nil { + return nil + } + crc32HexString := strInfo[transmeta.HeaderCRC32C] + if len(crc32HexString) != 0 { + crc32Byte, err := hex.DecodeString(crc32HexString) + if err != nil { + klog.Warnf("KITEX: crc32c key found in TTHeader, value is not a valid hex string") + return nil + } + expectedChecksum := binary.BigEndian.Uint32(crc32Byte) + payloadLen := message.PayloadLen() // total length + payload, err := in.Peek(payloadLen) + if err != nil { + return err + } + realChecksum := crc32.Checksum(payload, crc32cTable) + if realChecksum != expectedChecksum { + return perrors.NewProtocolErrorWithType(perrors.InvalidData, fmt.Sprintf("crc32c payload check failed, expected=%d, actual=%d", expectedChecksum, realChecksum)) + } + } + return nil +} diff --git a/pkg/remote/codec/default_codec_test.go b/pkg/remote/codec/default_codec_test.go index e27ecfca29..a02fb85b5f 100644 --- a/pkg/remote/codec/default_codec_test.go +++ b/pkg/remote/codec/default_codec_test.go @@ -20,6 +20,7 @@ import ( "context" "encoding/binary" "errors" + "fmt" "testing" "github.com/bytedance/mockey" @@ -221,6 +222,54 @@ func TestDefaultSizedCodec_Encode_Decode(t *testing.T) { test.Assert(t, err == nil, err) } +func TestDefaultCodecWithCRC32_Encode_Decode(t *testing.T) { + remote.PutPayloadCode(serviceinfo.Thrift, mpc) + + dc := NewDefaultCodecWithConfig(CodecConfig{CRC32Check: true}) + ctx := context.Background() + intKVInfo := prepareIntKVInfo() + strKVInfo := prepareStrKVInfo() + sendMsg := initClientSendMsg(transport.TTHeader, 3*1024) + sendMsg.TransInfo().PutTransIntInfo(intKVInfo) + sendMsg.TransInfo().PutTransStrInfo(strKVInfo) + + // test encode err + out := remote.NewReaderBuffer([]byte{}) + err := dc.Encode(ctx, sendMsg, out) + test.Assert(t, err != nil) + + // encode + out = remote.NewWriterBuffer(256) + err = dc.Encode(ctx, sendMsg, out) + test.Assert(t, err == nil, err) + + // decode, succeed + recvMsg := initServerRecvMsg() + buf, err := out.Bytes() + test.Assert(t, err == nil, err) + in := remote.NewReaderBuffer(buf) + err = dc.Decode(ctx, recvMsg, in) + test.Assert(t, err == nil, err) + intKVInfoRecv := recvMsg.TransInfo().TransIntInfo() + strKVInfoRecv := recvMsg.TransInfo().TransStrInfo() + test.DeepEqual(t, intKVInfoRecv, intKVInfo) + test.DeepEqual(t, strKVInfoRecv, strKVInfo) + test.Assert(t, sendMsg.RPCInfo().Invocation().SeqID() == recvMsg.RPCInfo().Invocation().SeqID()) + + // decode, crc32c check failed + buf, err = out.Bytes() + test.Assert(t, err == nil, err) + bufLen := len(buf) + modifiedBuf := make([]byte, bufLen) + copy(modifiedBuf, buf) + for i := bufLen - 1; i > bufLen-10; i-- { + modifiedBuf[i] = 123 + } + in = remote.NewReaderBuffer(modifiedBuf) + err = dc.Decode(ctx, recvMsg, in) + test.Assert(t, err != nil, err) +} + func TestCodecTypeNotMatchWithServiceInfoPayloadCodec(t *testing.T) { var req interface{} remote.PutPayloadCode(serviceinfo.Thrift, mpc) @@ -250,6 +299,42 @@ func TestCodecTypeNotMatchWithServiceInfoPayloadCodec(t *testing.T) { test.Assert(t, err == nil) } +func BenchmarkDefaultEncodeDecode(b *testing.B) { + ctx := context.Background() + remote.PutPayloadCode(serviceinfo.Thrift, mpc) + type factory func() remote.Codec + testCases := map[string]factory{"normal": NewDefaultCodec, "crc32c": func() remote.Codec { return NewDefaultCodecWithConfig(CodecConfig{CRC32Check: true}) }} + + for name, f := range testCases { + b.Run(name, func(b *testing.B) { + msgLen := 1 + for i := 0; i < 6; i++ { + b.ReportAllocs() + b.ResetTimer() + b.Run(fmt.Sprintf("payload-%d", msgLen), func(b *testing.B) { + for j := 0; j < b.N; j++ { + codec := f() + sendMsg := initClientSendMsg(transport.TTHeader, msgLen) + // encode + out := remote.NewWriterBuffer(1024) + err := codec.Encode(ctx, sendMsg, out) + test.Assert(b, err == nil, err) + + // decode + recvMsg := initServerRecvMsgWithMockMsg() + buf, err := out.Bytes() + test.Assert(b, err == nil, err) + in := remote.NewReaderBuffer(buf) + err = codec.Decode(ctx, recvMsg, in) + test.Assert(b, err == nil, err) + } + }) + msgLen *= 10 + } + }) + } +} + var mpc remote.PayloadCodec = mockPayloadCodec{} type mockPayloadCodec struct{} @@ -258,6 +343,23 @@ func (m mockPayloadCodec) Marshal(ctx context.Context, message remote.Message, o WriteUint32(ThriftV1Magic+uint32(message.MessageType()), out) WriteString(message.RPCInfo().Invocation().MethodName(), out) WriteUint32(uint32(message.RPCInfo().Invocation().SeqID()), out) + var ( + dataLen uint32 + dataStr string + ) + // write data + if data := message.Data(); data != nil { + if mm, ok := data.(*mockMsg); ok { + if len(mm.msg) != 0 { + dataStr = mm.msg + dataLen = uint32(len(mm.msg)) + } + } + } + WriteUint32(dataLen, out) + if dataLen > 0 { + WriteString(dataStr, out) + } return nil } @@ -288,6 +390,18 @@ func (m mockPayloadCodec) Unmarshal(ctx context.Context, message remote.Message, if err = SetOrCheckSeqID(int32(seqID), message); err != nil && msgType != uint32(remote.Exception) { return err } + // read data + dataLen, err := PeekUint32(in) + if err != nil { + return err + } + if dataLen == 0 { + // no data + return nil + } + if _, _, err = ReadString(in); err != nil { + return err + } return nil } diff --git a/pkg/remote/codec/header_codec.go b/pkg/remote/codec/header_codec.go index 524dd6a903..38903b5230 100644 --- a/pkg/remote/codec/header_codec.go +++ b/pkg/remote/codec/header_codec.go @@ -117,6 +117,8 @@ const ( type ttHeader struct{} func (t ttHeader) encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (totalLenField []byte, err error) { + mallocLenBefore := out.MallocLen() + // 1. header meta var headerMeta []byte headerMeta, err = out.Malloc(TTHeaderMetaSize) @@ -153,6 +155,12 @@ func (t ttHeader) encode(ctx context.Context, message remote.Message, out remote return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("invalid header length[%d]", headerInfoSize)) } binary.BigEndian.PutUint16(headerInfoSizeField, uint16(headerInfoSize/4)) + if message.PayloadLen() != 0 { + // payload encoded before. set total length here. + headerLen := out.MallocLen() - mallocLenBefore + totalLen := message.PayloadLen() + headerLen - Size32 + binary.BigEndian.PutUint32(totalLenField, uint32(totalLen)) + } return totalLenField, err } diff --git a/pkg/remote/codec/header_codec_test.go b/pkg/remote/codec/header_codec_test.go index f23c936669..f05f64e855 100644 --- a/pkg/remote/codec/header_codec_test.go +++ b/pkg/remote/codec/header_codec_test.go @@ -306,6 +306,15 @@ var ( rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) ) +type mockMsg struct { + msg string +} + +func initServerRecvMsgWithMockMsg() remote.Message { + req := &mockMsg{} + return remote.NewMessage(req, mocks.ServiceInfo(), mockSvrRPCInfo, remote.Call, remote.Server) +} + func initServerRecvMsg() remote.Message { svcInfo := mocks.ServiceInfo() svcSearchMap := map[string]*serviceinfo.ServiceInfo{ @@ -322,23 +331,32 @@ func initServerRecvMsg() remote.Message { return msg } -func initClientSendMsg(tp transport.Protocol) remote.Message { - var req interface{} +func initClientSendMsg(tp transport.Protocol, payloadLen ...int) remote.Message { + req := &mockMsg{} + if len(payloadLen) != 0 { + req.msg = string(make([]byte, payloadLen[0])) + } + svcInfo := mocks.ServiceInfo() + mi := svcInfo.MethodInfo(mockCliRPCInfo.Invocation().MethodName()) + mi.NewArgs() msg := remote.NewMessage(req, svcInfo, mockCliRPCInfo, remote.Call, remote.Client) msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) return msg } -func initServerSendMsg(tp transport.Protocol) remote.Message { - var resp interface{} +func initServerSendMsg(tp transport.Protocol, payloadLen ...int) remote.Message { + resp := &mockMsg{} + if len(payloadLen) != 0 { + resp.msg = string(make([]byte, payloadLen[0])) + } msg := remote.NewMessage(resp, mocks.ServiceInfo(), mockSvrRPCInfo, remote.Reply, remote.Server) msg.SetProtocolInfo(remote.NewProtocolInfo(tp, mocks.ServiceInfo().PayloadCodec)) return msg } func initClientRecvMsg() remote.Message { - var resp interface{} + resp := &mockMsg{} svcInfo := mocks.ServiceInfo() msg := remote.NewMessage(resp, svcInfo, mockCliRPCInfo, remote.Reply, remote.Client) return msg diff --git a/pkg/remote/default_bytebuf.go b/pkg/remote/default_bytebuf.go index 1a8ff17933..e8eb5ccab7 100644 --- a/pkg/remote/default_bytebuf.go +++ b/pkg/remote/default_bytebuf.go @@ -36,7 +36,6 @@ func init() { } // NewWriterBuffer is used to create a defaultByteBuffer using the given size. -// NOTICE: defaultByteBuffer is only used for testing. func NewWriterBuffer(size int) ByteBuffer { return newWriterByteBuffer(size) } @@ -279,6 +278,14 @@ func (b *defaultByteBuffer) Bytes() (buf []byte, err error) { return buf, nil } +// BytesNocopy is used to get the bytes written with nocopy. +func (b *defaultByteBuffer) BytesNocopy() (buf []byte, err error) { + if b.status&BitWritable == 0 { + return nil, errors.New("unwritable buffer, cannot support Bytes") + } + return b.buff[:b.writeIdx], nil +} + // NewBuffer returns a new writable remote.ByteBuffer. func (b *defaultByteBuffer) NewBuffer() ByteBuffer { return NewWriterBuffer(256) diff --git a/pkg/remote/transmeta/metakey.go b/pkg/remote/transmeta/metakey.go index a2ad43ce3d..f1a43afbd3 100644 --- a/pkg/remote/transmeta/metakey.go +++ b/pkg/remote/transmeta/metakey.go @@ -63,6 +63,8 @@ const ( // the connection peer will shutdown later,so it send back the header to tell client to close the connection. HeaderConnectionReadyToReset = "crrst" HeaderProcessAtTime = "K_ProcessAtTime" + // HeaderCRC32C is used to store the crc32c checksum of payload + HeaderCRC32C = "crc32c" ) // key of acl token