diff --git a/pkg/media/input.go b/pkg/media/input.go index aa6a8b4e..c8beb621 100644 --- a/pkg/media/input.go +++ b/pkg/media/input.go @@ -35,6 +35,7 @@ import ( type Source interface { GetSources() []*gst.Element + ValidateCaps(*gst.Caps) error Start(ctx context.Context) error Close() error } @@ -126,6 +127,28 @@ func (i *Input) Close() error { } func (i *Input) onPadAdded(_ *gst.Element, pad *gst.Pad) { + var err error + + defer func() { + if err != nil { + msg := gst.NewErrorMessage(i.bin.Element, err, err.Error(), nil) + i.bin.Element.GetBus().Post(msg) + } + }() + + typefind, err := i.bin.GetElementByName("typefind") + if err == nil && typefind != nil { + var caps interface{} + caps, err = typefind.GetProperty("caps") + if err == nil && caps != nil { + err = i.source.ValidateCaps(caps.(*gst.Caps)) + if err != nil { + logger.Infow("input caps validation failed", "error", err) + return + } + } + } + // surface callback for first audio and video pads, plug in fakesink on the rest i.lock.Lock() newPad := false @@ -156,11 +179,16 @@ func (i *Input) onPadAdded(_ *gst.Element, pad *gst.Pad) { } pad = ghostPad.Pad } else { - sink, err := gst.NewElement("fakesink") + var sink *gst.Element + + sink, err = gst.NewElement("fakesink") if err != nil { logger.Errorw("failed to create fakesink", err) + return } - pads, err := sink.GetSinkPads() + var pads []*gst.Pad + + pads, err = sink.GetSinkPads() pad.Link(pads[0]) return } diff --git a/pkg/media/rtmp/appsrc.go b/pkg/media/rtmp/appsrc.go index f57353bb..34545365 100644 --- a/pkg/media/rtmp/appsrc.go +++ b/pkg/media/rtmp/appsrc.go @@ -111,6 +111,23 @@ func (s *RTMPRelaySource) GetSources() []*gst.Element { return []*gst.Element{s.flvSrc.Element} } +func (s *RTMPRelaySource) ValidateCaps(caps *gst.Caps) error { + if caps.GetSize() == 0 { + return errors.ErrUnsupportedDecodeFormat + } + + str := caps.GetStructureAt(0) + if str == nil { + return errors.ErrUnsupportedDecodeFormat + } + + if str.Name() != "video/x-flv" { + return errors.ErrUnsupportedDecodeFormat + } + + return nil +} + type appSrcWriter struct { appSrc *app.Source eos *atomic.Bool diff --git a/pkg/media/urlpull/source.go b/pkg/media/urlpull/source.go index 772cadbb..a896046d 100644 --- a/pkg/media/urlpull/source.go +++ b/pkg/media/urlpull/source.go @@ -23,6 +23,19 @@ import ( "github.com/livekit/ingress/pkg/params" ) +var ( + supportedMimeTypes = []string{ + "audio/x-m4a", + "application/x-hls", + "video/quicktime", + "video/x-matroska", + "video/webm", + "audio/ogg", + "application/x-id3", + "audio/mpeg", + } +) + type URLSource struct { params *params.Params src *gst.Element @@ -85,6 +98,25 @@ func (u *URLSource) GetSources() []*gst.Element { } } +func (s *URLSource) ValidateCaps(caps *gst.Caps) error { + if caps.GetSize() == 0 { + return errors.ErrUnsupportedDecodeFormat + } + + str := caps.GetStructureAt(0) + if str == nil { + return errors.ErrUnsupportedDecodeFormat + } + + for _, mime := range supportedMimeTypes { + if str.Name() == mime { + return nil + } + } + + return errors.ErrUnsupportedDecodeFormat +} + func (u *URLSource) Start(ctx context.Context) error { return nil } diff --git a/pkg/media/whip/whipsrc.go b/pkg/media/whip/whipsrc.go index 08663972..1b5dc6b1 100644 --- a/pkg/media/whip/whipsrc.go +++ b/pkg/media/whip/whipsrc.go @@ -102,6 +102,10 @@ func (s *WHIPSource) GetSources() []*gst.Element { return ret } +func (s *WHIPSource) ValidateCaps(*gst.Caps) error { + return nil +} + func (s *WHIPSource) getRelayUrl(kind types.StreamKind) string { return fmt.Sprintf("%s/%s", s.params.RelayUrl, kind) }