Skip to content

Commit

Permalink
fix: dead lock during reopen conn of broker
Browse files Browse the repository at this point in the history
  • Loading branch information
childe committed Jul 16, 2024
1 parent 1003301 commit 45ab603
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 13 deletions.
39 changes: 30 additions & 9 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net"
"os"
"sync"
"syscall"
"time"
)

Expand All @@ -24,7 +25,8 @@ type Broker struct {

correlationID uint32

mux sync.Mutex
mux sync.Mutex
closeLock sync.Mutex
}

var (
Expand Down Expand Up @@ -66,11 +68,15 @@ func (broker *Broker) getHighestAvailableAPIVersion(apiKey uint16) uint16 {

// create a new connection to the broker, and then do the sasl authenticate if needed
func (broker *Broker) createConn() error {
// TODO split to use defer to unlock
broker.mux.Lock()
if conn, err := newConn(broker.GetAddress(), broker.config); err != nil {
broker.mux.Unlock()
return err
} else {
broker.conn = conn
}
broker.mux.Unlock()

clientID := "healer-init"
apiVersionsResponse, err := broker.requestAPIVersions(clientID)
Expand Down Expand Up @@ -186,12 +192,14 @@ func (broker *Broker) String() string {
// Close closes the connection to the broker
func (broker *Broker) Close() {
logger.Info("close broker", "broker", broker.String())
broker.mux.Lock()

broker.closeLock.Lock()
defer broker.closeLock.Unlock()

if broker.conn != nil {
broker.conn.Close()
broker.conn = nil
}
broker.mux.Unlock()
}
func (broker *Broker) ensureOpen() (err error) {
if broker.conn != nil {
Expand All @@ -215,7 +223,7 @@ func (broker *Broker) Request(r Request) (ReadParser, error) {
r.SetVersion(version)
rp, err := broker.request(r.Encode(version), timeout)
if err != nil {
return nil, fmt.Errorf("requst of %d(%d) to %s error: %w", r.API(), version, broker.GetAddress(), err)
return nil, fmt.Errorf("request of %d(%d) to %s error: %w", r.API(), version, broker.GetAddress(), err)
}
rp.version = version
rp.api = r.API()
Expand All @@ -224,18 +232,26 @@ func (broker *Broker) Request(r Request) (ReadParser, error) {
}

// RequestAndGet sends a request to the broker and returns the response
func (broker *Broker) RequestAndGet(r Request) (Response, error) {
func (broker *Broker) RequestAndGet(r Request) (resp Response, err error) {
if err := broker.ensureOpen(); err != nil {
return nil, err
}

broker.mux.Lock()
defer broker.mux.Unlock()

broker.ensureOpen()
defer func() {
if os.IsTimeout(err) || errors.Is(err, io.EOF) || errors.Is(err, syscall.EPIPE) {
broker.Close()
}
}()

rp, err := broker.Request(r)
if err != nil {
return nil, err
}

resp, err := rp.ReadAndParse()
resp, err = rp.ReadAndParse()
if err != nil {
return nil, err
}
Expand All @@ -249,7 +265,9 @@ func (broker *Broker) request(payload []byte, timeout int) (defaultReadParser, e
correlationID := binary.BigEndian.Uint32(payload[8:])
logger.V(5).Info("request info", "length", len(payload), "api", api, "apiVersion", apiVersion, "correlationID", correlationID, "timeout", timeout)

io.Copy(broker.conn, bytes.NewBuffer(payload))
if _, err := io.Copy(broker.conn, bytes.NewBuffer(payload)); err != nil {
return defaultReadParser{}, err
}

rp := defaultReadParser{
broker: broker,
Expand All @@ -261,7 +279,6 @@ func (broker *Broker) request(payload []byte, timeout int) (defaultReadParser, e
func (broker *Broker) requestStreamingly(payload []byte, timeout int) (r io.Reader, responseLength uint32, err error) {
defer func() {
if err != nil {
logger.Info("requestStreamingly error", "error", err)
broker.Close()
}
}()
Expand Down Expand Up @@ -337,6 +354,10 @@ func (broker *Broker) requestOffsets(clientID, topic string, partitionIDs []int3
}

func (broker *Broker) requestFetchStreamingly(fetchRequest *FetchRequest) (r io.Reader, responseLength uint32, err error) {
if err := broker.ensureOpen(); err != nil {
return nil, 0, err
}

broker.mux.Lock()
defer broker.mux.Unlock()

Expand Down
46 changes: 45 additions & 1 deletion broker_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package healer

import (
"errors"
"io"
"net"
"testing"

Expand All @@ -11,7 +13,7 @@ import (
func TestNewBroker(t *testing.T) {
mockey.PatchConvey("TestNewBroker", t, func() {
mockey.Mock((*Broker).requestAPIVersions).Return(APIVersionsResponse{}, nil).Build()
mockey.Mock((*net.Dialer).Dial).Return(nil, nil).Build()
mockey.Mock((*net.Dialer).Dial).Return(&MockConn{}, nil).Build()
broker, err := NewBroker("127.0.0.1:9092", 0, DefaultBrokerConfig())
convey.So(err, convey.ShouldBeNil)
convey.So(broker, convey.ShouldNotBeNil)
Expand Down Expand Up @@ -44,3 +46,45 @@ func TestGetHighestAvailableAPIVersion(t *testing.T) {
convey.So(got, convey.ShouldEqual, 0)
})
}

func TestReopenConn(t *testing.T) {
mockey.PatchConvey("conn EOF and reopen new conn", t, func() {
mockey.Mock(newAPIVersionsResponse).Return(APIVersionsResponse{}, nil).Build()
mockey.Mock((*net.Dialer).Dial).Return(&MockConn{}, nil).Build()
mockey.Mock(NewMetadataResponse).Return(MetadataResponse{}, nil).Build()
brokerCloseOrigin := (*Broker).Close
brokerClose := mockey.Mock((*Broker).Close).To(func(broker *Broker) { (brokerCloseOrigin)(broker) }).Origin(&brokerCloseOrigin).Build()
ensureOpenOrigin := (*Broker).ensureOpen
ensureOpen := mockey.Mock((*Broker).ensureOpen).To(func(broker *Broker) error { return (ensureOpenOrigin)(broker) }).Origin(&ensureOpenOrigin).Build()
createConnOrigin := (*Broker).createConn
createConn := mockey.Mock((*Broker).createConn).To(func(broker *Broker) error { return (createConnOrigin)(broker) }).Origin(&createConnOrigin).Build()

broker, err := NewBroker("127.0.0.1:9092", 0, DefaultBrokerConfig())
convey.So(err, convey.ShouldBeNil)
convey.So(broker, convey.ShouldNotBeNil)

convey.So(ensureOpen.Times(), convey.ShouldEqual, 1)
convey.So(createConn.Times(), convey.ShouldEqual, 1)

failCount := 0
mockey.Mock(defaultReadParser.Read).
To(func() ([]byte, error) {
if failCount == 0 {
failCount++
return nil, io.EOF
}
return make([]byte, 0), nil
}).Build()

req := NewMetadataRequest("test-clientID", []string{"test-topic"})
_, err = broker.RequestAndGet(req) // EOF
convey.So(errors.Is(err, io.EOF), convey.ShouldBeTrue)

_, err = broker.RequestAndGet(req)
convey.So(err, convey.ShouldBeNil)

convey.So(brokerClose.Times(), convey.ShouldEqual, 1)
convey.So(ensureOpen.Times(), convey.ShouldEqual, 4)
convey.So(createConn.Times(), convey.ShouldEqual, 2)
})
}
50 changes: 50 additions & 0 deletions conn_mock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package healer

import (
"net"
"time"
)

// MockConn is a mock struct for the Conn type
type MockConn struct {
MockRead func(p []byte) (n int, err error)
MockWrite func(p []byte) (n int, err error)
MockClose func() error
MockLocalAddr func() net.Addr
MockRemoteAddr func() net.Addr
MockSetDeadline func(t time.Time) error
MockSetReadDeadline func(t time.Time) error
MockSetWriteDeadline func(t time.Time) error
}

func (_m *MockConn) Read(p []byte) (n int, err error) {
return len(p), nil
}

func (_m *MockConn) Write(p []byte) (n int, err error) {
return len(p), nil
}

func (_m *MockConn) Close() error {
return nil
}

func (_m *MockConn) LocalAddr() net.Addr {
return nil
}

func (_m *MockConn) RemoteAddr() net.Addr {
return nil
}

func (_m *MockConn) SetDeadline(t time.Time) error {
return nil
}

func (_m *MockConn) SetReadDeadline(t time.Time) error {
return nil
}

func (_m *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
5 changes: 2 additions & 3 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func (p defaultReadParser) ReadAndParse() (Response, error) {

// Read read a whole response data from broker. it firstly read length of the response data, then read the whole response data
func (p defaultReadParser) Read() ([]byte, error) {
//TODO use LimitedReader
l := 0
responseLengthBuf := make([]byte, 4)
for {
Expand All @@ -50,7 +51,6 @@ func (p defaultReadParser) Read() ([]byte, error) {
}
length, err := p.broker.conn.Read(responseLengthBuf[l:])
if err != nil {
p.broker.Close()
return nil, err
}

Expand All @@ -69,7 +69,6 @@ func (p defaultReadParser) Read() ([]byte, error) {
}
length, err := p.broker.conn.Read(resp[4+readLength:])
if err != nil {
p.broker.Close()
return nil, err
}

Expand All @@ -82,7 +81,7 @@ func (p defaultReadParser) Read() ([]byte, error) {
}
}
copy(resp[0:4], responseLengthBuf)
logger.V(5).Info("response info", "length", len(resp), "CorrelationID", binary.BigEndian.Uint32(resp[4:]))
// logger.V(5).Info("response info", "length", len(resp), "CorrelationID", binary.BigEndian.Uint32(resp[4:]))
return resp, nil
}

Expand Down
6 changes: 6 additions & 0 deletions simple_consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ func (c *SimpleConsumer) consumeLoop(messages chan *FullMessage) {
c.consumeLoopWg.Add(1)
defer c.consumeLoopWg.Done()

defer logger.Info("consume loop exit", "topic", c.topic, "partitionID", c.partitionID)

wg := &sync.WaitGroup{}
for !c.stop {
innerMessages := make(chan *FullMessage, 1)
Expand All @@ -413,6 +415,8 @@ func (c *SimpleConsumer) consumeLoop(messages chan *FullMessage) {
return
}
logger.Error(err, "failed to fetch")
time.Sleep(time.Millisecond * time.Duration(c.config.RetryBackOffMS))
continue
}

//decode
Expand Down Expand Up @@ -480,6 +484,8 @@ func (c *SimpleConsumer) consumeMessages(innerMessages chan *FullMessage, messag
break
}
}
} else if message.Error == KafkaError(74) {
c.refreshPartiton()
}
return
} else {
Expand Down

0 comments on commit 45ab603

Please sign in to comment.