Skip to content

Commit

Permalink
improve logging of rate limits (#18907)
Browse files Browse the repository at this point in the history
* improve logging of rate limits

* fix rate limiting of logs
  • Loading branch information
arvidn authored Nov 21, 2024
1 parent 29da4a3 commit e17e657
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 62 deletions.
5 changes: 3 additions & 2 deletions chia/_tests/core/server/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import logging
import time
from typing import Optional

import pytest
from aiohttp import ClientSession, ClientTimeout, WSCloseCode, WSMessage, WSMsgType, WSServerHandshakeError
Expand Down Expand Up @@ -43,8 +44,8 @@ async def get_block_path(full_node: FullNodeAPI):


class FakeRateLimiter:
def process_msg_and_check(self, msg, capa, capb):
return True
def process_msg_and_check(self, msg, capa, capb) -> Optional[str]:
return None


class TestDos:
Expand Down
76 changes: 38 additions & 38 deletions chia/_tests/core/server/test_rate_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,25 @@ async def test_too_many_messages(self):
r = RateLimiter(incoming=True)
new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40))
for i in range(4999):
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(4999):
response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect

# Non-tx message
r = RateLimiter(incoming=True)
new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40))
for i in range(200):
assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(200):
response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect

Expand All @@ -63,40 +63,40 @@ async def test_large_message(self):
large_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 3 * 1024 * 1024))

r = RateLimiter(incoming=True)
assert r.process_msg_and_check(small_tx_message, rl_v2, rl_v2)
assert not r.process_msg_and_check(large_tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(small_tx_message, rl_v2, rl_v2) is None
assert r.process_msg_and_check(large_tx_message, rl_v2, rl_v2) is not None

small_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 5 * 1024))
large_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 600 * 1024))
r = RateLimiter(incoming=True)
assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2)
assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2)
assert not r.process_msg_and_check(large_vdf_message, rl_v2, rl_v2)
assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2) is None
assert r.process_msg_and_check(small_vdf_message, rl_v2, rl_v2) is None
assert r.process_msg_and_check(large_vdf_message, rl_v2, rl_v2) is not None

@pytest.mark.anyio
async def test_too_much_data(self):
# Too much data
r = RateLimiter(incoming=True)
tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
for i in range(40):
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(300):
response = r.process_msg_and_check(tx_message, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect

r = RateLimiter(incoming=True)
block_message = make_msg(ProtocolMessageTypes.respond_block, bytes([1] * 1024 * 1024))
for i in range(10):
assert r.process_msg_and_check(block_message, rl_v2, rl_v2)
assert r.process_msg_and_check(block_message, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(40):
response = r.process_msg_and_check(block_message, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect

Expand All @@ -109,15 +109,15 @@ async def test_non_tx_aggregate_limits(self):
message_3 = make_msg(ProtocolMessageTypes.plot_sync_start, bytes([1] * 64))

for i in range(500):
assert r.process_msg_and_check(message_1, rl_v2, rl_v2)
assert r.process_msg_and_check(message_1, rl_v2, rl_v2) is None

for i in range(500):
assert r.process_msg_and_check(message_2, rl_v2, rl_v2)
assert r.process_msg_and_check(message_2, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(500):
response = r.process_msg_and_check(message_3, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect

Expand All @@ -127,12 +127,12 @@ async def test_non_tx_aggregate_limits(self):
message_5 = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 49 * 1024 * 1024))

for i in range(2):
assert r.process_msg_and_check(message_4, rl_v2, rl_v2)
assert r.process_msg_and_check(message_4, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(2):
response = r.process_msg_and_check(message_5, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect

Expand All @@ -141,56 +141,56 @@ async def test_periodic_reset(self):
r = RateLimiter(True, 5)
tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
for i in range(10):
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(300):
response = r.process_msg_and_check(tx_message, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect
assert not r.process_msg_and_check(tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is not None
await asyncio.sleep(6)
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(tx_message, rl_v2, rl_v2) is None

# Counts reset also
r = RateLimiter(True, 5)
new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40))
for i in range(4999):
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(4999):
response = r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect
await asyncio.sleep(6)
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_tx_message, rl_v2, rl_v2) is None

@pytest.mark.anyio
async def test_percentage_limits(self):
r = RateLimiter(True, 60, 40)
new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40))
for i in range(50):
assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_peak_message, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(50):
response = r.process_msg_and_check(new_peak_message, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect

r = RateLimiter(True, 60, 40)
block_message = make_msg(ProtocolMessageTypes.respond_block, bytes([1] * 1024 * 1024))
for i in range(5):
assert r.process_msg_and_check(block_message, rl_v2, rl_v2)
assert r.process_msg_and_check(block_message, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(5):
response = r.process_msg_and_check(block_message, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect

Expand All @@ -201,14 +201,14 @@ async def test_percentage_limits(self):
message_3 = make_msg(ProtocolMessageTypes.plot_sync_start, bytes([1] * 32))

for i in range(180):
assert r.process_msg_and_check(message_1, rl_v2, rl_v2)
assert r.process_msg_and_check(message_1, rl_v2, rl_v2) is None
for i in range(180):
assert r.process_msg_and_check(message_2, rl_v2, rl_v2)
assert r.process_msg_and_check(message_2, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(100):
response = r.process_msg_and_check(message_3, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect

Expand All @@ -218,12 +218,12 @@ async def test_percentage_limits(self):
message_5 = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 24 * 1024 * 1024))

for i in range(2):
assert r.process_msg_and_check(message_4, rl_v2, rl_v2)
assert r.process_msg_and_check(message_4, rl_v2, rl_v2) is None

saw_disconnect = False
for i in range(2):
response = r.process_msg_and_check(message_5, rl_v2, rl_v2)
if not response:
if response is not None:
saw_disconnect = True
assert saw_disconnect

Expand All @@ -237,7 +237,7 @@ async def test_too_many_outgoing_messages(self):
passed = 0
blocked = 0
for i in range(non_tx_freq):
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2):
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None:
passed += 1
else:
blocked += 1
Expand All @@ -248,7 +248,7 @@ async def test_too_many_outgoing_messages(self):
# ensure that *another* message type is not blocked because of this

new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1]))
assert r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2) is None

