diff --git a/association.go b/association.go index 38e26c4f..93836254 100644 --- a/association.go +++ b/association.go @@ -1755,7 +1755,7 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { } // The caller should hold the lock. -func (a *Association) processFastRetransmission(cumTSNAckPoint, htna uint32, cumTSNAckPointAdvanced bool) error { +func (a *Association) processFastRetransmission(cumTSNAckPoint uint32, gapAckBlocks []gapAckBlock, htna uint32, cumTSNAckPointAdvanced bool) error { // HTNA algorithm - RFC 4960 Sec 7.2.4 // Increment missIndicator of each chunks that the SACK reported missing // when either of the following is met: @@ -1772,7 +1772,10 @@ func (a *Association) processFastRetransmission(cumTSNAckPoint, htna uint32, cum maxTSN = htna } else { // b) increment for all TSNs reported missing - maxTSN = cumTSNAckPoint + uint32(a.inflightQueue.size()) + 1 + maxTSN = cumTSNAckPoint + if len(gapAckBlocks) > 0 { + maxTSN += uint32(gapAckBlocks[len(gapAckBlocks)-1].end) + } } for tsn := cumTSNAckPoint + 1; sna32LT(tsn, maxTSN); tsn++ { @@ -1882,7 +1885,7 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error { a.setRWND(d.advertisedReceiverWindowCredit - bytesOutstanding) } - err = a.processFastRetransmission(d.cumulativeTSNAck, htna, cumTSNAckPointAdvanced) + err = a.processFastRetransmission(d.cumulativeTSNAck, d.gapAckBlocks, htna, cumTSNAckPointAdvanced) if err != nil { return err } diff --git a/association_test.go b/association_test.go index 28c3f65d..def1c01c 100644 --- a/association_test.go +++ b/association_test.go @@ -2921,49 +2921,79 @@ func TestAssociationFastRtxWnd(t *testing.T) { } } + // intercept SACK + var lastSACK atomic.Pointer[chunkSelectiveAck] + dbConn2.remoteInboundHandler = func(buf []byte) { + p := &packet{} + require.NoError(t, p.unmarshal(true, buf)) + for _, c := range p.chunks { + if ack, aok := c.(*chunkSelectiveAck); aok { + lastSACK.Store(ack) + } + } + dbConn1.inboundHandler(buf) + } + + _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) + require.NoError(t, err) + require.Eventually(t, func() bool { return lastSACK.Load() != nil }, 1*time.Second, 10*time.Millisecond) + shouldDrop.Store(true) // send packets and dropped - buf := make([]byte, 1000) - for i := 0; i < 10; i++ { + buf := make([]byte, 700) + for i := 0; i < 20; i++ { _, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary) require.NoError(t, err) } - require.Eventually(t, func() bool { return dropCounter.Load() >= 10 }, 5*time.Second, 10*time.Millisecond, "drop %d", dropCounter.Load()) - // send packets to trigger fast retransmit - shouldDrop.Store(false) + require.Eventually(t, func() bool { return dropCounter.Load() >= 15 }, 5*time.Second, 10*time.Millisecond, "drop %d", dropCounter.Load()) require.Zero(t, a1.stats.getNumFastRetrans()) require.False(t, a1.inFastRecovery) - // wait SACK - sackCh := make(chan []byte, 1) - dbConn2.remoteInboundHandler = func(buf []byte) { - p := &packet{} - require.NoError(t, p.unmarshal(true, buf)) - for _, c := range p.chunks { - if _, ok := c.(*chunkSelectiveAck); ok { - select { - case sackCh <- buf: - default: - } - return - } - } - } - // wait sack to trigger fast retransmit - for i := 0; i < 3; i++ { - _, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary) - require.NoError(t, err) - dbConn1.inboundHandler(<-sackCh) + // sack to trigger fast retransmit + ack := *(lastSACK.Load()) + ack.gapAckBlocks = []gapAckBlock{{start: 11}} + for i := 11; i < 14; i++ { + ack.gapAckBlocks[0].end = uint16(i) + pkt := a1.createPacket([]chunk{&ack}) + pktBuf, err1 := pkt.marshal(true) + require.NoError(t, err1) + dbConn1.inboundHandler(pktBuf) } - // fast retransmit and new sack sent + require.Eventually(t, func() bool { a1.lock.RLock() defer a1.lock.RUnlock() return a1.inFastRecovery }, 5*time.Second, 10*time.Millisecond) require.GreaterOrEqual(t, uint64(10), a1.stats.getNumFastRetrans()) + + // 7.2.4 b) In fast-recovery AND the Cumulative TSN Ack Point advanced + // the miss indications are incremented for all TSNs reported missing + // in the SACK. + a1.lock.Lock() + lastTSN := a1.inflightQueue.chunks.Back().tsn + lastTSNMinusTwo := lastTSN - 2 + lastChunk := a1.inflightQueue.chunks.Back() + lastChunkMinusTwo, ok := a1.inflightQueue.get(lastTSNMinusTwo) + a1.lock.Unlock() + require.True(t, ok) + require.True(t, lastTSN > ack.cumulativeTSNAck+uint32(ack.gapAckBlocks[0].end)+3) + + // sack with cumAckPoint advanced, lastTSN should not be marked as missing + ack.cumulativeTSNAck++ + end := lastTSN - 1 - ack.cumulativeTSNAck + ack.gapAckBlocks = append(ack.gapAckBlocks, gapAckBlock{start: uint16(end), end: uint16(end)}) + pkt := a1.createPacket([]chunk{&ack}) + pktBuf, err := pkt.marshal(true) + require.NoError(t, err) + dbConn1.inboundHandler(pktBuf) + require.Eventually(t, func() bool { + a1.lock.Lock() + defer a1.lock.Unlock() + return lastChunkMinusTwo.missIndicator == 1 && lastChunk.missIndicator == 0 + }, 5*time.Second, 10*time.Millisecond) } func TestAssociationMaxTSNOffset(t *testing.T) { diff --git a/queue.go b/queue.go index a6945caf..be8eebb1 100644 --- a/queue.go +++ b/queue.go @@ -47,6 +47,10 @@ func (q *queue[T]) Front() T { return q.buf[q.head] } +func (q *queue[T]) Back() T { + return q.buf[(q.tail-1+len(q.buf))%len(q.buf)] +} + func (q *queue[T]) At(i int) T { return q.buf[(q.head+i)%(len(q.buf))] }