diff --git a/xload/async.go b/xload/async.go index 41092dc..00a4e8d 100644 --- a/xload/async.go +++ b/xload/async.go @@ -7,8 +7,30 @@ import ( "github.com/sourcegraph/conc/pool" ) +func processConcurrently(ctx context.Context, v any, opts *options) error { + doneCh := make(chan struct{}, 1) + defer close(doneCh) + + p := pool.New().WithErrors().WithMaxGoroutines(opts.concurrency).WithContext(ctx).WithCancelOnError() + + err := processAsync(p, opts, opts.loader, v, func() { + doneCh <- struct{}{} + }) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-doneCh: + return err + } +} + //nolint:funlen,nestif -func processAsync(ctx context.Context, obj any, o *options, loader Loader) error { +func processAsync(p *pool.ContextPool, o *options, loader Loader, obj any, cb func()) error { + if cb != nil { + defer cb() + } + v := reflect.ValueOf(obj) if v.Kind() != reflect.Ptr { @@ -22,8 +44,6 @@ func processAsync(ctx context.Context, obj any, o *options, loader Loader) error typ := value.Type() - p := pool.New().WithErrors().WithMaxGoroutines(o.concurrency) - for i := 0; i < typ.NumField(); i++ { fTyp := typ.Field(i) fVal := value.Field(i) @@ -79,7 +99,9 @@ func processAsync(ctx context.Context, obj any, o *options, loader Loader) error // if the struct has a key, load it // and set the value to the struct if meta.name != "" && hasDecoder(fVal) { - loadAndSet := func(original reflect.Value, fVal reflect.Value, isNilStructPtr bool) error { + loadAndSet := func( + ctx context.Context, original reflect.Value, fVal reflect.Value, isNilStructPtr bool, + ) error { val, err := loader.Load(ctx, meta.name) if err != nil { return err @@ -102,9 +124,7 @@ func processAsync(ctx context.Context, obj any, o *options, loader Loader) error original := value.Field(i) - p.Go(func() error { - return loadAndSet(original, fVal, isNilStructPtr) - }) + p.Go(func(ctx context.Context) error { return loadAndSet(ctx, original, fVal, isNilStructPtr) }) continue } @@ -114,7 +134,7 @@ func processAsync(ctx context.Context, obj any, o *options, loader Loader) error pld = PrefixLoader(meta.prefix, loader) } - err := processAsync(ctx, fVal.Interface(), o, pld) + err := processAsync(p, o, pld, fVal.Interface(), nil) if err != nil { return err } @@ -128,7 +148,7 @@ func processAsync(ctx context.Context, obj any, o *options, loader Loader) error return ErrInvalidPrefix } - loadAndSet := func(fVal reflect.Value) error { + loadAndSet := func(ctx context.Context, fVal reflect.Value) error { // lookup value val, err := loader.Load(ctx, meta.name) if err != nil { @@ -148,9 +168,11 @@ func processAsync(ctx context.Context, obj any, o *options, loader Loader) error return nil } - p.Go(func() error { - return loadAndSet(fVal) - }) + p.Go(func(ctx context.Context) error { return loadAndSet(ctx, fVal) }) + } + + if cb == nil { + return nil } return p.Wait() diff --git a/xload/load.go b/xload/load.go index 5e0a130..32b6cb4 100644 --- a/xload/load.go +++ b/xload/load.go @@ -50,7 +50,7 @@ func Load(ctx context.Context, v any, opts ...Option) error { o := newOptions(opts...) if o.concurrency > 1 { - return processAsync(ctx, v, o, o.loader) + return processConcurrently(ctx, v, o) } return process(ctx, v, o.tagName, o.loader)