diff --git a/association.go b/association.go index 46881b1d..2198fc6f 100644 --- a/association.go +++ b/association.go @@ -105,11 +105,11 @@ const ( avgChunkSize = 500 // minTSNOffset is the minimum offset over the cummulative TSN that we will enqueue // irrespective of the receive buffer size - // see Association.getMaxTSNOffset + // see getMaxTSNOffset minTSNOffset = 2000 // maxTSNOffset is the maximum offset over the cummulative TSN that we will enqueue // irrespective of the receive buffer size - // see Association.getMaxTSNOffset + // see getMaxTSNOffset maxTSNOffset = 40000 // maxReconfigRequests is the maximum number of reconfig requests we will keep outstanding maxReconfigRequests = 1000 @@ -166,7 +166,6 @@ type Association struct { state uint32 initialTSN uint32 myNextTSN uint32 // nextTSN - peerLastTSN uint32 // lastRcvdTSN minTSN2MeasureRTT uint32 // for RTT measurement willSendForwardTSN bool willRetransmitFast bool @@ -190,7 +189,7 @@ type Association struct { myMaxNumInboundStreams uint16 myMaxNumOutboundStreams uint16 myCookie *paramStateCookie - payloadQueue *payloadQueue + payloadQueue *receivePayloadQueue inflightQueue *payloadQueue pendingQueue *pendingQueue controlQueue *controlQueue @@ -333,7 +332,7 @@ func createAssociation(config Config) *Association { myMaxNumOutboundStreams: math.MaxUint16, myMaxNumInboundStreams: math.MaxUint16, - payloadQueue: newPayloadQueue(), + payloadQueue: newReceivePayloadQueue(getMaxTSNOffset(maxReceiveBufferSize)), inflightQueue: newPayloadQueue(), pendingQueue: newPendingQueue(), controlQueue: newControlQueue(), @@ -1071,6 +1070,11 @@ func min32(a, b uint32) uint32 { return b } +// peerLastTSN return last received cumulative TSN +func (a *Association) peerLastTSN() uint32 { + return a.payloadQueue.getcumulativeTSN() +} + // setState atomically sets the state of the Association. // The caller should hold the lock. func (a *Association) setState(newState uint32) { @@ -1127,13 +1131,11 @@ func (a *Association) SRTT() float64 { } // getMaxTSNOffset returns the maximum offset over the current cummulative TSN that -// we are willing to enqueue. Limiting the maximum offset limits the number of -// tsns we have in the payloadQueue map. This ensures that we don't use too much space in -// the map itself. This also ensures that we keep the bytes utilized in the receive +// we are willing to enqueue. This ensures that we keep the bytes utilized in the receive // buffer within a small multiple of the user provided max receive buffer size. -func (a *Association) getMaxTSNOffset() uint32 { +func getMaxTSNOffset(maxReceiveBufferSize uint32) uint32 { // 4 is a magic number here. There is no theory behind this. - offset := (a.maxReceiveBufferSize * 4) / avgChunkSize + offset := (maxReceiveBufferSize * 4) / avgChunkSize if offset < minTSNOffset { offset = minTSNOffset } @@ -1186,7 +1188,7 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { // is set initially by taking the peer's initial TSN, // received in the INIT or INIT ACK chunk, and // subtracting one from it. - a.peerLastTSN = i.initialTSN - 1 + a.payloadQueue.init(i.initialTSN - 1) for _, param := range i.params { switch v := param.(type) { // nolint:gocritic @@ -1260,7 +1262,7 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { a.myMaxNumInboundStreams = min16(i.numInboundStreams, a.myMaxNumInboundStreams) a.myMaxNumOutboundStreams = min16(i.numOutboundStreams, a.myMaxNumOutboundStreams) a.peerVerificationTag = i.initiateTag - a.peerLastTSN = i.initialTSN - 1 + a.payloadQueue.init(i.initialTSN - 1) if a.sourcePort != p.destinationPort || a.destinationPort != p.sourcePort { a.log.Warnf("[%s] handleInitAck: port mismatch", a.name) @@ -1411,7 +1413,7 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet { a.name, d.tsn, d.immediateSack, len(d.userData)) a.stats.incDATAs() - canPush := a.payloadQueue.canPush(d, a.peerLastTSN, a.getMaxTSNOffset()) + canPush := a.payloadQueue.canPush(d.tsn) if canPush { s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown) if s == nil { @@ -1423,14 +1425,14 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet { if a.getMyReceiverWindowCredit() > 0 { // Pass the new chunk to stream level as soon as it arrives - a.payloadQueue.push(d, a.peerLastTSN) + a.payloadQueue.push(d.tsn) s.handleData(d) } else { // Receive buffer is full lastTSN, ok := a.payloadQueue.getLastTSNReceived() if ok && sna32LT(d.tsn, lastTSN) { a.log.Debugf("[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) - a.payloadQueue.push(d, a.peerLastTSN) + a.payloadQueue.push(d.tsn) s.handleData(d) } else { a.log.Debugf("[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) @@ -1454,10 +1456,9 @@ func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) // Meaning, if peerLastTSN+1 points to a chunk that is received, // advance peerLastTSN until peerLastTSN+1 points to unreceived chunk. for { - if _, popOk := a.payloadQueue.pop(a.peerLastTSN + 1); !popOk { + if popOk := a.payloadQueue.pop(false); !popOk { break } - a.peerLastTSN++ for _, rstReq := range a.reconfigRequests { resp := a.resetStreamsIfAny(rstReq) @@ -1470,7 +1471,7 @@ func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) hasPacketLoss := (a.payloadQueue.size() > 0) if hasPacketLoss { - a.log.Tracef("[%s] packetloss: %s", a.name, a.payloadQueue.getGapAckBlocksString(a.peerLastTSN)) + a.log.Tracef("[%s] packetloss: %s", a.name, a.payloadQueue.getGapAckBlocksString()) } if (a.ackState != ackStateImmediate && !sackImmediately && !hasPacketLoss && a.ackMode == ackModeNormal) || a.ackMode == ackModeAlwaysDelay { @@ -2068,8 +2069,8 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { // duplicate may indicate the previous SACK was lost in the network. a.log.Tracef("[%s] should send ack? newCumTSN=%d peerLastTSN=%d", - a.name, c.newCumulativeTSN, a.peerLastTSN) - if sna32LTE(c.newCumulativeTSN, a.peerLastTSN) { + a.name, c.newCumulativeTSN, a.peerLastTSN()) + if sna32LTE(c.newCumulativeTSN, a.peerLastTSN()) { a.log.Tracef("[%s] sending ack on Forward TSN", a.name) a.ackState = ackStateImmediate a.ackTimer.stop() @@ -2088,9 +2089,8 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { // chunk, // Advance peerLastTSN - for sna32LT(a.peerLastTSN, c.newCumulativeTSN) { - a.payloadQueue.pop(a.peerLastTSN + 1) // may not exist - a.peerLastTSN++ + for sna32LT(a.peerLastTSN(), c.newCumulativeTSN) { + a.payloadQueue.pop(true) // may not exist } // Report new peerLastTSN value and abandoned largest SSN value to @@ -2143,7 +2143,7 @@ func (a *Association) handleReconfigParam(raw param) (*packet, error) { switch p := raw.(type) { case *paramOutgoingResetRequest: a.log.Tracef("[%s] handleReconfigParam (OutgoingResetRequest)", a.name) - if a.peerLastTSN < p.senderLastTSN && len(a.reconfigRequests) >= maxReconfigRequests { + if a.peerLastTSN() < p.senderLastTSN && len(a.reconfigRequests) >= maxReconfigRequests { // We have too many reconfig requests outstanding. Drop the request and let // the peer retransmit. A well behaved peer should only have 1 outstanding // reconfig request. @@ -2189,9 +2189,9 @@ func (a *Association) handleReconfigParam(raw param) (*packet, error) { // The caller should hold the lock. func (a *Association) resetStreamsIfAny(p *paramOutgoingResetRequest) *packet { result := reconfigResultSuccessPerformed - if sna32LTE(p.senderLastTSN, a.peerLastTSN) { + if sna32LTE(p.senderLastTSN, a.peerLastTSN()) { a.log.Debugf("[%s] resetStream(): senderLastTSN=%d <= peerLastTSN=%d", - a.name, p.senderLastTSN, a.peerLastTSN) + a.name, p.senderLastTSN, a.peerLastTSN()) for _, id := range p.streamIdentifiers { s, ok := a.streams[id] if !ok { @@ -2206,7 +2206,7 @@ func (a *Association) resetStreamsIfAny(p *paramOutgoingResetRequest) *packet { delete(a.reconfigRequests, p.reconfigRequestSequenceNumber) } else { a.log.Debugf("[%s] resetStream(): senderLastTSN=%d > peerLastTSN=%d", - a.name, p.senderLastTSN, a.peerLastTSN) + a.name, p.senderLastTSN, a.peerLastTSN()) result = reconfigResultInProgress } @@ -2280,7 +2280,7 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1 break // would exceeds cwnd } - if dataLen > a.rwnd { + if dataLen > a.RWND() { break // no more rwnd } @@ -2454,10 +2454,10 @@ func (a *Association) generateNextRSN() uint32 { func (a *Association) createSelectiveAckChunk() *chunkSelectiveAck { sack := &chunkSelectiveAck{} - sack.cumulativeTSNAck = a.peerLastTSN + sack.cumulativeTSNAck = a.peerLastTSN() sack.advertisedReceiverWindowCredit = a.getMyReceiverWindowCredit() sack.duplicateTSN = a.payloadQueue.popDuplicates() - sack.gapAckBlocks = a.payloadQueue.getGapAckBlocks(a.peerLastTSN) + sack.gapAckBlocks = a.payloadQueue.getGapAckBlocks() return sack } diff --git a/association_test.go b/association_test.go index 0784f515..dd833351 100644 --- a/association_test.go +++ b/association_test.go @@ -1283,10 +1283,10 @@ func TestHandleForwardTSN(t *testing.T) { LoggerFactory: loggerFactory, }) a.useForwardTSN = true - prevTSN := a.peerLastTSN + prevTSN := a.peerLastTSN() fwdtsn := &chunkForwardTSN{ - newCumulativeTSN: a.peerLastTSN + 3, + newCumulativeTSN: prevTSN + 3, streams: []chunkForwardTSNStream{{identifier: 0, sequence: 0}}, } @@ -1296,7 +1296,7 @@ func TestHandleForwardTSN(t *testing.T) { delayedAckTriggered := a.delayedAckTriggered immediateAckTriggered := a.immediateAckTriggered a.lock.Unlock() - assert.Equal(t, a.peerLastTSN, prevTSN+3, "peerLastTSN should advance by 3 ") + assert.Equal(t, a.peerLastTSN(), prevTSN+3, "peerLastTSN should advance by 3 ") assert.True(t, delayedAckTriggered, "delayed sack should be triggered") assert.False(t, immediateAckTriggered, "immediate sack should NOT be triggered") assert.Nil(t, p, "should return nil") @@ -1308,20 +1308,13 @@ func TestHandleForwardTSN(t *testing.T) { LoggerFactory: loggerFactory, }) a.useForwardTSN = true - prevTSN := a.peerLastTSN + prevTSN := a.peerLastTSN() // this chunk is blocked by the missing chunk at tsn=1 - a.payloadQueue.push(&chunkPayloadData{ - beginningFragment: true, - endingFragment: true, - tsn: a.peerLastTSN + 2, - streamIdentifier: 0, - streamSequenceNumber: 1, - userData: []byte("ABC"), - }, a.peerLastTSN) + a.payloadQueue.push(a.peerLastTSN() + 2) fwdtsn := &chunkForwardTSN{ - newCumulativeTSN: a.peerLastTSN + 1, + newCumulativeTSN: a.peerLastTSN() + 1, streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, @@ -1333,7 +1326,7 @@ func TestHandleForwardTSN(t *testing.T) { delayedAckTriggered := a.delayedAckTriggered immediateAckTriggered := a.immediateAckTriggered a.lock.Unlock() - assert.Equal(t, a.peerLastTSN, prevTSN+2, "peerLastTSN should advance by 3") + assert.Equal(t, a.peerLastTSN(), prevTSN+2, "peerLastTSN should advance by 3") assert.True(t, delayedAckTriggered, "delayed sack should be triggered") assert.False(t, immediateAckTriggered, "immediate sack should NOT be triggered") assert.Nil(t, p, "should return nil") @@ -1345,20 +1338,13 @@ func TestHandleForwardTSN(t *testing.T) { LoggerFactory: loggerFactory, }) a.useForwardTSN = true - prevTSN := a.peerLastTSN + prevTSN := a.peerLastTSN() // this chunk is blocked by the missing chunk at tsn=1 - a.payloadQueue.push(&chunkPayloadData{ - beginningFragment: true, - endingFragment: true, - tsn: a.peerLastTSN + 3, - streamIdentifier: 0, - streamSequenceNumber: 1, - userData: []byte("ABC"), - }, a.peerLastTSN) + a.payloadQueue.push(a.peerLastTSN() + 3) fwdtsn := &chunkForwardTSN{ - newCumulativeTSN: a.peerLastTSN + 1, + newCumulativeTSN: a.peerLastTSN() + 1, streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, @@ -1369,7 +1355,7 @@ func TestHandleForwardTSN(t *testing.T) { a.lock.Lock() immediateAckTriggered := a.immediateAckTriggered a.lock.Unlock() - assert.Equal(t, a.peerLastTSN, prevTSN+1, "peerLastTSN should advance by 1") + assert.Equal(t, a.peerLastTSN(), prevTSN+1, "peerLastTSN should advance by 1") assert.True(t, immediateAckTriggered, "immediate sack should be triggered") assert.Nil(t, p, "should return nil") @@ -1381,10 +1367,10 @@ func TestHandleForwardTSN(t *testing.T) { LoggerFactory: loggerFactory, }) a.useForwardTSN = true - prevTSN := a.peerLastTSN + prevTSN := a.peerLastTSN() fwdtsn := &chunkForwardTSN{ - newCumulativeTSN: a.peerLastTSN, // old TSN + newCumulativeTSN: a.peerLastTSN(), // old TSN streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, @@ -1395,7 +1381,7 @@ func TestHandleForwardTSN(t *testing.T) { a.lock.Lock() ackState := a.ackState a.lock.Unlock() - assert.Equal(t, a.peerLastTSN, prevTSN, "peerLastTSN should not advance") + assert.Equal(t, a.peerLastTSN(), prevTSN, "peerLastTSN should not advance") assert.Equal(t, ackStateImmediate, ackState, "sack should be requested") assert.Nil(t, p, "should return nil") }) @@ -1690,7 +1676,7 @@ func TestAssocCreateNewStream(t *testing.T) { toBeIgnored := &chunkPayloadData{ beginningFragment: true, endingFragment: true, - tsn: a.peerLastTSN + 1, + tsn: a.peerLastTSN() + 1, streamIdentifier: newSI, userData: []byte("ABC"), } @@ -2482,7 +2468,7 @@ func TestAssocHandleInit(t *testing.T) { return } assert.NoError(t, err, "should succeed") - assert.Equal(t, init.initialTSN-1, a.peerLastTSN, "should match") + assert.Equal(t, init.initialTSN-1, a.peerLastTSN(), "should match") assert.Equal(t, uint16(1001), a.myMaxNumOutboundStreams, "should match") assert.Equal(t, uint16(1002), a.myMaxNumInboundStreams, "should match") assert.Equal(t, uint32(5678), a.peerVerificationTag, "should match") diff --git a/go.mod b/go.mod index 2f004cae..42fa5eb4 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/pion/sctp require ( + github.com/gammazero/deque v0.2.1 github.com/pion/logging v0.2.2 github.com/pion/randutil v0.1.0 github.com/pion/transport/v3 v3.0.2 diff --git a/go.sum b/go.sum index 25392e78..92a3c047 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gammazero/deque v0.2.1 h1:qSdsbG6pgp6nL7A0+K/B7s12mcCY/5l5SIUpMOl+dC0= +github.com/gammazero/deque v0.2.1/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= diff --git a/payload_queue.go b/payload_queue.go index a0b1b26f..22bfb73f 100644 --- a/payload_queue.go +++ b/payload_queue.go @@ -3,81 +3,28 @@ package sctp -import ( - "fmt" - "sort" -) +import "github.com/gammazero/deque" type payloadQueue struct { - chunkMap map[uint32]*chunkPayloadData - sorted []uint32 - dupTSN []uint32 - nBytes int + chunks *deque.Deque[*chunkPayloadData] + nBytes int } func newPayloadQueue() *payloadQueue { - return &payloadQueue{chunkMap: map[uint32]*chunkPayloadData{}} -} - -func (q *payloadQueue) updateSortedKeys() { - if q.sorted != nil { - return - } - - q.sorted = make([]uint32, len(q.chunkMap)) - i := 0 - for k := range q.chunkMap { - q.sorted[i] = k - i++ - } - - sort.Slice(q.sorted, func(i, j int) bool { - return sna32LT(q.sorted[i], q.sorted[j]) - }) -} - -func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32, maxTSNOffset uint32) bool { - _, ok := q.chunkMap[p.tsn] - if ok || sna32LTE(p.tsn, cumulativeTSN) || sna32GTE(p.tsn, cumulativeTSN+maxTSNOffset) { - return false - } - return true + return &payloadQueue{chunks: deque.New[*chunkPayloadData](128)} } func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) { - q.chunkMap[p.tsn] = p + q.chunks.PushBack(p) q.nBytes += len(p.userData) - q.sorted = nil -} - -// push pushes a payload data. If the payload data is already in our queue or -// older than our cumulativeTSN marker, it will be recored as duplications, -// which can later be retrieved using popDuplicates. -func (q *payloadQueue) push(p *chunkPayloadData, cumulativeTSN uint32) bool { - _, ok := q.chunkMap[p.tsn] - if ok || sna32LTE(p.tsn, cumulativeTSN) { - // Found the packet, log in dups - q.dupTSN = append(q.dupTSN, p.tsn) - return false - } - - q.chunkMap[p.tsn] = p - q.nBytes += len(p.userData) - q.sorted = nil - return true } // pop pops only if the oldest chunk's TSN matches the given TSN. func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) { - q.updateSortedKeys() - - if len(q.chunkMap) > 0 && tsn == q.sorted[0] { - q.sorted = q.sorted[1:] - if c, ok := q.chunkMap[tsn]; ok { - delete(q.chunkMap, tsn) - q.nBytes -= len(c.userData) - return c, true - } + if q.chunks.Len() > 0 && tsn == q.chunks.Front().tsn { + c := q.chunks.PopFront() + q.nBytes -= len(c.userData) + return c, true } return nil, false @@ -85,65 +32,20 @@ func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) { // get returns reference to chunkPayloadData with the given TSN value. func (q *payloadQueue) get(tsn uint32) (*chunkPayloadData, bool) { - c, ok := q.chunkMap[tsn] - return c, ok -} - -// popDuplicates returns an array of TSN values that were found duplicate. -func (q *payloadQueue) popDuplicates() []uint32 { - dups := q.dupTSN - q.dupTSN = []uint32{} - return dups -} - -func (q *payloadQueue) getGapAckBlocks(cumulativeTSN uint32) (gapAckBlocks []gapAckBlock) { - var b gapAckBlock - - if len(q.chunkMap) == 0 { - return []gapAckBlock{} + len := q.chunks.Len() + if len == 0 { + return nil, false } - - q.updateSortedKeys() - - for i, tsn := range q.sorted { - if i == 0 { - b.start = uint16(tsn - cumulativeTSN) - b.end = b.start - continue - } - diff := uint16(tsn - cumulativeTSN) - if b.end+1 == diff { - b.end++ - } else { - gapAckBlocks = append(gapAckBlocks, gapAckBlock{ - start: b.start, - end: b.end, - }) - b.start = diff - b.end = diff - } + head := q.chunks.Front().tsn + if tsn < head || int(tsn-head) >= len { + return nil, false } - - gapAckBlocks = append(gapAckBlocks, gapAckBlock{ - start: b.start, - end: b.end, - }) - - return gapAckBlocks -} - -func (q *payloadQueue) getGapAckBlocksString(cumulativeTSN uint32) string { - gapAckBlocks := q.getGapAckBlocks(cumulativeTSN) - str := fmt.Sprintf("cumTSN=%d", cumulativeTSN) - for _, b := range gapAckBlocks { - str += fmt.Sprintf(",%d-%d", b.start, b.end) - } - return str + return q.chunks.At(int(tsn - head)), true } func (q *payloadQueue) markAsAcked(tsn uint32) int { var nBytesAcked int - if c, ok := q.chunkMap[tsn]; ok { + if c, ok := q.get(tsn); ok { c.acked = true c.retransmit = false nBytesAcked = len(c.userData) @@ -154,18 +56,9 @@ func (q *payloadQueue) markAsAcked(tsn uint32) int { return nBytesAcked } -func (q *payloadQueue) getLastTSNReceived() (uint32, bool) { - q.updateSortedKeys() - - qlen := len(q.sorted) - if qlen == 0 { - return 0, false - } - return q.sorted[qlen-1], true -} - func (q *payloadQueue) markAllToRetrasmit() { - for _, c := range q.chunkMap { + for i := 0; i < q.chunks.Len(); i++ { + c := q.chunks.At(i) if c.acked || c.abandoned() { continue } @@ -178,5 +71,5 @@ func (q *payloadQueue) getNumBytes() int { } func (q *payloadQueue) size() int { - return len(q.chunkMap) + return q.chunks.Len() } diff --git a/payload_queue_test.go b/payload_queue_test.go index 399f4148..a5982b48 100644 --- a/payload_queue_test.go +++ b/payload_queue_test.go @@ -31,7 +31,6 @@ func TestPayloadQueue(t *testing.T) { assert.True(t, ok, "pop should succeed") if ok { assert.Equal(t, i, c.tsn, "TSN should match") - assert.NotNil(t, pq.sorted, "should not be nil") } } @@ -39,10 +38,8 @@ func TestPayloadQueue(t *testing.T) { assert.Equal(t, 0, pq.size(), "item count mismatch") pq.pushNoCheck(makePayload(3, 13)) - assert.Nil(t, pq.sorted, "should be nil") assert.Equal(t, 13, pq.getNumBytes(), "total bytes mismatch") pq.pushNoCheck(makePayload(4, 14)) - assert.Nil(t, pq.sorted, "should be nil") assert.Equal(t, 27, pq.getNumBytes(), "total bytes mismatch") for i := uint32(3); i < 5; i++ { @@ -50,7 +47,6 @@ func TestPayloadQueue(t *testing.T) { assert.True(t, ok, "pop should succeed") if ok { assert.Equal(t, i, c.tsn, "TSN should match") - assert.NotNil(t, pq.sorted, "should not be nil") } } @@ -58,69 +54,10 @@ func TestPayloadQueue(t *testing.T) { assert.Equal(t, 0, pq.size(), "item count mismatch") }) - t.Run("getGapAckBlocks", func(t *testing.T) { - pq := newPayloadQueue() - pq.push(makePayload(1, 0), 0) - pq.push(makePayload(2, 0), 0) - pq.push(makePayload(3, 0), 0) - pq.push(makePayload(4, 0), 0) - pq.push(makePayload(5, 0), 0) - pq.push(makePayload(6, 0), 0) - - gab1 := []*gapAckBlock{{start: 1, end: 6}} - gab2 := pq.getGapAckBlocks(0) - assert.NotNil(t, gab2) - assert.Len(t, gab2, 1) - - assert.Equal(t, gab1[0].start, gab2[0].start) - assert.Equal(t, gab1[0].end, gab2[0].end) - - pq.push(makePayload(8, 0), 0) - pq.push(makePayload(9, 0), 0) - - gab1 = []*gapAckBlock{{start: 1, end: 6}, {start: 8, end: 9}} - gab2 = pq.getGapAckBlocks(0) - assert.NotNil(t, gab2) - assert.Len(t, gab2, 2) - - assert.Equal(t, gab1[0].start, gab2[0].start) - assert.Equal(t, gab1[0].end, gab2[0].end) - assert.Equal(t, gab1[1].start, gab2[1].start) - assert.Equal(t, gab1[1].end, gab2[1].end) - }) - - t.Run("getLastTSNReceived", func(t *testing.T) { - pq := newPayloadQueue() - - // empty queie should return false - _, ok := pq.getLastTSNReceived() - assert.False(t, ok, "should be false") - - ok = pq.push(makePayload(20, 0), 0) - assert.True(t, ok, "should be true") - tsn, ok := pq.getLastTSNReceived() - assert.True(t, ok, "should be false") - assert.Equal(t, uint32(20), tsn, "should match") - - // append should work - ok = pq.push(makePayload(21, 0), 0) - assert.True(t, ok, "should be true") - tsn, ok = pq.getLastTSNReceived() - assert.True(t, ok, "should be false") - assert.Equal(t, uint32(21), tsn, "should match") - - // check if sorting applied - ok = pq.push(makePayload(19, 0), 0) - assert.True(t, ok, "should be true") - tsn, ok = pq.getLastTSNReceived() - assert.True(t, ok, "should be false") - assert.Equal(t, uint32(21), tsn, "should match") - }) - t.Run("markAllToRetrasmit", func(t *testing.T) { pq := newPayloadQueue() for i := 0; i < 3; i++ { - pq.push(makePayload(uint32(i+1), 10), 0) + pq.pushNoCheck(makePayload(uint32(i+1), 10)) } pq.markAsAcked(2) pq.markAllToRetrasmit() @@ -139,7 +76,7 @@ func TestPayloadQueue(t *testing.T) { t.Run("reset retransmit flag on ack", func(t *testing.T) { pq := newPayloadQueue() for i := 0; i < 4; i++ { - pq.push(makePayload(uint32(i+1), 10), 0) + pq.pushNoCheck(makePayload(uint32(i+1), 10)) } pq.markAllToRetrasmit() diff --git a/receive_payload_queue.go b/receive_payload_queue.go new file mode 100644 index 00000000..e3c36f31 --- /dev/null +++ b/receive_payload_queue.go @@ -0,0 +1,201 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "fmt" +) + +type receivePayloadQueue struct { + tailTSN uint32 + chunkSize int + tsnBitmask []uint64 + dupTSN []uint32 + maxTSNOffset uint32 + + cumulativeTSN uint32 +} + +func newReceivePayloadQueue(maxTSNOffset uint32) *receivePayloadQueue { + maxTSNOffset = ((maxTSNOffset + 63) / 64) * 64 + return &receivePayloadQueue{ + tsnBitmask: make([]uint64, maxTSNOffset/64), + maxTSNOffset: maxTSNOffset, + } +} + +func (q *receivePayloadQueue) init(cumulativeTSN uint32) { + q.cumulativeTSN = cumulativeTSN + q.tailTSN = cumulativeTSN + q.chunkSize = 0 + for i := range q.tsnBitmask { + q.tsnBitmask[i] = 0 + } + q.dupTSN = q.dupTSN[:0] +} + +func (q *receivePayloadQueue) hasChunk(tsn uint32) bool { + if q.chunkSize == 0 || sna32LTE(tsn, q.cumulativeTSN) || sna32GT(tsn, q.tailTSN) { + return false + } + + index, offset := int(tsn/64)%len(q.tsnBitmask), tsn%64 + return q.tsnBitmask[index]&(1< endTSN { + b.end = uint16(endTSN - q.cumulativeTSN) + gapAckBlocks = append(gapAckBlocks, gapAckBlock{ + start: b.start, + end: b.end, + }) + break + } + } + } + } + return gapAckBlocks +} + +func (q *receivePayloadQueue) getGapAckBlocksString() string { + gapAckBlocks := q.getGapAckBlocks() + str := fmt.Sprintf("cumTSN=%d", q.cumulativeTSN) + for _, b := range gapAckBlocks { + str += fmt.Sprintf(",%d-%d", b.start, b.end) + } + return str +} + +func (q *receivePayloadQueue) getLastTSNReceived() (uint32, bool) { + if q.chunkSize == 0 { + return 0, false + } + return q.tailTSN, true +} + +func (q *receivePayloadQueue) getcumulativeTSN() uint32 { + return q.cumulativeTSN +} + +func (q *receivePayloadQueue) size() int { + return q.chunkSize +} + +// get first non-zero bit index from val's bit range [start, end) +func getFirstNonZeroBit(val uint64, start, end int) (int, bool) { + // check all zero bit + if (val<<(64-end))>>(64-end+start) == 0 { + return 0, false + } + + for i := start; i < end; i++ { + if val&(1<>start)+1)<<(start+64-end) == 0 { + return 0, false + } + + for i := start; i < end; i++ { + if val&(1< +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReceivePayloadQueue(t *testing.T) { + maxOffset := uint32(512) + q := newReceivePayloadQueue(maxOffset) + initTSN := uint32(math.MaxUint32 - 10) + q.init(initTSN - 2) + assert.Equal(t, initTSN-2, q.getcumulativeTSN()) + assert.Zero(t, q.size()) + _, ok := q.getLastTSNReceived() + assert.False(t, ok) + assert.Empty(t, q.getGapAckBlocks()) + // force pop empy queue to advance cumulative TSN + assert.False(t, q.pop(true)) + assert.Equal(t, initTSN-1, q.getcumulativeTSN()) + assert.Zero(t, q.size()) + assert.Empty(t, q.getGapAckBlocks()) + + assert.True(t, q.push(initTSN)) + assert.False(t, q.canPush(initTSN-1)) + assert.True(t, q.canPush(initTSN+maxOffset-1)) + assert.False(t, q.canPush(initTSN+maxOffset)) + assert.False(t, q.push(initTSN+maxOffset)) + assert.Equal(t, 1, q.size()) + + gaps := q.getGapAckBlocks() + assert.EqualValues(t, []gapAckBlock{{start: uint16(1), end: uint16(1)}}, gaps) + + nextTSN := initTSN + maxOffset - 1 + assert.True(t, q.push(nextTSN)) + assert.Equal(t, 2, q.size()) + lastTSN, ok := q.getLastTSNReceived() + assert.True(t, lastTSN == nextTSN && ok, "lastTSN:%d, ok:%t", lastTSN, ok) + assert.True(t, q.hasChunk(nextTSN)) + + assert.True(t, q.pop(false)) + assert.Equal(t, 1, q.size()) + assert.Equal(t, initTSN, q.cumulativeTSN) + assert.False(t, q.pop(false)) + assert.Equal(t, initTSN, q.cumulativeTSN) + + size := q.size() + // push tsn with two gap + // tsnRange [[start,end]...] + tsnRange := [][]uint32{ + {initTSN + 5, initTSN + 6}, + {initTSN + 9, initTSN + 140}, + } + range0, range1 := tsnRange[0], tsnRange[1] + for tsn := range0[0]; sna32LTE(tsn, range0[1]); tsn++ { + assert.True(t, q.push(tsn)) + assert.False(t, q.pop(false)) + assert.True(t, q.hasChunk(tsn)) + } + size += int(range0[1] - range0[0] + 1) + + for tsn := range1[0]; sna32LTE(tsn, range1[1]); tsn++ { + assert.True(t, q.push(tsn)) + assert.False(t, q.pop(false)) + assert.True(t, q.hasChunk(tsn)) + } + size += int(range1[1] - range1[0] + 1) + + assert.Equal(t, size, q.size()) + gaps = q.getGapAckBlocks() + assert.EqualValues(t, []gapAckBlock{ + {start: uint16(range0[0] - initTSN), end: uint16(range0[1] - initTSN)}, + {start: uint16(range1[0] - initTSN), end: uint16(range1[1] - initTSN)}, + {start: uint16(nextTSN - initTSN), end: uint16(nextTSN - initTSN)}, + }, gaps) + + // push duplicate tsns + assert.False(t, q.push(initTSN-2)) + assert.False(t, q.push(range0[0])) + assert.False(t, q.push(range0[0])) + assert.False(t, q.push(nextTSN)) + assert.False(t, q.push(initTSN+maxOffset+1)) + duplicates := q.popDuplicates() + assert.EqualValues(t, []uint32{initTSN - 2, range0[0], range0[0], nextTSN}, duplicates) + + // force pop to advance cumulativeTSN to fill the gap [initTSN, initTSN+4] + for tsn := initTSN + 1; sna32LT(tsn, range0[0]); tsn++ { + assert.False(t, q.pop(true)) + assert.Equal(t, size, q.size()) + assert.Equal(t, tsn, q.cumulativeTSN) + } + + for tsn := range0[0]; sna32LTE(tsn, range0[1]); tsn++ { + assert.True(t, q.pop(false)) + assert.Equal(t, tsn, q.getcumulativeTSN()) + } + assert.False(t, q.pop(false)) + cumulativeTSN := q.getcumulativeTSN() + assert.Equal(t, range0[1], cumulativeTSN) + gaps = q.getGapAckBlocks() + assert.EqualValues(t, []gapAckBlock{ + {start: uint16(range1[0] - range0[1]), end: uint16(range1[1] - range0[1])}, + {start: uint16(nextTSN - range0[1]), end: uint16(nextTSN - range0[1])}, + }, gaps) + + // fill the gap with received tsn + for tsn := range0[1] + 1; sna32LT(tsn, range1[0]); tsn++ { + assert.True(t, q.push(tsn), tsn) + } + for tsn := range0[1] + 1; sna32LTE(tsn, range1[1]); tsn++ { + assert.True(t, q.pop(false)) + assert.Equal(t, tsn, q.getcumulativeTSN()) + } + assert.False(t, q.pop(false)) + assert.Equal(t, range1[1], q.getcumulativeTSN()) + gaps = q.getGapAckBlocks() + assert.EqualValues(t, []gapAckBlock{ + {start: uint16(nextTSN - range1[1]), end: uint16(nextTSN - range1[1])}, + }, gaps) + + // gap block cross end tsn + endTSN := maxOffset - 1 + for tsn := nextTSN + 1; sna32LTE(tsn, endTSN); tsn++ { + assert.True(t, q.push(tsn)) + } + gaps = q.getGapAckBlocks() + assert.EqualValues(t, []gapAckBlock{ + {start: uint16(nextTSN - range1[1]), end: uint16(endTSN - range1[1])}, + }, gaps) + + assert.NotEmpty(t, q.getGapAckBlocksString()) +} + +func TestBitfunc(t *testing.T) { + idx, ok := getFirstNonZeroBit(0xf, 0, 20) + assert.True(t, ok) + assert.Equal(t, 0, idx) + _, ok = getFirstNonZeroBit(0xf<<20, 0, 20) + assert.False(t, ok) + idx, ok = getFirstNonZeroBit(0xf<<20, 5, 25) + assert.True(t, ok) + assert.Equal(t, 20, idx) + _, ok = getFirstNonZeroBit(0xf<<20, 30, 40) + assert.False(t, ok) + _, ok = getFirstNonZeroBit(0, 0, 64) + assert.False(t, ok) + + idx, ok = getFirstZeroBit(0xf, 0, 20) + assert.True(t, ok) + assert.Equal(t, 4, idx) + idx, ok = getFirstZeroBit(0xf<<20, 0, 20) + assert.True(t, ok) + assert.Equal(t, 0, idx) + _, ok = getFirstZeroBit(0xf<<20, 20, 24) + assert.False(t, ok) + idx, ok = getFirstZeroBit(0xf<<20, 30, 40) + assert.True(t, ok) + assert.Equal(t, 30, idx) + _, ok = getFirstZeroBit(math.MaxUint64, 0, 64) + assert.False(t, ok) +}