Skip to content

Commit

Permalink
client,server: configurable wire message size limits.
Browse files Browse the repository at this point in the history
Implement configurable limits for the maximum accepted message
size of the wire protocol. The default limit can be overridden
using the WithClientWireMessageLimit() option for clients and
using the WithServerWireMessageLimit() option for servers. Add
exported constants for the minimum, maximum and default limits.

Signed-off-by: Krisztian Litkey <[email protected]>
  • Loading branch information
klihub committed Sep 13, 2024
1 parent 525ddce commit e1f03b3
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 29 deletions.
61 changes: 43 additions & 18 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ import (
"io"
"net"
"sync"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const (
messageHeaderLength = 10
messageLengthMax = 4 << 20
messageHeaderLength = 10
MinMessageLengthLimit = 4 << 10
MaxMessageLengthLimit = 4 << 22
DefaultMessageLengthLimit = 4 << 20
)

type messageType uint8
Expand Down Expand Up @@ -96,18 +95,23 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
var buffers sync.Pool

type channel struct {
conn net.Conn
bw *bufio.Writer
br *bufio.Reader
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
hwbuf [messageHeaderLength]byte
conn net.Conn
bw *bufio.Writer
br *bufio.Reader
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
hwbuf [messageHeaderLength]byte
maxMsgLen int
}

func newChannel(conn net.Conn) *channel {
func newChannel(conn net.Conn, maxMsgLen int) *channel {
if maxMsgLen == 0 {
maxMsgLen = DefaultMessageLengthLimit
}
return &channel{
conn: conn,
bw: bufio.NewWriter(conn),
br: bufio.NewReader(conn),
conn: conn,
bw: bufio.NewWriter(conn),
br: bufio.NewReader(conn),
maxMsgLen: maxMsgLen,
}
}

Expand All @@ -123,12 +127,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
return messageHeader{}, nil, err
}

if mh.Length > uint32(messageLengthMax) {
if maxMsgLen := ch.maxMsgLimit(true); mh.Length > uint32(maxMsgLen) {
if _, err := ch.br.Discard(int(mh.Length)); err != nil {
return mh, nil, fmt.Errorf("failed to discard after receiving oversized message: %w", err)
}

return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax)
return mh, nil, OversizedMessageError(int(mh.Length), maxMsgLen)
}

var p []byte
Expand All @@ -143,8 +147,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
}

func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
if len(p) > messageLengthMax {
return OversizedMessageError(len(p))
if maxMsgLen := ch.maxMsgLimit(false); maxMsgLen != 0 {
if len(p) > maxMsgLen {
return OversizedMessageError(len(p), maxMsgLen)
}
}

if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
Expand Down Expand Up @@ -180,3 +186,22 @@ func (ch *channel) getmbuf(size int) []byte {
func (ch *channel) putmbuf(p []byte) {
buffers.Put(&p)
}

func (ch *channel) maxMsgLimit(recv bool) int {
if ch.maxMsgLen == 0 && recv {
return DefaultMessageLengthLimit
}
return ch.maxMsgLen
}

func clampWireMessageLimit(maxMsgLen int) int {
switch {
case maxMsgLen == 0:
return 0
case maxMsgLen < MinMessageLengthLimit:
return MinMessageLengthLimit
case maxMsgLen > MaxMessageLengthLimit:
return MaxMessageLengthLimit
}
return maxMsgLen
}
19 changes: 14 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ import (

// Client for a ttrpc server
type Client struct {
codec codec
conn net.Conn
channel *channel
codec codec
conn net.Conn
channel *channel
maxMsgLen int

streamLock sync.RWMutex
streams map[streamID]*stream
Expand Down Expand Up @@ -107,14 +108,20 @@ func chainUnaryInterceptors(interceptors []UnaryClientInterceptor, final Invoker
}
}

// WithClientWireMessageLimit sets the maximum allowed message length on the wire for the client.
func WithClientWireMessageLimit(maxMsgLen int) ClientOpts {
maxMsgLen = clampWireMessageLimit(maxMsgLen)
return func(c *Client) {
c.maxMsgLen = maxMsgLen
}
}

// NewClient creates a new ttrpc client using the given connection
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
ctx, cancel := context.WithCancel(context.Background())
channel := newChannel(conn)
c := &Client{
codec: codec{},
conn: conn,
channel: channel,
streams: make(map[streamID]*stream),
nextStreamID: 1,
closed: cancel,
Expand All @@ -127,6 +134,8 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
o(c)
}

c.channel = newChannel(conn, c.maxMsgLen)

if c.interceptor == nil {
c.interceptor = defaultClientInterceptor
}
Expand Down
10 changes: 10 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
type serverConfig struct {
handshaker Handshaker
interceptor UnaryServerInterceptor
maxMsgLen int
}

// ServerOpt for configuring a ttrpc server
Expand Down Expand Up @@ -84,3 +85,12 @@ func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, intercep
chainUnaryServerInterceptors(info, method, interceptors[1:]))
}
}

