Skip to content

Commit

Permalink
handler, s3store: Fix data race problems (#1199)
Browse files Browse the repository at this point in the history
* ci: Enable data race detector in tests

* handler: Fix data race in bodyReader when stopping upload

* s3store: Prevent data race using `errgroup` package

* s3store: Fix data race in `concatUsingMultipart`

* fixup! s3store: Fix data race in `concatUsingMultipart`

* fixup! handler: Fix data race in bodyReader when stopping upload
  • Loading branch information
Acconut authored Oct 4, 2024
1 parent 1f510ec commit 8d43cf2
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 53 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/continuous-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ jobs:
-
name: Test code
run: |
go test ./pkg/...
go test ./internal/...
go test -race ./pkg/...
go test -race ./internal/...
shell: bash

-
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ require (
github.com/vimeo/go-util v1.4.1
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
golang.org/x/net v0.29.0
golang.org/x/sync v0.8.0
google.golang.org/api v0.199.0
google.golang.org/grpc v1.67.0
google.golang.org/protobuf v1.34.2
Expand Down Expand Up @@ -95,7 +96,6 @@ require (
go.opentelemetry.io/otel/trace v1.29.0 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.6.0 // indirect
Expand Down
25 changes: 20 additions & 5 deletions pkg/handler/body_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
Expand All @@ -28,8 +29,11 @@ type bodyReader struct {
bytesCounter int64
ctx *httpContext
reader io.ReadCloser
err error
onReadDone func()

// lock protects concurrent access to err.
lock sync.RWMutex
err error
}

func newBodyReader(c *httpContext, maxSize int64) *bodyReader {
Expand All @@ -41,7 +45,10 @@ func newBodyReader(c *httpContext, maxSize int64) *bodyReader {
}

func (r *bodyReader) Read(b []byte) (int, error) {
if r.err != nil {
r.lock.RLock()
hasErrored := r.err != nil
r.lock.RUnlock()
if hasErrored {
return 0, io.EOF
}

Expand Down Expand Up @@ -99,28 +106,36 @@ func (r *bodyReader) Read(b []byte) (int, error) {

// Other errors are stored for retrival with hasError, but is not returned
// to the consumer. We do not overwrite an error if it has been set already.
r.lock.Lock()
if r.err == nil {
r.err = err
}
r.lock.Unlock()
}

return n, nil
}

func (r bodyReader) hasError() error {
if r.err == io.EOF {
func (r *bodyReader) hasError() error {
r.lock.RLock()
err := r.err
r.lock.RUnlock()

if err == io.EOF {
return nil
}

return r.err
return err
}

func (r *bodyReader) bytesRead() int64 {
return atomic.LoadInt64(&r.bytesCounter)
}

func (r *bodyReader) closeWithError(err error) {
r.lock.Lock()
r.err = err
r.lock.Unlock()

// SetReadDeadline with the current time causes concurrent reads to the body to time out,
// so the body will be closed sooner with less delay.
Expand Down
1 change: 1 addition & 0 deletions pkg/s3store/multi_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
)

// TODO: Replace with errors.Join
func newMultiError(errs []error) error {
message := "Multiple errors occurred:\n"
for _, err := range errs {
Expand Down
90 changes: 45 additions & 45 deletions pkg/s3store/s3store.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ import (
"github.com/tus/tusd/v2/internal/uid"
"github.com/tus/tusd/v2/pkg/handler"
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
Expand Down Expand Up @@ -469,8 +470,7 @@ func (upload *s3Upload) uploadParts(ctx context.Context, offset int64, src io.Re
}()
go partProducer.produce(producerCtx, optimalPartSize)

var wg sync.WaitGroup
var uploadErr error
var eg errgroup.Group

for {
// We acquire the semaphore before starting the goroutine to avoid
Expand All @@ -497,10 +497,8 @@ func (upload *s3Upload) uploadParts(ctx context.Context, offset int64, src io.Re
}
upload.parts = append(upload.parts, part)

wg.Add(1)
go func(file io.ReadSeeker, part *s3Part, closePart func() error) {
eg.Go(func() error {
defer upload.store.releaseUploadSemaphore()
defer wg.Done()

t := time.Now()
uploadPartInput := &s3.UploadPartInput{
Expand All @@ -509,39 +507,46 @@ func (upload *s3Upload) uploadParts(ctx context.Context, offset int64, src io.Re
UploadId: aws.String(upload.multipartId),
PartNumber: aws.Int32(part.number),
}
etag, err := upload.putPartForUpload(ctx, uploadPartInput, file, part.size)
etag, err := upload.putPartForUpload(ctx, uploadPartInput, partfile, part.size)
store.observeRequestDuration(t, metricUploadPart)
if err != nil {
uploadErr = err
} else {
if err == nil {
part.etag = etag
}
if cerr := closePart(); cerr != nil && uploadErr == nil {
uploadErr = cerr

cerr := closePart()
if err != nil {
return err
}
}(partfile, part, closePart)
if cerr != nil {
return cerr
}
return nil
})
} else {
wg.Add(1)
go func(file io.ReadSeeker, closePart func() error) {
eg.Go(func() error {
defer upload.store.releaseUploadSemaphore()
defer wg.Done()

if err := store.putIncompletePartForUpload(ctx, upload.objectId, file); err != nil {
uploadErr = err
err := store.putIncompletePartForUpload(ctx, upload.objectId, partfile)
if err == nil {
upload.incompletePartSize = partsize
}
if cerr := closePart(); cerr != nil && uploadErr == nil {
uploadErr = cerr

cerr := closePart()
if err != nil {
return err
}
if cerr != nil {
return cerr
}
upload.incompletePartSize = partsize
}(partfile, closePart)
return nil
})
}

bytesUploaded += partsize
nextPartNum += 1
}

wg.Wait()

uploadErr := eg.Wait()
if uploadErr != nil {
return 0, uploadErr
}
Expand Down Expand Up @@ -969,47 +974,42 @@ func (upload *s3Upload) concatUsingDownload(ctx context.Context, partialUploads
func (upload *s3Upload) concatUsingMultipart(ctx context.Context, partialUploads []handler.Upload) error {
store := upload.store

numPartialUploads := len(partialUploads)
errs := make([]error, 0, numPartialUploads)
upload.parts = make([]*s3Part, len(partialUploads))

// Copy partial uploads concurrently
var wg sync.WaitGroup
wg.Add(numPartialUploads)
var eg errgroup.Group
for i, partialUpload := range partialUploads {

// Part numbers must be in the range of 1 to 10000, inclusive. Since
// slice indexes start at 0, we add 1 to ensure that i >= 1.
partNumber := int32(i + 1)
partialS3Upload := partialUpload.(*s3Upload)

upload.parts = append(upload.parts, &s3Part{
number: partNumber,
size: -1,
etag: "",
})

go func(partNumber int32, sourceObject string) {
defer wg.Done()

eg.Go(func() error {
res, err := store.Service.UploadPartCopy(ctx, &s3.UploadPartCopyInput{
Bucket: aws.String(store.Bucket),
Key: store.keyWithPrefix(upload.objectId),
UploadId: aws.String(upload.multipartId),
PartNumber: aws.Int32(partNumber),
CopySource: aws.String(store.Bucket + "/" + *store.keyWithPrefix(sourceObject)),
CopySource: aws.String(store.Bucket + "/" + *store.keyWithPrefix(partialS3Upload.objectId)),
})
if err != nil {
errs = append(errs, err)
return
return err
}

upload.parts[partNumber-1].etag = *res.CopyPartResult.ETag
}(partNumber, partialS3Upload.objectId)
}
upload.parts[partNumber-1] = &s3Part{
number: partNumber,
size: -1, // -1 is fine here bcause FinishUpload does not need this info.
etag: *res.CopyPartResult.ETag,
}

wg.Wait()
return nil
})
}

if len(errs) > 0 {
return newMultiError(errs)
err := eg.Wait()
if err != nil {
return err
}

return upload.FinishUpload(ctx)
Expand Down

0 comments on commit 8d43cf2

Please sign in to comment.