Skip to content

Commit

Permalink
Polish
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Dec 8, 2023
1 parent 89cc215 commit 71cd890
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 20 deletions.
6 changes: 4 additions & 2 deletions coroio/ssl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ TSslContext::~TSslContext() {
SSL_CTX_free(Ctx);
}

TSslContext TSslContext::Client() {
TSslContext TSslContext::Client(const std::function<void(const char*)>& logFunc) {
TSslContext ctx;
ctx.Ctx = SSL_CTX_new(TLS_client_method());
ctx.LogFunc = logFunc;
return ctx;
}

TSslContext TSslContext::Server(const char* certfile, const char* keyfile) {
TSslContext TSslContext::Server(const char* certfile, const char* keyfile, const std::function<void(const char*)>& logFunc) {
TSslContext ctx;
ctx.Ctx = SSL_CTX_new(TLS_server_method());
ctx.LogFunc = logFunc;

if (SSL_CTX_use_certificate_file(ctx.Ctx, certfile, SSL_FILETYPE_PEM) != 1) {
throw std::runtime_error("SSL_CTX_use_certificate_file failed");
Expand Down
18 changes: 9 additions & 9 deletions coroio/ssl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@ namespace NNet {

struct TSslContext {
SSL_CTX* Ctx;
std::function<void(const char*)> LogFunc = {};

TSslContext(TSslContext&& other)
: Ctx(other.Ctx)
, LogFunc(other.LogFunc)
{
other.Ctx = nullptr;
}

~TSslContext();

static TSslContext Client();
static TSslContext Server(const char* certfile, const char* keyfile);
static TSslContext Client(const std::function<void(const char*)>& logFunc = {});
static TSslContext Server(const char* certfile, const char* keyfile, const std::function<void(const char*)>& logFunc = {});

private:
TSslContext();
Expand All @@ -34,13 +36,12 @@ struct TSslContext {
template<typename THandle>
class TSslSocket {
public:
TSslSocket(THandle& socket, TSslContext& ctx, const std::function<void(const char*)>& logFunc = {})
TSslSocket(THandle& socket, TSslContext& ctx)
: Socket(socket)
, Ctx(ctx)
, Ssl(SSL_new(Ctx.Ctx))
, Rbio(BIO_new(BIO_s_mem()))
, Wbio(BIO_new(BIO_s_mem()))
, LogFunc(logFunc)
{
SSL_set_bio(Ssl, Rbio, Wbio);
}
Expand Down Expand Up @@ -153,21 +154,21 @@ class TSslSocket {
}
}

if (LogFunc) {
LogFunc("SSL Handshake established\n");
if (Ctx.LogFunc) {
Ctx.LogFunc("SSL Handshake established\n");
}
}

void LogState() {
if (!LogFunc) return;
if (!Ctx.LogFunc) return;

char buf[1024];

const char * state = SSL_state_string_long(Ssl);
if (state != LastState) {
if (state) {
snprintf(buf, sizeof(buf), "SSL-STATE: %s", state);
LogFunc(buf);
Ctx.LogFunc(buf);
}
LastState = state;
}
Expand All @@ -180,7 +181,6 @@ class TSslSocket {
BIO* Rbio = nullptr;
BIO* Wbio = nullptr;

std::function<void(const char*)> LogFunc = {};
const char* LastState = nullptr;
};

Expand Down
11 changes: 6 additions & 5 deletions examples/sslechoclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using namespace NNet;

template<bool debug, typename TPoller>
TVoidSuspendedTask client(TPoller& poller, TAddress addr)
TVoidSuspendedTask client(TPoller& poller, TSslContext& ctx, TAddress addr)
{
static constexpr int maxLineSize = 4096;
using TSocket = typename TPoller::TSocket;
Expand All @@ -15,8 +15,7 @@ TVoidSuspendedTask client(TPoller& poller, TAddress addr)
try {
TFileHandle input{0, poller}; // stdin
TSocket socket{std::move(addr), poller};
TSslContext ctx = TSslContext::Client();
TSslSocket sslSocket(socket, ctx, [&](const char* s) { std::cerr << s << "\n"; });
TSslSocket sslSocket(socket, ctx);
TLineReader lineReader(input, maxLineSize);
TByteWriter byteWriter(sslSocket);
TByteReader byteReader(sslSocket);
Expand All @@ -42,9 +41,11 @@ void run(bool debug, TAddress address)
NNet::TLoop<TPoller> loop;
NNet::THandle h;
if (debug) {
h = client<true>(loop.Poller(), std::move(address));
TSslContext ctx = TSslContext::Client([&](const char* s) { std::cerr << s << "\n"; });
h = client<true>(loop.Poller(), ctx, std::move(address));
} else {
h = client<false>(loop.Poller(), std::move(address));
TSslContext ctx = TSslContext::Client();
h = client<false>(loop.Poller(), ctx, std::move(address));
}
while (!h.done()) {
loop.Step();
Expand Down
13 changes: 9 additions & 4 deletions examples/sslechoserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ using NNet::TKqueue;
#endif

template<bool debug, typename TSocket>
TVoidTask client_handler(TSocket socket, NNet::TSslContext& ctx, int buffer_size) {
TVoidTask clientHandler(TSocket socket, NNet::TSslContext& ctx, int buffer_size) {
std::vector<char> buffer(buffer_size); ssize_t size = 0;

try {
NNet::TSslSocket<TSocket> sslSocket(socket, ctx, [&](const char* s) { std::cerr << s << "\n"; });
NNet::TSslSocket<TSocket> sslSocket(socket, ctx);

co_await sslSocket.Accept();
while ((size = co_await sslSocket.ReadSome(buffer.data(), buffer_size)) > 0) {
Expand All @@ -40,7 +40,12 @@ TVoidTask client_handler(TSocket socket, NNet::TSslContext& ctx, int buffer_size
template<bool debug, typename TPoller>
TVoidTask server(TPoller& poller, TAddress address, int buffer_size)
{
NNet::TSslContext ctx = NNet::TSslContext::Server("server.crt", "server.key");
std::function<void(const char*)> sslDebugLogFunc;
if constexpr(debug) {
sslDebugLogFunc = [](const char* s) { std::cerr << s << "\n"; };
}

NNet::TSslContext ctx = NNet::TSslContext::Server("server.crt", "server.key", sslDebugLogFunc);
typename TPoller::TSocket socket(std::move(address), poller);
socket.Bind();
socket.Listen();
Expand All @@ -50,7 +55,7 @@ TVoidTask server(TPoller& poller, TAddress address, int buffer_size)
if constexpr (debug) {
std::cerr << "Accepted\n";
}
client_handler<debug>(std::move(client), ctx, buffer_size);
clientHandler<debug>(std::move(client), ctx, buffer_size);
}
co_return;
}
Expand Down

0 comments on commit 71cd890

Please sign in to comment.