// WithServerWireMessageLimit sets the maximum allowed message length on the wire for the server.
func WithServerWireMessageLimit(maxMsgLen int) ServerOpt {
maxMsgLen = clampWireMessageLimit(maxMsgLen)
return func(c *serverConfig) error {
c.maxMsgLen = maxMsgLen
return nil
}
}
50 changes: 45 additions & 5 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ttrpc

import (
"errors"
"fmt"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand All @@ -43,20 +44,59 @@ var (
// length.
type OversizedMessageErr struct {
messageLength int
maxLength int
err error
}

var (
oversizedMsgFmt = "message length %d exceeds maximum message size of %d"
oversizedMsgScanFmt = fmt.Sprintf("%v", status.New(codes.ResourceExhausted, oversizedMsgFmt))
)

// OversizedMessageError returns an OversizedMessageErr error for the given message
// length if it exceeds the allowed maximum. Otherwise a nil error is returned.
func OversizedMessageError(messageLength int) error {
if messageLength <= messageLengthMax {
func OversizedMessageError(messageLength, maxLength int) error {
if messageLength <= maxLength {
return nil
}

return &OversizedMessageErr{
messageLength: messageLength,
err: status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", messageLength, messageLengthMax),
maxLength: maxLength,
err: OversizedMessageStatus(messageLength, maxLength).Err(),
}
}

// OversizedMessageStatus returns a Status for an oversized message error.
func OversizedMessageStatus(messageLength, maxLength int) *status.Status {
return status.Newf(codes.ResourceExhausted, oversizedMsgFmt, messageLength, maxLength)
}

// OversizedMessageFromError reconstructs an OversizedMessageErr from a Status.
func OversizedMessageFromError(err error) (*OversizedMessageErr, bool) {
var (
messageLength int
maxLength int
)

st, ok := status.FromError(err)
if !ok || st.Code() != codes.ResourceExhausted {
return nil, false
}

// TODO(klihub): might be too ugly to recover an error this way... An
// alternative would be to define our custom status detail proto type,
// then use status.WithDetails() and status.Details().

n, _ := fmt.Sscanf(st.Message(), oversizedMsgScanFmt, &messageLength, &maxLength)
if n != 2 {
n, _ = fmt.Sscanf(st.Message(), oversizedMsgFmt, &messageLength, &maxLength)
}
if n != 2 {
return nil, false
}

return OversizedMessageError(messageLength, maxLength).(*OversizedMessageErr), true
}

// Error returns the error message for the corresponding grpc Status for the error.
Expand All @@ -75,6 +115,6 @@ func (e *OversizedMessageErr) RejectedLength() int {
}

// MaximumLength retrieves the maximum allowed message length that triggered the error.
func (*OversizedMessageErr) MaximumLength() int {
return messageLengthMax
func (e *OversizedMessageErr) MaximumLength() int {
return e.maxLength
}
21 changes: 20 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func (c *serverConn) run(sctx context.Context) {
)

var (
ch = newChannel(c.conn)
ch = newChannel(c.conn, c.server.config.maxMsgLen)
ctx, cancel = context.WithCancel(sctx)
state connState = connStateIdle
responses = make(chan response)
Expand Down Expand Up @@ -373,6 +373,14 @@ func (c *serverConn) run(sctx context.Context) {
}
}

isResourceExhaustedError := func(err error) (*status.Status, bool) {
st, ok := status.FromError(err)
if !ok || st.Code() != codes.ResourceExhausted {
return nil, false
}
return st, true
}

go func(recvErr chan error) {
defer close(recvErr)
for {
Expand Down Expand Up @@ -525,6 +533,17 @@ func (c *serverConn) run(sctx context.Context) {
}

if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil {
if st, ok := isResourceExhaustedError(err); ok {
p, err = c.server.codec.Marshal(&Response{
Status: st.Proto(),
})
if err != nil {
log.G(ctx).WithError(err).Error("failed marshaling error response")
return
}
ch.send(response.id, messageTypeResponse, 0, p)
return
}
log.G(ctx).WithError(err).Error("failed sending message on channel")
return
}
Expand Down

0 comments on commit e1f03b3

Please sign in to comment.