From 1d98af60f119f0634a74d8967deb2c549c30118e Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Mon, 18 Nov 2024 17:13:35 +0800 Subject: [PATCH] Add congestion control parameters to config The loss-based congestion control get poor performance under high bandwidth, high rtt and packet loss case since the congestion window becomes 1 mtu and increase slowly after retransmit timeout. And fast recovery retransmit cause exit slowly in consecutive packet loss. This change add paramters to the config then the user can set them to get higher throughput in such cases. --- association.go | 43 ++++++++++++++---- association_test.go | 108 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 132 insertions(+), 19 deletions(-) diff --git a/association.go b/association.go index 29e9978c..38e26c4f 100644 --- a/association.go +++ b/association.go @@ -212,6 +212,9 @@ type Association struct { partialBytesAcked uint32 inFastRecovery bool fastRecoverExitPoint uint32 + minCwnd uint32 // Minimum congestion window + fastRtxWnd uint32 // Send window for fast retransmit + cwndCAStep uint32 // Step of congestion window increase at Congestion Avoidance // RTX & Ack timer rtoMgr *rtoManager @@ -261,8 +264,16 @@ type Config struct { MaxMessageSize uint32 EnableZeroChecksum bool LoggerFactory logging.LoggerFactory + + // congestion control configuration // RTOMax is the maximum retransmission timeout in milliseconds RTOMax float64 + // Minimum congestion window + MinCwnd uint32 + // Send window for fast retransmit + FastRtxWnd uint32 + // Step of congestion window increase at Congestion Avoidance + CwndCAStep uint32 } // Server accepts a SCTP stream over a conn @@ -325,6 +336,9 @@ func createAssociation(config Config) *Association { netConn: config.NetConn, maxReceiveBufferSize: maxReceiveBufferSize, maxMessageSize: maxMessageSize, + minCwnd: config.MinCwnd, + fastRtxWnd: config.FastRtxWnd, + cwndCAStep: config.CwndCAStep, // These two max values have us not need to follow // 5.1.1 where this peer may be incapable of supporting @@ -512,7 +526,7 @@ func (a *Association) Close() error { a.log.Debugf("[%s] stats nPackets (out) : %d", a.name, a.stats.getNumPacketsSent()) a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived()) - a.log.Debugf("[%s] stats nSACKs (out) : %d\n", a.name, a.stats.getNumSACKsSent()) + a.log.Debugf("[%s] stats nSACKs (out) : %d", a.name, a.stats.getNumSACKsSent()) a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) @@ -803,9 +817,13 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt if a.willRetransmitFast { a.willRetransmitFast = false - toFastRetrans := []chunk{} + toFastRetrans := []*chunkPayloadData{} fastRetransSize := commonHeaderSize + fastRetransWnd := a.MTU() + if fastRetransWnd < a.fastRtxWnd { + fastRetransWnd = a.fastRtxWnd + } for i := 0; ; i++ { c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) if !ok { @@ -831,7 +849,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt // packet. dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData)) - if a.MTU() < fastRetransSize+dataChunkSize { + if fastRetransWnd < fastRetransSize+dataChunkSize { break } @@ -845,10 +863,12 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt } if len(toFastRetrans) > 0 { - raw, err := a.marshalPacket(a.createPacket(toFastRetrans)) - if err != nil { - a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) - } else { + for _, p := range a.bundleDataChunksIntoPackets(toFastRetrans) { + raw, err := a.marshalPacket(p) + if err != nil { + a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) + continue + } rawPackets = append(rawPackets, raw) } } @@ -1115,6 +1135,9 @@ func (a *Association) CWND() uint32 { } func (a *Association) setCWND(cwnd uint32) { + if cwnd < a.minCwnd { + cwnd = a.minCwnd + } atomic.StoreUint32(&a.cwnd, cwnd) } @@ -1720,7 +1743,11 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { // reset partial_bytes_acked to (partial_bytes_acked - cwnd). if a.partialBytesAcked >= a.CWND() && a.pendingQueue.size() > 0 { a.partialBytesAcked -= a.CWND() - a.setCWND(a.CWND() + a.MTU()) + step := a.MTU() + if step < a.cwndCAStep { + step = a.cwndCAStep + } + a.setCWND(a.CWND() + step) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)", a.name, a.CWND(), a.ssthresh, totalBytesAcked) } diff --git a/association_test.go b/association_test.go index a8611318..28c3f65d 100644 --- a/association_test.go +++ b/association_test.go @@ -1839,6 +1839,7 @@ func TestAssocCongestionControl(t *testing.T) { br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNormal, maxReceiveBufferSize) + a0.cwndCAStep = 2800 // 2 mtu if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } @@ -2735,6 +2736,10 @@ func (d *udpDiscardReader) Read(b []byte) (n int, err error) { } func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, *Association, error) { + return createAssociationPairWithConfig(udpConn1, udpConn2, Config{}) +} + +func createAssociationPairWithConfig(udpConn1 net.Conn, udpConn2 net.Conn, config Config) (*Association, *Association, error) { loggerFactory := logging.NewDefaultLoggerFactory() a1Chan := make(chan interface{}) @@ -2744,10 +2749,10 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, defer cancel() go func() { - a, err2 := createClientWithContext(ctx, Config{ - NetConn: udpConn1, - LoggerFactory: loggerFactory, - }) + cfg := config + cfg.NetConn = udpConn1 + cfg.LoggerFactory = loggerFactory + a, err2 := createClientWithContext(ctx, cfg) if err2 != nil { a1Chan <- err2 } else { @@ -2756,11 +2761,13 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, }() go func() { - a, err2 := createClientWithContext(ctx, Config{ - NetConn: udpConn2, - LoggerFactory: loggerFactory, - MaxReceiveBufferSize: 100_000, - }) + cfg := config + cfg.NetConn = udpConn2 + cfg.LoggerFactory = loggerFactory + if cfg.MaxReceiveBufferSize == 0 { + cfg.MaxReceiveBufferSize = 100_000 + } + a, err2 := createClientWithContext(ctx, cfg) if err2 != nil { a2Chan <- err2 } else { @@ -2880,6 +2887,85 @@ func TestAssociationReceiveWindow(t *testing.T) { cancel() } +func TestAssociationFastRtxWnd(t *testing.T) { + udp1, udp2 := createUDPConnPair() + a1, a2, err := createAssociationPairWithConfig(udp1, udp2, Config{MinCwnd: 14000, FastRtxWnd: 14000}) + require.NoError(t, err) + defer noErrorClose(t, a2.Close) + defer noErrorClose(t, a1.Close) + s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary) + require.NoError(t, err) + defer noErrorClose(t, s1.Close) + _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) + require.NoError(t, err) + _, err = a2.AcceptStream() + require.NoError(t, err) + + a1.rtoMgr.setRTO(1000, true) + // ack the hello packet + time.Sleep(1 * time.Second) + + require.Equal(t, a1.minCwnd, a1.CWND()) + + var shouldDrop atomic.Bool + var dropCounter atomic.Uint32 + dbConn1, ok := udp1.(*dumbConn2) + require.True(t, ok) + dbConn2, ok := udp2.(*dumbConn2) + require.True(t, ok) + dbConn1.remoteInboundHandler = func(packet []byte) { + if !shouldDrop.Load() { + dbConn2.inboundHandler(packet) + } else { + dropCounter.Add(1) + } + } + + shouldDrop.Store(true) + // send packets and dropped + buf := make([]byte, 1000) + for i := 0; i < 10; i++ { + _, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary) + require.NoError(t, err) + } + + require.Eventually(t, func() bool { return dropCounter.Load() >= 10 }, 5*time.Second, 10*time.Millisecond, "drop %d", dropCounter.Load()) + // send packets to trigger fast retransmit + shouldDrop.Store(false) + + require.Zero(t, a1.stats.getNumFastRetrans()) + require.False(t, a1.inFastRecovery) + + // wait SACK + sackCh := make(chan []byte, 1) + dbConn2.remoteInboundHandler = func(buf []byte) { + p := &packet{} + require.NoError(t, p.unmarshal(true, buf)) + for _, c := range p.chunks { + if _, ok := c.(*chunkSelectiveAck); ok { + select { + case sackCh <- buf: + default: + } + return + } + } + } + // wait sack to trigger fast retransmit + for i := 0; i < 3; i++ { + _, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary) + require.NoError(t, err) + dbConn1.inboundHandler(<-sackCh) + } + // fast retransmit and new sack sent + require.Eventually(t, func() bool { + a1.lock.RLock() + defer a1.lock.RUnlock() + return a1.inFastRecovery + }, 5*time.Second, 10*time.Millisecond) + require.GreaterOrEqual(t, uint64(10), a1.stats.getNumFastRetrans()) +} + func TestAssociationMaxTSNOffset(t *testing.T) { udp1, udp2 := createUDPConnPair() // a1 is the association used for sending data @@ -3489,10 +3575,10 @@ func TestAssociation_OpenStreamAfterInternalClose(t *testing.T) { require.NoError(t, a2.netConn.Close()) _, err = a1.OpenStream(1, PayloadTypeWebRTCString) - require.NoError(t, err) + require.True(t, err == nil || errors.Is(err, ErrAssociationClosed)) _, err = a2.OpenStream(1, PayloadTypeWebRTCString) - require.NoError(t, err) + require.True(t, err == nil || errors.Is(err, ErrAssociationClosed)) require.NoError(t, a1.Close()) require.NoError(t, a2.Close())