From 99301e5d0653f45c61012dbf67599f8f40184cc1 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Tue, 26 Nov 2024 11:33:28 +0800 Subject: [PATCH] Add blocking write mode for association Block write is useful for cases like sfu to forward reliab message between clients and detach the data channel to use it as normal golang connections. --- association.go | 46 +++++++++++++++++++++- association_test.go | 95 ++++++++++++++++++++++++++++++++++++++++++++- stream.go | 43 +++++++++++++++++--- 3 files changed, 174 insertions(+), 10 deletions(-) diff --git a/association.go b/association.go index 93836254..559f9880 100644 --- a/association.go +++ b/association.go @@ -17,6 +17,7 @@ import ( "github.com/pion/logging" "github.com/pion/randutil" + "github.com/pion/transport/v3/deadline" ) // Port 5000 shows up in examples for SDPs used by WebRTC. Since this implementation @@ -251,6 +252,10 @@ type Association struct { delayedAckTriggered bool immediateAckTriggered bool + blockWrite bool + writePending bool + writeNotify chan struct{} + name string log logging.LeveledLogger } @@ -264,6 +269,7 @@ type Config struct { MaxMessageSize uint32 EnableZeroChecksum bool LoggerFactory logging.LoggerFactory + BlockWrite bool // congestion control configuration // RTOMax is the maximum retransmission timeout in milliseconds @@ -375,6 +381,8 @@ func createAssociation(config Config) *Association { stats: &associationStats{}, log: config.LoggerFactory.NewLogger("sctp"), name: config.Name, + blockWrite: config.BlockWrite, + writeNotify: make(chan struct{}, 1), } if a.name == "" { @@ -675,6 +683,20 @@ func (a *Association) awakeWriteLoop() { } } +func (a *Association) isBlockWrite() bool { + return a.blockWrite +} + +// Mark the association is writable and unblock the waiting write, +// the caller should hold the association write lock. +func (a *Association) notifyBlockWritable() { + a.writePending = false + select { + case a.writeNotify <- struct{}{}: + default: + } +} + // unregisterStream un-registers a stream from the association // The caller should hold the association write lock. func (a *Association) unregisterStream(s *Stream, err error) { @@ -1555,6 +1577,7 @@ func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream reassemblyQueue: newReassemblyQueue(streamIdentifier), log: a.log, name: fmt.Sprintf("%d:%s", streamIdentifier, a.name), + writeDeadline: deadline.New(), } s.readNotifier = sync.NewCond(&s.lock) @@ -2338,6 +2361,11 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1 } } + if a.blockWrite && len(chunks) > 0 && a.pendingQueue.size() == 0 { + a.log.Tracef("[%s] all pending data have been sent, notify writable", a.name) + a.notifyBlockWritable() + } + return chunks, sisToReset } @@ -2375,21 +2403,35 @@ func (a *Association) bundleDataChunksIntoPackets(chunks []*chunkPayloadData) [] } // sendPayloadData sends the data chunks. -func (a *Association) sendPayloadData(chunks []*chunkPayloadData) error { +func (a *Association) sendPayloadData(ctx context.Context, chunks []*chunkPayloadData) error { a.lock.Lock() - defer a.lock.Unlock() state := a.getState() if state != established { + a.lock.Unlock() return fmt.Errorf("%w: state=%s", ErrPayloadDataStateNotExist, getAssociationStateString(state)) } + if a.blockWrite { + for a.writePending { + a.lock.Unlock() + select { + case <-ctx.Done(): + return ctx.Err() + case <-a.writeNotify: + a.lock.Lock() + } + } + a.writePending = true + } + // Push the chunks into the pending queue first. for _, c := range chunks { a.pendingQueue.push(c) } + a.lock.Unlock() a.awakeWriteLoop() return nil } diff --git a/association_test.go b/association_test.go index def1c01c..44a71acf 100644 --- a/association_test.go +++ b/association_test.go @@ -2845,7 +2845,7 @@ func TestAssociationReceiveWindow(t *testing.T) { done := make(chan bool) go func() { - chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) + chunks, _ := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) chunks = chunks[:1] chunk := chunks[0] // Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue @@ -3016,7 +3016,7 @@ func TestAssociationMaxTSNOffset(t *testing.T) { require.NoError(t, err) require.Equal(t, uint16(1), s2.streamIdentifier) - chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) + chunks, _ := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) chunks = chunks[:1] sendChunk := func(tsn uint32) { chunk := chunks[0] @@ -3616,3 +3616,94 @@ func TestAssociation_OpenStreamAfterInternalClose(t *testing.T) { require.Equal(t, 0, len(a1.streams)) require.Equal(t, 0, len(a2.streams)) } + +func TestAssociation_BlockWrite(t *testing.T) { + checkGoroutineLeaks(t) + + conn1, conn2 := createUDPConnPair() + a1, a2, err := createAssociationPairWithConfig(conn1, conn2, Config{BlockWrite: true, MaxReceiveBufferSize: 4000}) + require.NoError(t, err) + + defer noErrorClose(t, a2.Close) + defer noErrorClose(t, a1.Close) + s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary) + require.NoError(t, err) + defer noErrorClose(t, s1.Close) + _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) + require.NoError(t, err) + s2, err := a2.AcceptStream() + require.NoError(t, err) + + data := make([]byte, 4000) + n, err := s2.Read(data) + require.NoError(t, err) + require.Equal(t, "hello", string(data[:n])) + + // Write should block until data is sent + dbConn1, ok := conn1.(*dumbConn2) + require.True(t, ok) + dbConn2, ok := conn2.(*dumbConn2) + require.True(t, ok) + + dbConn1.remoteInboundHandler = dbConn2.inboundHandler + + _, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary) + require.NoError(t, err) + _, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary) + require.NoError(t, err) + + // test write deadline + // a2's awnd is 0, so write should be blocked + require.NoError(t, s1.SetWriteDeadline(time.Now().Add(100*time.Millisecond))) + _, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary) + require.ErrorIs(t, err, context.DeadlineExceeded, err) + + // test write deadline cancel + require.NoError(t, s1.SetWriteDeadline(time.Time{})) + var deadLineCanceled atomic.Bool + writeCanceled := make(chan struct{}, 2) + // both write should be blocked and canceled by deadline + go func() { + _, err1 := s1.WriteSCTP(data, PayloadTypeWebRTCBinary) + require.ErrorIs(t, err, context.DeadlineExceeded, err1) + require.True(t, deadLineCanceled.Load()) + writeCanceled <- struct{}{} + }() + go func() { + _, err1 := s1.WriteSCTP(data, PayloadTypeWebRTCBinary) + require.ErrorIs(t, err, context.DeadlineExceeded, err1) + require.True(t, deadLineCanceled.Load()) + writeCanceled <- struct{}{} + }() + time.Sleep(100 * time.Millisecond) + deadLineCanceled.Store(true) + require.NoError(t, s1.SetWriteDeadline(time.Now().Add(-1*time.Second))) + <-writeCanceled + <-writeCanceled + require.NoError(t, s1.SetWriteDeadline(time.Time{})) + + rn, rerr := s2.Read(data) + require.NoError(t, rerr) + require.Equal(t, 4000, rn) + + // slow reader and fast writer, make sure all write is blocked + go func() { + for { + bytes := make([]byte, 4000) + rn, rerr = s2.Read(bytes) + if errors.Is(rerr, io.EOF) { + return + } + require.NoError(t, rerr) + require.Equal(t, 4000, rn) + time.Sleep(5 * time.Millisecond) + } + }() + + for i := 0; i < 10; i++ { + _, err = s1.Write(data) + require.NoError(t, err) + // bufferedAmount should not exceed RWND+message size (inflight + pending) + require.LessOrEqual(t, s1.BufferedAmount(), uint64(4000*2)) + } +} diff --git a/stream.go b/stream.go index 47e06f35..ed06d69a 100644 --- a/stream.go +++ b/stream.go @@ -13,6 +13,7 @@ import ( "time" "github.com/pion/logging" + "github.com/pion/transport/v3/deadline" ) const ( @@ -65,6 +66,8 @@ type Stream struct { readNotifier *sync.Cond readErr error readTimeoutCancel chan struct{} + writeDeadline *deadline.Deadline + writeLock sync.Mutex unordered bool reliabilityType byte reliabilityValue uint32 @@ -272,16 +275,44 @@ func (s *Stream) WriteSCTP(p []byte, ppi PayloadProtocolIdentifier) (int, error) return 0, ErrStreamClosed } - chunks := s.packetize(p, ppi) + // the send could fail if the association is blocked for writing (timeout), it will left a hole + // in the stream sequence number space, so we need to lock the write to avoid concurrent send and decrement + // the sequence number in case of failure + if s.association.isBlockWrite() { + s.writeLock.Lock() + } + chunks, unordered := s.packetize(p, ppi) n := len(p) - err := s.association.sendPayloadData(chunks) + err := s.association.sendPayloadData(s.writeDeadline, chunks) if err != nil { - return n, ErrStreamClosed + s.lock.Lock() + s.bufferedAmount -= uint64(n) + if !unordered { + s.sequenceNumber-- + } + s.lock.Unlock() + } + if s.association.isBlockWrite() { + s.writeLock.Unlock() + } + return n, err +} + +// SetWriteDeadline sets the write deadline in an identical way to net.Conn, it will only work for blocking writes +func (s *Stream) SetWriteDeadline(deadline time.Time) error { + s.writeDeadline.Set(deadline) + return nil +} + +// SetDeadline sets the read and write deadlines in an identical way to net.Conn +func (s *Stream) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err } - return n, nil + return s.SetWriteDeadline(t) } -func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPayloadData { +func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) ([]*chunkPayloadData, bool) { s.lock.Lock() defer s.lock.Unlock() @@ -336,7 +367,7 @@ func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPa s.bufferedAmount += uint64(len(raw)) s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount) - return chunks + return chunks, unordered } // Close closes the write-direction of the stream.