From 29d08116120e6382bdcc948e036530fd3a2de6ee Mon Sep 17 00:00:00 2001 From: sukun Date: Sat, 2 Dec 2023 20:18:00 +0530 Subject: [PATCH 1/3] Check deadline timer before signaling --- stream.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/stream.go b/stream.go index b12bb24c..eeb585e7 100644 --- a/stream.go +++ b/stream.go @@ -175,6 +175,11 @@ func (s *Stream) SetReadDeadline(deadline time.Time) error { t.Stop() return case <-t.C: + select { + case <-readTimeoutCancel: + return + default: + } s.lock.Lock() if s.readErr == nil { s.readErr = ErrReadDeadlineExceeded From f315ba579466893f6275523a98bd52819c0ed273 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Wed, 28 Feb 2024 10:05:50 -0500 Subject: [PATCH 2/3] DRY and explain use of port 5000 --- association.go | 10 ++++++++-- packet_test.go | 8 ++++---- rtx_timer_test.go | 30 +++++++++++++++--------------- vnet_test.go | 24 ++++++++++++------------ 4 files changed, 39 insertions(+), 33 deletions(-) diff --git a/association.go b/association.go index 47c0d94e..b4056a86 100644 --- a/association.go +++ b/association.go @@ -19,6 +19,12 @@ import ( "github.com/pion/randutil" ) +// Port 5000 shows up in examples for SDPs used by WebRTC. Since this implementation +// assumes it will be used by DTLS over UDP, the port is only meaningful for de-multiplexing +// but more-so verification. +// Example usage: https://www.rfc-editor.org/rfc/rfc8841.html#section-13.1-2 +const defaultSCTPSrcDstPort = 5000 + // Use global random generator to properly seed by crypto grade random. var globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals @@ -393,8 +399,8 @@ func (a *Association) sendInit() error { outbound := &packet{} outbound.verificationTag = a.peerVerificationTag - a.sourcePort = 5000 // Spec?? - a.destinationPort = 5000 // Spec?? + a.sourcePort = defaultSCTPSrcDstPort + a.destinationPort = defaultSCTPSrcDstPort outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort diff --git a/packet_test.go b/packet_test.go index 1a270e53..40557f3d 100644 --- a/packet_test.go +++ b/packet_test.go @@ -20,10 +20,10 @@ func TestPacketUnmarshal(t *testing.T) { switch { case err != nil: t.Errorf("Unmarshal failed for SCTP packet with no chunks: %v", err) - case pkt.sourcePort != 5000: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect source port exp: %d act: %d", 5000, pkt.sourcePort) - case pkt.destinationPort != 5000: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect destination port exp: %d act: %d", 5000, pkt.destinationPort) + case pkt.sourcePort != defaultSCTPSrcDstPort: + t.Errorf("Unmarshal passed for SCTP packet, but got incorrect source port exp: %d act: %d", defaultSCTPSrcDstPort, pkt.sourcePort) + case pkt.destinationPort != defaultSCTPSrcDstPort: + t.Errorf("Unmarshal passed for SCTP packet, but got incorrect destination port exp: %d act: %d", defaultSCTPSrcDstPort, pkt.destinationPort) case pkt.verificationTag != 0: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect verification tag exp: %d act: %d", 0, pkt.verificationTag) } diff --git a/rtx_timer_test.go b/rtx_timer_test.go index 678103df..2a7da5af 100644 --- a/rtx_timer_test.go +++ b/rtx_timer_test.go @@ -120,7 +120,7 @@ func TestRtxTimer(t *testing.T) { timerID := 0 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ - onRTO: func(id int, nRtos uint) { + onRTO: func(id int, _ uint) { atomic.AddInt32(&nCbs, 1) // 30 : 1 (30) // 60 : 2 (90) @@ -128,7 +128,7 @@ func TestRtxTimer(t *testing.T) { // 240: 4 (550) <== expected in 650 msec assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, - onRtxFailure: func(id int) {}, + onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) assert.False(t, rt.isRunning(), "should not be running") @@ -150,11 +150,11 @@ func TestRtxTimer(t *testing.T) { var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ - onRTO: func(id int, nRtos uint) { + onRTO: func(id int, _ uint) { atomic.AddInt32(&nCbs, 1) assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, - onRtxFailure: func(id int) {}, + onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) interval := float64(30.0) @@ -177,11 +177,11 @@ func TestRtxTimer(t *testing.T) { var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ - onRTO: func(id int, nRtos uint) { + onRTO: func(id int, _ uint) { atomic.AddInt32(&nCbs, 1) assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, - onRtxFailure: func(id int) {}, + onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) interval := float64(30.0) @@ -200,11 +200,11 @@ func TestRtxTimer(t *testing.T) { timerID := 1 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ - onRTO: func(id int, nRtos uint) { + onRTO: func(id int, _ uint) { atomic.AddInt32(&nCbs, 1) assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, - onRtxFailure: func(id int) {}, + onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) interval := float64(30.0) @@ -226,12 +226,12 @@ func TestRtxTimer(t *testing.T) { timerID := 2 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ - onRTO: func(id int, nRtos uint) { + onRTO: func(id int, _ uint) { atomic.AddInt32(&nCbs, 1) t.Log("onRTO() called") assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, - onRtxFailure: func(id int) {}, + onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) for i := 0; i < 1000; i++ { @@ -305,7 +305,7 @@ func TestRtxTimer(t *testing.T) { doneCh <- true } }, - onRtxFailure: func(id int) { + onRtxFailure: func(_ int) { assert.Fail(t, "timer should not fail") }, }, 0, 0) @@ -338,11 +338,11 @@ func TestRtxTimer(t *testing.T) { doneCh := make(chan bool) rt := newRTXTimer(timerID, &testTimerObserver{ - onRTO: func(id int, nRtos uint) { + onRTO: func(id int, _ uint) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) doneCh <- true }, - onRtxFailure: func(id int) {}, + onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) for i := 0; i < 10; i++ { @@ -362,10 +362,10 @@ func TestRtxTimer(t *testing.T) { var rtoCount int timerID := 6 rt := newRTXTimer(timerID, &testTimerObserver{ - onRTO: func(id int, nRtos uint) { + onRTO: func(_ int, _ uint) { rtoCount++ }, - onRtxFailure: func(id int) {}, + onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) ok := rt.start(20) diff --git a/vnet_test.go b/vnet_test.go index 6f225bbd..90152171 100644 --- a/vnet_test.go +++ b/vnet_test.go @@ -202,8 +202,8 @@ func testRwndFull(t *testing.T, unordered bool) { defer close(serverShutDown) // connected UDP conn for server conn, err := venv.net0.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return @@ -277,8 +277,8 @@ func testRwndFull(t *testing.T, unordered bool) { defer close(clientShutDown) // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return @@ -435,8 +435,8 @@ func TestStreamClose(t *testing.T) { defer close(serverShutDown) // connected UDP conn for server conn, innerErr := venv.net0.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, innerErr, "should succeed") { return @@ -485,8 +485,8 @@ func TestStreamClose(t *testing.T) { // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return @@ -620,8 +620,8 @@ func TestCookieEchoRetransmission(t *testing.T) { defer close(serverShutDown) // connected UDP conn for server conn, err := venv.net0.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return @@ -650,8 +650,8 @@ func TestCookieEchoRetransmission(t *testing.T) { defer close(clientShutDown) // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return From 32ef4a16c61db9e4d23750d2d09482f5c1681830 Mon Sep 17 00:00:00 2001 From: Michael Tuexen Date: Wed, 28 Feb 2024 17:03:51 +0100 Subject: [PATCH 3/3] Fix zero checksum usage zero checksum acceptance was implemented as a symmetrical feature, but this is not the way it is specified. Each side announces independently from the peer, that it can accept packets with an incorrect zero checksum. In particular, it is completely valid to send a packet containing an INIT ACK chunk with an incorrect zero checksum, as long as the peer announced the support in the INIT chunk. --- association.go | 26 ++++++++++++-------------- association_test.go | 6 +++--- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/association.go b/association.go index b4056a86..9b2aad55 100644 --- a/association.go +++ b/association.go @@ -183,8 +183,8 @@ type Association struct { cumulativeTSNAckPoint uint32 advancedPeerTSNAckPoint uint32 useForwardTSN bool - useZeroChecksum bool - requestZeroChecksum bool + sendZeroChecksum bool + recvZeroChecksum bool // Congestion control parameters maxReceiveBufferSize uint32 @@ -331,7 +331,7 @@ func createAssociation(config Config) *Association { handshakeCompletedCh: make(chan error), cumulativeTSNAckPoint: tsn - 1, advancedPeerTSNAckPoint: tsn - 1, - requestZeroChecksum: config.EnableZeroChecksum, + recvZeroChecksum: config.EnableZeroChecksum, silentError: ErrSilentlyDiscard, stats: &associationStats{}, log: config.LoggerFactory.NewLogger("sctp"), @@ -375,7 +375,7 @@ func (a *Association) init(isClient bool) { init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize setSupportedExtensions(&init.chunkInitCommon) - if a.requestZeroChecksum { + if a.recvZeroChecksum { init.params = append(init.params, ¶mZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod}) } @@ -638,7 +638,7 @@ func (a *Association) unregisterStream(s *Stream, err error) { func chunkMandatoryChecksum(cc []chunk) bool { for _, c := range cc { switch c.(type) { - case *chunkInit, *chunkInitAck, *chunkCookieEcho: + case *chunkInit, *chunkCookieEcho: return true } } @@ -646,12 +646,12 @@ func chunkMandatoryChecksum(cc []chunk) bool { } func (a *Association) marshalPacket(p *packet) ([]byte, error) { - return p.marshal(!a.useZeroChecksum || chunkMandatoryChecksum(p.chunks)) + return p.marshal(!a.sendZeroChecksum || chunkMandatoryChecksum(p.chunks)) } func (a *Association) unmarshalPacket(raw []byte) (*packet, error) { p := &packet{} - if err := p.unmarshal(!a.useZeroChecksum, raw); err != nil { + if err := p.unmarshal(!a.recvZeroChecksum, raw); err != nil { return nil, err } return p, nil @@ -1131,7 +1131,6 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { // subtracting one from it. a.peerLastTSN = i.initialTSN - 1 - peerHasZeroChecksum := false for _, param := range i.params { switch v := param.(type) { // nolint:gocritic case *paramSupportedExtensions: @@ -1142,7 +1141,7 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { } } case *paramZeroChecksumAcceptable: - peerHasZeroChecksum = v.edmid == dtlsErrorDetectionMethod + a.sendZeroChecksum = v.edmid == dtlsErrorDetectionMethod } } @@ -1172,11 +1171,10 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { initAck.params = []param{a.myCookie} - if peerHasZeroChecksum { + if a.recvZeroChecksum { initAck.params = append(initAck.params, ¶mZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod}) - a.useZeroChecksum = true } - a.log.Debugf("[%s] useZeroChecksum=%t (on init)", a.name, a.useZeroChecksum) + a.log.Debugf("[%s] sendZeroChecksum=%t (on init)", a.name, a.sendZeroChecksum) setSupportedExtensions(&initAck.chunkInitCommon) @@ -1236,11 +1234,11 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { } } case *paramZeroChecksumAcceptable: - a.useZeroChecksum = v.edmid == dtlsErrorDetectionMethod + a.sendZeroChecksum = v.edmid == dtlsErrorDetectionMethod } } - a.log.Debugf("[%s] useZeroChecksum=%t (on initAck)", a.name, a.useZeroChecksum) + a.log.Debugf("[%s] sendZeroChecksum=%t (on initAck)", a.name, a.sendZeroChecksum) if !a.useForwardTSN { a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name) diff --git a/association_test.go b/association_test.go index a12503d7..870968db 100644 --- a/association_test.go +++ b/association_test.go @@ -3082,7 +3082,7 @@ func (c customLogger) Trace(string) {} func (c customLogger) Tracef(string, ...interface{}) {} func (c customLogger) Debug(string) {} func (c customLogger) Debugf(format string, args ...interface{}) { - if format == "[%s] useZeroChecksum=%t (on initAck)" { + if format == "[%s] sendZeroChecksum=%t (on initAck)" { assert.Equal(c.t, args[1], c.expectZeroChecksum) } } @@ -3108,8 +3108,8 @@ func TestAssociation_ZeroChecksum(t *testing.T) { }{ {true, true, true}, {false, false, false}, - {true, false, true}, - {false, true, false}, + {true, false, false}, + {false, true, true}, } { a1chan, a2chan := make(chan *Association), make(chan *Association)