Skip to content

Commit

Permalink
Merge pull request #398 from tallclair/client-dial-cls
Browse files Browse the repository at this point in the history
[konnectivity-client] Ensure grpc tunnel is closed on dial failure
  • Loading branch information
k8s-ci-robot authored Sep 7, 2022
2 parents aedf2cf + 529507c commit 4271084
Show file tree
Hide file tree
Showing 4 changed files with 402 additions and 110 deletions.
220 changes: 167 additions & 53 deletions konnectivity-client/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"math/rand"
"net"
"sync"
"sync/atomic"
"time"

"google.golang.org/grpc"
Expand All @@ -42,7 +43,7 @@ type Tunnel interface {
}

type dialResult struct {
err string
err *dialFailure
connid int64
}

Expand All @@ -53,13 +54,70 @@ type pendingDial struct {
cancelCh <-chan struct{}
}

// TODO: Replace with a generic implementation once it is safe to assume the client is built with go1.18+
type pendingDialManager struct {
pendingDials map[int64]pendingDial
mutex sync.RWMutex
}

func (p *pendingDialManager) add(dialID int64, pd pendingDial) {
p.mutex.Lock()
defer p.mutex.Unlock()
p.pendingDials[dialID] = pd
}

func (p *pendingDialManager) remove(dialID int64) {
p.mutex.Lock()
defer p.mutex.Unlock()
delete(p.pendingDials, dialID)
}

func (p *pendingDialManager) get(dialID int64) (pendingDial, bool) {
p.mutex.RLock()
defer p.mutex.RUnlock()
pd, ok := p.pendingDials[dialID]
return pd, ok
}

// TODO: Replace with a generic implementation once it is safe to assume the client is built with go1.18+
type connectionManager struct {
conns map[int64]*conn
mutex sync.RWMutex
}

func (cm *connectionManager) add(connID int64, c *conn) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
cm.conns[connID] = c
}

func (cm *connectionManager) remove(connID int64) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
delete(cm.conns, connID)
}

func (cm *connectionManager) get(connID int64) (*conn, bool) {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
c, ok := cm.conns[connID]
return c, ok
}

func (cm *connectionManager) closeAll() {
cm.mutex.Lock()
defer cm.mutex.Unlock()
for _, conn := range cm.conns {
close(conn.readCh)
}
}

