From 278dfbb00d52c5937298d7650c913728fdfd35cd Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 1 Apr 2024 17:12:54 +0530 Subject: [PATCH] Optimise payload queue for write path --- association.go | 2 +- association_test.go | 8 +-- payload_queue.go | 140 ++++++------------------------------------ payload_queue_test.go | 83 ++++--------------------- 4 files changed, 37 insertions(+), 196 deletions(-) diff --git a/association.go b/association.go index c90d8c70..18e941f6 100644 --- a/association.go +++ b/association.go @@ -2184,7 +2184,7 @@ func (a *Association) movePendingDataChunkToInflightQueue(c *chunkPayloadData) { a.log.Tracef("[%s] sending ppi=%d tsn=%d ssn=%d sent=%d len=%d (%v,%v)", a.name, c.payloadType, c.tsn, c.streamSequenceNumber, c.nSent, len(c.userData), c.beginningFragment, c.endingFragment) - a.inflightQueue.pushNoCheck(c) + a.inflightQueue.push(c) } // popPendingDataChunksToSend pops chunks from the pending queues as many as diff --git a/association_test.go b/association_test.go index 44619281..1de4a252 100644 --- a/association_test.go +++ b/association_test.go @@ -1191,7 +1191,7 @@ func TestCreateForwardTSN(t *testing.T) { a.cumulativeTSNAckPoint = 9 a.advancedPeerTSNAckPoint = 10 - a.inflightQueue.pushNoCheck(&chunkPayloadData{ + a.inflightQueue.push(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 10, @@ -1218,7 +1218,7 @@ func TestCreateForwardTSN(t *testing.T) { a.cumulativeTSNAckPoint = 9 a.advancedPeerTSNAckPoint = 12 - a.inflightQueue.pushNoCheck(&chunkPayloadData{ + a.inflightQueue.push(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 10, @@ -1228,7 +1228,7 @@ func TestCreateForwardTSN(t *testing.T) { nSent: 1, _abandoned: true, }) - a.inflightQueue.pushNoCheck(&chunkPayloadData{ + a.inflightQueue.push(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 11, @@ -1238,7 +1238,7 @@ func TestCreateForwardTSN(t *testing.T) { nSent: 1, _abandoned: true, }) - a.inflightQueue.pushNoCheck(&chunkPayloadData{ + a.inflightQueue.push(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 12, diff --git a/payload_queue.go b/payload_queue.go index e5925a51..b3f52e20 100644 --- a/payload_queue.go +++ b/payload_queue.go @@ -3,15 +3,9 @@ package sctp -import ( - "fmt" - "sort" -) - type payloadQueue struct { chunkMap map[uint32]*chunkPayloadData - sorted []uint32 - dupTSN []uint32 + tsns []uint32 nBytes int } @@ -19,67 +13,35 @@ func newPayloadQueue() *payloadQueue { return &payloadQueue{chunkMap: map[uint32]*chunkPayloadData{}} } -func (q *payloadQueue) updateSortedKeys() { - if q.sorted != nil { +func (q *payloadQueue) push(p *chunkPayloadData) { + if _, ok := q.chunkMap[p.tsn]; ok { return } - - q.sorted = make([]uint32, len(q.chunkMap)) - i := 0 - for k := range q.chunkMap { - q.sorted[i] = k - i++ - } - - sort.Slice(q.sorted, func(i, j int) bool { - return sna32LT(q.sorted[i], q.sorted[j]) - }) -} - -func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool { - _, ok := q.chunkMap[p.tsn] - if ok || sna32LTE(p.tsn, cumulativeTSN) { - return false - } - return true -} - -func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) { q.chunkMap[p.tsn] = p q.nBytes += len(p.userData) - q.sorted = nil -} - -// push pushes a payload data. If the payload data is already in our queue or -// older than our cumulativeTSN marker, it will be recored as duplications, -// which can later be retrieved using popDuplicates. -func (q *payloadQueue) push(p *chunkPayloadData, cumulativeTSN uint32) bool { - _, ok := q.chunkMap[p.tsn] - if ok || sna32LTE(p.tsn, cumulativeTSN) { - // Found the packet, log in dups - q.dupTSN = append(q.dupTSN, p.tsn) - return false + var pos int + for pos = len(q.tsns) - 1; pos >= 0; pos-- { + if q.tsns[pos] < p.tsn { + break + } } - - q.chunkMap[p.tsn] = p - q.nBytes += len(p.userData) - q.sorted = nil - return true + pos++ + q.tsns = append(q.tsns, 0) + copy(q.tsns[pos+1:], q.tsns[pos:]) + q.tsns[pos] = p.tsn } // pop pops only if the oldest chunk's TSN matches the given TSN. func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) { - q.updateSortedKeys() - - if len(q.chunkMap) > 0 && tsn == q.sorted[0] { - q.sorted = q.sorted[1:] - if c, ok := q.chunkMap[tsn]; ok { - delete(q.chunkMap, tsn) - q.nBytes -= len(c.userData) - return c, true - } + if len(q.tsns) == 0 || q.tsns[0] != tsn { + return nil, false + } + q.tsns = q.tsns[1:] + if c, ok := q.chunkMap[tsn]; ok { + delete(q.chunkMap, tsn) + q.nBytes -= len(c.userData) + return c, true } - return nil, false } @@ -89,58 +51,6 @@ func (q *payloadQueue) get(tsn uint32) (*chunkPayloadData, bool) { return c, ok } -// popDuplicates returns an array of TSN values that were found duplicate. -func (q *payloadQueue) popDuplicates() []uint32 { - dups := q.dupTSN - q.dupTSN = []uint32{} - return dups -} - -func (q *payloadQueue) getGapAckBlocks(cumulativeTSN uint32) (gapAckBlocks []gapAckBlock) { - var b gapAckBlock - - if len(q.chunkMap) == 0 { - return []gapAckBlock{} - } - - q.updateSortedKeys() - - for i, tsn := range q.sorted { - if i == 0 { - b.start = uint16(tsn - cumulativeTSN) - b.end = b.start - continue - } - diff := uint16(tsn - cumulativeTSN) - if b.end+1 == diff { - b.end++ - } else { - gapAckBlocks = append(gapAckBlocks, gapAckBlock{ - start: b.start, - end: b.end, - }) - b.start = diff - b.end = diff - } - } - - gapAckBlocks = append(gapAckBlocks, gapAckBlock{ - start: b.start, - end: b.end, - }) - - return gapAckBlocks -} - -func (q *payloadQueue) getGapAckBlocksString(cumulativeTSN uint32) string { - gapAckBlocks := q.getGapAckBlocks(cumulativeTSN) - str := fmt.Sprintf("cumTSN=%d", cumulativeTSN) - for _, b := range gapAckBlocks { - str += fmt.Sprintf(",%d-%d", b.start, b.end) - } - return str -} - func (q *payloadQueue) markAsAcked(tsn uint32) int { var nBytesAcked int if c, ok := q.chunkMap[tsn]; ok { @@ -154,16 +64,6 @@ func (q *payloadQueue) markAsAcked(tsn uint32) int { return nBytesAcked } -func (q *payloadQueue) getLastTSNReceived() (uint32, bool) { - q.updateSortedKeys() - - qlen := len(q.sorted) - if qlen == 0 { - return 0, false - } - return q.sorted[qlen-1], true -} - func (q *payloadQueue) markAllToRetrasmit() { for _, c := range q.chunkMap { if c.acked || c.abandoned() { diff --git a/payload_queue_test.go b/payload_queue_test.go index 399f4148..a8401a6f 100644 --- a/payload_queue_test.go +++ b/payload_queue_test.go @@ -14,15 +14,15 @@ func makePayload(tsn uint32, nBytes int) *chunkPayloadData { } func TestPayloadQueue(t *testing.T) { - t.Run("pushNoCheck", func(t *testing.T) { + t.Run("push", func(t *testing.T) { pq := newPayloadQueue() - pq.pushNoCheck(makePayload(0, 10)) + pq.push(makePayload(0, 10)) assert.Equal(t, 10, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 1, pq.size(), "item count mismatch") - pq.pushNoCheck(makePayload(1, 11)) + pq.push(makePayload(1, 11)) assert.Equal(t, 21, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 2, pq.size(), "item count mismatch") - pq.pushNoCheck(makePayload(2, 12)) + pq.push(makePayload(2, 12)) assert.Equal(t, 33, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 3, pq.size(), "item count mismatch") @@ -31,18 +31,18 @@ func TestPayloadQueue(t *testing.T) { assert.True(t, ok, "pop should succeed") if ok { assert.Equal(t, i, c.tsn, "TSN should match") - assert.NotNil(t, pq.sorted, "should not be nil") + assert.NotNil(t, pq.tsns, "should not be nil") } } assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 0, pq.size(), "item count mismatch") - pq.pushNoCheck(makePayload(3, 13)) - assert.Nil(t, pq.sorted, "should be nil") + pq.push(makePayload(3, 13)) + assert.Len(t, pq.tsns, 1) assert.Equal(t, 13, pq.getNumBytes(), "total bytes mismatch") - pq.pushNoCheck(makePayload(4, 14)) - assert.Nil(t, pq.sorted, "should be nil") + pq.push(makePayload(4, 14)) + assert.Len(t, pq.tsns, 2) assert.Equal(t, 27, pq.getNumBytes(), "total bytes mismatch") for i := uint32(3); i < 5; i++ { @@ -50,7 +50,7 @@ func TestPayloadQueue(t *testing.T) { assert.True(t, ok, "pop should succeed") if ok { assert.Equal(t, i, c.tsn, "TSN should match") - assert.NotNil(t, pq.sorted, "should not be nil") + assert.NotNil(t, pq.tsns, "should not be nil") } } @@ -58,69 +58,10 @@ func TestPayloadQueue(t *testing.T) { assert.Equal(t, 0, pq.size(), "item count mismatch") }) - t.Run("getGapAckBlocks", func(t *testing.T) { - pq := newPayloadQueue() - pq.push(makePayload(1, 0), 0) - pq.push(makePayload(2, 0), 0) - pq.push(makePayload(3, 0), 0) - pq.push(makePayload(4, 0), 0) - pq.push(makePayload(5, 0), 0) - pq.push(makePayload(6, 0), 0) - - gab1 := []*gapAckBlock{{start: 1, end: 6}} - gab2 := pq.getGapAckBlocks(0) - assert.NotNil(t, gab2) - assert.Len(t, gab2, 1) - - assert.Equal(t, gab1[0].start, gab2[0].start) - assert.Equal(t, gab1[0].end, gab2[0].end) - - pq.push(makePayload(8, 0), 0) - pq.push(makePayload(9, 0), 0) - - gab1 = []*gapAckBlock{{start: 1, end: 6}, {start: 8, end: 9}} - gab2 = pq.getGapAckBlocks(0) - assert.NotNil(t, gab2) - assert.Len(t, gab2, 2) - - assert.Equal(t, gab1[0].start, gab2[0].start) - assert.Equal(t, gab1[0].end, gab2[0].end) - assert.Equal(t, gab1[1].start, gab2[1].start) - assert.Equal(t, gab1[1].end, gab2[1].end) - }) - - t.Run("getLastTSNReceived", func(t *testing.T) { - pq := newPayloadQueue() - - // empty queie should return false - _, ok := pq.getLastTSNReceived() - assert.False(t, ok, "should be false") - - ok = pq.push(makePayload(20, 0), 0) - assert.True(t, ok, "should be true") - tsn, ok := pq.getLastTSNReceived() - assert.True(t, ok, "should be false") - assert.Equal(t, uint32(20), tsn, "should match") - - // append should work - ok = pq.push(makePayload(21, 0), 0) - assert.True(t, ok, "should be true") - tsn, ok = pq.getLastTSNReceived() - assert.True(t, ok, "should be false") - assert.Equal(t, uint32(21), tsn, "should match") - - // check if sorting applied - ok = pq.push(makePayload(19, 0), 0) - assert.True(t, ok, "should be true") - tsn, ok = pq.getLastTSNReceived() - assert.True(t, ok, "should be false") - assert.Equal(t, uint32(21), tsn, "should match") - }) - t.Run("markAllToRetrasmit", func(t *testing.T) { pq := newPayloadQueue() for i := 0; i < 3; i++ { - pq.push(makePayload(uint32(i+1), 10), 0) + pq.push(makePayload(uint32(i+1), 10)) } pq.markAsAcked(2) pq.markAllToRetrasmit() @@ -139,7 +80,7 @@ func TestPayloadQueue(t *testing.T) { t.Run("reset retransmit flag on ack", func(t *testing.T) { pq := newPayloadQueue() for i := 0; i < 4; i++ { - pq.push(makePayload(uint32(i+1), 10), 0) + pq.push(makePayload(uint32(i+1), 10)) } pq.markAllToRetrasmit()