diff --git a/cmd/signal/allrpc/main.go b/cmd/signal/allrpc/main.go index 47527ac21..7010a0cbc 100644 --- a/cmd/signal/allrpc/main.go +++ b/cmd/signal/allrpc/main.go @@ -112,11 +112,11 @@ func parse() bool { } func getEnv(key string) string { - if value, exists := os.LookupEnv(key); exists { + if value, exists := os.LookupEnv(key); exists { return value - } + } - return "" + return "" } func main() { diff --git a/config.toml b/config.toml index 2b516902c..a13b2b1f3 100644 --- a/config.toml +++ b/config.toml @@ -10,8 +10,8 @@ withstats = false # Limit the remb bandwidth in kbps # zero means no limits maxbandwidth = 1500 -# max buffer time by ms for video tracks -maxbuffertime = 1000 +# max number of video tracks packets the SFU will keep track +maxpackettrack = 500 # Sets the audio level volume threshold. # Values from [0-127] where 0 is the loudest. # Audio levels are read from rtp extension header according to: diff --git a/pkg/buffer/bucket.go b/pkg/buffer/bucket.go new file mode 100644 index 000000000..9f9df5f97 --- /dev/null +++ b/pkg/buffer/bucket.go @@ -0,0 +1,107 @@ +package buffer + +import ( + "encoding/binary" + "math" +) + +const maxPktSize = 1350 + +type Bucket struct { + buf []byte + + headSN uint16 + step int + maxSteps int +} + +func NewBucket(buf []byte) *Bucket { + return &Bucket{ + buf: buf, + maxSteps: int(math.Floor(float64(len(buf))/float64(maxPktSize))) - 1, + } +} + +func (b *Bucket) AddPacket(pkt []byte, sn uint16, latest bool) ([]byte, error) { + if !latest { + return b.set(sn, pkt) + } + diff := sn - b.headSN + b.headSN = sn + for i := uint16(1); i < diff; i++ { + b.step++ + if b.step >= b.maxSteps { + b.step = 0 + } + } + return b.push(pkt), nil +} + +func (b *Bucket) GetPacket(buf []byte, sn uint16) (i int, err error) { + p := b.get(sn) + if p == nil { + err = errPacketNotFound + return + } + i = len(p) + if cap(buf) < i { + err = errBufferTooSmall + return + } + if len(buf) < i { + buf = buf[:i] + } + copy(buf, p) + return +} + +func (b *Bucket) push(pkt []byte) []byte { + binary.BigEndian.PutUint16(b.buf[b.step*maxPktSize:], uint16(len(pkt))) + off := b.step*maxPktSize + 2 + copy(b.buf[off:], pkt) + b.step++ + if b.step > b.maxSteps { + b.step = 0 + } + return b.buf[off : off+len(pkt)] +} + +func (b *Bucket) get(sn uint16) []byte { + pos := b.step - int(b.headSN-sn+1) + if pos < 0 { + if pos*-1 > b.maxSteps+1 { + return nil + } + pos = b.maxSteps + pos + 1 + } + off := pos * maxPktSize + if off > len(b.buf) { + return nil + } + if binary.BigEndian.Uint16(b.buf[off+4:off+6]) != sn { + return nil + } + sz := int(binary.BigEndian.Uint16(b.buf[off : off+2])) + return b.buf[off+2 : off+2+sz] +} + +func (b *Bucket) set(sn uint16, pkt []byte) ([]byte, error) { + if b.headSN-sn >= uint16(b.maxSteps+1) { + return nil, errPacketTooOld + } + pos := b.step - int(b.headSN-sn+1) + if pos < 0 { + pos = b.maxSteps + pos + 1 + } + off := pos * maxPktSize + if off > len(b.buf) || off < 0 { + return nil, errPacketTooOld + } + // Do not overwrite if packet exist + if binary.BigEndian.Uint16(b.buf[off+4:off+6]) == sn { + return b.buf[off+2 : off+2+len(pkt)], nil + } + binary.BigEndian.PutUint16(b.buf[off:], uint16(len(pkt))) + copy(b.buf[off+2:], pkt) + return b.buf[off+2 : off+2+len(pkt)], nil +} diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index c4045d9a4..75c8502d3 100644 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go @@ -9,7 +9,6 @@ import ( "time" "github.com/gammazero/deque" - log "github.com/pion/ion-log" "github.com/pion/rtcp" "github.com/pion/rtp" @@ -40,22 +39,23 @@ type ExtPacket struct { // Buffer contains all packets type Buffer struct { sync.Mutex - pool *sync.Pool - nacker *nackQueue - packetQueue *PacketQueue - codecType webrtc.RTPCodecType - extPackets deque.Deque - pPackets []pendingPackets - closeOnce sync.Once - mediaSSRC uint32 - clockRate uint32 - maxBitrate uint64 - lastReport int64 - twccExt uint8 - audioExt uint8 - bound bool - closed atomicBool - mime string + bucket *Bucket + nacker *nackQueue + videoPool *sync.Pool + audioPool *sync.Pool + codecType webrtc.RTPCodecType + extPackets deque.Deque + pPackets []pendingPackets + closeOnce sync.Once + mediaSSRC uint32 + clockRate uint32 + maxBitrate uint64 + lastReport int64 + twccExt uint8 + audioExt uint8 + bound bool + closed atomicBool + mime string // supported feedbacks remb bool @@ -101,15 +101,15 @@ type Stats struct { // BufferOptions provides configuration options for the buffer type Options struct { - BufferTime int MaxBitRate uint64 } // NewBuffer constructs a new Buffer -func NewBuffer(ssrc uint32, pp *sync.Pool) *Buffer { +func NewBuffer(ssrc uint32, vp, ap *sync.Pool) *Buffer { b := &Buffer{ mediaSSRC: ssrc, - pool: pp, + videoPool: vp, + audioPool: ap, } b.extPackets.SetMinCapacity(7) return b @@ -126,8 +126,10 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, o Options) { switch { case strings.HasPrefix(b.mime, "audio/"): b.codecType = webrtc.RTPCodecTypeAudio + b.bucket = NewBucket(b.audioPool.Get().([]byte)) case strings.HasPrefix(b.mime, "video/"): b.codecType = webrtc.RTPCodecTypeVideo + b.bucket = NewBucket(b.videoPool.Get().([]byte)) default: b.codecType = webrtc.RTPCodecType(0) } @@ -140,7 +142,6 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, o Options) { } if b.codecType == webrtc.RTPCodecTypeVideo { - b.packetQueue = NewPacketQueue(b.pool, 500) for _, fb := range codec.RTCPFeedback { switch fb.Type { case webrtc.TypeRTCPFBGoogREMB: @@ -156,7 +157,6 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, o Options) { } } } else if b.codecType == webrtc.RTPCodecTypeAudio { - b.packetQueue = NewPacketQueue(b.pool, 50) for _, h := range params.HeaderExtensions { if h.URI == sdp.AudioLevelURI { b.audioLevel = true @@ -244,9 +244,14 @@ func (b *Buffer) Close() error { defer b.Unlock() b.closeOnce.Do(func() { + if b.bucket != nil && b.codecType == webrtc.RTPCodecTypeVideo { + b.videoPool.Put(b.bucket.buf) + } + if b.bucket != nil && b.codecType == webrtc.RTPCodecTypeAudio { + b.audioPool.Put(b.bucket.buf) + } b.closed.set(true) b.onClose() - b.packetQueue.Close() }) return nil } @@ -261,7 +266,7 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { if b.stats.PacketCount == 0 { b.baseSN = sn b.maxSeqNo = sn - b.packetQueue.headSN = sn - 1 + b.bucket.headSN = sn - 1 b.lastReport = arrivalTime } else if (sn-b.maxSeqNo)&0x8000 == 0 { if sn < b.maxSeqNo { @@ -296,7 +301,12 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { b.stats.PacketCount++ var p rtp.Packet - if err := p.Unmarshal(b.packetQueue.AddPacket(pkt, sn, sn == b.maxSeqNo)); err != nil { + pb, err := b.bucket.AddPacket(pkt, sn, sn == b.maxSeqNo) + if err != nil { + log.Errorf("buffer write err: %v", err) + return + } + if err = p.Unmarshal(pb); err != nil { return } @@ -493,7 +503,7 @@ func (b *Buffer) GetPacket(buff []byte, sn uint16) (int, error) { if b.closed.get() { return 0, io.EOF } - return b.packetQueue.GetPacket(buff, sn) + return b.bucket.GetPacket(buff, sn) } // Bitrate returns the current publisher stream bitrate. diff --git a/pkg/buffer/buffer_test.go b/pkg/buffer/buffer_test.go index 67de83e6e..a3762679a 100644 --- a/pkg/buffer/buffer_test.go +++ b/pkg/buffer/buffer_test.go @@ -57,7 +57,6 @@ func TestNewBuffer(t *testing.T) { name: "Must not be nil and add packets in sequence", args: args{ options: Options{ - BufferTime: 1000, MaxBitRate: 1e6, }, }, @@ -93,7 +92,7 @@ func TestNewBuffer(t *testing.T) { return make([]byte, 1500) }, } - buff := NewBuffer(123, pool) + buff := NewBuffer(123, pool, pool) buff.codecType = webrtc.RTPCodecTypeVideo assert.NotNil(t, buff) assert.NotNil(t, TestPackets) diff --git a/pkg/buffer/errors.go b/pkg/buffer/errors.go index 90d2119e9..e2e7fe125 100644 --- a/pkg/buffer/errors.go +++ b/pkg/buffer/errors.go @@ -5,5 +5,5 @@ import "errors" var ( errPacketNotFound = errors.New("packet not found in cache") errBufferTooSmall = errors.New("buffer too small") - errExtNotFound = errors.New("ext not found") + errPacketTooOld = errors.New("received packet too old") ) diff --git a/pkg/buffer/factory.go b/pkg/buffer/factory.go index 07de72532..0b40f6315 100644 --- a/pkg/buffer/factory.go +++ b/pkg/buffer/factory.go @@ -9,16 +9,22 @@ import ( type Factory struct { sync.RWMutex - packetPool *sync.Pool + videoPool *sync.Pool + audioPool *sync.Pool rtpBuffers map[uint32]*Buffer rtcpReaders map[uint32]*RTCPReader } -func NewBufferFactory() *Factory { +func NewBufferFactory(trackingPackets int) *Factory { return &Factory{ - packetPool: &sync.Pool{ + videoPool: &sync.Pool{ New: func() interface{} { - return make([]byte, 1500) + return make([]byte, trackingPackets*maxPktSize) + }, + }, + audioPool: &sync.Pool{ + New: func() interface{} { + return make([]byte, maxPktSize*25) }, }, rtpBuffers: make(map[uint32]*Buffer), @@ -46,7 +52,7 @@ func (f *Factory) GetOrNew(packetType packetio.BufferPacketType, ssrc uint32) io if reader, ok := f.rtpBuffers[ssrc]; ok { return reader } - buffer := NewBuffer(ssrc, f.packetPool) + buffer := NewBuffer(ssrc, f.videoPool, f.audioPool) f.rtpBuffers[ssrc] = buffer buffer.OnClose(func() { f.Lock() diff --git a/pkg/buffer/queue_test.go b/pkg/buffer/queue_test.go index 810b70883..909c1b58b 100644 --- a/pkg/buffer/queue_test.go +++ b/pkg/buffer/queue_test.go @@ -1,12 +1,10 @@ package buffer import ( - "sync" "testing" - "github.com/stretchr/testify/assert" - "github.com/pion/rtp" + "github.com/stretchr/testify/assert" ) var TestPackets = []*rtp.Packet{ @@ -43,9 +41,7 @@ var TestPackets = []*rtp.Packet{ } func Test_queue(t *testing.T) { - q := NewPacketQueue(&sync.Pool{New: func() interface{} { - return make([]byte, 1500) - }}, 500) + q := NewBucket(make([]byte, maxPktSize*50)) for _, p := range TestPackets { p := p @@ -58,7 +54,7 @@ func Test_queue(t *testing.T) { var expectedSN uint16 expectedSN = 6 np := rtp.Packet{} - buff := make([]byte, 1500) + buff := make([]byte, maxPktSize) i, err := q.GetPacket(buff, 6) assert.NoError(t, err) err = np.Unmarshal(buff[:i]) @@ -74,30 +70,12 @@ func Test_queue(t *testing.T) { assert.NoError(t, err) expectedSN = 8 q.AddPacket(buf, 8, false) + q.AddPacket(buf, 8, false) i, err = q.GetPacket(buff, expectedSN) assert.NoError(t, err) err = np.Unmarshal(buff[:i]) assert.NoError(t, err) assert.Equal(t, expectedSN, np.SequenceNumber) - assert.NotPanics(t, q.Close) -} - -func Test_queue_disorder(t *testing.T) { - d := []byte("dummy data") - q := NewPacketQueue(&sync.Pool{New: func() interface{} { - return make([]byte, 1500) - }}, 500) - - q.headSN = 25745 - q.AddPacket(d, 25746, true) - dd := q.AddPacket(d, 25743, false) - assert.NotNil(t, dd) - assert.Equal(t, dd, d) - assert.Equal(t, 4, q.size) - dd = q.AddPacket(d, 25745, false) - assert.NotNil(t, dd) - assert.Equal(t, dd, d) - assert.Equal(t, 4, q.size) } func Test_queue_edges(t *testing.T) { @@ -118,9 +96,7 @@ func Test_queue_edges(t *testing.T) { }, }, } - q := NewPacketQueue(&sync.Pool{New: func() interface{} { - return make([]byte, 1500) - }}, 500) + q := NewBucket(make([]byte, 25000)) q.headSN = 65532 for _, p := range TestPackets { p := p @@ -137,7 +113,7 @@ func Test_queue_edges(t *testing.T) { var expectedSN uint16 expectedSN = 65534 np := rtp.Packet{} - buff := make([]byte, 1500) + buff := make([]byte, maxPktSize) i, err := q.GetPacket(buff, expectedSN) assert.NoError(t, err) err = np.Unmarshal(buff[:i]) @@ -157,5 +133,4 @@ func Test_queue_edges(t *testing.T) { err = np.Unmarshal(buff[:i]) assert.NoError(t, err) assert.Equal(t, expectedSN+1, np.SequenceNumber) - assert.NotPanics(t, q.Close) } diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index efcb8b1ac..278f9eafd 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -26,17 +26,19 @@ const ( // and SVC Publisher. type DownTrack struct { sync.Mutex - id string - peerID string - bound atomicBool - mime string - ssrc uint32 - streamID string - payloadType uint8 - sequencer *sequencer - trackType DownTrackType - skipFB int64 - payload []byte + id string + peerID string + bound atomicBool + mime string + ssrc uint32 + streamID string + maxTrack int + payloadType uint8 + sequencer *sequencer + trackType DownTrackType + bufferFactory *buffer.Factory + skipFB int64 + payload []byte spatialLayer int32 temporalLayer int32 @@ -69,13 +71,15 @@ type DownTrack struct { } // NewDownTrack returns a DownTrack. -func NewDownTrack(c webrtc.RTPCodecCapability, r Receiver, peerID string) (*DownTrack, error) { +func NewDownTrack(c webrtc.RTPCodecCapability, r Receiver, bf *buffer.Factory, peerID string, mt int) (*DownTrack, error) { return &DownTrack{ - id: r.TrackID(), - peerID: peerID, - streamID: r.StreamID(), - receiver: r, - codec: c, + id: r.TrackID(), + peerID: peerID, + maxTrack: mt, + streamID: r.StreamID(), + bufferFactory: bf, + receiver: r, + codec: c, }, nil } @@ -91,13 +95,13 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.mime = strings.ToLower(codec.MimeType) d.reSync.set(true) d.enabled.set(true) - if rr := bufferFactory.GetOrNew(packetio.RTCPBufferPacket, uint32(t.SSRC())).(*buffer.RTCPReader); rr != nil { + if rr := d.bufferFactory.GetOrNew(packetio.RTCPBufferPacket, uint32(t.SSRC())).(*buffer.RTCPReader); rr != nil { rr.OnPacket(func(pkt []byte) { d.handleRTCP(pkt) }) } if strings.HasPrefix(d.codec.MimeType, "video/") { - d.sequencer = newSequencer() + d.sequencer = newSequencer(d.maxTrack) } d.onBind() d.bound.set(true) @@ -270,6 +274,11 @@ func (d *DownTrack) CreateSenderReport() *rtcp.SenderReport { } } +func (d *DownTrack) UpdateStats(packetLen uint32) { + atomic.AddUint32(&d.octetCount, packetLen) + atomic.AddUint32(&d.packetCount, 1) +} + func (d *DownTrack) writeSimpleRTP(extPkt buffer.ExtPacket) error { if d.reSync.get() { if d.Kind() == webrtc.RTPCodecTypeVideo { @@ -286,8 +295,7 @@ func (d *DownTrack) writeSimpleRTP(extPkt buffer.ExtPacket) error { d.reSync.set(false) } - atomic.AddUint32(&d.octetCount, uint32(len(extPkt.Packet.Payload))) - atomic.AddUint32(&d.packetCount, 1) + d.UpdateStats(uint32(len(extPkt.Packet.Payload))) newSN := extPkt.Packet.SequenceNumber - d.snOffset newTS := extPkt.Packet.Timestamp - d.tsOffset diff --git a/pkg/sfu/publisher.go b/pkg/sfu/publisher.go index fd332f84c..c67cfe0f7 100644 --- a/pkg/sfu/publisher.go +++ b/pkg/sfu/publisher.go @@ -41,7 +41,7 @@ func NewPublisher(session *Session, id string, cfg WebRTCTransportConfig) (*Publ id: id, pc: pc, session: session, - router: newRouter(pc, id, session, cfg.router), + router: newRouter(id, pc, session, cfg.router), } pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 8b92426d7..e73d4bc11 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -238,6 +238,8 @@ func (w *WebRTCReceiver) RetransmitPackets(track *DownTrack, packets []packetMet } pkt.Header.SequenceNumber = meta.targetSeqNo pkt.Header.Timestamp = meta.timestamp + pkt.Header.SSRC = track.ssrc + pkt.Header.PayloadType = track.payloadType if track.simulcast.temporalSupported { switch track.mime { case "video/vp8": @@ -252,6 +254,8 @@ func (w *WebRTCReceiver) RetransmitPackets(track *DownTrack, packets []packetMet if _, err = track.writeStream.WriteRTP(&pkt.Header, pkt.Payload); err != nil { log.Errorf("Writing rtx packet err: %v", err) + } else { + track.UpdateStats(uint32(i)) } packetFactory.Put(pktBuff) diff --git a/pkg/sfu/router.go b/pkg/sfu/router.go index 4dac240ad..4d80daf83 100644 --- a/pkg/sfu/router.go +++ b/pkg/sfu/router.go @@ -23,7 +23,7 @@ type Router interface { type RouterConfig struct { WithStats bool `mapstructure:"withstats"` MaxBandwidth uint64 `mapstructure:"maxbandwidth"` - MaxBufferTime int `mapstructure:"maxbuffertime"` + MaxPacketTrack int `mapstructure:"maxpackettrack"` AudioLevelInterval int `mapstructure:"audiolevelinterval"` AudioLevelThreshold uint8 `mapstructure:"audiolevelthreshold"` AudioLevelFilter int `mapstructure:"audiolevelfilter"` @@ -32,29 +32,31 @@ type RouterConfig struct { type router struct { sync.RWMutex - id string - twcc *twcc.Responder - peer *webrtc.PeerConnection - stats map[uint32]*stats.Stream - rtcpCh chan []rtcp.Packet - stopCh chan struct{} - config RouterConfig - session *Session - receivers map[string]Receiver + id string + twcc *twcc.Responder + peer *webrtc.PeerConnection + stats map[uint32]*stats.Stream + rtcpCh chan []rtcp.Packet + stopCh chan struct{} + config RouterConfig + session *Session + receivers map[string]Receiver + bufferFactory *buffer.Factory } // newRouter for routing rtp/rtcp packets -func newRouter(peer *webrtc.PeerConnection, id string, session *Session, config RouterConfig) Router { +func newRouter(id string, peer *webrtc.PeerConnection, session *Session, config RouterConfig) Router { ch := make(chan []rtcp.Packet, 10) r := &router{ - id: id, - peer: peer, - rtcpCh: ch, - stopCh: make(chan struct{}), - config: config, - session: session, - receivers: make(map[string]Receiver), - stats: make(map[uint32]*stats.Stream), + id: id, + peer: peer, + rtcpCh: ch, + stopCh: make(chan struct{}), + config: config, + session: session, + receivers: make(map[string]Receiver), + stats: make(map[uint32]*stats.Stream), + bufferFactory: session.BufferFactory(), } if config.WithStats { @@ -85,7 +87,7 @@ func (r *router) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.TrackRe trackID := track.ID() rid := track.RID() - buff, rtcpReader := bufferFactory.GetBufferPair(uint32(track.SSRC())) + buff, rtcpReader := r.bufferFactory.GetBufferPair(uint32(track.SSRC())) buff.OnFeedback(func(fb []rtcp.Packet) { r.rtcpCh <- fb @@ -176,7 +178,6 @@ func (r *router) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.TrackRe recv.AddUpTrack(track, buff) buff.Bind(receiver.GetParameters(), buffer.Options{ - BufferTime: r.config.MaxBufferTime, MaxBitRate: r.config.MaxBandwidth, }) @@ -233,7 +234,7 @@ func (r *router) addDownTrack(sub *Subscriber, recv Receiver) error { Channels: codec.Channels, SDPFmtpLine: codec.SDPFmtpLine, RTCPFeedback: []webrtc.RTCPFeedback{{"goog-remb", ""}, {"nack", ""}, {"nack", "pli"}}, - }, recv, sub.id) + }, recv, r.bufferFactory, sub.id, r.config.MaxPacketTrack) if err != nil { return err } diff --git a/pkg/sfu/sequencer.go b/pkg/sfu/sequencer.go index 53f286dfb..0f9e2d7a2 100644 --- a/pkg/sfu/sequencer.go +++ b/pkg/sfu/sequencer.go @@ -8,8 +8,6 @@ import ( ) const ( - maxPacketMetaHistory = 500 - ignoreRetransmission = 100 // Ignore packet retransmission after ignoreRetransmission milliseconds ) @@ -50,15 +48,18 @@ func (p packetMeta) getVP8PayloadMeta() (uint8, uint16) { type sequencer struct { sync.Mutex init bool - seq [500]packetMeta + max int + seq []packetMeta step int headSN uint16 startTime int64 } -func newSequencer() *sequencer { +func newSequencer(maxTrack int) *sequencer { return &sequencer{ startTime: time.Now().UnixNano() / 1e6, + max: maxTrack, + seq: make([]packetMeta, maxTrack), } } @@ -75,7 +76,7 @@ func (n *sequencer) push(sn, offSn uint16, timeStamp uint32, layer uint8, head b inc := offSn - n.headSN for i := uint16(1); i < inc; i++ { n.step++ - if n.step >= maxPacketMetaHistory { + if n.step >= n.max { n.step = 0 } } @@ -84,11 +85,11 @@ func (n *sequencer) push(sn, offSn uint16, timeStamp uint32, layer uint8, head b } else { step = n.step - int(n.headSN-offSn) if step < 0 { - if step*-1 >= maxPacketMetaHistory { + if step*-1 >= n.max { log.Warnf("old packet received, can not be sequenced, head: %d received: %d", sn, offSn) return nil } - step = maxPacketMetaHistory + step + step = n.max + step } } n.seq[n.step] = packetMeta{ @@ -98,7 +99,7 @@ func (n *sequencer) push(sn, offSn uint16, timeStamp uint32, layer uint8, head b layer: layer, } n.step++ - if n.step >= maxPacketMetaHistory { + if n.step >= n.max { n.step = 0 } return &n.seq[n.step] @@ -111,10 +112,10 @@ func (n *sequencer) getSeqNoPairs(seqNo []uint16) []packetMeta { for _, sn := range seqNo { step := n.step - int(n.headSN-sn) - 1 if step < 0 { - if step*-1 >= maxPacketMetaHistory { + if step*-1 >= n.max { continue } - step = maxPacketMetaHistory + step + step = n.max + step } seq := &n.seq[step] if seq.targetSeqNo == sn { diff --git a/pkg/sfu/sequencer_test.go b/pkg/sfu/sequencer_test.go index 170d5e86a..3d1c2b491 100644 --- a/pkg/sfu/sequencer_test.go +++ b/pkg/sfu/sequencer_test.go @@ -9,7 +9,7 @@ import ( ) func Test_sequencer(t *testing.T) { - seq := newSequencer() + seq := newSequencer(500) off := uint16(15) for i := uint16(1); i < 520; i++ { @@ -67,7 +67,7 @@ func Test_sequencer_getNACKSeqNo(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - n := newSequencer() + n := newSequencer(500) for _, i := range tt.fields.input { n.push(i, i+tt.fields.offset, 123, 3, true) diff --git a/pkg/sfu/session.go b/pkg/sfu/session.go index 7a8bcd07c..088c990f3 100644 --- a/pkg/sfu/session.go +++ b/pkg/sfu/session.go @@ -5,6 +5,8 @@ import ( "sync" "time" + "github.com/pion/ion-sfu/pkg/buffer" + log "github.com/pion/ion-log" "github.com/pion/webrtc/v3" ) @@ -19,15 +21,17 @@ type Session struct { fanOutDCs []string datachannels []*Datachannel audioObserver *audioLevel + bufferFactory *buffer.Factory onCloseHandler func() } // NewSession creates a new session -func NewSession(id string, dcs []*Datachannel, cfg WebRTCTransportConfig) *Session { +func NewSession(id string, bf *buffer.Factory, dcs []*Datachannel, cfg WebRTCTransportConfig) *Session { s := &Session{ id: id, peers: make(map[string]*Peer), datachannels: dcs, + bufferFactory: bf, audioObserver: newAudioLevel(cfg.router.AudioLevelThreshold, cfg.router.AudioLevelInterval, cfg.router.AudioLevelFilter), } go s.audioLevelObserver(cfg.router.AudioLevelInterval) @@ -181,6 +185,11 @@ func (s *Session) Subscribe(peer *Peer) { peer.subscriber.negotiate() } +// BufferFactory returns current session buffer factory +func (s *Session) BufferFactory() *buffer.Factory { + return s.bufferFactory +} + // Transports returns peers in this session func (s *Session) Peers() []*Peer { s.mu.RLock() diff --git a/pkg/sfu/sfu.go b/pkg/sfu/sfu.go index 755857f26..0d7e82d26 100644 --- a/pkg/sfu/sfu.go +++ b/pkg/sfu/sfu.go @@ -48,25 +48,26 @@ type Config struct { Ballast int64 `mapstructure:"ballast"` WithStats bool `mapstructure:"withstats"` } `mapstructure:"sfu"` - WebRTC WebRTCConfig `mapstructure:"webrtc"` - Log log.Config `mapstructure:"log"` - Router RouterConfig `mapstructure:"router"` - Turn TurnConfig `mapstructure:"turn"` + WebRTC WebRTCConfig `mapstructure:"webrtc"` + Log log.Config `mapstructure:"log"` + Router RouterConfig `mapstructure:"router"` + Turn TurnConfig `mapstructure:"turn"` + BufferFactory *buffer.Factory } var ( - bufferFactory *buffer.Factory packetFactory *sync.Pool ) // SFU represents an sfu instance type SFU struct { sync.RWMutex - webrtc WebRTCTransportConfig - turn *turn.Server - sessions map[string]*Session - datachannels []*Datachannel - withStats bool + webrtc WebRTCTransportConfig + turn *turn.Server + sessions map[string]*Session + datachannels []*Datachannel + bufferFactory *buffer.Factory + withStats bool } // NewWebRTCTransportConfig parses our settings and returns a usable WebRTCTransportConfig for creating PeerConnections @@ -104,7 +105,7 @@ func NewWebRTCTransportConfig(c Config) WebRTCTransportConfig { } } - se.BufferFactory = bufferFactory.GetOrNew + se.BufferFactory = c.BufferFactory.GetOrNew sdpSemantics := webrtc.SDPSemanticsUnifiedPlan switch c.WebRTC.SDPSemantics { @@ -140,8 +141,6 @@ func NewWebRTCTransportConfig(c Config) WebRTCTransportConfig { } func init() { - // Init buffer factory - bufferFactory = buffer.NewBufferFactory() // Init packet factory packetFactory = &sync.Pool{ New: func() interface{} { @@ -157,12 +156,17 @@ func NewSFU(c Config) *SFU { // Init ballast ballast := make([]byte, c.SFU.Ballast*1024*1024) + if c.BufferFactory == nil { + c.BufferFactory = buffer.NewBufferFactory(c.Router.MaxPacketTrack) + } + w := NewWebRTCTransportConfig(c) sfu := &SFU{ - webrtc: w, - sessions: make(map[string]*Session), - withStats: c.Router.WithStats, + webrtc: w, + sessions: make(map[string]*Session), + withStats: c.Router.WithStats, + bufferFactory: c.BufferFactory, } if c.Turn.Enabled { @@ -179,7 +183,7 @@ func NewSFU(c Config) *SFU { // NewSession creates a new session instance func (s *SFU) newSession(id string) *Session { - session := NewSession(id, s.datachannels, s.webrtc) + session := NewSession(id, s.bufferFactory, s.datachannels, s.webrtc) session.OnClose(func() { s.Lock() diff --git a/pkg/sfu/sfu_test.go b/pkg/sfu/sfu_test.go index 0e8519f54..2dbbdeed7 100644 --- a/pkg/sfu/sfu_test.go +++ b/pkg/sfu/sfu_test.go @@ -135,7 +135,7 @@ func TestSFU_SessionScenarios(t *testing.T) { fixByFile := []string{"asm_amd64.s", "proc.go", "icegatherer.go", "jsonrpc2"} fixByFunc := []string{"Handle"} log.Init("trace", fixByFile, fixByFunc) - sfu := NewSFU(Config{Log: log.Config{Level: "trace"}}) + sfu := NewSFU(Config{Log: log.Config{Level: "trace"}, Router: RouterConfig{MaxPacketTrack: 200}}) sfu.NewDatachannel(APIChannelLabel) tests := []struct { name string