Skip to content

Commit

Permalink
Name and unpair individual clients (#2042)
Browse files Browse the repository at this point in the history
  • Loading branch information
xanderfrangos authored May 27, 2024
1 parent 287ac4c commit 5fcd07e
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 96 deletions.
61 changes: 59 additions & 2 deletions src/confighttp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,8 @@ namespace confighttp {
// TODO: Input Validation
pt::read_json(ss, inputTree);
std::string pin = inputTree.get<std::string>("pin");
outputTree.put("status", nvhttp::pin(pin));
std::string name = inputTree.get<std::string>("name");
outputTree.put("status", nvhttp::pin(pin, name));
}
catch (std::exception &e) {
BOOST_LOG(warning) << "SavePin: "sv << e.what();
Expand All @@ -717,6 +718,60 @@ namespace confighttp {
response->write(data.str());
});
nvhttp::erase_all_clients();
proc::proc.terminate();
outputTree.put("status", true);
}

void
unpair(resp_https_t response, req_https_t request) {
if (!authenticate(response, request)) return;

print_req(request);

std::stringstream ss;
ss << request->content.rdbuf();

pt::ptree inputTree, outputTree;

auto g = util::fail_guard([&]() {
std::ostringstream data;
pt::write_json(data, outputTree);
response->write(data.str());
});

try {
// TODO: Input Validation
pt::read_json(ss, inputTree);
std::string uuid = inputTree.get<std::string>("uuid");
outputTree.put("status", nvhttp::unpair_client(uuid));
}
catch (std::exception &e) {
BOOST_LOG(warning) << "Unpair: "sv << e.what();
outputTree.put("status", false);
outputTree.put("error", e.what());
return;
}
}

void
listClients(resp_https_t response, req_https_t request) {
if (!authenticate(response, request)) return;

print_req(request);

pt::ptree named_certs = nvhttp::get_all_clients();

pt::ptree outputTree;

outputTree.put("status", false);

auto g = util::fail_guard([&]() {
std::ostringstream data;
pt::write_json(data, outputTree);
response->write(data.str());
});

outputTree.add_child("named_certs", named_certs);
outputTree.put("status", true);
}

