Skip to content

Commit

Permalink
Merge pull request cloudwego#1313 from DMwangnima/feat/skip_codec
Browse files Browse the repository at this point in the history
feat(thrift codec): implement skipDecoder to enable Frugal and FastCodec for standard Thrift Buffer Protocol
  • Loading branch information
DMwangnima authored Apr 16, 2024
2 parents 3e665fd + 8b3dd95 commit 466ded0
Show file tree
Hide file tree
Showing 7 changed files with 476 additions and 21 deletions.
201 changes: 201 additions & 0 deletions pkg/remote/codec/thrift/skip_decoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package thrift

import (
"errors"
"fmt"

"github.com/apache/thrift/lib/go/thrift"

"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
)

const (
EnableSkipDecoder CodecType = 0b10000
)

// skipBuffer wraps remote.ByteBuffer and reimplement the Next method.
// Next method would not advance the reader in the underlying buffer until Buffer method is called.
// It is used to fit in BinaryProtocol and reduce memory allocation.
type skipBuffer struct {
remote.ByteBuffer
readNum int
}

func (b *skipBuffer) Next(n int) ([]byte, error) {
prev := b.readNum
next := prev + n
buf, err := b.ByteBuffer.Peek(next)
if err != nil {
return nil, err
}
b.readNum = next
return buf[prev:next], nil
}

func (b *skipBuffer) Buffer() ([]byte, error) {
return b.ByteBuffer.Next(b.readNum)
}

func newSkipBuffer(bb remote.ByteBuffer) *skipBuffer {
return &skipBuffer{
ByteBuffer: bb,
}
}

// skipDecoder is used to parse the input byte-by-byte and skip the thrift payload
// for making use of Frugal and FastCodec in standard Thrift Binary Protocol scenario.
type skipDecoder struct {
tprot *BinaryProtocol
sb *skipBuffer
}

func newSkipDecoder(trans remote.ByteBuffer) *skipDecoder {
sb := newSkipBuffer(trans)
return &skipDecoder{
tprot: NewBinaryProtocol(sb),
sb: sb,
}
}

func (sd *skipDecoder) SkipStruct() error {
return sd.skip(thrift.STRUCT, thrift.DEFAULT_RECURSION_DEPTH)
}

func (sd *skipDecoder) skipString() error {
size, err := sd.tprot.ReadI32()
if err != nil {
return err
}
if size < 0 {
return perrors.InvalidDataLength
}
_, err = sd.tprot.next(int(size))
return err
}

func (sd *skipDecoder) skipMap(maxDepth int) error {
keyTypeId, valTypeId, size, err := sd.tprot.ReadMapBegin()
if err != nil {
return err
}
for i := 0; i < size; i++ {
if err = sd.skip(keyTypeId, maxDepth); err != nil {
return err
}
if err = sd.skip(valTypeId, maxDepth); err != nil {
return err
}
}
return nil
}

func (sd *skipDecoder) skipList(maxDepth int) error {
elemTypeId, size, err := sd.tprot.ReadListBegin()
if err != nil {
return err
}
for i := 0; i < size; i++ {
if err = sd.skip(elemTypeId, maxDepth); err != nil {
return err
}
}
return nil
}

func (sd *skipDecoder) skipSet(maxDepth int) error {
return sd.skipList(maxDepth)
}

func (sd *skipDecoder) skip(typeId thrift.TType, maxDepth int) (err error) {
if maxDepth <= 0 {
return thrift.NewTProtocolExceptionWithType(thrift.DEPTH_LIMIT, errors.New("depth limit exceeded"))
}

switch typeId {
case thrift.BOOL, thrift.BYTE:
if _, err = sd.tprot.next(1); err != nil {
return
}
case thrift.I16:
if _, err = sd.tprot.next(2); err != nil {
return
}
case thrift.I32:
if _, err = sd.tprot.next(4); err != nil {
return
}
case thrift.I64, thrift.DOUBLE:
if _, err = sd.tprot.next(8); err != nil {
return
}
case thrift.STRING:
if err = sd.skipString(); err != nil {
return
}
case thrift.STRUCT:
if err = sd.skipStruct(maxDepth - 1); err != nil {
return
}
case thrift.MAP:
if err = sd.skipMap(maxDepth - 1); err != nil {
return
}
case thrift.SET:
if err = sd.skipSet(maxDepth - 1); err != nil {
return
}
case thrift.LIST:
if err = sd.skipList(maxDepth - 1); err != nil {
return
}
default:
return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("unknown data type %d", typeId))
}
return nil
}

