Skip to content

Commit

Permalink
client/server: Don't block the main connection loop for transport IO
Browse files Browse the repository at this point in the history
Restructures both the client and server connection management so that
sending messages on the transport is done by a separate "sender"
goroutine. The receiving end was already split out like this.

Without this change, it is possible for a send to block if the other end
isn't reading fast enough, which then would block the main connection
loop and prevent incoming messages from being processed.

Signed-off-by: Kevin Parsons <[email protected]>
  • Loading branch information
kevpar committed Dec 8, 2020
1 parent bfba540 commit 9536df6
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 20 deletions.
51 changes: 40 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,19 @@ func (r *receiver) run(ctx context.Context, c *channel) {
}

func (c *Client) run() {
type streamCall struct {
streamID uint32
call *callRequest
}
var (
streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
incoming = make(chan *message)
receiversDone = make(chan struct{})
wg sync.WaitGroup
streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
requests = make(chan streamCall)
requestsFailed = make(chan streamCall)
incoming = make(chan *message)
receiversDone = make(chan struct{})
wg sync.WaitGroup
)

// broadcast the shutdown error to the remaining waiters.
Expand All @@ -261,6 +267,21 @@ func (c *Client) run() {
}()
go recv.run(c.ctx, c.channel)

go func(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case streamCall := <-requests:
if err := c.send(streamCall.streamID, messageTypeRequest, streamCall.call.req); err != nil {
streamCall.call.errs <- err // errs is buffered so should not block.
requestsFailed <- streamCall
continue
}
}
}
}(c.ctx)

defer func() {
c.conn.Close()
c.userCloseFunc()
Expand All @@ -270,13 +291,21 @@ func (c *Client) run() {
for {
select {
case call := <-calls:
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
call.errs <- err
continue
}

go func(streamID uint32, call *callRequest) {
sc := streamCall{
streamID: streamID,
call: call,
}
select {
case <-c.ctx.Done():
case requests <- sc:
}
}(streamID, call)
waiters[streamID] = call
streamID += 2 // enforce odd client initiated request ids
case streamCall := <-requestsFailed:
// Sending the request failed, so stop tracking this stream ID.
delete(waiters, streamCall.streamID)
case msg := <-incoming:
call, ok := waiters[msg.StreamID]
if !ok {
Expand Down
43 changes: 34 additions & 9 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ func (c *serverConn) run(sctx context.Context) {
active int
state connState = connStateIdle
responses = make(chan response)
responseErr = make(chan error)
requests = make(chan request)
recvErr = make(chan error, 1)
shutdown = c.shutdown
Expand Down Expand Up @@ -412,6 +413,36 @@ func (c *serverConn) run(sctx context.Context) {
}
}(recvErr)

go func(responseErr chan error) {
for {
select {
// We don't want a case for c.shutdown here, as that would cause us to exit
// immediately when it is signaled, rather than waiting for any active requests
// to complete first. Instead, once all the active requests have completed,
// the main loop will return and close done, which will cause us to exit as well.
case <-done:
return
case response := <-responses:
p, err := c.server.codec.Marshal(response.resp)
if err != nil {
logrus.WithError(err).Error("failed marshaling response")
responseErr <- err
return
}

if err := ch.send(response.id, messageTypeResponse, p); err != nil {
logrus.WithError(err).Error("failed sending message on channel")
responseErr <- err
return
}

// Send a nil error so that the main loop knows an active request has
// completed successfully.
responseErr <- nil
}
}
}(responseErr)

for {
newstate := state
switch {
Expand Down Expand Up @@ -449,18 +480,12 @@ func (c *serverConn) run(sctx context.Context) {
case <-done:
}
}(request.id)
case response := <-responses:
p, err := c.server.codec.Marshal(response.resp)
case err := <-responseErr:
// responseErr sends nil if no error occurred in sending the response.
// In that case we just decrement the active count and continue.
if err != nil {
logrus.WithError(err).Error("failed marshaling response")
return
}

if err := ch.send(response.id, messageTypeResponse, p); err != nil {
logrus.WithError(err).Error("failed sending message on channel")
return
}

active--
case err := <-recvErr:
// TODO(stevvooe): Not wildly clear what we should do in this
Expand Down

0 comments on commit 9536df6

Please sign in to comment.