Skip to content

Commit

Permalink
fix: use netpoll buffer to encode payload to prevent length lost
Browse files Browse the repository at this point in the history
  • Loading branch information
ppzqh committed Feb 23, 2024
1 parent 1c063ba commit dad3def
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 26 deletions.
6 changes: 0 additions & 6 deletions pkg/remote/bytebuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ type NocopyWrite interface {
WriteDirect(buf []byte, remainCap int) error
}

// NocopyWrittenBytesGetter is used to get the written bytes from the buffer without copy.
type NocopyWrittenBytesGetter interface {
// BytesNocopy is used to get the bytes written with nocopy.
BytesNocopy() (buf []byte, err error)
}

// FrameWrite is to write header and data buffer separately to avoid memory copy
type FrameWrite interface {
// WriteHeader set header buffer without copy
Expand Down
37 changes: 25 additions & 12 deletions pkg/remote/codec/default_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ import (
"sync"
"sync/atomic"

netpoll2 "github.com/cloudwego/netpoll"

"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
"github.com/cloudwego/kitex/pkg/remote/transmeta"
"github.com/cloudwego/kitex/pkg/retry"
"github.com/cloudwego/kitex/pkg/rpcinfo"
Expand Down Expand Up @@ -164,7 +167,7 @@ func (c *defaultCodec) EncodePayload(ctx context.Context, message remote.Message
func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error {
tp := message.ProtocolInfo().TransProto
if c.crc32Check && tp&transport.TTHeader == transport.TTHeader {
return c.encodeMetaAndPayloadWithCRC32C(ctx, message, out, c)
return c.encodeMetaAndPayloadWithCRC32C(ctx, message, out, me)
}

var err error
Expand Down Expand Up @@ -196,22 +199,26 @@ func (c *defaultCodec) encodeMetaAndPayloadWithCRC32C(ctx context.Context, messa
var err error

// 1. encode payload and calculate crc32c checksum
newPayloadOut := remote.NewWriterBuffer(PayloadBufferSize)
payloadOut := netpoll.NewWriterByteBuffer(netpoll2.NewLinkBuffer())

if err = me.EncodePayload(ctx, message, newPayloadOut); err != nil {
if err = me.EncodePayload(ctx, message, payloadOut); err != nil {
return err
}
// get the payload from buffer
var payload []byte
if nc, ok := newPayloadOut.(remote.NocopyWrittenBytesGetter); ok {
payload, err = nc.BytesNocopy()
} else {
payload, err = newPayloadOut.Bytes()
}
newPayloadOut.Release(err)
payload, err = payloadOut.Bytes()
if err != nil {
// release if err
payloadOut.Release(err)
return err
}
outIsNetpollBuffer := netpoll.IsNetpollByteBuffer(out)
if !outIsNetpollBuffer {
// release payloadOut if the original out is not a netpoll buffer
// because it won't be used later.
payloadOut.Release(nil)
}

crc32c := getCRC32C(payload)
strInfo := message.TransInfo().TransStrInfo()
if crc32c != "" && strInfo != nil {
Expand All @@ -227,10 +234,16 @@ func (c *defaultCodec) encodeMetaAndPayloadWithCRC32C(ctx context.Context, messa
}

// 3. write payload to the buffer after TTHeader
if ncWriter, ok := out.(remote.NocopyWrite); ok {
err = ncWriter.WriteDirect(payload, 0)
if outIsNetpollBuffer {
// append buffer only if the input buffer is a netpollByteBuffer
// release will be executed in AppendBuffer
err = out.AppendBuffer(payloadOut)
} else {
_, err = out.WriteBinary(payload)
if ncWriter, ok := out.(remote.NocopyWrite); ok {
err = ncWriter.WriteDirect(payload, 0)
} else {
_, err = out.WriteBinary(payload)
}
}
return err
}
Expand Down
16 changes: 9 additions & 7 deletions pkg/remote/codec/default_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ import (
"testing"

"github.com/bytedance/mockey"
netpoll2 "github.com/cloudwego/netpoll"
"github.com/golang/mock/gomock"

"github.com/cloudwego/kitex/internal/mocks"
mocksremote "github.com/cloudwego/kitex/internal/mocks/remote"
"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/transport"
Expand Down Expand Up @@ -229,24 +231,25 @@ func TestDefaultCodecWithCRC32_Encode_Decode(t *testing.T) {
ctx := context.Background()
intKVInfo := prepareIntKVInfo()
strKVInfo := prepareStrKVInfo()
sendMsg := initClientSendMsg(transport.TTHeader, 3*1024)
sendMsg := initClientSendMsg(transport.TTHeaderFramed, 3*1024)
sendMsg.TransInfo().PutTransIntInfo(intKVInfo)
sendMsg.TransInfo().PutTransStrInfo(strKVInfo)

// test encode err
out := remote.NewReaderBuffer([]byte{})
err := dc.Encode(ctx, sendMsg, out)
badOut := netpoll.NewReaderByteBuffer(netpoll2.NewLinkBuffer())
err := dc.Encode(ctx, sendMsg, badOut)
test.Assert(t, err != nil)

// encode
out = remote.NewWriterBuffer(256)
err = dc.Encode(ctx, sendMsg, out)
npBuffer := netpoll.NewReaderWriterByteBuffer(netpoll2.NewLinkBuffer())
err = dc.Encode(ctx, sendMsg, npBuffer)
test.Assert(t, err == nil, err)

// decode, succeed
recvMsg := initServerRecvMsg()
buf, err := out.Bytes()
buf, err := npBuffer.Bytes()
test.Assert(t, err == nil, err)

in := remote.NewReaderBuffer(buf)
err = dc.Decode(ctx, recvMsg, in)
test.Assert(t, err == nil, err)
Expand All @@ -257,7 +260,6 @@ func TestDefaultCodecWithCRC32_Encode_Decode(t *testing.T) {
test.Assert(t, sendMsg.RPCInfo().Invocation().SeqID() == recvMsg.RPCInfo().Invocation().SeqID())

// decode, crc32c check failed
buf, err = out.Bytes()
test.Assert(t, err == nil, err)
bufLen := len(buf)
modifiedBuf := make([]byte, bufLen)
Expand Down
11 changes: 10 additions & 1 deletion pkg/remote/trans/netpoll/bytebuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ func NewReaderWriterByteBuffer(rw netpoll.ReadWriter) remote.ByteBuffer {
return bytebuf
}

func IsNetpollByteBuffer(b remote.ByteBuffer) bool {
_, ok := b.(*netpollByteBuffer)
return ok
}

func newNetpollByteBuffer() interface{} {
return &netpollByteBuffer{}
}
Expand Down Expand Up @@ -239,7 +244,11 @@ func (b *netpollByteBuffer) AppendBuffer(buf remote.ByteBuffer) (err error) {

// Bytes are not supported in netpoll bytebuf.
func (b *netpollByteBuffer) Bytes() (buf []byte, err error) {
return nil, errors.New("method Bytes() not support in netpoll bytebuf")
lb := b.writer.(*netpoll.LinkBuffer)
if err = lb.Flush(); err != nil {
return nil, err
}
return lb.Bytes(), nil
}

// Release will free the buffer already read.
Expand Down

0 comments on commit dad3def

Please sign in to comment.