Skip to content

Commit

Permalink
{channel,server}_test: add tests for message limits.
Browse files Browse the repository at this point in the history
Adjust unit test to accomodate for altered internal interfaces.
Add unit tests to exercise the new message size limit options.

Signed-off-by: Krisztian Litkey <[email protected]>
  • Loading branch information
klihub committed Sep 13, 2024
1 parent 5a0da2b commit e01a569
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 32 deletions.
6 changes: 3 additions & 3 deletions channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import (
func TestReadWriteMessage(t *testing.T) {
var (
w, r = net.Pipe()
ch = newChannel(w)
rch = newChannel(r)
ch = newChannel(w, 0)
rch = newChannel(r, 0)
messages = [][]byte{
[]byte("hello"),
[]byte("this is a test"),
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestReadWriteMessage(t *testing.T) {
func TestMessageOversize(t *testing.T) {
var (
w, _ = net.Pipe()
wch = newChannel(w)
wch = newChannel(w, 0)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
errs = make(chan error, 1)
)
Expand Down
273 changes: 244 additions & 29 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package ttrpc
import (
"bytes"
"context"
"crypto/md5"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (*
}

// testingServer is what would be implemented by the user of this package.
type testingServer struct{}
type testingServer struct {
echoOnce bool
}

func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) {
tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)}
tp := &internal.TestPayload{}
if s.echoOnce {
tp.Foo = req.Foo
} else {
tp.Foo = strings.Repeat(req.Foo, 2)
}
if dl, ok := ctx.Deadline(); ok {
tp.Deadline = dl.UnixNano()
}
Expand Down Expand Up @@ -330,38 +338,238 @@ func TestImmediateServerShutdown(t *testing.T) {
}

func TestOversizeCall(t *testing.T) {
var (
ctx = context.Background()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
)
defer cleanup()
defer listener.Close()
go func() {
errs <- server.Serve(ctx, listener)
}()
type testCase struct {
name string
echoOnce bool
clientLimit int
serverLimit int
requestSize int
clientFail bool
sendFail bool
serverFail bool
}

overhead := getWireMessageOverhead(t)

clientOpts := func(tc *testCase) []ClientOpts {
if tc.clientLimit == 0 {
return nil
}
return []ClientOpts{WithClientWireMessageLimit(tc.clientLimit)}
}
serverOpts := func(tc *testCase) []ServerOpt {
if tc.serverLimit == 0 {
return nil
}
return []ServerOpt{WithServerWireMessageLimit(tc.serverLimit)}
}

registerTestingService(server, &testingServer{})
runTest := func(t *testing.T, tc *testCase) {
var (
ctx = context.Background()
server = mustServer(t)(NewServer(serverOpts(tc)...))
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr, clientOpts(tc)...)
)
defer cleanup()
defer listener.Close()
go func() {
errs <- server.Serve(ctx, listener)
}()

registerTestingService(server, &testingServer{echoOnce: tc.echoOnce})

tp := &internal.TestPayload{
Foo: strings.Repeat("a", 1+messageLengthMax),
req := &internal.TestPayload{
Foo: strings.Repeat("a", tc.requestSize),
}
rsp := &internal.TestPayload{}

err := client.Call(ctx, serviceName, "Test", req, rsp)
if tc.clientFail {
if err == nil {
t.Fatalf("expected error from oversized message")
} else if status, ok := status.FromError(err); !ok {
t.Fatalf("expected status present in error: %v", err)
} else if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
}
if tc.sendFail {
var msgLenErr *OversizedMessageErr
if !errors.As(err, &msgLenErr) {
t.Fatalf("failed to retrieve client send OversizedMessageErr")
}
rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength()
if rejLen == 0 {
t.Fatalf("zero rejected length in client send oversized message error")
}
if maxLen == 0 {
t.Fatalf("zero maximum length in client send oversized message error")
}
if rejLen <= maxLen {
t.Fatalf("client send oversized message error rejected < max. length (%d < %d)",
rejLen, maxLen)
}
}
} else if tc.serverFail {
if err == nil {
t.Fatalf("expected error from server-side oversized message")
} else {
if status, ok := status.FromError(err); !ok {
t.Fatalf("expected status present in error: %v", err)
} else if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
}
if msgLenErr, ok := OversizedMessageFromError(err); !ok {
t.Fatalf("failed to retrieve oversized message error")
} else {
rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength()
if rejLen == 0 {
t.Fatalf("zero rejected length in oversized message error")
}
if maxLen == 0 {
t.Fatalf("zero maximum length in oversized message error")
}
if rejLen <= maxLen {
t.Fatalf("oversized message error rejected < max. length (%d < %d)",
rejLen, maxLen)
}
}
}
} else {
if err != nil {
t.Fatalf("expected success, got error %v", err)
}
}

if err := server.Shutdown(ctx); err != nil {
t.Fatal(err)
}
if err := <-errs; err != ErrServerClosed {
t.Fatal(err)
}
}
if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
t.Fatalf("expected error from oversized message")
} else if status, ok := status.FromError(err); !ok {
t.Fatalf("expected status present in error: %v", err)
} else if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)

for _, tc := range []*testCase{
{
name: "default limits, fitting request and response",
echoOnce: true,
clientLimit: 0,
serverLimit: 0,
requestSize: DefaultMessageLengthLimit - overhead,
},
{
name: "default limits, only recv side check",
clientLimit: 0,
serverLimit: 0,
requestSize: DefaultMessageLengthLimit - overhead,
serverFail: true,
},

{
name: "default limits, oversized request",
echoOnce: true,
clientLimit: 0,
serverLimit: 0,
requestSize: DefaultMessageLengthLimit,
clientFail: true,
},
{
name: "default limits, oversized response",
clientLimit: 0,
serverLimit: 0,
requestSize: DefaultMessageLengthLimit / 2,
serverFail: true,
},
{
name: "8K limits, 4K request and response",
echoOnce: true,
clientLimit: 8 * 1024,
serverLimit: 8 * 1024,
requestSize: 4 * 1024,
},
{
name: "4K limits, barely fitting cc. 4K request and response",
echoOnce: true,
clientLimit: 4 * 1024,
serverLimit: 4 * 1024,
requestSize: 4*1024 - overhead,
},
{
name: "4K limits, oversized request on client side",
echoOnce: true,
clientLimit: 4 * 1024,
serverLimit: 4 * 1024,
requestSize: 4 * 1024,
clientFail: true,
sendFail: true,
},
{
name: "4K limits, oversized request on server side",
echoOnce: true,
clientLimit: 4*1024 + overhead,
serverLimit: 4 * 1024,
requestSize: 4 * 1024,
serverFail: true,
},
{
name: "4K limits, oversized response on client side",
clientLimit: 4*1024 + overhead,
serverLimit: 4 * 1024,
requestSize: 8*1024 + overhead,
clientFail: true,
},
{
name: "4K limits, oversized response on server side",
clientLimit: 4*1024 + overhead,
serverLimit: 4 * 1024,
requestSize: 4 * 1024,
serverFail: true,
},
{
name: "too small limits, adjusted to minimum accepted limit",
echoOnce: true,
clientLimit: 4,
serverLimit: 4,
requestSize: 4*1024 - overhead,
},
{
name: "maximum allowed protocol limit",
echoOnce: true,
clientLimit: MaxMessageLengthLimit,
serverLimit: MaxMessageLengthLimit,
requestSize: MaxMessageLengthLimit - overhead,
},
} {
t.Run(tc.name, func(t *testing.T) {
runTest(t, tc)
})
}
}

if err := server.Shutdown(ctx); err != nil {
t.Fatal(err)
func getWireMessageOverhead(t *testing.T) int {
emptyReq, err := codec{}.Marshal(&Request{
Service: serviceName,
Method: "Test",
})
if err != nil {
t.Fatalf("failed to marshal empty request: %v", err)
}
if err := <-errs; err != ErrServerClosed {
t.Fatal(err)

emptyRsp, err := codec{}.Marshal(&Response{
Status: status.New(codes.OK, "").Proto(),
})
if err != nil {
t.Fatalf("failed to marshal empty response: %v", err)
}

reqLen := len(emptyReq)
rspLen := len(emptyRsp)
if reqLen > rspLen {
return reqLen + messageHeaderLength
}

return rspLen + messageHeaderLength
}

func TestClientEOF(t *testing.T) {
Expand Down Expand Up @@ -582,13 +790,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func
}

func newTestListener(t testing.TB) (string, net.Listener) {
var prefix string
var (
name = t.Name()
prefix string
)

// Abstracts sockets are only available on Linux.
if runtime.GOOS == "linux" {
prefix = "\x00"
} else {
if split := strings.SplitN(name, "/", 2); len(split) == 2 {
name = split[0] + "-" + fmt.Sprintf("%x", md5.Sum([]byte(split[1])))
}
}
addr := prefix + t.Name()
addr := prefix + name
listener, err := net.Listen("unix", addr)
if err != nil {
t.Fatal(err)
Expand Down

0 comments on commit e01a569

Please sign in to comment.