Skip to content

Commit

Permalink
Merge branch 'master' into named_assoc
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels authored Feb 28, 2024
2 parents cfe9774 + 32ef4a1 commit 840fdcc
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 50 deletions.
36 changes: 20 additions & 16 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ import (
"github.com/pion/randutil"
)

// Port 5000 shows up in examples for SDPs used by WebRTC. Since this implementation
// assumes it will be used by DTLS over UDP, the port is only meaningful for de-multiplexing
// but more-so verification.
// Example usage: https://www.rfc-editor.org/rfc/rfc8841.html#section-13.1-2
const defaultSCTPSrcDstPort = 5000

// Use global random generator to properly seed by crypto grade random.
var globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals

Expand Down Expand Up @@ -177,8 +183,8 @@ type Association struct {
cumulativeTSNAckPoint uint32
advancedPeerTSNAckPoint uint32
useForwardTSN bool
useZeroChecksum bool
requestZeroChecksum bool
sendZeroChecksum bool
recvZeroChecksum bool

// Congestion control parameters
maxReceiveBufferSize uint32
Expand Down Expand Up @@ -326,7 +332,7 @@ func createAssociation(config Config) *Association {
handshakeCompletedCh: make(chan error),
cumulativeTSNAckPoint: tsn - 1,
advancedPeerTSNAckPoint: tsn - 1,
requestZeroChecksum: config.EnableZeroChecksum,
recvZeroChecksum: config.EnableZeroChecksum,
silentError: ErrSilentlyDiscard,
stats: &associationStats{},
log: config.LoggerFactory.NewLogger("sctp"),
Expand Down Expand Up @@ -373,7 +379,7 @@ func (a *Association) init(isClient bool) {
init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize
setSupportedExtensions(&init.chunkInitCommon)

if a.requestZeroChecksum {
if a.recvZeroChecksum {
init.params = append(init.params, &paramZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod})
}

Expand All @@ -397,8 +403,8 @@ func (a *Association) sendInit() error {

outbound := &packet{}
outbound.verificationTag = a.peerVerificationTag
a.sourcePort = 5000 // Spec??
a.destinationPort = 5000 // Spec??
a.sourcePort = defaultSCTPSrcDstPort
a.destinationPort = defaultSCTPSrcDstPort
outbound.sourcePort = a.sourcePort
outbound.destinationPort = a.destinationPort

Expand Down Expand Up @@ -636,20 +642,20 @@ func (a *Association) unregisterStream(s *Stream, err error) {
func chunkMandatoryChecksum(cc []chunk) bool {
for _, c := range cc {
switch c.(type) {
case *chunkInit, *chunkInitAck, *chunkCookieEcho:
case *chunkInit, *chunkCookieEcho:
return true
}
}
return false
}

func (a *Association) marshalPacket(p *packet) ([]byte, error) {
return p.marshal(!a.useZeroChecksum || chunkMandatoryChecksum(p.chunks))
return p.marshal(!a.sendZeroChecksum || chunkMandatoryChecksum(p.chunks))
}

func (a *Association) unmarshalPacket(raw []byte) (*packet, error) {
p := &packet{}
if err := p.unmarshal(!a.useZeroChecksum, raw); err != nil {
if err := p.unmarshal(!a.recvZeroChecksum, raw); err != nil {
return nil, err
}
return p, nil
Expand Down Expand Up @@ -1129,7 +1135,6 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
// subtracting one from it.
a.peerLastTSN = i.initialTSN - 1

peerHasZeroChecksum := false
for _, param := range i.params {
switch v := param.(type) { // nolint:gocritic
case *paramSupportedExtensions:
Expand All @@ -1140,7 +1145,7 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
}
}
case *paramZeroChecksumAcceptable:
peerHasZeroChecksum = v.edmid == dtlsErrorDetectionMethod
a.sendZeroChecksum = v.edmid == dtlsErrorDetectionMethod
}
}

Expand Down Expand Up @@ -1170,11 +1175,10 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {

initAck.params = []param{a.myCookie}

if peerHasZeroChecksum {
if a.recvZeroChecksum {
initAck.params = append(initAck.params, &paramZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod})
a.useZeroChecksum = true
}
a.log.Debugf("[%s] useZeroChecksum=%t (on init)", a.name, a.useZeroChecksum)
a.log.Debugf("[%s] sendZeroChecksum=%t (on init)", a.name, a.sendZeroChecksum)

setSupportedExtensions(&initAck.chunkInitCommon)

Expand Down Expand Up @@ -1234,11 +1238,11 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error {
}
}
case *paramZeroChecksumAcceptable:
a.useZeroChecksum = v.edmid == dtlsErrorDetectionMethod
a.sendZeroChecksum = v.edmid == dtlsErrorDetectionMethod
}
}

a.log.Debugf("[%s] useZeroChecksum=%t (on initAck)", a.name, a.useZeroChecksum)
a.log.Debugf("[%s] sendZeroChecksum=%t (on initAck)", a.name, a.sendZeroChecksum)

if !a.useForwardTSN {
a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name)
Expand Down
6 changes: 3 additions & 3 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3084,7 +3084,7 @@ func (c customLogger) Trace(string) {}
func (c customLogger) Tracef(string, ...interface{}) {}
func (c customLogger) Debug(string) {}
func (c customLogger) Debugf(format string, args ...interface{}) {
if format == "[%s] useZeroChecksum=%t (on initAck)" {
if format == "[%s] sendZeroChecksum=%t (on initAck)" {
assert.Equal(c.t, args[1], c.expectZeroChecksum)
}
}
Expand All @@ -3110,8 +3110,8 @@ func TestAssociation_ZeroChecksum(t *testing.T) {
}{
{true, true, true},
{false, false, false},
{true, false, true},
{false, true, false},
{true, false, false},
{false, true, true},
} {
a1chan, a2chan := make(chan *Association), make(chan *Association)

Expand Down
8 changes: 4 additions & 4 deletions packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ func TestPacketUnmarshal(t *testing.T) {
switch {
case err != nil:
t.Errorf("Unmarshal failed for SCTP packet with no chunks: %v", err)
case pkt.sourcePort != 5000:
t.Errorf("Unmarshal passed for SCTP packet, but got incorrect source port exp: %d act: %d", 5000, pkt.sourcePort)
case pkt.destinationPort != 5000:
t.Errorf("Unmarshal passed for SCTP packet, but got incorrect destination port exp: %d act: %d", 5000, pkt.destinationPort)
case pkt.sourcePort != defaultSCTPSrcDstPort:
t.Errorf("Unmarshal passed for SCTP packet, but got incorrect source port exp: %d act: %d", defaultSCTPSrcDstPort, pkt.sourcePort)
case pkt.destinationPort != defaultSCTPSrcDstPort:
t.Errorf("Unmarshal passed for SCTP packet, but got incorrect destination port exp: %d act: %d", defaultSCTPSrcDstPort, pkt.destinationPort)
case pkt.verificationTag != 0:
t.Errorf("Unmarshal passed for SCTP packet, but got incorrect verification tag exp: %d act: %d", 0, pkt.verificationTag)
}
Expand Down
30 changes: 15 additions & 15 deletions rtx_timer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,15 @@ func TestRtxTimer(t *testing.T) {
timerID := 0
var nCbs int32
rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, nRtos uint) {
onRTO: func(id int, _ uint) {
atomic.AddInt32(&nCbs, 1)
// 30 : 1 (30)
// 60 : 2 (90)
// 120: 3 (210)
// 240: 4 (550) <== expected in 650 msec
assert.Equal(t, timerID, id, "unexpted timer ID: %d", id)
},
onRtxFailure: func(id int) {},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)

assert.False(t, rt.isRunning(), "should not be running")
Expand All @@ -150,11 +150,11 @@ func TestRtxTimer(t *testing.T) {
var nCbs int32

rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, nRtos uint) {
onRTO: func(id int, _ uint) {
atomic.AddInt32(&nCbs, 1)
assert.Equal(t, timerID, id, "unexpted timer ID: %d", id)
},
onRtxFailure: func(id int) {},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)

interval := float64(30.0)
Expand All @@ -177,11 +177,11 @@ func TestRtxTimer(t *testing.T) {
var nCbs int32

rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, nRtos uint) {
onRTO: func(id int, _ uint) {
atomic.AddInt32(&nCbs, 1)
assert.Equal(t, timerID, id, "unexpted timer ID: %d", id)
},
onRtxFailure: func(id int) {},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)

interval := float64(30.0)
Expand All @@ -200,11 +200,11 @@ func TestRtxTimer(t *testing.T) {
timerID := 1
var nCbs int32
rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, nRtos uint) {
onRTO: func(id int, _ uint) {
atomic.AddInt32(&nCbs, 1)
assert.Equal(t, timerID, id, "unexpted timer ID: %d", id)
},
onRtxFailure: func(id int) {},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)

interval := float64(30.0)
Expand All @@ -226,12 +226,12 @@ func TestRtxTimer(t *testing.T) {
timerID := 2
var nCbs int32
rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, nRtos uint) {
onRTO: func(id int, _ uint) {
atomic.AddInt32(&nCbs, 1)
t.Log("onRTO() called")
assert.Equal(t, timerID, id, "unexpted timer ID: %d", id)
},
onRtxFailure: func(id int) {},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)

for i := 0; i < 1000; i++ {
Expand Down Expand Up @@ -305,7 +305,7 @@ func TestRtxTimer(t *testing.T) {
doneCh <- true
}
},
onRtxFailure: func(id int) {
onRtxFailure: func(_ int) {
assert.Fail(t, "timer should not fail")
},
}, 0, 0)
Expand Down Expand Up @@ -338,11 +338,11 @@ func TestRtxTimer(t *testing.T) {
doneCh := make(chan bool)

rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, nRtos uint) {
onRTO: func(id int, _ uint) {
assert.Equal(t, timerID, id, "unexpted timer ID: %d", id)
doneCh <- true
},
onRtxFailure: func(id int) {},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)

for i := 0; i < 10; i++ {
Expand All @@ -362,10 +362,10 @@ func TestRtxTimer(t *testing.T) {
var rtoCount int
timerID := 6
rt := newRTXTimer(timerID, &testTimerObserver{
onRTO: func(id int, nRtos uint) {
onRTO: func(_ int, _ uint) {
rtoCount++
},
onRtxFailure: func(id int) {},
onRtxFailure: func(_ int) {},
}, pathMaxRetrans, 0)

ok := rt.start(20)
Expand Down
5 changes: 5 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ func (s *Stream) SetReadDeadline(deadline time.Time) error {
t.Stop()
return
case <-t.C:
select {
case <-readTimeoutCancel:
return
default:
}
s.lock.Lock()
if s.readErr == nil {
s.readErr = ErrReadDeadlineExceeded
Expand Down
24 changes: 12 additions & 12 deletions vnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ func testRwndFull(t *testing.T, unordered bool) {
defer close(serverShutDown)
// connected UDP conn for server
conn, err := venv.net0.DialUDP("udp4",
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort},
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort},
)
if !assert.NoError(t, err, "should succeed") {
return
Expand Down Expand Up @@ -277,8 +277,8 @@ func testRwndFull(t *testing.T, unordered bool) {
defer close(clientShutDown)
// connected UDP conn for client
conn, err := venv.net1.DialUDP("udp4",
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort},
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort},
)
if !assert.NoError(t, err, "should succeed") {
return
Expand Down Expand Up @@ -435,8 +435,8 @@ func TestStreamClose(t *testing.T) {
defer close(serverShutDown)
// connected UDP conn for server
conn, innerErr := venv.net0.DialUDP("udp4",
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort},
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort},
)
if !assert.NoError(t, innerErr, "should succeed") {
return
Expand Down Expand Up @@ -485,8 +485,8 @@ func TestStreamClose(t *testing.T) {

// connected UDP conn for client
conn, err := venv.net1.DialUDP("udp4",
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort},
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort},
)
if !assert.NoError(t, err, "should succeed") {
return
Expand Down Expand Up @@ -620,8 +620,8 @@ func TestCookieEchoRetransmission(t *testing.T) {
defer close(serverShutDown)
// connected UDP conn for server
conn, err := venv.net0.DialUDP("udp4",
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort},
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort},
)
if !assert.NoError(t, err, "should succeed") {
return
Expand Down Expand Up @@ -650,8 +650,8 @@ func TestCookieEchoRetransmission(t *testing.T) {
defer close(clientShutDown)
// connected UDP conn for client
conn, err := venv.net1.DialUDP("udp4",
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000},
&net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort},
&net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort},
)
if !assert.NoError(t, err, "should succeed") {
return
Expand Down

0 comments on commit 840fdcc

Please sign in to comment.