diff --git a/association.go b/association.go index 5ec597e9..466e620d 100644 --- a/association.go +++ b/association.go @@ -237,6 +237,8 @@ type Config struct { MaxMessageSize uint32 EnableZeroChecksum bool LoggerFactory logging.LoggerFactory + // RTOMax is the maximum retransmission timeout in milliseconds + RTOMax float64 } // Server accepts a SCTP stream over a conn @@ -312,7 +314,7 @@ func createAssociation(config Config) *Association { myNextRSN: tsn, minTSN2MeasureRTT: tsn, state: closed, - rtoMgr: newRTOManager(), + rtoMgr: newRTOManager(config.RTOMax), streams: map[uint16]*Stream{}, reconfigs: map[uint32]*chunkReconfig{}, reconfigRequests: map[uint32]*paramOutgoingResetRequest{}, @@ -340,11 +342,11 @@ func createAssociation(config Config) *Association { a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) a.srtt.Store(float64(0)) - a.t1Init = newRTXTimer(timerT1Init, a, maxInitRetrans) - a.t1Cookie = newRTXTimer(timerT1Cookie, a, maxInitRetrans) - a.t2Shutdown = newRTXTimer(timerT2Shutdown, a, noMaxRetrans) // retransmit forever - a.t3RTX = newRTXTimer(timerT3RTX, a, noMaxRetrans) // retransmit forever - a.tReconfig = newRTXTimer(timerReconfig, a, noMaxRetrans) // retransmit forever + a.t1Init = newRTXTimer(timerT1Init, a, maxInitRetrans, config.RTOMax) + a.t1Cookie = newRTXTimer(timerT1Cookie, a, maxInitRetrans, config.RTOMax) + a.t2Shutdown = newRTXTimer(timerT2Shutdown, a, noMaxRetrans, config.RTOMax) + a.t3RTX = newRTXTimer(timerT3RTX, a, noMaxRetrans, config.RTOMax) + a.tReconfig = newRTXTimer(timerReconfig, a, noMaxRetrans, config.RTOMax) a.ackTimer = newAckTimer(a) return a diff --git a/rtx_timer.go b/rtx_timer.go index ceb44301..13861813 100644 --- a/rtx_timer.go +++ b/rtx_timer.go @@ -12,7 +12,7 @@ import ( const ( rtoInitial float64 = 1.0 * 1000 // msec rtoMin float64 = 1.0 * 1000 // msec - rtoMax float64 = 60.0 * 1000 // msec + defaultRTOMax float64 = 60.0 * 1000 // msec rtoAlpha float64 = 0.125 rtoBeta float64 = 0.25 maxInitRetrans uint = 8 @@ -28,13 +28,20 @@ type rtoManager struct { rto float64 noUpdate bool mutex sync.RWMutex + rtoMax float64 } // newRTOManager creates a new rtoManager. -func newRTOManager() *rtoManager { - return &rtoManager{ - rto: rtoInitial, +func newRTOManager(rtoMax float64) *rtoManager { + mgr := rtoManager{ + rto: rtoInitial, + rtoMax: rtoMax, } + if mgr.rtoMax == 0 { + mgr.rtoMax = defaultRTOMax + } + return &mgr + } // setNewRTT takes a newly measured RTT then adjust the RTO in msec. @@ -55,7 +62,7 @@ func (m *rtoManager) setNewRTT(rtt float64) float64 { m.rttvar = (1-rtoBeta)*m.rttvar + rtoBeta*(math.Abs(m.srtt-rtt)) m.srtt = (1-rtoAlpha)*m.srtt + rtoAlpha*rtt } - m.rto = math.Min(math.Max(m.srtt+4*m.rttvar, rtoMin), rtoMax) + m.rto = math.Min(math.Max(m.srtt+4*m.rttvar, rtoMin), m.rtoMax) return m.srtt } @@ -106,6 +113,7 @@ type rtxTimer struct { stopFunc stopTimerLoop closed bool mutex sync.RWMutex + rtoMax float64 } type stopTimerLoop func() @@ -113,12 +121,19 @@ type stopTimerLoop func() // newRTXTimer creates a new retransmission timer. // if maxRetrans is set to 0, it will keep retransmitting until stop() is called. // (it will never make onRetransmissionFailure() callback. -func newRTXTimer(id int, observer rtxTimerObserver, maxRetrans uint) *rtxTimer { - return &rtxTimer{ +func newRTXTimer(id int, observer rtxTimerObserver, maxRetrans uint, + rtoMax float64) *rtxTimer { + + timer := rtxTimer{ id: id, observer: observer, maxRetrans: maxRetrans, + rtoMax: rtoMax, + } + if timer.rtoMax == 0 { + timer.rtoMax = defaultRTOMax } + return &timer } // start starts the timer. @@ -148,7 +163,7 @@ func (t *rtxTimer) start(rto float64) bool { canceling := false for !canceling { - timeout := calculateNextTimeout(rto, nRtos) + timeout := calculateNextTimeout(rto, nRtos, t.rtoMax) timer := time.NewTimer(time.Duration(timeout) * time.Millisecond) select { @@ -208,7 +223,7 @@ func (t *rtxTimer) isRunning() bool { return (t.stopFunc != nil) } -func calculateNextTimeout(rto float64, nRtos uint) float64 { +func calculateNextTimeout(rto float64, nRtos uint, rtoMax float64) float64 { // RFC 4096 sec 6.3.3. Handle T3-rtx Expiration // E2) For the destination address for which the timer expires, set RTO // <- RTO * 2 ("back off the timer"). The maximum value discussed diff --git a/rtx_timer_test.go b/rtx_timer_test.go index df47a6e9..678103df 100644 --- a/rtx_timer_test.go +++ b/rtx_timer_test.go @@ -14,7 +14,7 @@ import ( func TestRTOManager(t *testing.T) { t.Run("initial values", func(t *testing.T) { - m := newRTOManager() + m := newRTOManager(0) assert.Equal(t, rtoInitial, m.rto, "should be rtoInitial") assert.Equal(t, rtoInitial, m.getRTO(), "should be rtoInitial") assert.Equal(t, float64(0), m.srtt, "should be 0") @@ -23,7 +23,7 @@ func TestRTOManager(t *testing.T) { t.Run("RTO calculation (small RTT)", func(t *testing.T) { var rto float64 - m := newRTOManager() + m := newRTOManager(0) exp := []int32{ 1800, 1500, @@ -41,7 +41,7 @@ func TestRTOManager(t *testing.T) { t.Run("RTO calculation (large RTT)", func(t *testing.T) { var rto float64 - m := newRTOManager() + m := newRTOManager(0) exp := []int32{ 60000, // capped at RTO.Max 60000, // capped at RTO.Max @@ -59,22 +59,33 @@ func TestRTOManager(t *testing.T) { t.Run("calculateNextTimeout", func(t *testing.T) { var rto float64 - rto = calculateNextTimeout(1.0, 0) + rto = calculateNextTimeout(1.0, 0, defaultRTOMax) assert.Equal(t, float64(1), rto, "should match") - rto = calculateNextTimeout(1.0, 1) + rto = calculateNextTimeout(1.0, 1, defaultRTOMax) assert.Equal(t, float64(2), rto, "should match") - rto = calculateNextTimeout(1.0, 2) + rto = calculateNextTimeout(1.0, 2, defaultRTOMax) assert.Equal(t, float64(4), rto, "should match") - rto = calculateNextTimeout(1.0, 30) + rto = calculateNextTimeout(1.0, 30, defaultRTOMax) assert.Equal(t, float64(60000), rto, "should match") - rto = calculateNextTimeout(1.0, 63) + rto = calculateNextTimeout(1.0, 63, defaultRTOMax) assert.Equal(t, float64(60000), rto, "should match") - rto = calculateNextTimeout(1.0, 64) + rto = calculateNextTimeout(1.0, 64, defaultRTOMax) assert.Equal(t, float64(60000), rto, "should match") }) + t.Run("calculateNextTimeout w/ RTOMax", func(t *testing.T) { + var rto float64 + rto = calculateNextTimeout(1.0, 0, 2.0) + assert.Equal(t, 1.0, rto, "should match") + rto = calculateNextTimeout(1.5, 1, 2.0) + assert.Equal(t, 2.0, rto, "should match") + rto = calculateNextTimeout(1.0, 10, 2.0) + assert.Equal(t, 2.0, rto, "should match") + rto = calculateNextTimeout(1.0, 31, 1000.0) + assert.Equal(t, 1000.0, rto, "should match") + }) t.Run("reset", func(t *testing.T) { - m := newRTOManager() + m := newRTOManager(0) for i := 0; i < 10; i++ { m.setNewRTT(200) } @@ -118,7 +129,7 @@ func TestRtxTimer(t *testing.T) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(id int) {}, - }, pathMaxRetrans) + }, pathMaxRetrans, 0) assert.False(t, rt.isRunning(), "should not be running") @@ -144,7 +155,7 @@ func TestRtxTimer(t *testing.T) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(id int) {}, - }, pathMaxRetrans) + }, pathMaxRetrans, 0) interval := float64(30.0) ok := rt.start(interval) @@ -171,7 +182,7 @@ func TestRtxTimer(t *testing.T) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(id int) {}, - }, pathMaxRetrans) + }, pathMaxRetrans, 0) interval := float64(30.0) ok := rt.start(interval) @@ -194,7 +205,7 @@ func TestRtxTimer(t *testing.T) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(id int) {}, - }, pathMaxRetrans) + }, pathMaxRetrans, 0) interval := float64(30.0) ok := rt.start(interval) @@ -221,7 +232,7 @@ func TestRtxTimer(t *testing.T) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(id int) {}, - }, pathMaxRetrans) + }, pathMaxRetrans, 0) for i := 0; i < 1000; i++ { ok := rt.start(30) @@ -253,7 +264,7 @@ func TestRtxTimer(t *testing.T) { t.Logf("onRtxFailure: elapsed=%.03f\n", elapsed) doneCh <- true }, - }, pathMaxRetrans) + }, pathMaxRetrans, 0) // RTO(msec) Total(msec) // 10 10 1st RTO @@ -297,7 +308,7 @@ func TestRtxTimer(t *testing.T) { onRtxFailure: func(id int) { assert.Fail(t, "timer should not fail") }, - }, 0) + }, 0, 0) // RTO(msec) Total(msec) // 10 10 1st RTO @@ -332,7 +343,7 @@ func TestRtxTimer(t *testing.T) { doneCh <- true }, onRtxFailure: func(id int) {}, - }, pathMaxRetrans) + }, pathMaxRetrans, 0) for i := 0; i < 10; i++ { rt.stop() @@ -355,7 +366,7 @@ func TestRtxTimer(t *testing.T) { rtoCount++ }, onRtxFailure: func(id int) {}, - }, pathMaxRetrans) + }, pathMaxRetrans, 0) ok := rt.start(20) assert.True(t, ok, "should be accepted")