Skip to content

Commit

Permalink
Pass a context to the calling function for cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
janos committed May 27, 2021
1 parent d1b997f commit f856ab8
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 20 deletions.
37 changes: 27 additions & 10 deletions singleflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,43 @@ type Group struct {
mu sync.Mutex // protects calls
}

// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// Passed context terminates the execution of Do function, not the passed
// function fn. If there are multiple callers, context passed to one caller
// does not effect the execution and returned values of others.
// Do executes and returns the results of the given function, making sure that
// only one execution is in-flight for a given key at a time. If a duplicate
// comes in, the duplicate caller waits for the original to complete and
// receives the same results.
//
// The context passed to the fn function is a new context which is canceled when
// contexts from all callers are canceled, so that no caller is expecting the
// result. If there are multiple callers, context passed to one caller does not
// effect the execution and returned values of others.
//
// The return value shared indicates whether v was given to multiple callers.
func (g *Group) Do(ctx context.Context, key string, fn func() (interface{}, error)) (v interface{}, shared bool, err error) {
func (g *Group) Do(ctx context.Context, key string, fn func(ctx context.Context) (interface{}, error)) (v interface{}, shared bool, err error) {
g.mu.Lock()
if g.calls == nil {
g.calls = make(map[string]*call)
}

if c, ok := g.calls[key]; ok {
c.shared = true
c.counter++
g.mu.Unlock()

return g.wait(ctx, key, c)
}

callCtx, cancel := context.WithCancel(context.Background())

c := &call{
done: make(chan struct{}),
done: make(chan struct{}),
cancel: cancel,
counter: 1,
}
g.calls[key] = c
g.mu.Unlock()

go func() {
c.val, c.err = fn()
c.val, c.err = fn(callCtx)
close(c.done)
}()

Expand All @@ -65,6 +73,10 @@ func (g *Group) wait(ctx context.Context, key string, c *call) (v interface{}, s
err = ctx.Err()
}
g.mu.Lock()
c.counter--
if c.counter == 0 {
c.cancel()
}
if !c.forgotten {
delete(g.calls, key)
}
Expand Down Expand Up @@ -99,4 +111,9 @@ type call struct {

// shared indicates if results val and err are passed to multiple callers.
shared bool

// Number of callers that are waiting for the result.
counter int
// Cancel function for the context passed to the executing function.
cancel context.CancelFunc
}
157 changes: 147 additions & 10 deletions singleflight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
package singleflight_test

import (
"bytes"
"context"
"errors"
"fmt"
"runtime/pprof"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
Expand All @@ -21,7 +25,7 @@ func TestDo(t *testing.T) {
var g singleflight.Group

want := "val"
got, shared, err := g.Do(context.Background(), "key", func() (interface{}, error) {
got, shared, err := g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) {
return want, nil
})
if err != nil {
Expand All @@ -38,7 +42,7 @@ func TestDo(t *testing.T) {
func TestDo_error(t *testing.T) {
var g singleflight.Group
wantErr := errors.New("test error")
got, _, err := g.Do(context.Background(), "key", func() (interface{}, error) {
got, _, err := g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) {
return nil, wantErr
})
if err != wantErr {
Expand All @@ -64,7 +68,7 @@ func TestDo_multipleCalls(t *testing.T) {
for i := 0; i < n; i++ {
go func(i int) {
defer wg.Done()
got[i], shared[i], err[i] = g.Do(context.Background(), "key", func() (interface{}, error) {
got[i], shared[i], err[i] = g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) {
atomic.AddInt32(&counter, 1)
time.Sleep(100 * time.Millisecond)
return want, nil
Expand Down Expand Up @@ -95,7 +99,7 @@ func TestDo_callRemoval(t *testing.T) {

wantPrefix := "val"
counter := 0
fn := func() (interface{}, error) {
fn := func(_ context.Context) (interface{}, error) {
counter++
return wantPrefix + strconv.Itoa(counter), nil
}
Expand Down Expand Up @@ -124,6 +128,9 @@ func TestDo_callRemoval(t *testing.T) {
}

func TestDo_cancelContext(t *testing.T) {
done := make(chan struct{})
defer close(done)

var g singleflight.Group

want := "val"
Expand All @@ -133,8 +140,11 @@ func TestDo_cancelContext(t *testing.T) {
cancel()
}()
start := time.Now()
got, shared, err := g.Do(ctx, "key", func() (interface{}, error) {
time.Sleep(time.Second)
got, shared, err := g.Do(ctx, "key", func(_ context.Context) (interface{}, error) {
select {
case <-time.After(time.Second):
case <-done:
}
return want, nil
})
if d := time.Since(start); d < 100*time.Microsecond || d > time.Second {
Expand All @@ -152,11 +162,17 @@ func TestDo_cancelContext(t *testing.T) {
}

func TestDo_cancelContextSecond(t *testing.T) {
done := make(chan struct{})
defer close(done)

var g singleflight.Group

want := "val"
fn := func() (interface{}, error) {
time.Sleep(time.Second)
fn := func(_ context.Context) (interface{}, error) {
select {
case <-time.After(time.Second):
case <-done:
}
return want, nil
}
go func() {
Expand Down Expand Up @@ -186,16 +202,22 @@ func TestDo_cancelContextSecond(t *testing.T) {
}

func TestForget(t *testing.T) {
done := make(chan struct{})
defer close(done)

var g singleflight.Group

wantPrefix := "val"
var counter uint64
firstCall := make(chan struct{})
fn := func() (interface{}, error) {
fn := func(_ context.Context) (interface{}, error) {
c := atomic.AddUint64(&counter, 1)
if c == 1 {
close(firstCall)
time.Sleep(time.Second)
select {
case <-time.After(time.Second):
case <-done:
}
}
return wantPrefix + strconv.FormatUint(c, 10), nil
}
Expand All @@ -220,3 +242,118 @@ func TestForget(t *testing.T) {
t.Errorf("got value %v, want %v", got, want)
}
}

func TestDo_multipleCallsCanceled(t *testing.T) {
const n = 5

for lastCall := 0; lastCall < n; lastCall++ {
lastCall := lastCall
t.Run(fmt.Sprintf("last call %v of %v", lastCall, n), func(t *testing.T) {
done := make(chan struct{})
defer close(done)

var g singleflight.Group

var counter int32

fnCalled := make(chan struct{})
fnErrChan := make(chan error)
var mu sync.Mutex
contexts := make([]context.Context, n)
cancelFuncs := make([]context.CancelFunc, n)
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
go func(i int) {
defer wg.Done()
ctx, cancel := context.WithCancel(context.Background())
mu.Lock()
contexts[i] = ctx
cancelFuncs[i] = cancel
mu.Unlock()
_, _, _ = g.Do(ctx, "key", func(ctx context.Context) (interface{}, error) {
atomic.AddInt32(&counter, 1)
close(fnCalled)
var err error
select {
case <-ctx.Done():
err = ctx.Err()
if err == nil {
err = errors.New("got unexpected <nil> error from context")
}
case <-time.After(10 * time.Second):
err = errors.New("unexpected timeout, context not canceled")
case <-done:
}

fnErrChan <- err

return nil, nil
})
}(i)
}
select {
case <-fnCalled:
case <-time.After(10 * time.Second):
t.Fatal("timeout waiting for function to be called")
}

// Ensure that n goroutines are waiting at the select case in Group.wait.
// Update the line number on changes.
waitStacks(t, "resenje.org/singleflight/singleflight.go:68", n, 2*time.Second)

// cancel all but one calls
for i := 0; i < n; i++ {
if i == lastCall {
continue
}
mu.Lock()
cancelFuncs[i]()
<-contexts[i].Done()
mu.Unlock()
}

select {
case err := <-fnErrChan:
t.Fatalf("got unexpected error in function: %v", err)
default:
}

// Ensure that only the last goroutine is waiting at the select case in Group.wait.
// Update the line number on changes.
waitStacks(t, "resenje.org/singleflight/singleflight.go:68", 1, 2*time.Second)

mu.Lock()
cancelFuncs[lastCall]()
mu.Unlock()

wg.Wait()

select {
case err := <-fnErrChan:
if err != context.Canceled {
t.Fatalf("got unexpected error in function %v, want %v", err, context.Canceled)
}
case <-time.After(10 * time.Second):
t.Fatal("timeout waiting for the error")
}
})
}
}

func waitStacks(t *testing.T, loc string, count int, timeout time.Duration) {
t.Helper()

for deadline := time.Now().Add(timeout); time.Now().Before(deadline); {
// Ensure that exact n goroutines are waiting at the desired stack trace.
var buf bytes.Buffer
if err := pprof.Lookup("goroutine").WriteTo(&buf, 2); err != nil {
t.Fatal(err)
}
c := strings.Count(buf.String(), loc)
if c == count {
break
}
time.Sleep(10 * time.Millisecond)
}
}

2 comments on commit f856ab8

@aloknerurkar
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@janos
Copy link
Owner Author

@janos janos commented on f856ab8 May 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Thank you, Alok! :)

Please sign in to comment.