Skip to content

Commit

Permalink
Fix the comms library reporting that all client messages have been pr…
Browse files Browse the repository at this point in the history
…ocessed when a blockwise request is in progress
  • Loading branch information
sergeuz committed Apr 21, 2023
1 parent 2700304 commit fb730fe
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 31 deletions.
13 changes: 6 additions & 7 deletions communication/inc/dtls_message_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ namespace particle
namespace protocol
{

class Protocol;

/**
* Please centralize this somewhere else!
*/
Expand Down Expand Up @@ -79,7 +81,6 @@ class DTLSMessageChannel: public BufferMessageChannel<PROTOCOL_BUFFER_SIZE>
int (*restore)(void* data, size_t max_length, uint8_t type, void* reserved);

uint32_t (*calculate_crc)(const uint8_t* data, uint32_t length);
void (*notify_client_messages_processed)(void* reserved);
};

private:
Expand All @@ -92,6 +93,7 @@ class DTLSMessageChannel: public BufferMessageChannel<PROTOCOL_BUFFER_SIZE>
mbedtls_pk_context pkey;
mbedtls_timing_delay_context timer;
Callbacks callbacks;
Protocol* protocol;
uint8_t* server_public;
uint16_t server_public_len;
uint32_t keys_checksum;
Expand Down Expand Up @@ -123,13 +125,14 @@ class DTLSMessageChannel: public BufferMessageChannel<PROTOCOL_BUFFER_SIZE>
void reset_session();

public:
DTLSMessageChannel() :
explicit DTLSMessageChannel(Protocol* protocol) :
ssl_context(),
conf(),
clicert(),
pkey(),
timer(),
callbacks(),
protocol(protocol),
server_public(nullptr),
server_public_len(0),
keys_checksum(0),
Expand Down Expand Up @@ -167,11 +170,7 @@ class DTLSMessageChannel: public BufferMessageChannel<PROTOCOL_BUFFER_SIZE>

virtual ProtocolError command(Command cmd, void* arg=nullptr) override;

virtual void notify_client_messages_processed() override {
if (callbacks.notify_client_messages_processed) {
callbacks.notify_client_messages_processed(nullptr);
}
}
virtual void notify_client_messages_processed() override;

virtual AppStateDescriptor cached_app_state_descriptor() const override;

Expand Down
7 changes: 6 additions & 1 deletion communication/inc/message_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ struct MessageChannel : public Channel
*/
virtual void notify_client_messages_processed()=0;

/**
* Returns `true` if there are client messages being processed at the moment, or `false` otherwise.
*/
virtual bool has_pending_client_messages() const = 0;

/**
* Get a descriptor of the cached application state.
*
Expand All @@ -262,7 +267,7 @@ struct MessageChannel : public Channel
class AbstractMessageChannel : public MessageChannel
{
public:
void set_debug_enabled(bool enabled) override {
void set_debug_enabled(bool /* enabled */) override {
}
};

Expand Down
4 changes: 3 additions & 1 deletion communication/inc/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -624,10 +624,12 @@ class Protocol

virtual int get_describe_data(spark_protocol_describe_data* data, void* reserved);

virtual int get_status(protocol_status* status) const = 0;
int get_status(protocol_status* status) const;

void notify_message_complete(message_id_t msg_id, CoAPCode::Enum responseCode);

void notify_client_messages_processed();

/**
* Retrieves the next token.
*/
Expand Down
23 changes: 15 additions & 8 deletions communication/src/coap_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
#include "service_debug.h"

#include "communication_diagnostic.h"

#include <limits>
#include <utility>

namespace particle
{
Expand All @@ -52,8 +54,10 @@ class CoAPChannel : public T
}

public:
CoAPChannel(message_id_t msg_seed=0) : message_id(msg_seed)
{
template<typename... ArgsT>
explicit CoAPChannel(ArgsT&&... args) :
T(std::forward<ArgsT>(args)...),
message_id(0) {
}

/**
Expand Down Expand Up @@ -554,8 +558,10 @@ class CoAPReliableChannel : public T
DelegateChannel delegateChannel;

public:

CoAPReliableChannel(M m=0) : millis(m) {
template<typename... ArgsT>
explicit CoAPReliableChannel(ArgsT&&... args) :
T(std::forward<ArgsT>(args)...),
millis(nullptr) {
delegateChannel.init(this);
}

Expand Down Expand Up @@ -622,6 +628,11 @@ class CoAPReliableChannel : public T
return receive(msg, true);
}

bool has_pending_client_messages() const override
{
return client.has_messages();
}

/**
* Pulls messages from the message channel
*/
Expand All @@ -637,10 +648,6 @@ class CoAPReliableChannel : public T
return client.has_messages() || server.has_unacknowledged_requests();
}

