Skip to content

Commit

Permalink
Add blocking write mode for association
Browse files Browse the repository at this point in the history
Block write is useful for cases like sfu to
forward reliab message between clients and
detach the data channel to use it as normal
golang connections.
  • Loading branch information
cnderrauber committed Dec 9, 2024
1 parent ce7e8bc commit abb6bea
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 10 deletions.
46 changes: 44 additions & 2 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

"github.com/pion/logging"
"github.com/pion/randutil"
"github.com/pion/transport/v3/deadline"
)

// Port 5000 shows up in examples for SDPs used by WebRTC. Since this implementation
Expand Down Expand Up @@ -251,6 +252,10 @@ type Association struct {
delayedAckTriggered bool
immediateAckTriggered bool

blockWrite bool
writePending bool
writeNotify chan struct{}

name string
log logging.LeveledLogger
}
Expand All @@ -264,6 +269,7 @@ type Config struct {
MaxMessageSize uint32
EnableZeroChecksum bool
LoggerFactory logging.LoggerFactory
BlockWrite bool

// congestion control configuration
// RTOMax is the maximum retransmission timeout in milliseconds
Expand Down Expand Up @@ -375,6 +381,8 @@ func createAssociation(config Config) *Association {
stats: &associationStats{},
log: config.LoggerFactory.NewLogger("sctp"),
name: config.Name,
blockWrite: config.BlockWrite,
writeNotify: make(chan struct{}, 1),
}

if a.name == "" {
Expand Down Expand Up @@ -675,6 +683,20 @@ func (a *Association) awakeWriteLoop() {
}
}

func (a *Association) isBlockWrite() bool {
return a.blockWrite
}

// Mark the association is writable and unblock the waiting write,
// the caller should hold the association write lock.
func (a *Association) notifyBlockWritable() {
a.writePending = false
select {
case a.writeNotify <- struct{}{}:
default:
}
}

// unregisterStream un-registers a stream from the association
// The caller should hold the association write lock.
func (a *Association) unregisterStream(s *Stream, err error) {
Expand Down Expand Up @@ -1555,6 +1577,7 @@ func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream
reassemblyQueue: newReassemblyQueue(streamIdentifier),
log: a.log,
name: fmt.Sprintf("%d:%s", streamIdentifier, a.name),
writeDeadline: deadline.New(),
}

s.readNotifier = sync.NewCond(&s.lock)
Expand Down Expand Up @@ -2338,6 +2361,11 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1
}
}

if a.blockWrite && len(chunks) > 0 && a.pendingQueue.size() == 0 {
a.log.Tracef("[%s] all pending data have been sent, notify writable", a.name)
a.notifyBlockWritable()
}

return chunks, sisToReset
}

Expand Down Expand Up @@ -2375,21 +2403,35 @@ func (a *Association) bundleDataChunksIntoPackets(chunks []*chunkPayloadData) []
}

