Skip to content

Commit

Permalink
Correctly reset deeplima between files
Browse files Browse the repository at this point in the history
No more crash when analyzing several files
  • Loading branch information
kleag committed May 15, 2024
1 parent a7786ff commit f8554d0
Show file tree
Hide file tree
Showing 21 changed files with 367 additions and 248 deletions.
15 changes: 11 additions & 4 deletions deeplima/apps/deeplima.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ using namespace deeplima;
class file_parser
{
public:
std::shared_ptr<segmentation::ISegmentation> psegm = nullptr;
std::shared_ptr< ITokenSequenceAnalyzer > panalyzer = nullptr;
std::shared_ptr<segmentation::ISegmentation> psegm = nullptr; // tokenizer
std::shared_ptr< ITokenSequenceAnalyzer > panalyzer = nullptr; // tagger
std::shared_ptr< dumper::AbstractDumper > pdumper_segm_only = nullptr; // used when using segmentation only
std::shared_ptr< dumper::DumperBase > pdumper_complete = nullptr; // used when using tagger
std::shared_ptr<DependencyParser> parser = nullptr;
Expand Down Expand Up @@ -241,6 +241,8 @@ void init(const std::map<std::string, std::string>& models_fn,
});
}

// NOTE Commented out because psegm is now instantiated for each file in
// parse_file. This is a temporary solution while reusing it fails.
// psegm->register_handler([panalyzer]
// (const std::vector<segmentation::token_pos>& tokens,
// uint32_t len)
Expand Down Expand Up @@ -289,7 +291,8 @@ void parse_file(std::istream& input,
}
catch (std::runtime_error& e)
{
std::cerr << "In parse_file: failed to load model file " << models_fn.find("tok")->second << ": "
std::cerr << "In parse_file: failed to load model file "
<< models_fn.find("tok")->second << ": "
<< e.what() << std::endl;
throw;
}
Expand Down Expand Up @@ -340,7 +343,7 @@ void parse_file(std::istream& input,
// std::cerr << "Waiting for PoS tagger to stop. Calling panalyzer->finalize" << std::endl;
panalyzer->finalize();
pdumper_complete->flush();
std::cerr << "Analyzer stopped. panalyzer->finalize returned" << std::endl;
// std::cerr << "Analyzer stopped. panalyzer->finalize returned" << std::endl;
}

