diff --git a/go.mod b/go.mod index 61f948a..f9d6865 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,6 @@ go 1.13 require ( github.com/go-test/deep v1.0.7 + github.com/joeycumines/go-bigbuff v1.14.0 github.com/xlab/treeprint v1.1.0 ) diff --git a/go.sum b/go.sum index f06efe3..1c15462 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,10 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-test/deep v1.0.4/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/go-test/deep v1.0.7 h1:/VSMRlnY/JSyqxQUzQLKVMAskpY/NZKFA5j2P+0pP2M= github.com/go-test/deep v1.0.7/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8= +github.com/joeycumines/go-bigbuff v1.14.0 h1:nHql/X/YMUrV7sMQ9w0+H9T8vj0YL15lT/yOr7U0HpE= +github.com/joeycumines/go-bigbuff v1.14.0/go.mod h1:7hqtGnMDT3v+yOvHUb+hx+JSxhlTF2W9BGIkvNQizaA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/manager.go b/manager.go index e644a49..40a2b7c 100644 --- a/manager.go +++ b/manager.go @@ -18,7 +18,7 @@ package behaviortree import ( "errors" - "strings" + "github.com/joeycumines/go-bigbuff" "sync" ) @@ -33,12 +33,29 @@ type ( // manager is this package's implementation of the Manager interface manager struct { - mutex sync.Mutex - tickers []Ticker - errs []string - stopped bool + mu sync.RWMutex + once sync.Once + worker bigbuff.Worker done chan struct{} + stop chan struct{} + tickers chan managerTicker + errs []error } + + managerTicker struct { + Ticker Ticker + Done func() + } + + errManagerTicker []error + + errManagerStopped struct{ error } +) + +var ( + // ErrManagerStopped is returned by the manager implementation in this package (see also NewManager) in the case + // that Manager.Add is attempted after the manager has already started to stop. Use errors.Is to check this case. + ErrManagerStopped error = errManagerStopped{error: errors.New(`behaviortree.Manager.Add already stopped`)} ) // NewManager will construct an implementation of the Manager interface, which is a stateful set of Ticker @@ -47,9 +64,15 @@ type ( // // Note that any error (of any registered tickers) will also trigger stopping, and stopping will prevent further // Add calls from succeeding. +// +// As of v1.8.0, any (combined) ticker error returned by the Manager can now support error chaining (i.e. the use of +// errors.Is). Note that errors.Unwrap isn't supported, since there may be more than one. See also Manager.Err and +// Manager.Add. func NewManager() Manager { result := &manager{ - done: make(chan struct{}), + done: make(chan struct{}), + stop: make(chan struct{}), + tickers: make(chan managerTicker), } return result } @@ -59,83 +82,113 @@ func (m *manager) Done() <-chan struct{} { } func (m *manager) Err() error { - m.mutex.Lock() - defer m.mutex.Unlock() - return m.err() + m.mu.RLock() + defer m.mu.RUnlock() + if len(m.errs) != 0 { + return errManagerTicker(m.errs) + } + return nil } func (m *manager) Stop() { - m.mutex.Lock() - defer m.mutex.Unlock() - m.stopOnce() + m.once.Do(func() { + close(m.stop) + m.start()() + }) } func (m *manager) Add(ticker Ticker) error { if ticker == nil { return errors.New("behaviortree.Manager.Add nil ticker") } - m.mutex.Lock() - defer m.mutex.Unlock() - m.check() - if m.stopped { - if err := m.err(); err != nil { - return err + done := m.start() + select { + case <-m.stop: + default: + select { + case <-m.stop: + case m.tickers <- managerTicker{ + Ticker: ticker, + Done: done, + }: + return nil } - return errors.New("behaviortree.Manager.Add already stopped") } - m.tickers = append(m.tickers, ticker) - go func() { - <-ticker.Done() - m.mutex.Lock() - defer m.mutex.Unlock() - m.check() - }() - return nil + done() + if err := m.Err(); err != nil { + return errManagerStopped{error: err} + } + return ErrManagerStopped } -func (m *manager) err() error { - if len(m.errs) != 0 { - return errors.New(strings.Join(m.errs, " | ")) +func (m *manager) start() (done func()) { return m.worker.Do(m.run) } + +func (m *manager) run(stop <-chan struct{}) { + for { + select { + case <-stop: + select { + case <-m.stop: + select { + case <-m.done: + default: + close(m.done) + } + default: + } + return + case t := <-m.tickers: + go m.handle(t) + } } - return nil } -func (m *manager) stopOnce() { - if !m.stopped { - m.stopped = true - go m.cleanup() +func (m *manager) handle(t managerTicker) { + select { + case <-t.Ticker.Done(): + // note: this stop shouldn't be necessary, but has been retained for + // consistency, with the previous implementation) + t.Ticker.Stop() + case <-m.stop: + t.Ticker.Stop() + <-t.Ticker.Done() + } + if err := t.Ticker.Err(); err != nil { + m.mu.Lock() + m.errs = append(m.errs, err) + m.mu.Unlock() + m.Stop() } + t.Done() } -func (m *manager) finish(i int) { - m.tickers[i].Stop() - <-m.tickers[i].Done() - if err := m.tickers[i].Err(); err != nil { - m.errs = append(m.errs, err.Error()) - m.stopOnce() +func (e errManagerTicker) Error() string { + var b []byte + for i, err := range e { + if i != 0 { + b = append(b, ' ', '|', ' ') + } + b = append(b, err.Error()...) } - m.tickers[i] = m.tickers[len(m.tickers)-1] - m.tickers[len(m.tickers)-1] = nil - m.tickers = m.tickers[:len(m.tickers)-1] + return string(b) } -func (m *manager) check() { - for i := 0; i < len(m.tickers); i++ { - select { - case <-m.tickers[i].Done(): - m.finish(i) - i-- - default: +func (e errManagerTicker) Is(target error) bool { + for _, err := range e { + if errors.Is(err, target) { + return true } } + return false } -func (m *manager) cleanup() { - m.mutex.Lock() - for i := 0; i < len(m.tickers); i++ { - m.finish(i) - i-- +func (e errManagerStopped) Unwrap() error { return e.error } + +func (e errManagerStopped) Is(target error) bool { + switch target.(type) { + case errManagerStopped: + return true + default: + return false } - close(m.done) - m.mutex.Unlock() } diff --git a/manager_test.go b/manager_test.go index 1968e8f..e84aec9 100644 --- a/manager_test.go +++ b/manager_test.go @@ -20,19 +20,175 @@ import ( "errors" "runtime" "sync" + "sync/atomic" "testing" "time" ) -func TestNewManager(t *testing.T) { - startGoroutines := runtime.NumGoroutine() +func TestManager_Stop_raceCloseDone(t *testing.T) { + defer checkGoroutines(t)(false, time.Millisecond*100) + m := NewManager().(*manager) + close(m.done) + m.Stop() +} + +func TestManager_Stop_noTickers(t *testing.T) { + defer checkGoroutines(t)(false, time.Millisecond*100) + m := NewManager() + if err := m.Err(); err != nil { + t.Error(err) + } + select { + case <-m.Done(): + t.Error() + default: + } + m.Stop() + if err := m.Err(); err != nil { + t.Error(err) + } + <-m.Done() + if err := m.Err(); err != nil { + t.Error(err) + } +} + +func TestManager_Add_whileStopping(t *testing.T) { + defer checkGoroutines(t)(false, time.Millisecond*100) + m := NewManager() + for i := 0; i < 10; i++ { + if err := m.Add(NewManager()); err != nil { + t.Fatal(err) + } + } + var ( + wg sync.WaitGroup + count int64 + done int32 + stopped int32 + ) + wg.Add(8) defer func() { - time.Sleep(time.Millisecond * 100) - endGoroutines := runtime.NumGoroutine() - if startGoroutines < endGoroutines { - t.Errorf("ended with %d more goroutines", endGoroutines-startGoroutines) + atomic.AddInt64(&count, -atomic.LoadInt64(&count)) + m.Stop() + atomic.StoreInt32(&stopped, 1) + time.Sleep(time.Millisecond * 50) + atomic.StoreInt32(&done, 1) + wg.Wait() + <-m.Done() + if err := m.Err(); err != nil { + t.Error(err) + } + count := atomic.LoadInt64(&count) + t.Log(count) + if count < 15 { + t.Error(count) } }() + for i := 0; i < 8; i++ { + go func() { + defer wg.Done() + for atomic.LoadInt32(&done) == 0 { + var ( + stoppedBefore = atomic.LoadInt32(&stopped) != 0 + err = m.Add(NewManager()) + stoppedAfter = atomic.LoadInt32(&stopped) != 0 + ) + if err != nil && err != ErrManagerStopped { + t.Error(err) + } + if stoppedBefore && !stoppedAfter { + t.Error() + } + if stoppedBefore && err == nil { + t.Error() + } + atomic.AddInt64(&count, 1) + time.Sleep(time.Millisecond) + } + }() + } + time.Sleep(time.Millisecond * 20) +} + +func TestManager_Add_secondStopCase(t *testing.T) { + defer checkGoroutines(t)(false, time.Millisecond*100) + out := make(chan error) + defer close(out) + done := make(chan struct{}) + stop := make(chan struct{}) + go func() { out <- (&manager{stop: stop, done: done}).Add(mockTicker{}) }() + time.Sleep(time.Millisecond * 100) + select { + case err := <-out: + t.Fatal(err) + default: + } + close(stop) + if err := <-out; err != ErrManagerStopped { + t.Error(err) + } + <-done +} + +func TestManager_Stop_cleanupGoroutines(t *testing.T) { + check := checkGoroutines(t) + defer check(false, time.Millisecond*100) + + m := NewManager() + + { + // add a ticker then stop it, then verify that all resources (goroutines) are cleaned up + check(false, 0) + done := make(chan struct{}) + if err := m.Add(mockTicker{ + done: func() <-chan struct{} { return done }, + err: func() error { return nil }, + stop: func() {}, + }); err != nil { + t.Fatal(err) + } + check(true, 0) + close(done) + check(false, time.Millisecond*50) + if err := m.Err(); err != nil { + t.Error(err) + } + select { + case <-m.Done(): + t.Fatal() + default: + } + } + + { + // add two tickers (one multiple times), then stop one, then the other + var ( + m1 = NewManager() + m2 = NewManager() + ) + check(false, 0) + for i := 0; i < 30; i++ { + if err := m.Add(m1); err != nil { + t.Fatal(err) + } + } + if err := m.Add(m2); err != nil { + t.Fatal(err) + } + check(true, time.Millisecond*50) + gr := runtime.NumGoroutine() + m1.Stop() + check(true, time.Millisecond*50) + if diff := runtime.NumGoroutine() - gr + 1; diff > 0 { + t.Errorf("too many goroutines: +%d", diff) + } + m2.Stop() + } +} + +func TestNewManager(t *testing.T) { + defer checkGoroutines(t)(false, time.Millisecond*100) m := NewManager().(*manager) @@ -96,8 +252,21 @@ func TestNewManager(t *testing.T) { t.Fatal(err) } - if len(m.tickers) != 2 || len(m.errs) != 0 || m.stopped { - t.Fatal(m) + if d := m.Done(); d != m.done || d == nil { + t.Error(d) + } + if err := m.Err(); err != nil { + t.Error(err) + } + select { + case <-m.stop: + t.Error() + default: + } + select { + case <-m.done: + t.Error() + default: } mutex.Lock() @@ -127,22 +296,58 @@ func TestNewManager(t *testing.T) { <-m.Done() - if len(m.tickers) != 0 { - t.Error(m.tickers) + checkErrTicker := func(err error) { + t.Helper() + if err == nil || err.Error() != "some_error | other_error" { + t.Error(err) + } + if !errors.Is(err, err1) { + t.Error(err) + } + if !errors.Is(err, err2) { + t.Error(err) + } + if errors.Is(err, errors.New(`another_error`)) { + t.Error(err) + } + { + err := err + if v, ok := err.(errManagerStopped); ok { + err = v.Unwrap() + } + if v, ok := err.(errManagerTicker); !ok || len(v) != 2 { + t.Error(err) + } + } + } + checkErrStopped := func(err error) { + t.Helper() + if !errors.Is(err.(errManagerStopped), ErrManagerStopped.(errManagerStopped)) || + !(errManagerStopped{}).Is(err) || + !err.(interface{ Is(error) bool }).Is(errManagerStopped{}) { + t.Error(err) + } } - if err := m.Err(); err == nil || err.Error() != "some_error | other_error" { - t.Error(err) + checkErrTicker(m.Err()) + { + err := m.Add(mockTicker{}) + checkErrTicker(err) + checkErrStopped(err) } // does nothing m.Stop() - if err := m.Add(mockTicker{}); err == nil { - t.Error("expected error") + checkErrTicker(m.Err()) + { + err := m.Add(mockTicker{}) + checkErrTicker(err) + checkErrStopped(err) } + m.errs = nil - if err := m.Add(mockTicker{}); err == nil { + if err := m.Add(mockTicker{}); err != ErrManagerStopped { t.Error("expected error") } if err := m.Add(nil); err == nil { @@ -177,3 +382,19 @@ func (m mockTicker) Stop() { } panic("implement me") } + +func checkGoroutines(t *testing.T) func(increase bool, wait time.Duration) { + t.Helper() + start := runtime.NumGoroutine() + return func(increase bool, wait time.Duration) { + t.Helper() + time.Sleep(wait) + if now := runtime.NumGoroutine(); increase { + if start >= now { + t.Errorf("too few goroutines: -%d", start-now+1) + } + } else if start < now { + t.Errorf("too many goroutines: +%d", now-start) + } + } +} diff --git a/printer_test.go b/printer_test.go index 77a5170..9b30b5a 100644 --- a/printer_test.go +++ b/printer_test.go @@ -234,7 +234,9 @@ func TestDefaultPrinterInspector_noName(t *testing.T) { func TestTreePrinter_Fprint_emptyMeta(t *testing.T) { p := TreePrinter{ - Inspector: func(node Node, tick Tick) (meta []interface{}, value interface{}) { return []interface{}{``, ``, ``}, `` }, + Inspector: func(node Node, tick Tick) (meta []interface{}, value interface{}) { + return []interface{}{``, ``, ``}, `` + }, Formatter: DefaultPrinterFormatter, } b := new(bytes.Buffer)