Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ShutdownWithContext and ctx.Done() exist race. #1908

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,8 @@ func (s *Server) Shutdown() error {
//
// ShutdownWithContext does not close keepalive connections so it's recommended to set ReadTimeout and IdleTimeout
// to something else than 0.
//
// When ShutdownWithContext returns errors, any operation to the Server is unavailable.
func (s *Server) ShutdownWithContext(ctx context.Context) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
Expand All @@ -1898,11 +1900,7 @@ func (s *Server) ShutdownWithContext(ctx context.Context) (err error) {
return nil
}

for _, ln := range s.ln {
if err = ln.Close(); err != nil {
return err
}
}
lnerr := s.closeListenersLocked()

if s.done != nil {
close(s.done)
Expand All @@ -1913,28 +1911,25 @@ func (s *Server) ShutdownWithContext(ctx context.Context) (err error) {
// Now we just have to wait until all workers are done or timeout.
ticker := time.NewTicker(time.Millisecond * 100)
defer ticker.Stop()
END:

for {
s.closeIdleConns()

if open := atomic.LoadInt32(&s.open); open == 0 {
break
// There may be a pending request to call ctx.Done(). Therefore, we only set it to nil when open == 0.
s.done = nil
return lnerr
}
// This is not an optimal solution but using a sync.WaitGroup
// here causes data races as it's hard to prevent Add() to be called
// while Wait() is waiting.
select {
case <-ctx.Done():
err = ctx.Err()
break END
return ctx.Err()
case <-ticker.C:
continue
}
}

s.done = nil
s.ln = nil
return err
}

func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
Expand Down Expand Up @@ -2749,15 +2744,7 @@ func (ctx *RequestCtx) Deadline() (deadline time.Time, ok bool) {
// Note: Because creating a new channel for every request is just too expensive, so
// RequestCtx.s.done is only closed when the server is shutting down.
func (ctx *RequestCtx) Done() <-chan struct{} {
// fix use new variables to prevent panic caused by modifying the original done chan to nil.
done := ctx.s.done

if done == nil {
done = make(chan struct{}, 1)
done <- struct{}{}
return done
}
return done
return ctx.s.done
}

// Err returns a non-nil error value after Done is closed,
Expand Down Expand Up @@ -2934,6 +2921,17 @@ func (s *Server) closeIdleConns() {
s.idleConnsMu.Unlock()
}

func (s *Server) closeListenersLocked() error {
var err error
for _, ln := range s.ln {
if cerr := ln.Close(); cerr != nil && err == nil {
err = cerr
}
}
s.ln = nil
return err
}

// A ConnState represents the state of a client connection to a server.
// It's used by the optional Server.ConnState hook.
type ConnState int
Expand Down
46 changes: 46 additions & 0 deletions server_race_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//go:build race

package fasthttp

import (
"context"
"github.com/valyala/fasthttp/fasthttputil"
"math"
"testing"
)

func TestServerDoneRace(t *testing.T) {
t.Parallel()

s := &Server{
Handler: func(ctx *RequestCtx) {
for i := 0; i < math.MaxInt; i++ {
ctx.Done()
}
},
}

ln := fasthttputil.NewInmemoryListener()
defer ln.Close()

go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()

c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer c.Close()
if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: go.dev\r\nContent-Length: 3\r\n\r\nABC" +
"\r\n\r\n" + // <-- this stuff is bogus, but we'll ignore it
"GET / HTTP/1.1\r\nHost: go.dev\r\n\r\n")); err != nil {
t.Fatal(err)
}
ctx, cancelFunc := context.WithCancel(context.Background())
cancelFunc()

s.ShutdownWithContext(ctx)
}