From 8c81aa2c5183204d78dc837f45cde3848792c1a0 Mon Sep 17 00:00:00 2001 From: OrlandoCo Date: Fri, 9 Apr 2021 10:25:07 -0500 Subject: [PATCH] feat(relay): Add interface for relaying peers (#486) * feat(relay): add relay ORTC functionality * feat(relay): keep all the logic on provider * feat(relay): add sctp transport * feat(relay): add se configuration * feat(relay): add se configuration * feat(relay): add datachannels to relaying peers * feat(sfu): add outgoing relay * feat(sfu): make relay configurable on peer join * feat(relay): fix send signal not adding meta * feat(relay): Send SR to relay peers * fix(relay): send relay sender reports only * fix(relay): send relay sender reports only * fix(relay): rename relay peer to peer * fix: Fix misc issues * feat(relay): Include codec on receive callback * fix(simulcast): Fix layer change --- go.mod | 11 +- go.sum | 28 +++ pkg/buffer/bucket.go | 7 +- pkg/buffer/bucket_test.go | 1 - pkg/buffer/buffer.go | 9 +- pkg/relay/relay.go | 496 ++++++++++++++++++++++++++++++++++++++ pkg/sfu/downtrack.go | 50 ++-- pkg/sfu/helpers.go | 2 +- pkg/sfu/peer.go | 4 +- pkg/sfu/publisher.go | 81 ++++++- pkg/sfu/receiver.go | 39 ++- pkg/sfu/session.go | 108 ++++++--- pkg/sfu/sfu.go | 9 + 13 files changed, 764 insertions(+), 81 deletions(-) create mode 100644 pkg/relay/relay.go diff --git a/go.mod b/go.mod index fb4edb46c..d8105c866 100644 --- a/go.mod +++ b/go.mod @@ -14,15 +14,16 @@ require ( github.com/improbable-eng/grpc-web v0.13.0 github.com/lucsky/cuid v1.0.2 github.com/pion/dtls/v2 v2.0.8 - github.com/pion/ice/v2 v2.0.15 + github.com/pion/ice/v2 v2.0.16 github.com/pion/ion-log v1.0.0 github.com/pion/logging v0.2.2 github.com/pion/rtcp v1.2.6 github.com/pion/rtp v1.6.2 github.com/pion/sdp/v3 v3.0.4 - github.com/pion/transport v0.12.2 + github.com/pion/transport v0.12.3 github.com/pion/turn/v2 v2.0.5 - github.com/pion/webrtc/v3 v3.0.10 + github.com/pion/udp v0.1.1 // indirect + github.com/pion/webrtc/v3 v3.0.20 github.com/prometheus/client_golang v1.9.0 github.com/rs/cors v1.7.0 // indirect github.com/rs/zerolog v1.20.0 @@ -30,8 +31,10 @@ require ( github.com/sourcegraph/jsonrpc2 v0.0.0-20200429184054-15c2290dcb37 github.com/spf13/viper v1.7.1 github.com/stretchr/testify v1.7.0 + golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect + golang.org/x/net v0.0.0-20210331212208-0fccb6fa2b5c // indirect golang.org/x/sync v0.0.0-20201207232520-09787c993a3a - golang.org/x/sys v0.0.0-20210217090653-ed5674b6da4a // indirect + golang.org/x/sys v0.0.0-20210331175145-43e1dd70ce54 // indirect google.golang.org/grpc v1.35.0 google.golang.org/grpc/examples v0.0.0-20201209011439-fd32f6a4fefe // indirect google.golang.org/protobuf v1.25.0 diff --git a/go.sum b/go.sum index 1739cd3c9..af0955d46 100644 --- a/go.sum +++ b/go.sum @@ -292,14 +292,20 @@ github.com/pion/dtls/v2 v2.0.8 h1:reGe8rNIMfO/UAeFLqO61tl64t154Qfkr4U3Gzu1tsg= github.com/pion/dtls/v2 v2.0.8/go.mod h1:QuDII+8FVvk9Dp5t5vYIMTo7hh7uBkra+8QIm7QGm10= github.com/pion/ice/v2 v2.0.15 h1:KZrwa2ciL9od8+TUVJiYTNsCW9J5lktBjGwW1MacEnQ= github.com/pion/ice/v2 v2.0.15/go.mod h1:ZIiVGevpgAxF/cXiIVmuIUtCb3Xs4gCzCbXB6+nFkSI= +github.com/pion/ice/v2 v2.0.16 h1:K6bzD8ef9vMKbGMTHaUweHXEyuNGnvr2zdqKoLKZPn0= +github.com/pion/ice/v2 v2.0.16/go.mod h1:SJNJzC27gDZoOW0UoxIoC8Hf2PDxG28hQyNdSexDu38= github.com/pion/interceptor v0.0.9 h1:fk5hTdyLO3KURQsf/+RjMpEm4NE3yeTY9Kh97b5BvwA= github.com/pion/interceptor v0.0.9/go.mod h1:dHgEP5dtxOTf21MObuBAjJeAayPxLUAZjerGH8Xr07c= +github.com/pion/interceptor v0.0.12 h1:eC1iVneBIAQJEfaNAfDqAncJWhMDAnaXPRCJsltdokE= +github.com/pion/interceptor v0.0.12/go.mod h1:qzeuWuD/ZXvPqOnxNcnhWfkCZ2e1kwwslicyyPnhoK4= github.com/pion/ion-log v1.0.0 h1:2lJLImCmfCWCR38hLWsjQfBWe6NFz/htbqiYHwvOP/Q= github.com/pion/ion-log v1.0.0/go.mod h1:jwcla9KoB9bB/4FxYDSRJPcPYSLp5XiUUMnOLaqwl4E= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY= github.com/pion/mdns v0.0.4/go.mod h1:R1sL0p50l42S5lJs91oNdUL58nm0QHrhxnSegr++qC0= +github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw= +github.com/pion/mdns v0.0.5/go.mod h1:UgssrvdD3mxpi8tMxAXbsppL3vJ4Jipw1mTCW+al01g= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/rtcp v1.2.6 h1:1zvwBbyd0TeEuuWftrd/4d++m+/kZSeiguxU61LFWpo= @@ -309,10 +315,14 @@ github.com/pion/rtp v1.6.2/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko github.com/pion/sctp v1.7.10/go.mod h1:EhpTUQu1/lcK3xI+eriS6/96fWetHGCvBi9MSsnaBN0= github.com/pion/sctp v1.7.11 h1:UCnj7MsobLKLuP/Hh+JMiI/6W5Bs/VF45lWKgHFjSIE= github.com/pion/sctp v1.7.11/go.mod h1:EhpTUQu1/lcK3xI+eriS6/96fWetHGCvBi9MSsnaBN0= +github.com/pion/sctp v1.7.12 h1:GsatLufywVruXbZZT1CKg+Jr8ZTkwiPnmUC/oO9+uuY= +github.com/pion/sctp v1.7.12/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s= github.com/pion/sdp/v3 v3.0.4 h1:2Kf+dgrzJflNCSw3TV5v2VLeI0s/qkzy2r5jlR0wzf8= github.com/pion/sdp/v3 v3.0.4/go.mod h1:bNiSknmJE0HYBprTHXKPQ3+JjacTv5uap92ueJZKsRk= github.com/pion/srtp/v2 v2.0.1 h1:kgfh65ob3EcnFYA4kUBvU/menCp9u7qaJLXwWgpobzs= github.com/pion/srtp/v2 v2.0.1/go.mod h1:c8NWHhhkFf/drmHTAblkdu8++lsISEBBdAuiyxgqIsE= +github.com/pion/srtp/v2 v2.0.2 h1:664iGzVmaY7KYS5M0gleY0DscRo9ReDfTxQrq4UgGoU= +github.com/pion/srtp/v2 v2.0.2/go.mod h1:VEyLv4CuxrwGY8cxM+Ng3bmVy8ckz/1t6A0q/msKOw0= github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg= github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA= github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8= @@ -321,12 +331,20 @@ github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+D github.com/pion/transport v0.12.1/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q= github.com/pion/transport v0.12.2 h1:WYEjhloRHt1R86LhUKjC5y+P52Y11/QqEUalvtzVoys= github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q= +github.com/pion/transport v0.12.3 h1:vdBfvfU/0Wq8kd2yhUMSDB/x+O4Z9MYVl2fJ5BT4JZw= +github.com/pion/transport v0.12.3/go.mod h1:OViWW9SP2peE/HbwBvARicmAVnesphkNkCVZIWJ6q9A= github.com/pion/turn/v2 v2.0.5 h1:iwMHqDfPEDEOFzwWKT56eFmh6DYC6o/+xnLAEzgISbA= github.com/pion/turn/v2 v2.0.5/go.mod h1:APg43CFyt/14Uy7heYUOGWdkem/Wu4PhCO/bjyrTqMw= github.com/pion/udp v0.1.0 h1:uGxQsNyrqG3GLINv36Ff60covYmfrLoxzwnCsIYspXI= github.com/pion/udp v0.1.0/go.mod h1:BPELIjbwE9PRbd/zxI/KYBnbo7B6+oA6YuEaNE8lths= +github.com/pion/udp v0.1.1 h1:8UAPvyqmsxK8oOjloDk4wUt63TzFe9WEJkg5lChlj7o= +github.com/pion/udp v0.1.1/go.mod h1:6AFo+CMdKQm7UiA0eUPA8/eVCTx8jBIITLZHc9DWX5M= github.com/pion/webrtc/v3 v3.0.10 h1:hti6k0DeN4tbQmAZiA8v6OvdkANQGw+R3nyqk9+dnz0= github.com/pion/webrtc/v3 v3.0.10/go.mod h1:KdEZWLmBnxB2Qj4FtUb9vi1sIpqsHOisI7L6ggQBD0A= +github.com/pion/webrtc/v3 v3.0.20-0.20210401021312-0e0723e7127d h1:9HN8Fnxdp2x7H/WF1Ku00oYqpJQzb8MxVHsUL9WTzpc= +github.com/pion/webrtc/v3 v3.0.20-0.20210401021312-0e0723e7127d/go.mod h1:0eJnCpQrUMpRnvyonw4ZiWClToerpixrZ2KcoTxvX9M= +github.com/pion/webrtc/v3 v3.0.20 h1:Jj0sk45MqQdkR24E1wbFRmOzb1Lv258ot9zd2fYB/Pw= +github.com/pion/webrtc/v3 v3.0.20/go.mod h1:0eJnCpQrUMpRnvyonw4ZiWClToerpixrZ2KcoTxvX9M= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -452,6 +470,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY= golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -498,6 +518,10 @@ golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210119194325-5f4716e94777 h1:003p0dJM77cxMSyCPFphvZf/Y5/NXf5fzg6ufd1/Oew= golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= +golang.org/x/net v0.0.0-20210331212208-0fccb6fa2b5c h1:KHUzaHIpjWVlVVNh65G3hhuj3KB1HnjY6Cq5cTvRQT8= +golang.org/x/net v0.0.0-20210331212208-0fccb6fa2b5c/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -543,6 +567,10 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210217090653-ed5674b6da4a h1:m4knbKtdWq+rPB3TE+ApaRzkETZngkKdhYjvTnnRq4s= golang.org/x/sys v0.0.0-20210217090653-ed5674b6da4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210331175145-43e1dd70ce54 h1:rF3Ohx8DRyl8h2zw9qojyLHLhrJpEMgyPOImREEryf0= +golang.org/x/sys v0.0.0-20210331175145-43e1dd70ce54/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/pkg/buffer/bucket.go b/pkg/buffer/bucket.go index d1715bd0c..690fec8c1 100644 --- a/pkg/buffer/bucket.go +++ b/pkg/buffer/bucket.go @@ -10,8 +10,9 @@ const maxPktSize = 1350 type Bucket struct { buf []byte - headSN uint16 + init bool step int + headSN uint16 maxSteps int } @@ -23,6 +24,10 @@ func NewBucket(buf []byte) *Bucket { } func (b *Bucket) AddPacket(pkt []byte, sn uint16, latest bool) ([]byte, error) { + if !b.init { + b.headSN = sn - 1 + b.init = true + } if !latest { return b.set(sn, pkt) } diff --git a/pkg/buffer/bucket_test.go b/pkg/buffer/bucket_test.go index ad947f3c1..efe40f946 100644 --- a/pkg/buffer/bucket_test.go +++ b/pkg/buffer/bucket_test.go @@ -99,7 +99,6 @@ func Test_queue_edges(t *testing.T) { }, } q := NewBucket(make([]byte, 25000)) - q.headSN = 65532 for _, p := range TestPackets { p := p assert.NotNil(t, p) diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index 258cab4b2..dafa64ca7 100644 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go @@ -232,14 +232,14 @@ func (b *Buffer) Read(buff []byte) (n int, err error) { } } -func (b *Buffer) ReadExtended() (ExtPacket, error) { +func (b *Buffer) ReadExtended() (*ExtPacket, error) { for { if b.closed.get() { - return ExtPacket{}, io.EOF + return nil, io.EOF } b.Lock() if b.extPackets.Len() > 0 { - extPkt := b.extPackets.PopFront().(ExtPacket) + extPkt := b.extPackets.PopFront().(*ExtPacket) b.Unlock() return extPkt, nil } @@ -275,7 +275,6 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { if b.stats.PacketCount == 0 { b.baseSN = sn b.maxSeqNo = sn - b.bucket.headSN = sn - 1 b.lastReport = arrivalTime } else if (sn-b.maxSeqNo)&0x8000 == 0 { if sn < b.maxSeqNo { @@ -357,7 +356,7 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { b.minPacketProbe++ } - b.extPackets.PushBack(ep) + b.extPackets.PushBack(&ep) // if first time update or the timestamp is later (factoring timestamp wrap around) if (b.latestTimestampTime == 0) || IsLaterTimestamp(p.Timestamp, b.latestTimestamp) { diff --git a/pkg/relay/relay.go b/pkg/relay/relay.go new file mode 100644 index 000000000..f5ee3ec42 --- /dev/null +++ b/pkg/relay/relay.go @@ -0,0 +1,496 @@ +package relay + +import ( + "encoding/json" + "strings" + "sync" + + "github.com/pion/rtcp" + + "github.com/pion/ice/v2" + + "github.com/go-logr/logr" + "github.com/pion/webrtc/v3" +) + +type Provider struct { + mu sync.RWMutex + se webrtc.SettingEngine + log logr.Logger + peers map[string]*Peer + signal func(meta SignalMeta, signal []byte) ([]byte, error) + onRemote func(meta SignalMeta, receiver *webrtc.RTPReceiver, codec *webrtc.RTPCodecParameters) + iceServers []webrtc.ICEServer + onDatachannel func(meta SignalMeta, dc *webrtc.DataChannel) +} + +type Signal struct { + Metadata SignalMeta `json:"metadata"` + Encodings *webrtc.RTPCodingParameters `json:"encodings,omitempty"` + ICECandidates []webrtc.ICECandidate `json:"iceCandidates,omitempty"` + ICEParameters webrtc.ICEParameters `json:"iceParameters,omitempty"` + DTLSParameters webrtc.DTLSParameters `json:"dtlsParameters,omitempty"` + CodecParameters *webrtc.RTPCodecParameters `json:"codecParameters,omitempty"` + SCTPCapabilities *webrtc.SCTPCapabilities `json:"sctpCapabilities,omitempty"` +} + +type SignalMeta struct { + PeerID string `json:"peerId"` + StreamID string `json:"streamId"` + SessionID string `json:"sessionId"` +} + +type Peer struct { + me *webrtc.MediaEngine + id string + pid string + sid string + api *webrtc.API + ice *webrtc.ICETransport + sctp *webrtc.SCTPTransport + dtls *webrtc.DTLSTransport + provider *Provider + gatherer *webrtc.ICEGatherer + localTracks []webrtc.TrackLocal + datachannels []string +} + +func New(iceServers []webrtc.ICEServer, logger logr.Logger) *Provider { + return &Provider{ + log: logger, + peers: make(map[string]*Peer), + iceServers: iceServers, + } +} + +func (p *Provider) SetSettingEngine(se webrtc.SettingEngine) { + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + p.se = se +} + +func (p *Provider) SetSignaler(signaler func(meta SignalMeta, signal []byte) ([]byte, error)) { + p.signal = signaler +} + +func (p *Provider) OnRemoteStream(fn func(meta SignalMeta, receiver *webrtc.RTPReceiver, codec *webrtc.RTPCodecParameters)) { + p.onRemote = fn +} + +func (p *Provider) OnDatachannel(fn func(meta SignalMeta, dc *webrtc.DataChannel)) { + p.onDatachannel = fn +} + +func (p *Provider) AddDataChannels(sessionID, peerID string, labels []string) error { + var r *Peer + var err error + p.mu.RLock() + r = p.peers[peerID] + p.mu.RUnlock() + if r == nil { + r, err = p.newRelay(sessionID, peerID) + if err != nil { + return err + } + } + if r.ice.State() != webrtc.ICETransportStateNew { + r.datachannels = labels + return r.startDataChannels() + } + r.datachannels = labels + return nil +} + +func (p *Provider) Send(sessionID, peerID string, receiver *webrtc.RTPReceiver, localTrack webrtc.TrackLocal) (*Peer, *webrtc.RTPSender, error) { + p.mu.RLock() + if r, ok := p.peers[peerID]; ok { + p.mu.RUnlock() + s, err := r.send(receiver, localTrack) + return r, s, err + } + p.mu.RUnlock() + + r, err := p.newRelay(sessionID, peerID) + if err != nil { + return nil, nil, err + } + + s, err := r.send(receiver, localTrack) + return r, s, err +} + +func (p *Provider) Receive(remoteSignal []byte) ([]byte, error) { + s := Signal{} + if err := json.Unmarshal(remoteSignal, &s); err != nil { + return nil, err + } + + p.mu.RLock() + if r, ok := p.peers[s.Metadata.PeerID]; ok { + p.mu.RUnlock() + return r.receive(s) + } + p.mu.RUnlock() + + r, err := p.newRelay(s.Metadata.SessionID, s.Metadata.PeerID) + if err != nil { + return nil, err + } + + return r.receive(s) +} + +func (p *Provider) newRelay(sessionID, peerID string) (*Peer, error) { + // Prepare ICE gathering options + iceOptions := webrtc.ICEGatherOptions{ + ICEServers: p.iceServers, + } + me := webrtc.MediaEngine{} + // Create an API object + api := webrtc.NewAPI(webrtc.WithMediaEngine(&me), webrtc.WithSettingEngine(p.se)) + // Create the ICE gatherer + gatherer, err := api.NewICEGatherer(iceOptions) + if err != nil { + return nil, err + } + // Construct the ICE transport + i := api.NewICETransport(gatherer) + // Construct the DTLS transport + dtls, err := api.NewDTLSTransport(i, nil) + // Construct the SCTP transport + sctp := api.NewSCTPTransport(dtls) + if err != nil { + return nil, err + } + r := &Peer{ + me: &me, + pid: peerID, + sid: sessionID, + api: api, + ice: i, + sctp: sctp, + dtls: dtls, + provider: p, + gatherer: gatherer, + } + + p.mu.Lock() + p.peers[peerID] = r + p.mu.Unlock() + + if p.onDatachannel != nil { + sctp.OnDataChannel( + func(channel *webrtc.DataChannel) { + p.onDatachannel(SignalMeta{ + PeerID: peerID, + StreamID: r.id, + SessionID: sessionID, + }, channel) + }) + } + + i.OnConnectionStateChange(func(state webrtc.ICETransportState) { + if state == webrtc.ICETransportStateFailed || state == webrtc.ICETransportStateDisconnected { + p.mu.Lock() + delete(p.peers, peerID) + p.mu.Unlock() + if err := r.gatherer.Close(); err != nil { + p.log.Error(err, "Error closing ice gatherer", "peer_id", r.pid) + } + if err := r.ice.Stop(); err != nil { + p.log.Error(err, "Error stopping ice transport", "peer_id", r.pid) + } + if err := r.dtls.Stop(); err != nil { + p.log.Error(err, "Error stopping dtls transport", "peer_id", r.pid) + } + } + }) + + return r, nil +} + +func (r *Peer) WriteRTCP(pkts []rtcp.Packet) error { + _, err := r.dtls.WriteRTCP(pkts) + return err +} + +func (r *Peer) LocalTracks() []webrtc.TrackLocal { + return r.localTracks +} + +func (r *Peer) startDataChannels() error { + if len(r.datachannels) == 0 { + return nil + } + for idx, label := range r.datachannels { + id := uint16(idx) + dcParams := &webrtc.DataChannelParameters{ + Label: label, + ID: &id, + } + channel, err := r.api.NewDataChannel(r.sctp, dcParams) + if err != nil { + return err + } + if r.provider.onDatachannel != nil { + r.provider.onDatachannel(SignalMeta{ + PeerID: r.pid, + StreamID: r.id, + SessionID: r.sid, + }, channel) + } + } + return nil +} + +func (r *Peer) receive(s Signal) ([]byte, error) { + if r.gatherer.State() == webrtc.ICEGathererStateNew { + r.id = s.Metadata.StreamID + gatherFinished := make(chan struct{}) + r.gatherer.OnLocalCandidate(func(i *webrtc.ICECandidate) { + if i == nil { + close(gatherFinished) + } + }) + // Gather candidates + if err := r.gatherer.Gather(); err != nil { + return nil, err + } + <-gatherFinished + } + + var k webrtc.RTPCodecType + switch { + case strings.HasPrefix(s.CodecParameters.MimeType, "audio/"): + k = webrtc.RTPCodecTypeAudio + case strings.HasPrefix(s.CodecParameters.MimeType, "video/"): + k = webrtc.RTPCodecTypeVideo + default: + k = webrtc.RTPCodecType(0) + } + if err := r.me.RegisterCodec(*s.CodecParameters, k); err != nil { + return nil, err + } + + iceCandidates, err := r.gatherer.GetLocalCandidates() + if err != nil { + return nil, err + } + + iceParams, err := r.gatherer.GetLocalParameters() + if err != nil { + return nil, err + } + + dtlsParams, err := r.dtls.GetLocalParameters() + if err != nil { + return nil, err + } + + sctpCapabilities := r.sctp.GetCapabilities() + + localSignal := Signal{ + ICECandidates: iceCandidates, + ICEParameters: iceParams, + DTLSParameters: dtlsParams, + SCTPCapabilities: &sctpCapabilities, + } + + if err = r.ice.SetRemoteCandidates(s.ICECandidates); err != nil { + return nil, err + } + + recv, err := r.api.NewRTPReceiver(k, r.dtls) + if err != nil { + return nil, err + } + + if r.ice.State() == webrtc.ICETransportStateNew { + go func() { + iceRole := webrtc.ICERoleControlled + if err = r.ice.Start(nil, s.ICEParameters, &iceRole); err != nil { + r.provider.log.Error(err, "Start ICE error") + return + } + + if err = r.dtls.Start(s.DTLSParameters); err != nil { + r.provider.log.Error(err, "Start DTLS error") + return + } + + if s.SCTPCapabilities != nil { + if err = r.sctp.Start(*s.SCTPCapabilities); err != nil { + r.provider.log.Error(err, "Start SCTP error") + return + } + } + + if err = recv.Receive(webrtc.RTPReceiveParameters{Encodings: []webrtc.RTPDecodingParameters{ + { + webrtc.RTPCodingParameters{ + RID: s.Encodings.RID, + SSRC: s.Encodings.SSRC, + PayloadType: s.Encodings.PayloadType, + }, + }, + }}); err != nil { + r.provider.log.Error(err, "Start receiver error") + return + } + + if r.provider.onRemote != nil { + r.provider.onRemote(SignalMeta{ + PeerID: s.Metadata.PeerID, + StreamID: s.Metadata.StreamID, + SessionID: s.Metadata.SessionID, + }, recv, s.CodecParameters) + } + }() + } else { + if err = recv.Receive(webrtc.RTPReceiveParameters{Encodings: []webrtc.RTPDecodingParameters{ + { + webrtc.RTPCodingParameters{ + RID: s.Encodings.RID, + SSRC: s.Encodings.SSRC, + PayloadType: s.Encodings.PayloadType, + }, + }, + }}); err != nil { + return nil, err + } + + if r.provider.onRemote != nil { + r.provider.onRemote(SignalMeta{ + PeerID: s.Metadata.PeerID, + StreamID: s.Metadata.StreamID, + SessionID: s.Metadata.SessionID, + }, recv, s.CodecParameters) + } + } + + b, err := json.Marshal(localSignal) + if err != nil { + return nil, err + } + + return b, nil +} + +func (r *Peer) send(receiver *webrtc.RTPReceiver, localTrack webrtc.TrackLocal) (*webrtc.RTPSender, error) { + if r.gatherer.State() == webrtc.ICEGathererStateNew { + gatherFinished := make(chan struct{}) + r.gatherer.OnLocalCandidate(func(i *webrtc.ICECandidate) { + if i == nil { + close(gatherFinished) + } + }) + // Gather candidates + if err := r.gatherer.Gather(); err != nil { + return nil, err + } + <-gatherFinished + } + t := receiver.Track() + codec := receiver.Track().Codec() + sdr, err := r.api.NewRTPSender(localTrack, r.dtls) + r.id = t.StreamID() + if err != nil { + return nil, err + } + if err = r.me.RegisterCodec(codec, t.Kind()); err != nil { + return nil, err + } + + iceCandidates, err := r.gatherer.GetLocalCandidates() + if err != nil { + return nil, err + } + + iceParams, err := r.gatherer.GetLocalParameters() + if err != nil { + return nil, err + } + + dtlsParams, err := r.dtls.GetLocalParameters() + if err != nil { + return nil, err + } + + sctpCapabilities := r.sctp.GetCapabilities() + + signal := &Signal{ + Metadata: SignalMeta{ + PeerID: r.pid, + StreamID: r.id, + SessionID: r.sid, + }, + ICECandidates: iceCandidates, + ICEParameters: iceParams, + DTLSParameters: dtlsParams, + SCTPCapabilities: &sctpCapabilities, + CodecParameters: &codec, + Encodings: &webrtc.RTPCodingParameters{ + SSRC: t.SSRC(), + PayloadType: t.PayloadType(), + }, + } + local, err := json.Marshal(signal) + if err != nil { + return nil, err + } + + remote, err := r.provider.signal(SignalMeta{ + PeerID: r.pid, + StreamID: r.id, + SessionID: r.sid, + }, local) + if err != nil { + return nil, err + } + var remoteSignal Signal + if err = json.Unmarshal(remote, &remoteSignal); err != nil { + return nil, err + } + + if err = r.ice.SetRemoteCandidates(remoteSignal.ICECandidates); err != nil { + return nil, err + } + + if r.ice.State() == webrtc.ICETransportStateNew { + iceRole := webrtc.ICERoleControlling + if err = r.ice.Start(nil, remoteSignal.ICEParameters, &iceRole); err != nil { + return nil, err + } + + if err = r.dtls.Start(remoteSignal.DTLSParameters); err != nil { + return nil, err + } + + if remoteSignal.SCTPCapabilities != nil { + if err = r.sctp.Start(*remoteSignal.SCTPCapabilities); err != nil { + return nil, err + } + } + } + params := receiver.GetParameters() + + if err = sdr.Send(webrtc.RTPSendParameters{ + RTPParameters: params, + Encodings: []webrtc.RTPEncodingParameters{ + { + webrtc.RTPCodingParameters{ + SSRC: t.SSRC(), + PayloadType: t.PayloadType(), + RID: t.RID(), + }, + }, + }, + }); err != nil { + return nil, err + } + + if err = r.startDataChannels(); err != nil { + return nil, err + } + r.localTracks = append(r.localTracks, localTrack) + return sdr, nil +} diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 1192c9be6..87a6b9a63 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -102,7 +102,9 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, if strings.HasPrefix(d.codec.MimeType, "video/") { d.sequencer = newSequencer(d.maxTrack) } - d.onBind() + if d.onBind != nil { + d.onBind() + } d.bound.set(true) return codec, nil } @@ -144,7 +146,7 @@ func (d *DownTrack) SetTransceiver(transceiver *webrtc.RTPTransceiver) { } // WriteRTP writes a RTP Packet to the DownTrack -func (d *DownTrack) WriteRTP(p buffer.ExtPacket) error { +func (d *DownTrack) WriteRTP(p *buffer.ExtPacket) error { if !d.enabled.get() || !d.bound.get() { return nil } @@ -196,7 +198,7 @@ func (d *DownTrack) SwitchSpatialLayer(targetLayer int64, setAsMax bool) { currentLayer == uint16(targetLayer) { return } - if err := d.receiver.SubDownTrack(d, int(targetLayer)); err == nil { + if err := d.receiver.SwitchDownTrack(d, int(targetLayer)); err == nil { atomic.StoreInt32(&d.spatialLayer, int32(targetLayer<<16)|int32(currentLayer)) atomic.StoreInt64(&d.skipFB, 4) if setAsMax { @@ -278,7 +280,7 @@ func (d *DownTrack) UpdateStats(packetLen uint32) { atomic.AddUint32(&d.packetCount, 1) } -func (d *DownTrack) writeSimpleRTP(extPkt buffer.ExtPacket) error { +func (d *DownTrack) writeSimpleRTP(extPkt *buffer.ExtPacket) error { if d.reSync.get() { if d.Kind() == webrtc.RTPCodecTypeVideo { if !extPkt.KeyFrame { @@ -306,19 +308,20 @@ func (d *DownTrack) writeSimpleRTP(extPkt buffer.ExtPacket) error { atomic.StoreInt64(&d.lastPacketMs, extPkt.Arrival/1e6) atomic.StoreUint32(&d.lastTS, newTS) } - extPkt.Packet.PayloadType = d.payloadType - extPkt.Packet.Timestamp = newTS - extPkt.Packet.SequenceNumber = newSN - extPkt.Packet.SSRC = d.ssrc + hdr := extPkt.Packet.Header + hdr.PayloadType = d.payloadType + hdr.Timestamp = newTS + hdr.SequenceNumber = newSN + hdr.SSRC = d.ssrc - _, err := d.writeStream.WriteRTP(&extPkt.Packet.Header, extPkt.Packet.Payload) + _, err := d.writeStream.WriteRTP(&hdr, extPkt.Packet.Payload) if err != nil { Logger.Error(err, "Write packet err") } return err } -func (d *DownTrack) writeSimulcastRTP(extPkt buffer.ExtPacket) error { +func (d *DownTrack) writeSimulcastRTP(extPkt *buffer.ExtPacket) error { // Check if packet SSRC is different from before // if true, the video source changed d.Lock() @@ -328,15 +331,12 @@ func (d *DownTrack) writeSimulcastRTP(extPkt buffer.ExtPacket) error { layer := atomic.LoadInt32(&d.spatialLayer) currentLayer := uint16(layer) targetLayer := uint16(layer >> 16) - if currentLayer == targetLayer && lastSSRC != 0 && !reSync { + if currentLayer == targetLayer && !reSync { d.Unlock() return nil } - if reSync && d.simulcast.lTSCalc != 0 { - d.simulcast.lTSCalc = extPkt.Arrival - } // Wait for a keyframe to sync new source - if !extPkt.KeyFrame { + if reSync && !extPkt.KeyFrame { // Packet is not a keyframe, discard it d.receiver.SendRTCP([]rtcp.Packet{ &rtcp.PictureLossIndication{SenderSSRC: d.ssrc, MediaSSRC: extPkt.Packet.SSRC}, @@ -344,12 +344,15 @@ func (d *DownTrack) writeSimulcastRTP(extPkt buffer.ExtPacket) error { d.Unlock() return nil } + if reSync && d.simulcast.lTSCalc != 0 { + d.simulcast.lTSCalc = extPkt.Arrival + } // Switch is done remove sender from previous layer // and update current layer - if currentLayer != targetLayer { - go d.receiver.DeleteDownTrack(int(currentLayer), d.peerID) + if currentLayer != targetLayer && !reSync { + d.receiver.DeleteDownTrack(int(currentLayer), d.peerID) + atomic.StoreInt32(&d.spatialLayer, int32(targetLayer)<<16|int32(targetLayer)) } - atomic.StoreInt32(&d.spatialLayer, int32(targetLayer)<<16|int32(targetLayer)) if d.simulcast.temporalSupported { if d.mime == "video/vp8" { @@ -422,12 +425,13 @@ func (d *DownTrack) writeSimulcastRTP(extPkt buffer.ExtPacket) error { // Update base d.simulcast.lTSCalc = extPkt.Arrival // Update extPkt headers - extPkt.Packet.SequenceNumber = newSN - extPkt.Packet.Timestamp = newTS - extPkt.Packet.Header.SSRC = d.ssrc - extPkt.Packet.Header.PayloadType = d.payloadType + hdr := extPkt.Packet.Header + hdr.SequenceNumber = newSN + hdr.Timestamp = newTS + hdr.SSRC = d.ssrc + hdr.PayloadType = d.payloadType - _, err := d.writeStream.WriteRTP(&extPkt.Packet.Header, extPkt.Packet.Payload) + _, err := d.writeStream.WriteRTP(&hdr, extPkt.Packet.Payload) if err != nil { Logger.Error(err, "Write packet err") } diff --git a/pkg/sfu/helpers.go b/pkg/sfu/helpers.go index 2ceb81b6f..b314780d8 100644 --- a/pkg/sfu/helpers.go +++ b/pkg/sfu/helpers.go @@ -30,7 +30,7 @@ func (a *atomicBool) get() bool { // setVp8TemporalLayer is a helper to detect and modify accordingly the vp8 payload to reflect // temporal changes in the SFU. // VP8 temporal layers implemented according https://tools.ietf.org/html/rfc7741 -func setVP8TemporalLayer(p buffer.ExtPacket, s *DownTrack) (payload []byte, picID uint16, tlz0Idx uint8, drop bool) { +func setVP8TemporalLayer(p *buffer.ExtPacket, s *DownTrack) (payload []byte, picID uint16, tlz0Idx uint8, drop bool) { pkt, ok := p.Payload.(buffer.VP8) if !ok { return p.Packet.Payload, 0, 0, false diff --git a/pkg/sfu/peer.go b/pkg/sfu/peer.go index 7105330fb..f9fd48c7d 100644 --- a/pkg/sfu/peer.go +++ b/pkg/sfu/peer.go @@ -30,6 +30,8 @@ type JoinConfig struct { NoPublish bool // If true the peer will not be allowed to subscribe to other peers in session. NoSubscribe bool + // If true it will relay all the published tracks of the peer + Relay bool } // SessionProvider provides the session to the sfu.Peer{} @@ -130,7 +132,7 @@ func (p *Peer) Join(sid, uid string, config ...JoinConfig) error { } if !conf.NoPublish { - p.publisher, err = NewPublisher(p.session, uid, cfg) + p.publisher, err = NewPublisher(p.session, uid, conf.Relay, cfg) if err != nil { return fmt.Errorf("error creating transport: %v", err) } diff --git a/pkg/sfu/publisher.go b/pkg/sfu/publisher.go index 80b9c21b1..008a15e77 100644 --- a/pkg/sfu/publisher.go +++ b/pkg/sfu/publisher.go @@ -1,8 +1,14 @@ package sfu import ( + "io" "sync" "sync/atomic" + "time" + + "github.com/pion/rtcp" + + "github.com/pion/ion-sfu/pkg/relay" "github.com/pion/webrtc/v3" ) @@ -13,6 +19,7 @@ type Publisher struct { router Router session *Session + relayPeer *relay.Peer candidates []webrtc.ICECandidateInit onICEConnectionStateChangeHandler atomic.Value // func(webrtc.ICEConnectionState) @@ -21,7 +28,7 @@ type Publisher struct { } // NewPublisher creates a new Publisher -func NewPublisher(session *Session, id string, cfg WebRTCTransportConfig) (*Publisher, error) { +func NewPublisher(session *Session, id string, relay bool, cfg WebRTCTransportConfig) (*Publisher, error) { me, err := getPublisherMediaEngine() if err != nil { Logger.Error(err, "NewPeer error", "peer_id", id) @@ -43,6 +50,7 @@ func NewPublisher(session *Session, id string, cfg WebRTCTransportConfig) (*Publ session: session, } + var relayReports sync.Once pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { Logger.V(1).Info("Peer got remote track id", "peer_id", p.id, @@ -52,9 +60,46 @@ func NewPublisher(session *Session, id string, cfg WebRTCTransportConfig) (*Publ "stream_id", track.StreamID(), ) - if r, pub := p.router.AddReceiver(receiver, track); pub { + r, pub := p.router.AddReceiver(receiver, track) + if pub { p.session.Publish(p.router, r) } + + if relay && cfg.relay != nil && pub { + codec := track.Codec() + downTrack, err := NewDownTrack(webrtc.RTPCodecCapability{ + MimeType: codec.MimeType, + ClockRate: codec.ClockRate, + Channels: codec.Channels, + SDPFmtpLine: codec.SDPFmtpLine, + RTCPFeedback: []webrtc.RTCPFeedback{{"goog-remb", ""}, {"nack", ""}, {"nack", "pli"}}, + }, r, session.bufferFactory, id, cfg.router.MaxPacketTrack) + if err != nil { + Logger.V(1).Error(err, "Create relay downtrack err", "peer_id", id) + return + } + rr, sdr, err := cfg.relay.Send(session.id, id, receiver, downTrack) + if err != nil { + Logger.V(1).Error(err, "Relay err", "peer_id", id) + return + } + + if p.relayPeer == nil { + relayReports.Do(func() { + p.relayPeer = rr + go p.relayReports() + }) + } + + downTrack.OnCloseHandler(func() { + if err := sdr.Stop(); err != nil { + Logger.V(1).Error(err, "Relay sender close err", "peer_id", id) + } + }) + + r.AddDownTrack(downTrack, true) + } + }) pc.OnDataChannel(func(dc *webrtc.DataChannel) { @@ -80,6 +125,11 @@ func NewPublisher(session *Session, id string, cfg WebRTCTransportConfig) (*Publ } }) + if relay && cfg.relay != nil { + if err = cfg.relay.AddDataChannels(session.id, id, session.getDataChannelLabels()); err != nil { + Logger.Error(err, "Add relaying data channels error") + } + } return p, nil } @@ -141,3 +191,30 @@ func (p *Publisher) AddICECandidate(candidate webrtc.ICECandidateInit) error { p.candidates = append(p.candidates, candidate) return nil } + +func (p *Publisher) relayReports() { + for { + time.Sleep(5 * time.Second) + + var r []rtcp.Packet + for _, t := range p.relayPeer.LocalTracks() { + if dt, ok := t.(*DownTrack); ok { + if !dt.bound.get() { + continue + } + r = append(r, dt.CreateSenderReport()) + } + } + + if len(r) == 0 { + continue + } + + if err := p.relayPeer.WriteRTCP(r); err != nil { + if err == io.EOF || err == io.ErrClosedPipe { + return + } + Logger.Error(err, "Sending downtrack reports err") + } + } +} diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 523ce5986..4b591747d 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -2,6 +2,7 @@ package sfu import ( "io" + "math/rand" "sync" "time" @@ -22,7 +23,7 @@ type Receiver interface { SSRC(layer int) uint32 AddUpTrack(track *webrtc.TrackRemote, buffer *buffer.Buffer, bestQualityFirst bool) AddDownTrack(track *DownTrack, bestQualityFirst bool) - SubDownTrack(track *DownTrack, layer int) error + SwitchDownTrack(track *DownTrack, layer int) error GetBitrate() [3]uint64 GetMaxTemporalLayer() [3]int64 RetransmitPackets(track *DownTrack, packets []packetMeta) error @@ -34,6 +35,7 @@ type Receiver interface { // WebRTCReceiver receives a video track type WebRTCReceiver struct { + sync.Mutex rtcpMu sync.Mutex closeOnce sync.Once @@ -52,6 +54,7 @@ type WebRTCReceiver struct { upTracks [3]*webrtc.TrackRemote stats [3]*stats.Stream downTracks [3][]*DownTrack + pendingTracks [3][]*DownTrack nackWorker *workerpool.WorkerPool isSimulcast bool onCloseHandler func() @@ -106,11 +109,11 @@ func (w *WebRTCReceiver) AddUpTrack(track *webrtc.TrackRemote, buff *buffer.Buff layer = 0 } - w.locks[layer].Lock() + w.Lock() w.upTracks[layer] = track w.buffers[layer] = buff w.downTracks[layer] = make([]*DownTrack, 0, 10) - w.locks[layer].Unlock() + w.Unlock() subBestQuality := func(targetLayer int) { for l := 0; l < targetLayer; l++ { @@ -164,6 +167,7 @@ func (w *WebRTCReceiver) AddUpTrack(track *webrtc.TrackRemote, buff *buffer.Buff func (w *WebRTCReceiver) AddDownTrack(track *DownTrack, bestQualityFirst bool) { layer := 0 if w.isSimulcast { + w.Lock() for i, t := range w.upTracks { if t != nil { layer = i @@ -172,6 +176,7 @@ func (w *WebRTCReceiver) AddDownTrack(track *DownTrack, bestQualityFirst bool) { } } } + w.Unlock() w.locks[layer].Lock() if downTrackSubscribed(w.downTracks[layer], track) { w.locks[layer].Unlock() @@ -180,6 +185,7 @@ func (w *WebRTCReceiver) AddDownTrack(track *DownTrack, bestQualityFirst bool) { track.SetInitialLayers(int64(layer), 2) track.maxSpatialLayer = 2 track.maxTemporalLayer = 2 + track.lastSSRC = w.SSRC(layer) track.trackType = SimulcastDownTrack track.payload = packetFactory.Get().([]byte) } else { @@ -196,14 +202,13 @@ func (w *WebRTCReceiver) AddDownTrack(track *DownTrack, bestQualityFirst bool) { w.locks[layer].Unlock() } -func (w *WebRTCReceiver) SubDownTrack(track *DownTrack, layer int) error { - w.locks[layer].Lock() +func (w *WebRTCReceiver) SwitchDownTrack(track *DownTrack, layer int) error { if buf := w.buffers[layer]; buf != nil { - w.downTracks[layer] = append(w.downTracks[layer], track) + w.locks[layer].Lock() + w.pendingTracks[layer] = append(w.pendingTracks[layer], track) w.locks[layer].Unlock() return nil } - w.locks[layer].Unlock() return errNoReceiverFound } @@ -325,6 +330,10 @@ func (w *WebRTCReceiver) writeRTP(layer int) { go w.closeTracks() }) }() + + pli := []rtcp.Packet{ + &rtcp.PictureLossIndication{SenderSSRC: rand.Uint32(), MediaSSRC: w.SSRC(layer)}, + } var del []int for { @@ -334,14 +343,24 @@ func (w *WebRTCReceiver) writeRTP(layer int) { } w.locks[layer].Lock() + + if w.isSimulcast && len(w.pendingTracks[layer]) > 0 { + if pkt.KeyFrame { + w.downTracks[layer] = append(w.downTracks[layer], w.pendingTracks[layer]...) + w.pendingTracks[layer] = w.pendingTracks[layer][:0] + } else { + w.SendRTCP(pli) + } + } + for idx, dt := range w.downTracks[layer] { - if err := dt.WriteRTP(pkt); err == io.EOF { + if err := dt.WriteRTP(pkt); err == io.EOF || err == io.ErrClosedPipe { del = append(del, idx) } } if len(del) > 0 { - for _, idx := range del { - w.downTracks[layer][idx] = w.downTracks[layer][len(w.downTracks[layer])-1] + for i := len(del) - 1; i >= 0; i-- { + w.downTracks[layer][del[i]] = w.downTracks[layer][len(w.downTracks[layer])-1] w.downTracks[layer][len(w.downTracks[layer])-1] = nil w.downTracks[layer] = w.downTracks[layer][:len(w.downTracks[layer])-1] } diff --git a/pkg/sfu/session.go b/pkg/sfu/session.go index 3eeb08684..7ddac7f7a 100644 --- a/pkg/sfu/session.go +++ b/pkg/sfu/session.go @@ -1,6 +1,7 @@ package sfu import ( + "context" "encoding/json" "sync" "time" @@ -37,6 +38,11 @@ func NewSession(id string, bf *buffer.Factory, dcs []*Datachannel, cfg WebRTCTra } +// ID return session id +func (s *Session) ID() string { + return s.id +} + // AddPublisher adds a transport to the session func (s *Session) AddPeer(peer *Peer) { s.mu.Lock() @@ -58,36 +64,6 @@ func (s *Session) RemovePeer(pid string) { } } -func (s *Session) onMessage(origin, label string, msg webrtc.DataChannelMessage) { - dcs := s.getDataChannels(origin, label) - for _, dc := range dcs { - if msg.IsString { - if err := dc.SendText(string(msg.Data)); err != nil { - Logger.Error(err, "Sending dc message err") - } - } else { - if err := dc.Send(msg.Data); err != nil { - Logger.Error(err, "Sending dc message err") - } - } - } -} - -func (s *Session) getDataChannels(origin, label string) (dcs []*webrtc.DataChannel) { - s.mu.RLock() - defer s.mu.RUnlock() - for pid, p := range s.peers { - if origin == pid { - continue - } - - if dc, ok := p.subscriber.channels[label]; ok && dc.ReadyState() == webrtc.DataChannelStateOpen { - dcs = append(dcs, dc) - } - } - return -} - func (s *Session) AddDatachannel(owner string, dc *webrtc.DataChannel) { label := dc.Label() @@ -204,6 +180,36 @@ func (s *Session) OnClose(f func()) { s.onCloseHandler = f } +func (s *Session) setRelayedDatachannel(peerID string, datachannel *webrtc.DataChannel) { + label := datachannel.Label() + for _, dc := range s.datachannels { + dc := dc + if dc.Label == label { + mws := newDCChain(dc.middlewares) + p := mws.Process(ProcessFunc(func(ctx context.Context, args ProcessArgs) { + if dc.onMessage != nil { + dc.onMessage(ctx, args, s.getDataChannels(peerID, dc.Label)) + } + })) + s.mu.RLock() + peer := s.peers[peerID] + s.mu.RUnlock() + datachannel.OnMessage(func(msg webrtc.DataChannelMessage) { + p.Process(context.Background(), ProcessArgs{ + Peer: peer, + Message: msg, + DataChannel: datachannel, + }) + }) + } + return + } + + datachannel.OnMessage(func(msg webrtc.DataChannelMessage) { + s.onMessage(peerID, label, msg) + }) +} + func (s *Session) audioLevelObserver(audioLevelInterval int) { if audioLevelInterval <= 50 { Logger.V(0).Info("Values near/under 20ms may return unexpected values") @@ -239,7 +245,43 @@ func (s *Session) audioLevelObserver(audioLevelInterval int) { } } -// ID return session id -func (s *Session) ID() string { - return s.id +func (s *Session) onMessage(origin, label string, msg webrtc.DataChannelMessage) { + dcs := s.getDataChannels(origin, label) + for _, dc := range dcs { + if msg.IsString { + if err := dc.SendText(string(msg.Data)); err != nil { + Logger.Error(err, "Sending dc message err") + } + } else { + if err := dc.Send(msg.Data); err != nil { + Logger.Error(err, "Sending dc message err") + } + } + } +} + +func (s *Session) getDataChannels(origin, label string) (dcs []*webrtc.DataChannel) { + s.mu.RLock() + defer s.mu.RUnlock() + for pid, p := range s.peers { + if origin == pid { + continue + } + + if dc, ok := p.subscriber.channels[label]; ok && dc.ReadyState() == webrtc.DataChannelStateOpen { + dcs = append(dcs, dc) + } + } + return +} + +func (s *Session) getDataChannelLabels() []string { + s.mu.RLock() + defer s.mu.RUnlock() + res := make([]string, 0, len(s.datachannels)+len(s.fanOutDCs)) + copy(res, s.fanOutDCs) + for _, dc := range s.datachannels { + res = append(res, dc.Label) + } + return res } diff --git a/pkg/sfu/sfu.go b/pkg/sfu/sfu.go index ffb987ea9..9b076df50 100644 --- a/pkg/sfu/sfu.go +++ b/pkg/sfu/sfu.go @@ -7,6 +7,8 @@ import ( "sync" "time" + "github.com/pion/ion-sfu/pkg/relay" + "github.com/go-logr/logr" "github.com/pion/ice/v2" "github.com/pion/ion-sfu/pkg/buffer" @@ -35,6 +37,7 @@ type WebRTCTransportConfig struct { configuration webrtc.Configuration setting webrtc.SettingEngine router RouterConfig + relay *relay.Provider } // WebRTCConfig defines parameters for ice @@ -55,6 +58,7 @@ type Config struct { WebRTC WebRTCConfig `mapstructure:"webrtc"` Router RouterConfig `mapstructure:"router"` Turn TurnConfig `mapstructure:"turn"` + Relay *relay.Provider BufferFactory *buffer.Factory } @@ -125,6 +129,7 @@ func NewWebRTCTransportConfig(c Config) WebRTCTransportConfig { }, setting: se, router: c.Router, + relay: c.Relay, } if len(c.WebRTC.Candidates.NAT1To1IPs) > 0 { @@ -173,6 +178,10 @@ func NewSFU(c Config) *SFU { bufferFactory: c.BufferFactory, } + if c.Relay != nil { + c.Relay.SetSettingEngine(w.setting) + } + if c.Turn.Enabled { ts, err := InitTurnServer(c.Turn, nil) if err != nil {