diff --git a/association_test.go b/association_test.go index 28c3f65d..5c80cda9 100644 --- a/association_test.go +++ b/association_test.go @@ -2921,49 +2921,79 @@ func TestAssociationFastRtxWnd(t *testing.T) { } } + // intercept SACK + var lastSACK atomic.Pointer[chunkSelectiveAck] + dbConn2.remoteInboundHandler = func(buf []byte) { + p := &packet{} + require.NoError(t, p.unmarshal(true, buf)) + for _, c := range p.chunks { + if ack, ok := c.(*chunkSelectiveAck); ok { + lastSACK.Store(ack) + } + } + dbConn1.inboundHandler(buf) + } + + _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) + require.NoError(t, err) + require.Eventually(t, func() bool { return lastSACK.Load() != nil }, 1*time.Second, 10*time.Millisecond) + shouldDrop.Store(true) // send packets and dropped - buf := make([]byte, 1000) - for i := 0; i < 10; i++ { + buf := make([]byte, 700) + for i := 0; i < 20; i++ { _, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary) require.NoError(t, err) } - require.Eventually(t, func() bool { return dropCounter.Load() >= 10 }, 5*time.Second, 10*time.Millisecond, "drop %d", dropCounter.Load()) - // send packets to trigger fast retransmit - shouldDrop.Store(false) + require.Eventually(t, func() bool { return dropCounter.Load() >= 15 }, 5*time.Second, 10*time.Millisecond, "drop %d", dropCounter.Load()) require.Zero(t, a1.stats.getNumFastRetrans()) require.False(t, a1.inFastRecovery) - // wait SACK - sackCh := make(chan []byte, 1) - dbConn2.remoteInboundHandler = func(buf []byte) { - p := &packet{} - require.NoError(t, p.unmarshal(true, buf)) - for _, c := range p.chunks { - if _, ok := c.(*chunkSelectiveAck); ok { - select { - case sackCh <- buf: - default: - } - return - } - } - } - // wait sack to trigger fast retransmit - for i := 0; i < 3; i++ { - _, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary) + // sack to trigger fast retransmit + ack := *(lastSACK.Load()) + ack.gapAckBlocks = []gapAckBlock{{start: 11}} + for i := 11; i < 14; i++ { + ack.gapAckBlocks[0].end = uint16(i) + pkt := a1.createPacket([]chunk{&ack}) + pktBuf, err := pkt.marshal(true) require.NoError(t, err) - dbConn1.inboundHandler(<-sackCh) + dbConn1.inboundHandler(pktBuf) } - // fast retransmit and new sack sent + require.Eventually(t, func() bool { a1.lock.RLock() defer a1.lock.RUnlock() return a1.inFastRecovery }, 5*time.Second, 10*time.Millisecond) require.GreaterOrEqual(t, uint64(10), a1.stats.getNumFastRetrans()) + + // 7.2.4 b) In fast-recovery AND the Cumulative TSN Ack Point advanced + // the miss indications are incremented for all TSNs reported missing + // in the SACK. + a1.lock.Lock() + lastTSN := a1.inflightQueue.chunks.Back().tsn + lastTSNMinusTwo := lastTSN - 2 + lastChunk := a1.inflightQueue.chunks.Back() + lastChunkMinusTwo, ok := a1.inflightQueue.get(lastTSNMinusTwo) + a1.lock.Unlock() + require.True(t, ok) + require.True(t, lastTSN > ack.cumulativeTSNAck+uint32(ack.gapAckBlocks[0].end)+3) + + // sack with cumAckPoint advanced, lastTSN should not be marked as missing + ack.cumulativeTSNAck++ + end := lastTSN - 1 - ack.cumulativeTSNAck + ack.gapAckBlocks = append(ack.gapAckBlocks, gapAckBlock{start: uint16(end), end: uint16(end)}) + pkt := a1.createPacket([]chunk{&ack}) + pktBuf, err := pkt.marshal(true) + require.NoError(t, err) + dbConn1.inboundHandler(pktBuf) + require.Eventually(t, func() bool { + a1.lock.Lock() + defer a1.lock.Unlock() + return lastChunkMinusTwo.missIndicator == 1 && lastChunk.missIndicator == 0 + }, 5*time.Second, 10*time.Millisecond, "last %d, last-1 %d", lastChunk.missIndicator, lastChunkMinusTwo.missIndicator) } func TestAssociationMaxTSNOffset(t *testing.T) { diff --git a/queue.go b/queue.go index a6945caf..be8eebb1 100644 --- a/queue.go +++ b/queue.go @@ -47,6 +47,10 @@ func (q *queue[T]) Front() T { return q.buf[q.head] } +func (q *queue[T]) Back() T { + return q.buf[(q.tail-1+len(q.buf))%len(q.buf)] +} + func (q *queue[T]) At(i int) T { return q.buf[(q.head+i)%(len(q.buf))] }