// sendPayloadData sends the data chunks.
func (a *Association) sendPayloadData(chunks []*chunkPayloadData) error {
func (a *Association) sendPayloadData(ctx context.Context, chunks []*chunkPayloadData) error {
a.lock.Lock()
defer a.lock.Unlock()

state := a.getState()
if state != established {
a.lock.Unlock()
return fmt.Errorf("%w: state=%s", ErrPayloadDataStateNotExist,
getAssociationStateString(state))
}

if a.blockWrite {
for a.writePending {
a.lock.Unlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-a.writeNotify:
a.lock.Lock()
}
}
a.writePending = true
}

// Push the chunks into the pending queue first.
for _, c := range chunks {
a.pendingQueue.push(c)
}

a.lock.Unlock()
a.awakeWriteLoop()
return nil
}
Expand Down
95 changes: 93 additions & 2 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2845,7 +2845,7 @@ func TestAssociationReceiveWindow(t *testing.T) {

done := make(chan bool)
go func() {
chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
chunks, _ := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
chunks = chunks[:1]
chunk := chunks[0]
// Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue
Expand Down Expand Up @@ -3016,7 +3016,7 @@ func TestAssociationMaxTSNOffset(t *testing.T) {
require.NoError(t, err)
require.Equal(t, uint16(1), s2.streamIdentifier)

chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
chunks, _ := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
chunks = chunks[:1]
sendChunk := func(tsn uint32) {
chunk := chunks[0]
Expand Down Expand Up @@ -3616,3 +3616,94 @@ func TestAssociation_OpenStreamAfterInternalClose(t *testing.T) {
require.Equal(t, 0, len(a1.streams))
require.Equal(t, 0, len(a2.streams))
}

func TestAssociation_BlockWrite(t *testing.T) {
checkGoroutineLeaks(t)

conn1, conn2 := createUDPConnPair()
a1, a2, err := createAssociationPairWithConfig(conn1, conn2, Config{BlockWrite: true, MaxReceiveBufferSize: 4000})
require.NoError(t, err)

defer noErrorClose(t, a2.Close)
defer noErrorClose(t, a1.Close)
s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary)
require.NoError(t, err)
defer noErrorClose(t, s1.Close)
_, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
require.NoError(t, err)
s2, err := a2.AcceptStream()
require.NoError(t, err)

data := make([]byte, 4000)
n, err := s2.Read(data)
require.NoError(t, err)
require.Equal(t, "hello", string(data[:n]))

// Write should block until data is sent
dbConn1, ok := conn1.(*dumbConn2)
require.True(t, ok)
dbConn2, ok := conn2.(*dumbConn2)
require.True(t, ok)

dbConn1.remoteInboundHandler = dbConn2.inboundHandler

_, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary)
require.NoError(t, err)
_, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary)
require.NoError(t, err)

// test write deadline
// a2's awnd is 0, so write should be blocked
require.NoError(t, s1.SetWriteDeadline(time.Now().Add(100*time.Millisecond)))
_, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary)
require.ErrorIs(t, err, context.DeadlineExceeded, err)

// test write deadline cancel
require.NoError(t, s1.SetWriteDeadline(time.Time{}))
var deadLineCanceled atomic.Bool
writeCanceled := make(chan struct{}, 2)
// both write should be blocked and canceled by deadline
go func() {
_, err1 := s1.WriteSCTP(data, PayloadTypeWebRTCBinary)
require.ErrorIs(t, err, context.DeadlineExceeded, err1)
require.True(t, deadLineCanceled.Load())
writeCanceled <- struct{}{}
}()
go func() {
_, err1 := s1.WriteSCTP(data, PayloadTypeWebRTCBinary)
require.ErrorIs(t, err, context.DeadlineExceeded, err1)
require.True(t, deadLineCanceled.Load())
writeCanceled <- struct{}{}
}()
time.Sleep(100 * time.Millisecond)
deadLineCanceled.Store(true)
require.NoError(t, s1.SetWriteDeadline(time.Now().Add(-1*time.Second)))
<-writeCanceled
<-writeCanceled
require.NoError(t, s1.SetWriteDeadline(time.Time{}))

rn, rerr := s2.Read(data)
require.NoError(t, rerr)
require.Equal(t, 4000, rn)

// slow reader and fast writer, make sure all write is blocked
go func() {
for {
bytes := make([]byte, 4000)
rn, rerr = s2.Read(bytes)
if errors.Is(rerr, io.EOF) {
return
}
require.NoError(t, rerr)
require.Equal(t, 4000, rn)
time.Sleep(5 * time.Millisecond)
}
}()

for i := 0; i < 10; i++ {
_, err = s1.Write(data)
require.NoError(t, err)
// bufferedAmount should not exceed RWND+message size (inflight + pending)
require.LessOrEqual(t, s1.BufferedAmount(), uint64(4000*2))
}
}
43 changes: 37 additions & 6 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/pion/logging"
"github.com/pion/transport/v3/deadline"
)

const (
Expand Down Expand Up @@ -65,6 +66,8 @@ type Stream struct {
readNotifier *sync.Cond
readErr error
readTimeoutCancel chan struct{}
writeDeadline *deadline.Deadline
writeLock sync.Mutex
unordered bool
reliabilityType byte
reliabilityValue uint32
Expand Down Expand Up @@ -272,16 +275,44 @@ func (s *Stream) WriteSCTP(p []byte, ppi PayloadProtocolIdentifier) (int, error)
return 0, ErrStreamClosed
}

chunks := s.packetize(p, ppi)
// the send could fail if the association is blocked for writing (timeout), it will left a hole
// in the stream sequence number space, so we need to lock the write to avoid concurrent send and decrement
// the sequence number in case of failure
if s.association.isBlockWrite() {
s.writeLock.Lock()
}
chunks, unordered := s.packetize(p, ppi)
n := len(p)
err := s.association.sendPayloadData(chunks)
err := s.association.sendPayloadData(s.writeDeadline, chunks)
if err != nil {
return n, ErrStreamClosed
s.lock.Lock()
s.bufferedAmount -= uint64(n)
if !unordered {
s.sequenceNumber--
}
s.lock.Unlock()
}
if s.association.isBlockWrite() {
s.writeLock.Unlock()
}
return n, err
}

// SetWriteDeadline sets the write deadline in an identical way to net.Conn, it will only work for blocking writes
func (s *Stream) SetWriteDeadline(deadline time.Time) error {
s.writeDeadline.Set(deadline)
return nil
}

// SetDeadline sets the read and write deadlines in an identical way to net.Conn
func (s *Stream) SetDeadline(t time.Time) error {
if err := s.SetReadDeadline(t); err != nil {
return err
}
return n, nil
return s.SetWriteDeadline(t)
}

func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPayloadData {
func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) ([]*chunkPayloadData, bool) {
s.lock.Lock()
defer s.lock.Unlock()

Expand Down Expand Up @@ -336,7 +367,7 @@ func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPa
s.bufferedAmount += uint64(len(raw))
s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount)

return chunks
return chunks, unordered
}

// Close closes the write-direction of the stream.
Expand Down

0 comments on commit abb6bea

Please sign in to comment.