// grpcTunnel implements Tunnel
type grpcTunnel struct {
stream client.ProxyService_ProxyClient
pendingDial map[int64]pendingDial
conns map[int64]*conn
pendingDialLock sync.RWMutex
connsLock sync.RWMutex
stream client.ProxyService_ProxyClient
clientConn clientConn
pendingDial pendingDialManager
conns connectionManager

// The tunnel will be closed if the caller fails to read via conn.Read()
// more than readTimeoutSeconds after a packet has been received.
Expand All @@ -68,6 +126,11 @@ type grpcTunnel struct {
// The done channel is closed after the tunnel has cleaned up all connections and is no longer
// serving.
done chan struct{}

// closing is an atomic bool represented as a 0 or 1, and set to true when the tunnel is being closed.
// closing should only be accessed through atomic methods.
// TODO: switch this to an atomic.Bool once the client is exclusively buit with go1.19+
closing uint32
}

type clientConn interface {
Expand Down Expand Up @@ -106,42 +169,39 @@ func CreateSingleUseGrpcTunnelWithContext(createCtx, tunnelCtx context.Context,
return nil, err
}

tunnel := newUnstartedTunnel(stream)
tunnel := newUnstartedTunnel(stream, c)

go tunnel.serve(tunnelCtx, c)
go tunnel.serve(tunnelCtx)

return tunnel, nil
}

func newUnstartedTunnel(stream client.ProxyService_ProxyClient) *grpcTunnel {
func newUnstartedTunnel(stream client.ProxyService_ProxyClient, c clientConn) *grpcTunnel {
return &grpcTunnel{
stream: stream,
pendingDial: make(map[int64]pendingDial),
conns: make(map[int64]*conn),
clientConn: c,
pendingDial: pendingDialManager{pendingDials: make(map[int64]pendingDial)},
conns: connectionManager{conns: make(map[int64]*conn)},
readTimeoutSeconds: 10,
done: make(chan struct{}),
}
}

func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) {
func (t *grpcTunnel) serve(tunnelCtx context.Context) {
defer func() {
c.Close()
t.clientConn.Close()

// A connection in t.conns after serve() returns means
// we never received a CLOSE_RSP for it, so we need to
// close any channels remaining for these connections.
t.connsLock.Lock()
for _, conn := range t.conns {
close(conn.readCh)
}
t.connsLock.Unlock()
t.conns.closeAll()

close(t.done)
}()

for {
pkt, err := t.stream.Recv()
if err == io.EOF {
if err == io.EOF || t.isClosing() {
return
}
if err != nil || pkt == nil {
Expand All @@ -154,28 +214,29 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) {
switch pkt.Type {
case client.PacketType_DIAL_RSP:
resp := pkt.GetDialResponse()
t.pendingDialLock.RLock()
pendingDial, ok := t.pendingDial[resp.Random]
t.pendingDialLock.RUnlock()
pendingDial, ok := t.pendingDial.get(resp.Random)

if !ok {
// If the DIAL_RSP does not match a pending dial, it means one of two things:
// 1. There was a second DIAL_RSP for the connection request (this is very unlikely but possible)
// 2. grpcTunnel.DialContext() returned early due to a dial timeout or the client canceling the context
//
// In either scenario, we should return here and close the tunnel as it is no longer needed.
klog.V(1).InfoS("DialResp not recognized; dropped", "connectionID", resp.ConnectID, "dialID", resp.Random)
return
} else {
result := dialResult{
err: resp.Error,
connid: resp.ConnectID,
result := dialResult{connid: resp.ConnectID}
if resp.Error != "" {
result.err = &dialFailure{resp.Error, DialFailureEndpoint}
}
select {
// try to send to the result channel
case pendingDial.resultCh <- result:
// unblock if the cancel channel is closed
case <-pendingDial.cancelCh:
// If there are no readers of the pending dial channel above, it means one of two things:
// 1. There was a second DIAL_RSP for the connection request (this is very unlikely but possible)
// 2. grpcTunnel.DialContext() returned early due to a dial timeout or the client canceling the context
//
// In either scenario, we should return here as this tunnel is no longer needed.
// Note: this condition can only be hit by a race condition where the
// DialContext() returns early (timeout) after the pendingDial is already
// fetched here, but before the result is sent.
klog.V(1).InfoS("Pending dial has been cancelled; dropped", "connectionID", resp.ConnectID, "dialID", resp.Random)
return
case <-tunnelCtx.Done():
Expand All @@ -189,12 +250,36 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) {
return
}

case client.PacketType_DIAL_CLS:
resp := pkt.GetCloseDial()
pendingDial, ok := t.pendingDial.get(resp.Random)

if !ok {
// If the DIAL_CLS does not match a pending dial, it means one of two things:
// 1. There was a DIAL_CLS receieved after a DIAL_RSP (unlikely but possible)
// 2. grpcTunnel.DialContext() returned early due to a dial timeout or the client canceling the context
//
// In either scenario, we should return here and close the tunnel as it is no longer needed.
klog.V(1).InfoS("DIAL_CLS after dial finished", "dialID", resp.Random)
} else {
result := dialResult{
err: &dialFailure{"dial closed", DialFailureDialClosed},
}
select {
case pendingDial.resultCh <- result:
case <-pendingDial.cancelCh:
// Note: this condition can only be hit by a race condition where the
// DialContext() returns early (timeout) after the pendingDial is already
// fetched here, but before the result is sent.
case <-tunnelCtx.Done():
}
}
return // Stop serving & close the tunnel.

case client.PacketType_DATA:
resp := pkt.GetData()
// TODO: flow control
t.connsLock.RLock()
conn, ok := t.conns[resp.ConnectID]
t.connsLock.RUnlock()
conn, ok := t.conns.get(resp.ConnectID)

if ok {
timer := time.NewTimer((time.Duration)(t.readTimeoutSeconds) * time.Second)
Expand All @@ -210,19 +295,16 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) {
} else {
klog.V(1).InfoS("connection not recognized", "connectionID", resp.ConnectID)
}

case client.PacketType_CLOSE_RSP:
resp := pkt.GetCloseResponse()
t.connsLock.RLock()
conn, ok := t.conns[resp.ConnectID]
t.connsLock.RUnlock()
conn, ok := t.conns.get(resp.ConnectID)

if ok {
close(conn.readCh)
conn.closeCh <- resp.Error
close(conn.closeCh)
t.connsLock.Lock()
delete(t.conns, resp.ConnectID)
t.connsLock.Unlock()
t.conns.remove(resp.ConnectID)
return
}
klog.V(1).InfoS("connection not recognized", "connectionID", resp.ConnectID)
Expand Down Expand Up @@ -252,14 +334,8 @@ func (t *grpcTunnel) DialContext(requestCtx context.Context, protocol, address s
// This channel MUST NOT be buffered. The sender needs to know when we are not receiving things, so they can abort.
resCh := make(chan dialResult)

t.pendingDialLock.Lock()
t.pendingDial[random] = pendingDial{resultCh: resCh, cancelCh: cancelCh}
t.pendingDialLock.Unlock()
defer func() {
t.pendingDialLock.Lock()
delete(t.pendingDial, random)
t.pendingDialLock.Unlock()
}()
t.pendingDial.add(random, pendingDial{resultCh: resCh, cancelCh: cancelCh})
defer t.pendingDial.remove(random)

req := &client.Packet{
Type: client.PacketType_DIAL_REQ,
Expand All @@ -280,25 +356,32 @@ func (t *grpcTunnel) DialContext(requestCtx context.Context, protocol, address s

klog.V(5).Infoln("DIAL_REQ sent to proxy server")

c := &conn{stream: t.stream, random: random}
c := &conn{
stream: t.stream,
random: random,
closeTunnel: t.closeTunnel,
}

select {
case res := <-resCh:
if res.err != "" {
return nil, &dialFailure{res.err, DialFailureEndpoint}
if res.err != nil {
return nil, res.err
}
c.connID = res.connid
c.readCh = make(chan []byte, 10)
c.closeCh = make(chan string, 1)
t.connsLock.Lock()
t.conns[res.connid] = c
t.connsLock.Unlock()
t.conns.add(res.connid, c)
case <-time.After(30 * time.Second):
klog.V(5).InfoS("Timed out waiting for DialResp", "dialID", random)
go t.closeDial(random)
return nil, &dialFailure{"dial timeout, backstop", DialFailureTimeout}
case <-requestCtx.Done():
klog.V(5).InfoS("Context canceled waiting for DialResp", "ctxErr", requestCtx.Err(), "dialID", random)
go t.closeDial(random)
return nil, &dialFailure{"dial timeout, context", DialFailureContext}
case <-t.done:
klog.V(5).InfoS("Tunnel closed while waiting for DialResp", "dialID", random)
return nil, &dialFailure{"tunnel closed", DialFailureTunnelClosed}
}

return c, nil
Expand All @@ -308,6 +391,31 @@ func (t *grpcTunnel) Done() <-chan struct{} {
return t.done
}

// Send a best-effort DIAL_CLS request for the given dial ID.
func (t *grpcTunnel) closeDial(dialID int64) {
req := &client.Packet{
Type: client.PacketType_DIAL_CLS,
Payload: &client.Packet_CloseDial{
CloseDial: &client.CloseDial{
Random: dialID,
},
},
}
if err := t.stream.Send(req); err != nil {
klog.V(5).InfoS("Failed to send DIAL_CLS", "err", err, "dialID", dialID)
}
t.closeTunnel()
}

func (t *grpcTunnel) closeTunnel() {
atomic.StoreUint32(&t.closing, 1)
t.clientConn.Close()
}

func (t *grpcTunnel) isClosing() bool {
return atomic.LoadUint32(&t.closing) != 0
}

func GetDialFailureReason(err error) (isDialFailure bool, reason DialFailureReason) {
var df *dialFailure
if errors.As(err, &df) {
Expand Down Expand Up @@ -336,4 +444,10 @@ const (
DialFailureContext DialFailureReason = "context"
// DialFailureEndpoint indicates that the konnectivity-agent was unable to reach the backend endpoint.
DialFailureEndpoint DialFailureReason = "endpoint"
// DialFailureDialClosed indicates that the client received a CloseDial response, indicating the
// connection was closed before the dial could complete.
DialFailureDialClosed DialFailureReason = "dialclosed"
// DialFailureTunnelClosed indicates that the client connection was closed before the dial could
// complete.
DialFailureTunnelClosed DialFailureReason = "tunnelclosed"
)
Loading

0 comments on commit 4271084

Please sign in to comment.