Skip to content

Commit

Permalink
Transfer participant integration test (#186)
Browse files Browse the repository at this point in the history
This also makes sure that we do not cancel the transfer request if the RPC handler gets cancelled, and avoids situation where we may return an old transfer result to a 2nd transfer request if one failed before.
  • Loading branch information
biglittlebigben authored Oct 3, 2024
1 parent fb5e210 commit a85b40a
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 44 deletions.
8 changes: 5 additions & 3 deletions pkg/sip/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,8 @@ func (c *inboundCall) transferCall(ctx context.Context, transferTo string) error
return err
}

c.log.Infow("inbound call tranferred", "transferTo", transferTo)

// This is needed to actually terminate the session before a media timeout
c.Close()

Expand All @@ -760,7 +762,7 @@ func (s *Server) newInbound(id LocalTag, invite *sip.Request, inviteTx sip.Serve
invite: invite,
inviteTx: inviteTx,
cancelled: make(chan struct{}),
referDone: make(chan error, 1),
referDone: make(chan error), // Do not buffer the channel to avoid reading a result for an old request
}
c.from, _ = invite.From()
if c.from != nil {
Expand Down Expand Up @@ -1096,14 +1098,14 @@ func (c *sipInbound) handleNotify(req *sip.Request, tx sip.ServerTransaction) er
// Success
select {
case c.referDone <- nil:
default:
case <-time.After(notifyAckTimeout):
}
default:
// Failure
select {
// TODO be more specific in the reported error
case c.referDone <- psrpc.NewErrorf(psrpc.Canceled, "call transfer failed"):
default:
case <-time.After(notifyAckTimeout):
}
}
}
Expand Down
9 changes: 6 additions & 3 deletions pkg/sip/outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"net/netip"
"sort"
"sync"
"time"

"github.com/emiago/sipgo/sip"
"github.com/frostbyte73/core"
Expand Down Expand Up @@ -405,6 +406,8 @@ func (c *outboundCall) transferCall(ctx context.Context, transferTo string) erro
return err
}

c.log.Infow("outbound l tranferred", "transferTo", transferTo)

// This is needed to actually terminate the session before a media timeout
c.CloseWithReason(CallHangup, "call transferred")

Expand All @@ -423,7 +426,7 @@ func (c *Client) newOutbound(id LocalTag, from URI) *sipOutbound {
c: c,
id: id,
from: fromHeader,
referDone: make(chan error, 1),
referDone: make(chan error), // Do not buffer the channel to avoid reading a result for an old request
}
}

Expand Down Expand Up @@ -742,14 +745,14 @@ func (c *sipOutbound) handleNotify(req *sip.Request, tx sip.ServerTransaction) e
// Success
select {
case c.referDone <- nil:
default:
case <-time.After(notifyAckTimeout):
}
default:
// Failure
select {
// TODO be more specific in the reported error
case c.referDone <- psrpc.NewErrorf(psrpc.Canceled, "call transfer failed"):
default:
case <-time.After(notifyAckTimeout):
}
}
}
Expand Down
16 changes: 15 additions & 1 deletion pkg/sip/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,20 @@ import (
"regexp"
"strconv"
"strings"
"time"

"github.com/emiago/sipgo/sip"
"github.com/livekit/psrpc"
"github.com/pkg/errors"
)

var referIdRegexp = regexp.MustCompile(`^refer(;id=(\d+))?$`)
const (
notifyAckTimeout = 5 * time.Second
)

var (
referIdRegexp = regexp.MustCompile(`^refer(;id=(\d+))?$`)
)

