Skip to content

Commit

Permalink
[konnectivity-client] Close tunnel on dial failure
Browse files Browse the repository at this point in the history
  • Loading branch information
tallclair committed Sep 7, 2022
1 parent 41ceca0 commit 866231b
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 56 deletions.
33 changes: 27 additions & 6 deletions konnectivity-client/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"math/rand"
"net"
"sync"
"sync/atomic"
"time"

"google.golang.org/grpc"
Expand Down Expand Up @@ -56,6 +57,7 @@ type pendingDial struct {
// grpcTunnel implements Tunnel
type grpcTunnel struct {
stream client.ProxyService_ProxyClient
clientConn clientConn
pendingDial map[int64]pendingDial
conns map[int64]*conn
pendingDialLock sync.RWMutex
Expand All @@ -68,6 +70,11 @@ type grpcTunnel struct {
// The done channel is closed after the tunnel has cleaned up all connections and is no longer
// serving.
done chan struct{}

// closing is an atomic bool represented as a 0 or 1, and set to true when the tunnel is being closed.
// closing should only be accessed through atomic methods.
// TODO: switch this to an atomic.Bool once the client is exclusively buit with go1.19+
closing uint32
}

type clientConn interface {
Expand Down Expand Up @@ -106,26 +113,27 @@ func CreateSingleUseGrpcTunnelWithContext(createCtx, tunnelCtx context.Context,
return nil, err
}

tunnel := newUnstartedTunnel(stream)
tunnel := newUnstartedTunnel(stream, c)

go tunnel.serve(tunnelCtx, c)
go tunnel.serve(tunnelCtx)

return tunnel, nil
}

func newUnstartedTunnel(stream client.ProxyService_ProxyClient) *grpcTunnel {
func newUnstartedTunnel(stream client.ProxyService_ProxyClient, c clientConn) *grpcTunnel {
return &grpcTunnel{
stream: stream,
clientConn: c,
pendingDial: make(map[int64]pendingDial),
conns: make(map[int64]*conn),
readTimeoutSeconds: 10,
done: make(chan struct{}),
}
}

func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) {
func (t *grpcTunnel) serve(tunnelCtx context.Context) {
defer func() {
c.Close()
t.clientConn.Close()

// A connection in t.conns after serve() returns means
// we never received a CLOSE_RSP for it, so we need to
Expand All @@ -141,7 +149,7 @@ func (t *grpcTunnel) serve(tunnelCtx context.Context, c clientConn) {

for {
pkt, err := t.stream.Recv()
if err == io.EOF {
if err == io.EOF || t.isClosing() {
return
}
if err != nil || pkt == nil {
Expand Down Expand Up @@ -333,6 +341,9 @@ func (t *grpcTunnel) DialContext(requestCtx context.Context, protocol, address s
klog.V(5).InfoS("Context canceled waiting for DialResp", "ctxErr", requestCtx.Err(), "dialID", random)
go t.closeDial(random)
return nil, &dialFailure{"dial timeout, context", DialFailureContext}
case <-t.done:
klog.V(5).InfoS("Tunnel closed while waiting for DialResp", "dialID", random)
return nil, &dialFailure{"tunnel closed", DialFailureTunnelClosed}
}

return c, nil
Expand All @@ -355,6 +366,13 @@ func (t *grpcTunnel) closeDial(dialID int64) {
if err := t.stream.Send(req); err != nil {
klog.V(5).InfoS("Failed to send DIAL_CLS", "err", err, "dialID", dialID)
}

atomic.StoreUint32(&t.closing, 1)
t.clientConn.Close()
}

func (t *grpcTunnel) isClosing() bool {
return atomic.LoadUint32(&t.closing) != 0
}

func GetDialFailureReason(err error) (isDialFailure bool, reason DialFailureReason) {
Expand Down Expand Up @@ -388,4 +406,7 @@ const (
// DialFailureDialClosed indicates that the client received a CloseDial response, indicating the
// connection was closed before the dial could complete.
DialFailureDialClosed DialFailureReason = "dialclosed"
// DialFailureTunnelClosed indicates that the client connection was closed before the dial could
// complete.
DialFailureTunnelClosed DialFailureReason = "tunnelclosed"
)
108 changes: 67 additions & 41 deletions konnectivity-client/pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ func TestDial(t *testing.T) {
defer ps.Close()
defer s.Close()

tunnel := newUnstartedTunnel(s)
tunnel := newUnstartedTunnel(s, s.conn())

go tunnel.serve(ctx, &fakeConn{})
go tunnel.serve(ctx)
go ts.serve()

_, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80")
Expand Down Expand Up @@ -83,9 +83,9 @@ func TestDialRace(t *testing.T) {

// artificially delay after calling Send, ensure handoff of result from serve to DialContext still works
slowStream := fakeSlowSend{s}
tunnel := newUnstartedTunnel(slowStream)
tunnel := newUnstartedTunnel(slowStream, &fakeConn{})

go tunnel.serve(ctx, &fakeConn{})
go tunnel.serve(ctx)
go ts.serve()

_, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80")
Expand Down Expand Up @@ -127,9 +127,9 @@ func TestData(t *testing.T) {
defer ps.Close()
defer s.Close()

tunnel := newUnstartedTunnel(s)
tunnel := newUnstartedTunnel(s, s.conn())

go tunnel.serve(ctx, &fakeConn{})
go tunnel.serve(ctx)
go ts.serve()

conn, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80")
Expand Down Expand Up @@ -183,9 +183,9 @@ func TestClose(t *testing.T) {
defer ps.Close()
defer s.Close()

tunnel := newUnstartedTunnel(s)
tunnel := newUnstartedTunnel(s, s.conn())

go tunnel.serve(ctx, &fakeConn{})
go tunnel.serve(ctx)
go ts.serve()

conn, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80")
Expand Down Expand Up @@ -224,9 +224,9 @@ func TestCloseTimeout(t *testing.T) {
defer ps.Close()
defer s.Close()

tunnel := newUnstartedTunnel(s)
tunnel := newUnstartedTunnel(s, s.conn())

go tunnel.serve(ctx, &fakeConn{})
go tunnel.serve(ctx)
go ts.serve()

conn, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80")
Expand Down Expand Up @@ -283,9 +283,9 @@ func TestDialAfterTunnelCancelled(t *testing.T) {
defer ps.Close()
defer s.Close()

tunnel := newUnstartedTunnel(s)
tunnel := newUnstartedTunnel(s, s.conn())

go tunnel.serve(ctx, &fakeConn{})
go tunnel.serve(ctx)
go ts.serve()

_, err := tunnel.DialContext(ctx, "tcp", "127.0.0.1:80")
Expand All @@ -303,9 +303,12 @@ func TestDialAfterTunnelCancelled(t *testing.T) {
func TestDial_RequestContextCancelled(t *testing.T) {
defer goleakVerifyNone(t, goleak.IgnoreCurrent())

reqCtx, reqCancel := context.WithCancel(context.Background())
s, ps := pipe()
defer ps.Close()
defer s.Close()

ts := testServer(ps, 100)
reqCtx, reqCancel := context.WithCancel(context.Background())
ts.handlers[client.PacketType_DIAL_REQ] = func(*client.Packet) *client.Packet {
reqCancel()
return nil // don't respond
Expand All @@ -315,34 +318,45 @@ func TestDial_RequestContextCancelled(t *testing.T) {
close(closeCh)
return nil // don't respond
}
go ts.serve()

defer ps.Close()
defer s.Close()
func() {
// Tunnel should be shut down when the dial fails.
defer goleakVerifyNone(t, goleak.IgnoreCurrent())

tunnel := newUnstartedTunnel(s)
tunnel := newUnstartedTunnel(s, s.conn())
go tunnel.serve(context.Background())

go tunnel.serve(context.Background(), &fakeConn{})
go ts.serve()
_, err := tunnel.DialContext(reqCtx, "tcp", "127.0.0.1:80")
if err == nil {
t.Fatalf("Expected dial error, got none")
}

_, err := tunnel.DialContext(reqCtx, "tcp", "127.0.0.1:80")
if err == nil {
t.Fatalf("Expected dial error, got none")
}
isDialFailure, reason := GetDialFailureReason(err)
if !isDialFailure {
t.Errorf("Unexpected non-dial failure error: %v", err)
} else if reason != DialFailureContext {
t.Errorf("Expected DialFailureContext, got %v", reason)
}

isDialFailure, reason := GetDialFailureReason(err)
if !isDialFailure {
t.Errorf("Unexpected non-dial failure error: %v", err)
} else if reason != DialFailureContext {
t.Errorf("Expected DialFailureContext, got %v", reason)
}
ts.assertPacketType(0, client.PacketType_DIAL_REQ)
waitForDialClsStart := time.Now()
select {
case <-closeCh:
t.Logf("Dial closed after %#v", time.Since(waitForDialClsStart).String())
ts.assertPacketType(1, client.PacketType_DIAL_CLS)
case <-time.After(30 * time.Second):
t.Fatal("Timed out waiting for DIAL_CLS packet")
}

ts.assertPacketType(0, client.PacketType_DIAL_REQ)
select {
case <-closeCh:
ts.assertPacketType(1, client.PacketType_DIAL_CLS)
case <-time.After(30 * time.Second):
t.Fatal("Timed out waiting for DIAL_CLS packet")
}
waitForTunnelCloseStart := time.Now()
select {
case <-tunnel.Done():
t.Logf("Tunnel closed after %#v", time.Since(waitForTunnelCloseStart).String())
case <-time.After(30 * time.Second):
t.Errorf("Timed out waiting for tunnel to close")
}
}()
}

func TestDial_BackendError(t *testing.T) {
Expand All @@ -365,9 +379,9 @@ func TestDial_BackendError(t *testing.T) {
defer ps.Close()
defer s.Close()

tunnel := newUnstartedTunnel(s)
tunnel := newUnstartedTunnel(s, s.conn())

go tunnel.serve(context.Background(), &fakeConn{})
go tunnel.serve(context.Background())
go ts.serve()

_, err := tunnel.DialContext(context.Background(), "tcp", "127.0.0.1:80")
Expand Down Expand Up @@ -409,8 +423,8 @@ func TestDial_Closed(t *testing.T) {
// Verify that the tunnel goroutines are not leaked before cleaning up the test server.
goleakVerifyNone(t, goleak.IgnoreCurrent())

tunnel := newUnstartedTunnel(s)
go tunnel.serve(context.Background(), &fakeConn{})
tunnel := newUnstartedTunnel(s, s.conn())
go tunnel.serve(context.Background())

_, err := tunnel.DialContext(context.Background(), "tcp", "127.0.0.1:80")
if err == nil {
Expand Down Expand Up @@ -446,9 +460,13 @@ type fakeStream struct {
}

type fakeConn struct {
stream *fakeStream
}

func (f *fakeConn) Close() error {
if f.stream != nil {
f.stream.Close()
}
return nil
}

Expand Down Expand Up @@ -493,13 +511,21 @@ func (s *fakeStream) Recv() (*client.Packet, error) {
case pkt := <-s.r:
klog.V(4).InfoS("[DEBUG] recv", "packet", pkt)
return pkt, nil
case <-time.After(5 * time.Second):
case <-time.After(30 * time.Second):
return nil, errors.New("timeout recv")
}
}

func (s *fakeStream) Close() {
close(s.closed)
select {
case <-s.closed: // Avoid double-closing
default:
close(s.closed)
}
}

func (s *fakeStream) conn() *fakeConn {
return &fakeConn{s}
}

type proxyServer struct {
Expand Down
Loading

0 comments on commit 866231b

Please sign in to comment.