From 19e73c455955c46faec24284828e66ab7d6f176a Mon Sep 17 00:00:00 2001
From: Eric Daniels <eric@erdaniels.com>
Date: Wed, 28 Feb 2024 10:19:29 -0500
Subject: [PATCH] Add SACKs sent and packets stats

---
 association.go       | 21 ++++++++++++-------
 association_stats.go | 50 +++++++++++++++++++++++++++++++++++---------
 association_test.go  | 12 +++++------
 3 files changed, 60 insertions(+), 23 deletions(-)

diff --git a/association.go b/association.go
index 9b2aad55..ce049d36 100644
--- a/association.go
+++ b/association.go
@@ -476,8 +476,11 @@ func (a *Association) Close() error {
 	<-a.readLoopCloseCh
 
 	a.log.Debugf("[%s] association closed", a.name)
+	a.log.Debugf("[%s] stats nPackets (in) : %d", a.name, a.stats.getNumPacketsReceived())
+	a.log.Debugf("[%s] stats nPackets (out) : %d", a.name, a.stats.getNumPacketsSent())
 	a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs())
-	a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKs())
+	a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived())
+	a.log.Debugf("[%s] stats nSACKs (out) : %d\n", a.name, a.stats.getNumSACKsSent())
 	a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts())
 	a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts())
 	a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans())
@@ -547,7 +550,7 @@ func (a *Association) readLoop() {
 
 		a.log.Debugf("[%s] association closed", a.name)
 		a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs())
-		a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKs())
+		a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived())
 		a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts())
 		a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts())
 		a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans())
@@ -596,6 +599,7 @@ loop:
 				break loop
 			}
 			atomic.AddUint64(&a.bytesSent, uint64(len(raw)))
