Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rate limits (#60) #61

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Netstr/Extensions/MessagingExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.EntityFrameworkCore;
using Netstr.Messaging;
using Netstr.Messaging.Events;
using Netstr.Messaging.Events.Handlers;
using Netstr.Messaging.Events.Handlers.Replaceable;
using Netstr.Messaging.Events.Validators;
Expand Down Expand Up @@ -28,11 +29,13 @@ public static IServiceCollection AddMessaging(this IServiceCollection services)
// event
services.AddSingleton<IEventDispatcher, EventDispatcher>();
services.AddSingleton<IEventHandler, DeleteEventHandler>();
services.AddSingleton<IEventHandler, RegularEventHandler>();
services.AddSingleton<IEventHandler, ReplaceableEventHandler>();
services.AddSingleton<IEventHandler, EphemeralEventHandler>();
services.AddSingleton<IEventHandler, AddressableEventHandler>();

// RegularEventHandler needs to go last
services.AddSingleton<IEventHandler, RegularEventHandler>();

services.AddEventValidators();
services.AddSubscriptionValidators();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Netstr.Messaging.Events.Handlers;
using Netstr.Messaging.Models;

namespace Netstr.Messaging
namespace Netstr.Messaging.Events
{
/// <summary>
/// Dispatches EVENT message to someone who can handle it.
Expand Down
3 changes: 2 additions & 1 deletion src/Netstr/Messaging/Events/Handlers/RegularEventHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ public RegularEventHandler(
this.db = db;
}

public override bool CanHandleEvent(Event e) => (e.IsRegular() || e.IsUnknown()) && !e.IsDelete();
// this event handler also serves as a fallback for all unknown events
public override bool CanHandleEvent(Event e) => true;

protected override async Task HandleEventCoreAsync(IWebSocketAdapter sender, Event e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public EventTagsValidator(IOptions<LimitsOptions> limits)

public string? Validate(Event e, ClientContext context)
{
if (e.Tags.Length > this.limits.Value.MaxEventTags)
if (this.limits.Value.MaxEventTags > 0 && e.Tags.Length > this.limits.Value.MaxEventTags)
{
return Messages.InvalidTooManyTags;
}
Expand Down
5 changes: 3 additions & 2 deletions src/Netstr/Messaging/MessageHandlers/CountMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ public CountMessageHandler(
IDbContextFactory<NetstrDbContext> db,
IEnumerable<ISubscriptionRequestValidator> validators,
IOptions<LimitsOptions> limits,
IOptions<AuthOptions> auth)
: base(validators, limits, auth)
IOptions<AuthOptions> auth,
ILogger<CountMessageHandler> logger)
: base(validators, limits, auth, logger)
{
this.db = db;
}
Expand Down
23 changes: 22 additions & 1 deletion src/Netstr/Messaging/MessageHandlers/EventMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Netstr.Messaging.Models;
using Netstr.Options;
using System.Text.Json;
using System.Threading.RateLimiting;

namespace Netstr.Messaging.MessageHandlers
{
Expand All @@ -16,17 +17,28 @@ public class EventMessageHandler : IMessageHandler
private readonly IEventDispatcher eventDispatcher;
private readonly IEnumerable<IEventValidator> validators;
private readonly IOptions<AuthOptions> auth;
private readonly PartitionedRateLimiter<string> rateLimiter;

public EventMessageHandler(
ILogger<EventMessageHandler> logger,
IEventDispatcher eventDispatcher,
IEnumerable<IEventValidator> validators,
IOptions<AuthOptions> auth)
IOptions<AuthOptions> auth,
IOptions<LimitsOptions> limits
)
{
this.logger = logger;
this.eventDispatcher = eventDispatcher;
this.validators = validators;
this.auth = auth;
this.rateLimiter = PartitionedRateLimiter.Create<string, string>(
x => RateLimitPartition.GetSlidingWindowLimiter(x, _ => new SlidingWindowRateLimiterOptions
{
AutoReplenishment = true,
PermitLimit = limits.Value.MaxEventsPerMinute > 0 ? limits.Value.MaxEventsPerMinute : int.MaxValue,
SegmentsPerWindow = 6,
Window = TimeSpan.FromMinutes(1)
}));
}

public bool CanHandleMessage(string type) => type == MessageType.Event;
Expand All @@ -41,6 +53,15 @@ public async Task HandleMessageAsync(IWebSocketAdapter sender, JsonDocument[] pa
throw new MessageProcessingException(Messages.ErrorProcessingEvent);
}

using var lease = this.rateLimiter.AttemptAcquire(sender.Context.IpAddress);

if (!lease.IsAcquired)
{
this.logger.LogInformation($"User {sender.Context.IpAddress} is rate limited");
await sender.SendNotOkAsync(e.Id, Messages.RateLimited);
return;
}

var auth = this.auth.Value.Mode;

if (!sender.Context.IsAuthenticated() && (auth == AuthMode.Always || auth == AuthMode.Publishing))
Expand Down
24 changes: 23 additions & 1 deletion src/Netstr/Messaging/MessageHandlers/FilterMessageHandlerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
using Netstr.Messaging.Subscriptions;
using Netstr.Messaging.Subscriptions.Validators;
using Netstr.Options;
using System.Reflection;
using System.Text.Json;
using System.Threading.RateLimiting;

namespace Netstr.Messaging.MessageHandlers
{
Expand All @@ -18,15 +20,27 @@ public abstract class FilterMessageHandlerBase : IMessageHandler
protected readonly IEnumerable<ISubscriptionRequestValidator> validators;
protected readonly IOptions<LimitsOptions> limits;
protected readonly IOptions<AuthOptions> auth;
protected readonly ILogger<FilterMessageHandlerBase> logger;
protected readonly PartitionedRateLimiter<string> rateLimiter;

protected FilterMessageHandlerBase(
IEnumerable<ISubscriptionRequestValidator> validators,
IOptions<LimitsOptions> limits,
IOptions<AuthOptions> auth)
IOptions<AuthOptions> auth,
ILogger<FilterMessageHandlerBase> logger)
{
this.validators = validators;
this.limits = limits;
this.auth = auth;
this.logger = logger;
this.rateLimiter = PartitionedRateLimiter.Create<string, string>(
x => RateLimitPartition.GetSlidingWindowLimiter(x, _ => new SlidingWindowRateLimiterOptions
{
AutoReplenishment = true,
PermitLimit = limits.Value.MaxSubscriptionsPerMinute > 0 ? limits.Value.MaxSubscriptionsPerMinute : int.MaxValue,
SegmentsPerWindow = 6,
Window = TimeSpan.FromMinutes(1)
}));
}

protected abstract string AcceptedMessageType { get; }
Expand All @@ -41,6 +55,14 @@ public async Task HandleMessageAsync(IWebSocketAdapter adapter, JsonDocument[] p
}

var id = parameters[1].DeserializeRequired<string>();
using var lease = this.rateLimiter.AttemptAcquire(adapter.Context.IpAddress);

if (!lease.IsAcquired)
{
this.logger.LogInformation($"User {adapter.Context.IpAddress} is rate limited");
await adapter.SendClosedAsync(id, Messages.RateLimited);
return;
}

if (this.auth.Value.Mode == AuthMode.Always && !adapter.Context.IsAuthenticated())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ public SubscribeMessageHandler(
IDbContextFactory<NetstrDbContext> db,
IEnumerable<ISubscriptionRequestValidator> validators,
IOptions<LimitsOptions> limits,
IOptions<AuthOptions> auth)
: base(validators, limits, auth)
IOptions<AuthOptions> auth,
ILogger<SubscribeMessageHandler> logger)
: base(validators, limits, auth, logger)
{
this.db = db;
}
Expand Down
1 change: 1 addition & 0 deletions src/Netstr/Messaging/Messages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public static class Messages
public const string PowNotEnough = "pow: difficulty {0} is less than {1}";
public const string PowNoMatch = "pow: difficulty {0} doesn't match target of {1}";
public const string UnsupportedFilter = "unsupported: filter contains unknown elements";
public const string RateLimited = "rate-limited: slow down there chief";

public const string CannotParseMessage = "unable to parse the message";
public const string CannotProcessMessageType = "unknown message type";
Expand Down
5 changes: 4 additions & 1 deletion src/Netstr/Messaging/Models/ClientContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
/// </summary>
public class ClientContext
{
public ClientContext(string clientId)
public ClientContext(string clientId, string ipAddress)
{
ClientId = clientId;
IpAddress = ipAddress;
Challenge = Guid.NewGuid().ToString();
}

public string ClientId { get; }

public string IpAddress { get; }

public string Challenge { get; }

public string? PublicKey { get; private set; }
Expand Down
6 changes: 4 additions & 2 deletions src/Netstr/Messaging/WebSockets/WebSocketAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public WebSocketAdapter(
IMessageDispatcher dispatcher,
CancellationToken cancellationToken,
WebSocket ws,
IHeaderDictionary headers)
IHeaderDictionary headers,
ConnectionInfo connectionInfo)
{
this.logger = logger;
this.connection = connection;
Expand All @@ -45,8 +46,9 @@ public WebSocketAdapter(
e => logger.LogWarning($"Dropping following events due to capacity limit of {limits.Value.MaxPendingEvents}: {JsonSerializer.Serialize(e.Messages)}"));

var id = headers["sec-websocket-key"].ToString();


Context = new ClientContext(id);
Context = new ClientContext(id, connectionInfo.RemoteIpAddress?.ToString() ?? string.Empty);
}

public ClientContext Context { get; }
Expand Down
5 changes: 3 additions & 2 deletions src/Netstr/Messaging/WebSockets/WebSocketAdapterFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public WebSocketAdapterFactory(
this.lifetime = lifetime;
}

public IWebSocketListenerAdapter CreateAdapter(WebSocket socket, IHeaderDictionary headers)
public IWebSocketListenerAdapter CreateAdapter(WebSocket socket, IHeaderDictionary headers, ConnectionInfo connection)
{
var adapter = new WebSocketAdapter(
this.logger,
Expand All @@ -43,7 +43,8 @@ public IWebSocketListenerAdapter CreateAdapter(WebSocket socket, IHeaderDictiona
this.dispatcher,
this.lifetime.ApplicationStopping,
socket,
headers);
headers,
connection);

this.tracker.Add(adapter);

Expand Down
2 changes: 1 addition & 1 deletion src/Netstr/Middleware/WebSocketsMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public async Task Invoke(HttpContext context)
this.logger.LogInformation($"Accepting websocket connection from {context.Connection.RemoteIpAddress}");

var ws = await context.WebSockets.AcceptWebSocketAsync();
var adapter = this.factory.CreateAdapter(ws, context.Request.Headers);
var adapter = this.factory.CreateAdapter(ws, context.Request.Headers, context.Connection);

await adapter.StartAsync();

Expand Down
2 changes: 2 additions & 0 deletions src/Netstr/Options/LimitsOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@ public class LimitsOptions
public int MaxCreatedAtLowerOffset { get; init; }
public int MaxCreatedAtUpperOffset { get; init; }
public int MaxPendingEvents { get; init; }
public int MaxEventsPerMinute { get; init; }
public int MaxSubscriptionsPerMinute { get; init; }
}
}
4 changes: 3 additions & 1 deletion src/Netstr/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
"MaxEventTags": 1000,
"MaxCreatedAtLowerOffset": 31536000,
"MaxCreatedAtUpperOffset": 60,
"MaxPendingEvents": 1024
"MaxPendingEvents": 1024,
"MaxSubscriptionsPerMinute": 60,
"MaxEventsPerMinute": 300
},
"ConnectionStrings": {
"NetstrDatabase": "Host=localhost:5432;Database=Netsrt;Username=Netstr;Password=Netstr"
Expand Down
Binary file modified src/Netstr/wwwroot/favicon.ico
Binary file not shown.
8 changes: 5 additions & 3 deletions test/Netstr.Tests/Events/EventHandlersTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Moq;
using Netstr.Data;
using Netstr.Messaging;
using Netstr.Messaging.Events;
using Netstr.Messaging.Events.Handlers;
using Netstr.Messaging.Events.Handlers.Replaceable;
using Netstr.Messaging.Models;
Expand Down Expand Up @@ -57,16 +58,17 @@ public EventHandlersTests()
Mock.Of<IMessageDispatcher>(),
CancellationToken.None,
this.ws.Object,
Mock.Of<IHeaderDictionary>());
Mock.Of<IHeaderDictionary>(),
Mock.Of<ConnectionInfo>());