if (parser)
Expand All @@ -356,10 +359,14 @@ void parse_file(std::istream& input,

uint64_t token_counter = 0;
if(nullptr != pdumper_segm_only)
{
token_counter = pdumper_segm_only->get_token_counter();
pdumper_segm_only->reset();
}
else if (nullptr != pdumper_complete)
{
token_counter = pdumper_complete->get_token_counter();
pdumper_complete->reset();
}
else
{
Expand Down
45 changes: 24 additions & 21 deletions deeplima/include/deeplima/dependency_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "utils/str_index.h"
#include "helpers/path_resolver.h"
#include "deeplima/graph_dp.h"
#include "deeplima/token_type.h"
// #include "graph_dp/impl/graph_dp_impl.h"
#include "segmentation/impl/segmentation_decoder.h"
#include "token_sequence_analyzer.h"
Expand Down Expand Up @@ -76,7 +77,7 @@ class DependencyParser
m_ptoken(nullptr)
{ }

inline typename tokens_with_analysis_t::token_t::token_flags_t flags() const
inline token_flags_t flags() const
{
assert(nullptr != m_ptoken);
return m_ptoken->m_flags;
Expand All @@ -85,7 +86,7 @@ class DependencyParser
inline bool eos() const
{
assert(nullptr != m_ptoken);
return flags() & DependencyParser::tokens_with_analysis_t::token_t::token_flags_t::sentence_brk;
return flags() & token_flags_t::sentence_brk;
}

inline uint32_t cls(size_t idx) const
Expand Down Expand Up @@ -151,7 +152,7 @@ class DependencyParser
return m_current >= m_end;
}

inline impl::token_t::token_flags_t flags() const
inline token_flags_t flags() const
{
assert(! end());
return m_buffer[m_current].m_flags;
Expand Down Expand Up @@ -235,8 +236,9 @@ class DependencyParser
: m_buffer_size(buffer_size),
m_current_buffer(0),
m_current_timepoint(0),
m_stridx_ptr(stridx)//,
// m_stridx(*stridx)
m_stridx_ptr(stridx),
// m_stridx(*stridx),
m_impl()
{
assert(m_buffer_size > 0);
assert(num_buffers > 0);
Expand Down Expand Up @@ -307,9 +309,10 @@ class DependencyParser
}
}

// Apply the model to the sequence of tokens given by iter from the tagger
void operator()(TokenSequenceAnalyzer<>::TokenIterator& iter)
{
// std::cerr << "DependencyParser::operator()" << std::endl;
// std::cerr << "DependencyParser::operator(TokenSequenceAnalyzer<>::TokenIterator& iter)" << std::endl;
if (m_current_timepoint >= m_buffer_size)
{
acquire_buffer();
Expand All @@ -334,7 +337,7 @@ class DependencyParser
token.m_len = 0;
token.m_form_idx = m_stridx_ptr->get_idx("<ROOT>");
// std::cerr << "<ROOT>" << std::endl;
token.m_flags = impl::token_t::token_flags_t(segmentation::token_pos::flag_t::none);
token.m_flags = token_flags_t::none;
token.m_lemm_idx = token.m_form_idx;
insert_root = false;
tokens_to_process--;
Expand All @@ -360,8 +363,8 @@ class DependencyParser
token.m_classes[i] = iter.token_class(i);
}

if (iter.flags() & segmentation::token_pos::flag_t::sentence_brk ||
iter.flags() & segmentation::token_pos::flag_t::paragraph_brk)
if (iter.flags() & token_flags_t::sentence_brk ||
iter.flags() & token_flags_t::paragraph_brk)
{
insert_root = true;
}
Expand Down Expand Up @@ -480,8 +483,8 @@ class DependencyParser
// << "; m_buffer_size=" << m_buffer_size
// << "; token=" << iter.form() << std::endl;

if (iter.flags() & segmentation::token_pos::flag_t::sentence_brk ||
iter.flags() & segmentation::token_pos::flag_t::paragraph_brk)
if (iter.flags() & token_flags_t::sentence_brk ||
iter.flags() & token_flags_t::paragraph_brk)
{
break;
// lengths.push_back(this_sentence_tokens);
Expand Down Expand Up @@ -554,16 +557,16 @@ class GraphDpImpl: public deeplima::graph_dp::impl::GraphDependencyParser
m_curr_buff_idx(0)
{}

GraphDpImpl(
size_t threads,
size_t buffer_size_per_thread
)
: deeplima::graph_dp::impl::GraphDependencyParser(
0 /* TODO: FIX ME */, 4, threads * 2, buffer_size_per_thread, threads),
m_fastText(std::make_shared<FastTextVectorizer<eigen_wrp::EigenMatrixXf::matrix_t, Eigen::Index>>()),
m_current_timepoint(deeplima::graph_dp::impl::GraphDependencyParser::get_start_timepoint())
{
}
// GraphDpImpl(
// size_t threads,
// size_t buffer_size_per_thread
// )
// : deeplima::graph_dp::impl::GraphDependencyParser(
// 0 /* TODO: FIX ME */, 4, threads * 2, buffer_size_per_thread, threads),
// m_fastText(std::make_shared<FastTextVectorizer<eigen_wrp::EigenMatrixXf::matrix_t, Eigen::Index>>()),
// m_current_timepoint(deeplima::graph_dp::impl::GraphDependencyParser::get_start_timepoint())
// {
// }

std::shared_ptr<EmbdUInt64Float> convert(const EmbdStrFloat& src)
{
Expand Down
28 changes: 21 additions & 7 deletions deeplima/include/deeplima/dumper_conllu.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

// #include "deeplima/segmentation/impl/segmentation_impl.h"

#include "deeplima/token_type.h"

namespace deeplima
{
namespace dumper
Expand Down Expand Up @@ -158,6 +160,11 @@ class AbstractDumper
: m_token_counter(0) { }

virtual ~AbstractDumper() { }

void reset()
{
m_token_counter = 0;
}
};

class Horizontal : public AbstractDumper
Expand Down Expand Up @@ -199,8 +206,8 @@ class Horizontal : public AbstractDumper
}
std::cout << str << " ";

if (tokens[i].m_flags & deeplima::segmentation::token_pos::flag_t::sentence_brk ||
tokens[i].m_flags & deeplima::segmentation::token_pos::flag_t::paragraph_brk)
if (tokens[i].m_flags & token_flags_t::sentence_brk ||
tokens[i].m_flags & token_flags_t::paragraph_brk)
{
// std::cerr << "Horizontal endl" << std::endl;
std::cout << std::endl;
Expand Down Expand Up @@ -265,8 +272,8 @@ class TokensToConllU : public AbstractDumper
increment_token_counter();

m_next_token_idx += 1;
if (tokens[i].m_flags & deeplima::segmentation::token_pos::flag_t::sentence_brk ||
tokens[i].m_flags & deeplima::segmentation::token_pos::flag_t::paragraph_brk)
if (tokens[i].m_flags & token_flags_t::sentence_brk ||
tokens[i].m_flags & token_flags_t::paragraph_brk)
{
// std::cerr << "TokensToConllU end of sentence" << std::endl;
std::cout << std::endl;
Expand All @@ -285,6 +292,7 @@ class DumperBase
virtual ~DumperBase() = default;
virtual uint64_t get_token_counter() const = 0;
virtual void flush() = 0;
virtual void reset() = 0;
};

template <class I>
Expand All @@ -296,6 +304,11 @@ class AnalysisToConllU : public DumperBase
std::vector<ConllToken> m_tokens;
uint32_t m_root;

void reset()
{
m_token_counter = 0;
}

inline void increment_token_counter()
{
++m_token_counter;
Expand All @@ -315,13 +328,14 @@ class AnalysisToConllU : public DumperBase
m_has_feats(false),
m_first_feature_to_print(0)
{
// std::cerr << "AnalysisToConllU()" << (void*)this << std::endl;
}

virtual ~AnalysisToConllU()
{
// std::cerr << "~AnalysisToConllU " << (void*)this << std::endl;
// if (m_next_token_idx > 1)
// {
// std::cerr << "on AnalysisToConllU destructor" << std::endl;
// std::cout << std::endl;
// }
}
Expand Down Expand Up @@ -557,8 +571,8 @@ class AnalysisToConllU : public DumperBase
increment_token_counter();

m_next_token_idx += 1;
if (iter.flags() & deeplima::segmentation::token_pos::flag_t::sentence_brk ||
iter.flags() & deeplima::segmentation::token_pos::flag_t::paragraph_brk)
if (iter.flags() & token_flags_t::sentence_brk ||
iter.flags() & token_flags_t::paragraph_brk)
{
// std::cerr << "AnalysisToConllU::operator() on sent/para break. m_next_token_idx="
// << m_next_token_idx << std::endl;
Expand Down
27 changes: 14 additions & 13 deletions deeplima/include/deeplima/eigen_wrp/bilstm_and_dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ class Op_BiLSTM_Dense_ArgMax : public Op_Base

bool precompute()
{
std::cerr << "fw weights size: " << bilstm.fw.weight_hh.rows() << " x " << bilstm.fw.weight_hh.cols() << std::endl;
// std::cerr << "fw weights size: " << bilstm.fw.weight_hh.rows()
// << " x " << bilstm.fw.weight_hh.cols() << std::endl;

size_t hidden_size = bilstm.fw.weight_hh.cols();
std::cerr << "precompute(fw.input):" << std::endl;
// std::cerr << "precompute(fw.input):" << std::endl;
// /*
mul_fw.matmul_input = bilstm.fw.weight_hh.block(0, 0, hidden_size, hidden_size).inverse().partialPivLu();
mul_fw.matmul_forget = bilstm.fw.weight_hh.block(hidden_size, 0, hidden_size, hidden_size).inverse().partialPivLu();
Expand All @@ -110,15 +111,15 @@ class Op_BiLSTM_Dense_ArgMax : public Op_Base
// */

hidden_size = bilstm.bw.weight_hh.cols();
std::cerr << "precompute(bw.input):" << std::endl;
// std::cerr << "precompute(bw.input):" << std::endl;
// /*
mul_bw.matmul_input = bilstm.bw.weight_hh.block(0, 0, hidden_size, hidden_size).inverse().partialPivLu();
mul_bw.matmul_forget = bilstm.bw.weight_hh.block(hidden_size, 0, hidden_size, hidden_size).inverse().partialPivLu();
mul_bw.matmul_update = bilstm.bw.weight_hh.block(hidden_size*2, 0, hidden_size, hidden_size).inverse().partialPivLu();
mul_bw.matmul_output = bilstm.bw.weight_hh.block(hidden_size*3, 0, hidden_size, hidden_size).inverse().partialPivLu();
// */

std::cerr << "end of precomputing" << std::endl;
// std::cerr << "end of precomputing" << std::endl;
return true;
}
#else
Expand All @@ -142,18 +143,18 @@ class Op_BiLSTM_Dense_ArgMax : public Op_Base
{
if constexpr (std::is_integral_v<AuxScalar> && std::is_signed_v<AuxScalar>)
{
std::cerr << "Converting hh to fixed_point" << std::endl;
std::cerr << "min(fw_weight_hh) = " << bilstm.fw.weight_hh.minCoeff() << " "
<< "max(fw_weight_hh) = " << bilstm.fw.weight_hh.maxCoeff() << std::endl;
// std::cerr << "Converting hh to fixed_point" << std::endl;
// std::cerr << "min(fw_weight_hh) = " << bilstm.fw.weight_hh.minCoeff() << " "
// << "max(fw_weight_hh) = " << bilstm.fw.weight_hh.maxCoeff() << std::endl;
convert_matrix(bilstm.fw.weight_hh, weight_fw_hh_fixed_point);
std::cerr << "min(fw_weight_hh) = " << static_cast<T>(weight_fw_hh_fixed_point.minCoeff()) / WEIGHT_FRACTION_MULT << " "
<< "max(fw_weight_hh) = " << static_cast<T>(weight_fw_hh_fixed_point.maxCoeff()) / WEIGHT_FRACTION_MULT << std::endl;
// std::cerr << "min(fw_weight_hh) = " << static_cast<T>(weight_fw_hh_fixed_point.minCoeff()) / WEIGHT_FRACTION_MULT << " "
// << "max(fw_weight_hh) = " << static_cast<T>(weight_fw_hh_fixed_point.maxCoeff()) / WEIGHT_FRACTION_MULT << std::endl;

std::cerr << "min(bw_weight_hh) = " << bilstm.bw.weight_hh.minCoeff() << " "
<< "max(bw_weight_hh) = " << bilstm.bw.weight_hh.maxCoeff() << std::endl;
// std::cerr << "min(bw_weight_hh) = " << bilstm.bw.weight_hh.minCoeff() << " "
// << "max(bw_weight_hh) = " << bilstm.bw.weight_hh.maxCoeff() << std::endl;
convert_matrix(bilstm.bw.weight_hh, weight_bw_hh_fixed_point);
std::cerr << "min(bw_weight_hh) = " << static_cast<T>(weight_bw_hh_fixed_point.minCoeff()) / WEIGHT_FRACTION_MULT << " "
<< "max(bw_weight_hh) = " << static_cast<T>(weight_bw_hh_fixed_point.maxCoeff()) / WEIGHT_FRACTION_MULT << std::endl;
// std::cerr << "min(bw_weight_hh) = " << static_cast<T>(weight_bw_hh_fixed_point.minCoeff()) / /*WEIGHT_FRACTION_MULT << " "
// << "max(bw_weight_hh) = " << static_cast<T>(weight_bw_hh_fixed_point.maxCoeff()) / WEIGHT_FRACTION_MULT << std::endl;*/
}

return true;
Expand Down
4 changes: 2 additions & 2 deletions deeplima/include/deeplima/eigen_wrp/lstm_beam_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ class Op_LSTM_Beam_Decoder : public Op_Base
decoding_step++;

M& states_c = wb->states_c;
if (states_c.cols() != beam_size)
if ((size_t)states_c.cols() != beam_size)
states_c = M::Zero(hidden_size, beam_size);
for (size_t i = 0; i < beam_size; ++i) states_c.col(i) = c;
M& states_h = wb->states_h;
if (states_h.cols() != beam_size)
if ((size_t)states_h.cols() != beam_size)
states_h = M::Zero(hidden_size, beam_size);
for (size_t i = 0; i < beam_size; ++i) states_h.col(i) = h;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ class WordSeqEmbdVectorizerWithPrecomputing
m_pModel->precompute_inputs(input, Parent::m_precomputed_vectors[Parent::m_curr_bucket_id], 0);
Parent::m_curr_bucket_id++;

if (Parent::m_curr_bucket_id >= Parent::m_precomputed_vectors.size())
if ((size_t)Parent::m_curr_bucket_id >= Parent::m_precomputed_vectors.size())
{
Parent::m_precomputed_vectors.resize(Parent::m_curr_bucket_id + 1);
}
Expand Down
24 changes: 23 additions & 1 deletion deeplima/include/deeplima/ner.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,31 @@ namespace impl
#error Unknown inference engine
#endif

/**
* A kind of RnnSequenceClassifier, used for named entities tagging (?), but
* also the parent of TaggingImpl, used as member in TokenSequenceAnalyzer
*/
template <typename AuxScalar=float>
class EntityTaggingClassifier: public RnnSequenceClassifier<Model<AuxScalar>, BaseMatrix, uint8_t>
{};
{
public:
EntityTaggingClassifier() :
RnnSequenceClassifier<Model<AuxScalar>, BaseMatrix, uint8_t>()
{
}

// EntityTaggingClassifier(uint32_t max_feat,
// uint32_t overlap,
// uint32_t num_slots,
// uint32_t slot_len,
// uint32_t num_threads) :
// RnnSequenceClassifier<Model<AuxScalar>, BaseMatrix, uint8_t>(
// max_feat, overlap, num_slots, slot_len, num_threads)
// {
// }

virtual ~EntityTaggingClassifier() = default;
};

} // namespace impl

Expand Down
Loading

0 comments on commit f8554d0

Please sign in to comment.