Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add congestion control parameters to config #354

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@
partialBytesAcked uint32
inFastRecovery bool
fastRecoverExitPoint uint32
minCwnd uint32 // Minimum congestion window
fastRtxWnd uint32 // Send window for fast retransmit
cwndCAStep uint32 // Step of congestion window increase at Congestion Avoidance

// RTX & Ack timer
rtoMgr *rtoManager
Expand Down Expand Up @@ -261,8 +264,16 @@
MaxMessageSize uint32
EnableZeroChecksum bool
LoggerFactory logging.LoggerFactory

// congestion control configuration
// RTOMax is the maximum retransmission timeout in milliseconds
RTOMax float64
// Minimum congestion window
MinCwnd uint32
// Send window for fast retransmit
FastRtxWnd uint32
// Step of congestion window increase at Congestion Avoidance
CwndCAStep uint32
}

// Server accepts a SCTP stream over a conn
Expand Down Expand Up @@ -325,6 +336,9 @@
netConn: config.NetConn,
maxReceiveBufferSize: maxReceiveBufferSize,
maxMessageSize: maxMessageSize,
minCwnd: config.MinCwnd,
fastRtxWnd: config.FastRtxWnd,
cwndCAStep: config.CwndCAStep,

// These two max values have us not need to follow
// 5.1.1 where this peer may be incapable of supporting
Expand Down Expand Up @@ -512,7 +526,7 @@
a.log.Debugf("[%s] stats nPackets (out) : %d", a.name, a.stats.getNumPacketsSent())
a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs())
a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived())
a.log.Debugf("[%s] stats nSACKs (out) : %d\n", a.name, a.stats.getNumSACKsSent())
a.log.Debugf("[%s] stats nSACKs (out) : %d", a.name, a.stats.getNumSACKsSent())
a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts())
a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts())
a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans())
Expand Down Expand Up @@ -803,9 +817,13 @@
if a.willRetransmitFast {
a.willRetransmitFast = false

toFastRetrans := []chunk{}
toFastRetrans := []*chunkPayloadData{}
fastRetransSize := commonHeaderSize

fastRetransWnd := a.MTU()
if fastRetransWnd < a.fastRtxWnd {
fastRetransWnd = a.fastRtxWnd
}
for i := 0; ; i++ {
c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1)
if !ok {
Expand All @@ -831,7 +849,7 @@
// packet.

dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData))
if a.MTU() < fastRetransSize+dataChunkSize {
if fastRetransWnd < fastRetransSize+dataChunkSize {
break
}

Expand All @@ -845,10 +863,12 @@
}

if len(toFastRetrans) > 0 {
raw, err := a.marshalPacket(a.createPacket(toFastRetrans))
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name)
} else {
for _, p := range a.bundleDataChunksIntoPackets(toFastRetrans) {
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name)
continue

Check warning on line 870 in association.go

View check run for this annotation

Codecov / codecov/patch

association.go#L869-L870

Added lines #L869 - L870 were not covered by tests
}
rawPackets = append(rawPackets, raw)
}
}
Expand Down Expand Up @@ -1115,6 +1135,9 @@
}

func (a *Association) setCWND(cwnd uint32) {
if cwnd < a.minCwnd {
cwnd = a.minCwnd
}
atomic.StoreUint32(&a.cwnd, cwnd)
}

Expand Down Expand Up @@ -1720,7 +1743,11 @@
// reset partial_bytes_acked to (partial_bytes_acked - cwnd).
if a.partialBytesAcked >= a.CWND() && a.pendingQueue.size() > 0 {
a.partialBytesAcked -= a.CWND()
a.setCWND(a.CWND() + a.MTU())
step := a.MTU()
if step < a.cwndCAStep {
step = a.cwndCAStep
}
a.setCWND(a.CWND() + step)
a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)",
a.name, a.CWND(), a.ssthresh, totalBytesAcked)
}
Expand Down
108 changes: 97 additions & 11 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1839,6 +1839,7 @@ func TestAssocCongestionControl(t *testing.T) {
br := test.NewBridge()

a0, a1, err := createNewAssociationPair(br, ackModeNormal, maxReceiveBufferSize)
a0.cwndCAStep = 2800 // 2 mtu
if !assert.Nil(t, err, "failed to create associations") {
assert.FailNow(t, "failed due to earlier error")
}
Expand Down Expand Up @@ -2735,6 +2736,10 @@ func (d *udpDiscardReader) Read(b []byte) (n int, err error) {
}

func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, *Association, error) {
return createAssociationPairWithConfig(udpConn1, udpConn2, Config{})
}

