Skip to content

Commit

Permalink
Fix maximum packet count handling
Browse files Browse the repository at this point in the history
Accoring to RFC3711 section 9.2, SRTP/SRTCP session must not
wrap SRTP ROC and SRTCP index without changing the master key.

Also fix Context.SetROC() with ROC>0xffff.
  • Loading branch information
at-wat committed Jan 4, 2023
1 parent 994fbcd commit 5ce39f6
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 109 deletions.
10 changes: 5 additions & 5 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const (
labelSRTCPSalt = 0x05

maxSequenceNumber = 65535
maxROC = (1 << 32) - 1

seqNumMedian = 1 << 15
seqNumMax = 1 << 16
Expand Down Expand Up @@ -60,8 +61,7 @@ type Context struct {
// Passing multiple options which set the same parameter let the last one valid.
// Following example create SRTP Context with replay protection with window size of 256.
//
// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
//
// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
keyLen, err := profile.keyLen()
if err != nil {
Expand Down Expand Up @@ -112,7 +112,7 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts
}

// https://tools.ietf.org/html/rfc3550#appendix-A.1
func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (uint32, int32) {
func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, diff int32, overflow bool) {
seq := int32(sequenceNumber)
localRoc := uint32(s.index >> 16)
localSeq := int32(s.index & (seqNumMax - 1))
Expand Down Expand Up @@ -147,7 +147,7 @@ func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (uint32, int32)
}
}

return guessRoc, difference
return guessRoc, difference, (guessRoc == 0 && localRoc == maxROC)
}

