Skip to content

Commit

Permalink
feat: support payload check with crc32c when using ttheader
Browse files Browse the repository at this point in the history
  • Loading branch information
ppzqh committed Feb 23, 2024
1 parent 73984a4 commit fa276a9
Show file tree
Hide file tree
Showing 9 changed files with 376 additions and 14 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/cloudwego/fastpb v0.0.4
github.com/cloudwego/frugal v0.1.13
github.com/cloudwego/localsession v0.0.2
github.com/cloudwego/netpoll v0.5.2-0.20240220090456-7ba622bf763b
github.com/cloudwego/netpoll v0.5.2-0.20240223102227-0c594e3c8163
github.com/cloudwego/thriftgo v0.3.6
github.com/golang/mock v1.6.0
github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,8 @@ github.com/cloudwego/localsession v0.0.2/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiX
github.com/cloudwego/netpoll v0.2.4/go.mod h1:1T2WVuQ+MQw6h6DpE45MohSvDTKdy2DlzCx2KsnPI4E=
github.com/cloudwego/netpoll v0.3.1/go.mod h1:1T2WVuQ+MQw6h6DpE45MohSvDTKdy2DlzCx2KsnPI4E=
github.com/cloudwego/netpoll v0.4.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ=
github.com/cloudwego/netpoll v0.5.1 h1:zDUF7xF0C97I10fGlQFJ4jg65khZZMUvSu/TWX44Ohc=
github.com/cloudwego/netpoll v0.5.1/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ=
github.com/cloudwego/netpoll v0.5.2-0.20240220090456-7ba622bf763b h1:ZHtA1Q20H9WoLPfMHCSkMv8wUrN7YENJfQCVybErGy8=
github.com/cloudwego/netpoll v0.5.2-0.20240220090456-7ba622bf763b/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ=
github.com/cloudwego/netpoll v0.5.2-0.20240223102227-0c594e3c8163 h1:HGsD9cVt4x/tR3YfGLWSwYg9QGexudQdmIy/xnUaCJE=
github.com/cloudwego/netpoll v0.5.2-0.20240223102227-0c594e3c8163/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ=
github.com/cloudwego/thriftgo v0.1.2/go.mod h1:LzeafuLSiHA9JTiWC8TIMIq64iadeObgRUhmVG1OC/w=
github.com/cloudwego/thriftgo v0.2.4/go.mod h1:8i9AF5uDdWHGqzUhXDlubCjx4MEfKvWXGQlMWyR0tM4=
github.com/cloudwego/thriftgo v0.2.7/go.mod h1:8i9AF5uDdWHGqzUhXDlubCjx4MEfKvWXGQlMWyR0tM4=
Expand Down
154 changes: 152 additions & 2 deletions pkg/remote/codec/default_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,20 @@ package codec
import (
"context"
"encoding/binary"
"encoding/hex"
"fmt"
"hash/crc32"
"sync"
"sync/atomic"

netpoll2 "github.com/cloudwego/netpoll"

"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
"github.com/cloudwego/kitex/pkg/remote/transmeta"
"github.com/cloudwego/kitex/pkg/retry"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
Expand All @@ -45,6 +53,10 @@ const (

// MagicMask is bit mask for checking version.
MagicMask = 0xffff0000

// PayloadBufferSize is the default buffer size for encoding payload.
// Only used in encode using remote.defaultByteBuffer
PayloadBufferSize = 1024 * 4
)

var (
Expand All @@ -55,6 +67,12 @@ var (
_ remote.MetaDecoder = (*defaultCodec)(nil)
)

// crc32cTable is used for crc32c check
var (
crc32cTable *crc32.Table
crc32TableOnce sync.Once
)

// NewDefaultCodec creates the default protocol sniffing codec supporting thrift and protobuf.
func NewDefaultCodec() remote.Codec {
// No size limit by default
Expand All @@ -71,9 +89,32 @@ func NewDefaultCodecWithSizeLimit(maxSize int) remote.Codec {
}
}

// NewDefaultCodecWithConfig creates the default protocol sniffing codec supporting thrift and protobuf with the input config.
func NewDefaultCodecWithConfig(cfg CodecConfig) remote.Codec {
if cfg.CRC32Check {
crc32TableOnce.Do(func() {
crc32cTable = crc32.MakeTable(crc32.Castagnoli)
})
}
return &defaultCodec{
maxSize: cfg.MaxSize,
crc32Check: cfg.CRC32Check,
}
}

// CodecConfig is the config of defaultCodec
type CodecConfig struct {
MaxSize int
CRC32Check bool
}

type defaultCodec struct {
// maxSize limits the max size of the payload
maxSize int
// If crc32Check is true, the codec will validate the payload using crc32c.
// Only effective when transport is TTHeader.
// Payload is all the data after TTHeader.
crc32Check bool
}

// EncodePayload encode payload
Expand Down Expand Up @@ -124,10 +165,13 @@ func (c *defaultCodec) EncodePayload(ctx context.Context, message remote.Message

// EncodeMetaAndPayload encode meta and payload
func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error {
var err error
var totalLenField []byte
tp := message.ProtocolInfo().TransProto
if c.crc32Check && tp&transport.TTHeader == transport.TTHeader {
return c.encodeMetaAndPayloadWithCRC32C(ctx, message, out, me)
}

var err error
var totalLenField []byte
// 1. encode header and return totalLenField if needed
// totalLenField will be filled after payload encoded
if tp&transport.TTHeader == transport.TTHeader {
Expand All @@ -150,6 +194,55 @@ func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.
return nil
}

// encodeMetaAndPayloadWithCRC32C encodes payload and meta with crc32c checksum of the payload.
func (c *defaultCodec) encodeMetaAndPayloadWithCRC32C(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error {
var err error

// 1. encode payload and calculate crc32c checksum
payloadOut := netpoll.NewWriterByteBuffer(netpoll2.NewLinkBuffer())

if err = me.EncodePayload(ctx, message, payloadOut); err != nil {
return err
}
// get the payload from buffer
payload, payloadLen, err := payloadOut.(interface{ GetBytes() ([][]byte, int, error) }).GetBytes()
if err != nil {
// release if err
payloadOut.Release(err)
return err
}

crc32c := getCRC32C(payload)
strInfo := message.TransInfo().TransStrInfo()
if crc32c != "" && strInfo != nil {
strInfo[transmeta.HeaderCRC32C] = crc32c
}
// set payload length before encode TTHeader.
message.SetPayloadLen(payloadLen)

// 2. encode header and return totalLenField if needed
// In this case, set total length during TTHeader encode
if _, err = ttHeaderCodec.encode(ctx, message, out); err != nil {
return err
}

// 3. write payload to the buffer after TTHeader
if netpoll.IsNetpollByteBuffer(out) {
// append buffer only if the input buffer is a netpollByteBuffer
// release will be executed in AppendBuffer
err = out.AppendBuffer(payloadOut)
} else {
// convert [][]byte to []byte
p := convert(payload, payloadLen)
if ncWriter, ok := out.(remote.NocopyWrite); ok {
err = ncWriter.WriteDirect(p, 0)
} else {
_, err = out.WriteBinary(p)
}
}
return err
}

// Encode implements the remote.Codec interface, it does complete message encode include header and payload.
func (c *defaultCodec) Encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (err error) {
return c.EncodeMetaAndPayload(ctx, message, out, c)
Expand All @@ -176,6 +269,11 @@ func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, i
if flagBuf, err = in.Peek(2 * Size32); err != nil {
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("ttheader read payload first 8 byte failed: %s", err.Error()))
}
if c.crc32Check && crc32cTable != nil {
if err = checkCRC32C(message, in); err != nil {
return err
}
}
} else if isMeshHeader(flagBuf) {
message.Tags()[remote.MeshHeader] = true
// MeshHeader
Expand Down Expand Up @@ -373,3 +471,55 @@ func checkPayloadSize(payloadLen, maxSize int) error {
}
return nil
}

// getCRC32C calculates the crc32c checksum of the input bytes.
// the checksum will be converted into big-endian format and encoded into hex string.
func getCRC32C(payload [][]byte) string {
if crc32cTable == nil {
return ""
}
csb := make([]byte, Size32)
var checksum uint32
for i := 0; i < len(payload); i++ {
checksum = crc32.Update(checksum, crc32cTable, payload[i])
}
binary.BigEndian.PutUint32(csb, checksum)
return hex.EncodeToString(csb)
}

func convert(b2 [][]byte, length int) []byte {
b1 := make([]byte, length)
off := 0
for i := 0; i < len(b2); i++ {
copy(b1[off:off+len(b2[i])], b2[i])
off += len(b2[i])
}
return b1
}

// checkCRC32C validates the crc32c checksum in the header.
func checkCRC32C(message remote.Message, in remote.ByteBuffer) error {
strInfo := message.TransInfo().TransStrInfo()
if strInfo == nil {
return nil
}
crc32HexString := strInfo[transmeta.HeaderCRC32C]
if len(crc32HexString) != 0 {
crc32Byte, err := hex.DecodeString(crc32HexString)
if err != nil {
klog.Warnf("KITEX: crc32c key found in TTHeader, value is not a valid hex string")
return nil
}
expectedChecksum := binary.BigEndian.Uint32(crc32Byte)
payloadLen := message.PayloadLen() // total length
payload, err := in.Peek(payloadLen)
if err != nil {
return err
}
realChecksum := crc32.Checksum(payload, crc32cTable)
if realChecksum != expectedChecksum {
return perrors.NewProtocolErrorWithType(perrors.InvalidData, fmt.Sprintf("crc32c payload check failed, expected=%d, actual=%d", expectedChecksum, realChecksum))
}
}
return nil
}
Loading

0 comments on commit fa276a9

Please sign in to comment.