diff --git a/.circleci/config.yml b/.circleci/config.yml index 78ac11fb7..a6eebf2db 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -4,7 +4,7 @@ version: 2.1 executors: go: docker: - - image: docker.mirror.hashicorp.services/circleci/golang:1.15 + - image: docker.mirror.hashicorp.services/circleci/golang:1.16 environment: - TEST_RESULTS: /tmp/test-results # path to where test results are saved diff --git a/config.go b/config.go index 31099e75f..d7fe4c37b 100644 --- a/config.go +++ b/config.go @@ -21,6 +21,17 @@ type Config struct { // make a NetTransport using BindAddr and BindPort from this structure. Transport Transport + // Label is an optional set of bytes to include on the outside of each + // packet and stream. + // + // If gossip encryption is enabled and this is set it is treated as GCM + // authenticated data. + Label string + + // SkipInboundLabelCheck skips the check that inbound packets and gossip + // streams need to be label prefixed. + SkipInboundLabelCheck bool + // Configuration related to what address to bind to and ports to // listen on. The port is used for both UDP and TCP gossip. It is // assumed other nodes are running on this port, but they do not need diff --git a/label.go b/label.go new file mode 100644 index 000000000..bbe0163ab --- /dev/null +++ b/label.go @@ -0,0 +1,178 @@ +package memberlist + +import ( + "bufio" + "fmt" + "io" + "net" +) + +// General approach is to prefix all packets and streams with the same structure: +// +// magic type byte (244): uint8 +// length of label name: uint8 (because labels can't be longer than 255 bytes) +// label name: []uint8 + +// LabelMaxSize is the maximum length of a packet or stream label. +const LabelMaxSize = 255 + +// AddLabelHeaderToPacket prefixes outgoing packets with the correct header if +// the label is not empty. +func AddLabelHeaderToPacket(buf []byte, label string) ([]byte, error) { + if label == "" { + return buf, nil + } + if len(label) > LabelMaxSize { + return nil, fmt.Errorf("label %q is too long", label) + } + + return makeLabelHeader(label, buf), nil +} + +// RemoveLabelHeaderFromPacket removes any label header from the provided +// packet and returns it along with the remaining packet contents. +func RemoveLabelHeaderFromPacket(buf []byte) (newBuf []byte, label string, err error) { + if len(buf) == 0 { + return buf, "", nil // can't possibly be labeled + } + + // [type:byte] [size:byte] [size bytes] + + msgType := messageType(buf[0]) + if msgType != hasLabelMsg { + return buf, "", nil + } + + if len(buf) < 2 { + return nil, "", fmt.Errorf("cannot decode label; packet has been truncated") + } + + size := int(buf[1]) + if size < 1 { + return nil, "", fmt.Errorf("label header cannot be empty when present") + } + + if len(buf) < 2+size { + return nil, "", fmt.Errorf("cannot decode label; packet has been truncated") + } + + label = string(buf[2 : 2+size]) + newBuf = buf[2+size:] + + return newBuf, label, nil +} + +// AddLabelHeaderToStream prefixes outgoing streams with the correct header if +// the label is not empty. +func AddLabelHeaderToStream(conn net.Conn, label string) error { + if label == "" { + return nil + } + if len(label) > LabelMaxSize { + return fmt.Errorf("label %q is too long", label) + } + + header := makeLabelHeader(label, nil) + + _, err := conn.Write(header) + return err +} + +// RemoveLabelHeaderFromStream removes any label header from the beginning of +// the stream if present and returns it along with an updated conn with that +// header removed. +// +// Note that on error it is the caller's responsibility to close the +// connection. +func RemoveLabelHeaderFromStream(conn net.Conn) (net.Conn, string, error) { + br := bufio.NewReader(conn) + + // First check for the type byte. + peeked, err := br.Peek(1) + if err != nil { + if err == io.EOF { + // It is safe to return the original net.Conn at this point because + // it never contained any data in the first place so we don't have + // to splice the buffer into the conn because both are empty. + return conn, "", nil + } + return nil, "", err + } + + msgType := messageType(peeked[0]) + if msgType != hasLabelMsg { + conn, err = newPeekedConnFromBufferedReader(conn, br, 0) + return conn, "", err + } + + // We are guaranteed to get a size byte as well. + peeked, err = br.Peek(2) + if err != nil { + if err == io.EOF { + return nil, "", fmt.Errorf("cannot decode label; stream has been truncated") + } + return nil, "", err + } + + size := int(peeked[1]) + if size < 1 { + return nil, "", fmt.Errorf("label header cannot be empty when present") + } + // NOTE: we don't have to check this against LabelMaxSize because a byte + // already has a max value of 255. + + // Once we know the size we can peek the label as well. Note that since we + // are using the default bufio.Reader size of 4096, the entire label header + // fits in the initial buffer fill so this should be free. + peeked, err = br.Peek(2 + size) + if err != nil { + if err == io.EOF { + return nil, "", fmt.Errorf("cannot decode label; stream has been truncated") + } + return nil, "", err + } + + label := string(peeked[2 : 2+size]) + + conn, err = newPeekedConnFromBufferedReader(conn, br, 2+size) + if err != nil { + return nil, "", err + } + + return conn, label, nil +} + +// newPeekedConnFromBufferedReader will splice the buffer contents after the +// offset into the provided net.Conn and return the result so that the rest of +// the buffer contents are returned first when reading from the returned +// peekedConn before moving on to the unbuffered conn contents. +func newPeekedConnFromBufferedReader(conn net.Conn, br *bufio.Reader, offset int) (*peekedConn, error) { + // Extract any of the readahead buffer. + peeked, err := br.Peek(br.Buffered()) + if err != nil { + return nil, err + } + + return &peekedConn{ + Peeked: peeked[offset:], + Conn: conn, + }, nil +} + +func makeLabelHeader(label string, rest []byte) []byte { + newBuf := make([]byte, 2, 2+len(label)+len(rest)) + newBuf[0] = byte(hasLabelMsg) + newBuf[1] = byte(len(label)) + newBuf = append(newBuf, []byte(label)...) + if len(rest) > 0 { + newBuf = append(newBuf, []byte(rest)...) + } + return newBuf +} + +func labelOverhead(label string) int { + if label == "" { + return 0 + } + return 2 + len(label) +} diff --git a/label_test.go b/label_test.go new file mode 100644 index 000000000..a78ea4a9c --- /dev/null +++ b/label_test.go @@ -0,0 +1,364 @@ +package memberlist + +import ( + "bytes" + "io" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAddLabelHeaderToPacket(t *testing.T) { + type testcase struct { + buf []byte + label string + expectPacket []byte + expectErr string + } + + run := func(t *testing.T, tc testcase) { + got, err := AddLabelHeaderToPacket(tc.buf, tc.label) + if tc.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectPacket, got) + } + } + + longLabel := strings.Repeat("a", 255) + + cases := map[string]testcase{ + "nil buf with no label": testcase{ + buf: nil, + label: "", + expectPacket: nil, + }, + "nil buf with label": testcase{ + buf: nil, + label: "foo", + expectPacket: append([]byte{byte(hasLabelMsg), 3}, []byte("foo")...), + }, + "message with label": testcase{ + buf: []byte("something"), + label: "foo", + expectPacket: append([]byte{byte(hasLabelMsg), 3}, []byte("foosomething")...), + }, + "message with no label": testcase{ + buf: []byte("something"), + label: "", + expectPacket: []byte("something"), + }, + "message with almost too long label": testcase{ + buf: []byte("something"), + label: longLabel, + expectPacket: append([]byte{byte(hasLabelMsg), 255}, []byte(longLabel+"something")...), + }, + "label too long by one byte": testcase{ + buf: []byte("something"), + label: longLabel + "x", + expectErr: `label "` + longLabel + `x" is too long`, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + run(t, tc) + }) + } +} + +func TestRemoveLabelHeaderFromPacket(t *testing.T) { + type testcase struct { + buf []byte + expectLabel string + expectPacket []byte + expectErr string + } + + run := func(t *testing.T, tc testcase) { + gotBuf, gotLabel, err := RemoveLabelHeaderFromPacket(tc.buf) + if tc.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectPacket, gotBuf) + require.Equal(t, tc.expectLabel, gotLabel) + } + } + + cases := map[string]testcase{ + "empty buf": testcase{ + buf: []byte{}, + expectLabel: "", + expectPacket: []byte{}, + }, + "ping with no label": testcase{ + buf: buildBuffer(t, pingMsg, "blah"), + expectLabel: "", + expectPacket: buildBuffer(t, pingMsg, "blah"), + }, + "error with no label": testcase{ // 2021-10: largest standard message type + buf: buildBuffer(t, errMsg, "blah"), + expectLabel: "", + expectPacket: buildBuffer(t, errMsg, "blah"), + }, + "v1 encrypt with no label": testcase{ // 2021-10: highest encryption version + buf: buildBuffer(t, maxEncryptionVersion, "blah"), + expectLabel: "", + expectPacket: buildBuffer(t, maxEncryptionVersion, "blah"), + }, + "buf too small for label": testcase{ + buf: buildBuffer(t, hasLabelMsg, "x"), + expectErr: `cannot decode label; packet has been truncated`, + }, + "buf too small for label size": testcase{ + buf: buildBuffer(t, hasLabelMsg), + expectErr: `cannot decode label; packet has been truncated`, + }, + "label empty": testcase{ + buf: buildBuffer(t, hasLabelMsg, 0, "x"), + expectErr: `label header cannot be empty when present`, + }, + "label truncated": testcase{ + buf: buildBuffer(t, hasLabelMsg, 2, "x"), + expectErr: `cannot decode label; packet has been truncated`, + }, + "ping with label": testcase{ + buf: buildBuffer(t, hasLabelMsg, 3, "abc", pingMsg, "blah"), + expectLabel: "abc", + expectPacket: buildBuffer(t, pingMsg, "blah"), + }, + "error with label": testcase{ // 2021-10: largest standard message type + buf: buildBuffer(t, hasLabelMsg, 3, "abc", errMsg, "blah"), + expectLabel: "abc", + expectPacket: buildBuffer(t, errMsg, "blah"), + }, + "v1 encrypt with label": testcase{ // 2021-10: highest encryption version + buf: buildBuffer(t, hasLabelMsg, 3, "abc", maxEncryptionVersion, "blah"), + expectLabel: "abc", + expectPacket: buildBuffer(t, maxEncryptionVersion, "blah"), + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + run(t, tc) + }) + } +} + +func TestAddLabelHeaderToStream(t *testing.T) { + type testcase struct { + label string + expectData []byte + expectErr string + } + + suffixData := "EXTRA DATA" + + run := func(t *testing.T, tc testcase) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + var ( + dataCh = make(chan []byte, 1) + errCh = make(chan error, 1) + ) + go func() { + var buf bytes.Buffer + _, err := io.Copy(&buf, server) + if err != nil { + errCh <- err + } + dataCh <- buf.Bytes() + }() + + err := AddLabelHeaderToStream(client, tc.label) + if tc.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + return + } + require.NoError(t, err) + + client.Write([]byte(suffixData)) + client.Close() + + expect := make([]byte, 0, len(suffixData)+len(tc.expectData)) + expect = append(expect, tc.expectData...) + expect = append(expect, suffixData...) + + select { + case err := <-errCh: + require.NoError(t, err) + case got := <-dataCh: + require.Equal(t, expect, got) + } + } + + longLabel := strings.Repeat("a", 255) + + cases := map[string]testcase{ + "no label": testcase{ + label: "", + expectData: nil, + }, + "with label": testcase{ + label: "foo", + expectData: buildBuffer(t, hasLabelMsg, 3, "foo"), + }, + "almost too long label": testcase{ + label: longLabel, + expectData: buildBuffer(t, hasLabelMsg, 255, longLabel), + }, + "label too long by one byte": testcase{ + label: longLabel + "x", + expectErr: `label "` + longLabel + `x" is too long`, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + run(t, tc) + }) + } +} + +func TestRemoveLabelHeaderFromStream(t *testing.T) { + type testcase struct { + buf []byte + expectLabel string + expectData []byte + expectErr string + } + + run := func(t *testing.T, tc testcase) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + var ( + errCh = make(chan error, 1) + ) + go func() { + _, err := server.Write(tc.buf) + if err != nil { + errCh <- err + } + server.Close() + }() + + newConn, gotLabel, err := RemoveLabelHeaderFromStream(client) + if tc.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + return + } + require.NoError(t, err) + + gotBuf, err := io.ReadAll(newConn) + require.NoError(t, err) + + require.Equal(t, tc.expectData, gotBuf) + require.Equal(t, tc.expectLabel, gotLabel) + } + + cases := map[string]testcase{ + "empty buf": testcase{ + buf: []byte{}, + expectLabel: "", + expectData: []byte{}, + }, + "ping with no label": testcase{ + buf: buildBuffer(t, pingMsg, "blah"), + expectLabel: "", + expectData: buildBuffer(t, pingMsg, "blah"), + }, + "error with no label": testcase{ // 2021-10: largest standard message type + buf: buildBuffer(t, errMsg, "blah"), + expectLabel: "", + expectData: buildBuffer(t, errMsg, "blah"), + }, + "v1 encrypt with no label": testcase{ // 2021-10: highest encryption version + buf: buildBuffer(t, maxEncryptionVersion, "blah"), + expectLabel: "", + expectData: buildBuffer(t, maxEncryptionVersion, "blah"), + }, + "buf too small for label": testcase{ + buf: buildBuffer(t, hasLabelMsg, "x"), + expectErr: `cannot decode label; stream has been truncated`, + }, + "buf too small for label size": testcase{ + buf: buildBuffer(t, hasLabelMsg), + expectErr: `cannot decode label; stream has been truncated`, + }, + "label empty": testcase{ + buf: buildBuffer(t, hasLabelMsg, 0, "x"), + expectErr: `label header cannot be empty when present`, + }, + "label truncated": testcase{ + buf: buildBuffer(t, hasLabelMsg, 2, "x"), + expectErr: `cannot decode label; stream has been truncated`, + }, + "ping with label": testcase{ + buf: buildBuffer(t, hasLabelMsg, 3, "abc", pingMsg, "blah"), + expectLabel: "abc", + expectData: buildBuffer(t, pingMsg, "blah"), + }, + "error with label": testcase{ // 2021-10: largest standard message type + buf: buildBuffer(t, hasLabelMsg, 3, "abc", errMsg, "blah"), + expectLabel: "abc", + expectData: buildBuffer(t, errMsg, "blah"), + }, + "v1 encrypt with label": testcase{ // 2021-10: highest encryption version + buf: buildBuffer(t, hasLabelMsg, 3, "abc", maxEncryptionVersion, "blah"), + expectLabel: "abc", + expectData: buildBuffer(t, maxEncryptionVersion, "blah"), + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + run(t, tc) + }) + } +} + +func buildBuffer(t *testing.T, stuff ...interface{}) []byte { + var buf bytes.Buffer + for _, item := range stuff { + switch x := item.(type) { + case int: + x2 := uint(x) + if x2 > 255 { + t.Fatalf("int is too big") + } + buf.WriteByte(byte(x2)) + case byte: + buf.WriteByte(byte(x)) + case messageType: + buf.WriteByte(byte(x)) + case encryptionVersion: + buf.WriteByte(byte(x)) + case string: + buf.Write([]byte(x)) + case []byte: + buf.Write(x) + default: + t.Fatalf("unexpected type %T", item) + } + } + return buf.Bytes() +} + +func TestLabelOverhead(t *testing.T) { + require.Equal(t, 0, labelOverhead("")) + require.Equal(t, 3, labelOverhead("a")) + require.Equal(t, 9, labelOverhead("abcdefg")) +} diff --git a/memberlist.go b/memberlist.go index 7ee040091..cab6db69f 100644 --- a/memberlist.go +++ b/memberlist.go @@ -187,6 +187,17 @@ func newMemberlist(conf *Config) (*Memberlist, error) { nodeAwareTransport = &shimNodeAwareTransport{transport} } + if len(conf.Label) > LabelMaxSize { + return nil, fmt.Errorf("could not use %q as a label: too long", conf.Label) + } + + if conf.Label != "" { + nodeAwareTransport = &labelWrappedTransport{ + label: conf.Label, + NodeAwareTransport: nodeAwareTransport, + } + } + m := &Memberlist{ config: conf, shutdownCh: make(chan struct{}), @@ -262,7 +273,7 @@ func (m *Memberlist) Join(existing []string) (int, error) { hp := joinHostPort(addr.ip.String(), addr.port) a := Address{Addr: hp, Name: addr.nodeName} if err := m.pushPullNode(a, true); err != nil { - err = fmt.Errorf("Failed to join %s: %v", addr.ip, err) + err = fmt.Errorf("Failed to join %s: %v", a.Addr, err) errs = multierror.Append(errs, err) m.logger.Printf("[DEBUG] memberlist: %v", err) continue diff --git a/memberlist_test.go b/memberlist_test.go index 6a1eb70a0..686389248 100644 --- a/memberlist_test.go +++ b/memberlist_test.go @@ -16,6 +16,7 @@ import ( iretry "github.com/hashicorp/memberlist/internal/retry" "github.com/miekg/dns" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -541,7 +542,7 @@ func TestMemberList_ResolveAddr_TCP_First(t *testing.T) { go func() { if err := server.ListenAndServe(); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { - t.Fatalf("err: %v", err) + t.Errorf("err: %v", err) } }() wg.Wait() @@ -643,6 +644,90 @@ func TestMemberlist_Join(t *testing.T) { } } +func TestMemberlist_Join_with_Labels(t *testing.T) { + testMemberlist_Join_with_Labels(t, nil) +} +func TestMemberlist_Join_with_Labels_and_Encryption(t *testing.T) { + secretKey := TestKeys[0] + testMemberlist_Join_with_Labels(t, secretKey) +} +func testMemberlist_Join_with_Labels(t *testing.T, secretKey []byte) { + c1 := testConfig(t) + c1.Label = "blah" + c1.SecretKey = secretKey + m1, err := Create(c1) + require.NoError(t, err) + defer m1.Shutdown() + + bindPort := m1.config.BindPort + + // Create a second node + c2 := testConfig(t) + c2.Label = "blah" + c2.BindPort = bindPort + c2.SecretKey = secretKey + m2, err := Create(c2) + require.NoError(t, err) + defer m2.Shutdown() + + checkHost := func(t *testing.T, m *Memberlist, expected int) { + assert.Equal(t, expected, len(m.Members())) + assert.Equal(t, expected, m.estNumNodes()) + } + + runStep(t, "same label can join", func(t *testing.T) { + num, err := m2.Join([]string{m1.config.Name + "/" + m1.config.BindAddr}) + require.NoError(t, err) + require.Equal(t, 1, num) + + // Check the hosts + checkHost(t, m2, 2) + checkHost(t, m1, 2) + }) + + // Create a third node that uses no label + c3 := testConfig(t) + c3.Label = "" + c3.BindPort = bindPort + c3.SecretKey = secretKey + m3, err := Create(c3) + require.NoError(t, err) + defer m3.Shutdown() + + runStep(t, "no label cannot join", func(t *testing.T) { + _, err := m3.Join([]string{m1.config.Name + "/" + m1.config.BindAddr}) + require.Error(t, err) + + // Check the failed host + checkHost(t, m3, 1) + // Check the existing hosts + checkHost(t, m2, 2) + checkHost(t, m1, 2) + }) + + // Create a fourth node that uses a mismatched label + c4 := testConfig(t) + c4.Label = "not-blah" + c4.BindPort = bindPort + c4.SecretKey = secretKey + m4, err := Create(c4) + require.NoError(t, err) + defer m4.Shutdown() + + runStep(t, "mismatched label cannot join", func(t *testing.T) { + _, err := m4.Join([]string{m1.config.Name + "/" + m1.config.BindAddr}) + require.Error(t, err) + + // Check the failed host + checkHost(t, m4, 1) + // Check the previous failed host + checkHost(t, m3, 1) + // Check the existing hosts + checkHost(t, m2, 2) + checkHost(t, m1, 2) + }) +} + func TestMemberlist_JoinDifferentNetworksUniqueMask(t *testing.T) { c1 := testConfigNet(t, 0) c1.CIDRsAllowed, _ = ParseCIDRs([]string{"127.0.0.0/8"}) @@ -2071,3 +2156,10 @@ func TestMemberlist_EncryptedGossipTransition(t *testing.T) { // t.Fatalf("bad role for %s: %s", c2.Name, r) // } //} + +func runStep(t *testing.T, name string, fn func(t *testing.T)) { + t.Helper() + if !t.Run(name, fn) { + t.FailNow() + } +} diff --git a/net.go b/net.go index fe4acbc2e..66c1dcd94 100644 --- a/net.go +++ b/net.go @@ -7,6 +7,7 @@ import ( "fmt" "hash/crc32" "io" + "math" "net" "sync/atomic" "time" @@ -42,6 +43,9 @@ const ( type messageType uint8 // The list of available message types. +// +// WARNING: ONLY APPEND TO THIS LIST! The numeric values are part of the +// protocol itself. const ( pingMsg messageType = iota indirectPingMsg @@ -59,6 +63,13 @@ const ( errMsg ) +const ( + // hasLabelMsg has a deliberately high value so that you can disambiguate + // it from the encryptionVersion header which is either 0/1 right now and + // also any of the existing messageTypes + hasLabelMsg messageType = 244 +) + // compressionType is used to specify the compression algorithm type compressionType uint8 @@ -226,7 +237,32 @@ func (m *Memberlist) handleConn(conn net.Conn) { metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1) conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) - msgType, bufConn, dec, err := m.readStream(conn) + + var ( + streamLabel string + err error + ) + conn, streamLabel, err = RemoveLabelHeaderFromStream(conn) + if err != nil { + m.logger.Printf("[ERR] memberlist: failed to receive and remove the stream label header: %s %s", err, LogConn(conn)) + return + } + + if m.config.SkipInboundLabelCheck { + if streamLabel != "" { + m.logger.Printf("[ERR] memberlist: unexpected double stream label header: %s", LogConn(conn)) + return + } + // Set this from config so that the auth data assertions work below. + streamLabel = m.config.Label + } + + if m.config.Label != streamLabel { + m.logger.Printf("[ERR] memberlist: discarding stream with unacceptable label %q: %s", streamLabel, LogConn(conn)) + return + } + + msgType, bufConn, dec, err := m.readStream(conn, streamLabel) if err != nil { if err != io.EOF { m.logger.Printf("[ERR] memberlist: failed to receive: %s %s", err, LogConn(conn)) @@ -238,7 +274,7 @@ func (m *Memberlist) handleConn(conn net.Conn) { return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), streamLabel) if err != nil { m.logger.Printf("[ERR] memberlist: Failed to send error: %s %s", err, LogConn(conn)) return @@ -269,7 +305,7 @@ func (m *Memberlist) handleConn(conn net.Conn) { return } - if err := m.sendLocalState(conn, join); err != nil { + if err := m.sendLocalState(conn, join, streamLabel); err != nil { m.logger.Printf("[ERR] memberlist: Failed to push local state: %s %s", err, LogConn(conn)) return } @@ -297,7 +333,7 @@ func (m *Memberlist) handleConn(conn net.Conn) { return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), streamLabel) if err != nil { m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogConn(conn)) return @@ -322,10 +358,35 @@ func (m *Memberlist) packetListen() { } func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time) { + var ( + packetLabel string + err error + ) + buf, packetLabel, err = RemoveLabelHeaderFromPacket(buf) + if err != nil { + m.logger.Printf("[ERR] memberlist: %v %s", err, LogAddress(from)) + return + } + + if m.config.SkipInboundLabelCheck { + if packetLabel != "" { + m.logger.Printf("[ERR] memberlist: unexpected double packet label header: %s", LogAddress(from)) + return + } + // Set this from config so that the auth data assertions work below. + packetLabel = m.config.Label + } + + if m.config.Label != packetLabel { + m.logger.Printf("[ERR] memberlist: discarding packet with unacceptable label %q: %s", packetLabel, LogAddress(from)) + return + } + // Check if encryption is enabled if m.config.EncryptionEnabled() { // Decrypt the payload - plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil) + authData := []byte(packetLabel) + plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, authData) if err != nil { if !m.config.GossipVerifyIncoming { // Treat the message as plaintext @@ -723,7 +784,7 @@ func (m *Memberlist) encodeAndSendMsg(a Address, msgType messageType, msg interf // opportunistically create a compoundMsg and piggy back other broadcasts. func (m *Memberlist) sendMsg(a Address, msg []byte) error { // Check if we can piggy back any messages - bytesAvail := m.config.UDPBufferSize - len(msg) - compoundHeaderOverhead + bytesAvail := m.config.UDPBufferSize - len(msg) - compoundHeaderOverhead - labelOverhead(m.config.Label) if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { bytesAvail -= encryptOverhead(m.encryptionVersion()) } @@ -801,9 +862,12 @@ func (m *Memberlist) rawSendMsgPacket(a Address, node *Node, msg []byte) error { // Check if we have encryption enabled if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { // Encrypt the payload - var buf bytes.Buffer - primaryKey := m.config.Keyring.GetPrimaryKey() - err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf) + var ( + primaryKey = m.config.Keyring.GetPrimaryKey() + packetLabel = []byte(m.config.Label) + buf bytes.Buffer + ) + err := encryptPayload(m.encryptionVersion(), primaryKey, msg, packetLabel, &buf) if err != nil { m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err) return err @@ -818,7 +882,7 @@ func (m *Memberlist) rawSendMsgPacket(a Address, node *Node, msg []byte) error { // rawSendMsgStream is used to stream a message to another host without // modification, other than applying compression and encryption if enabled. -func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error { +func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte, streamLabel string) error { // Check if compression is enabled if m.config.EnableCompression { compBuf, err := compressPayload(sendBuf) @@ -831,7 +895,7 @@ func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error { // Check if encryption is enabled if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { - crypt, err := m.encryptLocalState(sendBuf) + crypt, err := m.encryptLocalState(sendBuf, streamLabel) if err != nil { m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err) return err @@ -877,7 +941,8 @@ func (m *Memberlist) sendUserMsg(a Address, sendBuf []byte) error { if _, err := bufConn.Write(sendBuf); err != nil { return err } - return m.rawSendMsgStream(conn, bufConn.Bytes()) + + return m.rawSendMsgStream(conn, bufConn.Bytes(), m.config.Label) } // sendAndReceiveState is used to initiate a push/pull over a stream with a @@ -897,12 +962,12 @@ func (m *Memberlist) sendAndReceiveState(a Address, join bool) ([]pushNodeState, metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1) // Send our state - if err := m.sendLocalState(conn, join); err != nil { + if err := m.sendLocalState(conn, join, m.config.Label); err != nil { return nil, nil, err } conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) - msgType, bufConn, dec, err := m.readStream(conn) + msgType, bufConn, dec, err := m.readStream(conn, m.config.Label) if err != nil { return nil, nil, err } @@ -927,7 +992,7 @@ func (m *Memberlist) sendAndReceiveState(a Address, join bool) ([]pushNodeState, } // sendLocalState is invoked to send our local state over a stream connection. -func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error { +func (m *Memberlist) sendLocalState(conn net.Conn, join bool, streamLabel string) error { // Setup a deadline conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) @@ -984,11 +1049,11 @@ func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error { } // Get the send buffer - return m.rawSendMsgStream(conn, bufConn.Bytes()) + return m.rawSendMsgStream(conn, bufConn.Bytes(), streamLabel) } // encryptLocalState is used to help encrypt local state before sending -func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { +func (m *Memberlist) encryptLocalState(sendBuf []byte, streamLabel string) ([]byte, error) { var buf bytes.Buffer // Write the encryptMsg byte @@ -1001,9 +1066,15 @@ func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { binary.BigEndian.PutUint32(sizeBuf, uint32(encLen)) buf.Write(sizeBuf) + // Authenticated Data is: + // + // [messageType; byte] [messageLength; uint32] [stream_label; optional] + // + dataBytes := appendBytes(buf.Bytes()[:5], []byte(streamLabel)) + // Write the encrypted cipher text to the buffer key := m.config.Keyring.GetPrimaryKey() - err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf) + err := encryptPayload(encVsn, key, sendBuf, dataBytes, &buf) if err != nil { return nil, err } @@ -1011,7 +1082,7 @@ func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { } // decryptRemoteState is used to help decrypt the remote state -func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { +func (m *Memberlist) decryptRemoteState(bufConn io.Reader, streamLabel string) ([]byte, error) { // Read in enough to determine message length cipherText := bytes.NewBuffer(nil) cipherText.WriteByte(byte(encryptMsg)) @@ -1025,6 +1096,12 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { moreBytes := binary.BigEndian.Uint32(cipherText.Bytes()[1:5]) if moreBytes > maxPushStateBytes { return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes) + + } + + //Start reporting the size before you cross the limit + if moreBytes > uint32(math.Floor(.6*maxPushStateBytes)) { + m.logger.Printf("[WARN] memberlist: Remote node state size is (%d) limit is (%d)", moreBytes, maxPushStateBytes) } // Read in the rest of the payload @@ -1033,8 +1110,13 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { return nil, err } - // Decrypt the cipherText - dataBytes := cipherText.Bytes()[:5] + // Decrypt the cipherText with some authenticated data + // + // Authenticated Data is: + // + // [messageType; byte] [messageLength; uint32] [label_data; optional] + // + dataBytes := appendBytes(cipherText.Bytes()[:5], []byte(streamLabel)) cipherBytes := cipherText.Bytes()[5:] // Decrypt the payload @@ -1042,15 +1124,18 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { return decryptPayload(keys, cipherBytes, dataBytes) } -// readStream is used to read from a stream connection, decrypting and +// readStream is used to read messages from a stream connection, decrypting and // decompressing the stream if necessary. -func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.Decoder, error) { +// +// The provided streamLabel if present will be authenticated during decryption +// of each message. +func (m *Memberlist) readStream(conn net.Conn, streamLabel string) (messageType, io.Reader, *codec.Decoder, error) { // Created a buffered reader var bufConn io.Reader = bufio.NewReader(conn) // Read the message type buf := [1]byte{0} - if _, err := bufConn.Read(buf[:]); err != nil { + if _, err := io.ReadFull(bufConn, buf[:]); err != nil { return 0, nil, nil, err } msgType := messageType(buf[0]) @@ -1062,7 +1147,7 @@ func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.D fmt.Errorf("Remote state is encrypted and encryption is not configured") } - plain, err := m.decryptRemoteState(bufConn) + plain, err := m.decryptRemoteState(bufConn, streamLabel) if err != nil { return 0, nil, nil, err } @@ -1242,11 +1327,11 @@ func (m *Memberlist) sendPingAndWaitForAck(a Address, ping ping, deadline time.T return false, err } - if err = m.rawSendMsgStream(conn, out.Bytes()); err != nil { + if err = m.rawSendMsgStream(conn, out.Bytes(), m.config.Label); err != nil { return false, err } - msgType, _, dec, err := m.readStream(conn) + msgType, _, dec, err := m.readStream(conn, m.config.Label) if err != nil { return false, err } diff --git a/net_test.go b/net_test.go index 9f1a9aae7..ce287547e 100644 --- a/net_test.go +++ b/net_test.go @@ -292,46 +292,56 @@ func TestTCPPing(t *testing.T) { // Do a normal round trip. pingOut := ping{SeqNo: 23, Node: "mongo"} + pingErrCh := make(chan error, 1) go func() { tcp.SetDeadline(time.Now().Add(pingTimeMax)) conn, err := tcp.AcceptTCP() if err != nil { - t.Fatalf("failed to connect: %s", err) + pingErrCh <- fmt.Errorf("failed to connect: %s", err) + return } defer conn.Close() - msgType, _, dec, err := m.readStream(conn) + msgType, _, dec, err := m.readStream(conn, "") if err != nil { - t.Fatalf("failed to read ping: %s", err) + pingErrCh <- fmt.Errorf("failed to read ping: %s", err) + return } if msgType != pingMsg { - t.Fatalf("expecting ping, got message type (%d)", msgType) + pingErrCh <- fmt.Errorf("expecting ping, got message type (%d)", msgType) + return } var pingIn ping if err := dec.Decode(&pingIn); err != nil { - t.Fatalf("failed to decode ping: %s", err) + pingErrCh <- fmt.Errorf("failed to decode ping: %s", err) + return } if pingIn.SeqNo != pingOut.SeqNo { - t.Fatalf("sequence number isn't correct (%d) vs (%d)", pingIn.SeqNo, pingOut.SeqNo) + pingErrCh <- fmt.Errorf("sequence number isn't correct (%d) vs (%d)", pingIn.SeqNo, pingOut.SeqNo) + return } if pingIn.Node != pingOut.Node { - t.Fatalf("node name isn't correct (%s) vs (%s)", pingIn.Node, pingOut.Node) + pingErrCh <- fmt.Errorf("node name isn't correct (%s) vs (%s)", pingIn.Node, pingOut.Node) + return } ack := ackResp{pingIn.SeqNo, nil} out, err := encode(ackRespMsg, &ack) if err != nil { - t.Fatalf("failed to encode ack: %s", err) + pingErrCh <- fmt.Errorf("failed to encode ack: %s", err) + return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), "") if err != nil { - t.Fatalf("failed to send ack: %s", err) + pingErrCh <- fmt.Errorf("failed to send ack: %s", err) + return } + pingErrCh <- nil }() deadline := time.Now().Add(pingTimeout) didContact, err := m.sendPingAndWaitForAck(tcpAddr2, pingOut, deadline) @@ -341,36 +351,45 @@ func TestTCPPing(t *testing.T) { if !didContact { t.Fatalf("expected successful ping") } + if err = <-pingErrCh; err != nil { + t.Fatal(err) + } // Make sure a mis-matched sequence number is caught. go func() { tcp.SetDeadline(time.Now().Add(pingTimeMax)) conn, err := tcp.AcceptTCP() if err != nil { - t.Fatalf("failed to connect: %s", err) + pingErrCh <- fmt.Errorf("failed to connect: %s", err) + return } defer conn.Close() - _, _, dec, err := m.readStream(conn) + _, _, dec, err := m.readStream(conn, "") if err != nil { - t.Fatalf("failed to read ping: %s", err) + pingErrCh <- fmt.Errorf("failed to read ping: %s", err) + return } var pingIn ping if err := dec.Decode(&pingIn); err != nil { - t.Fatalf("failed to decode ping: %s", err) + pingErrCh <- fmt.Errorf("failed to decode ping: %s", err) + return } ack := ackResp{pingIn.SeqNo + 1, nil} out, err := encode(ackRespMsg, &ack) if err != nil { - t.Fatalf("failed to encode ack: %s", err) + pingErrCh <- fmt.Errorf("failed to encode ack: %s", err) + return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), "") if err != nil { - t.Fatalf("failed to send ack: %s", err) + pingErrCh <- fmt.Errorf("failed to send ack: %s", err) + return } + pingErrCh <- nil }() deadline = time.Now().Add(pingTimeout) didContact, err = m.sendPingAndWaitForAck(tcpAddr2, pingOut, deadline) @@ -380,31 +399,39 @@ func TestTCPPing(t *testing.T) { if didContact { t.Fatalf("expected failed ping") } + if err = <-pingErrCh; err != nil { + t.Fatal(err) + } // Make sure an unexpected message type is handled gracefully. go func() { tcp.SetDeadline(time.Now().Add(pingTimeMax)) conn, err := tcp.AcceptTCP() if err != nil { - t.Fatalf("failed to connect: %s", err) + pingErrCh <- fmt.Errorf("failed to connect: %s", err) + return } defer conn.Close() - _, _, _, err = m.readStream(conn) + _, _, _, err = m.readStream(conn, "") if err != nil { - t.Fatalf("failed to read ping: %s", err) + pingErrCh <- fmt.Errorf("failed to read ping: %s", err) + return } bogus := indirectPingReq{} out, err := encode(indirectPingMsg, &bogus) if err != nil { - t.Fatalf("failed to encode bogus msg: %s", err) + pingErrCh <- fmt.Errorf("failed to encode bogus msg: %s", err) + return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), "") if err != nil { - t.Fatalf("failed to send bogus msg: %s", err) + pingErrCh <- fmt.Errorf("failed to send bogus msg: %s", err) + return } + pingErrCh <- nil }() deadline = time.Now().Add(pingTimeout) didContact, err = m.sendPingAndWaitForAck(tcpAddr2, pingOut, deadline) @@ -414,6 +441,9 @@ func TestTCPPing(t *testing.T) { if didContact { t.Fatalf("expected failed ping") } + if err = <-pingErrCh; err != nil { + t.Fatal(err) + } // Make sure failed I/O respects the deadline. In this case we try the // common case of the receiving node being totally down. @@ -667,7 +697,7 @@ func TestEncryptDecryptState(t *testing.T) { } defer m.Shutdown() - crypt, err := m.encryptLocalState(state) + crypt, err := m.encryptLocalState(state, "") if err != nil { t.Fatalf("err: %v", err) } @@ -676,7 +706,7 @@ func TestEncryptDecryptState(t *testing.T) { buf := bytes.NewReader(crypt) buf.Seek(1, 0) - plain, err := m.decryptRemoteState(buf) + plain, err := m.decryptRemoteState(buf, "") if err != nil { t.Fatalf("err: %v", err) } diff --git a/peeked_conn.go b/peeked_conn.go new file mode 100644 index 000000000..3181d90ce --- /dev/null +++ b/peeked_conn.go @@ -0,0 +1,48 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Originally from: https://github.com/google/tcpproxy/blob/master/tcpproxy.go +// at f5c09fbedceb69e4b238dec52cdf9f2fe9a815e2 + +package memberlist + +import "net" + +// peekedConn is an incoming connection that has had some bytes read from it +// to determine how to route the connection. The Read method stitches +// the peeked bytes and unread bytes back together. +type peekedConn struct { + // Peeked are the bytes that have been read from Conn for the + // purposes of route matching, but have not yet been consumed + // by Read calls. It set to nil by Read when fully consumed. + Peeked []byte + + // Conn is the underlying connection. + // It can be type asserted against *net.TCPConn or other types + // as needed. It should not be read from directly unless + // Peeked is nil. + net.Conn +} + +func (c *peekedConn) Read(p []byte) (n int, err error) { + if len(c.Peeked) > 0 { + n = copy(p, c.Peeked) + c.Peeked = c.Peeked[n:] + if len(c.Peeked) == 0 { + c.Peeked = nil + } + return n, nil + } + return c.Conn.Read(p) +} diff --git a/security.go b/security.go index 4cb4da36f..6831be3bc 100644 --- a/security.go +++ b/security.go @@ -199,3 +199,22 @@ func decryptPayload(keys [][]byte, msg []byte, data []byte) ([]byte, error) { return nil, fmt.Errorf("No installed keys could decrypt the message") } + +func appendBytes(first []byte, second []byte) []byte { + hasFirst := len(first) > 0 + hasSecond := len(second) > 0 + + switch { + case hasFirst && hasSecond: + out := make([]byte, 0, len(first)+len(second)) + out = append(out, first...) + out = append(out, second...) + return out + case hasFirst: + return first + case hasSecond: + return second + default: + return nil + } +} diff --git a/state.go b/state.go index 95e6ddc48..7a2339e9b 100644 --- a/state.go +++ b/state.go @@ -274,6 +274,11 @@ func failedRemote(err error) bool { case "dial", "read", "write": return true } + } else if strings.HasPrefix(t.Net, "udp") { + switch t.Op { + case "write": + return true + } } } return false @@ -324,7 +329,7 @@ func (m *Memberlist) probeNode(node *nodeState) { }() if node.State == StateAlive { if err := m.encodeAndSendMsg(node.FullAddress(), pingMsg, &ping); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send ping: %s", err) + m.logger.Printf("[ERR] memberlist: Failed to send UDP ping: %s", err) if failedRemote(err) { goto HANDLE_REMOTE_FAILURE } else { @@ -334,7 +339,7 @@ func (m *Memberlist) probeNode(node *nodeState) { } else { var msgs [][]byte if buf, err := encode(pingMsg, &ping); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to encode ping message: %s", err) + m.logger.Printf("[ERR] memberlist: Failed to encode UDP ping message: %s", err) return } else { msgs = append(msgs, buf.Bytes()) @@ -349,7 +354,7 @@ func (m *Memberlist) probeNode(node *nodeState) { compound := makeCompoundMessage(msgs) if err := m.rawSendMsgPacket(node.FullAddress(), &node.Node, compound.Bytes()); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send compound ping and suspect message to %s: %s", addr, err) + m.logger.Printf("[ERR] memberlist: Failed to send UDP compound ping and suspect message to %s: %s", addr, err) if failedRemote(err) { goto HANDLE_REMOTE_FAILURE } else { @@ -388,7 +393,7 @@ func (m *Memberlist) probeNode(node *nodeState) { // probe interval it will give the TCP fallback more time, which // is more active in dealing with lost packets, and it gives more // time to wait for indirect acks/nacks. - m.logger.Printf("[DEBUG] memberlist: Failed ping: %s (timeout reached)", node.Name) + m.logger.Printf("[DEBUG] memberlist: Failed UDP ping: %s (timeout reached)", node.Name) } HANDLE_REMOTE_FAILURE: @@ -421,7 +426,7 @@ HANDLE_REMOTE_FAILURE: } if err := m.encodeAndSendMsg(peer.FullAddress(), indirectPingMsg, &ind); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send indirect ping: %s", err) + m.logger.Printf("[ERR] memberlist: Failed to send indirect UDP ping: %s", err) } } @@ -444,7 +449,11 @@ HANDLE_REMOTE_FAILURE: defer close(fallbackCh) didContact, err := m.sendPingAndWaitForAck(node.FullAddress(), ping, deadline) if err != nil { - m.logger.Printf("[ERR] memberlist: Failed fallback ping: %s", err) + var to string + if ne, ok := err.(net.Error); ok && ne.Timeout() { + to = fmt.Sprintf("timeout %s: ", probeInterval) + } + m.logger.Printf("[ERR] memberlist: Failed fallback TCP ping: %s%s", to, err) } else { fallbackCh <- didContact } @@ -469,7 +478,7 @@ HANDLE_REMOTE_FAILURE: // any additional time here. for didContact := range fallbackCh { if didContact { - m.logger.Printf("[WARN] memberlist: Was able to connect to %s but other probes failed, network may be misconfigured", node.Name) + m.logger.Printf("[WARN] memberlist: Was able to connect to %s over TCP but UDP probes failed, network may be misconfigured", node.Name) return } } @@ -587,7 +596,7 @@ func (m *Memberlist) gossip() { m.nodeLock.RUnlock() // Compute the bytes available - bytesAvail := m.config.UDPBufferSize - compoundHeaderOverhead + bytesAvail := m.config.UDPBufferSize - compoundHeaderOverhead - labelOverhead(m.config.Label) if m.config.EncryptionEnabled() { bytesAvail -= encryptOverhead(m.encryptionVersion()) } diff --git a/state_test.go b/state_test.go index 059918af8..204f5020c 100644 --- a/state_test.go +++ b/state_test.go @@ -2218,6 +2218,8 @@ func TestMemberlist_FailedRemote(t *testing.T) { {"net.OpError for tcp with dial", &net.OpError{Net: "tcp", Op: "dial"}, true}, {"net.OpError for tcp with write", &net.OpError{Net: "tcp", Op: "write"}, true}, {"net.OpError for tcp with read", &net.OpError{Net: "tcp", Op: "read"}, true}, + {"net.OpError for udp with write", &net.OpError{Net: "udp", Op: "write"}, true}, + {"net.OpError for udp with read", &net.OpError{Net: "udp", Op: "read"}, false}, } for _, test := range tests { diff --git a/transport.go b/transport.go index b23b83914..f3d05364d 100644 --- a/transport.go +++ b/transport.go @@ -111,3 +111,50 @@ func (t *shimNodeAwareTransport) WriteToAddress(b []byte, addr Address) (time.Ti func (t *shimNodeAwareTransport) DialAddressTimeout(addr Address, timeout time.Duration) (net.Conn, error) { return t.DialTimeout(addr.Addr, timeout) } + +type labelWrappedTransport struct { + label string + NodeAwareTransport +} + +var _ NodeAwareTransport = (*labelWrappedTransport)(nil) + +func (t *labelWrappedTransport) WriteToAddress(buf []byte, addr Address) (time.Time, error) { + var err error + buf, err = AddLabelHeaderToPacket(buf, t.label) + if err != nil { + return time.Time{}, fmt.Errorf("failed to add label header to packet: %w", err) + } + return t.NodeAwareTransport.WriteToAddress(buf, addr) +} + +func (t *labelWrappedTransport) WriteTo(buf []byte, addr string) (time.Time, error) { + var err error + buf, err = AddLabelHeaderToPacket(buf, t.label) + if err != nil { + return time.Time{}, err + } + return t.NodeAwareTransport.WriteTo(buf, addr) +} + +func (t *labelWrappedTransport) DialAddressTimeout(addr Address, timeout time.Duration) (net.Conn, error) { + conn, err := t.NodeAwareTransport.DialAddressTimeout(addr, timeout) + if err != nil { + return nil, err + } + if err := AddLabelHeaderToStream(conn, t.label); err != nil { + return nil, fmt.Errorf("failed to add label header to stream: %w", err) + } + return conn, nil +} + +func (t *labelWrappedTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { + conn, err := t.NodeAwareTransport.DialTimeout(addr, timeout) + if err != nil { + return nil, err + } + if err := AddLabelHeaderToStream(conn, t.label); err != nil { + return nil, fmt.Errorf("failed to add label header to stream: %w", err) + } + return conn, nil +} diff --git a/util.go b/util.go index e7be4ad88..8f609c1e0 100644 --- a/util.go +++ b/util.go @@ -96,13 +96,13 @@ func pushPullScale(interval time.Duration, n int) time.Duration { return time.Duration(multiplier) * interval } -// moveDeadNodes moves nodes that are dead and beyond the gossip to the dead interval +// moveDeadNodes moves dead and left nodes that that have not changed during the gossipToTheDeadTime interval // to the end of the slice and returns the index of the first moved node. func moveDeadNodes(nodes []*nodeState, gossipToTheDeadTime time.Duration) int { numDead := 0 n := len(nodes) for i := 0; i < n-numDead; i++ { - if nodes[i].State != StateDead { + if !nodes[i].DeadOrLeft() { continue } diff --git a/util_test.go b/util_test.go index 0b43f2aa6..4b57f7aa0 100644 --- a/util_test.go +++ b/util_test.go @@ -179,6 +179,16 @@ func TestMoveDeadNodes(t *testing.T) { State: StateDead, StateChange: time.Now().Add(-10 * time.Second), }, + // This left node should not be moved, as its state changed + // less than the specified GossipToTheDead time ago + &nodeState{ + State: StateLeft, + StateChange: time.Now().Add(-10 * time.Second), + }, + &nodeState{ + State: StateLeft, + StateChange: time.Now().Add(-20 * time.Second), + }, &nodeState{ State: StateAlive, StateChange: time.Now().Add(-20 * time.Second), @@ -191,10 +201,14 @@ func TestMoveDeadNodes(t *testing.T) { State: StateAlive, StateChange: time.Now().Add(-20 * time.Second), }, + &nodeState{ + State: StateLeft, + StateChange: time.Now().Add(-20 * time.Second), + }, } idx := moveDeadNodes(nodes, (15 * time.Second)) - if idx != 4 { + if idx != 5 { t.Fatalf("bad index") } for i := 0; i < idx; i++ { @@ -205,6 +219,11 @@ func TestMoveDeadNodes(t *testing.T) { if nodes[i].State != StateDead { t.Fatalf("Bad state %d", i) } + case 3: + //Recently left node should remain at 3 + if nodes[i].State != StateLeft { + t.Fatalf("Bad State %d", i) + } default: if nodes[i].State != StateAlive { t.Fatalf("Bad state %d", i) @@ -212,7 +231,7 @@ func TestMoveDeadNodes(t *testing.T) { } } for i := idx; i < len(nodes); i++ { - if nodes[i].State != StateDead { + if !nodes[i].DeadOrLeft() { t.Fatalf("Bad state %d", i) } }