diff --git a/association.go b/association.go index 9b2aad55..7942795c 100644 --- a/association.go +++ b/association.go @@ -101,6 +101,16 @@ const ( acceptChSize = 16 ) +func getPort(addr net.Addr) uint16 { + switch addr := addr.(type) { + case *net.UDPAddr: + return uint16(addr.Port) + case *net.TCPAddr: + return uint16(addr.Port) + } + return 0 +} + func getAssociationStateString(a uint32) string { switch a { case closed: @@ -337,6 +347,10 @@ func createAssociation(config Config) *Association { log: config.LoggerFactory.NewLogger("sctp"), } + if config.NetConn != nil { + a.sourcePort = getPort(config.NetConn.LocalAddr()) + a.destinationPort = getPort(config.NetConn.RemoteAddr()) + } a.name = fmt.Sprintf("%p", a) // RFC 4690 Sec 7.2.1 @@ -399,8 +413,12 @@ func (a *Association) sendInit() error { outbound := &packet{} outbound.verificationTag = a.peerVerificationTag - a.sourcePort = defaultSCTPSrcDstPort - a.destinationPort = defaultSCTPSrcDstPort + if a.sourcePort == 0 { + a.sourcePort = defaultSCTPSrcDstPort + } + if a.destinationPort == 0 { + a.destinationPort = defaultSCTPSrcDstPort + } outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort