From 8d2a27af7d1d23119da66585d71f7b1e60fde3d3 Mon Sep 17 00:00:00 2001 From: Steve Simpson Date: Thu, 8 Jul 2021 15:51:22 +0200 Subject: [PATCH 01/12] Fix for compound messages containing >255 messages. --- memberlist_test.go | 6 +++--- state.go | 10 ++++++---- util.go | 17 +++++++++++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/memberlist_test.go b/memberlist_test.go index 70b5d780c..6a1eb70a0 100644 --- a/memberlist_test.go +++ b/memberlist_test.go @@ -1200,9 +1200,9 @@ func TestMemberlist_UserData(t *testing.T) { bindPort := m1.config.BindPort - bcasts := [][]byte{ - []byte("test"), - []byte("foobar"), + bcasts := make([][]byte, 256) + for i := range bcasts { + bcasts[i] = []byte(fmt.Sprintf("%d", i)) } // Create a second node diff --git a/state.go b/state.go index 5e4f7fdd7..95e6ddc48 100644 --- a/state.go +++ b/state.go @@ -606,10 +606,12 @@ func (m *Memberlist) gossip() { m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err) } } else { - // Otherwise create and send a compound message - compound := makeCompoundMessage(msgs) - if err := m.rawSendMsgPacket(node.FullAddress(), &node, compound.Bytes()); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err) + // Otherwise create and send one or more compound messages + compounds := makeCompoundMessages(msgs) + for _, compound := range compounds { + if err := m.rawSendMsgPacket(node.FullAddress(), &node, compound.Bytes()); err != nil { + m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err) + } } } } diff --git a/util.go b/util.go index 16a7d36d0..96d2c6b8f 100644 --- a/util.go +++ b/util.go @@ -152,6 +152,23 @@ OUTER: return kNodes } +// makeCompoundMessages takes a list of messages and packs +// them into one or multiple messages based on the limitations +// of compound messages (255 messages each). +func makeCompoundMessages(msgs [][]byte) []*bytes.Buffer { + const maxMsgs = 255 + bufs := make([]*bytes.Buffer, 0, (len(msgs)+(maxMsgs-1))/maxMsgs) + + for ; len(msgs) > maxMsgs; msgs = msgs[maxMsgs:] { + bufs = append(bufs, makeCompoundMessage(msgs[:maxMsgs])) + } + if len(msgs) > 0 { + bufs = append(bufs, makeCompoundMessage(msgs)) + } + + return bufs +} + // makeCompoundMessage takes a list of messages and generates // a single compound message containing all of them func makeCompoundMessage(msgs [][]byte) *bytes.Buffer { From 7e2219b23f3331101f56694b69a09c76b9707f7b Mon Sep 17 00:00:00 2001 From: Austin Date: Tue, 21 Sep 2021 12:56:42 -0600 Subject: [PATCH 02/12] Fix udp writes not causing nodes to become suspect (#242) Handle the udp write error case in failedRemote --- state.go | 5 +++++ state_test.go | 2 ++ 2 files changed, 7 insertions(+) diff --git a/state.go b/state.go index 5e4f7fdd7..327a4dce6 100644 --- a/state.go +++ b/state.go @@ -274,6 +274,11 @@ func failedRemote(err error) bool { case "dial", "read", "write": return true } + } else if strings.HasPrefix(t.Net, "udp") { + switch t.Op { + case "write": + return true + } } } return false diff --git a/state_test.go b/state_test.go index 059918af8..204f5020c 100644 --- a/state_test.go +++ b/state_test.go @@ -2218,6 +2218,8 @@ func TestMemberlist_FailedRemote(t *testing.T) { {"net.OpError for tcp with dial", &net.OpError{Net: "tcp", Op: "dial"}, true}, {"net.OpError for tcp with write", &net.OpError{Net: "tcp", Op: "write"}, true}, {"net.OpError for tcp with read", &net.OpError{Net: "tcp", Op: "read"}, true}, + {"net.OpError for udp with write", &net.OpError{Net: "udp", Op: "write"}, true}, + {"net.OpError for udp with read", &net.OpError{Net: "udp", Op: "read"}, false}, } for _, test := range tests { From 123f3fbfeacd70bbe467775ab563a2b87e8c5cd8 Mon Sep 17 00:00:00 2001 From: "R.B. Boyer" <4903+rboyer@users.noreply.github.com> Date: Thu, 14 Oct 2021 11:11:22 -0500 Subject: [PATCH 03/12] pass go vet (#244) --- memberlist_test.go | 2 +- net_test.go | 64 ++++++++++++++++++++++++++++++++++------------ 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/memberlist_test.go b/memberlist_test.go index 70b5d780c..4e41186a3 100644 --- a/memberlist_test.go +++ b/memberlist_test.go @@ -541,7 +541,7 @@ func TestMemberList_ResolveAddr_TCP_First(t *testing.T) { go func() { if err := server.ListenAndServe(); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) } }() wg.Wait() diff --git a/net_test.go b/net_test.go index 9f1a9aae7..8f016226c 100644 --- a/net_test.go +++ b/net_test.go @@ -292,46 +292,56 @@ func TestTCPPing(t *testing.T) { // Do a normal round trip. pingOut := ping{SeqNo: 23, Node: "mongo"} + pingErrCh := make(chan error, 1) go func() { tcp.SetDeadline(time.Now().Add(pingTimeMax)) conn, err := tcp.AcceptTCP() if err != nil { - t.Fatalf("failed to connect: %s", err) + pingErrCh <- fmt.Errorf("failed to connect: %s", err) + return } defer conn.Close() msgType, _, dec, err := m.readStream(conn) if err != nil { - t.Fatalf("failed to read ping: %s", err) + pingErrCh <- fmt.Errorf("failed to read ping: %s", err) + return } if msgType != pingMsg { - t.Fatalf("expecting ping, got message type (%d)", msgType) + pingErrCh <- fmt.Errorf("expecting ping, got message type (%d)", msgType) + return } var pingIn ping if err := dec.Decode(&pingIn); err != nil { - t.Fatalf("failed to decode ping: %s", err) + pingErrCh <- fmt.Errorf("failed to decode ping: %s", err) + return } if pingIn.SeqNo != pingOut.SeqNo { - t.Fatalf("sequence number isn't correct (%d) vs (%d)", pingIn.SeqNo, pingOut.SeqNo) + pingErrCh <- fmt.Errorf("sequence number isn't correct (%d) vs (%d)", pingIn.SeqNo, pingOut.SeqNo) + return } if pingIn.Node != pingOut.Node { - t.Fatalf("node name isn't correct (%s) vs (%s)", pingIn.Node, pingOut.Node) + pingErrCh <- fmt.Errorf("node name isn't correct (%s) vs (%s)", pingIn.Node, pingOut.Node) + return } ack := ackResp{pingIn.SeqNo, nil} out, err := encode(ackRespMsg, &ack) if err != nil { - t.Fatalf("failed to encode ack: %s", err) + pingErrCh <- fmt.Errorf("failed to encode ack: %s", err) + return } err = m.rawSendMsgStream(conn, out.Bytes()) if err != nil { - t.Fatalf("failed to send ack: %s", err) + pingErrCh <- fmt.Errorf("failed to send ack: %s", err) + return } + pingErrCh <- nil }() deadline := time.Now().Add(pingTimeout) didContact, err := m.sendPingAndWaitForAck(tcpAddr2, pingOut, deadline) @@ -341,36 +351,45 @@ func TestTCPPing(t *testing.T) { if !didContact { t.Fatalf("expected successful ping") } + if err = <-pingErrCh; err != nil { + t.Fatal(err) + } // Make sure a mis-matched sequence number is caught. go func() { tcp.SetDeadline(time.Now().Add(pingTimeMax)) conn, err := tcp.AcceptTCP() if err != nil { - t.Fatalf("failed to connect: %s", err) + pingErrCh <- fmt.Errorf("failed to connect: %s", err) + return } defer conn.Close() _, _, dec, err := m.readStream(conn) if err != nil { - t.Fatalf("failed to read ping: %s", err) + pingErrCh <- fmt.Errorf("failed to read ping: %s", err) + return } var pingIn ping if err := dec.Decode(&pingIn); err != nil { - t.Fatalf("failed to decode ping: %s", err) + pingErrCh <- fmt.Errorf("failed to decode ping: %s", err) + return } ack := ackResp{pingIn.SeqNo + 1, nil} out, err := encode(ackRespMsg, &ack) if err != nil { - t.Fatalf("failed to encode ack: %s", err) + pingErrCh <- fmt.Errorf("failed to encode ack: %s", err) + return } err = m.rawSendMsgStream(conn, out.Bytes()) if err != nil { - t.Fatalf("failed to send ack: %s", err) + pingErrCh <- fmt.Errorf("failed to send ack: %s", err) + return } + pingErrCh <- nil }() deadline = time.Now().Add(pingTimeout) didContact, err = m.sendPingAndWaitForAck(tcpAddr2, pingOut, deadline) @@ -380,31 +399,39 @@ func TestTCPPing(t *testing.T) { if didContact { t.Fatalf("expected failed ping") } + if err = <-pingErrCh; err != nil { + t.Fatal(err) + } // Make sure an unexpected message type is handled gracefully. go func() { tcp.SetDeadline(time.Now().Add(pingTimeMax)) conn, err := tcp.AcceptTCP() if err != nil { - t.Fatalf("failed to connect: %s", err) + pingErrCh <- fmt.Errorf("failed to connect: %s", err) + return } defer conn.Close() _, _, _, err = m.readStream(conn) if err != nil { - t.Fatalf("failed to read ping: %s", err) + pingErrCh <- fmt.Errorf("failed to read ping: %s", err) + return } bogus := indirectPingReq{} out, err := encode(indirectPingMsg, &bogus) if err != nil { - t.Fatalf("failed to encode bogus msg: %s", err) + pingErrCh <- fmt.Errorf("failed to encode bogus msg: %s", err) + return } err = m.rawSendMsgStream(conn, out.Bytes()) if err != nil { - t.Fatalf("failed to send bogus msg: %s", err) + pingErrCh <- fmt.Errorf("failed to send bogus msg: %s", err) + return } + pingErrCh <- nil }() deadline = time.Now().Add(pingTimeout) didContact, err = m.sendPingAndWaitForAck(tcpAddr2, pingOut, deadline) @@ -414,6 +441,9 @@ func TestTCPPing(t *testing.T) { if didContact { t.Fatalf("expected failed ping") } + if err = <-pingErrCh; err != nil { + t.Fatal(err) + } // Make sure failed I/O respects the deadline. In this case we try the // common case of the receiving node being totally down. From 923f1b205dd4653f2ea35e1c9531088e52053aa0 Mon Sep 17 00:00:00 2001 From: "R.B. Boyer" <4903+rboyer@users.noreply.github.com> Date: Fri, 12 Nov 2021 16:15:55 -0600 Subject: [PATCH 04/12] add an optional wrapper labels to packets and streams (#246) --- .circleci/config.yml | 2 +- config.go | 11 ++ label.go | 178 +++++++++++++++++++++ label_test.go | 364 +++++++++++++++++++++++++++++++++++++++++++ memberlist.go | 13 +- memberlist_test.go | 92 +++++++++++ net.go | 132 ++++++++++++---- net_test.go | 16 +- peeked_conn.go | 48 ++++++ security.go | 19 +++ state.go | 2 +- transport.go | 47 ++++++ 12 files changed, 886 insertions(+), 38 deletions(-) create mode 100644 label.go create mode 100644 label_test.go create mode 100644 peeked_conn.go diff --git a/.circleci/config.yml b/.circleci/config.yml index 78ac11fb7..a6eebf2db 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -4,7 +4,7 @@ version: 2.1 executors: go: docker: - - image: docker.mirror.hashicorp.services/circleci/golang:1.15 + - image: docker.mirror.hashicorp.services/circleci/golang:1.16 environment: - TEST_RESULTS: /tmp/test-results # path to where test results are saved diff --git a/config.go b/config.go index 31099e75f..d7fe4c37b 100644 --- a/config.go +++ b/config.go @@ -21,6 +21,17 @@ type Config struct { // make a NetTransport using BindAddr and BindPort from this structure. Transport Transport + // Label is an optional set of bytes to include on the outside of each + // packet and stream. + // + // If gossip encryption is enabled and this is set it is treated as GCM + // authenticated data. + Label string + + // SkipInboundLabelCheck skips the check that inbound packets and gossip + // streams need to be label prefixed. + SkipInboundLabelCheck bool + // Configuration related to what address to bind to and ports to // listen on. The port is used for both UDP and TCP gossip. It is // assumed other nodes are running on this port, but they do not need diff --git a/label.go b/label.go new file mode 100644 index 000000000..bbe0163ab --- /dev/null +++ b/label.go @@ -0,0 +1,178 @@ +package memberlist + +import ( + "bufio" + "fmt" + "io" + "net" +) + +// General approach is to prefix all packets and streams with the same structure: +// +// magic type byte (244): uint8 +// length of label name: uint8 (because labels can't be longer than 255 bytes) +// label name: []uint8 + +// LabelMaxSize is the maximum length of a packet or stream label. +const LabelMaxSize = 255 + +// AddLabelHeaderToPacket prefixes outgoing packets with the correct header if +// the label is not empty. +func AddLabelHeaderToPacket(buf []byte, label string) ([]byte, error) { + if label == "" { + return buf, nil + } + if len(label) > LabelMaxSize { + return nil, fmt.Errorf("label %q is too long", label) + } + + return makeLabelHeader(label, buf), nil +} + +// RemoveLabelHeaderFromPacket removes any label header from the provided +// packet and returns it along with the remaining packet contents. +func RemoveLabelHeaderFromPacket(buf []byte) (newBuf []byte, label string, err error) { + if len(buf) == 0 { + return buf, "", nil // can't possibly be labeled + } + + // [type:byte] [size:byte] [size bytes] + + msgType := messageType(buf[0]) + if msgType != hasLabelMsg { + return buf, "", nil + } + + if len(buf) < 2 { + return nil, "", fmt.Errorf("cannot decode label; packet has been truncated") + } + + size := int(buf[1]) + if size < 1 { + return nil, "", fmt.Errorf("label header cannot be empty when present") + } + + if len(buf) < 2+size { + return nil, "", fmt.Errorf("cannot decode label; packet has been truncated") + } + + label = string(buf[2 : 2+size]) + newBuf = buf[2+size:] + + return newBuf, label, nil +} + +// AddLabelHeaderToStream prefixes outgoing streams with the correct header if +// the label is not empty. +func AddLabelHeaderToStream(conn net.Conn, label string) error { + if label == "" { + return nil + } + if len(label) > LabelMaxSize { + return fmt.Errorf("label %q is too long", label) + } + + header := makeLabelHeader(label, nil) + + _, err := conn.Write(header) + return err +} + +// RemoveLabelHeaderFromStream removes any label header from the beginning of +// the stream if present and returns it along with an updated conn with that +// header removed. +// +// Note that on error it is the caller's responsibility to close the +// connection. +func RemoveLabelHeaderFromStream(conn net.Conn) (net.Conn, string, error) { + br := bufio.NewReader(conn) + + // First check for the type byte. + peeked, err := br.Peek(1) + if err != nil { + if err == io.EOF { + // It is safe to return the original net.Conn at this point because + // it never contained any data in the first place so we don't have + // to splice the buffer into the conn because both are empty. + return conn, "", nil + } + return nil, "", err + } + + msgType := messageType(peeked[0]) + if msgType != hasLabelMsg { + conn, err = newPeekedConnFromBufferedReader(conn, br, 0) + return conn, "", err + } + + // We are guaranteed to get a size byte as well. + peeked, err = br.Peek(2) + if err != nil { + if err == io.EOF { + return nil, "", fmt.Errorf("cannot decode label; stream has been truncated") + } + return nil, "", err + } + + size := int(peeked[1]) + if size < 1 { + return nil, "", fmt.Errorf("label header cannot be empty when present") + } + // NOTE: we don't have to check this against LabelMaxSize because a byte + // already has a max value of 255. + + // Once we know the size we can peek the label as well. Note that since we + // are using the default bufio.Reader size of 4096, the entire label header + // fits in the initial buffer fill so this should be free. + peeked, err = br.Peek(2 + size) + if err != nil { + if err == io.EOF { + return nil, "", fmt.Errorf("cannot decode label; stream has been truncated") + } + return nil, "", err + } + + label := string(peeked[2 : 2+size]) + + conn, err = newPeekedConnFromBufferedReader(conn, br, 2+size) + if err != nil { + return nil, "", err + } + + return conn, label, nil +} + +// newPeekedConnFromBufferedReader will splice the buffer contents after the +// offset into the provided net.Conn and return the result so that the rest of +// the buffer contents are returned first when reading from the returned +// peekedConn before moving on to the unbuffered conn contents. +func newPeekedConnFromBufferedReader(conn net.Conn, br *bufio.Reader, offset int) (*peekedConn, error) { + // Extract any of the readahead buffer. + peeked, err := br.Peek(br.Buffered()) + if err != nil { + return nil, err + } + + return &peekedConn{ + Peeked: peeked[offset:], + Conn: conn, + }, nil +} + +func makeLabelHeader(label string, rest []byte) []byte { + newBuf := make([]byte, 2, 2+len(label)+len(rest)) + newBuf[0] = byte(hasLabelMsg) + newBuf[1] = byte(len(label)) + newBuf = append(newBuf, []byte(label)...) + if len(rest) > 0 { + newBuf = append(newBuf, []byte(rest)...) + } + return newBuf +} + +func labelOverhead(label string) int { + if label == "" { + return 0 + } + return 2 + len(label) +} diff --git a/label_test.go b/label_test.go new file mode 100644 index 000000000..a78ea4a9c --- /dev/null +++ b/label_test.go @@ -0,0 +1,364 @@ +package memberlist + +import ( + "bytes" + "io" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAddLabelHeaderToPacket(t *testing.T) { + type testcase struct { + buf []byte + label string + expectPacket []byte + expectErr string + } + + run := func(t *testing.T, tc testcase) { + got, err := AddLabelHeaderToPacket(tc.buf, tc.label) + if tc.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectPacket, got) + } + } + + longLabel := strings.Repeat("a", 255) + + cases := map[string]testcase{ + "nil buf with no label": testcase{ + buf: nil, + label: "", + expectPacket: nil, + }, + "nil buf with label": testcase{ + buf: nil, + label: "foo", + expectPacket: append([]byte{byte(hasLabelMsg), 3}, []byte("foo")...), + }, + "message with label": testcase{ + buf: []byte("something"), + label: "foo", + expectPacket: append([]byte{byte(hasLabelMsg), 3}, []byte("foosomething")...), + }, + "message with no label": testcase{ + buf: []byte("something"), + label: "", + expectPacket: []byte("something"), + }, + "message with almost too long label": testcase{ + buf: []byte("something"), + label: longLabel, + expectPacket: append([]byte{byte(hasLabelMsg), 255}, []byte(longLabel+"something")...), + }, + "label too long by one byte": testcase{ + buf: []byte("something"), + label: longLabel + "x", + expectErr: `label "` + longLabel + `x" is too long`, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + run(t, tc) + }) + } +} + +func TestRemoveLabelHeaderFromPacket(t *testing.T) { + type testcase struct { + buf []byte + expectLabel string + expectPacket []byte + expectErr string + } + + run := func(t *testing.T, tc testcase) { + gotBuf, gotLabel, err := RemoveLabelHeaderFromPacket(tc.buf) + if tc.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectPacket, gotBuf) + require.Equal(t, tc.expectLabel, gotLabel) + } + } + + cases := map[string]testcase{ + "empty buf": testcase{ + buf: []byte{}, + expectLabel: "", + expectPacket: []byte{}, + }, + "ping with no label": testcase{ + buf: buildBuffer(t, pingMsg, "blah"), + expectLabel: "", + expectPacket: buildBuffer(t, pingMsg, "blah"), + }, + "error with no label": testcase{ // 2021-10: largest standard message type + buf: buildBuffer(t, errMsg, "blah"), + expectLabel: "", + expectPacket: buildBuffer(t, errMsg, "blah"), + }, + "v1 encrypt with no label": testcase{ // 2021-10: highest encryption version + buf: buildBuffer(t, maxEncryptionVersion, "blah"), + expectLabel: "", + expectPacket: buildBuffer(t, maxEncryptionVersion, "blah"), + }, + "buf too small for label": testcase{ + buf: buildBuffer(t, hasLabelMsg, "x"), + expectErr: `cannot decode label; packet has been truncated`, + }, + "buf too small for label size": testcase{ + buf: buildBuffer(t, hasLabelMsg), + expectErr: `cannot decode label; packet has been truncated`, + }, + "label empty": testcase{ + buf: buildBuffer(t, hasLabelMsg, 0, "x"), + expectErr: `label header cannot be empty when present`, + }, + "label truncated": testcase{ + buf: buildBuffer(t, hasLabelMsg, 2, "x"), + expectErr: `cannot decode label; packet has been truncated`, + }, + "ping with label": testcase{ + buf: buildBuffer(t, hasLabelMsg, 3, "abc", pingMsg, "blah"), + expectLabel: "abc", + expectPacket: buildBuffer(t, pingMsg, "blah"), + }, + "error with label": testcase{ // 2021-10: largest standard message type + buf: buildBuffer(t, hasLabelMsg, 3, "abc", errMsg, "blah"), + expectLabel: "abc", + expectPacket: buildBuffer(t, errMsg, "blah"), + }, + "v1 encrypt with label": testcase{ // 2021-10: highest encryption version + buf: buildBuffer(t, hasLabelMsg, 3, "abc", maxEncryptionVersion, "blah"), + expectLabel: "abc", + expectPacket: buildBuffer(t, maxEncryptionVersion, "blah"), + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + run(t, tc) + }) + } +} + +func TestAddLabelHeaderToStream(t *testing.T) { + type testcase struct { + label string + expectData []byte + expectErr string + } + + suffixData := "EXTRA DATA" + + run := func(t *testing.T, tc testcase) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + var ( + dataCh = make(chan []byte, 1) + errCh = make(chan error, 1) + ) + go func() { + var buf bytes.Buffer + _, err := io.Copy(&buf, server) + if err != nil { + errCh <- err + } + dataCh <- buf.Bytes() + }() + + err := AddLabelHeaderToStream(client, tc.label) + if tc.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + return + } + require.NoError(t, err) + + client.Write([]byte(suffixData)) + client.Close() + + expect := make([]byte, 0, len(suffixData)+len(tc.expectData)) + expect = append(expect, tc.expectData...) + expect = append(expect, suffixData...) + + select { + case err := <-errCh: + require.NoError(t, err) + case got := <-dataCh: + require.Equal(t, expect, got) + } + } + + longLabel := strings.Repeat("a", 255) + + cases := map[string]testcase{ + "no label": testcase{ + label: "", + expectData: nil, + }, + "with label": testcase{ + label: "foo", + expectData: buildBuffer(t, hasLabelMsg, 3, "foo"), + }, + "almost too long label": testcase{ + label: longLabel, + expectData: buildBuffer(t, hasLabelMsg, 255, longLabel), + }, + "label too long by one byte": testcase{ + label: longLabel + "x", + expectErr: `label "` + longLabel + `x" is too long`, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + run(t, tc) + }) + } +} + +func TestRemoveLabelHeaderFromStream(t *testing.T) { + type testcase struct { + buf []byte + expectLabel string + expectData []byte + expectErr string + } + + run := func(t *testing.T, tc testcase) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + var ( + errCh = make(chan error, 1) + ) + go func() { + _, err := server.Write(tc.buf) + if err != nil { + errCh <- err + } + server.Close() + }() + + newConn, gotLabel, err := RemoveLabelHeaderFromStream(client) + if tc.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + return + } + require.NoError(t, err) + + gotBuf, err := io.ReadAll(newConn) + require.NoError(t, err) + + require.Equal(t, tc.expectData, gotBuf) + require.Equal(t, tc.expectLabel, gotLabel) + } + + cases := map[string]testcase{ + "empty buf": testcase{ + buf: []byte{}, + expectLabel: "", + expectData: []byte{}, + }, + "ping with no label": testcase{ + buf: buildBuffer(t, pingMsg, "blah"), + expectLabel: "", + expectData: buildBuffer(t, pingMsg, "blah"), + }, + "error with no label": testcase{ // 2021-10: largest standard message type + buf: buildBuffer(t, errMsg, "blah"), + expectLabel: "", + expectData: buildBuffer(t, errMsg, "blah"), + }, + "v1 encrypt with no label": testcase{ // 2021-10: highest encryption version + buf: buildBuffer(t, maxEncryptionVersion, "blah"), + expectLabel: "", + expectData: buildBuffer(t, maxEncryptionVersion, "blah"), + }, + "buf too small for label": testcase{ + buf: buildBuffer(t, hasLabelMsg, "x"), + expectErr: `cannot decode label; stream has been truncated`, + }, + "buf too small for label size": testcase{ + buf: buildBuffer(t, hasLabelMsg), + expectErr: `cannot decode label; stream has been truncated`, + }, + "label empty": testcase{ + buf: buildBuffer(t, hasLabelMsg, 0, "x"), + expectErr: `label header cannot be empty when present`, + }, + "label truncated": testcase{ + buf: buildBuffer(t, hasLabelMsg, 2, "x"), + expectErr: `cannot decode label; stream has been truncated`, + }, + "ping with label": testcase{ + buf: buildBuffer(t, hasLabelMsg, 3, "abc", pingMsg, "blah"), + expectLabel: "abc", + expectData: buildBuffer(t, pingMsg, "blah"), + }, + "error with label": testcase{ // 2021-10: largest standard message type + buf: buildBuffer(t, hasLabelMsg, 3, "abc", errMsg, "blah"), + expectLabel: "abc", + expectData: buildBuffer(t, errMsg, "blah"), + }, + "v1 encrypt with label": testcase{ // 2021-10: highest encryption version + buf: buildBuffer(t, hasLabelMsg, 3, "abc", maxEncryptionVersion, "blah"), + expectLabel: "abc", + expectData: buildBuffer(t, maxEncryptionVersion, "blah"), + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + run(t, tc) + }) + } +} + +func buildBuffer(t *testing.T, stuff ...interface{}) []byte { + var buf bytes.Buffer + for _, item := range stuff { + switch x := item.(type) { + case int: + x2 := uint(x) + if x2 > 255 { + t.Fatalf("int is too big") + } + buf.WriteByte(byte(x2)) + case byte: + buf.WriteByte(byte(x)) + case messageType: + buf.WriteByte(byte(x)) + case encryptionVersion: + buf.WriteByte(byte(x)) + case string: + buf.Write([]byte(x)) + case []byte: + buf.Write(x) + default: + t.Fatalf("unexpected type %T", item) + } + } + return buf.Bytes() +} + +func TestLabelOverhead(t *testing.T) { + require.Equal(t, 0, labelOverhead("")) + require.Equal(t, 3, labelOverhead("a")) + require.Equal(t, 9, labelOverhead("abcdefg")) +} diff --git a/memberlist.go b/memberlist.go index 7ee040091..cab6db69f 100644 --- a/memberlist.go +++ b/memberlist.go @@ -187,6 +187,17 @@ func newMemberlist(conf *Config) (*Memberlist, error) { nodeAwareTransport = &shimNodeAwareTransport{transport} } + if len(conf.Label) > LabelMaxSize { + return nil, fmt.Errorf("could not use %q as a label: too long", conf.Label) + } + + if conf.Label != "" { + nodeAwareTransport = &labelWrappedTransport{ + label: conf.Label, + NodeAwareTransport: nodeAwareTransport, + } + } + m := &Memberlist{ config: conf, shutdownCh: make(chan struct{}), @@ -262,7 +273,7 @@ func (m *Memberlist) Join(existing []string) (int, error) { hp := joinHostPort(addr.ip.String(), addr.port) a := Address{Addr: hp, Name: addr.nodeName} if err := m.pushPullNode(a, true); err != nil { - err = fmt.Errorf("Failed to join %s: %v", addr.ip, err) + err = fmt.Errorf("Failed to join %s: %v", a.Addr, err) errs = multierror.Append(errs, err) m.logger.Printf("[DEBUG] memberlist: %v", err) continue diff --git a/memberlist_test.go b/memberlist_test.go index 4e41186a3..ab310ebc0 100644 --- a/memberlist_test.go +++ b/memberlist_test.go @@ -16,6 +16,7 @@ import ( iretry "github.com/hashicorp/memberlist/internal/retry" "github.com/miekg/dns" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -643,6 +644,90 @@ func TestMemberlist_Join(t *testing.T) { } } +func TestMemberlist_Join_with_Labels(t *testing.T) { + testMemberlist_Join_with_Labels(t, nil) +} +func TestMemberlist_Join_with_Labels_and_Encryption(t *testing.T) { + secretKey := TestKeys[0] + testMemberlist_Join_with_Labels(t, secretKey) +} +func testMemberlist_Join_with_Labels(t *testing.T, secretKey []byte) { + c1 := testConfig(t) + c1.Label = "blah" + c1.SecretKey = secretKey + m1, err := Create(c1) + require.NoError(t, err) + defer m1.Shutdown() + + bindPort := m1.config.BindPort + + // Create a second node + c2 := testConfig(t) + c2.Label = "blah" + c2.BindPort = bindPort + c2.SecretKey = secretKey + m2, err := Create(c2) + require.NoError(t, err) + defer m2.Shutdown() + + checkHost := func(t *testing.T, m *Memberlist, expected int) { + assert.Equal(t, expected, len(m.Members())) + assert.Equal(t, expected, m.estNumNodes()) + } + + runStep(t, "same label can join", func(t *testing.T) { + num, err := m2.Join([]string{m1.config.Name + "/" + m1.config.BindAddr}) + require.NoError(t, err) + require.Equal(t, 1, num) + + // Check the hosts + checkHost(t, m2, 2) + checkHost(t, m1, 2) + }) + + // Create a third node that uses no label + c3 := testConfig(t) + c3.Label = "" + c3.BindPort = bindPort + c3.SecretKey = secretKey + m3, err := Create(c3) + require.NoError(t, err) + defer m3.Shutdown() + + runStep(t, "no label cannot join", func(t *testing.T) { + _, err := m3.Join([]string{m1.config.Name + "/" + m1.config.BindAddr}) + require.Error(t, err) + + // Check the failed host + checkHost(t, m3, 1) + // Check the existing hosts + checkHost(t, m2, 2) + checkHost(t, m1, 2) + }) + + // Create a fourth node that uses a mismatched label + c4 := testConfig(t) + c4.Label = "not-blah" + c4.BindPort = bindPort + c4.SecretKey = secretKey + m4, err := Create(c4) + require.NoError(t, err) + defer m4.Shutdown() + + runStep(t, "mismatched label cannot join", func(t *testing.T) { + _, err := m4.Join([]string{m1.config.Name + "/" + m1.config.BindAddr}) + require.Error(t, err) + + // Check the failed host + checkHost(t, m4, 1) + // Check the previous failed host + checkHost(t, m3, 1) + // Check the existing hosts + checkHost(t, m2, 2) + checkHost(t, m1, 2) + }) +} + func TestMemberlist_JoinDifferentNetworksUniqueMask(t *testing.T) { c1 := testConfigNet(t, 0) c1.CIDRsAllowed, _ = ParseCIDRs([]string{"127.0.0.0/8"}) @@ -2071,3 +2156,10 @@ func TestMemberlist_EncryptedGossipTransition(t *testing.T) { // t.Fatalf("bad role for %s: %s", c2.Name, r) // } //} + +func runStep(t *testing.T, name string, fn func(t *testing.T)) { + t.Helper() + if !t.Run(name, fn) { + t.FailNow() + } +} diff --git a/net.go b/net.go index bac73bd89..1d015afb2 100644 --- a/net.go +++ b/net.go @@ -42,6 +42,9 @@ const ( type messageType uint8 // The list of available message types. +// +// WARNING: ONLY APPEND TO THIS LIST! The numeric values are part of the +// protocol itself. const ( pingMsg messageType = iota indirectPingMsg @@ -59,6 +62,13 @@ const ( errMsg ) +const ( + // hasLabelMsg has a deliberately high value so that you can disambiguate + // it from the encryptionVersion header which is either 0/1 right now and + // also any of the existing messageTypes + hasLabelMsg messageType = 244 +) + // compressionType is used to specify the compression algorithm type compressionType uint8 @@ -226,7 +236,32 @@ func (m *Memberlist) handleConn(conn net.Conn) { metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1) conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) - msgType, bufConn, dec, err := m.readStream(conn) + + var ( + streamLabel string + err error + ) + conn, streamLabel, err = RemoveLabelHeaderFromStream(conn) + if err != nil { + m.logger.Printf("[ERR] memberlist: failed to receive and remove the stream label header: %s %s", err, LogConn(conn)) + return + } + + if m.config.SkipInboundLabelCheck { + if streamLabel != "" { + m.logger.Printf("[ERR] memberlist: unexpected double stream label header: %s", LogConn(conn)) + return + } + // Set this from config so that the auth data assertions work below. + streamLabel = m.config.Label + } + + if m.config.Label != streamLabel { + m.logger.Printf("[ERR] memberlist: discarding stream with unacceptable label %q: %s", streamLabel, LogConn(conn)) + return + } + + msgType, bufConn, dec, err := m.readStream(conn, streamLabel) if err != nil { if err != io.EOF { m.logger.Printf("[ERR] memberlist: failed to receive: %s %s", err, LogConn(conn)) @@ -238,7 +273,7 @@ func (m *Memberlist) handleConn(conn net.Conn) { return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), streamLabel) if err != nil { m.logger.Printf("[ERR] memberlist: Failed to send error: %s %s", err, LogConn(conn)) return @@ -269,7 +304,7 @@ func (m *Memberlist) handleConn(conn net.Conn) { return } - if err := m.sendLocalState(conn, join); err != nil { + if err := m.sendLocalState(conn, join, streamLabel); err != nil { m.logger.Printf("[ERR] memberlist: Failed to push local state: %s %s", err, LogConn(conn)) return } @@ -297,7 +332,7 @@ func (m *Memberlist) handleConn(conn net.Conn) { return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), streamLabel) if err != nil { m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogConn(conn)) return @@ -322,10 +357,35 @@ func (m *Memberlist) packetListen() { } func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time) { + var ( + packetLabel string + err error + ) + buf, packetLabel, err = RemoveLabelHeaderFromPacket(buf) + if err != nil { + m.logger.Printf("[ERR] memberlist: %v %s", err, LogAddress(from)) + return + } + + if m.config.SkipInboundLabelCheck { + if packetLabel != "" { + m.logger.Printf("[ERR] memberlist: unexpected double packet label header: %s", LogAddress(from)) + return + } + // Set this from config so that the auth data assertions work below. + packetLabel = m.config.Label + } + + if m.config.Label != packetLabel { + m.logger.Printf("[ERR] memberlist: discarding packet with unacceptable label %q: %s", packetLabel, LogAddress(from)) + return + } + // Check if encryption is enabled if m.config.EncryptionEnabled() { // Decrypt the payload - plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil) + authData := []byte(packetLabel) + plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, authData) if err != nil { if !m.config.GossipVerifyIncoming { // Treat the message as plaintext @@ -723,7 +783,7 @@ func (m *Memberlist) encodeAndSendMsg(a Address, msgType messageType, msg interf // opportunistically create a compoundMsg and piggy back other broadcasts. func (m *Memberlist) sendMsg(a Address, msg []byte) error { // Check if we can piggy back any messages - bytesAvail := m.config.UDPBufferSize - len(msg) - compoundHeaderOverhead + bytesAvail := m.config.UDPBufferSize - len(msg) - compoundHeaderOverhead - labelOverhead(m.config.Label) if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { bytesAvail -= encryptOverhead(m.encryptionVersion()) } @@ -795,9 +855,12 @@ func (m *Memberlist) rawSendMsgPacket(a Address, node *Node, msg []byte) error { // Check if we have encryption enabled if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { // Encrypt the payload - var buf bytes.Buffer - primaryKey := m.config.Keyring.GetPrimaryKey() - err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf) + var ( + primaryKey = m.config.Keyring.GetPrimaryKey() + packetLabel = []byte(m.config.Label) + buf bytes.Buffer + ) + err := encryptPayload(m.encryptionVersion(), primaryKey, msg, packetLabel, &buf) if err != nil { m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err) return err @@ -812,7 +875,7 @@ func (m *Memberlist) rawSendMsgPacket(a Address, node *Node, msg []byte) error { // rawSendMsgStream is used to stream a message to another host without // modification, other than applying compression and encryption if enabled. -func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error { +func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte, streamLabel string) error { // Check if compression is enabled if m.config.EnableCompression { compBuf, err := compressPayload(sendBuf) @@ -825,7 +888,7 @@ func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error { // Check if encryption is enabled if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { - crypt, err := m.encryptLocalState(sendBuf) + crypt, err := m.encryptLocalState(sendBuf, streamLabel) if err != nil { m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err) return err @@ -871,7 +934,8 @@ func (m *Memberlist) sendUserMsg(a Address, sendBuf []byte) error { if _, err := bufConn.Write(sendBuf); err != nil { return err } - return m.rawSendMsgStream(conn, bufConn.Bytes()) + + return m.rawSendMsgStream(conn, bufConn.Bytes(), m.config.Label) } // sendAndReceiveState is used to initiate a push/pull over a stream with a @@ -891,12 +955,12 @@ func (m *Memberlist) sendAndReceiveState(a Address, join bool) ([]pushNodeState, metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1) // Send our state - if err := m.sendLocalState(conn, join); err != nil { + if err := m.sendLocalState(conn, join, m.config.Label); err != nil { return nil, nil, err } conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) - msgType, bufConn, dec, err := m.readStream(conn) + msgType, bufConn, dec, err := m.readStream(conn, m.config.Label) if err != nil { return nil, nil, err } @@ -921,7 +985,7 @@ func (m *Memberlist) sendAndReceiveState(a Address, join bool) ([]pushNodeState, } // sendLocalState is invoked to send our local state over a stream connection. -func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error { +func (m *Memberlist) sendLocalState(conn net.Conn, join bool, streamLabel string) error { // Setup a deadline conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) @@ -978,11 +1042,11 @@ func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error { } // Get the send buffer - return m.rawSendMsgStream(conn, bufConn.Bytes()) + return m.rawSendMsgStream(conn, bufConn.Bytes(), streamLabel) } // encryptLocalState is used to help encrypt local state before sending -func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { +func (m *Memberlist) encryptLocalState(sendBuf []byte, streamLabel string) ([]byte, error) { var buf bytes.Buffer // Write the encryptMsg byte @@ -995,9 +1059,15 @@ func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { binary.BigEndian.PutUint32(sizeBuf, uint32(encLen)) buf.Write(sizeBuf) + // Authenticated Data is: + // + // [messageType; byte] [messageLength; uint32] [stream_label; optional] + // + dataBytes := appendBytes(buf.Bytes()[:5], []byte(streamLabel)) + // Write the encrypted cipher text to the buffer key := m.config.Keyring.GetPrimaryKey() - err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf) + err := encryptPayload(encVsn, key, sendBuf, dataBytes, &buf) if err != nil { return nil, err } @@ -1005,7 +1075,7 @@ func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { } // decryptRemoteState is used to help decrypt the remote state -func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { +func (m *Memberlist) decryptRemoteState(bufConn io.Reader, streamLabel string) ([]byte, error) { // Read in enough to determine message length cipherText := bytes.NewBuffer(nil) cipherText.WriteByte(byte(encryptMsg)) @@ -1027,8 +1097,13 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { return nil, err } - // Decrypt the cipherText - dataBytes := cipherText.Bytes()[:5] + // Decrypt the cipherText with some authenticated data + // + // Authenticated Data is: + // + // [messageType; byte] [messageLength; uint32] [label_data; optional] + // + dataBytes := appendBytes(cipherText.Bytes()[:5], []byte(streamLabel)) cipherBytes := cipherText.Bytes()[5:] // Decrypt the payload @@ -1036,15 +1111,18 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { return decryptPayload(keys, cipherBytes, dataBytes) } -// readStream is used to read from a stream connection, decrypting and +// readStream is used to read messages from a stream connection, decrypting and // decompressing the stream if necessary. -func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.Decoder, error) { +// +// The provided streamLabel if present will be authenticated during decryption +// of each message. +func (m *Memberlist) readStream(conn net.Conn, streamLabel string) (messageType, io.Reader, *codec.Decoder, error) { // Created a buffered reader var bufConn io.Reader = bufio.NewReader(conn) // Read the message type buf := [1]byte{0} - if _, err := bufConn.Read(buf[:]); err != nil { + if _, err := io.ReadFull(bufConn, buf[:]); err != nil { return 0, nil, nil, err } msgType := messageType(buf[0]) @@ -1056,7 +1134,7 @@ func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.D fmt.Errorf("Remote state is encrypted and encryption is not configured") } - plain, err := m.decryptRemoteState(bufConn) + plain, err := m.decryptRemoteState(bufConn, streamLabel) if err != nil { return 0, nil, nil, err } @@ -1236,11 +1314,11 @@ func (m *Memberlist) sendPingAndWaitForAck(a Address, ping ping, deadline time.T return false, err } - if err = m.rawSendMsgStream(conn, out.Bytes()); err != nil { + if err = m.rawSendMsgStream(conn, out.Bytes(), m.config.Label); err != nil { return false, err } - msgType, _, dec, err := m.readStream(conn) + msgType, _, dec, err := m.readStream(conn, m.config.Label) if err != nil { return false, err } diff --git a/net_test.go b/net_test.go index 8f016226c..ce287547e 100644 --- a/net_test.go +++ b/net_test.go @@ -302,7 +302,7 @@ func TestTCPPing(t *testing.T) { } defer conn.Close() - msgType, _, dec, err := m.readStream(conn) + msgType, _, dec, err := m.readStream(conn, "") if err != nil { pingErrCh <- fmt.Errorf("failed to read ping: %s", err) return @@ -336,7 +336,7 @@ func TestTCPPing(t *testing.T) { return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), "") if err != nil { pingErrCh <- fmt.Errorf("failed to send ack: %s", err) return @@ -365,7 +365,7 @@ func TestTCPPing(t *testing.T) { } defer conn.Close() - _, _, dec, err := m.readStream(conn) + _, _, dec, err := m.readStream(conn, "") if err != nil { pingErrCh <- fmt.Errorf("failed to read ping: %s", err) return @@ -384,7 +384,7 @@ func TestTCPPing(t *testing.T) { return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), "") if err != nil { pingErrCh <- fmt.Errorf("failed to send ack: %s", err) return @@ -413,7 +413,7 @@ func TestTCPPing(t *testing.T) { } defer conn.Close() - _, _, _, err = m.readStream(conn) + _, _, _, err = m.readStream(conn, "") if err != nil { pingErrCh <- fmt.Errorf("failed to read ping: %s", err) return @@ -426,7 +426,7 @@ func TestTCPPing(t *testing.T) { return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), "") if err != nil { pingErrCh <- fmt.Errorf("failed to send bogus msg: %s", err) return @@ -697,7 +697,7 @@ func TestEncryptDecryptState(t *testing.T) { } defer m.Shutdown() - crypt, err := m.encryptLocalState(state) + crypt, err := m.encryptLocalState(state, "") if err != nil { t.Fatalf("err: %v", err) } @@ -706,7 +706,7 @@ func TestEncryptDecryptState(t *testing.T) { buf := bytes.NewReader(crypt) buf.Seek(1, 0) - plain, err := m.decryptRemoteState(buf) + plain, err := m.decryptRemoteState(buf, "") if err != nil { t.Fatalf("err: %v", err) } diff --git a/peeked_conn.go b/peeked_conn.go new file mode 100644 index 000000000..3181d90ce --- /dev/null +++ b/peeked_conn.go @@ -0,0 +1,48 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Originally from: https://github.com/google/tcpproxy/blob/master/tcpproxy.go +// at f5c09fbedceb69e4b238dec52cdf9f2fe9a815e2 + +package memberlist + +import "net" + +// peekedConn is an incoming connection that has had some bytes read from it +// to determine how to route the connection. The Read method stitches +// the peeked bytes and unread bytes back together. +type peekedConn struct { + // Peeked are the bytes that have been read from Conn for the + // purposes of route matching, but have not yet been consumed + // by Read calls. It set to nil by Read when fully consumed. + Peeked []byte + + // Conn is the underlying connection. + // It can be type asserted against *net.TCPConn or other types + // as needed. It should not be read from directly unless + // Peeked is nil. + net.Conn +} + +func (c *peekedConn) Read(p []byte) (n int, err error) { + if len(c.Peeked) > 0 { + n = copy(p, c.Peeked) + c.Peeked = c.Peeked[n:] + if len(c.Peeked) == 0 { + c.Peeked = nil + } + return n, nil + } + return c.Conn.Read(p) +} diff --git a/security.go b/security.go index 4cb4da36f..6831be3bc 100644 --- a/security.go +++ b/security.go @@ -199,3 +199,22 @@ func decryptPayload(keys [][]byte, msg []byte, data []byte) ([]byte, error) { return nil, fmt.Errorf("No installed keys could decrypt the message") } + +func appendBytes(first []byte, second []byte) []byte { + hasFirst := len(first) > 0 + hasSecond := len(second) > 0 + + switch { + case hasFirst && hasSecond: + out := make([]byte, 0, len(first)+len(second)) + out = append(out, first...) + out = append(out, second...) + return out + case hasFirst: + return first + case hasSecond: + return second + default: + return nil + } +} diff --git a/state.go b/state.go index 327a4dce6..a6351b4b0 100644 --- a/state.go +++ b/state.go @@ -592,7 +592,7 @@ func (m *Memberlist) gossip() { m.nodeLock.RUnlock() // Compute the bytes available - bytesAvail := m.config.UDPBufferSize - compoundHeaderOverhead + bytesAvail := m.config.UDPBufferSize - compoundHeaderOverhead - labelOverhead(m.config.Label) if m.config.EncryptionEnabled() { bytesAvail -= encryptOverhead(m.encryptionVersion()) } diff --git a/transport.go b/transport.go index b23b83914..f3d05364d 100644 --- a/transport.go +++ b/transport.go @@ -111,3 +111,50 @@ func (t *shimNodeAwareTransport) WriteToAddress(b []byte, addr Address) (time.Ti func (t *shimNodeAwareTransport) DialAddressTimeout(addr Address, timeout time.Duration) (net.Conn, error) { return t.DialTimeout(addr.Addr, timeout) } + +type labelWrappedTransport struct { + label string + NodeAwareTransport +} + +var _ NodeAwareTransport = (*labelWrappedTransport)(nil) + +func (t *labelWrappedTransport) WriteToAddress(buf []byte, addr Address) (time.Time, error) { + var err error + buf, err = AddLabelHeaderToPacket(buf, t.label) + if err != nil { + return time.Time{}, fmt.Errorf("failed to add label header to packet: %w", err) + } + return t.NodeAwareTransport.WriteToAddress(buf, addr) +} + +func (t *labelWrappedTransport) WriteTo(buf []byte, addr string) (time.Time, error) { + var err error + buf, err = AddLabelHeaderToPacket(buf, t.label) + if err != nil { + return time.Time{}, err + } + return t.NodeAwareTransport.WriteTo(buf, addr) +} + +func (t *labelWrappedTransport) DialAddressTimeout(addr Address, timeout time.Duration) (net.Conn, error) { + conn, err := t.NodeAwareTransport.DialAddressTimeout(addr, timeout) + if err != nil { + return nil, err + } + if err := AddLabelHeaderToStream(conn, t.label); err != nil { + return nil, fmt.Errorf("failed to add label header to stream: %w", err) + } + return conn, nil +} + +func (t *labelWrappedTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { + conn, err := t.NodeAwareTransport.DialTimeout(addr, timeout) + if err != nil { + return nil, err + } + if err := AddLabelHeaderToStream(conn, t.label); err != nil { + return nil, fmt.Errorf("failed to add label header to stream: %w", err) + } + return conn, nil +} From c25dd115fe9e4169876b82afd0e0cc4b7ce27364 Mon Sep 17 00:00:00 2001 From: Bryan Waters Date: Thu, 23 Dec 2021 09:24:23 -0600 Subject: [PATCH 05/12] Add warning message for remote node state limit --- net.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/net.go b/net.go index 1d015afb2..a3745f1da 100644 --- a/net.go +++ b/net.go @@ -1089,6 +1089,11 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader, streamLabel string) ( moreBytes := binary.BigEndian.Uint32(cipherText.Bytes()[1:5]) if moreBytes > maxPushStateBytes { return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes) + + } + //Start reporting the size before you cross the limit + if moreBytes > .6 * maxPushStateBytes { + m.logger.Printf("[WARN] memberlist: Remote node state size is %d approaching limit (%d)",moreBytes,maxPushStateBytes ) } // Read in the rest of the payload From 2d2beb5e2cf12781dec1b45416ca356bedbc63a4 Mon Sep 17 00:00:00 2001 From: Bryan Waters Date: Thu, 23 Dec 2021 09:29:19 -0600 Subject: [PATCH 06/12] Tweak format --- net.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/net.go b/net.go index a3745f1da..7d6d5f084 100644 --- a/net.go +++ b/net.go @@ -1091,9 +1091,10 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader, streamLabel string) ( return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes) } + //Start reporting the size before you cross the limit if moreBytes > .6 * maxPushStateBytes { - m.logger.Printf("[WARN] memberlist: Remote node state size is %d approaching limit (%d)",moreBytes,maxPushStateBytes ) + m.logger.Printf("[WARN] memberlist: Remote node state size is %d approaching limit (%d)", moreBytes, maxPushStateBytes ) } // Read in the rest of the payload From 17f26816f7a702eecb9ab324820ee37b0d976b94 Mon Sep 17 00:00:00 2001 From: Bryan Waters Date: Thu, 23 Dec 2021 09:31:39 -0600 Subject: [PATCH 07/12] run go fmt --- net.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/net.go b/net.go index 7d6d5f084..bd6f08386 100644 --- a/net.go +++ b/net.go @@ -1091,10 +1091,10 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader, streamLabel string) ( return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes) } - + //Start reporting the size before you cross the limit - if moreBytes > .6 * maxPushStateBytes { - m.logger.Printf("[WARN] memberlist: Remote node state size is %d approaching limit (%d)", moreBytes, maxPushStateBytes ) + if moreBytes > .6*maxPushStateBytes { + m.logger.Printf("[WARN] memberlist: Remote node state size is %d approaching limit (%d)", moreBytes, maxPushStateBytes) } // Read in the rest of the payload From 678bbf45cdda6d5e48c218fc53969c9162a01657 Mon Sep 17 00:00:00 2001 From: Bryan Waters Date: Thu, 23 Dec 2021 11:07:00 -0600 Subject: [PATCH 08/12] Need type conversion --- net.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/net.go b/net.go index bd6f08386..f2712a9b6 100644 --- a/net.go +++ b/net.go @@ -7,6 +7,7 @@ import ( "fmt" "hash/crc32" "io" + "math" "net" "sync/atomic" "time" @@ -1093,8 +1094,8 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader, streamLabel string) ( } //Start reporting the size before you cross the limit - if moreBytes > .6*maxPushStateBytes { - m.logger.Printf("[WARN] memberlist: Remote node state size is %d approaching limit (%d)", moreBytes, maxPushStateBytes) + if moreBytes > uint32(math.Floor(.06*maxPushStateBytes)) { + m.logger.Printf("[INFO] memberlist: Remote node state size is %d limit is (%d)", moreBytes, maxPushStateBytes) } // Read in the rest of the payload From c0dd9bd4487526ec9abeff00770b575aa3c3a112 Mon Sep 17 00:00:00 2001 From: Bryan Waters Date: Wed, 12 Jan 2022 07:59:44 -0600 Subject: [PATCH 09/12] Send out warning before crossing limit --- net.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/net.go b/net.go index f2712a9b6..1555b172a 100644 --- a/net.go +++ b/net.go @@ -1094,8 +1094,8 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader, streamLabel string) ( } //Start reporting the size before you cross the limit - if moreBytes > uint32(math.Floor(.06*maxPushStateBytes)) { - m.logger.Printf("[INFO] memberlist: Remote node state size is %d limit is (%d)", moreBytes, maxPushStateBytes) + if moreBytes > uint32(math.Floor(.6*maxPushStateBytes)) { + m.logger.Printf("[WARN] memberlist: Remote node state size is (%d) limit is (%d)", moreBytes, maxPushStateBytes) } // Read in the rest of the payload From 0bff30975d0b772ab5fe6bfe689c79cc6b1650f1 Mon Sep 17 00:00:00 2001 From: Bryan Waters Date: Wed, 12 Jan 2022 09:27:17 -0600 Subject: [PATCH 10/12] Purge left nodes along with dead ones (#254) Keep the state from infinitely expanding --- util.go | 4 ++-- util_test.go | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/util.go b/util.go index 16a7d36d0..cfae4c9e7 100644 --- a/util.go +++ b/util.go @@ -96,13 +96,13 @@ func pushPullScale(interval time.Duration, n int) time.Duration { return time.Duration(multiplier) * interval } -// moveDeadNodes moves nodes that are dead and beyond the gossip to the dead interval +// moveDeadNodes moves dead and left nodes that that have not changed during the gossipToTheDeadTime interval // to the end of the slice and returns the index of the first moved node. func moveDeadNodes(nodes []*nodeState, gossipToTheDeadTime time.Duration) int { numDead := 0 n := len(nodes) for i := 0; i < n-numDead; i++ { - if nodes[i].State != StateDead { + if !nodes[i].DeadOrLeft() { continue } diff --git a/util_test.go b/util_test.go index f97eb1703..5e3edb633 100644 --- a/util_test.go +++ b/util_test.go @@ -178,6 +178,16 @@ func TestMoveDeadNodes(t *testing.T) { State: StateDead, StateChange: time.Now().Add(-10 * time.Second), }, + // This left node should not be moved, as its state changed + // less than the specified GossipToTheDead time ago + &nodeState{ + State: StateLeft, + StateChange: time.Now().Add(-10 * time.Second), + }, + &nodeState{ + State: StateLeft, + StateChange: time.Now().Add(-20 * time.Second), + }, &nodeState{ State: StateAlive, StateChange: time.Now().Add(-20 * time.Second), @@ -190,10 +200,14 @@ func TestMoveDeadNodes(t *testing.T) { State: StateAlive, StateChange: time.Now().Add(-20 * time.Second), }, + &nodeState{ + State: StateLeft, + StateChange: time.Now().Add(-20 * time.Second), + }, } idx := moveDeadNodes(nodes, (15 * time.Second)) - if idx != 4 { + if idx != 5 { t.Fatalf("bad index") } for i := 0; i < idx; i++ { @@ -204,6 +218,11 @@ func TestMoveDeadNodes(t *testing.T) { if nodes[i].State != StateDead { t.Fatalf("Bad state %d", i) } + case 3: + //Recently left node should remain at 3 + if nodes[i].State != StateLeft { + t.Fatalf("Bad State %d", i) + } default: if nodes[i].State != StateAlive { t.Fatalf("Bad state %d", i) @@ -211,7 +230,7 @@ func TestMoveDeadNodes(t *testing.T) { } } for i := idx; i < len(nodes); i++ { - if nodes[i].State != StateDead { + if !nodes[i].DeadOrLeft() { t.Fatalf("Bad state %d", i) } } From 3c0cdff93db1763323d5fdc5dd5f9a5c4fd3c2f0 Mon Sep 17 00:00:00 2001 From: NODA Kai Date: Wed, 19 Jan 2022 22:39:35 +0800 Subject: [PATCH 11/12] state.go: log the timeout value on a ping timeout (#153) This will help users' troubleshooting Signed-off-by: NODA, Kai --- state.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/state.go b/state.go index a6351b4b0..6ba413515 100644 --- a/state.go +++ b/state.go @@ -449,7 +449,11 @@ HANDLE_REMOTE_FAILURE: defer close(fallbackCh) didContact, err := m.sendPingAndWaitForAck(node.FullAddress(), ping, deadline) if err != nil { - m.logger.Printf("[ERR] memberlist: Failed fallback ping: %s", err) + var to string + if ne, ok := err.(net.Error); ok && ne.Timeout() { + to = fmt.Sprintf("timeout %s: ", probeInterval) + } + m.logger.Printf("[ERR] memberlist: Failed fallback ping: %s%s", to, err) } else { fallbackCh <- didContact } From 696ff46201c1b64d31b60d57971d07d8cf15df7a Mon Sep 17 00:00:00 2001 From: Krastin Krastev Date: Fri, 25 Feb 2022 22:26:17 +0100 Subject: [PATCH 12/12] be explicit about protocol in state errors: TCP or UDP (#258) --- state.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/state.go b/state.go index f90fbe9bd..7a2339e9b 100644 --- a/state.go +++ b/state.go @@ -329,7 +329,7 @@ func (m *Memberlist) probeNode(node *nodeState) { }() if node.State == StateAlive { if err := m.encodeAndSendMsg(node.FullAddress(), pingMsg, &ping); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send ping: %s", err) + m.logger.Printf("[ERR] memberlist: Failed to send UDP ping: %s", err) if failedRemote(err) { goto HANDLE_REMOTE_FAILURE } else { @@ -339,7 +339,7 @@ func (m *Memberlist) probeNode(node *nodeState) { } else { var msgs [][]byte if buf, err := encode(pingMsg, &ping); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to encode ping message: %s", err) + m.logger.Printf("[ERR] memberlist: Failed to encode UDP ping message: %s", err) return } else { msgs = append(msgs, buf.Bytes()) @@ -354,7 +354,7 @@ func (m *Memberlist) probeNode(node *nodeState) { compound := makeCompoundMessage(msgs) if err := m.rawSendMsgPacket(node.FullAddress(), &node.Node, compound.Bytes()); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send compound ping and suspect message to %s: %s", addr, err) + m.logger.Printf("[ERR] memberlist: Failed to send UDP compound ping and suspect message to %s: %s", addr, err) if failedRemote(err) { goto HANDLE_REMOTE_FAILURE } else { @@ -393,7 +393,7 @@ func (m *Memberlist) probeNode(node *nodeState) { // probe interval it will give the TCP fallback more time, which // is more active in dealing with lost packets, and it gives more // time to wait for indirect acks/nacks. - m.logger.Printf("[DEBUG] memberlist: Failed ping: %s (timeout reached)", node.Name) + m.logger.Printf("[DEBUG] memberlist: Failed UDP ping: %s (timeout reached)", node.Name) } HANDLE_REMOTE_FAILURE: @@ -426,7 +426,7 @@ HANDLE_REMOTE_FAILURE: } if err := m.encodeAndSendMsg(peer.FullAddress(), indirectPingMsg, &ind); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send indirect ping: %s", err) + m.logger.Printf("[ERR] memberlist: Failed to send indirect UDP ping: %s", err) } } @@ -453,7 +453,7 @@ HANDLE_REMOTE_FAILURE: if ne, ok := err.(net.Error); ok && ne.Timeout() { to = fmt.Sprintf("timeout %s: ", probeInterval) } - m.logger.Printf("[ERR] memberlist: Failed fallback ping: %s%s", to, err) + m.logger.Printf("[ERR] memberlist: Failed fallback TCP ping: %s%s", to, err) } else { fallbackCh <- didContact } @@ -478,7 +478,7 @@ HANDLE_REMOTE_FAILURE: // any additional time here. for didContact := range fallbackCh { if didContact { - m.logger.Printf("[WARN] memberlist: Was able to connect to %s but other probes failed, network may be misconfigured", node.Name) + m.logger.Printf("[WARN] memberlist: Was able to connect to %s over TCP but UDP probes failed, network may be misconfigured", node.Name) return } }