From 1bdeef256f2943f28f23d37b621be7d1a0a49572 Mon Sep 17 00:00:00 2001 From: Atsushi Watanabe Date: Tue, 7 Feb 2023 12:40:27 +0900 Subject: [PATCH] Fix rollover estimation after SetROC SRTP context estimated the wrong ROC when the sequence number is overflown during burst packet loss, and SetROC gives the correct ROC. --- context.go | 3 ++- srtp_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/context.go b/context.go index a5b694e..438a51c 100644 --- a/context.go +++ b/context.go @@ -201,7 +201,8 @@ func (c *Context) ROC(ssrc uint32) (uint32, bool) { // SetROC sets SRTP rollover counter value of specified SSRC. func (c *Context) SetROC(ssrc uint32, roc uint32) { s := c.getSRTPSSRCState(ssrc) - s.index = uint64(roc)<<16 | (s.index & (seqNumMax - 1)) + s.index = uint64(roc) << 16 + s.rolloverHasProcessed = false } // Index returns SRTCP index value of specified SSRC. diff --git a/srtp_test.go b/srtp_test.go index 91091c5..e239614 100644 --- a/srtp_test.go +++ b/srtp_test.go @@ -774,3 +774,78 @@ func TestRTPMaxPackets(t *testing.T) { }) } } + +func TestRTPBurstLossWithSetROC(t *testing.T) { + profiles := map[string]ProtectionProfile{ + "CTR": profileCTR, + "GCM": profileGCM, + } + for name, profile := range profiles { + profile := profile + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + encryptContext, err := buildTestContext(profile) + if err != nil { + t.Fatal(err) + } + + type packetWithROC struct { + pkt rtp.Packet + enc []byte + raw []byte + + roc uint32 + } + + var pkts []*packetWithROC + encryptContext.SetROC(1, 3) + for i := 0x8C00; i < 0x20400; i += 0x100 { + p := &packetWithROC{ + pkt: rtp.Packet{ + Payload: []byte{ + byte(i >> 16), + byte(i >> 8), + byte(i), + }, + Header: rtp.Header{ + Marker: true, + SSRC: 1, + SequenceNumber: uint16(i), + }, + }, + } + b, errMarshal := p.pkt.Marshal() + if errMarshal != nil { + t.Fatal(errMarshal) + } + p.raw = b + enc, errEnc := encryptContext.EncryptRTP(nil, b, nil) + if errEnc != nil { + t.Fatal(errEnc) + } + p.roc, _ = encryptContext.ROC(1) + if 0x9000 < i && i < 0x20100 { + continue + } + p.enc = enc + pkts = append(pkts, p) + } + + decryptContext, err := buildTestContext(profile) + if err != nil { + t.Fatal(err) + } + + for _, p := range pkts { + decryptContext.SetROC(1, p.roc) + pkt, err := decryptContext.DecryptRTP(nil, p.enc, nil) + if err != nil { + t.Errorf("roc=%d, seq=%d: %v", p.roc, p.pkt.SequenceNumber, err) + continue + } + assert.Equal(p.raw, pkt) + } + }) + } +}