Skip to content

Commit

Permalink
agent: harden dispatcher concurrency (#484)
Browse files Browse the repository at this point in the history
* make nextTransactionID an atomic variable

This commit has 2 main benefits:
1) It makes it impossible to access nextTransactionID un-atomically
2) It fixes a small bug where we would have racy (albeit atomic)
   accesses to nextTransactionID. Consider the following interleaving:
   Dispactcher Call #1 and #2:
        Read nextTransactioID

   Dispactcher Call #1 and #2:
        Bump nextTransactionId *locally* and then write it back. The
        same value is written back twice.

   Dispactcher Call #1 and #2:
        Send a message with the newly minted transaction ID, x. Note,
        *two* messages are sent with x! So two responses will come back.

   First response arrives:
        Entry is deleted from dispatcher's waited hash map.

   Second response arrives:
        Received message with id x, but no record of it, because the
        entry was deleted when the first message arrived.

   The solution is just to use an atomic read-modify-write operation in
   the form of .Add(1)

* protect disp.waiters with mutex

disp.Call can be called from multiple threads (the main disp.run()
thread, and the healthchecker thread), so access needs to be guarded
with a mutex as the underlying map is not thread safe.

* rename nextTransactionID to lastTransactionID
  • Loading branch information
fprasx authored Aug 21, 2023
1 parent 3e905ed commit 61be293
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
35 changes: 24 additions & 11 deletions pkg/agent/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"

"go.uber.org/zap"
Expand Down Expand Up @@ -45,12 +46,16 @@ type Dispatcher struct {
// message and will send it down the SignalSender so the original sender can use it.
waiters map[uint64]util.SignalSender[*MonitorResult]

// lock guards mutating the waiters field. conn, logger, and nextTransactionID
// are all thread safe. server and protoVersion are never modified.
lock sync.Mutex

// The InformantServer that this dispatcher is part of
server *InformantServer

// nextTransactionID is the current transaction id. When we need a new one
// lastTransactionID is the last transaction id. When we need a new one
// we simply bump it and take the new number.
nextTransactionID uint64
lastTransactionID atomic.Uint64

logger *zap.Logger

Expand All @@ -59,7 +64,7 @@ type Dispatcher struct {

// Create a new Dispatcher, establishing a connection with the informant.
// Note that this does not immediately start the Dispatcher. Call Run() to start it.
func NewDispatcher(logger *zap.Logger, addr string, parent *InformantServer) (disp Dispatcher, _ error) {
func NewDispatcher(logger *zap.Logger, addr string, parent *InformantServer) (disp *Dispatcher, _ error) {
ctx := context.TODO()

logger.Info("connecting via websocket", zap.String("addr", addr))
Expand All @@ -81,25 +86,26 @@ func NewDispatcher(logger *zap.Logger, addr string, parent *InformantServer) (di
},
)
if err != nil {
return Dispatcher{}, fmt.Errorf("error sending protocol range to monitor: %w", err)
return nil, fmt.Errorf("error sending protocol range to monitor: %w", err)
}
var version api.MonitorProtocolResponse
err = wsjson.Read(ctx, c, &version)
if err != nil {
return Dispatcher{}, fmt.Errorf("error reading monitor response during protocol handshake: %w", err)
return nil, fmt.Errorf("error reading monitor response during protocol handshake: %w", err)
}
if version.Error != nil {
return Dispatcher{}, fmt.Errorf("monitor returned error during protocol handshake: %q", *version.Error)
return nil, fmt.Errorf("monitor returned error during protocol handshake: %q", *version.Error)
}
logger.Info("negotiated protocol version with monitor", zap.String("version", version.Version.String()))

disp = Dispatcher{
disp = &Dispatcher{
conn: c,
waiters: make(map[uint64]util.SignalSender[*MonitorResult]),
nextTransactionID: 0,
lastTransactionID: atomic.Uint64{},
logger: logger.Named("dispatcher"),
protoVersion: version.Version,
server: parent,
lock: sync.Mutex{},
}
return disp, nil
}
Expand All @@ -119,17 +125,24 @@ func (disp *Dispatcher) send(ctx context.Context, id uint64, message any) error
return wsjson.Write(ctx, disp.conn, &raw)
}

// registerWaiter registers a util.SignalSender to get notified when a
// message with the given id arrives.
func (disp *Dispatcher) registerWaiter(id uint64, sender util.SignalSender[*MonitorResult]) {
disp.lock.Lock()
defer disp.lock.Unlock()
disp.waiters[id] = sender
}

// Make a request to the monitor. The dispatcher will handle returning a response
// on the provided SignalSender. The value passed into message must be a valid value
// to send to the monitor. See the docs for SerializeInformantMessage.
func (disp *Dispatcher) Call(ctx context.Context, sender util.SignalSender[*MonitorResult], message any) error {
id := atomic.LoadUint64(&disp.nextTransactionID)
atomic.AddUint64(&disp.nextTransactionID, 1)
id := disp.lastTransactionID.Add(1)
err := disp.send(ctx, id, message)
if err != nil {
disp.logger.Error("failed to send message", zap.Any("message", message), zap.Error(err))
}
disp.waiters[id] = sender
disp.registerWaiter(id, sender)
return nil
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/agent/informant.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ func (s *InformantServer) RegisterWithInformant(ctx context.Context, logger *zap
)
// pre-declare disp so that err get's assigned to err from enclosing scope,
// overwriting original request error.
var disp Dispatcher
var disp *Dispatcher
disp, err = NewDispatcher(logger, addr, s)
// If the error is not nil, it will get handled below
if err == nil {
Expand All @@ -437,7 +437,7 @@ func (s *InformantServer) RegisterWithInformant(ctx context.Context, logger *zap

connectedToMonitor = true
s.informantIsMonitor = true
s.dispatcher = &disp
s.dispatcher = disp
s.mode = InformantServerRunning
s.updatedInformant.Send()
if s.runner.server == s {
Expand Down

0 comments on commit 61be293

Please sign in to comment.