this.clients = new WebSocketAdapterCollection();

var handlers = new IEventHandler[]
{
new RegularEventHandler(Mock.Of<ILogger<RegularEventHandler>>(), auth, this.clients, this.dbFactoryMock.Object),
new EphemeralEventHandler(Mock.Of<ILogger<EphemeralEventHandler>>(), auth, this.clients),
new ReplaceableEventHandler(Mock.Of<ILogger<ReplaceableEventHandler>>(), auth, this.clients, this.dbFactoryMock.Object),
new AddressableEventHandler(Mock.Of<ILogger<ReplaceableEventHandler>>(), auth, this.clients, this.dbFactoryMock.Object)
new AddressableEventHandler(Mock.Of<ILogger<ReplaceableEventHandler>>(), auth, this.clients, this.dbFactoryMock.Object),
new RegularEventHandler(Mock.Of<ILogger<RegularEventHandler>>(), auth, this.clients, this.dbFactoryMock.Object)
};
this.dispatcher = new EventDispatcher(Mock.Of<ILogger<EventDispatcher>>(), handlers);
_ = Task.Run(this.adapter.StartAsync);
Expand Down
4 changes: 2 additions & 2 deletions test/Netstr.Tests/Events/EventVerificationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void AcceptsValidEvent()
Signature = "44224ca5edd01161f617a7347d4f0b1c9a8ccf7bfb3f70bd74db3d6e26f44aa5318f3d39c93f5769d24fa5e56bd98eed7cd23a114cc3412650678a0280ed94f4"
};

