Skip to content

Commit

Permalink
Unittests with ssl
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Dec 9, 2023
1 parent 09dced0 commit 64b0dc0
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 0 deletions.
23 changes: 23 additions & 0 deletions coroio/ssl.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "ssl.hpp"
#include <stdexcept>
#include <assert.h>

namespace NNet {

Expand Down Expand Up @@ -45,4 +46,26 @@ TSslContext TSslContext::Server(const char* certfile, const char* keyfile, const
return ctx;
}

TSslContext TSslContext::ServerFromMem(const void* certMem, const void* keyMem, const std::function<void(const char*)>& logFunc) {
TSslContext ctx;
ctx.Ctx = SSL_CTX_new(TLS_server_method());
ctx.LogFunc = logFunc;

auto cbio = std::shared_ptr<BIO>(BIO_new_mem_buf(certMem, -1), BIO_free);
auto cert = std::shared_ptr<X509>(PEM_read_bio_X509(cbio.get(), NULL, 0, nullptr), X509_free);
if (!cert) {
throw std::runtime_error("Cannot load X509 certificate");
}
SSL_CTX_use_certificate(ctx.Ctx, cert.get());

auto kbio = std::shared_ptr<BIO>(BIO_new_mem_buf(keyMem, -1), BIO_free);
auto key = std::shared_ptr<EVP_PKEY>(PEM_read_bio_PrivateKey(kbio.get(), NULL, 0, nullptr), EVP_PKEY_free);
if (!key) {
throw std::runtime_error("Cannot load Key");
}
SSL_CTX_use_PrivateKey(ctx.Ctx, key.get());

return ctx;
}

} // namespace NNet
2 changes: 2 additions & 0 deletions coroio/ssl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "base.hpp"
#include "corochain.hpp"
#include "sockutils.hpp"

namespace NNet {

Expand All @@ -28,6 +29,7 @@ struct TSslContext {

static TSslContext Client(const std::function<void(const char*)>& logFunc = {});
static TSslContext Server(const char* certfile, const char* keyfile, const std::function<void(const char*)>& logFunc = {});
static TSslContext ServerFromMem(const void* certfile, const void* keyfile, const std::function<void(const char*)>& logFunc = {});

private:
TSslContext();
Expand Down
21 changes: 21 additions & 0 deletions tests/server.crt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
static const char* testMemCert = R"__(
-----BEGIN CERTIFICATE-----
MIIDETCCAfkCFE795Bs5d55BoKzfkMSCs28/mX7OMA0GCSqGSIb3DQEBCwUAMEUx
CzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl
cm5ldCBXaWRnaXRzIFB0eSBMdGQwHhcNMjMxMjA5MjAxOTUxWhcNMjQxMjA4MjAx
OTUxWjBFMQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UE
CgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOC
AQ8AMIIBCgKCAQEA1i70LSK8y88n/etvRj1b516Xlzqhx7xjm1/GvVHC9h7dUXzY
kq4Nx0Snf2L1PF1JhhgEwTzKJfAXYHzeHdKOaa08eC7r0tfikxIkXpuLJrccyfVU
IQxu62cj4LCjMeR3grjjuJNmoZ86ISA3D22Fo9SuiMw+9wqhWEYmM0f0EhJ2tdAd
Tu9XGikF7T4R+n/G6kmz9uiKwu73FOcbOJ5AcKxBY046lvKuEypkmv3tqxPwJF/i
T3W/pgMd9qhP3eNylEQCaaIiVmMpOQN/1RpQ0JWLAchvYVMrOJcjONtlGBJ0YufP
SrsesRpucsBFtRRRGm/jIWQxhALxPxqTV5F56QIDAQABMA0GCSqGSIb3DQEBCwUA
A4IBAQAx6jo7IYH+qh3cKtkqp7uS5MPatupIuqaFuZzOznVEtlM0HatyagjjAMtY
LDI55gRLa9hFK1+qh3hcNHj/qJkDdsUVaip+7PEtqbmPTt8c9clnLruDgw3QDdBU
0aRb4OZdL63oSDIfGGq5DPSlKoyLzQGtuikmKtKYd+RLbv3wNt5yekWjLXUT/jcP
9FFFKaCbEwLz4JI0juatUtpljKS4mnllXWXBPbsLL0WX5QGZllH0iDQwxYnDz78s
HMxbEX94FY7aDOujIqv/N12R6IpWzw/FyNKrAqiFKbdPCdhrnxl9Y7AWqLR8aboG
7fUdkUW7zoOdH8alcCgHMDNRpfpz
-----END CERTIFICATE-----
)__";
30 changes: 30 additions & 0 deletions tests/server.key
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
static const char* testMemKey = R"__(
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDWLvQtIrzLzyf9
629GPVvnXpeXOqHHvGObX8a9UcL2Ht1RfNiSrg3HRKd/YvU8XUmGGATBPMol8Bdg
fN4d0o5prTx4LuvS1+KTEiRem4smtxzJ9VQhDG7rZyPgsKMx5HeCuOO4k2ahnzoh
IDcPbYWj1K6IzD73CqFYRiYzR/QSEna10B1O71caKQXtPhH6f8bqSbP26IrC7vcU
5xs4nkBwrEFjTjqW8q4TKmSa/e2rE/AkX+JPdb+mAx32qE/d43KURAJpoiJWYyk5
A3/VGlDQlYsByG9hUys4lyM422UYEnRi589Kux6xGm5ywEW1FFEab+MhZDGEAvE/
GpNXkXnpAgMBAAECggEAUDzL1cLfNnz3Lu1NxOMEtHMf2BQrej+Nky34roDcSEa8
w6PBIIYa/E0wcIz6cTBDdHw3/8pNspO0tj1hGowANP+kmSN+zgB5TX5s6JJduVW8
77271Bury/1aF/kkUfMUgIDSMpnpx192r+U5K0rs1zi8X9wgNH0jf4XcFrb8bO4V
TWh3aS4U+pNuwDytiCgkAmLoR5qgMQxrLR1m5bhwBu6wmjhkgJl4vj8KOH0VQZxB
iQDjK6WhhJ5ruNwOdsro6gYEuImm/2Wso4tz/AqR4PC6+XbAIN5wihcI02ZJivp7
ZpFbnSz+6NIL3vZ87iaweGU+E7p6yt75RuYIjeJ3fwKBgQDw9fX1g06SqO88b/J/
V8FIZMrH9+85/zdrFc/1Lpch5oaJXAuJTH4QeCKmAyY82U5PS/X9CDhYqrKdhJIH
rvyusqlyfCYhtpLeMgFnMbMg8CTvHLUQyCrVvcSQPyEs9F9hy3lSKkAKGppWRfgK
yRBvc/f6H2lcaJtVeQZroaVqkwKBgQDjjSXU7ME9o7w/PuUryO/H/YHFnSXvXQ+4
kQra3oR0iHC0Z+i0Ze+yxS8UuGncIvpR55NODI2fA51zLJ56FAmUNvMwj6m63eEB
3RXRl/QeuV5Yy27kYFBP7VOZdjmw0x/Gp+tXLAm8xJP4vLwCwU9UlOgxvJ5QGzGC
6BKrdRjLEwKBgQDLbrzDFKKni1y/Z7wR6uLR3dad0SL1khUVoYq68yTBiECZg05y
ElR0Txjhk9MamFRW+kip4eDAawz1k9E+D2xhiZEpiMsgt2VzlkA9AWa8LkLgZRox
Gu2fGuHy7nlx3LcSd5jr16PNY/xdTiFF6c6oaf43+4EWdXJ/TPgwsn5XZQKBgQDH
dWPh/h1s0Gcj8Rekh59W6CmmdJdZ93LeT5T6QO5Nz4MrP6HE701qoFkiinuQUMCm
ppyCX5KL/fk3ibboP0QePQRyXptihzbCEW8cp1t+yvGeV8O+P4ZmaRtMe0saahWC
ZpJteNaYNp+V+qm6qIPHGjdl0XXbtdpyasZisGOpLQKBgC8u2N1RxmZl9RhudEAx
840YvhFMv++BXVBIU8y3eQhsHgW4wKdS3kCdNqL5oNveJI/HMvVgcPoauq+iBOz2
MZAk4XfpVZz8VQbMRJuvgslHCr3BaytsIHg4M3mvn39rvkbfzaYTg3lJpNEdVuh1
bFSKGNOZ1Uqdrg2UTW66K8HV
-----END PRIVATE KEY-----
)__";
53 changes: 53 additions & 0 deletions tests/tests.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "coroio/ssl.hpp"
#include <chrono>
#include <array>
#include <exception>
Expand All @@ -12,6 +13,9 @@ extern "C" {
#include <cmocka.h>
}

