Skip to content

Commit

Permalink
Fix batched streams [#3439]
Browse files Browse the repository at this point in the history
  • Loading branch information
firelizzard18 committed Oct 4, 2023
1 parent 3d6b0b1 commit e3960e4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
11 changes: 7 additions & 4 deletions pkg/api/v3/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ package api
import "context"

type BatchData struct {
values map[any]any
context context.Context
values map[any]any
}

func (d *BatchData) Get(k any) any { return d.values[k] }
func (d *BatchData) Put(k, v any) { d.values[k] = v }
func (d *BatchData) Context() context.Context { return d.context }
func (d *BatchData) Get(k any) any { return d.values[k] }
func (d *BatchData) Put(k, v any) { d.values[k] = v }

type contextKeyBatch struct{}

Expand All @@ -29,9 +31,10 @@ func ContextWithBatchData(ctx context.Context) (context.Context, context.CancelF
return ctx, func() {}, v
}

ctx, cancel := context.WithCancel(ctx)
bd := new(BatchData)
bd.context = ctx
bd.values = map[any]any{}
ctx, cancel := context.WithCancel(ctx)
ctx = context.WithValue(ctx, contextKeyBatch{}, bd)
return ctx, cancel, bd
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/api/v3/p2p/dial_self.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ func (d *selfDialer) Dial(ctx context.Context, addr multiaddr.Multiaddr) (messag
}

func handleLocally(ctx context.Context, service *serviceHandler) message.Stream {
if batch := api.GetBatchData(ctx); batch != nil {
ctx = batch.Context()
}
p, q := message.DuplexPipe(ctx)
go func() {
// Panic protection
Expand Down
9 changes: 8 additions & 1 deletion pkg/api/v3/p2p/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ func openStreamFor(ctx context.Context, host dialerHost, peer peer.ID, sa *api.S
}

// Close the stream when the context is canceled
go func() { <-ctx.Done(); _ = conn.Close() }()
go func() {
ctx := ctx
if batch := api.GetBatchData(ctx); batch != nil {
ctx = batch.Context()
}
<-ctx.Done()
_ = conn.Close()
}()

s := new(stream)
s.peer = peer
Expand Down

0 comments on commit e3960e4

Please sign in to comment.