diff --git a/deeplima/apps/deeplima.cpp b/deeplima/apps/deeplima.cpp index 720185690..0ee5a57db 100644 --- a/deeplima/apps/deeplima.cpp +++ b/deeplima/apps/deeplima.cpp @@ -24,177 +24,17 @@ namespace po = boost::program_options; // #define DP_BUFFER_SIZE (16384) #define DP_BUFFER_SIZE (128) -void init(const std::map& models_fn, - const deeplima::PathResolver& path_resolver, - size_t threads, - size_t out_fmt=1, - bool tag_use_mp=true); -void parse_file(std::istream& input, - const std::map& models_fn, - const deeplima::PathResolver& path_resolver, - size_t threads, - size_t out_fmt=1, - bool tag_use_mp=true); - -int main(int argc, char* argv[]) -{ - setlocale(LC_ALL, "en_US.UTF-8"); - // std::cerr << "deeplima (git commit hash: " << deeplima::version::get_git_commit_hash() << ", " - // << "git branch: " << deeplima::version::get_git_branch() - // << ")" << std::endl; - - size_t threads = 1; - std::string input_format, output_format, tok_model, tag_model, lem_model, dp_model; - std::string lem_dict; - std::string fixed_ini; - std::string lower_ini; - std::string fixed_lemm; - bool tag_use_mp; - std::vector input_files; - - po::options_description desc("deeplima (analysis demo)"); - desc.add_options() - ("help,h", "Display this help message") - ("input-format", po::value(&input_format)->default_value("plain"), "Input format: plain|conllu") - ("output-format", po::value(&output_format)->default_value("conllu"), "Output format: conllu|vertical|horizontal") - ("tok-model", po::value(&tok_model)->default_value(""), "Tokenization model") - ("tag-model", po::value(&tag_model)->default_value(""), "Tagging model") - ("lem-model", po::value(&lem_model)->default_value(""), "Lemmatization model") - ("dp-model", po::value(&dp_model)->default_value(""), "Dependency parsing model") - ("lem-dict", po::value(&lem_dict)->default_value(""), "Lemmatization dictionary") - ("fixed-ini", po::value(&fixed_ini)->default_value(""), "List of upos wiht fixed lemmas") - ("lower-ini", po::value(&lower_ini)->default_value(""), "List of upos wih lowercased lemmas") - ("fixed-lemm", po::value(&fixed_lemm)->default_value(""), "List of upos wih lowercased lemmas") - ("input-file", po::value>(&input_files), "Input file names") - ("threads", po::value(&threads), "Max threads to use") - ("tag-use-mp", po::value(&tag_use_mp)->default_value(true), "Use mixed-precision calculations in tagger") - ; - - po::positional_options_description pos_desc; - pos_desc.add("input-file", -1); - - po::variables_map vm; - - try - { - po::store(po::command_line_parser(argc, argv).options(desc).positional(pos_desc).run(), vm); - po::notify(vm); - } - catch (const boost::program_options::unknown_option& e) - { - std::cerr << e.what() << std::endl; - return -1; - } - - if (vm.count("help")) { - std::cout << desc << std::endl; - return 0; - } - - if (tok_model.empty() && tag_model.empty()) - { - std::cerr << "No model is provided: --tok-model or --tag-model parameters are required." << std::endl << std::endl; - std::cout << desc << std::endl; - return -1; - } - - std::map models; - - if (tok_model.size() > 0) - { - models["tok"] = tok_model; - } - - if (tag_model.size() > 0) - { - models["tag"] = tag_model; - } - - if (lem_model.size() > 0) - { - models["lem"] = lem_model; - } - - if (dp_model.size() > 0) - { - models["dp"] = dp_model; - } - - if (lem_dict.size() > 0) - { - models["lem_dict"] = lem_dict; - } - - if (fixed_ini.size() > 0) - { - models["fixed_ini"] = fixed_ini; - } - - if (lower_ini.size() > 0) - { - models["lower_ini"] = lower_ini; - } - - if (fixed_lemm.size() > 0) - { - models["fixed_lemm"] = fixed_lemm; - } - - size_t out_fmt = 1; - if (output_format.size() > 0) - { - if (output_format == "horizontal") - { - out_fmt = 2; - } - } - - deeplima::PathResolver path_resolver; - - init(models, path_resolver, threads, out_fmt, tag_use_mp); - - if (vm.count("input-file") > 0) - { - - char read_buffer[READ_BUFFER_SIZE]; - for ( const auto& fn : input_files ) - { - std::cerr << "Reading file: " << fn << std::endl; - std::ifstream file(fn, std::ifstream::binary | std::ios::in); - if (!file.is_open()) - { - std::cerr << "Failed to open file: " << fn << std::endl; - throw std::runtime_error(std::string("Failed to open file "+fn)); - } - file.rdbuf()->pubsetbuf(read_buffer, READ_BUFFER_SIZE); - try - { - parse_file(file, models, path_resolver, threads, out_fmt, tag_use_mp); - } - catch (const std::runtime_error& e) - { - std::cerr << "Analysis failure: Exception while analyzing file " << fn << ":" << std::endl - << e.what() << std::endl; - return 1; - } - } - } - else - { - try - { - parse_file(std::cin, models, path_resolver, threads, out_fmt, tag_use_mp); - } - catch (const std::runtime_error& e) - { - std::cerr << "Analysis failure: Exception while analyzing text from stdin:" << std::endl - << e.what() << std::endl; - return 1; - } - } - - return 0; -} +// void init(const std::map& models_fn, +// const deeplima::PathResolver& path_resolver, +// size_t threads, +// size_t out_fmt=1, +// bool tag_use_mp=true); +// void parse_file(std::istream& input, +// const std::map& models_fn, +// const deeplima::PathResolver& path_resolver, +// size_t threads, +// size_t out_fmt=1, +// bool tag_use_mp=true); #include "deeplima/segmentation/impl/segmentation_impl.h" #include "deeplima/ner.h" @@ -205,12 +45,29 @@ int main(int argc, char* argv[]) using namespace deeplima; +class file_parser +{ +public: std::shared_ptr psegm = nullptr; std::shared_ptr< ITokenSequenceAnalyzer > panalyzer = nullptr; std::shared_ptr< dumper::AbstractDumper > pdumper = nullptr; std::shared_ptr< dumper::DumperBase > pDumperBase = nullptr; std::shared_ptr parser = nullptr; +/** + * @param models_fn models file names, a map associating a kind of model to its file path + * keys are: + * - tok: the tokenizer model (can be absent if loading a pre-tokenized text) + * - tag: the tagger model (can be absent if doing tokenization only) + * - lem: the lemmatizer moder (can be absent) + * - lem_dict: the lemmatizer dictionary (can be absent) + * - fixed_ini: + * - lower_ini: + * - fixed_lemm: + * - dp: the dependency parser model (can be absent) + * @param tag_use_mp if true, use mixed precision (int16) model, otherwise, use + * full precision (float) + */ void init(const std::map& models_fn, const PathResolver& path_resolver, size_t threads, @@ -219,25 +76,26 @@ void init(const std::map& models_fn, { // std::cerr << "deeplima parse_file threads = " << threads << std::endl; - if (models_fn.end() != models_fn.find("tok")) - { - psegm = std::make_shared(); - try - { - std::dynamic_pointer_cast(psegm)->load(models_fn.find("tok")->second); - } - catch (std::runtime_error& e) - { - std::cerr << "In parse_file: failed to load model file " << models_fn.find("tok")->second << ": " - << e.what() << std::endl; - throw; - } - std::dynamic_pointer_cast(psegm)->init(threads, 16*1024); - } - else - { - psegm = std::make_shared(); - } + // if (models_fn.end() != models_fn.find("tok")) + // { + // psegm = std::make_shared(); + // try + // { + // std::dynamic_pointer_cast(psegm)->load(models_fn.find("tok")->second); + // } + // catch (std::runtime_error& e) + // { + // std::cerr << "In parse_file: failed to load model file " << models_fn.find("tok")->second << ": " + // << e.what() << std::endl; + // throw; + // } + // std::dynamic_pointer_cast(psegm)->init(threads, 16*1024); + // } + // else + // { + // // Not a real segmenter but a CoNLLU reader + // psegm = std::make_shared(); + // } if (models_fn.end() != models_fn.find("tag")) { @@ -319,7 +177,7 @@ void init(const std::map& models_fn, conllu_dumper->set_classes(i, panalyzer->get_class_names()[i], panalyzer->get_classes()[i]); } - panalyzer->register_handler([&parser](std::shared_ptr< StringIndex > stridx, + panalyzer->register_handler([this](std::shared_ptr< StringIndex > stridx, const token_buffer_t<>& tokens, const std::vector& lemmata, std::shared_ptr< StdMatrix > classes, @@ -382,13 +240,13 @@ void init(const std::map& models_fn, }); } - psegm->register_handler([panalyzer] - (const std::vector& tokens, - uint32_t len) - { - // std::cerr << "In psegm handler. Calling panalyzer functor" << std::endl; - (*panalyzer)(tokens, len); - }); + // psegm->register_handler([panalyzer] + // (const std::vector& tokens, + // uint32_t len) + // { + // // std::cerr << "In psegm handler. Calling panalyzer functor" << std::endl; + // (*panalyzer)(tokens, len); + // }); } else { @@ -403,13 +261,13 @@ void init(const std::map& models_fn, throw std::runtime_error("Unknown output format"); break; } - psegm->register_handler([pdumper] - (const std::vector& tokens, - uint32_t len) - { - // std::cerr << "In psegm handler. Calling pdumper functor" << std::endl; - (*pdumper)(tokens, len); - }); + // psegm->register_handler([pdumper] + // (const std::vector& tokens, + // uint32_t len) + // { + // // std::cerr << "In psegm handler. Calling pdumper functor" << std::endl; + // (*pdumper)(tokens, len); + // }); } } @@ -421,6 +279,47 @@ void parse_file(std::istream& input, size_t out_fmt, bool tag_use_mp) { + if (models_fn.end() != models_fn.find("tok")) + { + psegm = std::make_shared(); + try + { + std::dynamic_pointer_cast(psegm)->load(models_fn.find("tok")->second); + } + catch (std::runtime_error& e) + { + std::cerr << "In parse_file: failed to load model file " << models_fn.find("tok")->second << ": " + << e.what() << std::endl; + throw; + } + std::dynamic_pointer_cast(psegm)->init(threads, 16*1024); + } + else + { + // Not a real segmenter but a CoNLLU reader + psegm = std::make_shared(); + } + if (models_fn.end() != models_fn.find("tag")) + { + psegm->register_handler([this] + (const std::vector& tokens, + uint32_t len) + { + // std::cerr << "In psegm handler. Calling panalyzer functor" << std::endl; + (*panalyzer)(tokens, len); + }); + } + else + { + psegm->register_handler([this] + (const std::vector& tokens, + uint32_t len) + { + // std::cerr << "In psegm handler. Calling pdumper functor" << std::endl; + (*pdumper)(tokens, len); + }); + + } auto parsing_begin = std::chrono::high_resolution_clock::now(); std::dynamic_pointer_cast(psegm)->reset(); psegm->parse_from_stream([&input] @@ -478,4 +377,167 @@ void parse_file(std::istream& input, psegm->finalize(); } +}; + + +int main(int argc, char* argv[]) +{ + setlocale(LC_ALL, "en_US.UTF-8"); + // std::cerr << "deeplima (git commit hash: " << deeplima::version::get_git_commit_hash() << ", " + // << "git branch: " << deeplima::version::get_git_branch() + // << ")" << std::endl; + + size_t threads = 1; + std::string input_format, output_format, tok_model, tag_model, lem_model, dp_model; + std::string lem_dict; + std::string fixed_ini; + std::string lower_ini; + std::string fixed_lemm; + bool tag_use_mp; + std::vector input_files; + + po::options_description desc("deeplima (analysis demo)"); + desc.add_options() + ("help,h", "Display this help message") + ("input-format", po::value(&input_format)->default_value("plain"), "Input format: plain|conllu") + ("output-format", po::value(&output_format)->default_value("conllu"), "Output format: conllu|vertical|horizontal") + ("tok-model", po::value(&tok_model)->default_value(""), "Tokenization model") + ("tag-model", po::value(&tag_model)->default_value(""), "Tagging model") + ("lem-model", po::value(&lem_model)->default_value(""), "Lemmatization model") + ("dp-model", po::value(&dp_model)->default_value(""), "Dependency parsing model") + ("lem-dict", po::value(&lem_dict)->default_value(""), "Lemmatization dictionary") + ("fixed-ini", po::value(&fixed_ini)->default_value(""), "List of upos wiht fixed lemmas") + ("lower-ini", po::value(&lower_ini)->default_value(""), "List of upos wih lowercased lemmas") + ("fixed-lemm", po::value(&fixed_lemm)->default_value(""), "List of upos wih lowercased lemmas") + ("input-file", po::value>(&input_files), "Input file names") + ("threads", po::value(&threads), "Max threads to use") + ("tag-use-mp", po::value(&tag_use_mp)->default_value(true), "Use mixed-precision calculations in tagger") + ; + + po::positional_options_description pos_desc; + pos_desc.add("input-file", -1); + + po::variables_map vm; + + try + { + po::store(po::command_line_parser(argc, argv).options(desc).positional(pos_desc).run(), vm); + po::notify(vm); + } + catch (const boost::program_options::unknown_option& e) + { + std::cerr << e.what() << std::endl; + return -1; + } + + if (vm.count("help")) { + std::cout << desc << std::endl; + return 0; + } + + if (tok_model.empty() && tag_model.empty()) + { + std::cerr << "No model is provided: --tok-model or --tag-model parameters are required." << std::endl << std::endl; + std::cout << desc << std::endl; + return -1; + } + + std::map models; + + if (tok_model.size() > 0) + { + models["tok"] = tok_model; + } + + if (tag_model.size() > 0) + { + models["tag"] = tag_model; + } + + if (lem_model.size() > 0) + { + models["lem"] = lem_model; + } + + if (dp_model.size() > 0) + { + models["dp"] = dp_model; + } + + if (lem_dict.size() > 0) + { + models["lem_dict"] = lem_dict; + } + + if (fixed_ini.size() > 0) + { + models["fixed_ini"] = fixed_ini; + } + + if (lower_ini.size() > 0) + { + models["lower_ini"] = lower_ini; + } + + if (fixed_lemm.size() > 0) + { + models["fixed_lemm"] = fixed_lemm; + } + + size_t out_fmt = 1; + if (output_format.size() > 0) + { + if (output_format == "horizontal") + { + out_fmt = 2; + } + } + + deeplima::PathResolver path_resolver; + + file_parser fp; + fp.init(models, path_resolver, threads, out_fmt, tag_use_mp); + + if (vm.count("input-file") > 0) + { + + char read_buffer[READ_BUFFER_SIZE]; + for ( const auto& fn : input_files ) + { + std::cerr << "Reading file: " << fn << std::endl; + std::ifstream file(fn, std::ifstream::binary | std::ios::in); + if (!file.is_open()) + { + std::cerr << "Failed to open file: " << fn << std::endl; + throw std::runtime_error(std::string("Failed to open file "+fn)); + } + file.rdbuf()->pubsetbuf(read_buffer, READ_BUFFER_SIZE); + try + { + fp.parse_file(file, models, path_resolver, threads, out_fmt, tag_use_mp); + } + catch (const std::runtime_error& e) + { + std::cerr << "Analysis failure: Exception while analyzing file " << fn << ":" << std::endl + << e.what() << std::endl; + return 1; + } + } + } + else + { + try + { + fp.parse_file(std::cin, models, path_resolver, threads, out_fmt, tag_use_mp); + } + catch (const std::runtime_error& e) + { + std::cerr << "Analysis failure: Exception while analyzing text from stdin:" << std::endl + << e.what() << std::endl; + return 1; + } + } + + return 0; +}