diff --git a/src/epiworld-common.hpp b/src/epiworld-common.hpp index 84da898..45bf22b 100644 --- a/src/epiworld-common.hpp +++ b/src/epiworld-common.hpp @@ -1,87 +1,273 @@ #ifndef DEFM_COMMON_H #define DEFM_COMMON_H -using namespace pybind11::literals; +#include +#include -inline void pyprinter(const char * fmt, ...) -{ +#ifdef _WIN32 +#include +#include +#define unlink _unlink +#else +#include +#include +#endif + +#if !defined(_WIN32) && (defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__))) +#define EPIWORLD_PLATFORM_UNIX +#elif defined(__CYGWIN__) && !defined(_WIN32) +#define EPIWORLD_PLATFORM_UNIXISH +#elif defined(_WIN32) || defined(_WIN64) +#define EPIWORLD_PLATFORM_WINDOWS +#else +#define EPIWORLD_PLATFORM_UNKNOWN +#endif + +#ifdef EPIWORLD_PLATFORM_WINDOWS +const char EPIWORLD_OS_PATHSEP = '\\'; +#else +const char EPIWORLD_OS_PATHSEP = '/'; +#endif - // Creating a buffer +namespace epiworld { +inline void pyprinter(const char * fmt, ...) { char buffer[1024]; va_list args; va_start(args, fmt); - vsprintf(&buffer[0], fmt, args); + vsnprintf(buffer, sizeof(buffer), fmt, args); va_end(args); - // Passing to pyprint pybind11::print(std::string(buffer), pybind11::arg("end") = ""); +} +} + +#define printf_epiworld epiworld::pyprinter +#include "epiworld.hpp" + +namespace epiworld { +class Saver { +private: + std::function*)> fun; + std::vector what; + std::string fn; + std::string id; + bool file_output; + +public: + Saver( + std::vector what, + std::string fn, + std::string id, + bool file_output); + + void unlink_siblings() const; + const std::ostream& out(std::ostream &stream) const; + std::function*)> operator*(); +}; + +inline std::ostream& operator<<(std::ostream &stream, const Saver& data) { + data.out(stream); + return stream; } -#define printf_epiworld pyprinter +static std::string parse_kwarg_string(const pybind11::kwargs& kwargs, const char* key, const std::string& _default) { + PyObject* item = PyDict_GetItemString(kwargs.ptr(), key); -#include "epiworld.hpp" -// #include "models/defm.hpp" - -// inline void check_covar( -// int & idx_, -// std::string & idx, -// std::shared_ptr< defm::DEFM > & ptr -// ) { - -// // Retrieving the matching covariate -// if (idx != "") -// { - -// // Getting the covariate names -// auto cnames = ptr->get_X_names(); - -// // Can we find it? -// for (size_t i = 0u; i < cnames.size(); ++i) { -// if (cnames[i] == idx) -// { -// idx_ = i; -// break; -// } -// } - -// if (idx_ < 0) -// throw std::range_error( -// "The variable " + idx + "does not exists." -// ); - -// } - -// } - -// #define DEFM_DEFINE_ACCESS(object) \ -// std::function element_access; \ -// if ((object)->get_column_major()) \ -// { \ -// element_access = [](size_t i, size_t j, size_t nrow, size_t) -> size_t { \ -// return i + j * nrow; \ -// }; \ -// } else { \ -// element_access = [](size_t i, size_t j, size_t, size_t ncol) -> size_t { \ -// return j + i * ncol; \ -// }; \ -// } - - -// /** -// * @brief Create a numpy array from a pointer -// * @param res The numpy array -// * @param ptr The pointer -// * @param nrows The number of rows -// * @param ncols The number of columns -// * @param type_ The type of the array -// */ -// #define DEFM_WRAP_NUMPY(var_res, var_ptr, nrows, ncols, type_) \ -// py::array_t< type_ > var_res ({nrows, ncols}); \ -// auto res_buff = var_res .request(); \ -// type_ * var_ptr = static_cast< type_ * >(res_buff.ptr); + if (item != nullptr) { + return std::string(PyBytes_AS_STRING(PyUnicode_AsEncodedString(PyObject_Str(item), "utf-8", "?"))); + } + + return _default; +} + +static int parse_kwarg_int(const pybind11::kwargs& kwargs, const char* key, int _default) { + PyObject* item = PyDict_GetItemString(kwargs.ptr(), key); + + if (item != nullptr) { + return PyLong_AsLong(item); + } + + return _default; +} + +static bool parse_kwarg_bool(const pybind11::kwargs& kwargs, const char* key, bool _default) { + PyObject* item = PyDict_GetItemString(kwargs.ptr(), key); + + if (item != nullptr) { + return item == Py_True; + } + + return _default; +} + +static std::string dirname(const std::string& filepath) { + struct stat s; + +#if EPIWORLD_PLATFORM_WINDOWS + if (_stat(filepath.c_str(), &s) == 0) { +#else + if (stat(filepath.c_str(), &s) == 0) { +#endif + if (s.st_mode & S_IFREG) { + goto treat_as_file; + } else { + return filepath; + } + } else { + goto treat_as_file; + } + +treat_as_file: + std::string directory; + const size_t last_slash_idx = filepath.rfind(EPIWORLD_OS_PATHSEP); + + if (std::string::npos != last_slash_idx) { + directory = filepath.substr(0, last_slash_idx); + } + + return directory; +} + +static std::vector get_files_in_dir(const std::string& directory) { + std::vector found; + +#ifdef _WIN32 + WIN32_FIND_DATA find_file_data; + HANDLE hFind = FindFirstFile((directory + "\\*").c_str(), &find_file_data); + if (hFind == INVALID_HANDLE_VALUE) { + std::cerr << "Error opening directory: " << GetLastError() << std::endl; + return found; + } + + do { + std::string file_name = find_file_data.cFileName; + if (!(find_file_data.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) { + found.push_back(directory + "\\" + file_name); + } + } while (FindNextFile(hFind, &find_file_data) != 0); + + FindClose(hFind); +#else + struct dirent* entry; + DIR* dir = opendir(directory.c_str()); + + if (dir == nullptr) { + throw std::runtime_error(directory + ": " + strerror(errno)); + } + + while ((entry = readdir(dir)) != nullptr) { + std::string file_name = entry->d_name; + if (entry->d_type != DT_DIR) { + found.push_back(directory + "/" + file_name); + } + } + + closedir(dir); +#endif + + return found; +} + +inline std::string temp_id(size_t len) { + const char alphanum[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, sizeof(alphanum) - 2); + std::string id; + + id.reserve(len); + + for (size_t i = 0; i < len; i++) { + id += alphanum[dis(gen)]; + } + + return id; +} + +inline std::string temp_directory_path() { + std::string env_to_check[] {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}; + + for (auto env : env_to_check) { + char const *result = getenv(env.c_str()); + + if (result != nullptr) { + return result; + } + } + + /* Otherwise, default to a value hardcoded per platform. */ +#if defined(EPIWORLD_PLATFORM_UNIX) || defined(EPIWORLD_PLATFORM_UNIXISH) + return "/tmp/"; +#elif defined(EPIWORLD_PLATFORM_WINDOWS) + /* I can't see us ever getting here, Windows isn't at heterogeneous as UNIX. */ + throw std::runtime_error("TEMP not defined on Windows, are you nuts!?"); +#elif EPIWORLD_PLATFORM_UNKNOWN + return ""; /* Current directory. */ +#endif +} + +inline Saver::Saver( + std::vector what, + std::string fn, + std::string id, + bool file_output) : + fun(epiworld::make_save_run( + fn, + std::find(what.begin(), what.end(), "total_hist") != what.end(), + std::find(what.begin(), what.end(), "virus_info") != what.end(), + std::find(what.begin(), what.end(), "virus_hist") != what.end(), + std::find(what.begin(), what.end(), "tool_info") != what.end(), + std::find(what.begin(), what.end(), "tool_hist") != what.end(), + std::find(what.begin(), what.end(), "transmission") != what.end(), + std::find(what.begin(), what.end(), "transition") != what.end(), + std::find(what.begin(), what.end(), "reproductive") != what.end(), + std::find(what.begin(), what.end(), "generation") != what.end() + )), + what(what), + fn(fn), + id(id), + file_output(file_output) {} + +inline void Saver::unlink_siblings() const { + auto dir = dirname(fn); + auto contestants = get_files_in_dir(dir); + + for (auto contestant : contestants) { + if (unlink(contestant.c_str()) != 0 && errno != ENOENT) { + throw std::runtime_error("Failed to remove file " + contestant + ": " + strerror(errno)); + } + } + } + +inline const std::ostream& Saver::out(std::ostream &stream) const { + stream << "A saver for -run_multiple-" << std::endl; + stream << "Saves the following: "; + + for (const auto whatum : what) { + stream << whatum; + + if (what.back() != whatum) { + stream << ", "; + } else { + stream << std::endl; + } + } + + stream << "To file : " << (file_output ? "yes" : "no") << std::endl; + if (file_output) { + stream << "Saver pattern : " << fn << std::endl; + } + + return stream; +} + +inline std::function*)> Saver::operator*() { + return fun; +} +} #endif diff --git a/src/epiworldpy/__init__.py b/src/epiworldpy/__init__.py index 5d15030..fcf624f 100644 --- a/src/epiworldpy/__init__.py +++ b/src/epiworldpy/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from ._core import __doc__, __version__, ModelSEIR +from ._core import __doc__, __version__, ModelSEIR, Saver -__all__ = ["__doc__", "__version__", "ModelSEIR"] +__all__ = ["__doc__", "__version__", "ModelSEIR", "Saver"] diff --git a/src/main.cpp b/src/main.cpp index ccc3282..f257534 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,11 +1,16 @@ +#include +#include +#include #include #include #include #include #include // for py::array_t #include "epiworld-common.hpp" +#include "model-bones.hpp" #include "misc.hpp" + using namespace epiworld; using namespace pybind11::literals; @@ -14,7 +19,6 @@ namespace py = pybind11; #define STRINGIFY(x) #x #define MACRO_STRINGIFY(x) STRINGIFY(x) - std::shared_ptr< epimodels::ModelSEIRCONN > ModelSEIR( std::string name, int n, @@ -40,7 +44,6 @@ std::shared_ptr< epimodels::ModelSEIRCONN > ModelSEIR( return object; } - PYBIND11_MODULE(_core, m) { m.doc() = R"pbdoc( epiworldpy wrapper @@ -54,11 +57,92 @@ PYBIND11_MODULE(_core, m) { ModelSEIR )pbdoc"; - // Only this is necessary to expose the class py::class_, std::shared_ptr>>(m, "Model") .def("get_name", &Model::get_name); - // Only this is necessary to expose the class + py::class_(m, "Saver") + .def(py::init([](py::args args, const py::kwargs& kwargs) { + /* TODO: Verify that this has the same effect as `make_saver` in: + * https://github.com/UofUEpiBio/epiworldR/blob/main/R/make_saver.R + */ + + bool file_output = true; + struct stat sinfo; + std::string fn = parse_kwarg_string(kwargs, "fn", ""); + std::string id = temp_id(5); + std::vector whats; + std::vector valid_whats = { + "total_hist", + "virus_info", + "virus_hist", + "tool_info", + "tool_hist", + "transmission", + "transition", + "reproductive", + "generation" + }; + + /* Make sure valid arguments are passed into this constructor, and marshall + * things out all the same. */ + for (auto arg : args) { + std::string whatum = arg.cast(); + + if (std::find(valid_whats.begin(), valid_whats.end(), whatum) == valid_whats.end()) { + throw std::invalid_argument("What '" + whatum + "' is not supported."); + } + + whats.push_back(whatum); + } + + /* Handle the filename. If only we have C++17's std::filesystem... */ + if (fn.empty()) { + int error = 0; + std::string norm = temp_directory_path() + EPIWORLD_OS_PATHSEP + "epiworldpy-" + id; + +#ifdef EPIWORLD_PLATFORM_WINDOWS + error = _mkdir(norm.c_str()); +#else + error = mkdir(norm.c_str(), 0733); +#endif + + if (error != 0) { + throw std::runtime_error(strerror(error)); + } + + fn = norm + EPIWORLD_OS_PATHSEP + "%05lu-episimulation.csv"; + file_output = false; +#if EPIWORLD_PLATFORM_WINDOWS + } else if (_stat(fn.c_str(), &sinfo) != 0) { +#else + } else if (stat(fn.c_str(), &sinfo) != 0) { +#endif + throw std::runtime_error("The directory \"" + fn + "\" does not exist."); + } + + return epiworld::Saver(whats, fn, id, file_output); + })) + .def("run_multiple", []( + Saver &self, + Model &model, + int ndays, + int nsims, + const py::kwargs& kwargs) { + int seed = parse_kwarg_int(kwargs, "seed", std::time(0)); + int nthreads = parse_kwarg_int(kwargs, "nthreads", 1); + bool reset = parse_kwarg_bool(kwargs, "reset", true); + bool verbose = parse_kwarg_bool(kwargs, "verbose", true); + + /* Do we have previously saved files? */ + self.unlink_siblings(); + + /* Dispatch! */ + model.run_multiple(ndays, nsims, seed, *self, reset, verbose, nthreads); + + /* EpiworldR does this so we do too. */ + return model; + }); + py::class_, std::shared_ptr>>(m, "DataBase") .def("get_hist_total", [](DataBase &self) { /* Lo, one of the times in modern C++ where the 'new' keyword isn't out of place. */ @@ -172,7 +256,7 @@ PYBIND11_MODULE(_core, m) { }); // Only this is necessary to expose the class - py::class_, std::shared_ptr>>(m, "ModelSEIRCONN") + py::class_, std::shared_ptr>, Model>(m, "ModelSEIRCONN") // .def(py::init<>()) .def("print", &epimodels::ModelSEIRCONN::print, // py::call_guard(), @@ -190,7 +274,6 @@ PYBIND11_MODULE(_core, m) { py::arg("seed") = 1u ) .def("get_db", [](epimodels::ModelSEIRCONN &self) { - //std::cout << "!!! " << self.get_db().get_model()->get_name() << std::endl; return std::shared_ptr>(&self.get_db(), [](DataBase*){ /* do nothing, no delete */ }); }, py::return_value_policy::reference); diff --git a/tests/test_saver.py b/tests/test_saver.py new file mode 100644 index 0000000..e505cf8 --- /dev/null +++ b/tests/test_saver.py @@ -0,0 +1,18 @@ +import epiworldpy as epiworld + +def test_saver_basic(): + covid19 = epiworld.ModelSEIR( + name = 'covid-19', + n = 10000, + prevalence = .01, + contact_rate = 2.0, + transmission_rate = .1, + incubation_days = 7.0, + recovery_rate = 0.14 + ) + + saver = epiworld.Saver("total_hist", "virus_hist") + + saver.run_multiple(covid19, 100, 4, nthreads=1) + + # TODO: Verify things worked correctly, as is the point of tesing. \ No newline at end of file