diff --git a/net.go b/net.go index 1555b172a..66c1dcd94 100644 --- a/net.go +++ b/net.go @@ -800,11 +800,17 @@ func (m *Memberlist) sendMsg(a Address, msg []byte) error { msgs = append(msgs, msg) msgs = append(msgs, extra...) - // Create a compound message - compound := makeCompoundMessage(msgs) + // Create one or more compound messages. + compounds := makeCompoundMessages(msgs) - // Send the message - return m.rawSendMsgPacket(a, nil, compound.Bytes()) + // Send the messages. + for _, compound := range compounds { + if err := m.rawSendMsgPacket(a, nil, compound.Bytes()); err != nil { + return err + } + } + + return nil } // rawSendMsgPacket is used to send message via packet to another host without diff --git a/util.go b/util.go index 24112210d..8f609c1e0 100644 --- a/util.go +++ b/util.go @@ -154,11 +154,37 @@ OUTER: // makeCompoundMessages takes a list of messages and packs // them into one or multiple messages based on the limitations -// of compound messages (255 messages each). +// of compound messages (255 messages each, 64KB max message size). +// +// The input msgs can be modified in-place. func makeCompoundMessages(msgs [][]byte) []*bytes.Buffer { - const maxMsgs = 255 + const ( + maxMsgs = math.MaxUint8 + maxMsgLength = math.MaxUint16 + ) + + // Optimistically assume there will be no big message. bufs := make([]*bytes.Buffer, 0, (len(msgs)+(maxMsgs-1))/maxMsgs) + // Do not add to a compound message any message bigger than the max message length + // we can store. + r, w := 0, 0 + for r < len(msgs) { + if len(msgs[r]) <= maxMsgLength { + // Keep it. + msgs[w] = msgs[r] + r++ + w++ + continue + } + + // This message is a large one, so we send it alone. + bufs = append(bufs, bytes.NewBuffer(msgs[r])) + r++ + } + msgs = msgs[:w] + + // Group remaining messages in compound message(s). for ; len(msgs) > maxMsgs; msgs = msgs[maxMsgs:] { bufs = append(bufs, makeCompoundMessage(msgs[:maxMsgs])) } diff --git a/util_test.go b/util_test.go index 5e3edb633..4b57f7aa0 100644 --- a/util_test.go +++ b/util_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -384,3 +385,137 @@ func TestCompressDecompressPayload(t *testing.T) { t.Fatalf("bad payload: %v", decomp) } } + +func TestMakeCompoundMessages(t *testing.T) { + const ( + smallMsgSeqNo = uint32(1) + smallMsgPayloadLength = 1 + bigMsgSeqNo = uint32(2) + bigMsgPayloadLength = 70000 + ) + + // Generate some fixtures. + smallMessages := make([][]byte, 300) + for i := 0; i < len(smallMessages); i++ { + msg := &ackResp{SeqNo: smallMsgSeqNo, Payload: []byte{byte(i)}} + encoded, err := encode(ackRespMsg, msg) + require.NoError(t, err) + + smallMessages[i] = encoded.Bytes() + } + + bigMessages := make([][]byte, 3) + for i := 0; i < len(bigMessages); i++ { + payload := []byte{bigMsgPayloadLength - 1: byte(i)} + require.Len(t, payload, bigMsgPayloadLength) + + msg := &ackResp{SeqNo: bigMsgSeqNo, Payload: payload} + encoded, err := encode(ackRespMsg, msg) + require.NoError(t, err) + + bigMessages[i] = encoded.Bytes() + } + + tests := map[string]struct { + input [][]byte + expected [][]byte + }{ + "no input": { + input: [][]byte{}, + expected: [][]byte{}, + }, + "one small message": { + input: smallMessages[0:1], + expected: [][]byte{makeCompoundMessage(smallMessages[0:1]).Bytes()}, + }, + "few small messages": { + input: smallMessages[0:3], + expected: [][]byte{makeCompoundMessage(smallMessages[0:3]).Bytes()}, + }, + "many small messages (more than 255)": { + input: smallMessages[0:300], + expected: [][]byte{ + makeCompoundMessage(smallMessages[0:255]).Bytes(), + makeCompoundMessage(smallMessages[255:300]).Bytes(), + }, + }, + "one big message": { + input: bigMessages[0:1], + expected: bigMessages[0:1], + }, + "few big messages": { + input: bigMessages[0:3], + expected: bigMessages[0:3], + }, + "mix of many small and big messages": { + input: func() [][]byte { + var out [][]byte + + out = append(out, bigMessages[0]) + out = append(out, smallMessages[0:20]...) + out = append(out, bigMessages[1]) + out = append(out, smallMessages[20:260]...) + out = append(out, bigMessages[2]) + out = append(out, smallMessages[260:300]...) + + return out + }(), + expected: [][]byte{ + bigMessages[0], + bigMessages[1], + bigMessages[2], + makeCompoundMessage(smallMessages[0:255]).Bytes(), + makeCompoundMessage(smallMessages[255:300]).Bytes(), + }, + }, + } + + for testName, testData := range tests { + t.Run(testName, func(t *testing.T) { + actual := makeCompoundMessages(testData.input) + + // Get the actual []byte of each message. + actualBytes := make([][]byte, 0, len(actual)) + for _, data := range actual { + actualBytes = append(actualBytes, data.Bytes()) + } + + assert.Equal(t, testData.expected, actualBytes) + + // Ensure we can successfully decode every message. + for i := 0; i < len(actual); i++ { + msg := actualBytes[i] + typ := messageType(msg[0]) + + switch typ { + case ackRespMsg: + var got ackResp + require.NoError(t, decode(msg[1:], &got)) + + if got.SeqNo == smallMsgSeqNo { + assert.Len(t, got.Payload, smallMsgPayloadLength) + } else if got.SeqNo == bigMsgSeqNo { + assert.Len(t, got.Payload, bigMsgPayloadLength) + } else { + require.Fail(t, "unexpected seq no") + } + case compoundMsg: + trunc, parts, err := decodeCompoundMessage(msg[1:]) + require.NoError(t, err) + require.Equal(t, 0, trunc) + + for _, part := range parts { + require.Equal(t, ackRespMsg, messageType(part[0])) + + var got ackResp + require.NoError(t, decode(part[1:], &got)) + assert.Equal(t, smallMsgSeqNo, got.SeqNo) + assert.Len(t, got.Payload, smallMsgPayloadLength) + } + default: + require.Fail(t, "unexpected message") + } + } + }) + } +}