Expand Down Expand Up @@ -765,7 +820,9 @@ namespace confighttp {
server.resource["^/api/restart$"]["POST"] = restart;
server.resource["^/api/password$"]["POST"] = savePassword;
server.resource["^/api/apps/([0-9]+)$"]["DELETE"] = deleteApp;
server.resource["^/api/clients/unpair$"]["POST"] = unpairAll;
server.resource["^/api/clients/unpair-all$"]["POST"] = unpairAll;
server.resource["^/api/clients/list$"]["GET"] = listClients;
server.resource["^/api/clients/unpair$"]["POST"] = unpair;
server.resource["^/api/apps/close$"]["POST"] = closeApp;
server.resource["^/api/covers/upload$"]["POST"] = uploadCover;
server.resource["^/images/sunshine.ico$"]["GET"] = getFaviconImage;
Expand Down
176 changes: 132 additions & 44 deletions src/nvhttp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,15 @@ namespace nvhttp {
std::string pkey;
} conf_intern;

struct named_cert_t {
std::string name;
std::string uuid;
std::string cert;
};

struct client_t {
std::string uniqueID;
std::vector<std::string> certs;
std::vector<named_cert_t> named_devices;
};

struct pair_session_t {
Expand All @@ -145,7 +151,7 @@ namespace nvhttp {

// uniqueID, session
std::unordered_map<std::string, pair_session_t> map_id_sess;
std::unordered_map<std::string, client_t> map_id_client;
client_t client_root;
std::atomic<uint32_t> session_id_counter;

using args_t = SimpleWeb::CaseInsensitiveMultimap;
Expand Down Expand Up @@ -189,22 +195,18 @@ namespace nvhttp {
root.erase("root"s);

root.put("root.uniqueid", http::unique_id);
auto &nodes = root.add_child("root.devices", pt::ptree {});
for (auto &[_, client] : map_id_client) {
pt::ptree node;

node.put("uniqueid"s, client.uniqueID);

pt::ptree cert_nodes;
for (auto &cert : client.certs) {
pt::ptree cert_node;
cert_node.put_value(cert);
cert_nodes.push_back(std::make_pair(""s, cert_node));
}
node.add_child("certs"s, cert_nodes);

nodes.push_back(std::make_pair(""s, node));
client_t &client = client_root;
pt::ptree node;

pt::ptree named_cert_nodes;
for (auto &named_cert : client.named_devices) {
pt::ptree named_cert_node;
named_cert_node.put("name"s, named_cert.name);
named_cert_node.put("cert"s, named_cert.cert);
named_cert_node.put("uuid"s, named_cert.uuid);
named_cert_nodes.push_back(std::make_pair(""s, named_cert_node));
}
root.add_child("root.named_devices"s, named_cert_nodes);

try {
pt::write_json(config::nvhttp.file_state, root);
Expand All @@ -223,48 +225,79 @@ namespace nvhttp {
return;
}

pt::ptree root;
pt::ptree tree;
try {
pt::read_json(config::nvhttp.file_state, root);
pt::read_json(config::nvhttp.file_state, tree);
}
catch (std::exception &e) {
BOOST_LOG(error) << "Couldn't read "sv << config::nvhttp.file_state << ": "sv << e.what();

return;
}

auto unique_id_p = root.get_optional<std::string>("root.uniqueid");
auto unique_id_p = tree.get_optional<std::string>("root.uniqueid");
if (!unique_id_p) {
// This file doesn't contain moonlight credentials
http::unique_id = uuid_util::uuid_t::generate().string();
return;
}
http::unique_id = std::move(*unique_id_p);

auto device_nodes = root.get_child("root.devices");

for (auto &[_, device_node] : device_nodes) {
auto uniqID = device_node.get<std::string>("uniqueid");
auto &client = map_id_client.emplace(uniqID, client_t {}).first->second;

client.uniqueID = uniqID;
auto root = tree.get_child("root");
client_t client;

// Import from old format
if (root.get_child_optional("devices")) {
auto device_nodes = root.get_child("devices");
for (auto &[_, device_node] : device_nodes) {
auto uniqID = device_node.get<std::string>("uniqueid");

if (device_node.count("certs")) {
for (auto &[_, el] : device_node.get_child("certs")) {
named_cert_t named_cert;
named_cert.name = ""s;
named_cert.cert = el.get_value<std::string>();
named_cert.uuid = uuid_util::uuid_t::generate().string();
client.named_devices.emplace_back(named_cert);
client.certs.emplace_back(named_cert.cert);
}
}
}
}

for (auto &[_, el] : device_node.get_child("certs")) {
client.certs.emplace_back(el.get_value<std::string>());
if (root.count("named_devices")) {
for (auto &[_, el] : root.get_child("named_devices")) {
named_cert_t named_cert;
named_cert.name = el.get_child("name").get_value<std::string>();
named_cert.cert = el.get_child("cert").get_value<std::string>();
named_cert.uuid = el.get_child("uuid").get_value<std::string>();
client.named_devices.emplace_back(named_cert);
client.certs.emplace_back(named_cert.cert);
}
}

// Empty certificate chain and import certs from file
cert_chain.clear();
for (auto &cert : client.certs) {
cert_chain.add(crypto::x509(cert));
}
for (auto &named_cert : client.named_devices) {
cert_chain.add(crypto::x509(named_cert.cert));
}

client_root = client;
}

void
update_id_client(const std::string &uniqueID, std::string &&cert, op_e op) {
switch (op) {
case op_e::ADD: {
auto &client = map_id_client[uniqueID];
client_t &client = client_root;
client.certs.emplace_back(std::move(cert));
client.uniqueID = uniqueID;
} break;
case op_e::REMOVE:
map_id_client.erase(uniqueID);
client_t client;
client_root = client;
break;
}

Expand Down Expand Up @@ -579,15 +612,16 @@ namespace nvhttp {
/**
* @brief Compare the user supplied pin to the Moonlight pin.
* @param pin The user supplied pin.
* @param name The user supplied name.
* @return `true` if the pin is correct, `false` otherwise.
*
* EXAMPLES:
* ```cpp
* bool pin_status = nvhttp::pin("1234");
* bool pin_status = nvhttp::pin("1234", "laptop");
* ```
*/
bool
pin(std::string pin) {
pin(std::string pin, std::string name) {
pt::ptree tree;
if (map_id_sess.empty()) {
return false;
Expand All @@ -613,6 +647,14 @@ namespace nvhttp {
auto &sess = std::begin(map_id_sess)->second;
getservercert(sess, tree, pin);

// set up named cert
client_t &client = client_root;
named_cert_t named_cert;
named_cert.name = name;
named_cert.cert = sess.client.cert;
named_cert.uuid = uuid_util::uuid_t::generate().string();
client.named_devices.emplace_back(named_cert);

// response to the request for pin
std::ostringstream data;
pt::write_xml(data, tree);
Expand Down Expand Up @@ -645,9 +687,7 @@ namespace nvhttp {
auto clientID = args.find("uniqueid"s);

if (clientID != std::end(args)) {
if (auto it = map_id_client.find(clientID->second); it != std::end(map_id_client)) {
pair_status = 1;
}
pair_status = 1;
}
}

Expand Down Expand Up @@ -742,6 +782,20 @@ namespace nvhttp {
response->close_connection_after_response = true;
}

pt::ptree
get_all_clients() {
pt::ptree named_cert_nodes;
client_t &client = client_root;
for (auto &named_cert : client.named_devices) {
pt::ptree named_cert_node;
named_cert_node.put("name"s, named_cert.name);
named_cert_node.put("uuid"s, named_cert.uuid);
named_cert_nodes.push_back(std::make_pair(""s, named_cert_node));
}

return named_cert_nodes;
}

void
applist(resp_https_t response, req_https_t request) {
print_req<SimpleWeb::HTTPS>(request);
Expand Down Expand Up @@ -1020,12 +1074,6 @@ namespace nvhttp {
conf_intern.pkey = file_handler::read_file(config::nvhttp.pkey.c_str());
conf_intern.servercert = file_handler::read_file(config::nvhttp.cert.c_str());

for (auto &[_, client] : map_id_client) {
for (auto &cert : client.certs) {
cert_chain.add(crypto::x509(cert));
}
}

auto add_cert = std::make_shared<safe::queue_t<crypto::x509_t>>(30);

// resume doesn't always get the parameter "localAudioPlayMode"
Expand Down Expand Up @@ -1149,8 +1197,48 @@ namespace nvhttp {
*/
void
erase_all_clients() {
map_id_client.clear();
client_t client;
client_root = client;
cert_chain.clear();
save_state();
}

/**
* @brief Remove single client.
*
* EXAMPLES:
* ```cpp
* nvhttp::unpair_client("4D7BB2DD-5704-A405-B41C-891A022932E1");
* ```
*/
int
unpair_client(std::string uuid) {
int removed = 0;
client_t &client = client_root;
for (auto it = client.named_devices.begin(); it != client.named_devices.end();) {
if ((*it).uuid == uuid) {
// Find matching cert and remove it
for (auto cert = client.certs.begin(); cert != client.certs.end();) {
if ((*cert) == (*it).cert) {
cert = client.certs.erase(cert);
removed++;
}
else {
++cert;
}
}

// And then remove the named cert
it = client.named_devices.erase(it);
removed++;
}
else {
++it;
}
}

save_state();
load_state();
return removed;
}
} // namespace nvhttp
9 changes: 8 additions & 1 deletion src/nvhttp.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
// standard includes
#include <string>

// lib includes
#include <boost/property_tree/ptree.hpp>

// local includes
#include "thread_safe.h"

Expand Down Expand Up @@ -43,7 +46,11 @@ namespace nvhttp {
void
start();
bool
pin(std::string pin);
pin(std::string pin, std::string name);
int
unpair_client(std::string uniqueid);
boost::property_tree::ptree
get_all_clients();
void
erase_all_clients();
} // namespace nvhttp
Loading

0 comments on commit 5fcd07e

Please sign in to comment.