Skip to content

Commit

Permalink
Fix client session disposal in server.
Browse files Browse the repository at this point in the history
  • Loading branch information
chkr1011 committed Dec 9, 2018
1 parent ab24368 commit 28b2562
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 132 deletions.
267 changes: 140 additions & 127 deletions Source/MQTTnet/Server/MqttClientSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,129 @@ public Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapte
return _workerTask;
}

public void Stop(MqttClientDisconnectType type)
{
Stop(type, false);
}

public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket)
{
if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket));

var checkSubscriptionsResult = _subscriptionsManager.CheckSubscriptions(publishPacket.Topic, publishPacket.QualityOfServiceLevel);
if (!checkSubscriptionsResult.IsSubscribed)
{
return;
}

publishPacket = new MqttPublishPacket
{
Topic = publishPacket.Topic,
Payload = publishPacket.Payload,
QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel,
Retain = publishPacket.Retain,
Dup = false
};

if (publishPacket.QualityOfServiceLevel > 0)
{
publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier();
}

if (_options.ClientMessageQueueInterceptor != null)
{
var context = new MqttClientMessageQueueInterceptorContext(
senderClientSession?.ClientId,
ClientId,
publishPacket.ToApplicationMessage());

_options.ClientMessageQueueInterceptor?.Invoke(context);

if (!context.AcceptEnqueue || context.ApplicationMessage == null)
{
return;
}

publishPacket.Topic = context.ApplicationMessage.Topic;
publishPacket.Payload = context.ApplicationMessage.Payload;
publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel;
}

_pendingPacketsQueue.Enqueue(publishPacket);
}

public Task SubscribeAsync(IList<TopicFilter> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

_subscriptionsManager.Subscribe(new MqttSubscribePacket
{
TopicFilters = topicFilters
});

EnqueueSubscribedRetainedMessages(topicFilters);
return Task.FromResult(0);
}

public Task UnsubscribeAsync(IList<string> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

_subscriptionsManager.Unsubscribe(new MqttUnsubscribePacket
{
TopicFilters = topicFilters
});

return Task.FromResult(0);
}

public void ClearPendingApplicationMessages()
{
_pendingPacketsQueue.Clear();
}

public void Dispose()
{
_pendingPacketsQueue?.Dispose();

_cancellationTokenSource?.Cancel ();
_cancellationTokenSource?.Dispose();
_cancellationTokenSource = null;
}

private void Stop(MqttClientDisconnectType type, bool isInsideSession)
{
try
{
var cts = _cancellationTokenSource;
if (cts == null || cts.IsCancellationRequested)
{
return;
}

_cancellationTokenSource?.Cancel(false);

_wasCleanDisconnect = type == MqttClientDisconnectType.Clean;

if (_willMessage != null && !_wasCleanDisconnect)
{
_sessionsManager.EnqueueApplicationMessage(this, _willMessage.ToPublishPacket());
}

_willMessage = null;

if (!isInsideSession)
{
_workerTask?.GetAwaiter().GetResult();
}
}
finally
{
_logger.Info("Client '{0}': Disconnected (clean={1}).", ClientId, _wasCleanDisconnect);
_eventDispatcher.OnClientDisconnected(ClientId, _wasCleanDisconnect);
}
}

private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter)
{
if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket));
Expand All @@ -98,7 +221,11 @@ private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChanne

//workaround for https://github.com/dotnet/corefx/issues/24430
#pragma warning disable 4014
_cleanupHandle = _cancellationTokenSource.Token.Register(() => TryDisposeAdapterAsync(adapter));
_cleanupHandle = _cancellationTokenSource.Token.Register(async () =>
{
await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false);
TryDisposeAdapter(adapter);
});
#pragma warning restore 4014
//end workaround

Expand Down Expand Up @@ -149,7 +276,9 @@ private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChanne
_adapterEndpoint = null;
_adapterProtocolVersion = null;

await TryDisposeAdapterAsync(adapter).ConfigureAwait(false);
// Uncomment as soon as the workaround above is no longer needed.
//await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false);
//TryDisposeAdapter(adapter);

_cleanupHandle?.Dispose();
_cleanupHandle = null;
Expand All @@ -160,7 +289,7 @@ private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChanne
}
}

private async Task TryDisposeAdapterAsync(IMqttChannelAdapter adapter)
private void TryDisposeAdapter(IMqttChannelAdapter adapter)
{
if (adapter == null)
{
Expand All @@ -172,145 +301,29 @@ private async Task TryDisposeAdapterAsync(IMqttChannelAdapter adapter)
adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted;
adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted;

await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
adapter.Dispose();
}
catch (Exception exception)
{
_logger.Error(exception, exception.Message);
}
finally
{
try
{
adapter.Dispose();
}
catch
{
}
}
}

public void Stop(MqttClientDisconnectType type)
{
Stop(type, false);
}

