diff --git a/association.go b/association.go index 3b928820..c90d8c70 100644 --- a/association.go +++ b/association.go @@ -175,7 +175,7 @@ type Association struct { myMaxNumInboundStreams uint16 myMaxNumOutboundStreams uint16 myCookie *paramStateCookie - payloadQueue *payloadQueue + payloadQueue *receivedChunkTracker inflightQueue *payloadQueue pendingQueue *pendingQueue controlQueue *controlQueue @@ -318,7 +318,7 @@ func createAssociation(config Config) *Association { myMaxNumOutboundStreams: math.MaxUint16, myMaxNumInboundStreams: math.MaxUint16, - payloadQueue: newPayloadQueue(), + payloadQueue: newReceivedPacketTracker(), inflightQueue: newPayloadQueue(), pendingQueue: newPendingQueue(), controlQueue: newControlQueue(), @@ -1378,7 +1378,7 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet { a.name, d.tsn, d.immediateSack, len(d.userData)) a.stats.incDATAs() - canPush := a.payloadQueue.canPush(d, a.peerLastTSN) + canPush := a.payloadQueue.canPush(d.tsn, a.peerLastTSN) if canPush { s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown) if s == nil { @@ -1390,14 +1390,14 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet { if a.getMyReceiverWindowCredit() > 0 { // Pass the new chunk to stream level as soon as it arrives - a.payloadQueue.push(d, a.peerLastTSN) + a.payloadQueue.push(d.tsn, a.peerLastTSN) s.handleData(d) } else { // Receive buffer is full lastTSN, ok := a.payloadQueue.getLastTSNReceived() if ok && sna32LT(d.tsn, lastTSN) { a.log.Debugf("[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) - a.payloadQueue.push(d, a.peerLastTSN) + a.payloadQueue.push(d.tsn, a.peerLastTSN) s.handleData(d) } else { a.log.Debugf("[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) @@ -1421,7 +1421,7 @@ func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) // Meaning, if peerLastTSN+1 points to a chunk that is received, // advance peerLastTSN until peerLastTSN+1 points to unreceived chunk. for { - if _, popOk := a.payloadQueue.pop(a.peerLastTSN + 1); !popOk { + if popOk := a.payloadQueue.pop(a.peerLastTSN + 1); !popOk { break } a.peerLastTSN++ diff --git a/association_test.go b/association_test.go index 8a50cf29..44619281 100644 --- a/association_test.go +++ b/association_test.go @@ -1310,14 +1310,7 @@ func TestHandleForwardTSN(t *testing.T) { prevTSN := a.peerLastTSN // this chunk is blocked by the missing chunk at tsn=1 - a.payloadQueue.push(&chunkPayloadData{ - beginningFragment: true, - endingFragment: true, - tsn: a.peerLastTSN + 2, - streamIdentifier: 0, - streamSequenceNumber: 1, - userData: []byte("ABC"), - }, a.peerLastTSN) + a.payloadQueue.push(a.peerLastTSN+2, a.peerLastTSN) fwdtsn := &chunkForwardTSN{ newCumulativeTSN: a.peerLastTSN + 1, @@ -1347,14 +1340,7 @@ func TestHandleForwardTSN(t *testing.T) { prevTSN := a.peerLastTSN // this chunk is blocked by the missing chunk at tsn=1 - a.payloadQueue.push(&chunkPayloadData{ - beginningFragment: true, - endingFragment: true, - tsn: a.peerLastTSN + 3, - streamIdentifier: 0, - streamSequenceNumber: 1, - userData: []byte("ABC"), - }, a.peerLastTSN) + a.payloadQueue.push(a.peerLastTSN+3, a.peerLastTSN) fwdtsn := &chunkForwardTSN{ newCumulativeTSN: a.peerLastTSN + 1, diff --git a/received_packet_tracker.go b/received_packet_tracker.go new file mode 100644 index 00000000..41d56afe --- /dev/null +++ b/received_packet_tracker.go @@ -0,0 +1,157 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "fmt" + "strings" +) + +// receivedChunkTracker tracks received chunks for maintaining ACK ranges +type receivedChunkTracker struct { + chunks map[uint32]struct{} + dupTSN []uint32 + ranges []ackRange +} + +// ackRange is a contiguous range of chunks that we have received +type ackRange struct { + start uint32 + end uint32 +} + +func newReceivedPacketTracker() *receivedChunkTracker { + return &receivedChunkTracker{chunks: make(map[uint32]struct{})} +} + +func (q *receivedChunkTracker) canPush(tsn uint32, cumulativeTSN uint32) bool { + _, ok := q.chunks[tsn] + if ok || sna32LTE(tsn, cumulativeTSN) { + return false + } + return true +} + +// push pushes a payload data. If the payload data is already in our queue or +// older than our cumulativeTSN marker, it will be recorded as duplications, +// which can later be retrieved using popDuplicates. +func (q *receivedChunkTracker) push(tsn uint32, cumulativeTSN uint32) bool { + _, ok := q.chunks[tsn] + if ok || sna32LTE(tsn, cumulativeTSN) { + // Found the packet, log in dups + q.dupTSN = append(q.dupTSN, tsn) + return false + } + q.chunks[tsn] = struct{}{} + + insert := true + var pos int + for pos = len(q.ranges) - 1; pos >= 0; pos-- { + if tsn == q.ranges[pos].end+1 { + q.ranges[pos].end++ + insert = false + break + } + if tsn == q.ranges[pos].start-1 { + q.ranges[pos].start-- + insert = false + break + } + if tsn > q.ranges[pos].end { + break + } + } + if insert { + // pos is at the element just before the insertion point + pos++ + q.ranges = append(q.ranges, ackRange{}) + copy(q.ranges[pos+1:], q.ranges[pos:]) + q.ranges[pos] = ackRange{start: tsn, end: tsn} + } else { + // extended element at pos, check if we can merge it with adjacent elements + if pos-1 >= 0 { + if q.ranges[pos-1].end+1 == q.ranges[pos].start { + q.ranges[pos-1] = ackRange{ + start: q.ranges[pos-1].start, + end: q.ranges[pos].end, + } + copy(q.ranges[pos:], q.ranges[pos+1:]) + q.ranges = q.ranges[:len(q.ranges)-1] + // We have merged pos and pos-1 in to pos-1, update pos to reflect that. + // Not updating this won't be an error but it's nice to maintain the invariant + pos-- + } + } + if pos+1 < len(q.ranges) { + if q.ranges[pos+1].start-1 == q.ranges[pos].end { + q.ranges[pos+1] = ackRange{ + start: q.ranges[pos].start, + end: q.ranges[pos+1].end, + } + copy(q.ranges[pos:], q.ranges[pos+1:]) + q.ranges = q.ranges[:len(q.ranges)-1] + } + } + } + return true +} + +// pop pops only if the oldest chunk's TSN matches the given TSN. +func (q *receivedChunkTracker) pop(tsn uint32) bool { + if len(q.ranges) == 0 || q.ranges[0].start != tsn { + return false + } + q.ranges[0].start++ + if q.ranges[0].start > q.ranges[0].end { + q.ranges = q.ranges[1:] + } + delete(q.chunks, tsn) + return true +} + +// popDuplicates returns an array of TSN values that were found duplicate. +func (q *receivedChunkTracker) popDuplicates() []uint32 { + dups := q.dupTSN + q.dupTSN = []uint32{} + return dups +} + +// receivedPacketTracker getGapACKBlocks returns gapAckBlocks after the cummulative TSN +func (q *receivedChunkTracker) getGapAckBlocks(cumulativeTSN uint32) []gapAckBlock { + gapAckBlocks := make([]gapAckBlock, 0, len(q.ranges)) + for _, ar := range q.ranges { + if ar.end > cumulativeTSN { + st := ar.start + if st < cumulativeTSN { + st = cumulativeTSN + 1 + } + gapAckBlocks = append(gapAckBlocks, gapAckBlock{ + start: uint16(st - cumulativeTSN), + end: uint16(ar.end - cumulativeTSN), + }) + } + } + return gapAckBlocks +} + +func (q *receivedChunkTracker) getGapAckBlocksString(cumulativeTSN uint32) string { + gapAckBlocks := q.getGapAckBlocks(cumulativeTSN) + sb := strings.Builder{} + sb.WriteString(fmt.Sprintf("cumTSN=%d", cumulativeTSN)) + for _, b := range gapAckBlocks { + sb.WriteString(fmt.Sprintf(",%d-%d", b.start, b.end)) + } + return sb.String() +} + +func (q *receivedChunkTracker) getLastTSNReceived() (uint32, bool) { + if len(q.ranges) == 0 { + return 0, false + } + return q.ranges[len(q.ranges)-1].end, true +} + +func (q *receivedChunkTracker) size() int { + return len(q.chunks) +} diff --git a/received_packet_tracker_test.go b/received_packet_tracker_test.go new file mode 100644 index 00000000..2affa8f1 --- /dev/null +++ b/received_packet_tracker_test.go @@ -0,0 +1,125 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "fmt" + "math/rand" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReceivedPacketTrackerPushPop(t *testing.T) { + q := newReceivedPacketTracker() + for i := uint32(1); i < 100; i++ { + q.push(i, 0) + } + // leave a gap at position 100 + for i := uint32(101); i < 200; i++ { + q.push(i, 0) + } + for i := uint32(2); i < 200; i++ { + require.False(t, q.pop(i)) // all pop will fail till we pop the first tsn + } + for i := uint32(1); i < 100; i++ { + require.True(t, q.pop(i)) + } + // 101 is the smallest value now + for i := uint32(102); i < 200; i++ { + require.False(t, q.pop(i)) + } + q.push(100, 99) + for i := uint32(100); i < 200; i++ { + require.True(t, q.pop(i)) + } + + // q is empty now + require.Equal(t, q.size(), 0) + for i := uint32(0); i < 200; i++ { + require.False(t, q.pop(i)) + } +} + +func TestReceivedPacketTrackerGapACKBlocksStress(t *testing.T) { + testChunks := func(chunks []uint32, st uint32) { + if len(chunks) == 0 { + return + } + expected := make([]gapAckBlock, 0, len(chunks)) + cr := ackRange{start: chunks[0], end: chunks[0]} + for i := 1; i < len(chunks); i++ { + if cr.end+1 != chunks[i] { + expected = append(expected, gapAckBlock{ + start: uint16(cr.start - st), + end: uint16(cr.end - st), + }) + cr = ackRange{start: chunks[i], end: chunks[i]} + } else { + cr.end++ + } + } + expected = append(expected, gapAckBlock{ + start: uint16(cr.start - st), + end: uint16(cr.end - st), + }) + + q := newReceivedPacketTracker() + rand.Shuffle(len(chunks), func(i, j int) { + chunks[i], chunks[j] = chunks[j], chunks[i] + }) + for _, t := range chunks { + q.push(t, 0) + } + res := q.getGapAckBlocks(0) + require.Equal(t, expected, res, chunks) + } + chunks := make([]uint32, 0, 10) + for i := 1; i < (1 << 10); i++ { + for j := 0; j < 10; j++ { + if i&(1<