-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add receive chunk tracker for better received chunk handling
- Loading branch information
Showing
4 changed files
with
290 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> | ||
// SPDX-License-Identifier: MIT | ||
|
||
package sctp | ||
|
||
import ( | ||
"fmt" | ||
"strings" | ||
) | ||
|
||
// receivedChunkTracker tracks received chunks for maintaining ACK ranges | ||
type receivedChunkTracker struct { | ||
chunks map[uint32]struct{} | ||
dupTSN []uint32 | ||
ranges []ackRange | ||
} | ||
|
||
// ackRange is a contiguous range of chunks that we have received | ||
type ackRange struct { | ||
start uint32 | ||
end uint32 | ||
} | ||
|
||
func newReceivedPacketTracker() *receivedChunkTracker { | ||
return &receivedChunkTracker{chunks: make(map[uint32]struct{})} | ||
} | ||
|
||
func (q *receivedChunkTracker) canPush(tsn uint32, cumulativeTSN uint32) bool { | ||
_, ok := q.chunks[tsn] | ||
if ok || sna32LTE(tsn, cumulativeTSN) { | ||
return false | ||
} | ||
return true | ||
} | ||
|
||
// push pushes a payload data. If the payload data is already in our queue or | ||
// older than our cumulativeTSN marker, it will be recorded as duplications, | ||
// which can later be retrieved using popDuplicates. | ||
func (q *receivedChunkTracker) push(tsn uint32, cumulativeTSN uint32) bool { | ||
_, ok := q.chunks[tsn] | ||
if ok || sna32LTE(tsn, cumulativeTSN) { | ||
// Found the packet, log in dups | ||
q.dupTSN = append(q.dupTSN, tsn) | ||
return false | ||
} | ||
q.chunks[tsn] = struct{}{} | ||
|
||
insert := true | ||
var pos int | ||
for pos = len(q.ranges) - 1; pos >= 0; pos-- { | ||
if tsn == q.ranges[pos].end+1 { | ||
q.ranges[pos].end++ | ||
insert = false | ||
break | ||
} | ||
if tsn == q.ranges[pos].start-1 { | ||
q.ranges[pos].start-- | ||
insert = false | ||
break | ||
} | ||
if tsn > q.ranges[pos].end { | ||
break | ||
} | ||
} | ||
if insert { | ||
// pos is at the element just before the insertion point | ||
pos++ | ||
q.ranges = append(q.ranges, ackRange{}) | ||
copy(q.ranges[pos+1:], q.ranges[pos:]) | ||
q.ranges[pos] = ackRange{start: tsn, end: tsn} | ||
} else { | ||
// extended element at pos, check if we can merge it with adjacent elements | ||
if pos-1 >= 0 { | ||
if q.ranges[pos-1].end+1 == q.ranges[pos].start { | ||
q.ranges[pos-1] = ackRange{ | ||
start: q.ranges[pos-1].start, | ||
end: q.ranges[pos].end, | ||
} | ||
copy(q.ranges[pos:], q.ranges[pos+1:]) | ||
q.ranges = q.ranges[:len(q.ranges)-1] | ||
// We have merged pos and pos-1 in to pos-1, update pos to reflect that. | ||
// Not updating this won't be an error but it's nice to maintain the invariant | ||
pos-- | ||
} | ||
} | ||
if pos+1 < len(q.ranges) { | ||
if q.ranges[pos+1].start-1 == q.ranges[pos].end { | ||
q.ranges[pos+1] = ackRange{ | ||
start: q.ranges[pos].start, | ||
end: q.ranges[pos+1].end, | ||
} | ||
copy(q.ranges[pos:], q.ranges[pos+1:]) | ||
q.ranges = q.ranges[:len(q.ranges)-1] | ||
} | ||
} | ||
} | ||
return true | ||
} | ||
|
||
// pop pops only if the oldest chunk's TSN matches the given TSN. | ||
func (q *receivedChunkTracker) pop(tsn uint32) bool { | ||
if len(q.ranges) == 0 || q.ranges[0].start != tsn { | ||
return false | ||
} | ||
q.ranges[0].start++ | ||
if q.ranges[0].start > q.ranges[0].end { | ||
q.ranges = q.ranges[1:] | ||
} | ||
delete(q.chunks, tsn) | ||
return true | ||
} | ||
|
||
// popDuplicates returns an array of TSN values that were found duplicate. | ||
func (q *receivedChunkTracker) popDuplicates() []uint32 { | ||
dups := q.dupTSN | ||
q.dupTSN = []uint32{} | ||
return dups | ||
} | ||
|
||
// receivedPacketTracker getGapACKBlocks returns gapAckBlocks after the cummulative TSN | ||
func (q *receivedChunkTracker) getGapAckBlocks(cumulativeTSN uint32) []gapAckBlock { | ||
gapAckBlocks := make([]gapAckBlock, 0, len(q.ranges)) | ||
for _, ar := range q.ranges { | ||
if ar.end > cumulativeTSN { | ||
st := ar.start | ||
if st < cumulativeTSN { | ||
st = cumulativeTSN + 1 | ||
} | ||
gapAckBlocks = append(gapAckBlocks, gapAckBlock{ | ||
start: uint16(st - cumulativeTSN), | ||
end: uint16(ar.end - cumulativeTSN), | ||
}) | ||
} | ||
} | ||
return gapAckBlocks | ||
} | ||
|
||
func (q *receivedChunkTracker) getGapAckBlocksString(cumulativeTSN uint32) string { | ||
gapAckBlocks := q.getGapAckBlocks(cumulativeTSN) | ||
sb := strings.Builder{} | ||
sb.WriteString(fmt.Sprintf("cumTSN=%d", cumulativeTSN)) | ||
for _, b := range gapAckBlocks { | ||
sb.WriteString(fmt.Sprintf(",%d-%d", b.start, b.end)) | ||
} | ||
return sb.String() | ||
} | ||
|
||
func (q *receivedChunkTracker) getLastTSNReceived() (uint32, bool) { | ||
if len(q.ranges) == 0 { | ||
return 0, false | ||
} | ||
return q.ranges[len(q.ranges)-1].end, true | ||
} | ||
|
||
func (q *receivedChunkTracker) size() int { | ||
return len(q.chunks) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> | ||
// SPDX-License-Identifier: MIT | ||
|
||
package sctp | ||
|
||
import ( | ||
"fmt" | ||
"math/rand" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestReceivedPacketTrackerPushPop(t *testing.T) { | ||
q := newReceivedPacketTracker() | ||
for i := uint32(1); i < 100; i++ { | ||
q.push(i, 0) | ||
} | ||
// leave a gap at position 100 | ||
for i := uint32(101); i < 200; i++ { | ||
q.push(i, 0) | ||
} | ||
for i := uint32(2); i < 200; i++ { | ||
require.False(t, q.pop(i)) // all pop will fail till we pop the first tsn | ||
} | ||
for i := uint32(1); i < 100; i++ { | ||
require.True(t, q.pop(i)) | ||
} | ||
// 101 is the smallest value now | ||
for i := uint32(102); i < 200; i++ { | ||
require.False(t, q.pop(i)) | ||
} | ||
q.push(100, 99) | ||
for i := uint32(100); i < 200; i++ { | ||
require.True(t, q.pop(i)) | ||
} | ||
|
||
// q is empty now | ||
require.Equal(t, q.size(), 0) | ||
for i := uint32(0); i < 200; i++ { | ||
require.False(t, q.pop(i)) | ||
} | ||
} | ||
|
||
func TestReceivedPacketTrackerGapACKBlocksStress(t *testing.T) { | ||
testChunks := func(chunks []uint32, st uint32) { | ||
if len(chunks) == 0 { | ||
return | ||
} | ||
expected := make([]gapAckBlock, 0, len(chunks)) | ||
cr := ackRange{start: chunks[0], end: chunks[0]} | ||
for i := 1; i < len(chunks); i++ { | ||
if cr.end+1 != chunks[i] { | ||
expected = append(expected, gapAckBlock{ | ||
start: uint16(cr.start - st), | ||
end: uint16(cr.end - st), | ||
}) | ||
cr = ackRange{start: chunks[i], end: chunks[i]} | ||
} else { | ||
cr.end++ | ||
} | ||
} | ||
expected = append(expected, gapAckBlock{ | ||
start: uint16(cr.start - st), | ||
end: uint16(cr.end - st), | ||
}) | ||
|
||
q := newReceivedPacketTracker() | ||
rand.Shuffle(len(chunks), func(i, j int) { | ||
chunks[i], chunks[j] = chunks[j], chunks[i] | ||
}) | ||
for _, t := range chunks { | ||
q.push(t, 0) | ||
} | ||
res := q.getGapAckBlocks(0) | ||
require.Equal(t, expected, res, chunks) | ||
} | ||
chunks := make([]uint32, 0, 10) | ||
for i := 1; i < (1 << 10); i++ { | ||
for j := 0; j < 10; j++ { | ||
if i&(1<<j) != 0 { | ||
chunks = append(chunks, uint32(j+1)) | ||
} | ||
} | ||
testChunks(chunks, 0) | ||
chunks = chunks[:0] | ||
} | ||
} | ||
|
||
func TestReceivedPacketTrackerGapACKBlocksStress2(t *testing.T) { | ||
|
||
tests := []struct { | ||
chunks []uint32 | ||
cummulativeTSN uint32 | ||
result []gapAckBlock | ||
}{ | ||
{ | ||
chunks: []uint32{3, 4, 1, 2, 7, 8, 10000}, | ||
cummulativeTSN: 3, | ||
result: []gapAckBlock{{1, 1}, {4, 5}, {10000 - 3, 10000 - 3}}, | ||
}, | ||
{ | ||
chunks: []uint32{3, 5, 1, 2, 7, 8, 10000}, | ||
cummulativeTSN: 3, | ||
result: []gapAckBlock{{2, 2}, {4, 5}, {10000 - 3, 10000 - 3}}, | ||
}, | ||
{ | ||
chunks: []uint32{3, 4, 1, 2, 7, 8, 10000}, | ||
cummulativeTSN: 0, | ||
result: []gapAckBlock{{1, 4}, {7, 8}, {10000, 10000}}, | ||
}, | ||
} | ||
|
||
for i, tc := range tests { | ||
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { | ||
q := newReceivedPacketTracker() | ||
for _, t := range tc.chunks { | ||
q.push(t, 0) | ||
} | ||
res := q.getGapAckBlocks(tc.cummulativeTSN) | ||
require.Equal(t, tc.result, res) | ||
}) | ||
} | ||
|
||
} | ||