Skip to content

Commit

Permalink
Add receive chunk tracker for better received chunk handling
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Apr 1, 2024
1 parent e4788a9 commit c04adbf
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 22 deletions.
12 changes: 6 additions & 6 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ type Association struct {
myMaxNumInboundStreams uint16
myMaxNumOutboundStreams uint16
myCookie *paramStateCookie
payloadQueue *payloadQueue
payloadQueue *receivedChunkTracker
inflightQueue *payloadQueue
pendingQueue *pendingQueue
controlQueue *controlQueue
Expand Down Expand Up @@ -318,7 +318,7 @@ func createAssociation(config Config) *Association {
myMaxNumOutboundStreams: math.MaxUint16,
myMaxNumInboundStreams: math.MaxUint16,

payloadQueue: newPayloadQueue(),
payloadQueue: newReceivedPacketTracker(),
inflightQueue: newPayloadQueue(),
pendingQueue: newPendingQueue(),
controlQueue: newControlQueue(),
Expand Down Expand Up @@ -1378,7 +1378,7 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet {
a.name, d.tsn, d.immediateSack, len(d.userData))
a.stats.incDATAs()

canPush := a.payloadQueue.canPush(d, a.peerLastTSN)
canPush := a.payloadQueue.canPush(d.tsn, a.peerLastTSN)
if canPush {
s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown)
if s == nil {
Expand All @@ -1390,14 +1390,14 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet {

if a.getMyReceiverWindowCredit() > 0 {
// Pass the new chunk to stream level as soon as it arrives
a.payloadQueue.push(d, a.peerLastTSN)
a.payloadQueue.push(d.tsn, a.peerLastTSN)
s.handleData(d)
} else {
// Receive buffer is full
lastTSN, ok := a.payloadQueue.getLastTSNReceived()
if ok && sna32LT(d.tsn, lastTSN) {
a.log.Debugf("[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber)
a.payloadQueue.push(d, a.peerLastTSN)
a.payloadQueue.push(d.tsn, a.peerLastTSN)
s.handleData(d)
} else {
a.log.Debugf("[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber)
Expand All @@ -1421,7 +1421,7 @@ func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool)
// Meaning, if peerLastTSN+1 points to a chunk that is received,
// advance peerLastTSN until peerLastTSN+1 points to unreceived chunk.
for {
if _, popOk := a.payloadQueue.pop(a.peerLastTSN + 1); !popOk {
if popOk := a.payloadQueue.pop(a.peerLastTSN + 1); !popOk {
break
}
a.peerLastTSN++
Expand Down
18 changes: 2 additions & 16 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1310,14 +1310,7 @@ func TestHandleForwardTSN(t *testing.T) {
prevTSN := a.peerLastTSN

// this chunk is blocked by the missing chunk at tsn=1
a.payloadQueue.push(&chunkPayloadData{
beginningFragment: true,
endingFragment: true,
tsn: a.peerLastTSN + 2,
streamIdentifier: 0,
streamSequenceNumber: 1,
userData: []byte("ABC"),
}, a.peerLastTSN)
a.payloadQueue.push(a.peerLastTSN+2, a.peerLastTSN)

fwdtsn := &chunkForwardTSN{
newCumulativeTSN: a.peerLastTSN + 1,
Expand Down Expand Up @@ -1347,14 +1340,7 @@ func TestHandleForwardTSN(t *testing.T) {
prevTSN := a.peerLastTSN

// this chunk is blocked by the missing chunk at tsn=1
a.payloadQueue.push(&chunkPayloadData{
beginningFragment: true,
endingFragment: true,
tsn: a.peerLastTSN + 3,
streamIdentifier: 0,
streamSequenceNumber: 1,
userData: []byte("ABC"),
}, a.peerLastTSN)
a.payloadQueue.push(a.peerLastTSN+3, a.peerLastTSN)

fwdtsn := &chunkForwardTSN{
newCumulativeTSN: a.peerLastTSN + 1,
Expand Down
157 changes: 157 additions & 0 deletions received_packet_tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package sctp

import (
"fmt"
"strings"
)

// receivedChunkTracker tracks received chunks for maintaining ACK ranges
type receivedChunkTracker struct {
chunks map[uint32]struct{}
dupTSN []uint32
ranges []ackRange
}

// ackRange is a contiguous range of chunks that we have received
type ackRange struct {
start uint32
end uint32
}

func newReceivedPacketTracker() *receivedChunkTracker {
return &receivedChunkTracker{chunks: make(map[uint32]struct{})}
}

func (q *receivedChunkTracker) canPush(tsn uint32, cumulativeTSN uint32) bool {
_, ok := q.chunks[tsn]
if ok || sna32LTE(tsn, cumulativeTSN) {
return false
}
return true
}

// push pushes a payload data. If the payload data is already in our queue or
// older than our cumulativeTSN marker, it will be recorded as duplications,
// which can later be retrieved using popDuplicates.
func (q *receivedChunkTracker) push(tsn uint32, cumulativeTSN uint32) bool {

Check failure on line 39 in received_packet_tracker.go

View workflow job for this annotation

GitHub Actions / lint / Go

(*receivedChunkTracker).push - result 0 (bool) is never used (unparam)
_, ok := q.chunks[tsn]
if ok || sna32LTE(tsn, cumulativeTSN) {
// Found the packet, log in dups
q.dupTSN = append(q.dupTSN, tsn)
return false
}

Check warning on line 45 in received_packet_tracker.go

View check run for this annotation

Codecov / codecov/patch

received_packet_tracker.go#L42-L45

Added lines #L42 - L45 were not covered by tests
q.chunks[tsn] = struct{}{}

insert := true
var pos int
for pos = len(q.ranges) - 1; pos >= 0; pos-- {
if tsn == q.ranges[pos].end+1 {
q.ranges[pos].end++
insert = false
break
}
if tsn == q.ranges[pos].start-1 {
q.ranges[pos].start--
insert = false
break
}
if tsn > q.ranges[pos].end {
break
}
}
if insert {
// pos is at the element just before the insertion point
pos++
q.ranges = append(q.ranges, ackRange{})
copy(q.ranges[pos+1:], q.ranges[pos:])
q.ranges[pos] = ackRange{start: tsn, end: tsn}
} else {
// extended element at pos, check if we can merge it with adjacent elements
if pos-1 >= 0 {
if q.ranges[pos-1].end+1 == q.ranges[pos].start {
q.ranges[pos-1] = ackRange{
start: q.ranges[pos-1].start,
end: q.ranges[pos].end,
}
copy(q.ranges[pos:], q.ranges[pos+1:])
q.ranges = q.ranges[:len(q.ranges)-1]
// We have merged pos and pos-1 in to pos-1, update pos to reflect that.
// Not updating this won't be an error but it's nice to maintain the invariant
pos--
}
}
if pos+1 < len(q.ranges) {
if q.ranges[pos+1].start-1 == q.ranges[pos].end {
q.ranges[pos+1] = ackRange{
start: q.ranges[pos].start,
end: q.ranges[pos+1].end,
}
copy(q.ranges[pos:], q.ranges[pos+1:])
q.ranges = q.ranges[:len(q.ranges)-1]
}

Check warning on line 94 in received_packet_tracker.go

View check run for this annotation

Codecov / codecov/patch

received_packet_tracker.go#L88-L94

Added lines #L88 - L94 were not covered by tests
}
}
return true
}

// pop pops only if the oldest chunk's TSN matches the given TSN.
func (q *receivedChunkTracker) pop(tsn uint32) bool {
if len(q.ranges) == 0 || q.ranges[0].start != tsn {
return false
}
q.ranges[0].start++
if q.ranges[0].start > q.ranges[0].end {
q.ranges = q.ranges[1:]
}
delete(q.chunks, tsn)
return true
}

// popDuplicates returns an array of TSN values that were found duplicate.
func (q *receivedChunkTracker) popDuplicates() []uint32 {
dups := q.dupTSN
q.dupTSN = []uint32{}
return dups
}

// receivedPacketTracker getGapACKBlocks returns gapAckBlocks after the cummulative TSN
func (q *receivedChunkTracker) getGapAckBlocks(cumulativeTSN uint32) []gapAckBlock {
gapAckBlocks := make([]gapAckBlock, 0, len(q.ranges))
for _, ar := range q.ranges {
if ar.end > cumulativeTSN {
st := ar.start
if st < cumulativeTSN {
st = cumulativeTSN + 1
}
gapAckBlocks = append(gapAckBlocks, gapAckBlock{
start: uint16(st - cumulativeTSN),
end: uint16(ar.end - cumulativeTSN),
})
}
}
return gapAckBlocks
}

func (q *receivedChunkTracker) getGapAckBlocksString(cumulativeTSN uint32) string {
gapAckBlocks := q.getGapAckBlocks(cumulativeTSN)
sb := strings.Builder{}
sb.WriteString(fmt.Sprintf("cumTSN=%d", cumulativeTSN))
for _, b := range gapAckBlocks {
sb.WriteString(fmt.Sprintf(",%d-%d", b.start, b.end))
}
return sb.String()
}

func (q *receivedChunkTracker) getLastTSNReceived() (uint32, bool) {
if len(q.ranges) == 0 {
return 0, false
}
return q.ranges[len(q.ranges)-1].end, true
}

func (q *receivedChunkTracker) size() int {
return len(q.chunks)
}
125 changes: 125 additions & 0 deletions received_packet_tracker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package sctp

import (
"fmt"
"math/rand"
"testing"

"github.com/stretchr/testify/require"
)

func TestReceivedPacketTrackerPushPop(t *testing.T) {
q := newReceivedPacketTracker()
for i := uint32(1); i < 100; i++ {
q.push(i, 0)
}
// leave a gap at position 100
for i := uint32(101); i < 200; i++ {
q.push(i, 0)
}
for i := uint32(2); i < 200; i++ {
require.False(t, q.pop(i)) // all pop will fail till we pop the first tsn
}
for i := uint32(1); i < 100; i++ {
require.True(t, q.pop(i))
}
// 101 is the smallest value now
for i := uint32(102); i < 200; i++ {
require.False(t, q.pop(i))
}
q.push(100, 99)
for i := uint32(100); i < 200; i++ {
require.True(t, q.pop(i))
}

// q is empty now
require.Equal(t, q.size(), 0)
for i := uint32(0); i < 200; i++ {
require.False(t, q.pop(i))
}
}

func TestReceivedPacketTrackerGapACKBlocksStress(t *testing.T) {
testChunks := func(chunks []uint32, st uint32) {
if len(chunks) == 0 {
return
}
expected := make([]gapAckBlock, 0, len(chunks))
cr := ackRange{start: chunks[0], end: chunks[0]}
for i := 1; i < len(chunks); i++ {
if cr.end+1 != chunks[i] {
expected = append(expected, gapAckBlock{
start: uint16(cr.start - st),
end: uint16(cr.end - st),
})
cr = ackRange{start: chunks[i], end: chunks[i]}
} else {
cr.end++
}
}
expected = append(expected, gapAckBlock{
start: uint16(cr.start - st),
end: uint16(cr.end - st),
})

q := newReceivedPacketTracker()
rand.Shuffle(len(chunks), func(i, j int) {
chunks[i], chunks[j] = chunks[j], chunks[i]
})
for _, t := range chunks {
q.push(t, 0)
}
res := q.getGapAckBlocks(0)
require.Equal(t, expected, res, chunks)
}
chunks := make([]uint32, 0, 10)
for i := 1; i < (1 << 10); i++ {
for j := 0; j < 10; j++ {
if i&(1<<j) != 0 {
chunks = append(chunks, uint32(j+1))
}
}
testChunks(chunks, 0)
chunks = chunks[:0]
}
}

func TestReceivedPacketTrackerGapACKBlocksStress2(t *testing.T) {

Check failure on line 90 in received_packet_tracker_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

unnecessary leading newline (whitespace)

Check failure on line 91 in received_packet_tracker_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

File is not `gofumpt`-ed (gofumpt)
tests := []struct {
chunks []uint32
cummulativeTSN uint32
result []gapAckBlock
}{
{
chunks: []uint32{3, 4, 1, 2, 7, 8, 10000},
cummulativeTSN: 3,
result: []gapAckBlock{{1, 1}, {4, 5}, {10000 - 3, 10000 - 3}},
},
{
chunks: []uint32{3, 5, 1, 2, 7, 8, 10000},
cummulativeTSN: 3,
result: []gapAckBlock{{2, 2}, {4, 5}, {10000 - 3, 10000 - 3}},
},
{
chunks: []uint32{3, 4, 1, 2, 7, 8, 10000},
cummulativeTSN: 0,
result: []gapAckBlock{{1, 4}, {7, 8}, {10000, 10000}},
},
}

for i, tc := range tests {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
q := newReceivedPacketTracker()
for _, t := range tc.chunks {
q.push(t, 0)
}
res := q.getGapAckBlocks(tc.cummulativeTSN)
require.Equal(t, tc.result, res)
})
}

Check failure on line 124 in received_packet_tracker_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

File is not `gofumpt`-ed (gofumpt)
}

Check failure on line 125 in received_packet_tracker_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

unnecessary trailing newline (whitespace)

0 comments on commit c04adbf

Please sign in to comment.