#include "server.crt"
#include "server.key"

namespace {

static uint32_t rand_(uint32_t* seed) {
Expand Down Expand Up @@ -622,6 +626,54 @@ void test_zero_copy_line_splitter(void**) {
}
}

template<typename TPoller>
void test_read_write_full_ssl(void**) {
using TLoop = TLoop<TPoller>;
using TSocket = typename TPoller::TSocket;

std::vector<char> data(1024*1024);
int cur = 0;
for (auto& ch : data) {
ch = cur + 'a';
cur = (cur + 1) % ('z' - 'a' + 1);
}

TLoop loop;
TSocket socket(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller());
socket.Bind();
socket.Listen();

TSocket client(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller());

NNet::TVoidSuspendedTask h1 = [](TSocket& client, const std::vector<char>& data) -> NNet::TVoidSuspendedTask
{
TSslContext ctx = TSslContext::Client();
auto sslClient = TSslSocket(client, ctx);
co_await sslClient.Connect();
co_await TByteWriter(sslClient).Write(data.data(), data.size());
co_return;
}(client, data);

std::vector<char> received(1024*1024);
NNet::TVoidSuspendedTask h2 = [](TSocket& server, std::vector<char>& received) -> NNet::TVoidSuspendedTask
{
TSslContext ctx = TSslContext::ServerFromMem(testMemCert, testMemKey);
auto client = std::move(co_await server.Accept());
auto sslClient = TSslSocket(client, ctx);
co_await sslClient.Accept();
co_await TByteReader(sslClient).Read(received.data(), received.size());
co_return;
}(socket, received);

while (!(h1.done() && h2.done())) {
loop.Step();
}

assert_memory_equal(data.data(), received.data(), data.size());

h1.destroy(); h2.destroy();
}

#ifdef __linux__

namespace {
Expand Down Expand Up @@ -807,6 +859,7 @@ int main() {
my_unit_poller(test_read_write_full),
my_unit_poller(test_read_write_struct),
my_unit_poller(test_read_write_lines),
my_unit_poller(test_read_write_full_ssl),
#ifdef __linux__
cmocka_unit_test(test_uring_create),
cmocka_unit_test(test_uring_write),
Expand Down

0 comments on commit 64b0dc0

Please sign in to comment.