diff --git a/association.go b/association.go index 29e9978c..1f344e12 100644 --- a/association.go +++ b/association.go @@ -212,6 +212,9 @@ type Association struct { partialBytesAcked uint32 inFastRecovery bool fastRecoverExitPoint uint32 + minCwnd uint32 // Minimum congestion window + fastRtxWnd uint32 // Send window for fast retransmit + cwndCAStep uint32 // Step of congestion window increase at Congestion Avoidance // RTX & Ack timer rtoMgr *rtoManager @@ -261,8 +264,16 @@ type Config struct { MaxMessageSize uint32 EnableZeroChecksum bool LoggerFactory logging.LoggerFactory + + // congestion control configuration // RTOMax is the maximum retransmission timeout in milliseconds RTOMax float64 + // Minimum congestion window + MinCwnd uint32 + // Send window for fast retransmit + FastRtxWnd uint32 + // Step of congestion window increase at Congestion Avoidance + CwndCAStep uint32 } // Server accepts a SCTP stream over a conn @@ -325,6 +336,9 @@ func createAssociation(config Config) *Association { netConn: config.NetConn, maxReceiveBufferSize: maxReceiveBufferSize, maxMessageSize: maxMessageSize, + minCwnd: config.MinCwnd, + fastRtxWnd: config.FastRtxWnd, + cwndCAStep: config.CwndCAStep, // These two max values have us not need to follow // 5.1.1 where this peer may be incapable of supporting @@ -803,9 +817,13 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt if a.willRetransmitFast { a.willRetransmitFast = false - toFastRetrans := []chunk{} + toFastRetrans := []*chunkPayloadData{} fastRetransSize := commonHeaderSize + fastRetransWnd := a.MTU() + if fastRetransWnd < a.fastRtxWnd { + fastRetransWnd = a.fastRtxWnd + } for i := 0; ; i++ { c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) if !ok { @@ -831,7 +849,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt // packet. dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData)) - if a.MTU() < fastRetransSize+dataChunkSize { + if fastRetransWnd < fastRetransSize+dataChunkSize { break } @@ -845,10 +863,12 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt } if len(toFastRetrans) > 0 { - raw, err := a.marshalPacket(a.createPacket(toFastRetrans)) - if err != nil { - a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) - } else { + for _, p := range a.bundleDataChunksIntoPackets(toFastRetrans) { + raw, err := a.marshalPacket(p) + if err != nil { + a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) + continue + } rawPackets = append(rawPackets, raw) } } @@ -1115,6 +1135,9 @@ func (a *Association) CWND() uint32 { } func (a *Association) setCWND(cwnd uint32) { + if cwnd < a.minCwnd { + cwnd = a.minCwnd + } atomic.StoreUint32(&a.cwnd, cwnd) } @@ -1720,7 +1743,11 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { // reset partial_bytes_acked to (partial_bytes_acked - cwnd). if a.partialBytesAcked >= a.CWND() && a.pendingQueue.size() > 0 { a.partialBytesAcked -= a.CWND() - a.setCWND(a.CWND() + a.MTU()) + step := a.MTU() + if step < a.cwndCAStep { + step = a.cwndCAStep + } + a.setCWND(a.CWND() + step) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)", a.name, a.CWND(), a.ssthresh, totalBytesAcked) }