bool has_unacknowledged_client_requests() const {
return client.has_messages();
}

/**
* Pulls messages from the channel and stores it in a message store for
* reliable receipt and retransmission.
Expand Down
2 changes: 2 additions & 0 deletions communication/src/description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ ProtocolError Description::receiveAckOrRst(const Message& msg, int* descFlags) {
activeReq_.reset();
if (!reqQueue_.isEmpty()) {
CHECK_PROTOCOL(sendNextRequest(reqQueue_.takeFirst()));
} else {
proto_->notify_client_messages_processed();
}
*descFlags = flags;
}
Expand Down
6 changes: 6 additions & 0 deletions communication/src/description.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class Description {

ProtocolError serialize(Appender* appender, int descFlags);

bool hasPendingClientRequests() const;

void reset();

private:
Expand Down Expand Up @@ -96,6 +98,10 @@ class Description {
system_tick_t millis() const;
};

inline bool Description::hasPendingClientRequests() const {
return activeReq_.has_value() || !reqQueue_.isEmpty();
}

} // namespace protocol

} // namespace particle
4 changes: 4 additions & 0 deletions communication/src/dtls_message_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,10 @@ ProtocolError DTLSMessageChannel::command(Command command, void* arg)
return NO_ERROR;
}

void DTLSMessageChannel::notify_client_messages_processed() {
protocol->notify_client_messages_processed();
}

AppStateDescriptor DTLSMessageChannel::cached_app_state_descriptor() const
{
return sessionPersist.app_state_descriptor();
Expand Down
3 changes: 0 additions & 3 deletions communication/src/dtls_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ void DTLSProtocol::init(const char *id,
channelCallbacks.save = callbacks.save;
channelCallbacks.restore = callbacks.restore;
}
if (offsetof(SparkCallbacks, notify_client_messages_processed) + sizeof(SparkCallbacks::notify_client_messages_processed) <= callbacks.size) {
channelCallbacks.notify_client_messages_processed = callbacks.notify_client_messages_processed;
}

// TODO: Ideally, the next token value should be stored in the session data
mbedtls_default_rng(nullptr, &next_token, sizeof(next_token));
Expand Down
16 changes: 5 additions & 11 deletions communication/src/dtls_protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ class DTLSProtocol : public Protocol

// todo - this a duplicate of LightSSLProtocol - factor out

DTLSProtocol() : Protocol(channel) {}
DTLSProtocol() :
Protocol(channel),
channel(this),
device_id() {
}

void init(const char *id,
const SparkKeys &keys,
Expand Down Expand Up @@ -120,16 +124,6 @@ class DTLSProtocol : public Protocol
}
}

int get_status(protocol_status* status) const override {
SPARK_ASSERT(status);
status->flags = 0;
if (channel.has_unacknowledged_client_requests()) {
status->flags |= PROTOCOL_STATUS_HAS_PENDING_CLIENT_MESSAGES;
}
return NO_ERROR;
}


/**
* Ensures that all outstanding sent coap messages have been acknowledged.
*/
Expand Down
16 changes: 16 additions & 0 deletions communication/src/protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,22 @@ int Protocol::get_describe_data(spark_protocol_describe_data* data, void* reserv
return 0;
}

int Protocol::get_status(protocol_status* status) const {
SPARK_ASSERT(status);
status->flags = 0;
if (channel.has_pending_client_messages() || description.hasPendingClientRequests()) {
status->flags |= PROTOCOL_STATUS_HAS_PENDING_CLIENT_MESSAGES;
}
return ProtocolError::NO_ERROR;
}

void Protocol::notify_client_messages_processed() {
if (callbacks.notify_client_messages_processed && !channel.has_pending_client_messages() &&
!description.hasPendingClientRequests()) { // Ensure there's no pending blockwise requests
callbacks.notify_client_messages_processed(nullptr /* reserved */);
}
}

size_t Protocol::get_max_transmit_message_size() const
{
if (!max_transmit_message_size) {
Expand Down

0 comments on commit fb730fe

Please sign in to comment.