func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference int32) {
Expand Down Expand Up @@ -201,7 +201,7 @@ func (c *Context) ROC(ssrc uint32) (uint32, bool) {
// SetROC sets SRTP rollover counter value of specified SSRC.
func (c *Context) SetROC(ssrc uint32, roc uint32) {
s := c.getSRTPSSRCState(ssrc)
s.index = uint64(roc<<16) | (s.index & (seqNumMax - 1))
s.index = uint64(roc)<<16 | (s.index & (seqNumMax - 1))
}

// Index returns SRTCP index value of specified SSRC.
Expand Down
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ var (
errPayloadDiffers = errors.New("payload differs")
errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed")
errBadIVLength = errors.New("bad iv length in xorBytesCTR")
errExceededMaxPackets = errors.New("exceeded the maximum number of packets")

errStreamNotInited = errors.New("stream has not been inited, unable to close")
errStreamAlreadyClosed = errors.New("stream is already closed")
Expand Down
4 changes: 2 additions & 2 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type ContextOption func(*Context) error
func SRTPReplayProtection(windowSize uint) ContextOption { // nolint:revive
return func(c *Context) error {
c.newSRTPReplayDetector = func() replaydetector.ReplayDetector {
return replaydetector.WithWrap(windowSize, maxSequenceNumber)
return replaydetector.New(windowSize, maxROC<<16|maxSequenceNumber)
}
return nil
}
Expand All @@ -21,7 +21,7 @@ func SRTPReplayProtection(windowSize uint) ContextOption { // nolint:revive
func SRTCPReplayProtection(windowSize uint) ContextOption {
return func(c *Context) error {
c.newSRTCPReplayDetector = func() replaydetector.ReplayDetector {
return replaydetector.WithWrap(windowSize, maxSRTCPIndex)
return replaydetector.New(windowSize, maxSRTCPIndex)
}
return nil
}
Expand Down
11 changes: 8 additions & 3 deletions srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,16 @@ func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) {
ssrc := binary.BigEndian.Uint32(decrypted[4:])
s := c.getSRTCPSSRCState(ssrc)

if s.srtcpIndex >= maxSRTCPIndex {
// ... when 2^48 SRTP packets or 2^31 SRTCP packets have been secured with the same key
// (whichever occurs before), the key management MUST be called to provide new master key(s)
// (previously stored and used keys MUST NOT be used again), or the session MUST be terminated.
// https://www.rfc-editor.org/rfc/rfc3711#section-9.2
return nil, errExceededMaxPackets
}

// We roll over early because MSB is used for marking as encrypted
s.srtcpIndex++
if s.srtcpIndex > maxSRTCPIndex {
s.srtcpIndex = 0
}

return c.cipher.encryptRTCP(dst, decrypted, s.srtcpIndex, ssrc)
}
Expand Down
228 changes: 158 additions & 70 deletions srtcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type rtcpTestCase struct {
packets []rtcpTestPacket
}

func rtcpTestCasesSingle() map[string]rtcpTestCase {
func rtcpTestCases() map[string]rtcpTestCase {
return map[string]rtcpTestCase{
"AEAD_AES_128_GCM": {
algo: ProtectionProfileAeadAes128Gcm,
Expand Down Expand Up @@ -112,72 +112,6 @@ func rtcpTestCasesSingle() map[string]rtcpTestCase {
}
}

func rtcpTestCases() map[string]rtcpTestCase {
single := rtcpTestCasesSingle()
return map[string]rtcpTestCase{
"AEAD_AES_128_GCM": single["AEAD_AES_128_GCM"],
"AES_128_CM_HMAC_SHA1_80": {
algo: ProtectionProfileAes128CmHmacSha1_80,
masterKey: single["AES_128_CM_HMAC_SHA1_80"].masterKey,
masterSalt: single["AES_128_CM_HMAC_SHA1_80"].masterSalt,
packets: []rtcpTestPacket{
single["AES_128_CM_HMAC_SHA1_80"].packets[0],
single["AES_128_CM_HMAC_SHA1_80"].packets[1],
{
ssrc: 0x11111111,
index: 0x7ffffffe, // Upper boundary of index
pktType: rtcp.TypeSenderReport,
encrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11, 0xda, 0xf5,
0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40,
0xbf, 0x22, 0xf6, 0x46, 0x11, 0xa4, 0xc1, 0x3a,
0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4,
0x95, 0x38, 0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77,
0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad,
0xff, 0xff, 0xff, 0xff, 0x5a, 0x99, 0xce, 0xed,
0x9f, 0x2e, 0x4d, 0x9d, 0xfa, 0x97,
},
decrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0x04, 0x99, 0x47, 0x53, 0xc4, 0x1e, 0xb9, 0xde,
0x52, 0xa3, 0x1d, 0x77, 0x2f, 0xff, 0xcc, 0x75,
0xbb, 0x6a, 0x29, 0xb8, 0x01, 0xb7, 0x2e, 0x4b,
0x4e, 0xcb, 0xa4, 0x81, 0x2d, 0x46, 0x04, 0x5e,
0x86, 0x90, 0x17, 0x4f, 0x4d, 0x78, 0x2f, 0x58,
0xb8, 0x67, 0x91, 0x89, 0xe3, 0x61, 0x01, 0x7d,
},
},
{
ssrc: 0x11111111,
index: 0x7fffffff, // Will be wrapped to 0
pktType: rtcp.TypeSenderReport,
encrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11, 0xda, 0xf5,
0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40,
0xbf, 0x22, 0xf6, 0x46, 0x11, 0xa4, 0xc1, 0x3a,
0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4,
0x95, 0x38, 0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77,
0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad,
0x80, 0x00, 0x00, 0x00, 0x7d, 0x51, 0xf8, 0x0e,
0x56, 0x40, 0x72, 0x7b, 0x9e, 0x02,
},
decrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0xda, 0xb5, 0xe0, 0x56, 0x9a, 0x4a, 0x74, 0xed,
0x8a, 0x54, 0x0c, 0xcf, 0xd5, 0x09, 0xb1, 0x40,
0x01, 0x42, 0xc3, 0x9a, 0x76, 0x00, 0xa9, 0xd4,
0xf7, 0x29, 0x9e, 0x51, 0xfb, 0x3c, 0xc1, 0x74,
0x72, 0xf9, 0x52, 0xb1, 0x92, 0x31, 0xca, 0x22,
0xab, 0x3e, 0xc5, 0x5f, 0x83, 0x34, 0xf0, 0x28,
},
},
},
},
}
}

func TestRTCPLifecycle(t *testing.T) {
options := map[string][]ContextOption{
"Default": {},
Expand Down Expand Up @@ -371,7 +305,7 @@ func TestRTCPInvalidAuthTag(t *testing.T) {
}

func TestRTCPReplayDetectorSeparation(t *testing.T) {
for caseName, testCase := range rtcpTestCasesSingle() {
for caseName, testCase := range rtcpTestCases() {
testCase := testCase
t.Run(caseName, func(t *testing.T) {
assert := assert.New(t)
Expand Down Expand Up @@ -409,7 +343,7 @@ func getRTCPIndex(encrypted []byte, authTagLen int) uint32 {
}

func TestEncryptRTCPSeparation(t *testing.T) {
for caseName, testCase := range rtcpTestCasesSingle() {
for caseName, testCase := range rtcpTestCases() {
testCase := testCase
t.Run(caseName, func(t *testing.T) {
assert := assert.New(t)
Expand Down Expand Up @@ -462,7 +396,7 @@ func TestEncryptRTCPSeparation(t *testing.T) {
}

func TestRTCPDecryptShortenedPacket(t *testing.T) {
for caseName, testCase := range rtcpTestCasesSingle() {
for caseName, testCase := range rtcpTestCases() {
testCase := testCase
t.Run(caseName, func(t *testing.T) {
pkt := testCase.packets[0]
Expand All @@ -479,3 +413,157 @@ func TestRTCPDecryptShortenedPacket(t *testing.T) {
})
}
}

func TestRTCPMaxPackets(t *testing.T) {
const ssrc = 0x11111111
testCases := map[string]rtcpTestCase{
"AEAD_AES_128_GCM": {
algo: ProtectionProfileAeadAes128Gcm,
masterKey: []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f},
masterSalt: []byte{0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab},
packets: []rtcpTestPacket{
{
pktType: rtcp.TypeSenderReport,
encrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0x02, 0xb6, 0xc1, 0x47, 0x92, 0xbe, 0xf0, 0xae,
0xd9, 0x40, 0xa5, 0x1c, 0xbe, 0xec, 0xaf, 0xfc,
0x7d, 0x86, 0x3b, 0xbb, 0x93, 0x0c, 0xb0, 0xd4,
0xea, 0x4a, 0x3c, 0x5b, 0xd1, 0xd5, 0x47, 0xb1,
0x1a, 0x61, 0xae, 0xa6, 0x1a, 0x0c, 0xb9, 0x14,
0xa5, 0x16, 0x08, 0xe4, 0xfb, 0x0d, 0x15, 0xba,
0x7f, 0x70, 0x2b, 0xb8, 0x99, 0x97, 0x91, 0xfd,
0x53, 0x03, 0xcd, 0x57, 0xbb, 0x8f, 0x93, 0xbe,
0xff, 0xff, 0xff, 0xff,
},
decrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0x04, 0x99, 0x47, 0x53, 0xc4, 0x1e, 0xb9, 0xde,
0x52, 0xa3, 0x1d, 0x77, 0x2f, 0xff, 0xcc, 0x75,
0xbb, 0x6a, 0x29, 0xb8, 0x01, 0xb7, 0x2e, 0x4b,
0x4e, 0xcb, 0xa4, 0x81, 0x2d, 0x46, 0x04, 0x5e,
0x86, 0x90, 0x17, 0x4f, 0x4d, 0x78, 0x2f, 0x58,
0xb8, 0x67, 0x91, 0x89, 0xe3, 0x61, 0x01, 0x7d,
},
},
{
pktType: rtcp.TypeSenderReport,
encrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0x77, 0x47, 0x0c, 0x21, 0xc2, 0xcd, 0x33, 0xa7,
0x5a, 0x81, 0xb5, 0xb5, 0x8f, 0xe2, 0x34, 0x28,
0x11, 0xa8, 0xa3, 0x34, 0xf8, 0x9d, 0xfc, 0xd8,
0xcb, 0x87, 0xe2, 0x51, 0x8e, 0xae, 0xdb, 0xfd,
0x9d, 0xf1, 0xfa, 0x18, 0xe2, 0xdc, 0x0a, 0xd4,
0xe3, 0x06, 0x18, 0xff, 0xf7, 0x27, 0x92, 0x1f,
0x28, 0xcd, 0x3c, 0xf8, 0xa4, 0x0a, 0x2b, 0xbb,
0x5b, 0x1f, 0x4d, 0x1f, 0xef, 0x0e, 0xc4, 0x91,
0x80, 0x00, 0x00, 0x01,
},
decrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0xda, 0xb5, 0xe0, 0x56, 0x9a, 0x4a, 0x74, 0xed,
0x8a, 0x54, 0x0c, 0xcf, 0xd5, 0x09, 0xb1, 0x40,
0x01, 0x42, 0xc3, 0x9a, 0x76, 0x00, 0xa9, 0xd4,
0xf7, 0x29, 0x9e, 0x51, 0xfb, 0x3c, 0xc1, 0x74,
0x72, 0xf9, 0x52, 0xb1, 0x92, 0x31, 0xca, 0x22,
0xab, 0x3e, 0xc5, 0x5f, 0x83, 0x34, 0xf0, 0x28,
},
},
},
},
"AES_128_CM_HMAC_SHA1_80": {
algo: ProtectionProfileAes128CmHmacSha1_80,
masterKey: []byte{0xfd, 0xa6, 0x25, 0x95, 0xd7, 0xf6, 0x92, 0x6f, 0x7d, 0x9c, 0x02, 0x4c, 0xc9, 0x20, 0x9f, 0x34},
masterSalt: []byte{0xa9, 0x65, 0x19, 0x85, 0x54, 0x0b, 0x47, 0xbe, 0x2f, 0x27, 0xa8, 0xb8, 0x81, 0x23},
packets: []rtcpTestPacket{
{
encrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11, 0xda, 0xf5,
0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40,
0xbf, 0x22, 0xf6, 0x46, 0x11, 0xa4, 0xc1, 0x3a,
0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4,
0x95, 0x38, 0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77,
0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad,
0xff, 0xff, 0xff, 0xff, 0x5a, 0x99, 0xce, 0xed,
0x9f, 0x2e, 0x4d, 0x9d, 0xfa, 0x97,
},
decrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0x04, 0x99, 0x47, 0x53, 0xc4, 0x1e, 0xb9, 0xde,
0x52, 0xa3, 0x1d, 0x77, 0x2f, 0xff, 0xcc, 0x75,
0xbb, 0x6a, 0x29, 0xb8, 0x01, 0xb7, 0x2e, 0x4b,
0x4e, 0xcb, 0xa4, 0x81, 0x2d, 0x46, 0x04, 0x5e,
0x86, 0x90, 0x17, 0x4f, 0x4d, 0x78, 0x2f, 0x58,
0xb8, 0x67, 0x91, 0x89, 0xe3, 0x61, 0x01, 0x7d,
},
},
{
encrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0x12, 0x71, 0x75, 0x7a, 0xb0, 0xfd, 0x80, 0xcb,
0x26, 0xbb, 0x54, 0x5a, 0x1c, 0x0e, 0x98, 0x09,
0xbe, 0x60, 0x23, 0xd8, 0xe6, 0x6e, 0x68, 0xe8,
0x6e, 0x9c, 0xb2, 0x7e, 0x02, 0xa7, 0xab, 0xfe,
0xb3, 0xf4, 0x4c, 0x13, 0xc3, 0xac, 0x97, 0x2c,
0x35, 0x91, 0xbb, 0x37, 0x9c, 0x86, 0x28, 0x85,
0x80, 0x00, 0x00, 0x01, 0x89, 0x76, 0x07, 0xca,
0xd9, 0xc4, 0xcb, 0xca, 0x66, 0xab,
},
decrypted: []byte{
0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11,
0xda, 0xb5, 0xe0, 0x56, 0x9a, 0x4a, 0x74, 0xed,
0x8a, 0x54, 0x0c, 0xcf, 0xd5, 0x09, 0xb1, 0x40,
0x01, 0x42, 0xc3, 0x9a, 0x76, 0x00, 0xa9, 0xd4,
0xf7, 0x29, 0x9e, 0x51, 0xfb, 0x3c, 0xc1, 0x74,
0x72, 0xf9, 0x52, 0xb1, 0x92, 0x31, 0xca, 0x22,
0xab, 0x3e, 0xc5, 0x5f, 0x83, 0x34, 0xf0, 0x28,
},
},
},
},
}

for caseName, testCase := range testCases {
testCase := testCase
t.Run(caseName, func(t *testing.T) {
assert := assert.New(t)
encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo)
if err != nil {
t.Errorf("CreateContext failed: %v", err)
}

decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, SRTCPReplayProtection(10))
if err != nil {
t.Errorf("CreateContext failed: %v", err)
}

// Upper boundary of index
encryptContext.SetIndex(ssrc, 0x7ffffffe)

decryptResult, err := decryptContext.DecryptRTCP(nil, testCase.packets[0].encrypted, nil)
if err != nil {
t.Error(err)
}
assert.Equal(testCase.packets[0].decrypted, decryptResult, "RTCP failed to decrypt")

encryptResult, err := encryptContext.EncryptRTCP(nil, testCase.packets[0].decrypted, nil)
if err != nil {
t.Error(err)
}
assert.Equal(testCase.packets[0].encrypted, encryptResult, "RTCP failed to encrypt")

// Next packet will exceeds the maximum packet count
_, err = decryptContext.DecryptRTCP(nil, testCase.packets[1].encrypted, nil)
if !errors.Is(err, errDuplicated) {
t.Errorf("Expected error: '%v', got: '%v'", errDuplicated, err)
}

_, err = encryptContext.EncryptRTCP(nil, testCase.packets[1].decrypted, nil)
if !errors.Is(err, errExceededMaxPackets) {
t.Errorf("Expected error: '%v', got: '%v'", errExceededMaxPackets, err)
}
})
}
}
15 changes: 12 additions & 3 deletions srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int) ([]byte, error) {
s := c.getSRTPSSRCState(header.SSRC)

markAsValid, ok := s.replayDetector.Check(uint64(header.SequenceNumber))
roc, diff, _ := s.nextRolloverCount(header.SequenceNumber)
markAsValid, ok := s.replayDetector.Check(
(uint64(roc) << 16) | uint64(header.SequenceNumber),
)
if !ok {
return nil, &duplicatedError{
Proto: "srtp", SSRC: header.SSRC, Index: uint32(header.SequenceNumber),
Expand All @@ -20,7 +23,6 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL
return nil, err
}
dst = growBufferSize(dst, len(ciphertext)-authTagLen)
roc, diff := s.nextRolloverCount(header.SequenceNumber)

dst, err = c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc)
if err != nil {
Expand Down Expand Up @@ -67,7 +69,14 @@ func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) (
// Similar to above but faster because it can avoid unmarshaling the header and marshaling the payload.
func (c *Context) encryptRTP(dst []byte, header *rtp.Header, payload []byte) (ciphertext []byte, err error) {
s := c.getSRTPSSRCState(header.SSRC)
roc, diff := s.nextRolloverCount(header.SequenceNumber)
roc, diff, ovf := s.nextRolloverCount(header.SequenceNumber)
if ovf {
// ... when 2^48 SRTP packets or 2^31 SRTCP packets have been secured with the same key
// (whichever occurs before), the key management MUST be called to provide new master key(s)
// (previously stored and used keys MUST NOT be used again), or the session MUST be terminated.
// https://www.rfc-editor.org/rfc/rfc3711#section-9.2
return nil, errExceededMaxPackets
}
s.updateRolloverCount(header.SequenceNumber, diff)

return c.cipher.encryptRTP(dst, header, payload, roc)
Expand Down
Loading

0 comments on commit 5ce39f6

Please sign in to comment.