+			a.stats.incPacketsSent()
 		}
 
 		if !ok {
@@ -670,7 +674,7 @@ func (a *Association) handleInbound(raw []byte) error {
 		return nil
 	}
 
-	a.handleChunkStart()
+	a.handleChunksStart()
 
 	for _, c := range p.chunks {
 		if err := a.handleChunk(p, c); err != nil {
@@ -678,7 +682,7 @@ func (a *Association) handleInbound(raw []byte) error {
 		}
 	}
 
-	a.handleChunkEnd()
+	a.handleChunksEnd()
 
 	return nil
 }
@@ -825,6 +829,7 @@ func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte {
 	if a.ackState == ackStateImmediate {
 		a.ackState = ackStateIdle
 		sack := a.createSelectiveAckChunk()
+		a.stats.incSACKsSent()
 		a.log.Debugf("[%s] sending SACK: %s", a.name, sack)
 		raw, err := a.marshalPacket(a.createPacket([]chunk{sack}))
 		if err != nil {
@@ -1718,7 +1723,7 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error {
 		return nil
 	}
 
-	a.stats.incSACKs()
+	a.stats.incSACKsReceived()
 
 	if sna32GT(a.cumulativeTSNAckPoint, d.cumulativeTSNAck) {
 		// RFC 4960 sec 6.2.1.  Processing a Received SACK
@@ -2377,15 +2382,17 @@ func pack(p *packet) []*packet {
 	return []*packet{p}
 }
 
-func (a *Association) handleChunkStart() {
+func (a *Association) handleChunksStart() {
 	a.lock.Lock()
 	defer a.lock.Unlock()
 
+	a.stats.incPacketsReceived()
+
 	a.delayedAckTriggered = false
 	a.immediateAckTriggered = false
 }
 
-func (a *Association) handleChunkEnd() {
+func (a *Association) handleChunksEnd() {
 	a.lock.Lock()
 	defer a.lock.Unlock()
 
diff --git a/association_stats.go b/association_stats.go
index 60883c47..0e4e581b 100644
--- a/association_stats.go
+++ b/association_stats.go
@@ -8,11 +8,30 @@ import (
 )
 
 type associationStats struct {
-	nDATAs       uint64
-	nSACKs       uint64
-	nT3Timeouts  uint64
-	nAckTimeouts uint64
-	nFastRetrans uint64
+	nPacketsReceived uint64
+	nPacketsSent     uint64
+	nDATAs           uint64
+	nSACKsReceived   uint64
+	nSACKsSent       uint64
+	nT3Timeouts      uint64
+	nAckTimeouts     uint64
+	nFastRetrans     uint64
+}
+
+func (s *associationStats) incPacketsReceived() {
+	atomic.AddUint64(&s.nPacketsReceived, 1)
+}
+
+func (s *associationStats) getNumPacketsReceived() uint64 {
+	return atomic.LoadUint64(&s.nPacketsReceived)
+}
+
+func (s *associationStats) incPacketsSent() {
+	atomic.AddUint64(&s.nPacketsSent, 1)
+}
+
+func (s *associationStats) getNumPacketsSent() uint64 {
+	return atomic.LoadUint64(&s.nPacketsSent)
 }
 
 func (s *associationStats) incDATAs() {
@@ -23,12 +42,20 @@ func (s *associationStats) getNumDATAs() uint64 {
 	return atomic.LoadUint64(&s.nDATAs)
 }
 
-func (s *associationStats) incSACKs() {
-	atomic.AddUint64(&s.nSACKs, 1)
+func (s *associationStats) incSACKsReceived() {
+	atomic.AddUint64(&s.nSACKsReceived, 1)
+}
+
+func (s *associationStats) getNumSACKsReceived() uint64 {
+	return atomic.LoadUint64(&s.nSACKsReceived)
+}
+
+func (s *associationStats) incSACKsSent() {
+	atomic.AddUint64(&s.nSACKsSent, 1)
 }
 
-func (s *associationStats) getNumSACKs() uint64 {
-	return atomic.LoadUint64(&s.nSACKs)
+func (s *associationStats) getNumSACKsSent() uint64 {
+	return atomic.LoadUint64(&s.nSACKsSent)
 }
 
 func (s *associationStats) incT3Timeouts() {
@@ -56,8 +83,11 @@ func (s *associationStats) getNumFastRetrans() uint64 {
 }
 
 func (s *associationStats) reset() {
+	atomic.StoreUint64(&s.nPacketsReceived, 0)
+	atomic.StoreUint64(&s.nPacketsSent, 0)
 	atomic.StoreUint64(&s.nDATAs, 0)
-	atomic.StoreUint64(&s.nSACKs, 0)
+	atomic.StoreUint64(&s.nSACKsReceived, 0)
+	atomic.StoreUint64(&s.nSACKsSent, 0)
 	atomic.StoreUint64(&s.nT3Timeouts, 0)
 	atomic.StoreUint64(&s.nAckTimeouts, 0)
 	atomic.StoreUint64(&s.nFastRetrans, 0)
diff --git a/association_test.go b/association_test.go
index 870968db..3361da27 100644
--- a/association_test.go
+++ b/association_test.go
@@ -1823,7 +1823,7 @@ func TestAssocCongestionControl(t *testing.T) {
 		assert.False(t, inFastRecovery, "should not be in fast-recovery")
 
 		t.Logf("nDATAs      : %d\n", a1.stats.getNumDATAs())
-		t.Logf("nSACKs      : %d\n", a0.stats.getNumSACKs())
+		t.Logf("nSACKs      : %d\n", a0.stats.getNumSACKsReceived())
 		t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts())
 		t.Logf("nFastRetrans: %d\n", a0.stats.getNumFastRetrans())
 
@@ -1907,11 +1907,11 @@ func TestAssocCongestionControl(t *testing.T) {
 		assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty")
 
 		t.Logf("nDATAs      : %d\n", a1.stats.getNumDATAs())
-		t.Logf("nSACKs      : %d\n", a0.stats.getNumSACKs())
+		t.Logf("nSACKs      : %d\n", a0.stats.getNumSACKsReceived())
 		t.Logf("nT3Timeouts : %d\n", a0.stats.getNumT3Timeouts())
 
 		assert.Equal(t, uint64(nPacketsToSend), a1.stats.getNumDATAs(), "packet count mismatch")
-		assert.True(t, a0.stats.getNumSACKs() <= nPacketsToSend/2, "too many sacks")
+		assert.True(t, a0.stats.getNumSACKsReceived() <= nPacketsToSend/2, "too many sacks")
 		assert.Equal(t, uint64(0), a0.stats.getNumT3Timeouts(), "should be no retransmit")
 
 		closeAssociationPair(br, a0, a1)
@@ -2002,7 +2002,7 @@ func TestAssocCongestionControl(t *testing.T) {
 		assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty")
 
 		t.Logf("nDATAs      : %d\n", a1.stats.getNumDATAs())
-		t.Logf("nSACKs      : %d\n", a0.stats.getNumSACKs())
+		t.Logf("nSACKs      : %d\n", a0.stats.getNumSACKsReceived())
 		t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts())
 
 		closeAssociationPair(br, a0, a1)
@@ -2081,11 +2081,11 @@ func TestAssocDelayedAck(t *testing.T) {
 		assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty")
 
 		t.Logf("nDATAs      : %d\n", a1.stats.getNumDATAs())
-		t.Logf("nSACKs      : %d\n", a0.stats.getNumSACKs())
+		t.Logf("nSACKs      : %d\n", a0.stats.getNumSACKsReceived())
 		t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts())
 
 		assert.Equal(t, uint64(1), a1.stats.getNumDATAs(), "DATA chunk count mismatch")
-		assert.Equal(t, a0.stats.getNumSACKs(), a1.stats.getNumDATAs(), "sack count should be equal to the number of data chunks")
+		assert.Equal(t, a0.stats.getNumSACKsReceived(), a1.stats.getNumDATAs(), "sack count should be equal to the number of data chunks")
 		assert.Equal(t, uint64(1), a1.stats.getNumAckTimeouts(), "ackTimeout count mismatch")
 		assert.Equal(t, uint64(0), a0.stats.getNumT3Timeouts(), "should be no retransmit")