this.validators.ToList().ForEach(x => x.Validate(e, new ClientContext("test")).Should().BeNull());
this.validators.ToList().ForEach(x => x.Validate(e, new ClientContext("test", "ip")).Should().BeNull());
}

[Theory]
Expand Down Expand Up @@ -82,7 +82,7 @@ public void RejectsIfValidationFails(string id, string pubkey, string signature,
Signature = signature
};

var result = this.validators.Select(x => x.Validate(e, new ClientContext("test"))).FirstOrDefault(x => x != null);
var result = this.validators.Select(x => x.Validate(e, new ClientContext("test", "ip"))).FirstOrDefault(x => x != null);

result.Should().Be(error);
}
Expand Down
5 changes: 3 additions & 2 deletions test/Netstr.Tests/MessageDispatcherTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Moq;
using Netstr.Data;
using Netstr.Messaging;
using Netstr.Messaging.Events;
using Netstr.Messaging.MessageHandlers;
using Netstr.Options;

Expand All @@ -21,8 +22,8 @@ public MessageDispatcherTests()

this.handlers =
[
new EventMessageHandler(Mock.Of<ILogger<EventMessageHandler>>(), eventDispatcher.Object, [], Mock.Of<IOptions<AuthOptions>>()),
new SubscribeMessageHandler(Mock.Of<IDbContextFactory<NetstrDbContext>>(), [], Mock.Of<IOptions<LimitsOptions>>(), Mock.Of<IOptions<AuthOptions>>()),
new EventMessageHandler(Mock.Of<ILogger<EventMessageHandler>>(), eventDispatcher.Object, [], Mock.Of<IOptions<AuthOptions>>(), Mock.Of<IOptions<LimitsOptions>>()),
new SubscribeMessageHandler(Mock.Of<IDbContextFactory<NetstrDbContext>>(), [], Mock.Of<IOptions<LimitsOptions>>(), Mock.Of<IOptions<AuthOptions>>(), Mock.Of<ILogger<SubscribeMessageHandler>>()),
new UnsubscribeMessageHandler(),
];

Expand Down
Loading
Loading