type ErrorStatus struct {
StatusCode int
Expand Down Expand Up @@ -171,6 +178,13 @@ func parseNotifyBody(body string) (int, error) {

func handleNotify(req *sip.Request) (method sip.RequestMethod, cseq uint32, status int, err error) {
event := req.GetHeader("Event")
if event == nil {
event = req.GetHeader("o")
}
if event == nil {
return "", 0, 0, psrpc.NewErrorf(psrpc.MalformedRequest, "no event in NOTIFY request")
}

var cseq64 uint64

if m := referIdRegexp.FindStringSubmatch(strings.ToLower(event.Value())); len(m) > 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sip/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (s *Service) CreateSIPParticipantAffinity(ctx context.Context, req *rpc.Int
func (s *Service) TransferSIPParticipant(ctx context.Context, req *rpc.InternalTransferSIPParticipantRequest) (*emptypb.Empty, error) {
s.log.Infow("transfering SIP call", "callID", req.SipCallId, "transferTo", req.TransferTo)

ctx, done := context.WithTimeout(ctx, 30*time.Second)
ctx, done := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
defer done()

// Look for call both in client (outbound) and server (inbound)
Expand Down
145 changes: 127 additions & 18 deletions pkg/siptest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/at-wat/ebml-go/webm"
"github.com/emiago/sipgo"
"github.com/emiago/sipgo/sip"
"github.com/frostbyte73/core"
"github.com/icholy/digest"
"github.com/pion/sdp/v3"

Expand All @@ -56,6 +57,7 @@ type ClientConfig struct {
OnBye func()
OnMediaTimeout func()
OnDTMF func(ev dtmf.Event)
OnRefer func(req *sip.Request)
Codec string
}

Expand Down Expand Up @@ -146,6 +148,14 @@ func NewClient(id string, conf ClientConfig) (*Client, error) {
default:
}
})
cli.sipServer.OnRefer(func(req *sip.Request, tx sip.ServerTransaction) {
if conf.OnRefer != nil {
conf.OnRefer(req)
}

err = tx.Respond(sip.NewResponseFromRequest(req, 202, "Accepted", nil))
tx.Terminate()
})

return cli, nil
}
Expand All @@ -168,7 +178,8 @@ type Client struct {
inviteReq *sip.Request
inviteResp *sip.Response
recordHandler atomic.Pointer[rtp.Handler]
closed atomic.Bool
lastCSeq atomic.Uint32
closed core.Fuse
}

func (c *Client) LocalIP() string {
Expand All @@ -183,23 +194,22 @@ func (c *Client) RemoteHeaders() []sip.Header {
}

func (c *Client) Close() {
if !c.closed.CompareAndSwap(false, true) {
return
}
if c.mediaConn != nil {
c.mediaConn.Close()
}
if c.inviteResp != nil {
c.sendBye()
c.inviteReq = nil
c.inviteResp = nil
}
if c.sipClient != nil {
c.sipClient.Close()
}
if c.sipServer != nil {
c.sipServer.Close()
}
c.closed.Once(func() {
if c.mediaConn != nil {
c.mediaConn.Close()
}
if c.inviteResp != nil {
c.sendBye()
c.inviteReq = nil
c.inviteResp = nil
}
if c.sipClient != nil {
c.sipClient.Close()
}
if c.sipServer != nil {
c.sipServer.Close()
}
})
}

func (c *Client) setupRTPReceiver() {
Expand Down Expand Up @@ -322,6 +332,11 @@ func (c *Client) Dial(ip string, uri string, number string, headers map[string]s
}
c.inviteReq = req
c.inviteResp = resp

if h, ok := req.CSeq(); ok {
c.lastCSeq.Store(h.SeqNo)
}

c.mediaConn.SetDestAddr(dstAddr)
c.log.Debug("client connected", "media-dst", dstAddr)
return nil
Expand Down Expand Up @@ -361,6 +376,10 @@ func (c *Client) sendBye() {
req := sip.NewByeRequest(c.inviteReq, c.inviteResp, nil)
req.AppendHeader(sip.NewHeader("User-Agent", "LiveKit"))

cseq := c.lastCSeq.Add(1)
cseqH, _ := req.CSeq()
cseqH.SeqNo = cseq

tx, err := c.sipClient.TransactionRequest(req)
if err != nil {
return
Expand All @@ -381,6 +400,96 @@ func (c *Client) SendDTMF(digits string) error {
return dtmf.Write(context.Background(), c.audioOut, c.mediaDTMF, c.mediaAudio.GetCurrentTimestamp(), digits)
}

func (c *Client) SendNotify(eventReq *sip.Request, notifyStatus string) error {
var recipient *sip.Uri

if contact, ok := eventReq.Contact(); ok {
recipient = &contact.Address
} else if from, ok := eventReq.From(); ok {
recipient = &from.Address
} else {
return errors.New("missing destination address")
}

req := sip.NewRequest(sip.NOTIFY, recipient)

req.SipVersion = eventReq.SipVersion
sip.CopyHeaders("Via", eventReq, req)

if len(eventReq.GetHeaders("Route")) > 0 {
sip.CopyHeaders("Route", eventReq, req)
} else {
hdrs := c.inviteResp.GetHeaders("Record-Route")
for i := len(hdrs) - 1; i >= 0; i-- {
rrh, ok := hdrs[i].(*sip.RecordRouteHeader)
if !ok {
continue
}

h := rrh.Clone()
req.AppendHeader(h)
}
}

maxForwardsHeader := sip.MaxForwardsHeader(70)
req.AppendHeader(&maxForwardsHeader)

if to, ok := eventReq.To(); ok {
req.AppendHeader((*sip.FromHeader)(to))
} else {
return errors.New("missing To header in REFER request")
}

if from, ok := eventReq.From(); ok {
req.AppendHeader((*sip.ToHeader)(from))
} else {
return errors.New("missing From header in REFER request")
}

if callId, ok := eventReq.CallID(); ok {
req.AppendHeader(callId)
}

ct := sip.ContentTypeHeader("message/sipfrag")
req.AppendHeader(&ct)

cseq := c.lastCSeq.Add(1)
cseqH := &sip.CSeqHeader{
SeqNo: cseq,
MethodName: sip.NOTIFY,
}
req.AppendHeader(cseqH)

req.SetTransport(eventReq.Transport())
req.SetSource(eventReq.Destination())
req.SetDestination(eventReq.Source())

if eventCSeq, ok := eventReq.CSeq(); ok {
req.AppendHeader(sip.NewHeader("Event", fmt.Sprintf("refer;id=%d", eventCSeq.SeqNo)))
} else {
return errors.New("missing CSeq header in REFER request")
}

req.SetBody([]byte(notifyStatus))

tx, err := c.sipClient.TransactionRequest(req)
if err != nil {
return err
}
defer tx.Terminate()

resp, err := getResponse(tx)
if err != nil {
return err
}

if resp.StatusCode != sip.StatusOK {
return fmt.Errorf("NOTIFY failed with status %d", resp.StatusCode)
}

return nil
}

func (c *Client) createOffer() ([]byte, error) {
sessionId := rand.Uint64()

Expand Down
Loading

0 comments on commit a85b40a

Please sign in to comment.