func (sd *skipDecoder) skipStruct(maxDepth int) (err error) {
var fieldTypeId thrift.TType

for {
_, fieldTypeId, _, err = sd.tprot.ReadFieldBegin()
if err != nil {
return err
}
if fieldTypeId == thrift.STOP {
return err
}
if err = sd.skip(fieldTypeId, maxDepth); err != nil {
return err
}
}
}

// Buffer returns the skipped buffer.
// Using this buffer to feed to Frugal or FastCodec
func (sd *skipDecoder) Buffer() ([]byte, error) {
return sd.sb.Buffer()
}

// Recycle recycles the internal BinaryProtocol and would not affect the buffer
// returned by Buffer()
func (sd *skipDecoder) Recycle() {
sd.tprot.Recycle()
}
111 changes: 111 additions & 0 deletions pkg/remote/codec/thrift/skip_decoder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package thrift

import (
"testing"

"github.com/apache/thrift/lib/go/thrift"

"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/remote"
)

func TestSkipBuffer(t *testing.T) {
buf := []byte{1, 2, 3}
rb := remote.NewReaderBuffer(buf)
sb := newSkipBuffer(rb)
peekData, err := sb.Next(3)
test.Assert(t, err == nil)
test.DeepEqual(t, buf, peekData[:len(buf)])
test.Assert(t, sb.readNum == len(buf))
test.Assert(t, rb.ReadLen() == 0)

data, err := sb.Buffer()
test.Assert(t, err == nil)
test.DeepEqual(t, buf, data[:len(buf)])
test.Assert(t, rb.ReadLen() == len(buf))
}

func TestSkipDecoder_SkipStruct(t *testing.T) {
tProt := NewBinaryProtocol(remote.NewReaderWriterBuffer(1024))
defer tProt.Recycle()
tProt.WriteStructBegin("testStruct")
// 1. Byte
tProt.WriteFieldBegin("Byte", thrift.BYTE, 1)
tProt.WriteByte('1')
tProt.WriteFieldEnd()
// 2. Bool
tProt.WriteFieldBegin("Bool", thrift.BOOL, 2)
tProt.WriteBool(true)
tProt.WriteFieldEnd()
// 3. I16
tProt.WriteFieldBegin("I16", thrift.I16, 3)
tProt.WriteI16(2)
tProt.WriteFieldEnd()
// 4. I32
tProt.WriteFieldBegin("I32", thrift.I32, 4)
tProt.WriteI32(3)
tProt.WriteFieldEnd()
// 5. I64
tProt.WriteFieldBegin("I64", thrift.I64, 5)
tProt.WriteI64(4)
tProt.WriteFieldEnd()
// 6. Double
tProt.WriteFieldBegin("Double", thrift.DOUBLE, 6)
tProt.WriteDouble(5)
tProt.WriteFieldEnd()
// 7. String
tProt.WriteFieldBegin("String", thrift.STRING, 7)
tProt.WriteString("6")
tProt.WriteFieldEnd()
// 8. Map
tProt.WriteFieldBegin("Map", thrift.MAP, 8)
tProt.WriteMapBegin(thrift.I32, thrift.I32, 1)
tProt.WriteI32(7)
tProt.WriteI32(8)
tProt.WriteMapEnd()
tProt.WriteFieldEnd()
// 9. Set
tProt.WriteFieldBegin("Set", thrift.SET, 9)
tProt.WriteSetBegin(thrift.I32, 1)
tProt.WriteI32(9)
tProt.WriteSetEnd()
tProt.WriteFieldEnd()
// 10. List
tProt.WriteFieldBegin("List", thrift.LIST, 10)
tProt.WriteListBegin(thrift.I32, 1)
tProt.WriteI32(9)
tProt.WriteListEnd()
tProt.WriteFieldEnd()

tProt.WriteFieldStop()
tProt.WriteStructEnd()

length := tProt.ByteBuffer().ReadableLen()
sd := newSkipDecoder(tProt.ByteBuffer())
defer sd.Recycle()
err := sd.SkipStruct()
test.Assert(t, err == nil)
test.Assert(t, sd.sb.readNum == length)
test.Assert(t, sd.sb.ReadLen() == 0)
test.Assert(t, sd.sb.ReadableLen() == length)
_, err = sd.Buffer()
test.Assert(t, err == nil)
test.Assert(t, sd.sb.ReadLen() == length)
test.Assert(t, sd.sb.ReadableLen() == 0)
}
Loading

0 comments on commit 466ded0

Please sign in to comment.