diff --git a/Source/MQTTnet/Server/MqttClientSession.cs b/Source/MQTTnet/Server/MqttClientSession.cs index 9e402bebf..7a1d6fec6 100644 --- a/Source/MQTTnet/Server/MqttClientSession.cs +++ b/Source/MQTTnet/Server/MqttClientSession.cs @@ -79,6 +79,129 @@ public Task RunAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapte return _workerTask; } + public void Stop(MqttClientDisconnectType type) + { + Stop(type, false); + } + + public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket) + { + if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket)); + + var checkSubscriptionsResult = _subscriptionsManager.CheckSubscriptions(publishPacket.Topic, publishPacket.QualityOfServiceLevel); + if (!checkSubscriptionsResult.IsSubscribed) + { + return; + } + + publishPacket = new MqttPublishPacket + { + Topic = publishPacket.Topic, + Payload = publishPacket.Payload, + QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel, + Retain = publishPacket.Retain, + Dup = false + }; + + if (publishPacket.QualityOfServiceLevel > 0) + { + publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); + } + + if (_options.ClientMessageQueueInterceptor != null) + { + var context = new MqttClientMessageQueueInterceptorContext( + senderClientSession?.ClientId, + ClientId, + publishPacket.ToApplicationMessage()); + + _options.ClientMessageQueueInterceptor?.Invoke(context); + + if (!context.AcceptEnqueue || context.ApplicationMessage == null) + { + return; + } + + publishPacket.Topic = context.ApplicationMessage.Topic; + publishPacket.Payload = context.ApplicationMessage.Payload; + publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel; + } + + _pendingPacketsQueue.Enqueue(publishPacket); + } + + public Task SubscribeAsync(IList topicFilters) + { + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + _subscriptionsManager.Subscribe(new MqttSubscribePacket + { + TopicFilters = topicFilters + }); + + EnqueueSubscribedRetainedMessages(topicFilters); + return Task.FromResult(0); + } + + public Task UnsubscribeAsync(IList topicFilters) + { + if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); + + _subscriptionsManager.Unsubscribe(new MqttUnsubscribePacket + { + TopicFilters = topicFilters + }); + + return Task.FromResult(0); + } + + public void ClearPendingApplicationMessages() + { + _pendingPacketsQueue.Clear(); + } + + public void Dispose() + { + _pendingPacketsQueue?.Dispose(); + + _cancellationTokenSource?.Cancel (); + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; + } + + private void Stop(MqttClientDisconnectType type, bool isInsideSession) + { + try + { + var cts = _cancellationTokenSource; + if (cts == null || cts.IsCancellationRequested) + { + return; + } + + _cancellationTokenSource?.Cancel(false); + + _wasCleanDisconnect = type == MqttClientDisconnectType.Clean; + + if (_willMessage != null && !_wasCleanDisconnect) + { + _sessionsManager.EnqueueApplicationMessage(this, _willMessage.ToPublishPacket()); + } + + _willMessage = null; + + if (!isInsideSession) + { + _workerTask?.GetAwaiter().GetResult(); + } + } + finally + { + _logger.Info("Client '{0}': Disconnected (clean={1}).", ClientId, _wasCleanDisconnect); + _eventDispatcher.OnClientDisconnected(ClientId, _wasCleanDisconnect); + } + } + private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter adapter) { if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket)); @@ -98,7 +221,11 @@ private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChanne //workaround for https://github.com/dotnet/corefx/issues/24430 #pragma warning disable 4014 - _cleanupHandle = _cancellationTokenSource.Token.Register(() => TryDisposeAdapterAsync(adapter)); + _cleanupHandle = _cancellationTokenSource.Token.Register(async () => + { + await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false); + TryDisposeAdapter(adapter); + }); #pragma warning restore 4014 //end workaround @@ -149,7 +276,9 @@ private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChanne _adapterEndpoint = null; _adapterProtocolVersion = null; - await TryDisposeAdapterAsync(adapter).ConfigureAwait(false); + // Uncomment as soon as the workaround above is no longer needed. + //await TryDisconnectAdapterAsync(adapter).ConfigureAwait(false); + //TryDisposeAdapter(adapter); _cleanupHandle?.Dispose(); _cleanupHandle = null; @@ -160,7 +289,7 @@ private async Task RunInternalAsync(MqttConnectPacket connectPacket, IMqttChanne } } - private async Task TryDisposeAdapterAsync(IMqttChannelAdapter adapter) + private void TryDisposeAdapter(IMqttChannelAdapter adapter) { if (adapter == null) { @@ -172,145 +301,29 @@ private async Task TryDisposeAdapterAsync(IMqttChannelAdapter adapter) adapter.ReadingPacketStarted -= OnAdapterReadingPacketStarted; adapter.ReadingPacketCompleted -= OnAdapterReadingPacketCompleted; - await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false); + adapter.Dispose(); } catch (Exception exception) { - _logger.Error(exception, exception.Message); - } - finally - { - try - { - adapter.Dispose(); - } - catch - { - } - } - } - - public void Stop(MqttClientDisconnectType type) - { - Stop(type, false); - } - - private void Stop(MqttClientDisconnectType type, bool isInsideSession) - { - try - { - var cts = _cancellationTokenSource; - if (cts == null || cts.IsCancellationRequested) - { - return; - } - - _cancellationTokenSource?.Cancel(false); - - _wasCleanDisconnect = type == MqttClientDisconnectType.Clean; - - if (_willMessage != null && !_wasCleanDisconnect) - { - _sessionsManager.EnqueueApplicationMessage(this, _willMessage.ToPublishPacket()); - } - - _willMessage = null; - - if (!isInsideSession) - { - _workerTask?.GetAwaiter().GetResult(); - } - } - finally - { - _logger.Info("Client '{0}': Disconnected (clean={1}).", ClientId, _wasCleanDisconnect); - _eventDispatcher.OnClientDisconnected(ClientId, _wasCleanDisconnect); + _logger.Error(exception, "Error while disposing channel adapter."); } } - public void EnqueueApplicationMessage(MqttClientSession senderClientSession, MqttPublishPacket publishPacket) + private async Task TryDisconnectAdapterAsync(IMqttChannelAdapter adapter) { - if (publishPacket == null) throw new ArgumentNullException(nameof(publishPacket)); - - var checkSubscriptionsResult = _subscriptionsManager.CheckSubscriptions(publishPacket.Topic, publishPacket.QualityOfServiceLevel); - if (!checkSubscriptionsResult.IsSubscribed) + if (adapter == null) { return; } - publishPacket = new MqttPublishPacket - { - Topic = publishPacket.Topic, - Payload = publishPacket.Payload, - QualityOfServiceLevel = checkSubscriptionsResult.QualityOfServiceLevel, - Retain = publishPacket.Retain, - Dup = false - }; - - if (publishPacket.QualityOfServiceLevel > 0) + try { - publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNewPacketIdentifier(); + await adapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false); } - - if (_options.ClientMessageQueueInterceptor != null) + catch (Exception exception) { - var context = new MqttClientMessageQueueInterceptorContext( - senderClientSession?.ClientId, - ClientId, - publishPacket.ToApplicationMessage()); - - _options.ClientMessageQueueInterceptor?.Invoke(context); - - if (!context.AcceptEnqueue || context.ApplicationMessage == null) - { - return; - } - - publishPacket.Topic = context.ApplicationMessage.Topic; - publishPacket.Payload = context.ApplicationMessage.Payload; - publishPacket.QualityOfServiceLevel = context.ApplicationMessage.QualityOfServiceLevel; + _logger.Error(exception, "Error while disconnecting channel adapter."); } - - _pendingPacketsQueue.Enqueue(publishPacket); - } - - public Task SubscribeAsync(IList topicFilters) - { - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - - _subscriptionsManager.Subscribe(new MqttSubscribePacket - { - TopicFilters = topicFilters - }); - - EnqueueSubscribedRetainedMessages(topicFilters); - return Task.FromResult(0); - } - - public Task UnsubscribeAsync(IList topicFilters) - { - if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters)); - - _subscriptionsManager.Unsubscribe(new MqttUnsubscribePacket - { - TopicFilters = topicFilters - }); - - return Task.FromResult(0); - } - - public void ClearPendingApplicationMessages() - { - _pendingPacketsQueue.Clear(); - } - - public void Dispose() - { - _pendingPacketsQueue?.Dispose(); - - _cancellationTokenSource?.Cancel (); - _cancellationTokenSource?.Dispose(); - _cancellationTokenSource = null; } private void ProcessReceivedPacket(IMqttChannelAdapter adapter, MqttBasePacket packet, CancellationToken cancellationToken) diff --git a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs index 39352ec3f..457de926b 100644 --- a/Tests/MQTTnet.Core.Tests/MqttServerTests.cs +++ b/Tests/MQTTnet.Core.Tests/MqttServerTests.cs @@ -301,11 +301,6 @@ public async Task MqttServer_ShutdownDisconnectsClientsGracefully() [TestMethod] public async Task MqttServer_HandleCleanDisconnect() { - MqttNetGlobalLogger.LogMessagePublished += (_, e) => - { - System.Diagnostics.Debug.WriteLine($"[{e.TraceMessage.Timestamp:s}] {e.TraceMessage.Source} {e.TraceMessage.Message}"); - }; - var serverAdapter = new MqttTcpServerAdapter(new MqttNetLogger().CreateChildLogger()); var s = new MqttFactory().CreateMqttServer(new[] { serverAdapter }, new MqttNetLogger()); @@ -338,6 +333,51 @@ public async Task MqttServer_HandleCleanDisconnect() Assert.AreEqual(clientConnectedCalled, clientDisconnectedCalled); } + [TestMethod] + public async Task MqttServer_Client_Disconnect_Without_Errors() + { + var errors = 0; + + MqttNetGlobalLogger.LogMessagePublished += (_, e) => + { + System.Diagnostics.Debug.WriteLine($"[{e.TraceMessage.Timestamp:s}] {e.TraceMessage.Source} {e.TraceMessage.Message}"); + + if (e.TraceMessage.Level == MqttNetLogLevel.Error) + { + errors++; + } + }; + + bool clientWasConnected; + + var server = new MqttFactory().CreateMqttServer(); + try + { + var options = new MqttServerOptionsBuilder().Build(); + await server.StartAsync(options); + + var client = new MqttFactory().CreateMqttClient(); + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost") + .Build(); + + await client.ConnectAsync(clientOptions); + + clientWasConnected = true; + + await client.DisconnectAsync(); + + await Task.Delay(500); + } + finally + { + await server.StopAsync(); + } + + Assert.IsTrue(clientWasConnected); + Assert.AreEqual(0, errors); + } + [TestMethod] public async Task MqttServer_LotsOfRetainedMessages() {