From 30f6940e6a96718eb2d6588f3c35fa8b93425825 Mon Sep 17 00:00:00 2001 From: Aleksandr Alekseev Date: Wed, 2 Oct 2024 18:34:09 +0300 Subject: [PATCH] Add ResponderStreamsFilter to NACKs Allows user to selectively disable NACKs on some streams --- pkg/nack/generator_interceptor.go | 4 +- pkg/nack/generator_interceptor_test.go | 59 ++++++++++++++++++- pkg/nack/generator_option.go | 9 +++ pkg/nack/responder_interceptor.go | 10 ++-- pkg/nack/responder_interceptor_test.go | 81 ++++++++++++++++++++++++++ pkg/nack/responder_option.go | 13 ++++- 6 files changed, 169 insertions(+), 7 deletions(-) diff --git a/pkg/nack/generator_interceptor.go b/pkg/nack/generator_interceptor.go index 051e424c..ab2bb2c5 100644 --- a/pkg/nack/generator_interceptor.go +++ b/pkg/nack/generator_interceptor.go @@ -21,6 +21,7 @@ type GeneratorInterceptorFactory struct { // NewInterceptor constructs a new ReceiverInterceptor func (g *GeneratorInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { i := &GeneratorInterceptor{ + streamsFilter: streamSupportNack, size: 512, skipLastN: 0, maxNacksPerPacket: 0, @@ -47,6 +48,7 @@ func (g *GeneratorInterceptorFactory) NewInterceptor(_ string) (interceptor.Inte // GeneratorInterceptor interceptor generates nack feedback messages. type GeneratorInterceptor struct { interceptor.NoOp + streamsFilter func(info *interceptor.StreamInfo) bool size uint16 skipLastN uint16 maxNacksPerPacket uint16 @@ -86,7 +88,7 @@ func (n *GeneratorInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) int // BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method // will be called once per rtp packet. func (n *GeneratorInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { - if !streamSupportNack(info) { + if !n.streamsFilter(info) { return reader } diff --git a/pkg/nack/generator_interceptor_test.go b/pkg/nack/generator_interceptor_test.go index 46a812ae..e303c59a 100644 --- a/pkg/nack/generator_interceptor_test.go +++ b/pkg/nack/generator_interceptor_test.go @@ -43,7 +43,7 @@ func TestGeneratorInterceptor(t *testing.T) { case r := <-stream.ReadRTP(): assert.NoError(t, r.Err) assert.Equal(t, seqNum, r.Packet.SequenceNumber) - case <-time.After(10 * time.Millisecond): + case <-time.After(50 * time.Millisecond): t.Fatal("receiver rtp packet not found") } } @@ -76,3 +76,60 @@ func TestGeneratorInterceptor_InvalidSize(t *testing.T) { _, err := f.NewInterceptor("") assert.Error(t, err, ErrInvalidSize) } + +func TestGeneratorInterceptor_StreamFilter(t *testing.T) { + const interval = time.Millisecond * 10 + f, err := NewGeneratorInterceptor( + GeneratorSize(64), + GeneratorSkipLastN(2), + GeneratorInterval(interval), + GeneratorLog(logging.NewDefaultLoggerFactory().NewLogger("test")), + GeneratorStreamsFilter(func(info *interceptor.StreamInfo) bool { + return info.SSRC != 1 // enable nacks only for ssrc 2 + }), + ) + assert.NoError(t, err) + + i, err := f.NewInterceptor("") + assert.NoError(t, err) + + streamWithoutNacks := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 1, + RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, + }, i) + defer func() { + assert.NoError(t, streamWithoutNacks.Close()) + }() + + streamWithNacks := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 2, + RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, + }, i) + defer func() { + assert.NoError(t, streamWithNacks.Close()) + }() + + for _, seqNum := range []uint16{10, 11, 12, 14, 16, 18} { + streamWithNacks.ReceiveRTP(&rtp.Packet{Header: rtp.Header{SequenceNumber: seqNum}}) + streamWithoutNacks.ReceiveRTP(&rtp.Packet{Header: rtp.Header{SequenceNumber: seqNum}}) + } + + time.Sleep(interval * 2) // wait for at least 2 nack packets + + // both test streams receive RTCP packets about both test streams (as they both call BindRTCPWriter), so we + // can check only one + rtcpStream := streamWithNacks.WrittenRTCP() + + for { + select { + case pkts := <-rtcpStream: + for _, pkt := range pkts { + if nack, isNack := pkt.(*rtcp.TransportLayerNack); isNack { + assert.NotEqual(t, uint32(1), nack.MediaSSRC) // check there are no nacks for ssrc 1 + } + } + default: + return + } + } +} diff --git a/pkg/nack/generator_option.go b/pkg/nack/generator_option.go index 346bc999..5403e3ee 100644 --- a/pkg/nack/generator_option.go +++ b/pkg/nack/generator_option.go @@ -6,6 +6,7 @@ package nack import ( "time" + "github.com/pion/interceptor" "github.com/pion/logging" ) @@ -54,3 +55,11 @@ func GeneratorInterval(interval time.Duration) GeneratorOption { return nil } } + +// GeneratorStreamsFilter sets filter for generator streams +func GeneratorStreamsFilter(filter func(info *interceptor.StreamInfo) bool) GeneratorOption { + return func(r *GeneratorInterceptor) error { + r.streamsFilter = filter + return nil + } +} diff --git a/pkg/nack/responder_interceptor.go b/pkg/nack/responder_interceptor.go index 8f74952d..03b084de 100644 --- a/pkg/nack/responder_interceptor.go +++ b/pkg/nack/responder_interceptor.go @@ -24,9 +24,10 @@ type packetFactory interface { // NewInterceptor constructs a new ResponderInterceptor func (r *ResponderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { i := &ResponderInterceptor{ - size: 1024, - log: logging.NewDefaultLoggerFactory().NewLogger("nack_responder"), - streams: map[uint32]*localStream{}, + streamsFilter: streamSupportNack, + size: 1024, + log: logging.NewDefaultLoggerFactory().NewLogger("nack_responder"), + streams: map[uint32]*localStream{}, } for _, opt := range r.opts { @@ -49,6 +50,7 @@ func (r *ResponderInterceptorFactory) NewInterceptor(_ string) (interceptor.Inte // ResponderInterceptor responds to nack feedback messages type ResponderInterceptor struct { interceptor.NoOp + streamsFilter func(info *interceptor.StreamInfo) bool size uint16 log logging.LeveledLogger packetFactory packetFactory @@ -99,7 +101,7 @@ func (n *ResponderInterceptor) BindRTCPReader(reader interceptor.RTCPReader) int // BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method // will be called once per rtp packet. func (n *ResponderInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { - if !streamSupportNack(info) { + if !n.streamsFilter(info) { return writer } diff --git a/pkg/nack/responder_interceptor_test.go b/pkg/nack/responder_interceptor_test.go index 360e142b..9eb5b232 100644 --- a/pkg/nack/responder_interceptor_test.go +++ b/pkg/nack/responder_interceptor_test.go @@ -150,3 +150,84 @@ func TestResponderInterceptor_Race(t *testing.T) { } } } + +func TestResponderInterceptor_StreamFilter(t *testing.T) { + f, err := NewResponderInterceptor( + ResponderSize(8), + ResponderLog(logging.NewDefaultLoggerFactory().NewLogger("test")), + ResponderStreamsFilter(func(info *interceptor.StreamInfo) bool { + return info.SSRC != 1 // enable nacks only for ssrc 2 + })) + + require.NoError(t, err) + + i, err := f.NewInterceptor("") + require.NoError(t, err) + + streamWithoutNacks := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 1, + RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, + }, i) + defer func() { + require.NoError(t, streamWithoutNacks.Close()) + }() + + streamWithNacks := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 2, + RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, + }, i) + defer func() { + require.NoError(t, streamWithNacks.Close()) + }() + + for _, seqNum := range []uint16{10, 11, 12, 14, 15} { + require.NoError(t, streamWithoutNacks.WriteRTP(&rtp.Packet{Header: rtp.Header{SequenceNumber: seqNum, SSRC: 1}})) + require.NoError(t, streamWithNacks.WriteRTP(&rtp.Packet{Header: rtp.Header{SequenceNumber: seqNum, SSRC: 2}})) + + select { + case p := <-streamWithoutNacks.WrittenRTP(): + require.Equal(t, seqNum, p.SequenceNumber) + case <-time.After(10 * time.Millisecond): + t.Fatal("written rtp packet not found") + } + + select { + case p := <-streamWithNacks.WrittenRTP(): + require.Equal(t, seqNum, p.SequenceNumber) + case <-time.After(10 * time.Millisecond): + t.Fatal("written rtp packet not found") + } + } + + streamWithoutNacks.ReceiveRTCP([]rtcp.Packet{ + &rtcp.TransportLayerNack{ + MediaSSRC: 1, + SenderSSRC: 2, + Nacks: []rtcp.NackPair{ + {PacketID: 11, LostPackets: 0b1011}, // sequence numbers: 11, 12, 13, 15 + }, + }, + }) + + streamWithNacks.ReceiveRTCP([]rtcp.Packet{ + &rtcp.TransportLayerNack{ + MediaSSRC: 2, + SenderSSRC: 2, + Nacks: []rtcp.NackPair{ + {PacketID: 11, LostPackets: 0b1011}, // sequence numbers: 11, 12, 13, 15 + }, + }, + }) + + select { + case <-streamWithNacks.WrittenRTP(): + case <-time.After(10 * time.Millisecond): + t.Fatal("nack response expected") + } + + select { + case <-streamWithoutNacks.WrittenRTP(): + t.Fatal("no nack response expected") + case <-time.After(10 * time.Millisecond): + } +} diff --git a/pkg/nack/responder_option.go b/pkg/nack/responder_option.go index 0aaa7577..24c7c469 100644 --- a/pkg/nack/responder_option.go +++ b/pkg/nack/responder_option.go @@ -3,7 +3,10 @@ package nack -import "github.com/pion/logging" +import ( + "github.com/pion/interceptor" + "github.com/pion/logging" +) // ResponderOption can be used to configure ResponderInterceptor type ResponderOption func(s *ResponderInterceptor) error @@ -33,3 +36,11 @@ func DisableCopy() ResponderOption { return nil } } + +// ResponderStreamsFilter sets filter for local streams +func ResponderStreamsFilter(filter func(info *interceptor.StreamInfo) bool) ResponderOption { + return func(r *ResponderInterceptor) error { + r.streamsFilter = filter + return nil + } +}