diff --git a/association.go b/association.go index 29e9978c..38e26c4f 100644 --- a/association.go +++ b/association.go @@ -212,6 +212,9 @@ type Association struct { partialBytesAcked uint32 inFastRecovery bool fastRecoverExitPoint uint32 + minCwnd uint32 // Minimum congestion window + fastRtxWnd uint32 // Send window for fast retransmit + cwndCAStep uint32 // Step of congestion window increase at Congestion Avoidance // RTX & Ack timer rtoMgr *rtoManager @@ -261,8 +264,16 @@ type Config struct { MaxMessageSize uint32 EnableZeroChecksum bool LoggerFactory logging.LoggerFactory + + // congestion control configuration // RTOMax is the maximum retransmission timeout in milliseconds RTOMax float64 + // Minimum congestion window + MinCwnd uint32 + // Send window for fast retransmit + FastRtxWnd uint32 + // Step of congestion window increase at Congestion Avoidance + CwndCAStep uint32 } // Server accepts a SCTP stream over a conn @@ -325,6 +336,9 @@ func createAssociation(config Config) *Association { netConn: config.NetConn, maxReceiveBufferSize: maxReceiveBufferSize, maxMessageSize: maxMessageSize, + minCwnd: config.MinCwnd, + fastRtxWnd: config.FastRtxWnd, + cwndCAStep: config.CwndCAStep, // These two max values have us not need to follow // 5.1.1 where this peer may be incapable of supporting @@ -512,7 +526,7 @@ func (a *Association) Close() error { a.log.Debugf("[%s] stats nPackets (out) : %d", a.name, a.stats.getNumPacketsSent()) a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived()) - a.log.Debugf("[%s] stats nSACKs (out) : %d\n", a.name, a.stats.getNumSACKsSent()) + a.log.Debugf("[%s] stats nSACKs (out) : %d", a.name, a.stats.getNumSACKsSent()) a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) @@ -803,9 +817,13 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt if a.willRetransmitFast { a.willRetransmitFast = false - toFastRetrans := []chunk{} + toFastRetrans := []*chunkPayloadData{} fastRetransSize := commonHeaderSize + fastRetransWnd := a.MTU() + if fastRetransWnd < a.fastRtxWnd { + fastRetransWnd = a.fastRtxWnd + } for i := 0; ; i++ { c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) if !ok { @@ -831,7 +849,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt // packet. dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData)) - if a.MTU() < fastRetransSize+dataChunkSize { + if fastRetransWnd < fastRetransSize+dataChunkSize { break } @@ -845,10 +863,12 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt } if len(toFastRetrans) > 0 { - raw, err := a.marshalPacket(a.createPacket(toFastRetrans)) - if err != nil { - a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) - } else { + for _, p := range a.bundleDataChunksIntoPackets(toFastRetrans) { + raw, err := a.marshalPacket(p) + if err != nil { + a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) + continue + } rawPackets = append(rawPackets, raw) } } @@ -1115,6 +1135,9 @@ func (a *Association) CWND() uint32 { } func (a *Association) setCWND(cwnd uint32) { + if cwnd < a.minCwnd { + cwnd = a.minCwnd + } atomic.StoreUint32(&a.cwnd, cwnd) } @@ -1720,7 +1743,11 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { // reset partial_bytes_acked to (partial_bytes_acked - cwnd). if a.partialBytesAcked >= a.CWND() && a.pendingQueue.size() > 0 { a.partialBytesAcked -= a.CWND() - a.setCWND(a.CWND() + a.MTU()) + step := a.MTU() + if step < a.cwndCAStep { + step = a.cwndCAStep + } + a.setCWND(a.CWND() + step) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)", a.name, a.CWND(), a.ssthresh, totalBytesAcked) } diff --git a/association_test.go b/association_test.go index a8611318..28c3f65d 100644 --- a/association_test.go +++ b/association_test.go @@ -1839,6 +1839,7 @@ func TestAssocCongestionControl(t *testing.T) { br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNormal, maxReceiveBufferSize) + a0.cwndCAStep = 2800 // 2 mtu if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -2735,6 +2736,10 @@ func (d *udpDiscardReader) Read(b []byte) (n int, err error) { } func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, *Association, error) { + return createAssociationPairWithConfig(udpConn1, udpConn2, Config{}) +} + +func createAssociationPairWithConfig(udpConn1 net.Conn, udpConn2 net.Conn, config Config) (*Association, *Association, error) { loggerFactory := logging.NewDefaultLoggerFactory() a1Chan := make(chan interface{}) @@ -2744,10 +2749,10 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, defer cancel() go func() { - a, err2 := createClientWithContext(ctx, Config{ - NetConn: udpConn1, - LoggerFactory: loggerFactory, - }) + cfg := config + cfg.NetConn = udpConn1 + cfg.LoggerFactory = loggerFactory + a, err2 := createClientWithContext(ctx, cfg) if err2 != nil { a1Chan <- err2 } else { @@ -2756,11 +2761,13 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, }() go func() { - a, err2 := createClientWithContext(ctx, Config{ - NetConn: udpConn2, - LoggerFactory: loggerFactory, - MaxReceiveBufferSize: 100_000, - }) + cfg := config + cfg.NetConn = udpConn2 + cfg.LoggerFactory = loggerFactory + if cfg.MaxReceiveBufferSize == 0 { + cfg.MaxReceiveBufferSize = 100_000 + } + a, err2 := createClientWithContext(ctx, cfg) if err2 != nil { a2Chan <- err2 } else { @@ -2880,6 +2887,85 @@ func TestAssociationReceiveWindow(t *testing.T) { cancel() } +func TestAssociationFastRtxWnd(t *testing.T) { + udp1, udp2 := createUDPConnPair() + a1, a2, err := createAssociationPairWithConfig(udp1, udp2, Config{MinCwnd: 14000, FastRtxWnd: 14000}) + require.NoError(t, err) + defer noErrorClose(t, a2.Close) + defer noErrorClose(t, a1.Close) + s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary) + require.NoError(t, err) + defer noErrorClose(t, s1.Close) + _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) + require.NoError(t, err) + _, err = a2.AcceptStream() + require.NoError(t, err) + + a1.rtoMgr.setRTO(1000, true) + // ack the hello packet + time.Sleep(1 * time.Second) + + require.Equal(t, a1.minCwnd, a1.CWND()) + + var shouldDrop atomic.Bool + var dropCounter atomic.Uint32 + dbConn1, ok := udp1.(*dumbConn2) + require.True(t, ok) + dbConn2, ok := udp2.(*dumbConn2) + require.True(t, ok) + dbConn1.remoteInboundHandler = func(packet []byte) { + if !shouldDrop.Load() { + dbConn2.inboundHandler(packet) + } else { + dropCounter.Add(1) + } + } + + shouldDrop.Store(true) + // send packets and dropped + buf := make([]byte, 1000) + for i := 0; i < 10; i++ { + _, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary) + require.NoError(t, err) + } + + require.Eventually(t, func() bool { return dropCounter.Load() >= 10 }, 5*time.Second, 10*time.Millisecond, "drop %d", dropCounter.Load()) + // send packets to trigger fast retransmit + shouldDrop.Store(false) + + require.Zero(t, a1.stats.getNumFastRetrans()) + require.False(t, a1.inFastRecovery) + + // wait SACK + sackCh := make(chan []byte, 1) + dbConn2.remoteInboundHandler = func(buf []byte) { + p := &packet{} + require.NoError(t, p.unmarshal(true, buf)) + for _, c := range p.chunks { + if _, ok := c.(*chunkSelectiveAck); ok { + select { + case sackCh <- buf: + default: + } + return + } + } + } + // wait sack to trigger fast retransmit + for i := 0; i < 3; i++ { + _, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary) + require.NoError(t, err) + dbConn1.inboundHandler(<-sackCh) + } + // fast retransmit and new sack sent + require.Eventually(t, func() bool { + a1.lock.RLock() + defer a1.lock.RUnlock() + return a1.inFastRecovery + }, 5*time.Second, 10*time.Millisecond) + require.GreaterOrEqual(t, uint64(10), a1.stats.getNumFastRetrans()) +} + func TestAssociationMaxTSNOffset(t *testing.T) { udp1, udp2 := createUDPConnPair() // a1 is the association used for sending data @@ -3489,10 +3575,10 @@ func TestAssociation_OpenStreamAfterInternalClose(t *testing.T) { require.NoError(t, a2.netConn.Close()) _, err = a1.OpenStream(1, PayloadTypeWebRTCString) - require.NoError(t, err) + require.True(t, err == nil || errors.Is(err, ErrAssociationClosed)) _, err = a2.OpenStream(1, PayloadTypeWebRTCString) - require.NoError(t, err) + require.True(t, err == nil || errors.Is(err, ErrAssociationClosed)) require.NoError(t, a1.Close()) require.NoError(t, a2.Close())