Skip to content

Commit

Permalink
#692 Add the possibility to specify "Sec-WebSocket-Protocol" (#693)
Browse files Browse the repository at this point in the history
* #692 Add possibility to specific Sec-WebSocket-Protocol
Add documentation
Not using find_first_of as it is a C++ 17 function
Added get_subprotocol getter
use const& instead of && for subprotocols parameter
Remove closing of the WebSocket if no subprotocol is matching
Fix missing include <algorithm>
Add unit test for subprotocol
  • Loading branch information
KaSSaaaa authored Sep 10, 2024
1 parent a7217e6 commit 210cbfd
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 3 deletions.
5 changes: 5 additions & 0 deletions docs/guides/websockets.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ The maximum payload size that a connection accepts can be adjusted either global
By default, this limit is disabled. To disable the global setting in specific routes, you only need to call `#!cpp CROW_WEBSOCKET_ROUTE(app, "/url").max_payload(UINT64_MAX)`.
## Subprotocols
<span class="tag">[:octicons-feed-tag-16: master](https://github.com/CrowCpp/Crow)</span>
Specifies the possible subprotocols that are available for the client. If specified, the first match with the client's requested subprotocols will be returned in the "Sec-WebSocket-Protocol" header of the handshake response. Otherwise, the connection will be closed. If no subprotocol are specified on both the client and the server side, the connection process will continue normally. It can be specified by using `#!cpp CROW_WEBSOCKET_ROUTE(app, "/url").subprotocols(<values>)`.
For more info about websocket routes go [here](../reference/classcrow_1_1_web_socket_rule.html).
Expand Down
11 changes: 9 additions & 2 deletions include/crow/routing.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,12 @@ namespace crow // NOTE: Already documented in "crow/app.h"
void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override
{
max_payload_ = max_payload_override_ ? max_payload_ : app_->websocket_max_payload();
new crow::websocket::Connection<SocketAdaptor, App>(req, std::move(adaptor), app_, max_payload_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_);
new crow::websocket::Connection<SocketAdaptor, App>(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_);
}
#ifdef CROW_ENABLE_SSL
void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override
{
new crow::websocket::Connection<SSLAdaptor, App>(req, std::move(adaptor), app_, max_payload_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_);
new crow::websocket::Connection<SSLAdaptor, App>(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_);
}
#endif

Expand All @@ -478,6 +478,12 @@ namespace crow // NOTE: Already documented in "crow/app.h"
return *this;
}

self_t& subprotocols(const std::vector<std::string>& subprotocols)
{
subprotocols_ = subprotocols;
return *this;
}

template<typename Func>
self_t& onopen(Func f)
{
Expand Down Expand Up @@ -522,6 +528,7 @@ namespace crow // NOTE: Already documented in "crow/app.h"
std::function<bool(const crow::request&, void**)> accept_handler_;
uint64_t max_payload_;
bool max_payload_override_ = false;
std::vector<std::string> subprotocols_;
};

/// Allows the user to assign parameters using functions.
Expand Down
39 changes: 39 additions & 0 deletions include/crow/utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -907,5 +907,44 @@ namespace crow

return v.substr(begin, end - begin);
}

/**
* @brief splits a string based on a separator
*/
inline static std::vector<std::string> split(const std::string& v, const std::string& separator)
{
std::vector<std::string> result;
size_t startPos = 0;

for (size_t foundPos = v.find(separator); foundPos != std::string::npos; foundPos = v.find(separator, startPos))
{
result.push_back(v.substr(startPos, foundPos - startPos));
startPos = foundPos + separator.size();
}

result.push_back(v.substr(startPos));
return result;
}

/**
* @brief Returns the first occurence that matches between two ranges of iterators
* @param first1 begin() iterator of the first range
* @param last1 end() iterator of the first range
* @param first2 begin() iterator of the second range
* @param last2 end() iterator of the second range
* @return first occurence that matches between two ranges of iterators
*/
template <typename Iter1, typename Iter2>
inline static Iter1 find_first_of(Iter1 first1, Iter1 last1, Iter2 first2, Iter2 last2)
{
for (; first1 != last1; ++first1)
{
if (std::find(first2, last2, *first1) != last2)
{
return first1;
}
}
return last1;
}
} // namespace utility
} // namespace crow
28 changes: 27 additions & 1 deletion include/crow/websocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ namespace crow // NOTE: Already documented in "crow/app.h"
virtual void send_pong(std::string msg) = 0;
virtual void close(std::string const& msg = "quit", uint16_t status_code = CloseStatusCode::NormalClosure) = 0;
virtual std::string get_remote_ip() = 0;
virtual std::string get_subprotocol() const = 0;
virtual ~connection() = default;