func createAssociationPairWithConfig(udpConn1 net.Conn, udpConn2 net.Conn, config Config) (*Association, *Association, error) {
loggerFactory := logging.NewDefaultLoggerFactory()

a1Chan := make(chan interface{})
Expand All @@ -2744,10 +2749,10 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association,
defer cancel()

go func() {
a, err2 := createClientWithContext(ctx, Config{
NetConn: udpConn1,
LoggerFactory: loggerFactory,
})
cfg := config
cfg.NetConn = udpConn1
cfg.LoggerFactory = loggerFactory
a, err2 := createClientWithContext(ctx, cfg)
if err2 != nil {
a1Chan <- err2
} else {
Expand All @@ -2756,11 +2761,13 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association,
}()

go func() {
a, err2 := createClientWithContext(ctx, Config{
NetConn: udpConn2,
LoggerFactory: loggerFactory,
MaxReceiveBufferSize: 100_000,
})
cfg := config
cfg.NetConn = udpConn2
cfg.LoggerFactory = loggerFactory
if cfg.MaxReceiveBufferSize == 0 {
cfg.MaxReceiveBufferSize = 100_000
}
a, err2 := createClientWithContext(ctx, cfg)
if err2 != nil {
a2Chan <- err2
} else {
Expand Down Expand Up @@ -2880,6 +2887,85 @@ func TestAssociationReceiveWindow(t *testing.T) {
cancel()
}

func TestAssociationFastRtxWnd(t *testing.T) {
udp1, udp2 := createUDPConnPair()
a1, a2, err := createAssociationPairWithConfig(udp1, udp2, Config{MinCwnd: 14000, FastRtxWnd: 14000})
require.NoError(t, err)
defer noErrorClose(t, a2.Close)
defer noErrorClose(t, a1.Close)
s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary)
require.NoError(t, err)
defer noErrorClose(t, s1.Close)
_, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
require.NoError(t, err)
_, err = a2.AcceptStream()
require.NoError(t, err)

a1.rtoMgr.setRTO(1000, true)
// ack the hello packet
time.Sleep(1 * time.Second)

require.Equal(t, a1.minCwnd, a1.CWND())

var shouldDrop atomic.Bool
var dropCounter atomic.Uint32
dbConn1, ok := udp1.(*dumbConn2)
require.True(t, ok)
dbConn2, ok := udp2.(*dumbConn2)
require.True(t, ok)
dbConn1.remoteInboundHandler = func(packet []byte) {
if !shouldDrop.Load() {
dbConn2.inboundHandler(packet)
} else {
dropCounter.Add(1)
}
}

shouldDrop.Store(true)
// send packets and dropped
buf := make([]byte, 1000)
for i := 0; i < 10; i++ {
_, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary)
require.NoError(t, err)
}

require.Eventually(t, func() bool { return dropCounter.Load() >= 10 }, 5*time.Second, 10*time.Millisecond, "drop %d", dropCounter.Load())
// send packets to trigger fast retransmit
shouldDrop.Store(false)

require.Zero(t, a1.stats.getNumFastRetrans())
require.False(t, a1.inFastRecovery)

// wait SACK
sackCh := make(chan []byte, 1)
dbConn2.remoteInboundHandler = func(buf []byte) {
p := &packet{}
require.NoError(t, p.unmarshal(true, buf))
for _, c := range p.chunks {
if _, ok := c.(*chunkSelectiveAck); ok {
select {
case sackCh <- buf:
default:
}
return
}
}
}
// wait sack to trigger fast retransmit
for i := 0; i < 3; i++ {
_, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary)
require.NoError(t, err)
dbConn1.inboundHandler(<-sackCh)
}
// fast retransmit and new sack sent
require.Eventually(t, func() bool {
a1.lock.RLock()
defer a1.lock.RUnlock()
return a1.inFastRecovery
}, 5*time.Second, 10*time.Millisecond)
require.GreaterOrEqual(t, uint64(10), a1.stats.getNumFastRetrans())
}

func TestAssociationMaxTSNOffset(t *testing.T) {
udp1, udp2 := createUDPConnPair()
// a1 is the association used for sending data
Expand Down Expand Up @@ -3489,10 +3575,10 @@ func TestAssociation_OpenStreamAfterInternalClose(t *testing.T) {
require.NoError(t, a2.netConn.Close())

_, err = a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)
require.True(t, err == nil || errors.Is(err, ErrAssociationClosed))

_, err = a2.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)
require.True(t, err == nil || errors.Is(err, ErrAssociationClosed))

require.NoError(t, a1.Close())
require.NoError(t, a2.Close())
Expand Down
Loading