diff --git a/coroio/ssl.cpp b/coroio/ssl.cpp index 9e7b5a8..f51bd66 100644 --- a/coroio/ssl.cpp +++ b/coroio/ssl.cpp @@ -18,15 +18,17 @@ TSslContext::~TSslContext() { SSL_CTX_free(Ctx); } -TSslContext TSslContext::Client() { +TSslContext TSslContext::Client(const std::function& logFunc) { TSslContext ctx; ctx.Ctx = SSL_CTX_new(TLS_client_method()); + ctx.LogFunc = logFunc; return ctx; } -TSslContext TSslContext::Server(const char* certfile, const char* keyfile) { +TSslContext TSslContext::Server(const char* certfile, const char* keyfile, const std::function& logFunc) { TSslContext ctx; ctx.Ctx = SSL_CTX_new(TLS_server_method()); + ctx.LogFunc = logFunc; if (SSL_CTX_use_certificate_file(ctx.Ctx, certfile, SSL_FILETYPE_PEM) != 1) { throw std::runtime_error("SSL_CTX_use_certificate_file failed"); diff --git a/coroio/ssl.hpp b/coroio/ssl.hpp index 87e55ff..f25e517 100644 --- a/coroio/ssl.hpp +++ b/coroio/ssl.hpp @@ -15,17 +15,19 @@ namespace NNet { struct TSslContext { SSL_CTX* Ctx; + std::function LogFunc = {}; TSslContext(TSslContext&& other) : Ctx(other.Ctx) + , LogFunc(other.LogFunc) { other.Ctx = nullptr; } ~TSslContext(); - static TSslContext Client(); - static TSslContext Server(const char* certfile, const char* keyfile); + static TSslContext Client(const std::function& logFunc = {}); + static TSslContext Server(const char* certfile, const char* keyfile, const std::function& logFunc = {}); private: TSslContext(); @@ -34,13 +36,12 @@ struct TSslContext { template class TSslSocket { public: - TSslSocket(THandle& socket, TSslContext& ctx, const std::function& logFunc = {}) + TSslSocket(THandle& socket, TSslContext& ctx) : Socket(socket) , Ctx(ctx) , Ssl(SSL_new(Ctx.Ctx)) , Rbio(BIO_new(BIO_s_mem())) , Wbio(BIO_new(BIO_s_mem())) - , LogFunc(logFunc) { SSL_set_bio(Ssl, Rbio, Wbio); } @@ -153,13 +154,13 @@ class TSslSocket { } } - if (LogFunc) { - LogFunc("SSL Handshake established\n"); + if (Ctx.LogFunc) { + Ctx.LogFunc("SSL Handshake established\n"); } } void LogState() { - if (!LogFunc) return; + if (!Ctx.LogFunc) return; char buf[1024]; @@ -167,7 +168,7 @@ class TSslSocket { if (state != LastState) { if (state) { snprintf(buf, sizeof(buf), "SSL-STATE: %s", state); - LogFunc(buf); + Ctx.LogFunc(buf); } LastState = state; } @@ -180,7 +181,6 @@ class TSslSocket { BIO* Rbio = nullptr; BIO* Wbio = nullptr; - std::function LogFunc = {}; const char* LastState = nullptr; }; diff --git a/examples/sslechoclient.cpp b/examples/sslechoclient.cpp index 4e46152..a68ed97 100644 --- a/examples/sslechoclient.cpp +++ b/examples/sslechoclient.cpp @@ -5,7 +5,7 @@ using namespace NNet; template -TVoidSuspendedTask client(TPoller& poller, TAddress addr) +TVoidSuspendedTask client(TPoller& poller, TSslContext& ctx, TAddress addr) { static constexpr int maxLineSize = 4096; using TSocket = typename TPoller::TSocket; @@ -15,8 +15,7 @@ TVoidSuspendedTask client(TPoller& poller, TAddress addr) try { TFileHandle input{0, poller}; // stdin TSocket socket{std::move(addr), poller}; - TSslContext ctx = TSslContext::Client(); - TSslSocket sslSocket(socket, ctx, [&](const char* s) { std::cerr << s << "\n"; }); + TSslSocket sslSocket(socket, ctx); TLineReader lineReader(input, maxLineSize); TByteWriter byteWriter(sslSocket); TByteReader byteReader(sslSocket); @@ -42,9 +41,11 @@ void run(bool debug, TAddress address) NNet::TLoop loop; NNet::THandle h; if (debug) { - h = client(loop.Poller(), std::move(address)); + TSslContext ctx = TSslContext::Client([&](const char* s) { std::cerr << s << "\n"; }); + h = client(loop.Poller(), ctx, std::move(address)); } else { - h = client(loop.Poller(), std::move(address)); + TSslContext ctx = TSslContext::Client(); + h = client(loop.Poller(), ctx, std::move(address)); } while (!h.done()) { loop.Step(); diff --git a/examples/sslechoserver.cpp b/examples/sslechoserver.cpp index 04a08a4..3586be5 100644 --- a/examples/sslechoserver.cpp +++ b/examples/sslechoserver.cpp @@ -15,11 +15,11 @@ using NNet::TKqueue; #endif template -TVoidTask client_handler(TSocket socket, NNet::TSslContext& ctx, int buffer_size) { +TVoidTask clientHandler(TSocket socket, NNet::TSslContext& ctx, int buffer_size) { std::vector buffer(buffer_size); ssize_t size = 0; try { - NNet::TSslSocket sslSocket(socket, ctx, [&](const char* s) { std::cerr << s << "\n"; }); + NNet::TSslSocket sslSocket(socket, ctx); co_await sslSocket.Accept(); while ((size = co_await sslSocket.ReadSome(buffer.data(), buffer_size)) > 0) { @@ -40,7 +40,12 @@ TVoidTask client_handler(TSocket socket, NNet::TSslContext& ctx, int buffer_size template TVoidTask server(TPoller& poller, TAddress address, int buffer_size) { - NNet::TSslContext ctx = NNet::TSslContext::Server("server.crt", "server.key"); + std::function sslDebugLogFunc; + if constexpr(debug) { + sslDebugLogFunc = [](const char* s) { std::cerr << s << "\n"; }; + } + + NNet::TSslContext ctx = NNet::TSslContext::Server("server.crt", "server.key", sslDebugLogFunc); typename TPoller::TSocket socket(std::move(address), poller); socket.Bind(); socket.Listen(); @@ -50,7 +55,7 @@ TVoidTask server(TPoller& poller, TAddress address, int buffer_size) if constexpr (debug) { std::cerr << "Accepted\n"; } - client_handler(std::move(client), ctx, buffer_size); + clientHandler(std::move(client), ctx, buffer_size); } co_return; }