From 5a5182f48f6869b0ed04fe4701e6f920ea18428b Mon Sep 17 00:00:00 2001 From: nitwhiz Date: Sat, 17 Feb 2024 15:57:53 +0100 Subject: [PATCH] fix server lock up in high load scenarios --- .github/workflows/integration-tests.yml | 27 +++++ cmd/cli/main.go | 6 +- pkg/client/client.go | 31 +++++- pkg/lock/lock_table.go | 72 ++++++++++++ pkg/locking/lock_table.go | 67 ------------ pkg/locking/table.go | 85 --------------- pkg/server/server.go | 97 +++++++++-------- pkg/server/server_command.go | 20 +++- pkg/server/server_command_handler_lock.go | 13 ++- pkg/server/server_command_handler_trylock.go | 9 +- pkg/server/server_command_handler_unlock.go | 9 +- pkg/server/server_option.go | 25 ----- pkg/server/server_options.go | 57 ++++++++++ pkg/util/util_map.go | 86 +++++++++++++++ test/integration/server_lock_test.go | 109 ++++++++++++++++++- test/integration/server_trylock_test.go | 21 +++- test/integration/test_helpers.go | 14 ++- 17 files changed, 495 insertions(+), 253 deletions(-) create mode 100644 .github/workflows/integration-tests.yml create mode 100644 pkg/lock/lock_table.go delete mode 100644 pkg/locking/lock_table.go delete mode 100644 pkg/locking/table.go delete mode 100644 pkg/server/server_option.go create mode 100644 pkg/server/server_options.go create mode 100644 pkg/util/util_map.go diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 0000000..6f48da3 --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,27 @@ +name: Run integration tests + +on: [ push ] + +env: + GO_VERSION: 1.21.6 + +jobs: + integration-tests: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: ${{ env.GO_VERSION }} + - name: Build + run: go build -v ./... + - name: Run integration tests + run: go test -json ./test/integration > IntegrationTestResults-${{ env.GO_VERSION }}.json + - name: Upload Go integration tests results + uses: actions/upload-artifact@v4 + with: + name: Go-results-${{ env.GO_VERSION }} + path: IntegrationTestResults-${{ env.GO_VERSION }}.json diff --git a/cmd/cli/main.go b/cmd/cli/main.go index f34b708..6db044b 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -2,17 +2,17 @@ package main import ( "context" - "fmt" "github.com/nitwhiz/omnilock/pkg/server" + "log" ) func main() { s, err := server.New(context.Background()) if err != nil { - fmt.Println(err) + log.Println(err) return } - s.Accept() + s.Run() } diff --git a/pkg/client/client.go b/pkg/client/client.go index 4fdf1a0..ae37fb7 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -5,6 +5,8 @@ import ( "context" "github.com/nitwhiz/omnilock/pkg/id" "net" + "sync" + "time" ) type Client struct { @@ -13,6 +15,7 @@ type Client struct { conn net.Conn reader *bufio.Reader cmdChan chan<- *Command + mu *sync.Mutex } func New(ctx context.Context, conn net.Conn, cmdChan chan<- *Command) *Client { @@ -22,11 +25,24 @@ func New(ctx context.Context, conn net.Conn, cmdChan chan<- *Command) *Client { conn: conn, reader: bufio.NewReader(conn), cmdChan: cmdChan, + mu: &sync.Mutex{}, } return &c } +func (c *Client) Done() <-chan struct{} { + return c.ctx.Done() +} + +func (c *Client) Lock() { + c.mu.Lock() +} + +func (c *Client) Unlock() { + c.mu.Unlock() +} + func (c *Client) GetID() uint64 { return c.id } @@ -35,7 +51,13 @@ func (c *Client) GetContext() context.Context { return c.ctx } -func (c *Client) Write(b []byte) (int, error) { +func (c *Client) Write(b []byte, timeout time.Duration) (int, error) { + err := c.conn.SetWriteDeadline(time.Now().Add(timeout)) + + if err != nil { + return 0, err + } + return c.conn.Write(b) } @@ -46,9 +68,14 @@ func (c *Client) waitForCommand() bool { return false } - c.cmdChan <- &Command{ + select { + case <-c.ctx.Done(): + return false + case c.cmdChan <- &Command{ Client: c, Cmd: cmdString[:len(cmdString)-1], + }: + break } return true diff --git a/pkg/lock/lock_table.go b/pkg/lock/lock_table.go new file mode 100644 index 0000000..b131abe --- /dev/null +++ b/pkg/lock/lock_table.go @@ -0,0 +1,72 @@ +package lock + +import ( + "context" + "github.com/nitwhiz/omnilock/pkg/util" + "sync" + "time" +) + +type Table struct { + locks *util.Map[string, uint64] +} + +func NewTable() *Table { + return &Table{ + locks: util.NewMap[string, uint64](), + } +} + +func (t *Table) TryLock(name string, clientId uint64) bool { + return t.locks.TryPut(name, clientId) +} + +func (t *Table) Lock(ctx context.Context, name string, clientId uint64) bool { + if t.TryLock(name, clientId) { + return true + } + + result := false + + wg := &sync.WaitGroup{} + isRunning := false + + go func() { + wg.Add(1) + defer wg.Done() + + isRunning = true + + for { + select { + case <-ctx.Done(): + return + case <-time.After(time.Microsecond * 250): + if t.TryLock(name, clientId) { + result = true + return + } + + break + } + } + }() + + for !isRunning { + <-time.After(time.Millisecond) + } + + wg.Wait() + + return result +} + +func (t *Table) Unlock(name string, clientId uint64) bool { + return t.locks.RemoveIf(name, func(v uint64) bool { + return clientId == v + }) +} + +func (t *Table) UnlockAll(clientId uint64) { + t.locks.RemoveByValue(clientId) +} diff --git a/pkg/locking/lock_table.go b/pkg/locking/lock_table.go deleted file mode 100644 index 691f286..0000000 --- a/pkg/locking/lock_table.go +++ /dev/null @@ -1,67 +0,0 @@ -package locking - -import ( - "context" - "github.com/nitwhiz/omnilock/pkg/client" - "time" -) - -type LockTable struct { - locks *Table[string, uint64] -} - -func NewLockTable() *LockTable { - return &LockTable{ - locks: New[string, uint64](), - } -} - -func (t *LockTable) acquireLockWithContext(c *client.Client, name string, ctx context.Context) bool { - for { - if t.TryLock(c, name) { - return true - } - - select { - case <-ctx.Done(): - return false - case <-time.After(time.Millisecond): - break - } - } -} - -func (t *LockTable) Lock(c *client.Client, name string) bool { - return t.acquireLockWithContext(c, name, c.GetContext()) -} - -func (t *LockTable) LockWithTimeout(c *client.Client, name string, timeout time.Duration) bool { - lockCtx, cancel := context.WithTimeout(c.GetContext(), timeout) - defer cancel() - - return t.acquireLockWithContext(c, name, lockCtx) -} - -func (t *LockTable) TryLock(c *client.Client, name string) bool { - return t.locks.TryPut(name, c.GetID()) -} - -func (t *LockTable) Unlock(c *client.Client, name string) bool { - cID := c.GetID() - - return t.locks.RemoveIf(name, func(v uint64) bool { - return cID == v - }) -} - -func (t *LockTable) forceUnlock(name string) { - t.locks.Remove(name) -} - -func (t *LockTable) UnlockAllForClient(c *client.Client) { - t.locks.RemoveByValue(c.GetID()) -} - -func (t *LockTable) Count() int { - return t.locks.Len() -} diff --git a/pkg/locking/table.go b/pkg/locking/table.go deleted file mode 100644 index 5416d33..0000000 --- a/pkg/locking/table.go +++ /dev/null @@ -1,85 +0,0 @@ -package locking - -import "sync" - -type keyType interface { - uint64 | string -} - -type Table[K keyType, V comparable] struct { - mu *sync.RWMutex - entries map[K]V -} - -func New[K keyType, V comparable]() *Table[K, V] { - return &Table[K, V]{ - mu: &sync.RWMutex{}, - entries: map[K]V{}, - } -} - -func (t *Table[K, V]) TryPut(k K, v V) bool { - t.mu.Lock() - defer t.mu.Unlock() - - if _, ok := t.entries[k]; ok { - return false - } - - t.entries[k] = v - - return true -} - -func (t *Table[K, V]) Exists(k K) bool { - t.mu.RLock() - defer t.mu.RUnlock() - - _, ok := t.entries[k] - - return ok -} - -func (t *Table[K, V]) Remove(k K) bool { - t.mu.Lock() - defer t.mu.Unlock() - - if _, ok := t.entries[k]; ok { - delete(t.entries, k) - - return true - } - - return false -} - -func (t *Table[K, V]) RemoveIf(k K, callback func(v V) bool) bool { - t.mu.Lock() - defer t.mu.Unlock() - - if v, ok := t.entries[k]; ok && callback(v) { - delete(t.entries, k) - - return true - } - - return false -} - -func (t *Table[K, V]) RemoveByValue(v V) { - t.mu.Lock() - defer t.mu.Unlock() - - for k, tv := range t.entries { - if tv == v { - delete(t.entries, k) - } - } -} - -func (t *Table[K, V]) Len() int { - t.mu.RLock() - defer t.mu.RUnlock() - - return len(t.entries) -} diff --git a/pkg/server/server.go b/pkg/server/server.go index b3f801c..540a07e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -2,65 +2,45 @@ package server import ( "context" - "fmt" "github.com/nitwhiz/omnilock/pkg/client" - "github.com/nitwhiz/omnilock/pkg/locking" + "github.com/nitwhiz/omnilock/pkg/lock" + "log" "net" - "runtime" "sync" "time" ) type Server struct { - wg *sync.WaitGroup - keepAlive time.Duration - listenAddr string - listener net.Listener - acceptorCount int - ctx context.Context - LockTable *locking.LockTable - cmdChan chan *client.Command - cmdHandlers map[string]CommandHandler -} - -func (s *Server) applyDefaults() { - if s.keepAlive == 0 { - WithKeepAlivePeriod(time.Second * 5)(s) - } - - if s.acceptorCount == 0 { - WithAcceptorCount(runtime.NumCPU())(s) - } - - if s.listenAddr == "" { - WithListenAddr("0.0.0.0:7194")(s) - } + options *Options + listener net.Listener + wg *sync.WaitGroup + ctx context.Context + locks *lock.Table + cmdHandlers map[string]CommandHandler + cmdChan chan *client.Command } func New(ctx context.Context, opts ...Option) (*Server, error) { s := Server{ - wg: &sync.WaitGroup{}, - keepAlive: 0, - listenAddr: "", - listener: nil, - acceptorCount: 0, - ctx: ctx, - cmdChan: make(chan *client.Command), - cmdHandlers: map[string]CommandHandler{}, - LockTable: locking.NewLockTable(), + options: &Options{}, + wg: &sync.WaitGroup{}, + ctx: ctx, + locks: lock.NewTable(), + cmdHandlers: map[string]CommandHandler{}, + cmdChan: make(chan *client.Command), } for _, withOption := range opts { - withOption(&s) + withOption(s.options) } - s.applyDefaults() + s.options.applyDefaults() listenConfig := net.ListenConfig{ - KeepAlive: s.keepAlive, + KeepAlive: s.options.keepAlive, } - listener, err := listenConfig.Listen(s.ctx, "tcp", s.listenAddr) + listener, err := listenConfig.Listen(ctx, "tcp", s.options.listenAddr) if err != nil { return nil, err @@ -81,6 +61,10 @@ func (s *Server) acceptor() { conn, err := s.listener.Accept() if err != nil { + if s.ctx.Err() != nil { + return + } + continue } @@ -92,30 +76,47 @@ func (s *Server) handleConnection(conn net.Conn) { s.wg.Add(1) defer s.wg.Done() - c := client.New(s.ctx, conn, s.cmdChan) + ctx, cancel := context.WithCancel(s.ctx) - c.ListenForCommands() + c := client.New(ctx, conn, s.cmdChan) - s.LockTable.UnlockAllForClient(c) + clientId := c.GetID() + + defer log.Printf("Client #%d disonnected.\n", clientId) + + defer func() { + c.Lock() + defer c.Unlock() + + s.locks.UnlockAll(clientId) + }() + + defer cancel() + + log.Printf("Client #%d connected from %s.\n", clientId, conn.RemoteAddr().String()) + + c.ListenForCommands() } -func (s *Server) waitForShutdown() { +func (s *Server) wait() { <-s.ctx.Done() s.wg.Wait() } -func (s *Server) Accept() { +func (s *Server) Run() { go s.startCommandListener() - for i := 0; i < s.acceptorCount; i++ { + for i := 0; i < s.options.acceptorCount; i++ { go s.acceptor() } - fmt.Println("Ready!") + log.Println("Ready to serve connections") + + <-time.After(time.Millisecond) - s.waitForShutdown() + s.wait() } func (s *Server) Write(c *client.Client, msg string) { - _, _ = c.Write([]byte(msg + "\n")) + _, _ = c.Write([]byte(msg+"\n"), s.options.clientTimeout) } diff --git a/pkg/server/server_command.go b/pkg/server/server_command.go index ac10046..0f0ab74 100644 --- a/pkg/server/server_command.go +++ b/pkg/server/server_command.go @@ -16,6 +16,19 @@ func (s *Server) initCommandHandlers() { } func (s *Server) handleCommand(cmd *client.Command) { + s.wg.Add(1) + defer s.wg.Done() + + cmd.Client.Lock() + defer cmd.Client.Unlock() + + select { + case <-cmd.Client.Done(): + return + default: + break + } + argv := strings.Split(cmd.Cmd, " ") if len(argv) < 1 { @@ -50,13 +63,12 @@ func (s *Server) startCommandListener() { for { select { - case c := <-s.cmdChan: - s.handleCommand(c) - break case <-s.ctx.Done(): _ = s.listener.Close() - return + case c := <-s.cmdChan: + go s.handleCommand(c) + break } } } diff --git a/pkg/server/server_command_handler_lock.go b/pkg/server/server_command_handler_lock.go index 964a983..f934572 100644 --- a/pkg/server/server_command_handler_lock.go +++ b/pkg/server/server_command_handler_lock.go @@ -1,8 +1,10 @@ package server import ( + "context" "errors" "github.com/nitwhiz/omnilock/pkg/client" + "log" "strconv" "time" ) @@ -30,12 +32,19 @@ func LockHandler(s *Server) CommandHandler { timeout = time.Millisecond * time.Duration(timeoutInt) } + result = false + if timeout == 0 { - result = s.LockTable.Lock(c, lockName) + result = s.locks.Lock(s.ctx, lockName, c.GetID()) } else { - result = s.LockTable.LockWithTimeout(c, lockName, timeout) + ctx, cancel := context.WithTimeout(s.ctx, timeout) + defer cancel() + + result = s.locks.Lock(ctx, lockName, c.GetID()) } + log.Printf("Client #%d requested lock '%s': %v.\n", c.GetID(), lockName, result) + return result, nil } } diff --git a/pkg/server/server_command_handler_trylock.go b/pkg/server/server_command_handler_trylock.go index bc0467b..772b9c4 100644 --- a/pkg/server/server_command_handler_trylock.go +++ b/pkg/server/server_command_handler_trylock.go @@ -3,6 +3,7 @@ package server import ( "errors" "github.com/nitwhiz/omnilock/pkg/client" + "log" ) func TryLockHandler(s *Server) CommandHandler { @@ -11,6 +12,12 @@ func TryLockHandler(s *Server) CommandHandler { return false, errors.New("not enough arguments") } - return s.LockTable.TryLock(c, argv[0]), nil + lockName := argv[0] + + result = s.locks.TryLock(lockName, c.GetID()) + + log.Printf("Client #%d tried lock '%s': %v.\n", c.GetID(), lockName, result) + + return result, nil } } diff --git a/pkg/server/server_command_handler_unlock.go b/pkg/server/server_command_handler_unlock.go index a30f132..4b5ca1f 100644 --- a/pkg/server/server_command_handler_unlock.go +++ b/pkg/server/server_command_handler_unlock.go @@ -3,6 +3,7 @@ package server import ( "errors" "github.com/nitwhiz/omnilock/pkg/client" + "log" ) func UnlockHandler(s *Server) CommandHandler { @@ -11,6 +12,12 @@ func UnlockHandler(s *Server) CommandHandler { return false, errors.New("not enough arguments") } - return s.LockTable.Unlock(c, argv[0]), nil + lockName := argv[0] + + result = s.locks.Unlock(lockName, c.GetID()) + + log.Printf("Client #%d requested unlock '%s': %v.\n", c.GetID(), lockName, result) + + return result, nil } } diff --git a/pkg/server/server_option.go b/pkg/server/server_option.go deleted file mode 100644 index fb9c733..0000000 --- a/pkg/server/server_option.go +++ /dev/null @@ -1,25 +0,0 @@ -package server - -import ( - "time" -) - -type Option func(*Server) - -func WithKeepAlivePeriod(p time.Duration) Option { - return func(s *Server) { - s.keepAlive = p - } -} - -func WithAcceptorCount(acceptorCount int) Option { - return func(s *Server) { - s.acceptorCount = acceptorCount - } -} - -func WithListenAddr(listenAddr string) Option { - return func(s *Server) { - s.listenAddr = listenAddr - } -} diff --git a/pkg/server/server_options.go b/pkg/server/server_options.go new file mode 100644 index 0000000..6151dd4 --- /dev/null +++ b/pkg/server/server_options.go @@ -0,0 +1,57 @@ +package server + +import ( + "runtime" + "time" +) + +type Options struct { + keepAlive time.Duration + acceptorCount int + listenAddr string + clientTimeout time.Duration +} + +type Option func(*Options) + +func (o *Options) applyDefaults() { + if o.keepAlive == 0 { + WithKeepAlivePeriod(time.Second * 5)(o) + } + + if o.acceptorCount == 0 { + WithAcceptorCount(runtime.NumCPU())(o) + } + + if o.listenAddr == "" { + WithListenAddr("0.0.0.0:7194")(o) + } + + if o.clientTimeout == 0 { + WithClientTimeout(time.Second * 5)(o) + } +} + +func WithKeepAlivePeriod(p time.Duration) Option { + return func(o *Options) { + o.keepAlive = p + } +} + +func WithAcceptorCount(acceptorCount int) Option { + return func(o *Options) { + o.acceptorCount = acceptorCount + } +} + +func WithListenAddr(listenAddr string) Option { + return func(o *Options) { + o.listenAddr = listenAddr + } +} + +func WithClientTimeout(clientTimeout time.Duration) Option { + return func(o *Options) { + o.clientTimeout = clientTimeout + } +} diff --git a/pkg/util/util_map.go b/pkg/util/util_map.go new file mode 100644 index 0000000..862f0bb --- /dev/null +++ b/pkg/util/util_map.go @@ -0,0 +1,86 @@ +package util + +import ( + "sync" +) + +type keyType interface { + uint64 | string +} + +// Map is a goroutine-safe map +type Map[K keyType, V comparable] struct { + mu *sync.RWMutex + entries map[K]V +} + +func NewMap[K keyType, V comparable]() *Map[K, V] { + return &Map[K, V]{ + mu: &sync.RWMutex{}, + entries: map[K]V{}, + } +} + +func (m *Map[K, V]) TryPut(k K, v V) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.entries[k]; ok { + return false + } + + m.entries[k] = v + + return true +} + +func (m *Map[K, V]) Exists(k K) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + _, ok := m.entries[k] + + return ok +} + +func (m *Map[K, V]) Remove(k K) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.entries[k]; ok { + delete(m.entries, k) + return true + } + + return false +} + +func (m *Map[K, V]) RemoveIf(k K, callback func(v V) bool) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if v, ok := m.entries[k]; ok && callback(v) { + delete(m.entries, k) + return true + } + + return false +} + +func (m *Map[K, V]) RemoveByValue(v V) { + m.mu.Lock() + defer m.mu.Unlock() + + for k, tv := range m.entries { + if tv == v { + delete(m.entries, k) + } + } +} + +func (m *Map[K, V]) Len() int { + m.mu.RLock() + defer m.mu.RUnlock() + + return len(m.entries) +} diff --git a/test/integration/server_lock_test.go b/test/integration/server_lock_test.go index ec539d8..a7cbfea 100644 --- a/test/integration/server_lock_test.go +++ b/test/integration/server_lock_test.go @@ -3,14 +3,18 @@ package integration import ( "bufio" "net" + "sync" "testing" "time" ) func TestLock(t *testing.T) { - serverAddr, cancel := startServer(t) + wg, serverAddr, cancel := startServer(t) - defer cancel() + defer func() { + cancel() + wg.Wait() + }() conn := connect(t, serverAddr) @@ -47,9 +51,12 @@ func TestLock(t *testing.T) { } func TestLockWithTimeoutAndLockoutSameConnection(t *testing.T) { - serverAddr, cancel := startServer(t) + wg, serverAddr, cancel := startServer(t) - defer cancel() + defer func() { + cancel() + wg.Wait() + }() conn := connect(t, serverAddr) @@ -115,9 +122,12 @@ func TestLockWithTimeoutAndLockoutSameConnection(t *testing.T) { } func TestLockWithTimeoutAndLockoutDifferentConnection(t *testing.T) { - serverAddr, cancel := startServer(t) + wg, serverAddr, cancel := startServer(t) - defer cancel() + defer func() { + cancel() + wg.Wait() + }() // try 1 - success @@ -189,3 +199,90 @@ func TestLockWithTimeoutAndLockoutDifferentConnection(t *testing.T) { t.Fatalf("response mismatch, expected \"%s\", got \"%s\"", expected, recvEscaped) } } + +func TestLockWithTimeoutAndConcurrency(t *testing.T) { + wg, serverAddr, cancel := startServer(t) + + defer func() { + cancel() + wg.Wait() + }() + + wg2 := &sync.WaitGroup{} + + go func() { + wg2.Add(1) + + conn1 := connect(t, serverAddr) + + defer func(conn *net.TCPConn) { + _ = conn.Close() + wg2.Done() + }(conn1) + + _, err := conn1.Write([]byte("lock test1\n")) + + if err != nil { + t.Error(err) + return + } + + r1 := bufio.NewReader(conn1) + + recv, err := r1.ReadString('\n') + + if err != nil { + t.Error(err) + return + } + + recvEscaped := escapeString(recv) + expected := "success\\x0A" + + if recvEscaped != expected { + t.Errorf("response mismatch, expected \"%s\", got \"%s\"", expected, recvEscaped) + return + } + + time.Sleep(time.Millisecond * 250) + }() + + go func() { + wg2.Add(1) + + conn1 := connect(t, serverAddr) + + defer func(conn *net.TCPConn) { + _ = conn.Close() + wg2.Done() + }(conn1) + + _, err := conn1.Write([]byte("lock test1 500\n")) + + if err != nil { + t.Error(err) + return + } + + r1 := bufio.NewReader(conn1) + + recv, err := r1.ReadString('\n') + + if err != nil { + t.Error(err) + return + } + + recvEscaped := escapeString(recv) + expected := "success\\x0A" + + if recvEscaped != expected { + t.Errorf("response mismatch, expected \"%s\", got \"%s\"", expected, recvEscaped) + return + } + }() + + time.Sleep(time.Millisecond) + + wg2.Wait() +} diff --git a/test/integration/server_trylock_test.go b/test/integration/server_trylock_test.go index 047848e..2fd7df0 100644 --- a/test/integration/server_trylock_test.go +++ b/test/integration/server_trylock_test.go @@ -8,9 +8,12 @@ import ( ) func TestTryLock(t *testing.T) { - serverAddr, cancel := startServer(t) + wg, serverAddr, cancel := startServer(t) - defer cancel() + defer func() { + cancel() + wg.Wait() + }() conn := connect(t, serverAddr) @@ -47,9 +50,12 @@ func TestTryLock(t *testing.T) { } func TestTryLockAndLockoutSameConnection(t *testing.T) { - serverAddr, cancel := startServer(t) + wg, serverAddr, cancel := startServer(t) - defer cancel() + defer func() { + cancel() + wg.Wait() + }() conn := connect(t, serverAddr) @@ -115,9 +121,12 @@ func TestTryLockAndLockoutSameConnection(t *testing.T) { } func TestTryLockAndLockoutDifferentConnection(t *testing.T) { - serverAddr, cancel := startServer(t) + wg, serverAddr, cancel := startServer(t) - defer cancel() + defer func() { + cancel() + wg.Wait() + }() // try 1 - success diff --git a/test/integration/test_helpers.go b/test/integration/test_helpers.go index b39558a..7645f0f 100644 --- a/test/integration/test_helpers.go +++ b/test/integration/test_helpers.go @@ -6,6 +6,7 @@ import ( "github.com/nitwhiz/omnilock/pkg/server" "net" "strings" + "sync" "testing" ) @@ -19,7 +20,9 @@ func connect(t *testing.T, tcpServer *net.TCPAddr) *net.TCPConn { return conn } -func startServer(t *testing.T) (*net.TCPAddr, context.CancelFunc) { +func startServer(t *testing.T) (*sync.WaitGroup, *net.TCPAddr, context.CancelFunc) { + wg := &sync.WaitGroup{} + ctx, cancel := context.WithCancel(context.Background()) s, err := server.New(ctx, server.WithListenAddr("localhost:3000")) @@ -28,11 +31,16 @@ func startServer(t *testing.T) (*net.TCPAddr, context.CancelFunc) { t.Fatal(err) } - go s.Accept() + go func() { + wg.Add(1) + defer wg.Done() + + s.Run() + }() tcpServer, err := net.ResolveTCPAddr("tcp", "localhost:3000") - return tcpServer, cancel + return wg, tcpServer, cancel } func escapeString(input string) string {