Skip to content

Commit

Permalink
feat(sfu): add data channel middlewares (#388)
Browse files Browse the repository at this point in the history
* feat(sfu): Add datachannel middlewares

* feat(sfu): Add datachannel fanout middleware

* fix grpc web sfu initialization

* Add keepalive datachannel middleware

* Simplify middleware api

* Fix tests

* Fix tests and linter issues

* Add readme and forward args to callback
  • Loading branch information
OrlandoCo authored Jan 22, 2021
1 parent f0d71e9 commit ee0517f
Show file tree
Hide file tree
Showing 15 changed files with 296 additions and 75 deletions.
9 changes: 7 additions & 2 deletions cmd/signal/allrpc/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"net"
"net/http"

"github.com/pion/ion-sfu/pkg/middlewares/datachannel"

grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
log "github.com/pion/ion-log"
pb "github.com/pion/ion-sfu/cmd/signal/grpc/proto"
Expand All @@ -25,9 +27,12 @@ type Server struct {
}

// New create a server which support grpc/jsonrpc
func New(c sfu.Config) *Server {
func New(c sfu.Config) *Server { // Register default middlewares
s := sfu.NewSFU(c)
dc := s.NewDatachannel(sfu.APIChannelLabel)
dc.Use(datachannel.SubscriberAPI)
return &Server{
sfu: sfu.NewSFU(c),
sfu: s,
}
}

Expand Down
8 changes: 7 additions & 1 deletion cmd/signal/grpc/grpc-web/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"os"

"github.com/pion/ion-sfu/pkg/middlewares/datachannel"

log "github.com/pion/ion-log"
"github.com/pion/ion-sfu/pkg/sfu"
"github.com/spf13/viper"
Expand Down Expand Up @@ -105,7 +107,11 @@ func main() {
options.Addr = addr
options.AllowAllOrigins = true
options.UseWebSocket = true
s := server.NewWrapperedGRPCWebServer(options, sfu.NewSFU(conf.Config))

nsfu := sfu.NewSFU(conf.Config)
dc := nsfu.NewDatachannel(sfu.APIChannelLabel)
dc.Use(datachannel.SubscriberAPI)
s := server.NewWrapperedGRPCWebServer(options, nsfu)
if err := s.Serve(); err != nil {
log.Panicf("failed to serve: %v", err)
}
Expand Down
9 changes: 8 additions & 1 deletion cmd/signal/grpc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"net/http"
"os"

"github.com/pion/ion-sfu/pkg/middlewares/datachannel"

grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
log "github.com/pion/ion-log"
pb "github.com/pion/ion-sfu/cmd/signal/grpc/proto"
Expand Down Expand Up @@ -135,7 +137,12 @@ func main() {
s := grpc.NewServer(
grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor),
)
pb.RegisterSFUServer(s, server.NewServer(sfu.NewSFU(conf.Config)))

nsfu := sfu.NewSFU(conf.Config)
dc := nsfu.NewDatachannel(sfu.APIChannelLabel)
dc.Use(datachannel.SubscriberAPI)

pb.RegisterSFUServer(s, server.NewServer(nsfu))
grpc_prometheus.Register(s)

go startMetrics(metricsAddr)
Expand Down
15 changes: 9 additions & 6 deletions cmd/signal/json-rpc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@ import (
"fmt"
"net"
"net/http"
"os"

_ "net/http/pprof"
"os"

"github.com/gorilla/websocket"
log "github.com/pion/ion-log"
"github.com/pion/ion-sfu/cmd/signal/json-rpc/server"
"github.com/pion/ion-sfu/pkg/middlewares/datachannel"
"github.com/pion/ion-sfu/pkg/sfu"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sourcegraph/jsonrpc2"
websocketjsonrpc2 "github.com/sourcegraph/jsonrpc2/websocket"
"github.com/spf13/viper"

log "github.com/pion/ion-log"
"github.com/pion/ion-sfu/cmd/signal/json-rpc/server"
"github.com/pion/ion-sfu/pkg/sfu"
)

var (
Expand Down Expand Up @@ -126,7 +125,11 @@ func main() {
log.Init(conf.Log.Level, fixByFile, fixByFunc)

log.Infof("--- Starting SFU Node ---")

s := sfu.NewSFU(conf)
dc := s.NewDatachannel(sfu.APIChannelLabel)
dc.Use(datachannel.SubscriberAPI)

upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
Expand Down
30 changes: 30 additions & 0 deletions pkg/middlewares/datachannel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Datachannels middlewares

`ion-sfu` supports datachannels middlewares similar to the `net/http` standard library handlers.

## API

### Middleware

To create a datachannel middleware, just follow below pattern:

```go
func SubscriberAPI(next sfu.MessageProcessor) sfu.MessageProcessor {
return sfu.ProcessFunc(func(ctx context.Context, args sfu.ProcessArgs) {
next.Process(ctx,args)
}
}
```

### Init middlewares

To initialize the middlewares you need to declare them after sfu initialization:
```go
s := sfu.NewSFU(conf)
dc := s.NewDatachannel(sfu.APIChannelLabel)
dc.Use(datachannel.KeepAlive(5*time.Second), datachannel.SubscriberAPI)
// This callback is optional
dc.OnMessage(func(ctx context.Context, msg webrtc.DataChannelMessage, in *webrtc.DataChannel, out []*webrtc.DataChannel) {
})
```
The datachannels will be negotiated on peer join in the `Subscriber` peer connection.
31 changes: 31 additions & 0 deletions pkg/middlewares/datachannel/keepalive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package datachannel

import (
"bytes"
"context"
"time"

"github.com/pion/ion-sfu/pkg/sfu"
)

func KeepAlive(timeout time.Duration) func(next sfu.MessageProcessor) sfu.MessageProcessor {
var timer *time.Timer
return func(next sfu.MessageProcessor) sfu.MessageProcessor {
return sfu.ProcessFunc(func(ctx context.Context, args sfu.ProcessArgs) {
if timer == nil {
timer = time.AfterFunc(timeout, func() {
_ = args.Peer.Close()
})
}
if args.Message.IsString && bytes.Equal(args.Message.Data, []byte("ping")) {
if !timer.Stop() {
<-timer.C
}
timer.Reset(timeout)
_ = args.DataChannel.SendText("pong")
return
}
next.Process(ctx, args)
})
}
}
17 changes: 8 additions & 9 deletions pkg/sfu/api.go → pkg/middlewares/datachannel/subscriberapi.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package sfu
package datachannel

import (
"context"
"encoding/json"

log "github.com/pion/ion-log"
"github.com/pion/ion-sfu/pkg/sfu"
"github.com/pion/webrtc/v3"
)

const (
apiChannelLabel = "ion-sfu"

videoHighQuality = "high"
videoMediumQuality = "medium"
videoLowQuality = "low"
Expand All @@ -22,14 +21,13 @@ type setRemoteMedia struct {
Audio bool `json:"audio"`
}

func handleAPICommand(s *Subscriber, dc *webrtc.DataChannel) {
dc.OnMessage(func(msg webrtc.DataChannelMessage) {
func SubscriberAPI(next sfu.MessageProcessor) sfu.MessageProcessor {
return sfu.ProcessFunc(func(ctx context.Context, args sfu.ProcessArgs) {
srm := &setRemoteMedia{}
if err := json.Unmarshal(msg.Data, srm); err != nil {
log.Errorf("Unmarshal api command err: %v", err)
if err := json.Unmarshal(args.Message.Data, srm); err != nil {
return
}
downTracks := s.GetDownTracks(srm.StreamID)
downTracks := args.Peer.Subscriber().GetDownTracks(srm.StreamID)

for _, dt := range downTracks {
switch dt.Kind() {
Expand All @@ -51,5 +49,6 @@ func handleAPICommand(s *Subscriber, dc *webrtc.DataChannel) {
}
}
}
next.Process(ctx, args)
})
}
81 changes: 81 additions & 0 deletions pkg/sfu/datachannel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package sfu

import (
"context"

"github.com/pion/webrtc/v3"
)

type (
// Datachannel is a wrapper to define middlewares executed on defined label.
// The datachannels created will be negotiated on join to all peers that joins
// the SFU.
Datachannel struct {
label string
middlewares []func(MessageProcessor) MessageProcessor
onMessage func(ctx context.Context, args ProcessArgs, out []*webrtc.DataChannel)
}

ProcessArgs struct {
Peer *Peer
Message webrtc.DataChannelMessage
DataChannel *webrtc.DataChannel
}

Middlewares []func(MessageProcessor) MessageProcessor

MessageProcessor interface {
Process(ctx context.Context, args ProcessArgs)
}

ProcessFunc func(ctx context.Context, args ProcessArgs)

chainHandler struct {
middlewares Middlewares
Last MessageProcessor
current MessageProcessor
}
)

// Use adds the middlewares to the current Datachannel.
// The middlewares are going to be executed before the OnMessage event fires.
func (dc *Datachannel) Use(middlewares ...func(MessageProcessor) MessageProcessor) {
dc.middlewares = append(dc.middlewares, middlewares...)
}

// OnMessage sets the message callback for the datachannel, the event is fired
// after all the middlewares have processed the message.
func (dc *Datachannel) OnMessage(fn func(ctx context.Context, args ProcessArgs, out []*webrtc.DataChannel)) {
dc.onMessage = fn
}

func (p ProcessFunc) Process(ctx context.Context, args ProcessArgs) {
p(ctx, args)
}

func (mws Middlewares) Process(h MessageProcessor) MessageProcessor {
return &chainHandler{mws, h, chain(mws, h)}
}

func (mws Middlewares) ProcessFunc(h MessageProcessor) MessageProcessor {
return &chainHandler{mws, h, chain(mws, h)}
}

func newDCChain(m []func(p MessageProcessor) MessageProcessor) Middlewares {
return Middlewares(m)
}

func (c *chainHandler) Process(ctx context.Context, args ProcessArgs) {
c.current.Process(ctx, args)
}

func chain(mws []func(processor MessageProcessor) MessageProcessor, last MessageProcessor) MessageProcessor {
if len(mws) == 0 {
return last
}
h := mws[len(mws)-1](last)
for i := len(mws) - 2; i >= 0; i-- {
h = mws[i](h)
}
return h
}
5 changes: 1 addition & 4 deletions pkg/sfu/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@ package sfu
import "errors"

var (
//Peer erors
errPeerConnectionInitFailed = errors.New("pc init failed")
errPtNotSupported = errors.New("payload type not supported")
errCreatingDataChannel = errors.New("failed to create data channel")
// router errors
errNoReceiverFound = errors.New("no receiver found")
// Helpers errors
errShortPacket = errors.New("packet is not large enough")
errNilPacket = errors.New("invalid nil packet")
// buffer errors
errPacketNotFound = errors.New("packet not found in cache")
errPacketTooOld = errors.New("packet not found in cache, too old")
)
25 changes: 22 additions & 3 deletions pkg/sfu/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ type SessionProvider interface {
// Peer represents a pair peer connection
type Peer struct {
sync.Mutex
id string
session *Session
provider SessionProvider
id string
session *Session
provider SessionProvider

publisher *Publisher
subscriber *Subscriber

Expand Down Expand Up @@ -80,6 +81,12 @@ func (p *Peer) Join(sid string, sdp webrtc.SessionDescription) (*webrtc.SessionD
return nil, fmt.Errorf("error creating transport: %v", err)
}

for _, dc := range p.session.datachannels {
if err := p.subscriber.AddDatachannel(p, dc); err != nil {
return nil, fmt.Errorf("error setting subscriber default dc datachannel")
}
}

p.subscriber.OnNegotiationNeeded(func() {
p.Lock()
defer p.Unlock()
Expand Down Expand Up @@ -235,3 +242,15 @@ func (p *Peer) Close() error {
}
return nil
}

func (p *Peer) Subscriber() *Subscriber {
return p.subscriber
}

func (p *Peer) Publisher() *Publisher {
return p.publisher
}

func (p *Peer) Session() *Session {
return p.session
}
3 changes: 1 addition & 2 deletions pkg/sfu/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ type Publisher struct {
session *Session
candidates []webrtc.ICECandidateInit

onTrackHandler func(*webrtc.TrackRemote, *webrtc.RTPReceiver)
onICEConnectionStateChangeHandler atomic.Value // func(webrtc.ICEConnectionState)

closeOnce sync.Once
Expand Down Expand Up @@ -53,7 +52,7 @@ func NewPublisher(session *Session, id string, cfg WebRTCTransportConfig) (*Publ
})

pc.OnDataChannel(func(dc *webrtc.DataChannel) {
if dc.Label() == apiChannelLabel {
if dc.Label() == APIChannelLabel {
// terminate api data channel
return
}
Expand Down
Loading

0 comments on commit ee0517f

Please sign in to comment.