From 5384b60675e5189248815f3ece8ae20bbad6342f Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Fri, 23 Feb 2024 18:45:42 -0500 Subject: [PATCH] Various cleanup and fixes --- association.go | 199 ++++++++++++++++++++++++++------------- association_stats.go | 30 ++++++ association_test.go | 9 +- chunk_payload_data.go | 4 + control_queue.go | 16 +++- packet_test.go | 8 +- payload_queue.go | 49 +++++++++- pending_queue.go | 19 +++- reassembly_queue.go | 35 ++++++- reassembly_queue_test.go | 38 ++++++++ rtx_timer.go | 30 ++++-- vnet_test.go | 24 ++--- 12 files changed, 360 insertions(+), 101 deletions(-) diff --git a/association.go b/association.go index cf493ef7..22326914 100644 --- a/association.go +++ b/association.go @@ -19,6 +19,14 @@ import ( "github.com/pion/randutil" ) +// Port 5000 shows up in examples for SDPs +// used by WebRTC. Since this implementation +// assumes it will be used by DTLS over UDP, +// the port is only meaningful for de-multiplexing +// but more-so verification. +// Example usage: https://www.rfc-editor.org/rfc/rfc8841.html#section-13.1-2 +const defaultSCTPSrcDstPort = 5000 + // Use global random generator to properly seed by crypto grade random. var globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals @@ -131,6 +139,8 @@ func getAssociationStateString(a uint32) string { // // Note: No "CLOSED" state is illustrated since if a // association is "CLOSED" its TCB SHOULD be removed. +// Note: By nature of an Association being constructed with one net.Conn, +// it is not a multi-home supporting implementation of SCTP. type Association struct { bytesReceived uint64 bytesSent uint64 @@ -177,7 +187,7 @@ type Association struct { cumulativeTSNAckPoint uint32 advancedPeerTSNAckPoint uint32 useForwardTSN bool - useZeroChecksum bool + useZeroChecksum uint32 requestZeroChecksum bool // Congestion control parameters @@ -232,6 +242,7 @@ type Association struct { // Config collects the arguments to createAssociation construction into // a single structure type Config struct { + Name string NetConn net.Conn MaxReceiveBufferSize uint32 MaxMessageSize uint32 @@ -296,11 +307,17 @@ func createAssociation(config Config) *Association { tsn := globalMathRandomGenerator.Uint32() a := &Association{ - netConn: config.NetConn, - maxReceiveBufferSize: maxReceiveBufferSize, - maxMessageSize: maxMessageSize, + netConn: config.NetConn, + maxReceiveBufferSize: maxReceiveBufferSize, + maxMessageSize: maxMessageSize, + + // These two max values have us not need to follow + // 5.1.1 where this peer may be incapable of supporting + // the requested amount of outbound streams from the other + // peer. myMaxNumOutboundStreams: math.MaxUint16, myMaxNumInboundStreams: math.MaxUint16, + payloadQueue: newPayloadQueue(), inflightQueue: newPayloadQueue(), pendingQueue: newPendingQueue(), @@ -327,9 +344,12 @@ func createAssociation(config Config) *Association { silentError: ErrSilentlyDiscard, stats: &associationStats{}, log: config.LoggerFactory.NewLogger("sctp"), + name: config.Name, } - a.name = fmt.Sprintf("%p", a) + if a.name == "" { + a.name = fmt.Sprintf("%p", a) + } // RFC 4690 Sec 7.2.1 // o The initial cwnd before DATA transmission or after a sufficiently @@ -358,13 +378,15 @@ func (a *Association) init(isClient bool) { go a.writeLoop() if isClient { - a.setState(cookieWait) - init := &chunkInit{} - init.initialTSN = a.myNextTSN - init.numOutboundStreams = a.myMaxNumOutboundStreams - init.numInboundStreams = a.myMaxNumInboundStreams - init.initiateTag = a.myVerificationTag - init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize + init := &chunkInit{ + chunkInitCommon: chunkInitCommon{ + initialTSN: a.myNextTSN, + numOutboundStreams: a.myMaxNumOutboundStreams, + numInboundStreams: a.myMaxNumInboundStreams, + initiateTag: a.myVerificationTag, + advertisedReceiverWindowCredit: a.maxReceiveBufferSize, + }, + } setSupportedExtensions(&init.chunkInitCommon) if a.requestZeroChecksum { @@ -378,7 +400,9 @@ func (a *Association) init(isClient bool) { a.log.Errorf("[%s] failed to send init: %s", a.name, err.Error()) } + // After sending the INIT chunk, "A" starts the T1-init timer and enters the COOKIE-WAIT state. a.t1Init.start(a.rtoMgr.getRTO()) + a.setState(cookieWait) } } @@ -389,10 +413,14 @@ func (a *Association) sendInit() error { return ErrInitNotStoredToSend } - outbound := &packet{} + a.sourcePort = defaultSCTPSrcDstPort + a.destinationPort = defaultSCTPSrcDstPort + outbound := &packet{ + verificationTag: a.peerVerificationTag, + sourcePort: a.sourcePort, + destinationPort: a.destinationPort, + } outbound.verificationTag = a.peerVerificationTag - a.sourcePort = 5000 // Spec?? - a.destinationPort = 5000 // Spec?? outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort @@ -468,8 +496,11 @@ func (a *Association) Close() error { <-a.readLoopCloseCh a.log.Debugf("[%s] association closed", a.name) + a.log.Debugf("[%s] stats nPackets (in) : %d", a.name, a.stats.getNumPackets()) + a.log.Debugf("[%s] stats nPackets (out) : %d", a.name, a.stats.getNumPacketsSent()) a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKs()) + a.log.Debugf("[%s] stats nSACKs (out) : %d\n", a.name, a.stats.getNumSACKsSent()) a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) @@ -588,6 +619,7 @@ loop: break loop } atomic.AddUint64(&a.bytesSent, uint64(len(raw))) + a.stats.incPacketsSent() } if !ok { @@ -638,12 +670,12 @@ func chunkMandatoryChecksum(cc []chunk) bool { } func (a *Association) marshalPacket(p *packet) ([]byte, error) { - return p.marshal(!a.useZeroChecksum || chunkMandatoryChecksum(p.chunks)) + return p.marshal(!a.getUseZeroChecksum() || chunkMandatoryChecksum(p.chunks)) } func (a *Association) unmarshalPacket(raw []byte) (*packet, error) { p := &packet{} - if !a.useZeroChecksum { + if !a.getUseZeroChecksum() { if err := p.unmarshal(true, raw); err != nil { return nil, err } @@ -653,13 +685,18 @@ func (a *Association) unmarshalPacket(raw []byte) (*packet, error) { if err := p.unmarshal(false, raw); err != nil { return nil, err } - if chunkMandatoryChecksum(p.chunks) { - if err := p.unmarshal(true, raw); err != nil { - return nil, err - } + if !chunkMandatoryChecksum(p.chunks) { + return p, nil } - return p, nil + // nolint:godox + // TODO: This feels inefficient to unmarshal again + checkedP := &packet{} + if err := checkedP.unmarshal(true, raw); err != nil { + return nil, err + } + + return checkedP, nil } // handleInbound parses incoming raw packets @@ -675,7 +712,7 @@ func (a *Association) handleInbound(raw []byte) error { return nil } - a.handleChunkStart() + a.handleChunksStart() for _, c := range p.chunks { if err := a.handleChunk(p, c); err != nil { @@ -683,7 +720,7 @@ func (a *Association) handleInbound(raw []byte) error { } } - a.handleChunkEnd() + a.handleChunksEnd() return nil } @@ -1094,6 +1131,18 @@ func (a *Association) SRTT() float64 { return a.srtt.Load().(float64) //nolint:forcetypeassert } +func (a *Association) setUseZeroChecksum(val bool) { + toStore := 0 + if val { + toStore = 1 + } + atomic.StoreUint32(&a.useZeroChecksum, uint32(toStore)) +} + +func (a *Association) getUseZeroChecksum() bool { + return atomic.LoadUint32(&a.useZeroChecksum) == 1 +} + func setSupportedExtensions(init *chunkInitCommon) { // nolint:godox // TODO RFC5061 https://tools.ietf.org/html/rfc6525#section-5.2 @@ -1123,7 +1172,10 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { return nil, fmt.Errorf("%w: %s", ErrHandleInitState, getAssociationStateString(state)) } - // Should we be setting any of these permanently until we've ACKed further? + // NOTE: Setting these prior to a reception of a COOKIE ECHO chunk containing + // our cookie is not compliant with https://www.rfc-editor.org/rfc/rfc9260#section-5.1-2.2.3. + // It makes us more vulnerable to resource attacks, albeit minimally so. + // https://www.rfc-editor.org/rfc/rfc9260#sec_handle_stream_parameters a.myMaxNumInboundStreams = min16(i.numInboundStreams, a.myMaxNumInboundStreams) a.myMaxNumOutboundStreams = min16(i.numOutboundStreams, a.myMaxNumOutboundStreams) a.peerVerificationTag = i.initiateTag @@ -1155,21 +1207,28 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { a.log.Warnf("[%s] not using ForwardTSN (on init)", a.name) } - outbound := &packet{} - outbound.verificationTag = a.peerVerificationTag - outbound.sourcePort = a.sourcePort - outbound.destinationPort = a.destinationPort + outbound := &packet{ + verificationTag: a.peerVerificationTag, + sourcePort: a.sourcePort, + destinationPort: a.destinationPort, + } - initAck := &chunkInitAck{} + a.log.Debug("sending INIT ACK") - initAck.initialTSN = a.myNextTSN - initAck.numOutboundStreams = a.myMaxNumOutboundStreams - initAck.numInboundStreams = a.myMaxNumInboundStreams - initAck.initiateTag = a.myVerificationTag - initAck.advertisedReceiverWindowCredit = a.maxReceiveBufferSize + initAck := &chunkInitAck{ + chunkInitCommon: chunkInitCommon{ + initialTSN: a.myNextTSN, + numOutboundStreams: a.myMaxNumOutboundStreams, + numInboundStreams: a.myMaxNumInboundStreams, + initiateTag: a.myVerificationTag, + advertisedReceiverWindowCredit: a.maxReceiveBufferSize, + }, + } if a.myCookie == nil { var err error + // NOTE: This generation process is not compliant with + // 5.1.3. Generating State Cookie (https://www.rfc-editor.org/rfc/rfc4960#section-5.1.3) if a.myCookie, err = newRandomStateCookie(); err != nil { return nil, err } @@ -1179,9 +1238,9 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { if peerHasZeroChecksum { initAck.params = append(initAck.params, ¶mZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod}) - a.useZeroChecksum = true + a.setUseZeroChecksum(true) } - a.log.Debugf("[%s] useZeroChecksum=%t (on init)", a.name, a.useZeroChecksum) + a.log.Debugf("[%s] useZeroChecksum=%t (on init)", a.name, a.getUseZeroChecksum()) setSupportedExtensions(&initAck.chunkInitCommon) @@ -1241,11 +1300,11 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { } } case *paramZeroChecksumAcceptable: - a.useZeroChecksum = v.edmid == dtlsErrorDetectionMethod + a.setUseZeroChecksum(v.edmid == dtlsErrorDetectionMethod) } } - a.log.Debugf("[%s] useZeroChecksum=%t (on initAck)", a.name, a.useZeroChecksum) + a.log.Debugf("[%s] useZeroChecksum=%t (on initAck)", a.name, a.getUseZeroChecksum()) if !a.useForwardTSN { a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name) @@ -1310,6 +1369,8 @@ func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet { return nil } + // RFC wise, these do not seem to belong here, but removing them + // causes TestCookieEchoRetransmission to break a.t1Init.stop() a.storedInit = nil @@ -1317,6 +1378,8 @@ func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet { a.storedCookieEcho = nil a.setState(established) + // nolint:godox + // TODO: add COMMUNICATION UP? a.handshakeCompletedCh <- nil } @@ -1345,6 +1408,8 @@ func (a *Association) handleCookieAck() { a.storedCookieEcho = nil a.setState(established) + // nolint:godox + // TODO: add COMMUNICATION UP? a.handshakeCompletedCh <- nil } @@ -1358,9 +1423,9 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet { if canPush { s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown) if s == nil { - // silentely discard the data. (sender will retry on T3-rtx timeout) + // silently discard the data. (sender will retry on T3-rtx timeout) // see pion/sctp#30 - a.log.Debugf("discard %d", d.streamSequenceNumber) + a.log.Debugf("[%s] discard %d", a.name, d.streamSequenceNumber) return nil } @@ -1740,7 +1805,6 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error { a.name, d.cumulativeTSNAck, a.cumulativeTSNAckPoint) - return nil } @@ -1821,7 +1885,6 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error { } a.postprocessSack(state, cumTSNAckPointAdvanced) - return nil } @@ -2383,15 +2446,17 @@ func pack(p *packet) []*packet { return []*packet{p} } -func (a *Association) handleChunkStart() { +func (a *Association) handleChunksStart() { a.lock.Lock() defer a.lock.Unlock() + a.stats.incPackets() + a.delayedAckTriggered = false a.immediateAckTriggered = false } -func (a *Association) handleChunkEnd() { +func (a *Association) handleChunksEnd() { a.lock.Lock() defer a.lock.Unlock() @@ -2414,38 +2479,28 @@ func (a *Association) handleChunk(p *packet, c chunk) error { var err error if _, err = c.check(); err != nil { - a.log.Errorf("[ %s ] failed validating chunk: %s ", a.name, err) + a.log.Errorf("[%s] failed validating chunk: %s ", a.name, err) return nil } isAbort := false switch c := c.(type) { + // Note: We do not do the following for chunkInit, chunkInitAck, and chunkCookieEcho: + // If an endpoint receives an INIT, INIT ACK, or COOKIE ECHO chunk but decides not to establish the + // new association due to missing mandatory parameters in the received INIT or INIT ACK chunk, invalid + // parameter values, or lack of local resources, it SHOULD respond with an ABORT chunk. + + // BEGIN HANDSHAKE case *chunkInit: packets, err = a.handleInit(p, c) - case *chunkInitAck: err = a.handleInitAck(p, c) - - case *chunkAbort: - isAbort = true - err = a.handleAbort(c) - - case *chunkError: - var errStr string - for _, e := range c.errorCauses { - errStr += fmt.Sprintf("(%s)", e) - } - a.log.Debugf("[%s] Error chunk, with following errors: %s", a.name, errStr) - - case *chunkHeartbeat: - packets = a.handleHeartbeat(c) - case *chunkCookieEcho: packets = a.handleCookieEcho(c) - case *chunkCookieAck: a.handleCookieAck() + // END HANDSHAKE case *chunkPayloadData: packets = a.handleData(c) @@ -2453,18 +2508,36 @@ func (a *Association) handleChunk(p *packet, c chunk) error { case *chunkSelectiveAck: err = a.handleSack(c) + // nolint:godox + // TODO: chunkHeartbeatAck not handled? + case *chunkHeartbeat: + packets = a.handleHeartbeat(c) + + case *chunkAbort: + isAbort = true + err = a.handleAbort(c) + + case *chunkError: + var errStr string + for _, e := range c.errorCauses { + errStr += fmt.Sprintf("(%s)", e) + } + a.log.Debugf("Error chunk, with following errors: %s", errStr) + case *chunkReconfig: packets, err = a.handleReconfig(c) case *chunkForwardTSN: packets = a.handleForwardTSN(c) + // BEGIN SHUTDOWN case *chunkShutdown: a.handleShutdown(c) case *chunkShutdownAck: a.handleShutdownAck(c) case *chunkShutdownComplete: err = a.handleShutdownComplete(c) + // END SHUTDOWN default: err = ErrChunkTypeUnhandled diff --git a/association_stats.go b/association_stats.go index 60883c47..00214e20 100644 --- a/association_stats.go +++ b/association_stats.go @@ -8,13 +8,32 @@ import ( ) type associationStats struct { + nPackets uint64 + nPacketsSent uint64 nDATAs uint64 nSACKs uint64 + nSACKsSent uint64 nT3Timeouts uint64 nAckTimeouts uint64 nFastRetrans uint64 } +func (s *associationStats) incPackets() { + atomic.AddUint64(&s.nPackets, 1) +} + +func (s *associationStats) getNumPackets() uint64 { + return atomic.LoadUint64(&s.nPackets) +} + +func (s *associationStats) incPacketsSent() { + atomic.AddUint64(&s.nPacketsSent, 1) +} + +func (s *associationStats) getNumPacketsSent() uint64 { + return atomic.LoadUint64(&s.nPacketsSent) +} + func (s *associationStats) incDATAs() { atomic.AddUint64(&s.nDATAs, 1) } @@ -31,6 +50,14 @@ func (s *associationStats) getNumSACKs() uint64 { return atomic.LoadUint64(&s.nSACKs) } +func (s *associationStats) incSACKsSent() { + atomic.AddUint64(&s.nSACKsSent, 1) +} + +func (s *associationStats) getNumSACKsSent() uint64 { + return atomic.LoadUint64(&s.nSACKsSent) +} + func (s *associationStats) incT3Timeouts() { atomic.AddUint64(&s.nT3Timeouts, 1) } @@ -56,8 +83,11 @@ func (s *associationStats) getNumFastRetrans() uint64 { } func (s *associationStats) reset() { + atomic.StoreUint64(&s.nPackets, 0) + atomic.StoreUint64(&s.nPacketsSent, 0) atomic.StoreUint64(&s.nDATAs, 0) atomic.StoreUint64(&s.nSACKs, 0) + atomic.StoreUint64(&s.nSACKsSent, 0) atomic.StoreUint64(&s.nT3Timeouts, 0) atomic.StoreUint64(&s.nAckTimeouts, 0) atomic.StoreUint64(&s.nFastRetrans, 0) diff --git a/association_test.go b/association_test.go index 2b6fb733..fe7de7ba 100644 --- a/association_test.go +++ b/association_test.go @@ -257,6 +257,7 @@ func createNewAssociationPair(br *test.Bridge, ackMode int, recvBufSize uint32) go func() { a0, err0 = Client(Config{ + Name: "a0", NetConn: br.GetConn0(), MaxReceiveBufferSize: recvBufSize, LoggerFactory: loggerFactory, @@ -264,7 +265,11 @@ func createNewAssociationPair(br *test.Bridge, ackMode int, recvBufSize uint32) handshake0Ch <- true }() go func() { - a1, err1 = Client(Config{ + // we could have two "client"s here but it's more + // standard to have one peer starting initialization and + // another waiting for the initialization to be requested (INIT). + a1, err1 = Server(Config{ + Name: "a1", NetConn: br.GetConn1(), MaxReceiveBufferSize: recvBufSize, LoggerFactory: loggerFactory, @@ -1752,7 +1757,7 @@ func TestAssocT3RtxTimer(t *testing.T) { } func TestAssocCongestionControl(t *testing.T) { - // sbuf - large enobh not to be bundled + // sbuf - large enough not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xcc) diff --git a/chunk_payload_data.go b/chunk_payload_data.go index b6f1b614..a6e50db9 100644 --- a/chunk_payload_data.go +++ b/chunk_payload_data.go @@ -206,3 +206,7 @@ func (p *chunkPayloadData) setAllInflight() { } } } + +func (p *chunkPayloadData) isFragmented() bool { + return !(p.head == nil && p.beginningFragment && p.endingFragment) +} diff --git a/control_queue.go b/control_queue.go index 5c417bf0..74dbfb50 100644 --- a/control_queue.go +++ b/control_queue.go @@ -3,9 +3,14 @@ package sctp +import ( + "sync" +) + // control queue type controlQueue struct { + mu sync.RWMutex queue []*packet } @@ -14,19 +19,28 @@ func newControlQueue() *controlQueue { } func (q *controlQueue) push(c *packet) { + q.mu.Lock() q.queue = append(q.queue, c) + q.mu.Unlock() } func (q *controlQueue) pushAll(packets []*packet) { + q.mu.Lock() q.queue = append(q.queue, packets...) + q.mu.Unlock() } func (q *controlQueue) popAll() []*packet { + q.mu.Lock() packets := q.queue q.queue = []*packet{} + q.mu.Unlock() return packets } func (q *controlQueue) size() int { - return len(q.queue) + q.mu.RLock() + size := len(q.queue) + q.mu.RUnlock() + return size } diff --git a/packet_test.go b/packet_test.go index 1a270e53..40557f3d 100644 --- a/packet_test.go +++ b/packet_test.go @@ -20,10 +20,10 @@ func TestPacketUnmarshal(t *testing.T) { switch { case err != nil: t.Errorf("Unmarshal failed for SCTP packet with no chunks: %v", err) - case pkt.sourcePort != 5000: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect source port exp: %d act: %d", 5000, pkt.sourcePort) - case pkt.destinationPort != 5000: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect destination port exp: %d act: %d", 5000, pkt.destinationPort) + case pkt.sourcePort != defaultSCTPSrcDstPort: + t.Errorf("Unmarshal passed for SCTP packet, but got incorrect source port exp: %d act: %d", defaultSCTPSrcDstPort, pkt.sourcePort) + case pkt.destinationPort != defaultSCTPSrcDstPort: + t.Errorf("Unmarshal passed for SCTP packet, but got incorrect destination port exp: %d act: %d", defaultSCTPSrcDstPort, pkt.destinationPort) case pkt.verificationTag != 0: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect verification tag exp: %d act: %d", 0, pkt.verificationTag) } diff --git a/payload_queue.go b/payload_queue.go index e5925a51..8f1e375e 100644 --- a/payload_queue.go +++ b/payload_queue.go @@ -6,9 +6,11 @@ package sctp import ( "fmt" "sort" + "sync" ) type payloadQueue struct { + mu sync.RWMutex chunkMap map[uint32]*chunkPayloadData sorted []uint32 dupTSN []uint32 @@ -19,7 +21,15 @@ func newPayloadQueue() *payloadQueue { return &payloadQueue{chunkMap: map[uint32]*chunkPayloadData{}} } +//nolint:unused func (q *payloadQueue) updateSortedKeys() { + q.mu.Lock() + defer q.mu.Unlock() + + q.updateSortedKeysWithLock() +} + +func (q *payloadQueue) updateSortedKeysWithLock() { if q.sorted != nil { return } @@ -37,6 +47,9 @@ func (q *payloadQueue) updateSortedKeys() { } func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool { + q.mu.RLock() + defer q.mu.RUnlock() + _, ok := q.chunkMap[p.tsn] if ok || sna32LTE(p.tsn, cumulativeTSN) { return false @@ -45,6 +58,9 @@ func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool { } func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) { + q.mu.Lock() + defer q.mu.Unlock() + q.chunkMap[p.tsn] = p q.nBytes += len(p.userData) q.sorted = nil @@ -54,6 +70,9 @@ func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) { // older than our cumulativeTSN marker, it will be recored as duplications, // which can later be retrieved using popDuplicates. func (q *payloadQueue) push(p *chunkPayloadData, cumulativeTSN uint32) bool { + q.mu.Lock() + defer q.mu.Unlock() + _, ok := q.chunkMap[p.tsn] if ok || sna32LTE(p.tsn, cumulativeTSN) { // Found the packet, log in dups @@ -69,7 +88,10 @@ func (q *payloadQueue) push(p *chunkPayloadData, cumulativeTSN uint32) bool { // pop pops only if the oldest chunk's TSN matches the given TSN. func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) { - q.updateSortedKeys() + q.mu.Lock() + defer q.mu.Unlock() + + q.updateSortedKeysWithLock() if len(q.chunkMap) > 0 && tsn == q.sorted[0] { q.sorted = q.sorted[1:] @@ -85,25 +107,34 @@ func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) { // get returns reference to chunkPayloadData with the given TSN value. func (q *payloadQueue) get(tsn uint32) (*chunkPayloadData, bool) { + q.mu.RLock() + defer q.mu.RUnlock() + c, ok := q.chunkMap[tsn] return c, ok } // popDuplicates returns an array of TSN values that were found duplicate. func (q *payloadQueue) popDuplicates() []uint32 { + q.mu.Lock() + defer q.mu.Unlock() + dups := q.dupTSN q.dupTSN = []uint32{} return dups } func (q *payloadQueue) getGapAckBlocks(cumulativeTSN uint32) (gapAckBlocks []gapAckBlock) { + q.mu.Lock() + defer q.mu.Unlock() + var b gapAckBlock if len(q.chunkMap) == 0 { return []gapAckBlock{} } - q.updateSortedKeys() + q.updateSortedKeysWithLock() for i, tsn := range q.sorted { if i == 0 { @@ -155,7 +186,10 @@ func (q *payloadQueue) markAsAcked(tsn uint32) int { } func (q *payloadQueue) getLastTSNReceived() (uint32, bool) { - q.updateSortedKeys() + q.mu.Lock() + defer q.mu.Unlock() + + q.updateSortedKeysWithLock() qlen := len(q.sorted) if qlen == 0 { @@ -165,6 +199,9 @@ func (q *payloadQueue) getLastTSNReceived() (uint32, bool) { } func (q *payloadQueue) markAllToRetrasmit() { + q.mu.Lock() + defer q.mu.Unlock() + for _, c := range q.chunkMap { if c.acked || c.abandoned() { continue @@ -174,9 +211,15 @@ func (q *payloadQueue) markAllToRetrasmit() { } func (q *payloadQueue) getNumBytes() int { + q.mu.RLock() + defer q.mu.RUnlock() + return q.nBytes } func (q *payloadQueue) size() int { + q.mu.RLock() + defer q.mu.RUnlock() + return len(q.chunkMap) } diff --git a/pending_queue.go b/pending_queue.go index 6f70fea2..71e4b936 100644 --- a/pending_queue.go +++ b/pending_queue.go @@ -5,11 +5,14 @@ package sctp import ( "errors" + "sync" + "sync/atomic" ) // pendingBaseQueue type pendingBaseQueue struct { + mu sync.RWMutex queue []*chunkPayloadData } @@ -18,10 +21,14 @@ func newPendingBaseQueue() *pendingBaseQueue { } func (q *pendingBaseQueue) push(c *chunkPayloadData) { + q.mu.Lock() q.queue = append(q.queue, c) + q.mu.Unlock() } func (q *pendingBaseQueue) pop() *chunkPayloadData { + q.mu.Lock() + defer q.mu.Unlock() if len(q.queue) == 0 { return nil } @@ -31,6 +38,8 @@ func (q *pendingBaseQueue) pop() *chunkPayloadData { } func (q *pendingBaseQueue) get(i int) *chunkPayloadData { + q.mu.RLock() + defer q.mu.RUnlock() if len(q.queue) == 0 || i < 0 || i >= len(q.queue) { return nil } @@ -38,6 +47,8 @@ func (q *pendingBaseQueue) get(i int) *chunkPayloadData { } func (q *pendingBaseQueue) size() int { + q.mu.RLock() + defer q.mu.RUnlock() return len(q.queue) } @@ -46,7 +57,7 @@ func (q *pendingBaseQueue) size() int { type pendingQueue struct { unorderedQueue *pendingBaseQueue orderedQueue *pendingBaseQueue - nBytes int + nBytes uint64 selected bool unorderedIsSelected bool } @@ -71,7 +82,7 @@ func (q *pendingQueue) push(c *chunkPayloadData) { } else { q.orderedQueue.push(c) } - q.nBytes += len(c.userData) + atomic.AddUint64(&q.nBytes, uint64(len(c.userData))) } func (q *pendingQueue) peek() *chunkPayloadData { @@ -129,12 +140,12 @@ func (q *pendingQueue) pop(c *chunkPayloadData) error { } } } - q.nBytes -= len(c.userData) + atomic.AddUint64(&q.nBytes, -uint64(len(c.userData))) return nil } func (q *pendingQueue) getNumBytes() int { - return q.nBytes + return int(atomic.LoadUint64(&q.nBytes)) } func (q *pendingQueue) size() int { diff --git a/reassembly_queue.go b/reassembly_queue.go index c23e6991..bff157ff 100644 --- a/reassembly_queue.go +++ b/reassembly_queue.go @@ -7,6 +7,7 @@ import ( "errors" "io" "sort" + "sync" "sync/atomic" ) @@ -100,6 +101,7 @@ func (set *chunkSet) isComplete() bool { } type reassemblyQueue struct { + mu sync.RWMutex si uint16 nextSSN uint16 // expected SSN for next ordered chunk ordered []*chunkSet @@ -125,6 +127,9 @@ func newReassemblyQueue(si uint16) *reassemblyQueue { } func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool { + r.mu.Lock() + defer r.mu.Unlock() + var cset *chunkSet if chunk.streamIdentifier != r.si { @@ -156,10 +161,19 @@ func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool { } // Check if a chunkSet with the SSN already exists - for _, set := range r.ordered { - if set.ssn == chunk.streamSequenceNumber { - cset = set - break + if chunk.isFragmented() { + for _, set := range r.ordered { + // nolint:godox + // TODO: add caution around SSN wrapping here... this helps a little bit + // by ensuring we don't add to an unfragmented cset (1 chunk). + + // nolint:godox + // TODO: this slice can get pretty big; it may be worth maintaining a map + // for O(1) lookups at the cost of 2x memory. + if set.ssn == chunk.streamSequenceNumber && set.chunks[0].isFragmented() { + cset = set + break + } } } @@ -177,6 +191,7 @@ func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool { return cset.push(chunk) } +// assumes lock is held func (r *reassemblyQueue) findCompleteUnorderedChunkSet() *chunkSet { startIdx := -1 nChunks := 0 @@ -235,6 +250,9 @@ func (r *reassemblyQueue) findCompleteUnorderedChunkSet() *chunkSet { } func (r *reassemblyQueue) isReadable() bool { + r.mu.RLock() + defer r.mu.RUnlock() + // Check unordered first if len(r.unordered) > 0 { // The chunk sets in r.unordered should all be complete. @@ -254,6 +272,9 @@ func (r *reassemblyQueue) isReadable() bool { } func (r *reassemblyQueue) read(buf []byte) (int, PayloadProtocolIdentifier, error) { + r.mu.Lock() + defer r.mu.Unlock() + var cset *chunkSet // Check unordered first switch { @@ -297,6 +318,9 @@ func (r *reassemblyQueue) read(buf []byte) (int, PayloadProtocolIdentifier, erro } func (r *reassemblyQueue) forwardTSNForOrdered(lastSSN uint16) { + r.mu.Lock() + defer r.mu.Unlock() + // Use lastSSN to locate a chunkSet then remove it if the set has // not been complete keep := []*chunkSet{} @@ -321,6 +345,9 @@ func (r *reassemblyQueue) forwardTSNForOrdered(lastSSN uint16) { } func (r *reassemblyQueue) forwardTSNForUnordered(newCumulativeTSN uint32) { + r.mu.Lock() + defer r.mu.Unlock() + // Remove all fragments in the unordered sets that contains chunks // equal to or older than `newCumulativeTSN`. // We know all sets in the r.unordered are complete ones. diff --git a/reassembly_queue_test.go b/reassembly_queue_test.go index 91fd5305..02478f45 100644 --- a/reassembly_queue_test.go +++ b/reassembly_queue_test.go @@ -464,6 +464,44 @@ func TestReassemblyQueue(t *testing.T) { assert.Equal(t, 1, len(rq.unorderedChunks), "there should be one chunk kept") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") }) + + t.Run("fragmented and unfragmented chunks with the same ssn", func(t *testing.T) { + rq := newReassemblyQueue(0) + + orgPpi := PayloadTypeWebRTCBinary + + var chunk *chunkPayloadData + var complete bool + var ssn uint16 = 6 + + chunk = &chunkPayloadData{ + payloadType: orgPpi, + tsn: 12, + beginningFragment: true, + endingFragment: true, + streamSequenceNumber: ssn, + userData: []byte("DEF"), + } + + complete = rq.push(chunk) + assert.True(t, complete, "chunk set should be complete") + assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") + + chunk = &chunkPayloadData{ + payloadType: orgPpi, + beginningFragment: true, + tsn: 11, + streamSequenceNumber: ssn, + userData: []byte("ABC"), + } + + complete = rq.push(chunk) + assert.False(t, complete, "chunk set should not be complete yet") + assert.Equal(t, 6, rq.getNumBytes(), "num bytes mismatch") + + assert.Equal(t, 2, len(rq.ordered), "there should be two chunks") + assert.Equal(t, 6, rq.getNumBytes(), "num bytes mismatch") + }) } func TestChunkSet(t *testing.T) { diff --git a/rtx_timer.go b/rtx_timer.go index ceb44301..67158cba 100644 --- a/rtx_timer.go +++ b/rtx_timer.go @@ -10,14 +10,28 @@ import ( ) const ( - rtoInitial float64 = 1.0 * 1000 // msec - rtoMin float64 = 1.0 * 1000 // msec - rtoMax float64 = 60.0 * 1000 // msec - rtoAlpha float64 = 0.125 - rtoBeta float64 = 0.25 - maxInitRetrans uint = 8 - pathMaxRetrans uint = 5 - noMaxRetrans uint = 0 + // RTO.Initial + rtoInitial float64 = 1.0 * 1000 // msec + + // RTO.Min + rtoMin float64 = 1.0 * 1000 // msec + + // RTO.Max + rtoMax float64 = 60.0 * 1000 // msec + + // RTO.Alpha + rtoAlpha float64 = 0.125 + + // RTO.Beta + rtoBeta float64 = 0.25 + + // Max.Init.Retransmits: + maxInitRetrans uint = 8 + + // Path.Max.Retrans + pathMaxRetrans uint = 5 + + noMaxRetrans uint = 0 ) // rtoManager manages Rtx timeout values. diff --git a/vnet_test.go b/vnet_test.go index 6f225bbd..90152171 100644 --- a/vnet_test.go +++ b/vnet_test.go @@ -202,8 +202,8 @@ func testRwndFull(t *testing.T, unordered bool) { defer close(serverShutDown) // connected UDP conn for server conn, err := venv.net0.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return @@ -277,8 +277,8 @@ func testRwndFull(t *testing.T, unordered bool) { defer close(clientShutDown) // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return @@ -435,8 +435,8 @@ func TestStreamClose(t *testing.T) { defer close(serverShutDown) // connected UDP conn for server conn, innerErr := venv.net0.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, innerErr, "should succeed") { return @@ -485,8 +485,8 @@ func TestStreamClose(t *testing.T) { // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return @@ -620,8 +620,8 @@ func TestCookieEchoRetransmission(t *testing.T) { defer close(serverShutDown) // connected UDP conn for server conn, err := venv.net0.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return @@ -650,8 +650,8 @@ func TestCookieEchoRetransmission(t *testing.T) { defer close(clientShutDown) // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return