Skip to content

Commit

Permalink
[konnectivity-client] Always use defer for mutex.Unlock
Browse files Browse the repository at this point in the history
  • Loading branch information
tallclair committed Sep 7, 2022
1 parent 6539562 commit 529507c
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 42 deletions.
112 changes: 73 additions & 39 deletions konnectivity-client/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,70 @@ type pendingDial struct {
cancelCh <-chan struct{}
}

// TODO: Replace with a generic implementation once it is safe to assume the client is built with go1.18+
type pendingDialManager struct {
pendingDials map[int64]pendingDial
mutex sync.RWMutex
}

func (p *pendingDialManager) add(dialID int64, pd pendingDial) {
p.mutex.Lock()
defer p.mutex.Unlock()
p.pendingDials[dialID] = pd
}

func (p *pendingDialManager) remove(dialID int64) {
p.mutex.Lock()
defer p.mutex.Unlock()
delete(p.pendingDials, dialID)
}

func (p *pendingDialManager) get(dialID int64) (pendingDial, bool) {
p.mutex.RLock()
defer p.mutex.RUnlock()
pd, ok := p.pendingDials[dialID]
return pd, ok
}

// TODO: Replace with a generic implementation once it is safe to assume the client is built with go1.18+
type connectionManager struct {
conns map[int64]*conn
mutex sync.RWMutex
}

func (cm *connectionManager) add(connID int64, c *conn) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
cm.conns[connID] = c
}

func (cm *connectionManager) remove(connID int64) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
delete(cm.conns, connID)
}

func (cm *connectionManager) get(connID int64) (*conn, bool) {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
c, ok := cm.conns[connID]
return c, ok
}

func (cm *connectionManager) closeAll() {
cm.mutex.Lock()
defer cm.mutex.Unlock()
for _, conn := range cm.conns {
close(conn.readCh)
}
}

// grpcTunnel implements Tunnel
type grpcTunnel struct {
stream client.ProxyService_ProxyClient
clientConn clientConn
pendingDial map[int64]pendingDial
conns map[int64]*conn
pendingDialLock sync.RWMutex
connsLock sync.RWMutex
stream client.ProxyService_ProxyClient
clientConn clientConn
pendingDial pendingDialManager
conns connectionManager

// The tunnel will be closed if the caller fails to read via conn.Read()
// more than readTimeoutSeconds after a packet has been received.
Expand Down Expand Up @@ -124,8 +180,8 @@ func newUnstartedTunnel(stream client.ProxyService_ProxyClient, c clientConn) *g
return &grpcTunnel{
stream: stream,
clientConn: c,
pendingDial: make(map[int64]pendingDial),
conns: make(map[int64]*conn),
pendingDial: pendingDialManager{pendingDials: make(map[int64]pendingDial)},
conns: connectionManager{conns: make(map[int64]*conn)},
readTimeoutSeconds: 10,
done: make(chan struct{}),
}
Expand All @@ -138,11 +194,7 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context) {
// A connection in t.conns after serve() returns means
// we never received a CLOSE_RSP for it, so we need to
// close any channels remaining for these connections.
t.connsLock.Lock()
for _, conn := range t.conns {
close(conn.readCh)
}
t.connsLock.Unlock()
t.conns.closeAll()

close(t.done)
}()
Expand All @@ -162,9 +214,7 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context) {
switch pkt.Type {
case client.PacketType_DIAL_RSP:
resp := pkt.GetDialResponse()
t.pendingDialLock.RLock()
pendingDial, ok := t.pendingDial[resp.Random]
t.pendingDialLock.RUnlock()
pendingDial, ok := t.pendingDial.get(resp.Random)

if !ok {
// If the DIAL_RSP does not match a pending dial, it means one of two things:
Expand Down Expand Up @@ -202,9 +252,7 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context) {

case client.PacketType_DIAL_CLS:
resp := pkt.GetCloseDial()
t.pendingDialLock.RLock()
pendingDial, ok := t.pendingDial[resp.Random]
t.pendingDialLock.RUnlock()
pendingDial, ok := t.pendingDial.get(resp.Random)

if !ok {
// If the DIAL_CLS does not match a pending dial, it means one of two things:
Expand All @@ -231,9 +279,7 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context) {
case client.PacketType_DATA:
resp := pkt.GetData()
// TODO: flow control
t.connsLock.RLock()
conn, ok := t.conns[resp.ConnectID]
t.connsLock.RUnlock()
conn, ok := t.conns.get(resp.ConnectID)

if ok {
timer := time.NewTimer((time.Duration)(t.readTimeoutSeconds) * time.Second)
Expand All @@ -252,17 +298,13 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context) {

case client.PacketType_CLOSE_RSP:
resp := pkt.GetCloseResponse()
t.connsLock.RLock()
conn, ok := t.conns[resp.ConnectID]
t.connsLock.RUnlock()
conn, ok := t.conns.get(resp.ConnectID)

if ok {
close(conn.readCh)
conn.closeCh <- resp.Error
close(conn.closeCh)
t.connsLock.Lock()
delete(t.conns, resp.ConnectID)
t.connsLock.Unlock()
t.conns.remove(resp.ConnectID)
return
}
klog.V(1).InfoS("connection not recognized", "connectionID", resp.ConnectID)
Expand Down Expand Up @@ -292,14 +334,8 @@ func (t *grpcTunnel) DialContext(requestCtx context.Context, protocol, address s
// This channel MUST NOT be buffered. The sender needs to know when we are not receiving things, so they can abort.
resCh := make(chan dialResult)

t.pendingDialLock.Lock()
t.pendingDial[random] = pendingDial{resultCh: resCh, cancelCh: cancelCh}
t.pendingDialLock.Unlock()
defer func() {
t.pendingDialLock.Lock()
delete(t.pendingDial, random)
t.pendingDialLock.Unlock()
}()
t.pendingDial.add(random, pendingDial{resultCh: resCh, cancelCh: cancelCh})
defer t.pendingDial.remove(random)

req := &client.Packet{
Type: client.PacketType_DIAL_REQ,
Expand Down Expand Up @@ -334,9 +370,7 @@ func (t *grpcTunnel) DialContext(requestCtx context.Context, protocol, address s
c.connID = res.connid
c.readCh = make(chan []byte, 10)
c.closeCh = make(chan string, 1)
t.connsLock.Lock()
t.conns[res.connid] = c
t.connsLock.Unlock()
t.conns.add(res.connid, c)
case <-time.After(30 * time.Second):
klog.V(5).InfoS("Timed out waiting for DialResp", "dialID", random)
go t.closeDial(random)
Expand Down
8 changes: 5 additions & 3 deletions konnectivity-client/pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,9 +566,11 @@ func (s *proxyServer) serve() {
return
}

s.packetsLock.Lock()
s.packets = append(s.packets, pkt)
s.packetsLock.Unlock()
func() {
s.packetsLock.Lock()
defer s.packetsLock.Unlock()
s.packets = append(s.packets, pkt)
}()

if handler, ok := s.handlers[pkt.Type]; ok {
req := handler(pkt)
Expand Down

0 comments on commit 529507c

Please sign in to comment.