@pytest.mark.anyio
async def test_too_many_incoming_messages(self):
Expand All @@ -260,7 +260,7 @@ async def test_too_many_incoming_messages(self):
passed = 0
blocked = 0
for i in range(non_tx_freq):
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2):
if r.process_msg_and_check(new_peers_message, rl_v2, rl_v2) is None:
passed += 1
else:
blocked += 1
Expand All @@ -271,7 +271,7 @@ async def test_too_many_incoming_messages(self):
# ensure that other message types *are* blocked because of this

new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1]))
assert not r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2)
assert r.process_msg_and_check(new_signatures_message, rl_v2, rl_v2) is not None

@pytest.mark.parametrize(
"node_with_params",
Expand Down
49 changes: 39 additions & 10 deletions chia/server/rate_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import time
from collections import Counter
from typing import Optional

from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.shared_protocol import Capability
Expand Down Expand Up @@ -43,9 +44,11 @@ def __init__(self, incoming: bool, reset_seconds: int = 60, percentage_of_limit:

def process_msg_and_check(
self, message: Message, our_capabilities: list[Capability], peer_capabilities: list[Capability]
) -> bool:
) -> Optional[str]:
"""
Returns True if message can be processed successfully, false if a rate limit is passed.
Returns a string indicating which limit was hit if a rate limit is
exceeded, and the message should be blocked. Returns None if the limit was not
hit and the message is good to be sent or received.
"""

current_minute = int(time.time() // self.reset_seconds)
Expand All @@ -59,7 +62,7 @@ def process_msg_and_check(
message_type = ProtocolMessageTypes(message.type)
except Exception as e:
log.warning(f"Invalid message: {message.type}, {e}")
return True
return None

new_message_counts: int = self.message_counts[message_type] + 1
new_cumulative_size: int = self.message_cumulative_sizes[message_type] + len(message.data)
Expand All @@ -81,25 +84,51 @@ def process_msg_and_check(
new_non_tx_count = self.non_tx_message_counts + 1
new_non_tx_size = self.non_tx_cumulative_size + len(message.data)
if new_non_tx_count > non_tx_freq * proportion_of_limit:
return False
return " ".join(
[
f"non-tx count: {new_non_tx_count}",
f"> {non_tx_freq * proportion_of_limit}",
f"(scale factor: {proportion_of_limit})",
]
)
if new_non_tx_size > non_tx_max_total_size * proportion_of_limit:
return False
return " ".join(
[
f"non-tx size: {new_non_tx_size}",
f"> {non_tx_max_total_size * proportion_of_limit}",
f"(scale factor: {proportion_of_limit})",
]
)
else:
log.warning(f"Message type {message_type} not found in rate limits")
log.warning(
f"Message type {message_type} not found in rate limits (scale factor: {proportion_of_limit})",
)

if limits.max_total_size is None:
limits = dataclasses.replace(limits, max_total_size=limits.frequency * limits.max_size)
assert limits.max_total_size is not None

if new_message_counts > limits.frequency * proportion_of_limit:
return False
return " ".join(
[
f"message count: {new_message_counts}"
f"> {limits.frequency * proportion_of_limit}"
f"(scale factor: {proportion_of_limit})"
]
)
if len(message.data) > limits.max_size:
return False
return f"message size: {len(message.data)} > {limits.max_size}"
if new_cumulative_size > limits.max_total_size * proportion_of_limit:
return False
return " ".join(
[
f"cumulative size: {new_cumulative_size}",
f"> {limits.max_total_size * proportion_of_limit}",
f"(scale factor: {proportion_of_limit})",
]
)

ret = True
return True
return None
finally:
if self.incoming or ret:
# now that we determined that it's OK to send the message, commit the
Expand Down
Loading

0 comments on commit e17e657

Please sign in to comment.