From 55a96b3e3dd696ff1b183fa39e21191cfa8c828d Mon Sep 17 00:00:00 2001 From: koyeo Date: Tue, 29 Aug 2023 17:10:34 +0800 Subject: [PATCH] feat: update parallel --- go.mod | 2 +- go.sum | 2 + parallel/parallel.go | 108 +++++++++++++++++++++++-------------------- 3 files changed, 62 insertions(+), 50 deletions(-) diff --git a/go.mod b/go.mod index 440f6f5..b242c4b 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/gozelle/async go 1.18 require ( - github.com/gozelle/multierr v1.9.10 + github.com/gozelle/multierr v0.0.1 github.com/gozelle/testify v1.8.11 ) diff --git a/go.sum b/go.sum index ba13bf2..07fbfd1 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/gozelle/go-difflib v1.0.0/go.mod h1:PcDy306aUy3I8kB5ohutXpbhkgQlygtcB github.com/gozelle/go-internal v1.9.0 h1:VVobf+WAesfyYtHklxbmzo+2IXdTw3PxiwkmwC6QTTA= github.com/gozelle/go-spew v1.1.10 h1:kFDq5IDd/Wk3l7UutORhveCt5fUdW8b9afnDYtuHmEc= github.com/gozelle/go-spew v1.1.10/go.mod h1:qwUFQBhiE5zwtTAfe/m87+73t2jRbNHwGOvs+FgYx+8= +github.com/gozelle/multierr v0.0.1 h1:JvYYvGTGusBS6IEOTOmoroY4SqW5PU2S29Yi8T4UvnE= +github.com/gozelle/multierr v0.0.1/go.mod h1:Kd/mRKyMcPywI5eYDaMeiecSeRhgAsu8nIx35ic5NvE= github.com/gozelle/multierr v1.9.10 h1:EUU22u5Yx82/mQm55fH1WF6ztp0w4U+cZl9lCzuyUvE= github.com/gozelle/multierr v1.9.10/go.mod h1:4tC7qdet8CoxU9Q/Ha2bn5obpheqU7bTB9AENXwWzaI= github.com/gozelle/pretty v0.3.1 h1:dtU7yIlzRqiMmB9TcunypDfvAC/QOP0K8y/BKKOrgRg= diff --git a/parallel/parallel.go b/parallel/parallel.go index 7eb5a0f..4248a49 100644 --- a/parallel/parallel.go +++ b/parallel/parallel.go @@ -3,7 +3,7 @@ package parallel import ( "context" "fmt" - "github.com/gozelle/atomic" + "github.com/gozelle/multierr" "runtime/debug" "sync" @@ -21,74 +21,84 @@ type Runner[T any] async.Runner[T] func Run[T any](ctx context.Context, limit uint, runners []Runner[T]) <-chan *Result[T] { - results := make(chan *Result[T], len(runners)) + ch := make(chan *Result[T], len(runners)) if limit == 0 { defer func() { - results <- &Result[T]{Error: fmt.Errorf("limit expect great than 0")} - close(results) + ch <- &Result[T]{Error: fmt.Errorf("limit expect great than 0")} + close(ch) }() - return results + return ch } if ctx == nil { ctx = context.Background() } - //go func() { - // select { - // case <-ctx.Done(): - // err.Store(ctx.Err()) - // case <-done: - // return - // } - //}() + go run[T](ctx, limit, runners, ch) - go run[T](ctx, limit, runners, results) - - return results + return ch } -func run[T any](ctx context.Context, limit uint, runners []Runner[T], results chan *Result[T]) { - err := atomic.NewError(nil) +func run[T any](ctx context.Context, limit uint, runners []Runner[T], ch chan *Result[T]) { + + errs := multierr.Errors{} wg := sync.WaitGroup{} sem := make(chan struct{}, limit) - done := make(chan struct{}) + + defer func() { + close(ch) + close(sem) + }() + for _, v := range runners { + + // achieve a blocking effect by sending semaphores to a channel with a specified capacity of "limit" + // when the channel is full, it will block here until a task is completed and frees up channel capacity sem <- struct{}{} - if err.Load() != nil { - <-sem - continue - } - wg.Add(1) - go func(runner Runner[T]) { - defer func() { - e := recover() - if e != nil { - err.Store(fmt.Errorf("%v", e)) - debug.PrintStack() - } + + // if the semaphore is acquired, prioritize checking whether the context has done. + // if it has, break out of the for loop. + select { + case <-ctx.Done(): + errs.AddError(ctx.Err()) + break + default: + // when an error occurs, the semaphores of all subsequent tasks will be directly ignored. + if errs.Error() != nil { <-sem - wg.Done() - }() - - r, e := runner(ctx) - if e != nil { - err.Store(e) - } else { - results <- &Result[T]{Value: r} + continue } - }(v) + wg.Add(1) + go func(runner Runner[T]) { + defer func() { + e := recover() + if e != nil { + errs.AddError(fmt.Errorf("%v", e)) + debug.PrintStack() + } + // the task has been executed to completion, + // release the semaphore. + <-sem + wg.Done() + }() + + r, e := runner(ctx) + if e != nil { + errs.AddError(e) + } else { + ch <- &Result[T]{Value: r} + } + }(v) + } } - go func() { - wg.Wait() - if err.Load() != nil { - results <- &Result[T]{Error: err.Load()} - } - close(done) - close(results) - close(sem) - }() + wg.Wait() + + // all tasks have been completed. + // check for any errors and ensure that the error is the last result sent to the channel. + if errs.Error() != nil { + ch <- &Result[T]{Error: errs.Error()} + } } func Wait[T any](results <-chan *Result[T], handler func(v T) error) error {