Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement draft-ietf-tsvwg-sctp-zero-checksum-01 #284

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 73 additions & 14 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@
cumulativeTSNAckPoint uint32
advancedPeerTSNAckPoint uint32
useForwardTSN bool
useZeroChecksum bool
requestZeroChecksum bool

// Congestion control parameters
maxReceiveBufferSize uint32
Expand Down Expand Up @@ -233,6 +235,7 @@
NetConn net.Conn
MaxReceiveBufferSize uint32
MaxMessageSize uint32
EnableZeroChecksum bool
LoggerFactory logging.LoggerFactory
}

Expand Down Expand Up @@ -320,6 +323,7 @@
handshakeCompletedCh: make(chan error),
cumulativeTSNAckPoint: tsn - 1,
advancedPeerTSNAckPoint: tsn - 1,
requestZeroChecksum: config.EnableZeroChecksum,
silentError: ErrSilentlyDiscard,
stats: &associationStats{},
log: config.LoggerFactory.NewLogger("sctp"),
Expand Down Expand Up @@ -362,6 +366,11 @@
init.initiateTag = a.myVerificationTag
init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize
setSupportedExtensions(&init.chunkInitCommon)

if a.requestZeroChecksum {
init.params = append(init.params, &paramZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod})
}

a.storedInit = init

err := a.sendInit()
Expand Down Expand Up @@ -618,10 +627,45 @@
s.readNotifier.Broadcast()
}

func chunkMandatoryChecksum(cc []chunk) bool {
for _, c := range cc {
switch c.(type) {
case *chunkInit, *chunkInitAck, *chunkCookieEcho:
return true
}
}
return false
}

func (a *Association) marshalPacket(p *packet) ([]byte, error) {
return p.marshal(!a.useZeroChecksum || chunkMandatoryChecksum(p.chunks))
}

func (a *Association) unmarshalPacket(raw []byte) (*packet, error) {
p := &packet{}
if !a.useZeroChecksum {
if err := p.unmarshal(true, raw); err != nil {
return nil, err
}
return p, nil
}

if err := p.unmarshal(false, raw); err != nil {
return nil, err
}

Check warning on line 655 in association.go

View check run for this annotation

Codecov / codecov/patch

association.go#L654-L655

Added lines #L654 - L655 were not covered by tests
if chunkMandatoryChecksum(p.chunks) {
if err := p.unmarshal(true, raw); err != nil {
return nil, err
}

Check warning on line 659 in association.go

View check run for this annotation

Codecov / codecov/patch

association.go#L658-L659

Added lines #L658 - L659 were not covered by tests
}

return p, nil
}

// handleInbound parses incoming raw packets
func (a *Association) handleInbound(raw []byte) error {
p := &packet{}
if err := p.unmarshal(raw); err != nil {
p, err := a.unmarshalPacket(raw)
if err != nil {
a.log.Warnf("[%s] unable to parse SCTP packet %s", a.name, err)
return nil
}
Expand All @@ -647,7 +691,7 @@
// The caller should hold the lock
func (a *Association) gatherDataPacketsToRetransmit(rawPackets [][]byte) [][]byte {
for _, p := range a.getDataPacketsToRetransmit() {
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name)
continue
Expand All @@ -668,7 +712,7 @@
a.log.Tracef("[%s] T3-rtx timer start (pt1)", a.name)
a.t3RTX.start(a.rtoMgr.getRTO())
for _, p := range a.bundleDataChunksIntoPackets(chunks) {
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet", a.name)
continue
Expand All @@ -683,7 +727,7 @@
a.log.Debugf("[%s] retransmit %d RECONFIG chunk(s)", a.name, len(a.reconfigs))
for _, c := range a.reconfigs {
p := a.createPacket([]chunk{c})
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be retransmitted", a.name)
} else {
Expand All @@ -706,7 +750,7 @@
a.log.Debugf("[%s] sending RECONFIG: rsn=%d tsn=%d streams=%v",
a.name, rsn, a.myNextTSN-1, sisToReset)
p := a.createPacket([]chunk{c})
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be transmitted", a.name)
} else {
Expand Down Expand Up @@ -769,7 +813,7 @@
}

if len(toFastRetrans) > 0 {
raw, err := a.createPacket(toFastRetrans).marshal()
raw, err := a.marshalPacket(a.createPacket(toFastRetrans))
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name)
} else {
Expand All @@ -787,7 +831,7 @@
a.ackState = ackStateIdle
sack := a.createSelectiveAckChunk()
a.log.Debugf("[%s] sending SACK: %s", a.name, sack)
raw, err := a.createPacket([]chunk{sack}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{sack}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a SACK packet", a.name)
} else {
Expand All @@ -804,7 +848,7 @@
a.willSendForwardTSN = false
if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) {
fwdtsn := a.createForwardTSN()
raw, err := a.createPacket([]chunk{fwdtsn}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{fwdtsn}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a Forward TSN packet", a.name)
} else {
Expand All @@ -827,7 +871,7 @@
cumulativeTSNAck: a.cumulativeTSNAckPoint,
}

raw, err := a.createPacket([]chunk{shutdown}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdown}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a Shutdown packet", a.name)
} else {
Expand All @@ -839,7 +883,7 @@

shutdownAck := &chunkShutdownAck{}

raw, err := a.createPacket([]chunk{shutdownAck}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownAck}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a ShutdownAck packet", a.name)
} else {
Expand All @@ -851,7 +895,7 @@

shutdownComplete := &chunkShutdownComplete{}

raw, err := a.createPacket([]chunk{shutdownComplete}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownComplete}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a ShutdownComplete packet", a.name)
} else {
Expand All @@ -875,7 +919,7 @@
abort.errorCauses = []errorCause{cause}
}

raw, err := a.createPacket([]chunk{abort}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{abort}))

return raw, err
}
Expand All @@ -900,7 +944,7 @@

if a.controlQueue.size() > 0 {
for _, p := range a.controlQueue.popAll() {
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a control packet", a.name)
continue
Expand Down Expand Up @@ -1092,6 +1136,7 @@
// subtracting one from it.
a.peerLastTSN = i.initialTSN - 1

peerHasZeroChecksum := false
for _, param := range i.params {
switch v := param.(type) { // nolint:gocritic
case *paramSupportedExtensions:
Expand All @@ -1101,8 +1146,11 @@
a.useForwardTSN = true
}
}
case *paramZeroChecksumAcceptable:
peerHasZeroChecksum = v.edmid == dtlsErrorDetectionMethod
}
}

if !a.useForwardTSN {
a.log.Warnf("[%s] not using ForwardTSN (on init)", a.name)
}
Expand All @@ -1129,6 +1177,12 @@

initAck.params = []param{a.myCookie}

if peerHasZeroChecksum {
initAck.params = append(initAck.params, &paramZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod})
a.useZeroChecksum = true
}
a.log.Debugf("[%s] useZeroChecksum=%t (on init)", a.name, a.useZeroChecksum)

setSupportedExtensions(&initAck.chunkInitCommon)

outbound.chunks = []chunk{initAck}
Expand Down Expand Up @@ -1186,8 +1240,13 @@
a.useForwardTSN = true
}
}
case *paramZeroChecksumAcceptable:
a.useZeroChecksum = v.edmid == dtlsErrorDetectionMethod
}
}