private void Stop(MqttClientDisconnectType type, bool isInsideSession)
{
try
{
var cts = _cancellationTokenSource;
if (cts == null || cts.IsCancellationRequested)
{
return;
}

_cancellationTokenSource?.Cancel(false);

_wasCleanDisconnect = type == MqttClientDisconnectType.Clean;

if (_willMessage != null && !_wasCleanDisconnect)
{
_sessionsManager.EnqueueApplicationMessage(this, _willMessage.ToPublishPacket());
}

_willMessage = null;

if (!isInsideSession)
{
_workerTask?.GetAwaiter().GetResult();
}
}
finally
{
_logger.Info("Client '{0}': Disconnected (clean={1}).", ClientId, _wasCleanDisconnect);
_eventDispatcher.OnClientDisconnected(ClientId, _wasCleanDisconnect);
_logger.Error(exception, "Error while disposing channel adapter.");
}
}

public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket)
private async Task TryDisconnectAdapterAsync(IMqttChannelAdapter adapter)
{
if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket));

var checkSubscriptionsResult = _subscriptionsManager.CheckSubscriptions(publishPacket.Topic, publishPacket.QualityOfServiceLevel);
if (!checkSubscriptionsResult.IsSubscribed)
if (adapter == null)
{
return;
}

publishPacket = new MqttPublishPacket
{
Topic = publishPacket.Topic,
Payload = publishPacket.Payload,
QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel,
Retain = publishPacket.Retain,
Dup = false
};

if (publishPacket.QualityOfServiceLevel > 0)
try
{
publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier();
await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
}

if (_options.ClientMessageQueueInterceptor != null)
catch (Exception exception)
{
var context = new MqttClientMessageQueueInterceptorContext(
senderClientSession?.ClientId,
ClientId,
publishPacket.ToApplicationMessage());

_options.ClientMessageQueueInterceptor?.Invoke(context);

if (!context.AcceptEnqueue || context.ApplicationMessage == null)
{
return;
}

publishPacket.Topic = context.ApplicationMessage.Topic;
publishPacket.Payload = context.ApplicationMessage.Payload;
publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel;
_logger.Error(exception, "Error while disconnecting channel adapter.");
}

_pendingPacketsQueue.Enqueue(publishPacket);
}

public Task SubscribeAsync(IList<TopicFilter> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

_subscriptionsManager.Subscribe(new MqttSubscribePacket
{
TopicFilters = topicFilters
});

EnqueueSubscribedRetainedMessages(topicFilters);
return Task.FromResult(0);
}

public Task UnsubscribeAsync(IList<string> topicFilters)
{
if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));

_subscriptionsManager.Unsubscribe(new MqttUnsubscribePacket
{
TopicFilters = topicFilters
});

return Task.FromResult(0);
}

public void ClearPendingApplicationMessages()
{
_pendingPacketsQueue.Clear();
}

public void Dispose()
{
_pendingPacketsQueue?.Dispose();

_cancellationTokenSource?.Cancel ();
_cancellationTokenSource?.Dispose();
_cancellationTokenSource = null;
}

private void ProcessReceivedPacket(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken)
Expand Down
50 changes: 45 additions & 5 deletions Tests/MQTTnet.Core.Tests/MqttServerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,6 @@ public async Task MqttServer_ShutdownDisconnectsClientsGracefully()
[TestMethod]
public async Task MqttServer_HandleCleanDisconnect()
{
MqttNetGlobalLogger.LogMessagePublished += (_, e) =>
{
System.Diagnostics.Debug.WriteLine($"[{e.TraceMessage.Timestamp:s}] {e.TraceMessage.Source} {e.TraceMessage.Message}");
};

var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger());
var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger());

Expand Down Expand Up @@ -338,6 +333,51 @@ public async Task MqttServer_HandleCleanDisconnect()
Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled);
}

[TestMethod]
public async Task MqttServer_Client_Disconnect_Without_Errors()
{
var errors = 0;

MqttNetGlobalLogger.LogMessagePublished += (_, e) =>
{
System.Diagnostics.Debug.WriteLine($"[{e.TraceMessage.Timestamp:s}] {e.TraceMessage.Source} {e.TraceMessage.Message}");

if (e.TraceMessage.Level == MqttNetLogLevel.Error)
{
errors++;
}
};

bool clientWasConnected;

var server = new MqttFactory().CreateMqttServer();
try
{
var options = new MqttServerOptionsBuilder().Build();
await server.StartAsync(options);

var client = new MqttFactory().CreateMqttClient();
var clientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("localhost")
.Build();

await client.ConnectAsync(clientOptions);

clientWasConnected = true;

await client.DisconnectAsync();

await Task.Delay(500);
}
finally
{
await server.StopAsync();
}

Assert.IsTrue(clientWasConnected);
Assert.AreEqual(0, errors);
}

[TestMethod]
public async Task MqttServer_LotsOfRetainedMessages()
{
Expand Down

0 comments on commit 28b2562

Please sign in to comment.