diff --git a/cmd/signal/grpc/server/server.go b/cmd/signal/grpc/server/server.go index a287f5ead..45df5a336 100644 --- a/cmd/signal/grpc/server/server.go +++ b/cmd/signal/grpc/server/server.go @@ -149,12 +149,10 @@ func (s *SFUServer) Signal(stream pb.SFU_SignalServer) error { _, nopub := payload.Join.Config["NoPublish"] _, nosub := payload.Join.Config["NoSubscribe"] - _, relay := payload.Join.Config["Relay"] cfg := sfu.JoinConfig{ NoPublish: nopub, NoSubscribe: nosub, - Relay: relay, } err = peer.Join(payload.Join.Sid, payload.Join.Uid, cfg) if err != nil { diff --git a/pkg/middlewares/datachannel/subscriberapi.go b/pkg/middlewares/datachannel/subscriberapi.go index bcf779d84..843727795 100644 --- a/pkg/middlewares/datachannel/subscriberapi.go +++ b/pkg/middlewares/datachannel/subscriberapi.go @@ -86,7 +86,7 @@ func sendMessage(streamID string, peer sfu.Peer, layers []string, activeLayer in sfu.Logger.Error(err, "unable to marshal active layer message") } - if err := peer.SendDCMessage(sfu.APIChannelLabel, &bytes); err != nil { + if err := peer.SendDCMessage(sfu.APIChannelLabel, bytes); err != nil { sfu.Logger.Error(err, "unable to send ActiveLayerMessage to peer", "peer_id", peer.ID()) } } diff --git a/pkg/relay/relay.go b/pkg/relay/relay.go index fa1b8f302..abd08fa01 100644 --- a/pkg/relay/relay.go +++ b/pkg/relay/relay.go @@ -2,213 +2,279 @@ package relay import ( "encoding/json" + "errors" "fmt" "math/rand" "strings" "sync" "time" - "github.com/pion/rtcp" - - "github.com/pion/ice/v2" - "github.com/go-logr/logr" + "github.com/pion/rtcp" "github.com/pion/webrtc/v3" ) -type Provider struct { - mu sync.RWMutex - se webrtc.SettingEngine - log logr.Logger - peers map[string]*Peer - signal func(meta SignalMeta, signal []byte) ([]byte, error) - onRemote func(meta SignalMeta, receiver *webrtc.RTPReceiver, codec *webrtc.RTPCodecParameters) - iceServers []webrtc.ICEServer - onDatachannel func(meta SignalMeta, dc *webrtc.DataChannel) -} +const ( + signalerLabel = "ion_sfu_relay_signaler" + signalerEvent = "ion_sfu_relay_event" +) -type Signal struct { - Metadata SignalMeta `json:"metadata"` +var ( + ErrRelayPeerNotReady = errors.New("relay Peer is not ready") + ErrRelayPeerSignalDone = errors.New("relay Peer signal already called") + ErrRelaySignalDCNotReady = errors.New("relay Peer data channel is not ready") +) + +type signal struct { Encodings *webrtc.RTPCodingParameters `json:"encodings,omitempty"` ICECandidates []webrtc.ICECandidate `json:"iceCandidates,omitempty"` ICEParameters webrtc.ICEParameters `json:"iceParameters,omitempty"` DTLSParameters webrtc.DTLSParameters `json:"dtlsParameters,omitempty"` - CodecParameters *webrtc.RTPCodecParameters `json:"codecParameters,omitempty"` SCTPCapabilities *webrtc.SCTPCapabilities `json:"sctpCapabilities,omitempty"` + TrackMeta *TrackMeta `json:"trackInfo,omitempty"` } -type SignalMeta struct { - PeerID string `json:"peerId"` - StreamID string `json:"streamId"` - SessionID string `json:"sessionId"` +type signalRequest struct { + ID uint32 `json:"id"` + Signal *signal `json:"signal,omitempty"` } -type Peer struct { - sync.Mutex - me *webrtc.MediaEngine - id string - pid string - sid string - api *webrtc.API - ice *webrtc.ICETransport - sctp *webrtc.SCTPTransport - dtls *webrtc.DTLSTransport - provider *Provider - senders []*webrtc.RTPSender - receivers []*webrtc.RTPReceiver - gatherer *webrtc.ICEGatherer - localTracks []webrtc.TrackLocal - dataChannels []string +type Request struct { + Event string `json:"event"` + Payload []byte `json:"payload"` } -func New(iceServers []webrtc.ICEServer, logger logr.Logger) *Provider { - return &Provider{ - log: logger, - peers: make(map[string]*Peer), - iceServers: iceServers, - } +type TrackMeta struct { + StreamID string `json:"streamId"` + TrackID string `json:"trackId"` + CodecParameters *webrtc.RTPCodecParameters `json:"codecParameters,omitempty"` } -func (p *Provider) SetSettingEngine(se webrtc.SettingEngine) { - se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) - p.se = se +type PeerConfig struct { + SettingEngine webrtc.SettingEngine + ICEServers []webrtc.ICEServer + Logger logr.Logger } -func (p *Provider) SetSignaler(signaler func(meta SignalMeta, signal []byte) ([]byte, error)) { - p.signal = signaler +type PeerMeta struct { + PeerID string `json:"peerId"` + SessionID string `json:"sessionId"` } -func (p *Provider) OnRemoteStream(fn func(meta SignalMeta, receiver *webrtc.RTPReceiver, codec *webrtc.RTPCodecParameters)) { - p.onRemote = fn +type Peer struct { + mu sync.Mutex + me *webrtc.MediaEngine + log logr.Logger + api *webrtc.API + ice *webrtc.ICETransport + meta PeerMeta + sctp *webrtc.SCTPTransport + dtls *webrtc.DTLSTransport + role *webrtc.ICERole + ready bool + senders []*webrtc.RTPSender + receivers []*webrtc.RTPReceiver + pendingSender map[uint32]func() + gatherer *webrtc.ICEGatherer + localTracks []webrtc.TrackLocal + dcIndex uint16 + signalingDC *webrtc.DataChannel + + onReady func() + onRequest func(r Request) + onDataChannel func(channel *webrtc.DataChannel) + onTrack func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, meta *TrackMeta) } -func (p *Provider) OnDatachannel(fn func(meta SignalMeta, dc *webrtc.DataChannel)) { - p.onDatachannel = fn -} +func NewPeer(meta PeerMeta, conf *PeerConfig) (*Peer, error) { + // Prepare ICE gathering options + iceOptions := webrtc.ICEGatherOptions{ + ICEServers: conf.ICEServers, + } + me := webrtc.MediaEngine{} + // Create an API object + api := webrtc.NewAPI(webrtc.WithMediaEngine(&me), webrtc.WithSettingEngine(conf.SettingEngine)) + // Create the ICE gatherer + gatherer, err := api.NewICEGatherer(iceOptions) + if err != nil { + return nil, err + } + // Construct the ICE transport + i := api.NewICETransport(gatherer) + // Construct the DTLS transport + dtls, err := api.NewDTLSTransport(i, nil) + // Construct the SCTP transport + sctp := api.NewSCTPTransport(dtls) + if err != nil { + return nil, err + } -func (p *Provider) AddDataChannels(sessionID, peerID string, labels []string) error { - var r *Peer - var err error - p.mu.RLock() - r = p.peers[peerID] - p.mu.RUnlock() - if r == nil { - r, err = p.newRelay(sessionID, peerID) - if err != nil { - return err + p := &Peer{ + me: &me, + api: api, + log: conf.Logger, + ice: i, + meta: meta, + sctp: sctp, + dtls: dtls, + gatherer: gatherer, + pendingSender: make(map[uint32]func()), + } + + sctp.OnDataChannel(func(channel *webrtc.DataChannel) { + p.mu.Lock() + defer p.mu.Unlock() + if channel.Label() == signalerLabel { + p.signalingDC = channel + channel.OnMessage(p.handleRequest) + p.ready = true + if p.onReady != nil { + p.onReady() + } + return } + + if p.onDataChannel != nil { + p.onDataChannel(channel) + } + }) + + i.OnConnectionStateChange(func(state webrtc.ICETransportState) { + if state == webrtc.ICETransportStateFailed || state == webrtc.ICETransportStateDisconnected { + if err = p.Close(); err != nil { + p.log.Error(err, "Closing relayed p error") + } + } + }) + + return p, nil +} + +// Offer is used for establish the connection of the local relay Peer +// with the remote relay Peer. +// +// If connection is successful OnReady handler will be called +func (p *Peer) Offer(signalFn func(meta PeerMeta, signal []byte) ([]byte, error)) error { + if p.gatherer.State() != webrtc.ICEGathererStateNew { + return ErrRelayPeerSignalDone } - if r.ice.State() != webrtc.ICETransportStateNew { - r.dataChannels = labels - return r.startDataChannels() + + ls := &signal{} + gatherFinished := make(chan struct{}) + p.gatherer.OnLocalCandidate(func(i *webrtc.ICECandidate) { + if i == nil { + close(gatherFinished) + } + }) + // Gather candidates + if err := p.gatherer.Gather(); err != nil { + return err } - r.dataChannels = labels - return nil -} + <-gatherFinished + + var err error -func (p *Provider) Send(sessionID, peerID string, receiver *webrtc.RTPReceiver, remoteTrack *webrtc.TrackRemote, - localTrack webrtc.TrackLocal) (*Peer, *webrtc.RTPSender, error) { - p.mu.RLock() - if r, ok := p.peers[peerID]; ok { - p.mu.RUnlock() - s, err := r.send(receiver, remoteTrack, localTrack) - return r, s, err + if ls.ICECandidates, err = p.gatherer.GetLocalCandidates(); err != nil { + return err + } + if ls.ICEParameters, err = p.gatherer.GetLocalParameters(); err != nil { + return err } - p.mu.RUnlock() + if ls.DTLSParameters, err = p.dtls.GetLocalParameters(); err != nil { + return err + } + + sc := p.sctp.GetCapabilities() + ls.SCTPCapabilities = &sc - r, err := p.newRelay(sessionID, peerID) + role := webrtc.ICERoleControlling + p.role = &role + data, err := json.Marshal(ls) + + remoteSignal, err := signalFn(p.meta, data) if err != nil { - return nil, nil, err + return err } - s, err := r.send(receiver, remoteTrack, localTrack) - return r, s, err -} + rs := &signal{} -func (p *Provider) Receive(remoteSignal []byte) ([]byte, error) { - s := Signal{} - if err := json.Unmarshal(remoteSignal, &s); err != nil { - return nil, err + if err = json.Unmarshal(remoteSignal, rs); err != nil { + return err } - p.mu.RLock() - if r, ok := p.peers[s.Metadata.PeerID]; ok { - p.mu.RUnlock() - return r.receive(s) + if err = p.start(rs); err != nil { + return err } - p.mu.RUnlock() - r, err := p.newRelay(s.Metadata.SessionID, s.Metadata.PeerID) - if err != nil { - return nil, err + if p.signalingDC, err = p.createDataChannel(signalerLabel); err != nil { + return err } - return r.receive(s) + p.signalingDC.OnOpen(func() { + p.mu.Lock() + p.ready = true + p.mu.Unlock() + if p.onReady != nil { + p.onReady() + } + }) + p.signalingDC.OnMessage(p.handleRequest) + return nil } -func (p *Provider) newRelay(sessionID, peerID string) (*Peer, error) { - // Prepare ICE gathering options - iceOptions := webrtc.ICEGatherOptions{ - ICEServers: p.iceServers, +// Answer answers the remote Peer signal signalRequest +func (p *Peer) Answer(request []byte) ([]byte, error) { + if p.gatherer.State() != webrtc.ICEGathererStateNew { + return nil, ErrRelayPeerSignalDone } - me := webrtc.MediaEngine{} - // Create an API object - api := webrtc.NewAPI(webrtc.WithMediaEngine(&me), webrtc.WithSettingEngine(p.se)) - // Create the ICE gatherer - gatherer, err := api.NewICEGatherer(iceOptions) - if err != nil { + + ls := &signal{} + gatherFinished := make(chan struct{}) + p.gatherer.OnLocalCandidate(func(i *webrtc.ICECandidate) { + if i == nil { + close(gatherFinished) + } + }) + // Gather candidates + if err := p.gatherer.Gather(); err != nil { return nil, err } - // Construct the ICE transport - i := api.NewICETransport(gatherer) - // Construct the DTLS transport - dtls, err := api.NewDTLSTransport(i, nil) - // Construct the SCTP transport - sctp := api.NewSCTPTransport(dtls) - if err != nil { + <-gatherFinished + + var err error + + if ls.ICECandidates, err = p.gatherer.GetLocalCandidates(); err != nil { + return nil, err + } + if ls.ICEParameters, err = p.gatherer.GetLocalParameters(); err != nil { return nil, err } - peer := &Peer{ - me: &me, - pid: peerID, - sid: sessionID, - api: api, - ice: i, - sctp: sctp, - dtls: dtls, - provider: p, - gatherer: gatherer, + if ls.DTLSParameters, err = p.dtls.GetLocalParameters(); err != nil { + return nil, err } - p.mu.Lock() - p.peers[peerID] = peer - p.mu.Unlock() + sc := p.sctp.GetCapabilities() + ls.SCTPCapabilities = &sc - if p.onDatachannel != nil { - sctp.OnDataChannel( - func(channel *webrtc.DataChannel) { - p.onDatachannel(SignalMeta{ - PeerID: peerID, - StreamID: peer.id, - SessionID: sessionID, - }, channel) - }) + role := webrtc.ICERoleControlled + p.role = &role + + rs := &signal{} + if err = json.Unmarshal(request, rs); err != nil { + return nil, err } - i.OnConnectionStateChange(func(state webrtc.ICETransportState) { - if state == webrtc.ICETransportStateFailed || state == webrtc.ICETransportStateDisconnected { - p.mu.Lock() - delete(p.peers, peerID) - p.mu.Unlock() - if err := peer.Close(); err != nil { - p.log.Error(err, "Closing relayed peer error", "peer_id", peer.id) - } + go func() { + if err = p.start(rs); err != nil { + p.log.Error(err, "Error starting relay") } - }) + }() - return peer, nil + return json.Marshal(ls) } +// WriteRTCP sends a user provided RTCP packet to the connected Peer. If no Peer is connected the +// packet is discarded. It also runs any configured interceptors. func (p *Peer) WriteRTCP(pkts []rtcp.Packet) error { _, err := p.dtls.WriteRTCP(pkts) return err @@ -218,8 +284,55 @@ func (p *Peer) LocalTracks() []webrtc.TrackLocal { return p.localTracks } -func (p *Peer) Close() error { +// OnReady calls the callback when relay Peer is ready to start sending/receiving and creating DC +func (p *Peer) OnReady(f func()) { + p.mu.Lock() + p.onReady = f + p.mu.Unlock() +} + +// OnRequest calls the callback when Peer gets a request message from remote Peer +func (p *Peer) OnRequest(f func(r Request)) { + p.mu.Lock() + p.onRequest = f + p.mu.Unlock() +} + +// Request is used to send messages to remote Peer that will end in remote Peer. Other +// data channels if used in ion-sfu may act as middlewares or fan outs. +func (p *Peer) Request(r Request) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.signalingDC == nil { + return ErrRelaySignalDCNotReady + } + + b, err := json.Marshal(r) + if err != nil { + return err + } + return p.signalingDC.Send(b) +} + +// OnDataChannel sets an event handler which is invoked when a data +// channel message arrives from a remote Peer. +func (p *Peer) OnDataChannel(f func(channel *webrtc.DataChannel)) { + p.mu.Lock() + p.onDataChannel = f + p.mu.Unlock() +} +// OnTrack sets an event handler which is called when remote track +// arrives from a remote Peer +func (p *Peer) OnTrack(f func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, meta *TrackMeta)) { + p.mu.Lock() + p.onTrack = f + p.mu.Unlock() +} + +// Close ends the relay Peer +func (p *Peer) Close() error { closeErrs := make([]error, 3+len(p.senders)+len(p.receivers)) for _, sdr := range p.senders { closeErrs = append(closeErrs, sdr.Stop()) @@ -233,211 +346,96 @@ func (p *Peer) Close() error { return joinErrs(closeErrs...) } -func (p *Peer) startDataChannels() error { - if len(p.dataChannels) == 0 { - return nil - } - for idx, label := range p.dataChannels { - id := uint16(idx) - dcParams := &webrtc.DataChannelParameters{ - Label: label, - ID: &id, - } - channel, err := p.api.NewDataChannel(p.sctp, dcParams) - if err != nil { - return err - } - if p.provider.onDatachannel != nil { - p.provider.onDatachannel(SignalMeta{ - PeerID: p.pid, - StreamID: p.id, - SessionID: p.sid, - }, channel) - } +// CreateDataChannel creates a new DataChannel object with the given label +func (p *Peer) CreateDataChannel(label string) (*webrtc.DataChannel, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if !p.ready { + return nil, ErrRelayPeerNotReady } - return nil + + return p.createDataChannel(label) } -func (p *Peer) receive(s Signal) ([]byte, error) { - p.Lock() - defer p.Unlock() - - localSignal := Signal{} - if p.gatherer.State() == webrtc.ICEGathererStateNew { - p.id = s.Metadata.StreamID - gatherFinished := make(chan struct{}) - p.gatherer.OnLocalCandidate(func(i *webrtc.ICECandidate) { - if i == nil { - close(gatherFinished) - } - }) - // Gather candidates - if err := p.gatherer.Gather(); err != nil { - return nil, err - } - <-gatherFinished +func (p *Peer) createDataChannel(label string) (*webrtc.DataChannel, error) { + idx := p.dcIndex + p.dcIndex = +1 + dcParams := &webrtc.DataChannelParameters{ + Label: label, + ID: &idx, + Ordered: true, + } + return p.api.NewDataChannel(p.sctp, dcParams) +} - var err error +func (p *Peer) start(s *signal) error { + if err := p.ice.SetRemoteCandidates(s.ICECandidates); err != nil { + return err + } - localSignal.ICECandidates, err = p.gatherer.GetLocalCandidates() - if err != nil { - return nil, err - } + if err := p.ice.Start(p.gatherer, s.ICEParameters, p.role); err != nil { + return err + } - localSignal.ICEParameters, err = p.gatherer.GetLocalParameters() - if err != nil { - return nil, err - } + if err := p.dtls.Start(s.DTLSParameters); err != nil { + return err + } - localSignal.DTLSParameters, err = p.dtls.GetLocalParameters() - if err != nil { - return nil, err + if s.SCTPCapabilities != nil { + if err := p.sctp.Start(*s.SCTPCapabilities); err != nil { + return err } - - sc := p.sctp.GetCapabilities() - localSignal.SCTPCapabilities = &sc - } + return nil +} +func (p *Peer) receive(s *signal) error { var k webrtc.RTPCodecType switch { - case strings.HasPrefix(s.CodecParameters.MimeType, "audio/"): + case strings.HasPrefix(s.TrackMeta.CodecParameters.MimeType, "audio/"): k = webrtc.RTPCodecTypeAudio - case strings.HasPrefix(s.CodecParameters.MimeType, "video/"): + case strings.HasPrefix(s.TrackMeta.CodecParameters.MimeType, "video/"): k = webrtc.RTPCodecTypeVideo default: k = webrtc.RTPCodecType(0) } - if err := p.me.RegisterCodec(*s.CodecParameters, k); err != nil { - return nil, err + if err := p.me.RegisterCodec(*s.TrackMeta.CodecParameters, k); err != nil { + return err } recv, err := p.api.NewRTPReceiver(k, p.dtls) if err != nil { - return nil, err + return err } - if p.ice.State() == webrtc.ICETransportStateNew { - go func() { - iceRole := webrtc.ICERoleControlled - - if err = p.ice.SetRemoteCandidates(s.ICECandidates); err != nil { - p.provider.log.Error(err, "Start ICE error") - return - } - - if err = p.ice.Start(p.gatherer, s.ICEParameters, &iceRole); err != nil { - p.provider.log.Error(err, "Start ICE error") - return - } - - if err = p.dtls.Start(s.DTLSParameters); err != nil { - p.provider.log.Error(err, "Start DTLS error") - return - } - - if s.SCTPCapabilities != nil { - if err = p.sctp.Start(*s.SCTPCapabilities); err != nil { - p.provider.log.Error(err, "Start SCTP error") - return - } - } - - if err = recv.Receive(webrtc.RTPReceiveParameters{Encodings: []webrtc.RTPDecodingParameters{ - { - webrtc.RTPCodingParameters{ - RID: s.Encodings.RID, - SSRC: s.Encodings.SSRC, - PayloadType: s.Encodings.PayloadType, - }, - }, - }}); err != nil { - p.provider.log.Error(err, "Start receiver error") - return - } - - if p.provider.onRemote != nil { - p.provider.onRemote(SignalMeta{ - PeerID: s.Metadata.PeerID, - StreamID: s.Metadata.StreamID, - SessionID: s.Metadata.SessionID, - }, recv, s.CodecParameters) - } - }() - } else { - if err = recv.Receive(webrtc.RTPReceiveParameters{Encodings: []webrtc.RTPDecodingParameters{ - { - webrtc.RTPCodingParameters{ - RID: s.Encodings.RID, - SSRC: s.Encodings.SSRC, - PayloadType: s.Encodings.PayloadType, - }, + if err = recv.Receive(webrtc.RTPReceiveParameters{Encodings: []webrtc.RTPDecodingParameters{ + { + webrtc.RTPCodingParameters{ + RID: s.Encodings.RID, + SSRC: s.Encodings.SSRC, + PayloadType: s.Encodings.PayloadType, }, - }}); err != nil { - return nil, err - } - - if p.provider.onRemote != nil { - p.provider.onRemote(SignalMeta{ - PeerID: s.Metadata.PeerID, - StreamID: s.Metadata.StreamID, - SessionID: s.Metadata.SessionID, - }, recv, s.CodecParameters) - } + }, + }}); err != nil { + } + if p.onTrack != nil { + p.onTrack(recv.Track(), recv, s.TrackMeta) } p.receivers = append(p.receivers, recv) - b, err := json.Marshal(localSignal) - if err != nil { - return nil, err - } - return b, nil + return nil } -func (p *Peer) send(receiver *webrtc.RTPReceiver, remoteTrack *webrtc.TrackRemote, +// Send is used to negotiate a track to the remote peer +func (p *Peer) Send(receiver *webrtc.RTPReceiver, remoteTrack *webrtc.TrackRemote, localTrack webrtc.TrackLocal) (*webrtc.RTPSender, error) { - p.Lock() - defer p.Unlock() - - signal := &Signal{} - if p.gatherer.State() == webrtc.ICEGathererStateNew { - gatherFinished := make(chan struct{}) - p.gatherer.OnLocalCandidate(func(i *webrtc.ICECandidate) { - if i == nil { - close(gatherFinished) - } - }) - // Gather candidates - if err := p.gatherer.Gather(); err != nil { - return nil, err - } - <-gatherFinished - - var err error - - signal.ICECandidates, err = p.gatherer.GetLocalCandidates() - if err != nil { - return nil, err - } - - signal.ICEParameters, err = p.gatherer.GetLocalParameters() - if err != nil { - return nil, err - } - - signal.DTLSParameters, err = p.dtls.GetLocalParameters() - if err != nil { - return nil, err - } - - sc := p.sctp.GetCapabilities() - signal.SCTPCapabilities = &sc + p.mu.Lock() + defer p.mu.Unlock() - } codec := remoteTrack.Codec() sdr, err := p.api.NewRTPSender(localTrack, p.dtls) - p.id = remoteTrack.StreamID() if err != nil { return nil, err } @@ -445,79 +443,117 @@ func (p *Peer) send(receiver *webrtc.RTPReceiver, remoteTrack *webrtc.TrackRemot return nil, err } - signal.Metadata = SignalMeta{ - PeerID: p.pid, - StreamID: p.id, - SessionID: p.sid, + rr := rand.New(rand.NewSource(time.Now().UnixNano())) + s := signalRequest{ + ID: rr.Uint32(), + Signal: &signal{}, + } + s.Signal.TrackMeta = &TrackMeta{ + StreamID: remoteTrack.StreamID(), + TrackID: remoteTrack.ID(), + CodecParameters: &codec, } - signal.CodecParameters = &codec - rr := rand.New(rand.NewSource(time.Now().UnixNano())) - signal.Encodings = &webrtc.RTPCodingParameters{ + s.Signal.Encodings = &webrtc.RTPCodingParameters{ SSRC: webrtc.SSRC(rr.Uint32()), PayloadType: remoteTrack.PayloadType(), } - - local, err := json.Marshal(signal) + pld, err := json.Marshal(&s) if err != nil { return nil, err } - remote, err := p.provider.signal(SignalMeta{ - PeerID: p.pid, - StreamID: p.id, - SessionID: p.sid, - }, local) + req := Request{ + Event: signalerEvent, + Payload: pld, + } + + msg, err := json.Marshal(req) if err != nil { return nil, err } - var remoteSignal Signal - if err = json.Unmarshal(remote, &remoteSignal); err != nil { + + if err = p.signalingDC.Send(msg); err != nil { return nil, err } - if p.ice.State() == webrtc.ICETransportStateNew { - if err = p.ice.SetRemoteCandidates(remoteSignal.ICECandidates); err != nil { - return nil, err - } - iceRole := webrtc.ICERoleControlling - if err = p.ice.Start(p.gatherer, remoteSignal.ICEParameters, &iceRole); err != nil { - return nil, err - } + params := receiver.GetParameters() - if err = p.dtls.Start(remoteSignal.DTLSParameters); err != nil { - return nil, err + p.pendingSender[s.ID] = func() { + if err = sdr.Send(webrtc.RTPSendParameters{ + RTPParameters: params, + Encodings: []webrtc.RTPEncodingParameters{ + { + webrtc.RTPCodingParameters{ + SSRC: s.Signal.Encodings.SSRC, + PayloadType: s.Signal.Encodings.PayloadType, + }, + }, + }, + }); err != nil { + p.log.Error(err, "Send RTPSender failed") } + } + p.localTracks = append(p.localTracks, localTrack) + p.senders = append(p.senders, sdr) + return sdr, nil +} - if remoteSignal.SCTPCapabilities != nil { - if err = p.sctp.Start(*remoteSignal.SCTPCapabilities); err != nil { - return nil, err - } +func (p *Peer) handleRequest(msg webrtc.DataChannelMessage) { + mr := &Request{} + if err := json.Unmarshal(msg.Data, mr); err != nil { + p.log.Error(err, "Error marshaling remote message", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID) + return + } + + if mr.Event != signalerEvent { + p.mu.Lock() + if p.onRequest != nil { + p.onRequest(*mr) } + p.mu.Unlock() + return + } + + r := &signalRequest{} + if err := json.Unmarshal(mr.Payload, r); err != nil { + p.log.Error(err, "Error marshaling remote message", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID) + return + } - if err = p.startDataChannels(); err != nil { - return nil, err + p.mu.Lock() + defer p.mu.Unlock() + + if r.Signal == nil { + if f, ok := p.pendingSender[r.ID]; ok { + f() } + return } - params := receiver.GetParameters() - if err = sdr.Send(webrtc.RTPSendParameters{ - RTPParameters: params, - Encodings: []webrtc.RTPEncodingParameters{ - { - webrtc.RTPCodingParameters{ - SSRC: signal.Encodings.SSRC, - PayloadType: signal.Encodings.PayloadType, - RID: remoteTrack.RID(), - }, - }, - }, - }); err != nil { - return nil, err + if err := p.receive(r.Signal); err != nil { + return + } + rr := &signalRequest{ + ID: r.ID, + } + d, err := json.Marshal(rr) + if err != nil { + p.log.Error(err, "Error marshaling remote signalRequest", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID, "stream_id") + return + } + req := Request{ + Event: signalerEvent, + Payload: d, + } + d, err = json.Marshal(req) + if err != nil { + p.log.Error(err, "Error marshaling response Request", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID, "stream_id") + return + } + if err = p.signalingDC.Send(d); err != nil { + p.log.Error(err, "Error sending response", "peer_id", p.meta.PeerID, "session_id", p.meta.SessionID, "stream_id") } - p.localTracks = append(p.localTracks, localTrack) - p.senders = append(p.senders, sdr) - return sdr, nil } func joinErrs(errs ...error) error { diff --git a/pkg/sfu/errors.go b/pkg/sfu/errors.go index 3dc6fd207..89195c91a 100644 --- a/pkg/sfu/errors.go +++ b/pkg/sfu/errors.go @@ -11,7 +11,7 @@ var ( // Helpers errors errShortPacket = errors.New("packet is not large enough") errNilPacket = errors.New("invalid nil packet") - // Downtrack errors + ErrSpatialNotSupported = errors.New("current track does not support simulcast/SVC") ErrSpatialLayerBusy = errors.New("a spatial layer change is in progress, try latter") ) diff --git a/pkg/sfu/peer.go b/pkg/sfu/peer.go index 3eac90f86..e8d21cc16 100644 --- a/pkg/sfu/peer.go +++ b/pkg/sfu/peer.go @@ -30,7 +30,7 @@ type Peer interface { Publisher() *Publisher Subscriber() *Subscriber Close() error - SendDCMessage(label string, msg *[]byte) error + SendDCMessage(label string, msg []byte) error } // JoinConfig allow adding more control to the peers joining a SessionLocal. @@ -39,8 +39,6 @@ type JoinConfig struct { NoPublish bool // If true the peer will not be allowed to subscribe to other peers in SessionLocal. NoSubscribe bool - // If true it will Relay all the published tracks of the peer - Relay bool } // SessionProvider provides the SessionLocal to the sfu.Peer @@ -144,7 +142,7 @@ func (p *PeerLocal) Join(sid, uid string, config ...JoinConfig) error { } if !conf.NoPublish { - p.publisher, err = NewPublisher(p.session, uid, conf.Relay, &cfg) + p.publisher, err = NewPublisher(uid, p.session, &cfg) if err != nil { return fmt.Errorf("error creating transport: %v", err) } @@ -247,7 +245,7 @@ func (p *PeerLocal) Trickle(candidate webrtc.ICECandidateInit, target int) error return nil } -func (p *PeerLocal) SendDCMessage(label string, msg *[]byte) error { +func (p *PeerLocal) SendDCMessage(label string, msg []byte) error { if p.subscriber == nil { return fmt.Errorf("no subscriber for this peer") } @@ -257,7 +255,7 @@ func (p *PeerLocal) SendDCMessage(label string, msg *[]byte) error { return fmt.Errorf("data channel %s doesn't exist", label) } - if err := dc.SendText(string(*msg)); err != nil { + if err := dc.SendText(string(msg)); err != nil { return fmt.Errorf("failed to send message: %v", err) } return nil diff --git a/pkg/sfu/publisher.go b/pkg/sfu/publisher.go index 4e2fb95b7..0e86039b7 100644 --- a/pkg/sfu/publisher.go +++ b/pkg/sfu/publisher.go @@ -1,25 +1,27 @@ package sfu import ( + "fmt" "io" "sync" "sync/atomic" "time" - "github.com/pion/rtcp" - "github.com/pion/ion-sfu/pkg/relay" - + "github.com/pion/rtcp" "github.com/pion/webrtc/v3" ) type Publisher struct { - id string - pc *webrtc.PeerConnection + mu sync.Mutex + id string + pc *webrtc.PeerConnection + cfg *WebRTCTransportConfig router Router session Session - relayPeer *relay.Peer + tracks []publisherTracks + relayPeer []*relay.Peer candidates []webrtc.ICECandidateInit onICEConnectionStateChangeHandler atomic.Value // func(webrtc.ICEConnectionState) @@ -27,8 +29,13 @@ type Publisher struct { closeOnce sync.Once } +type publisherTracks struct { + track *webrtc.TrackRemote + receiver Receiver +} + // NewPublisher creates a new Publisher -func NewPublisher(session Session, id string, relay bool, cfg *WebRTCTransportConfig) (*Publisher, error) { +func NewPublisher(id string, session Session, cfg *WebRTCTransportConfig) (*Publisher, error) { me, err := getPublisherMediaEngine() if err != nil { Logger.Error(err, "NewPeer error", "peer_id", id) @@ -46,11 +53,11 @@ func NewPublisher(session Session, id string, relay bool, cfg *WebRTCTransportCo p := &Publisher{ id: id, pc: pc, + cfg: cfg, router: newRouter(id, pc, session, cfg), session: session, } - var relayReports sync.Once pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { Logger.V(1).Info("Peer got remote track id", "peer_id", p.id, @@ -63,43 +70,15 @@ func NewPublisher(session Session, id string, relay bool, cfg *WebRTCTransportCo r, pub := p.router.AddReceiver(receiver, track) if pub { p.session.Publish(p.router, r) - } - - if relay && cfg.Relay != nil && pub { - codec := track.Codec() - downTrack, err := NewDownTrack(webrtc.RTPCodecCapability{ - MimeType: codec.MimeType, - ClockRate: codec.ClockRate, - Channels: codec.Channels, - SDPFmtpLine: codec.SDPFmtpLine, - RTCPFeedback: []webrtc.RTCPFeedback{{"goog-remb", ""}, {"nack", ""}, {"nack", "pli"}}, - }, r, cfg.BufferFactory, id, cfg.Router.MaxPacketTrack) - if err != nil { - Logger.V(1).Error(err, "Create Relay downtrack err", "peer_id", id) - return - } - rr, sdr, err := cfg.Relay.Send(session.ID(), id, receiver, track, downTrack) - if err != nil { - Logger.V(1).Error(err, "Relay err", "peer_id", id) - return - } - - if p.relayPeer == nil { - relayReports.Do(func() { - p.relayPeer = rr - go p.relayReports() - }) - } - - downTrack.OnCloseHandler(func() { - if err := sdr.Stop(); err != nil { - Logger.V(1).Error(err, "Relay sender close err", "peer_id", id) + p.mu.Lock() + p.tracks = append(p.tracks, publisherTracks{track, r}) + for _, rp := range p.relayPeer { + if err = p.createRelayTrack(track, r, rp); err != nil { + Logger.V(1).Error(err, "Creating relay track.", "peer_id", p.id) } - }) - - r.AddDownTrack(downTrack, true) + } + p.mu.Unlock() } - }) pc.OnDataChannel(func(dc *webrtc.DataChannel) { @@ -125,11 +104,6 @@ func NewPublisher(session Session, id string, relay bool, cfg *WebRTCTransportCo } }) - if relay && cfg.Relay != nil { - if err = cfg.Relay.AddDataChannels(session.ID(), id, session.GetDataChannelLabels()); err != nil { - Logger.Error(err, "Add relaying data channels error") - } - } return p, nil } @@ -163,10 +137,14 @@ func (p *Publisher) GetRouter() Router { // Close peer func (p *Publisher) Close() { p.closeOnce.Do(func() { - if p.relayPeer != nil { - if err := p.relayPeer.Close(); err != nil { - Logger.Error(err, "relay peer transport close err") + if len(p.relayPeer) > 0 { + p.mu.Lock() + for _, rp := range p.relayPeer { + if err := rp.Close(); err != nil { + Logger.Error(err, "Closing relay peer transport.") + } } + p.mu.Unlock() } p.router.Stop() if err := p.pc.Close(); err != nil { @@ -192,6 +170,45 @@ func (p *Publisher) PeerConnection() *webrtc.PeerConnection { return p.pc } +func (p *Publisher) Relay(ice []webrtc.ICEServer) (*relay.Peer, error) { + rp, err := relay.NewPeer(relay.PeerMeta{ + PeerID: p.id, + SessionID: p.session.ID(), + }, &relay.PeerConfig{ + SettingEngine: p.cfg.Setting, + ICEServers: ice, + Logger: Logger, + }) + if err != nil { + return nil, fmt.Errorf("relay: %w", err) + } + + rp.OnReady(func() { + for _, lbl := range p.session.GetDataChannelLabels() { + if _, err := rp.CreateDataChannel(lbl); err != nil { + Logger.V(1).Error(err, "Creating data channels.", "peer_id", p.id) + } + } + + p.mu.Lock() + for _, tp := range p.tracks { + if err = p.createRelayTrack(tp.track, tp.receiver, rp); err != nil { + Logger.V(1).Error(err, "Creating relay track.", "peer_id", p.id) + } + } + p.relayPeer = append(p.relayPeer, rp) + + go p.relayReports(rp) + p.mu.Unlock() + }) + + if err = rp.Offer(p.cfg.Relay); err != nil { + return nil, fmt.Errorf("relay: %w", err) + } + + return rp, nil +} + // AddICECandidate to peer connection func (p *Publisher) AddICECandidate(candidate webrtc.ICECandidateInit) error { if p.pc.RemoteDescription() != nil { @@ -201,12 +218,42 @@ func (p *Publisher) AddICECandidate(candidate webrtc.ICECandidateInit) error { return nil } -func (p *Publisher) relayReports() { +func (p *Publisher) createRelayTrack(track *webrtc.TrackRemote, receiver Receiver, rp *relay.Peer) error { + codec := track.Codec() + downTrack, err := NewDownTrack(webrtc.RTPCodecCapability{ + MimeType: codec.MimeType, + ClockRate: codec.ClockRate, + Channels: codec.Channels, + SDPFmtpLine: codec.SDPFmtpLine, + RTCPFeedback: []webrtc.RTCPFeedback{{"nack", ""}, {"nack", "pli"}}, + }, receiver, p.cfg.BufferFactory, p.id, p.cfg.Router.MaxPacketTrack) + if err != nil { + Logger.V(1).Error(err, "Create Relay downtrack err", "peer_id", p.id) + return err + } + + sdr, err := rp.Send(receiver.(*WebRTCReceiver).receiver, track, downTrack) + if err != nil { + Logger.V(1).Error(err, "Relaying track.", "peer_id", p.id) + return fmt.Errorf("relay: %w", err) + } + + downTrack.OnCloseHandler(func() { + if err = sdr.Stop(); err != nil { + Logger.V(1).Error(err, "Stopping relay sender.", "peer_id", p.id) + } + }) + + receiver.AddDownTrack(downTrack, true) + return nil +} + +func (p *Publisher) relayReports(rp *relay.Peer) { for { time.Sleep(5 * time.Second) var r []rtcp.Packet - for _, t := range p.relayPeer.LocalTracks() { + for _, t := range rp.LocalTracks() { if dt, ok := t.(*DownTrack); ok { if !dt.bound.get() { continue @@ -219,7 +266,7 @@ func (p *Publisher) relayReports() { continue } - if err := p.relayPeer.WriteRTCP(r); err != nil { + if err := rp.WriteRTCP(r); err != nil { if err == io.EOF || err == io.ErrClosedPipe { return } diff --git a/pkg/sfu/sfu.go b/pkg/sfu/sfu.go index 48e6965c9..595262432 100644 --- a/pkg/sfu/sfu.go +++ b/pkg/sfu/sfu.go @@ -38,7 +38,7 @@ type WebRTCTransportConfig struct { Configuration webrtc.Configuration Setting webrtc.SettingEngine Router RouterConfig - Relay *relay.Provider + Relay func(meta relay.PeerMeta, signal []byte) ([]byte, error) BufferFactory *buffer.Factory } @@ -68,7 +68,7 @@ type Config struct { WebRTC WebRTCConfig `mapstructure:"webrtc"` Router RouterConfig `mapstructure:"Router"` Turn TurnConfig `mapstructure:"turn"` - Relay *relay.Provider + Relay func(meta relay.PeerMeta, signal []byte) ([]byte, error) BufferFactory *buffer.Factory TurnAuth func(username string, realm string, srcAddr net.Addr) ([]byte, bool) } @@ -210,10 +210,6 @@ func NewSFU(c Config) *SFU { withStats: c.Router.WithStats, } - if c.Relay != nil { - c.Relay.SetSettingEngine(w.Setting) - } - if c.Turn.Enabled { ts, err := InitTurnServer(c.Turn, c.TurnAuth) if err != nil {