diff --git a/parallel/parallel.go b/parallel/parallel.go index 5abf81c..7eb5a0f 100644 --- a/parallel/parallel.go +++ b/parallel/parallel.go @@ -35,20 +35,24 @@ func Run[T any](ctx context.Context, limit uint, runners []Runner[T]) <-chan *Re ctx = context.Background() } + //go func() { + // select { + // case <-ctx.Done(): + // err.Store(ctx.Err()) + // case <-done: + // return + // } + //}() + + go run[T](ctx, limit, runners, results) + + return results +} +func run[T any](ctx context.Context, limit uint, runners []Runner[T], results chan *Result[T]) { err := atomic.NewError(nil) wg := sync.WaitGroup{} sem := make(chan struct{}, limit) done := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - err.Store(ctx.Err()) - case <-done: - return - } - }() - for _, v := range runners { sem <- struct{}{} if err.Load() != nil { @@ -85,8 +89,6 @@ func Run[T any](ctx context.Context, limit uint, runners []Runner[T]) <-chan *Re close(results) close(sem) }() - - return results } func Wait[T any](results <-chan *Result[T], handler func(v T) error) error { diff --git a/parallel/parallel_test.go b/parallel/parallel_test.go index 0a5df7e..4a28719 100644 --- a/parallel/parallel_test.go +++ b/parallel/parallel_test.go @@ -20,15 +20,17 @@ func TestRun1(t *testing.T) { v := i runners = append(runners, func(ctx context.Context) (result int, err error) { result = v + time.Sleep(time.Duration(v) * time.Second) return }) } values := parallel.Run[int](context.Background(), 2, runners) - + t.Log("begin") n := 0 err := parallel.Wait[int](values, func(v int) error { n += v + t.Log(v) return nil }) require.NoError(t, err)