Skip to content

Commit

Permalink
Implement draft-ietf-tsvwg-sctp-zero-checksum-01
Browse files Browse the repository at this point in the history
  • Loading branch information
mengelbart authored and Sean-Der committed Feb 9, 2024
1 parent 2927025 commit b3c21fb
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 48 deletions.
97 changes: 83 additions & 14 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ type Association struct {
cumulativeTSNAckPoint uint32
advancedPeerTSNAckPoint uint32
useForwardTSN bool
acceptZeroChecksum bool
useZeroChecksum bool

// Congestion control parameters
maxReceiveBufferSize uint32
Expand Down Expand Up @@ -233,6 +235,7 @@ type Config struct {
NetConn net.Conn
MaxReceiveBufferSize uint32
MaxMessageSize uint32
UseZeroChecksum bool
LoggerFactory logging.LoggerFactory
}

Expand Down Expand Up @@ -320,6 +323,7 @@ func createAssociation(config Config) *Association {
handshakeCompletedCh: make(chan error),
cumulativeTSNAckPoint: tsn - 1,
advancedPeerTSNAckPoint: tsn - 1,
useZeroChecksum: config.UseZeroChecksum,
silentError: ErrSilentlyDiscard,
stats: &associationStats{},
log: config.LoggerFactory.NewLogger("sctp"),
Expand Down Expand Up @@ -362,6 +366,9 @@ func (a *Association) init(isClient bool) {
init.initiateTag = a.myVerificationTag
init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize
setSupportedExtensions(&init.chunkInitCommon)
if a.acceptZeroChecksum {
setZeroChecksumAccepted(&init.chunkInitCommon)
}
a.storedInit = init

err := a.sendInit()
Expand Down Expand Up @@ -618,10 +625,48 @@ func (a *Association) unregisterStream(s *Stream, err error) {
s.readNotifier.Broadcast()
}

// handleInbound parses incoming raw packets
func (a *Association) handleInbound(raw []byte) error {
func needsCorrectChecksum(cc []chunk) bool {
for _, c := range cc {
switch c.(type) {
case *chunkInit:
return true
case *chunkCookieEcho:
return true
}
}
return false
}

func (a *Association) marshalPacket(p *packet) ([]byte, error) {
if !a.useZeroChecksum || needsCorrectChecksum(p.chunks) {
return p.marshal()
}
return p.marshalWithoutChecksum()
}

func (a *Association) unmarshalPacket(raw []byte) (*packet, error) {
p := &packet{}
if a.acceptZeroChecksum {
if err := p.unmarshalZeroChecksum(raw); err != nil {
return nil, err
}
if needsCorrectChecksum(p.chunks) {
if err := p.checkChecksum(raw); err != nil {
return nil, err
}
}
return p, nil
}
if err := p.unmarshal(raw); err != nil {
return nil, err
}
return p, nil
}

// handleInbound parses incoming raw packets
func (a *Association) handleInbound(raw []byte) error {
p, err := a.unmarshalPacket(raw)
if err != nil {
a.log.Warnf("[%s] unable to parse SCTP packet %s", a.name, err)
return nil
}
Expand All @@ -647,7 +692,7 @@ func (a *Association) handleInbound(raw []byte) error {
// The caller should hold the lock
func (a *Association) gatherDataPacketsToRetransmit(rawPackets [][]byte) [][]byte {
for _, p := range a.getDataPacketsToRetransmit() {
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name)
continue
Expand All @@ -668,7 +713,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
a.log.Tracef("[%s] T3-rtx timer start (pt1)", a.name)
a.t3RTX.start(a.rtoMgr.getRTO())
for _, p := range a.bundleDataChunksIntoPackets(chunks) {
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet", a.name)
continue
Expand All @@ -683,7 +728,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
a.log.Debugf("[%s] retransmit %d RECONFIG chunk(s)", a.name, len(a.reconfigs))
for _, c := range a.reconfigs {
p := a.createPacket([]chunk{c})
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be retransmitted", a.name)
} else {
Expand All @@ -706,7 +751,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
a.log.Debugf("[%s] sending RECONFIG: rsn=%d tsn=%d streams=%v",
a.name, rsn, a.myNextTSN-1, sisToReset)
p := a.createPacket([]chunk{c})
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be transmitted", a.name)
} else {
Expand Down Expand Up @@ -769,7 +814,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt
}

if len(toFastRetrans) > 0 {
raw, err := a.createPacket(toFastRetrans).marshal()
raw, err := a.marshalPacket(a.createPacket(toFastRetrans))
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name)
} else {
Expand All @@ -787,7 +832,7 @@ func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte {
a.ackState = ackStateIdle
sack := a.createSelectiveAckChunk()
a.log.Debugf("[%s] sending SACK: %s", a.name, sack)
raw, err := a.createPacket([]chunk{sack}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{sack}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a SACK packet", a.name)
} else {
Expand All @@ -804,7 +849,7 @@ func (a *Association) gatherOutboundForwardTSNPackets(rawPackets [][]byte) [][]b
a.willSendForwardTSN = false
if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) {
fwdtsn := a.createForwardTSN()
raw, err := a.createPacket([]chunk{fwdtsn}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{fwdtsn}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a Forward TSN packet", a.name)
} else {
Expand All @@ -827,7 +872,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by
cumulativeTSNAck: a.cumulativeTSNAckPoint,
}

raw, err := a.createPacket([]chunk{shutdown}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdown}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a Shutdown packet", a.name)
} else {
Expand All @@ -839,7 +884,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by

shutdownAck := &chunkShutdownAck{}

raw, err := a.createPacket([]chunk{shutdownAck}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownAck}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a ShutdownAck packet", a.name)
} else {
Expand All @@ -851,7 +896,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by

shutdownComplete := &chunkShutdownComplete{}

raw, err := a.createPacket([]chunk{shutdownComplete}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownComplete}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a ShutdownComplete packet", a.name)
} else {
Expand All @@ -875,7 +920,7 @@ func (a *Association) gatherAbortPacket() ([]byte, error) {
abort.errorCauses = []errorCause{cause}
}

raw, err := a.createPacket([]chunk{abort}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{abort}))

return raw, err
}
Expand All @@ -900,7 +945,7 @@ func (a *Association) gatherOutbound() ([][]byte, bool) {

if a.controlQueue.size() > 0 {
for _, p := range a.controlQueue.popAll() {
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a control packet", a.name)
continue
Expand Down Expand Up @@ -1061,6 +1106,10 @@ func setSupportedExtensions(init *chunkInitCommon) {
})
}

func setZeroChecksumAccepted(init *chunkInitCommon) {
init.params = append(init.params, &paramZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod})
}

// The caller should hold the lock.
func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
state := a.getState()
Expand Down Expand Up @@ -1092,6 +1141,7 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
// subtracting one from it.
a.peerLastTSN = i.initialTSN - 1

peerAcceptsZeroChecksum := false
for _, param := range i.params {
switch v := param.(type) { // nolint:gocritic
case *paramSupportedExtensions:
Expand All @@ -1101,8 +1151,15 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
a.useForwardTSN = true
}
}
case *paramZeroChecksumAcceptable:
peerAcceptsZeroChecksum = v.edmid == dtlsErrorDetectionMethod
}
}
if a.useZeroChecksum && peerAcceptsZeroChecksum {
a.useZeroChecksum = true
} else {
a.useZeroChecksum = false
}
if !a.useForwardTSN {
a.log.Warnf("[%s] not using ForwardTSN (on init)", a.name)
}
Expand All @@ -1129,6 +1186,10 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {

initAck.params = []param{a.myCookie}

if a.acceptZeroChecksum {
setZeroChecksumAccepted(&initAck.chunkInitCommon)
}

setSupportedExtensions(&initAck.chunkInitCommon)

outbound.chunks = []chunk{initAck}
Expand Down Expand Up @@ -1174,6 +1235,7 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error {
a.t1Init.stop()
a.storedInit = nil

peerAcceptsZeroChecksum := false
var cookieParam *paramStateCookie
for _, param := range i.params {
switch v := param.(type) {
Expand All @@ -1186,8 +1248,15 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error {
a.useForwardTSN = true
}
}
case *paramZeroChecksumAcceptable:
peerAcceptsZeroChecksum = v.edmid == dtlsErrorDetectionMethod
}
}
if a.useZeroChecksum && peerAcceptsZeroChecksum {
a.useZeroChecksum = true
} else {
a.useZeroChecksum = false
}
if !a.useForwardTSN {
a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name)
}
Expand Down
4 changes: 2 additions & 2 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2285,7 +2285,7 @@ func TestAssocAbort(t *testing.T) {
errorCauseHeader: errorCauseHeader{code: protocolViolation},
}},
}
packet, err := a0.createPacket([]chunk{abort}).marshal()
packet, err := a0.marshalPacket(a0.createPacket([]chunk{abort}))
assert.NoError(t, err)

_, _, err = establishSessionPair(br, a0, a1, si)
Expand Down Expand Up @@ -2964,7 +2964,7 @@ func TestAssociation_HandlePacketInCookieWaitState(t *testing.T) {
}()
}

packet, err := testCase.inputPacket.marshal()
packet, err := a.marshalPacket(testCase.inputPacket)
assert.NoError(t, err)
_, err = charlieConn.Write(packet)
assert.NoError(t, err)
Expand Down
33 changes: 28 additions & 5 deletions packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ var (
ErrChecksumMismatch = errors.New("checksum mismatch theirs")
)

func (p *packet) unmarshal(raw []byte) error {
func (p *packet) unmarshalWithoutCheckingChecksum(raw []byte) error {
if len(raw) < packetHeaderSize {
return fmt.Errorf("%w: raw only %d bytes, %d is the minimum length", ErrPacketRawTooSmall, len(raw), packetHeaderSize)
}
Expand Down Expand Up @@ -125,6 +125,10 @@ func (p *packet) unmarshal(raw []byte) error {
chunkValuePadding := getPadding(c.valueLength())
offset += chunkHeaderSize + c.valueLength() + chunkValuePadding
}
return nil
}

func (p *packet) checkChecksum(raw []byte) error {
theirChecksum := binary.LittleEndian.Uint32(raw[8:])
ourChecksum := generatePacketChecksum(raw)
if theirChecksum != ourChecksum {
Expand All @@ -133,7 +137,30 @@ func (p *packet) unmarshal(raw []byte) error {
return nil
}

func (p *packet) unmarshalZeroChecksum(raw []byte) error {
return p.unmarshalWithoutCheckingChecksum(raw)
}

func (p *packet) unmarshal(raw []byte) error {
if err := p.unmarshalWithoutCheckingChecksum(raw); err != nil {
return err
}
return p.checkChecksum(raw)
}

func (p *packet) marshal() ([]byte, error) {
raw, err := p.marshalWithoutChecksum()
if err != nil {
return nil, err
}

// Checksum is already in BigEndian
// Using LittleEndian.PutUint32 stops it from being flipped
binary.LittleEndian.PutUint32(raw[8:], generatePacketChecksum(raw))
return raw, nil
}

func (p *packet) marshalWithoutChecksum() ([]byte, error) {
raw := make([]byte, packetHeaderSize)

// Populate static headers
Expand All @@ -155,10 +182,6 @@ func (p *packet) marshal() ([]byte, error) {
raw = append(raw, make([]byte, paddingNeeded)...)
}
}

// Checksum is already in BigEndian
// Using LittleEndian.PutUint32 stops it from being flipped
binary.LittleEndian.PutUint32(raw[8:], generatePacketChecksum(raw))
return raw, nil
}

Expand Down
2 changes: 2 additions & 0 deletions param.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ func buildParam(t paramType, rawParam []byte) (param, error) {
return (&paramOutgoingResetRequest{}).unmarshal(rawParam)
case reconfigResp:
return (&paramReconfigResponse{}).unmarshal(rawParam)
case zeroChecksumAcceptable:
return (&paramZeroChecksumAcceptable{}).unmarshal(rawParam)
default:
return nil, fmt.Errorf("%w: %v", ErrParamTypeUnhandled, t)
}
Expand Down
Loading

0 comments on commit b3c21fb

Please sign in to comment.