forked from cloudwego/kitex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request cloudwego#1313 from DMwangnima/feat/skip_codec
feat(thrift codec): implement skipDecoder to enable Frugal and FastCodec for standard Thrift Buffer Protocol
- Loading branch information
Showing
7 changed files
with
476 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
/* | ||
* Copyright 2024 CloudWeGo Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package thrift | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
|
||
"github.com/apache/thrift/lib/go/thrift" | ||
|
||
"github.com/cloudwego/kitex/pkg/remote" | ||
"github.com/cloudwego/kitex/pkg/remote/codec/perrors" | ||
) | ||
|
||
const ( | ||
EnableSkipDecoder CodecType = 0b10000 | ||
) | ||
|
||
// skipBuffer wraps remote.ByteBuffer and reimplement the Next method. | ||
// Next method would not advance the reader in the underlying buffer until Buffer method is called. | ||
// It is used to fit in BinaryProtocol and reduce memory allocation. | ||
type skipBuffer struct { | ||
remote.ByteBuffer | ||
readNum int | ||
} | ||
|
||
func (b *skipBuffer) Next(n int) ([]byte, error) { | ||
prev := b.readNum | ||
next := prev + n | ||
buf, err := b.ByteBuffer.Peek(next) | ||
if err != nil { | ||
return nil, err | ||
} | ||
b.readNum = next | ||
return buf[prev:next], nil | ||
} | ||
|
||
func (b *skipBuffer) Buffer() ([]byte, error) { | ||
return b.ByteBuffer.Next(b.readNum) | ||
} | ||
|
||
func newSkipBuffer(bb remote.ByteBuffer) *skipBuffer { | ||
return &skipBuffer{ | ||
ByteBuffer: bb, | ||
} | ||
} | ||
|
||
// skipDecoder is used to parse the input byte-by-byte and skip the thrift payload | ||
// for making use of Frugal and FastCodec in standard Thrift Binary Protocol scenario. | ||
type skipDecoder struct { | ||
tprot *BinaryProtocol | ||
sb *skipBuffer | ||
} | ||
|
||
func newSkipDecoder(trans remote.ByteBuffer) *skipDecoder { | ||
sb := newSkipBuffer(trans) | ||
return &skipDecoder{ | ||
tprot: NewBinaryProtocol(sb), | ||
sb: sb, | ||
} | ||
} | ||
|
||
func (sd *skipDecoder) SkipStruct() error { | ||
return sd.skip(thrift.STRUCT, thrift.DEFAULT_RECURSION_DEPTH) | ||
} | ||
|
||
func (sd *skipDecoder) skipString() error { | ||
size, err := sd.tprot.ReadI32() | ||
if err != nil { | ||
return err | ||
} | ||
if size < 0 { | ||
return perrors.InvalidDataLength | ||
} | ||
_, err = sd.tprot.next(int(size)) | ||
return err | ||
} | ||
|
||
func (sd *skipDecoder) skipMap(maxDepth int) error { | ||
keyTypeId, valTypeId, size, err := sd.tprot.ReadMapBegin() | ||
if err != nil { | ||
return err | ||
} | ||
for i := 0; i < size; i++ { | ||
if err = sd.skip(keyTypeId, maxDepth); err != nil { | ||
return err | ||
} | ||
if err = sd.skip(valTypeId, maxDepth); err != nil { | ||
return err | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func (sd *skipDecoder) skipList(maxDepth int) error { | ||
elemTypeId, size, err := sd.tprot.ReadListBegin() | ||
if err != nil { | ||
return err | ||
} | ||
for i := 0; i < size; i++ { | ||
if err = sd.skip(elemTypeId, maxDepth); err != nil { | ||
return err | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func (sd *skipDecoder) skipSet(maxDepth int) error { | ||
return sd.skipList(maxDepth) | ||
} | ||
|
||
func (sd *skipDecoder) skip(typeId thrift.TType, maxDepth int) (err error) { | ||
if maxDepth <= 0 { | ||
return thrift.NewTProtocolExceptionWithType(thrift.DEPTH_LIMIT, errors.New("depth limit exceeded")) | ||
} | ||
|
||
switch typeId { | ||
case thrift.BOOL, thrift.BYTE: | ||
if _, err = sd.tprot.next(1); err != nil { | ||
return | ||
} | ||
case thrift.I16: | ||
if _, err = sd.tprot.next(2); err != nil { | ||
return | ||
} | ||
case thrift.I32: | ||
if _, err = sd.tprot.next(4); err != nil { | ||
return | ||
} | ||
case thrift.I64, thrift.DOUBLE: | ||
if _, err = sd.tprot.next(8); err != nil { | ||
return | ||
} | ||
case thrift.STRING: | ||
if err = sd.skipString(); err != nil { | ||
return | ||
} | ||
case thrift.STRUCT: | ||
if err = sd.skipStruct(maxDepth - 1); err != nil { | ||
return | ||
} | ||
case thrift.MAP: | ||
if err = sd.skipMap(maxDepth - 1); err != nil { | ||
return | ||
} | ||
case thrift.SET: | ||
if err = sd.skipSet(maxDepth - 1); err != nil { | ||
return | ||
} | ||
case thrift.LIST: | ||
if err = sd.skipList(maxDepth - 1); err != nil { | ||
return | ||
} | ||
default: | ||
return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("unknown data type %d", typeId)) | ||
} | ||
return nil | ||
} | ||
|
||
func (sd *skipDecoder) skipStruct(maxDepth int) (err error) { | ||
var fieldTypeId thrift.TType | ||
|
||
for { | ||
_, fieldTypeId, _, err = sd.tprot.ReadFieldBegin() | ||
if err != nil { | ||
return err | ||
} | ||
if fieldTypeId == thrift.STOP { | ||
return err | ||
} | ||
if err = sd.skip(fieldTypeId, maxDepth); err != nil { | ||
return err | ||
} | ||
} | ||
} | ||
|
||
// Buffer returns the skipped buffer. | ||
// Using this buffer to feed to Frugal or FastCodec | ||
func (sd *skipDecoder) Buffer() ([]byte, error) { | ||
return sd.sb.Buffer() | ||
} | ||
|
||
// Recycle recycles the internal BinaryProtocol and would not affect the buffer | ||
// returned by Buffer() | ||
func (sd *skipDecoder) Recycle() { | ||
sd.tprot.Recycle() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
/* | ||
* Copyright 2024 CloudWeGo Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package thrift | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/apache/thrift/lib/go/thrift" | ||
|
||
"github.com/cloudwego/kitex/internal/test" | ||
"github.com/cloudwego/kitex/pkg/remote" | ||
) | ||
|
||
func TestSkipBuffer(t *testing.T) { | ||
buf := []byte{1, 2, 3} | ||
rb := remote.NewReaderBuffer(buf) | ||
sb := newSkipBuffer(rb) | ||
peekData, err := sb.Next(3) | ||
test.Assert(t, err == nil) | ||
test.DeepEqual(t, buf, peekData[:len(buf)]) | ||
test.Assert(t, sb.readNum == len(buf)) | ||
test.Assert(t, rb.ReadLen() == 0) | ||
|
||
data, err := sb.Buffer() | ||
test.Assert(t, err == nil) | ||
test.DeepEqual(t, buf, data[:len(buf)]) | ||
test.Assert(t, rb.ReadLen() == len(buf)) | ||
} | ||
|
||
func TestSkipDecoder_SkipStruct(t *testing.T) { | ||
tProt := NewBinaryProtocol(remote.NewReaderWriterBuffer(1024)) | ||
defer tProt.Recycle() | ||
tProt.WriteStructBegin("testStruct") | ||
// 1. Byte | ||
tProt.WriteFieldBegin("Byte", thrift.BYTE, 1) | ||
tProt.WriteByte('1') | ||
tProt.WriteFieldEnd() | ||
// 2. Bool | ||
tProt.WriteFieldBegin("Bool", thrift.BOOL, 2) | ||
tProt.WriteBool(true) | ||
tProt.WriteFieldEnd() | ||
// 3. I16 | ||
tProt.WriteFieldBegin("I16", thrift.I16, 3) | ||
tProt.WriteI16(2) | ||
tProt.WriteFieldEnd() | ||
// 4. I32 | ||
tProt.WriteFieldBegin("I32", thrift.I32, 4) | ||
tProt.WriteI32(3) | ||
tProt.WriteFieldEnd() | ||
// 5. I64 | ||
tProt.WriteFieldBegin("I64", thrift.I64, 5) | ||
tProt.WriteI64(4) | ||
tProt.WriteFieldEnd() | ||
// 6. Double | ||
tProt.WriteFieldBegin("Double", thrift.DOUBLE, 6) | ||
tProt.WriteDouble(5) | ||
tProt.WriteFieldEnd() | ||
// 7. String | ||
tProt.WriteFieldBegin("String", thrift.STRING, 7) | ||
tProt.WriteString("6") | ||
tProt.WriteFieldEnd() | ||
// 8. Map | ||
tProt.WriteFieldBegin("Map", thrift.MAP, 8) | ||
tProt.WriteMapBegin(thrift.I32, thrift.I32, 1) | ||
tProt.WriteI32(7) | ||
tProt.WriteI32(8) | ||
tProt.WriteMapEnd() | ||
tProt.WriteFieldEnd() | ||
// 9. Set | ||
tProt.WriteFieldBegin("Set", thrift.SET, 9) | ||
tProt.WriteSetBegin(thrift.I32, 1) | ||
tProt.WriteI32(9) | ||
tProt.WriteSetEnd() | ||
tProt.WriteFieldEnd() | ||
// 10. List | ||
tProt.WriteFieldBegin("List", thrift.LIST, 10) | ||
tProt.WriteListBegin(thrift.I32, 1) | ||
tProt.WriteI32(9) | ||
tProt.WriteListEnd() | ||
tProt.WriteFieldEnd() | ||
|
||
tProt.WriteFieldStop() | ||
tProt.WriteStructEnd() | ||
|
||
length := tProt.ByteBuffer().ReadableLen() | ||
sd := newSkipDecoder(tProt.ByteBuffer()) | ||
defer sd.Recycle() | ||
err := sd.SkipStruct() | ||
test.Assert(t, err == nil) | ||
test.Assert(t, sd.sb.readNum == length) | ||
test.Assert(t, sd.sb.ReadLen() == 0) | ||
test.Assert(t, sd.sb.ReadableLen() == length) | ||
_, err = sd.Buffer() | ||
test.Assert(t, err == nil) | ||
test.Assert(t, sd.sb.ReadLen() == length) | ||
test.Assert(t, sd.sb.ReadableLen() == 0) | ||
} |
Oops, something went wrong.