Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SRTP #232

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/media/dtmf/dtmf.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ func Decode(data []byte) (Event, error) {
}, nil
}

func DecodeRTP(p *rtp.Packet) (Event, bool) {
if !p.Marker {
func DecodeRTP(h *rtp.Header, payload []byte) (Event, bool) {
if !h.Marker {
return Event{}, false
}
ev, err := Decode(p.Payload)
ev, err := Decode(payload)
if err != nil {
return Event{}, false
}
Expand Down
22 changes: 12 additions & 10 deletions pkg/media/rtp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package rtp

import (
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -131,9 +132,11 @@ func (c *Conn) Listen(portMin, portMax int, listenAddr string) error {
if listenAddr == "" {
listenAddr = "0.0.0.0"
}

var err error
c.conn, err = ListenUDPPortRange(portMin, portMax, net.ParseIP(listenAddr))
ip, err := netip.ParseAddr(listenAddr)
if err != nil {
return err
}
c.conn, err = ListenUDPPortRange(portMin, portMax, ip)
if err != nil {
return err
}
Expand Down Expand Up @@ -167,24 +170,23 @@ func (c *Conn) readLoop() {
close(c.received)
}
if h := c.onRTP.Load(); h != nil {
_ = (*h).HandleRTP(&p)
_ = (*h).HandleRTP(&p.Header, p.Payload)
}
}
}

func (c *Conn) WriteRTP(p *rtp.Packet) error {
func (c *Conn) WriteRTP(h *rtp.Header, payload []byte) (int, error) {
addr := c.dest.Load()
if addr == nil {
return nil
return 0, nil
}
data, err := p.Marshal()
data, err := (&rtp.Packet{Header: *h, Payload: payload}).Marshal()
if err != nil {
return err
return 0, err
}
c.wmu.Lock()
defer c.wmu.Unlock()
_, err = c.conn.WriteToUDP(data, addr)
return err
return c.conn.WriteToUDP(data, addr)
}

func (c *Conn) ReadRTP() (*rtp.Packet, *net.UDPAddr, error) {
Expand Down
11 changes: 6 additions & 5 deletions pkg/media/rtp/jitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ package rtp
import (
"time"

"github.com/livekit/server-sdk-go/v2/pkg/jitter"
"github.com/pion/rtp"

"github.com/livekit/server-sdk-go/v2/pkg/jitter"
)

const (
Expand All @@ -41,11 +42,11 @@ type jitterHandler struct {
buf *jitter.Buffer
}

func (h *jitterHandler) HandleRTP(p *rtp.Packet) error {
h.buf.Push(p)
func (r *jitterHandler) HandleRTP(h *rtp.Header, payload []byte) error {
r.buf.Push(&rtp.Packet{Header: *h, Payload: payload})
var last error
for _, p := range h.buf.Pop(false) {
if err := h.h.HandleRTP(p); err != nil {
for _, p := range r.buf.Pop(false) {
if err := r.h.HandleRTP(&p.Header, p.Payload); err != nil {
last = err
}
}
Expand Down
7 changes: 4 additions & 3 deletions pkg/media/rtp/listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ import (
"errors"
"math/rand"
"net"
"net/netip"
)

var ListenErr = errors.New("failed to listen on udp port")

func ListenUDPPortRange(portMin, portMax int, IP net.IP) (*net.UDPConn, error) {
func ListenUDPPortRange(portMin, portMax int, ip netip.Addr) (*net.UDPConn, error) {
if portMin == 0 && portMax == 0 {
return net.ListenUDP("udp", &net.UDPAddr{
IP: IP,
IP: ip.AsSlice(),
Port: 0,
})
}
Expand All @@ -48,7 +49,7 @@ func ListenUDPPortRange(portMin, portMax int, IP net.IP) (*net.UDPConn, error) {
portCurrent := portStart

for {
c, e := net.ListenUDP("udp", &net.UDPAddr{IP: IP, Port: portCurrent})
c, e := net.ListenUDP("udp", &net.UDPAddr{IP: ip.AsSlice(), Port: portCurrent})
if e == nil {
return c, nil
}
Expand Down
18 changes: 9 additions & 9 deletions pkg/media/rtp/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,25 @@ type Mux struct {

// HandleRTP selects a Handler based on payload type.
// Types can be registered with Register. If no handler is set, a default one will be used.
func (m *Mux) HandleRTP(p *rtp.Packet) error {
func (m *Mux) HandleRTP(h *rtp.Header, payload []byte) error {
if m == nil {
return nil
}
var h Handler
var r Handler
m.mu.RLock()
if p.PayloadType < byte(len(m.static)) {
h = m.static[p.PayloadType]
if h.PayloadType < byte(len(m.static)) {
r = m.static[h.PayloadType]
} else {
h = m.dynamic[p.PayloadType]
r = m.dynamic[h.PayloadType]
}
if h == nil {
h = m.def
if r == nil {
r = m.def
}
m.mu.RUnlock()
if h == nil {
if r == nil {
return nil
}
return h.HandleRTP(p)
return r.HandleRTP(h, payload)
}

// SetDefault sets a default RTP handler.
Expand Down
49 changes: 25 additions & 24 deletions pkg/media/rtp/rtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package rtp
import (
"fmt"
"math/rand/v2"
"slices"
"sync"

"github.com/pion/interceptor"
Expand All @@ -31,21 +32,21 @@ type BytesFrame interface {
}

type Writer interface {
WriteRTP(p *rtp.Packet) error
WriteRTP(h *rtp.Header, payload []byte) (int, error)
}

type Reader interface {
ReadRTP() (*rtp.Packet, interceptor.Attributes, error)
}

type Handler interface {
HandleRTP(p *rtp.Packet) error
HandleRTP(h *rtp.Header, payload []byte) error
}

type HandlerFunc func(p *rtp.Packet) error
type HandlerFunc func(h *rtp.Header, payload []byte) error

func (fnc HandlerFunc) HandleRTP(p *rtp.Packet) error {
return fnc(p)
func (fnc HandlerFunc) HandleRTP(h *rtp.Header, payload []byte) error {
return fnc(h, payload)
}

func HandleLoop(r Reader, h Handler) error {
Expand All @@ -54,7 +55,7 @@ func HandleLoop(r Reader, h Handler) error {
if err != nil {
return err
}
err = h.HandleRTP(p)
err = h.HandleRTP(&p.Header, p.Payload)
if err != nil {
return err
}
Expand All @@ -64,26 +65,27 @@ func HandleLoop(r Reader, h Handler) error {
// Buffer is a Writer that clones and appends RTP packets into a slice.
type Buffer []*Packet

func (b *Buffer) WriteRTP(p *Packet) error {
p2 := p.Clone()
*b = append(*b, p2)
func (b *Buffer) WriteRTP(h *rtp.Header, payload []byte) error {
*b = append(*b, &rtp.Packet{
Header: *h,
Payload: slices.Clone(payload),
})
return nil
}

// NewSeqWriter creates an RTP writer that automatically increments the sequence number.
func NewSeqWriter(w Writer) *SeqWriter {
s := &SeqWriter{w: w}
s.p = rtp.Packet{
Header: rtp.Header{
Version: 2,
SSRC: rand.Uint32(),
SequenceNumber: 0,
},
s.h = rtp.Header{
Version: 2,
SSRC: rand.Uint32(),
SequenceNumber: 0,
}
return s
}

type Packet = rtp.Packet
type Header = rtp.Header

type Event struct {
Type byte
Expand All @@ -95,20 +97,19 @@ type Event struct {
type SeqWriter struct {
mu sync.Mutex
w Writer
p Packet
h Header
}

func (s *SeqWriter) WriteEvent(ev *Event) error {
s.mu.Lock()
defer s.mu.Unlock()
s.p.PayloadType = ev.Type
s.p.Payload = ev.Payload
s.p.Marker = ev.Marker
s.p.Timestamp = ev.Timestamp
if err := s.w.WriteRTP(&s.p); err != nil {
s.h.PayloadType = ev.Type
s.h.Marker = ev.Marker
s.h.Timestamp = ev.Timestamp
if _, err := s.w.WriteRTP(&s.h, ev.Payload); err != nil {
return err
}
s.p.Header.SequenceNumber++
s.h.SequenceNumber++
return nil
}

Expand Down Expand Up @@ -211,6 +212,6 @@ func (s *MediaStreamIn[T]) String() string {
return fmt.Sprintf("RTP(%d) -> %s", s.Writer.SampleRate(), s.Writer)
}

func (s *MediaStreamIn[T]) HandleRTP(p *rtp.Packet) error {
return s.Writer.WriteSample(T(p.Payload))
func (s *MediaStreamIn[T]) HandleRTP(_ *rtp.Header, payload []byte) error {
return s.Writer.WriteSample(T(payload))
}
Loading
Loading