void userdata(void* u) { userdata_ = u; }
Expand Down Expand Up @@ -109,7 +110,8 @@ namespace crow // NOTE: Already documented in "crow/app.h"
///
/// Requires a request with an "Upgrade: websocket" header.<br>
/// Automatically handles the handshake.
Connection(const crow::request& req, Adaptor&& adaptor, Handler* handler, uint64_t max_payload,
Connection(const crow::request& req, Adaptor&& adaptor, Handler* handler,
uint64_t max_payload, const std::vector<std::string>& subprotocols,
std::function<void(crow::websocket::connection&)> open_handler,
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler,
Expand All @@ -132,6 +134,17 @@ namespace crow // NOTE: Already documented in "crow/app.h"
return;
}

std::string requested_subprotocols_header = req.get_header_value("Sec-WebSocket-Protocol");
if (!subprotocols.empty() || !requested_subprotocols_header.empty())
{
auto requested_subprotocols = utility::split(requested_subprotocols_header, ", ");
auto subprotocol = utility::find_first_of(subprotocols.begin(), subprotocols.end(), requested_subprotocols.begin(), requested_subprotocols.end());
if (subprotocol != subprotocols.end())
{
subprotocol_ = *subprotocol;
}
}

if (accept_handler_)
{
void* ud = nullptr;
Expand Down Expand Up @@ -268,6 +281,12 @@ namespace crow // NOTE: Already documented in "crow/app.h"
max_payload_bytes_ = payload;
}

/// Returns the matching client/server subprotocol, empty string if none matched.
std::string get_subprotocol() const override
{
return subprotocol_;
}

protected:
/// Generate the websocket headers using an opcode and the message size (in bytes).
std::string build_header(int opcode, size_t size)
Expand Down Expand Up @@ -307,6 +326,12 @@ namespace crow // NOTE: Already documented in "crow/app.h"
write_buffers_.emplace_back(header);
write_buffers_.emplace_back(std::move(hello));
write_buffers_.emplace_back(crlf);
if (!subprotocol_.empty())
{
write_buffers_.emplace_back("Sec-WebSocket-Protocol: ");
write_buffers_.emplace_back(subprotocol_);
write_buffers_.emplace_back(crlf);
}
write_buffers_.emplace_back(crlf);
do_write();
if (open_handler_)
Expand Down Expand Up @@ -779,6 +804,7 @@ namespace crow // NOTE: Already documented in "crow/app.h"
uint16_t remaining_length16_{0};
uint64_t remaining_length_{0};
uint64_t max_payload_bytes_{UINT64_MAX};
std::string subprotocol_;
bool close_connection_{false};
bool is_reading{false};
bool has_mask_{false};
Expand Down
51 changes: 51 additions & 0 deletions tests/unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3203,6 +3203,57 @@ TEST_CASE("websocket_max_payload")
app.stop();
} // websocket_max_payload

TEST_CASE("websocket_subprotocols")
{
static std::string http_message = "GET /ws HTTP/1.1\r\nConnection: keep-alive, Upgrade\r\nupgrade: websocket\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Protocol: myprotocol\r\nSec-WebSocket-Version: 13\r\nHost: localhost\r\n\r\n";

static websocket::connection* connection = nullptr;
static bool connected{false};

SimpleApp app;

CROW_WEBSOCKET_ROUTE(app, "/ws")
.subprotocols({"anotherprotocol", "myprotocol"})
.onaccept([&](const crow::request& req, void**) {
CROW_LOG_INFO << "Accepted websocket with URL " << req.url;
return true;
})
.onopen([&](websocket::connection& con) {
connected = true;
connection = &con;
CROW_LOG_INFO << "Connected websocket and subprotocol is " << con.get_subprotocol();
})
.onclose([&](websocket::connection&, const std::string&, uint16_t) {
CROW_LOG_INFO << "Closing websocket";
});

app.validate();

auto _ = app.bindaddr(LOCALHOST_ADDRESS).port(45451).run_async();
app.wait_for_server_start();
asio::io_service is;

asio::ip::tcp::socket c(is);
c.connect(asio::ip::tcp::endpoint(
asio::ip::address::from_string(LOCALHOST_ADDRESS), 45451));


char buf[2048];

//----------Handshake----------
{
std::fill_n(buf, 2048, 0);
c.send(asio::buffer(http_message));

c.receive(asio::buffer(buf, 2048));
std::this_thread::sleep_for(std::chrono::milliseconds(5));
CHECK(connected);
CHECK(connection->get_subprotocol() == "myprotocol");
}

app.stop();
}

#ifdef CROW_ENABLE_COMPRESSION
TEST_CASE("zlib_compression")
{
Expand Down

0 comments on commit 210cbfd

Please sign in to comment.