a.log.Debugf("[%s] useZeroChecksum=%t (on initAck)", a.name, a.useZeroChecksum)

if !a.useForwardTSN {
a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name)
}
Expand Down
90 changes: 87 additions & 3 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1622,7 +1622,7 @@ func TestAssocT1CookieTimer(t *testing.T) {
// Drop all COOKIE-ECHO
br.Filter(0, func(raw []byte) bool {
p := &packet{}
err := p.unmarshal(raw)
err := p.unmarshal(true, raw)
if !assert.Nil(t, err, "failed to parse packet") {
return false // drop
}
Expand Down Expand Up @@ -2285,7 +2285,7 @@ func TestAssocAbort(t *testing.T) {
errorCauseHeader: errorCauseHeader{code: protocolViolation},
}},
}
packet, err := a0.createPacket([]chunk{abort}).marshal()
packet, err := a0.marshalPacket(a0.createPacket([]chunk{abort}))
assert.NoError(t, err)

_, _, err = establishSessionPair(br, a0, a1, si)
Expand Down Expand Up @@ -2964,7 +2964,7 @@ func TestAssociation_HandlePacketInCookieWaitState(t *testing.T) {
}()
}

packet, err := testCase.inputPacket.marshal()
packet, err := a.marshalPacket(testCase.inputPacket)
assert.NoError(t, err)
_, err = charlieConn.Write(packet)
assert.NoError(t, err)
Expand Down Expand Up @@ -3072,3 +3072,87 @@ loop:
assert.Error(t, err1, "context canceled")
assert.Error(t, err2, "context canceled")
}

type customLogger struct {
expectZeroChecksum bool
t *testing.T
}

func (c customLogger) Trace(string) {}
func (c customLogger) Tracef(string, ...interface{}) {}
func (c customLogger) Debug(string) {}
func (c customLogger) Debugf(format string, args ...interface{}) {
if format == "[%s] useZeroChecksum=%t (on initAck)" {
assert.Equal(c.t, args[1], c.expectZeroChecksum)
}
}
func (c customLogger) Info(string) {}
func (c customLogger) Infof(string, ...interface{}) {}
func (c customLogger) Warn(string) {}
func (c customLogger) Warnf(string, ...interface{}) {}
func (c customLogger) Error(string) {}
func (c customLogger) Errorf(string, ...interface{}) {}

func (c customLogger) NewLogger(string) logging.LeveledLogger {
return c
}

func TestAssociation_ZeroChecksum(t *testing.T) {
checkGoroutineLeaks(t)

lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

for _, testCase := range []struct {
clientZeroChecksum, serverZeroChecksum, expectChecksumEnabled bool
}{
{true, true, true},
{false, false, false},
{true, false, true},
{false, true, false},
} {
a1chan, a2chan := make(chan *Association), make(chan *Association)

udp1, udp2 := createUDPConnPair()

go func() {
a1, err := Client(Config{
NetConn: udp1,
LoggerFactory: &customLogger{testCase.expectChecksumEnabled, t},
EnableZeroChecksum: testCase.clientZeroChecksum,
})
assert.NoError(t, err)
a1chan <- a1
}()

go func() {
a2, err := Server(Config{
NetConn: udp2,
LoggerFactory: &customLogger{testCase.expectChecksumEnabled, t},
EnableZeroChecksum: testCase.serverZeroChecksum,
})
assert.NoError(t, err)
a2chan <- a2
}()

a1, a2 := <-a1chan, <-a2chan

writeStream, err := a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

readStream, err := a2.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

testData := []byte("test")
_, err = writeStream.Write(testData)
require.NoError(t, err)

buf := make([]byte, len(testData))
_, err = readStream.Read(buf)
assert.NoError(t, err)
assert.Equal(t, testData, buf)

require.NoError(t, a1.Close())
require.NoError(t, a2.Close())
}
}
Loading
Loading