diff --git a/train-rank/src/entity/entry.h b/train-rank/src/entity/entry.h index 7ecd030..2cf950b 100644 --- a/train-rank/src/entity/entry.h +++ b/train-rank/src/entity/entry.h @@ -4,6 +4,8 @@ #include #include +#include + static const char ENTRY_ID[] = "id"; static const char ENTRY_FILE_TYPE[] = "file_type"; static const char ENTRY_READ_LATER[] = "readlater"; @@ -17,6 +19,7 @@ static const char ENTRY_LANGUAGE[] = "language"; static const char ENTRY_URL[] = "url"; static const char ENTRY_PURE_CONTENT[] = "pure_content"; static const char ENTRY_TITLE[] = "title"; +static const char ENTRY_CREATED_AT[] = "created_at"; struct Entry { std::string id; @@ -32,4 +35,5 @@ struct Entry { std::string url; std::string pure_content; std::string title; + boost::posix_time::ptime timestamp; }; diff --git a/train-rank/src/knowledgebase_api.cpp b/train-rank/src/knowledgebase_api.cpp index e291949..c51ff47 100644 --- a/train-rank/src/knowledgebase_api.cpp +++ b/train-rank/src/knowledgebase_api.cpp @@ -3,15 +3,37 @@ #include #include +#include #include #include #include "easylogging++.h" +#include + using namespace web::json; +using std::string; +using std::unordered_map; +using std::vector; + namespace knowledgebase { +namespace { + +boost::posix_time::ptime parse_time(string& timestamp_str) { + timestamp_str[10] = ' '; + timestamp_str.pop_back(); // remove the last 'Z' + + try { + return boost::posix_time::time_from_string(timestamp_str); + } catch (boost::bad_lexical_cast& e) { + LOG(ERROR) << "Failed to convert timestamp " << timestamp_str << ", Reason: " << e.what(); + return boost::posix_time::ptime(boost::gregorian::date(1970, 1, 1)); + } +} + +} // namespace EntryCache::EntryCache():cache_miss(0), cache_hit(0) { } @@ -20,62 +42,31 @@ EntryCache& EntryCache::getInstance() { return instance; } +void EntryCache::init() { + loadAllEntries(); +} + std::optional EntryCache::getEntryById(const std::string& id) { if (cache.find(id) != cache.end()) { ++cache_hit; return std::make_optional(cache[id]); } + auto result = GetEntryById(id); + if (result != std::nullopt) { + ++cache_hit; + cache[id] = result.value(); + return result; + } ++cache_miss; + return std::nullopt; +} - http_client client(U(std::getenv("KNOWLEDGE_BASE_API_URL"))); - std::string current_entry_api_suffix = - std::string(ENTRY_API_SUFFIX) + "/" + id; - - LOG(DEBUG) << "current_entry_api_suffix " << current_entry_api_suffix - << std::endl; - - uri_builder builder(U(current_entry_api_suffix)); - - // Impression current_impression; - std::optional option_entry = std::nullopt; - - client.request(methods::GET, builder.to_string()) - .then([](http_response response) -> pplx::task { - if (response.status_code() == status_codes::OK) { - return response.extract_json(); - } - return pplx::task_from_result(web::json::value()); - }) - .then([&option_entry](pplx::task previousTask) { - try { - web::json::value const &v = previousTask.get(); - int code = v.at("code").as_integer(); - std::string message = "null"; - if (v.has_string_field("message")) { - message = v.at("message").as_string(); - } - LOG(DEBUG) << "code " << code << " message " << message << std::endl; - if (code == 0) { - web::json::value current_value = v.at("data"); - option_entry = convertFromWebJsonValueToEntry(current_value); - } - } catch (http_exception const &e) { - LOG(ERROR) << "Error exception " << e.what() << std::endl; - } - }) - .wait(); - - if (option_entry.has_value()) { - cache[id] = option_entry.value(); - } - return option_entry; +void EntryCache::loadAllEntries() { + cache = getEntries(FLAGS_recommend_source_name); } void EntryCache::dumpStatistics() { LOG(INFO) << "Cache hit: " << cache_hit << ", miss: " << cache_miss << endl; - for (auto& pr : cache) { - LOG(INFO) << pr.first << endl; - } } bool updateAlgorithmScoreAndMetadata( @@ -330,6 +321,15 @@ std::optional convertFromWebJsonValueToEntry( return std::nullopt; } + if (current_item.has_string_field(ENTRY_CREATED_AT)) { + auto timestamp_str = current_item.at(ENTRY_CREATED_AT).as_string(); + temp_entry.timestamp = parse_time(timestamp_str); + } else { + LOG(ERROR) << "current web json value have no " << ENTRY_CREATED_AT + << std::endl; + return std::nullopt; + } + return std::make_optional(temp_entry); } @@ -825,6 +825,77 @@ std::optional GetEntryById(const std::string &id) { return option_entry; } +unordered_map getEntries(const string& source) { + unordered_map result; + int offset = 0, count = 0, limit = 100; + do { + vector entry_list; + getEntries(limit, offset, source, &entry_list, &count); + for (auto& entry : entry_list) { + result.emplace(entry.id, std::move(entry)); + } + offset += limit; + } while (result.size() < count); + return result; +} + +void getEntries(int limit, int offset, const string& source, + std::vector *entry_list, int *count) { + entry_list->clear(); + http_client client(U(std::getenv("KNOWLEDGE_BASE_API_URL"))); + std::string current_suffix = + std::string(ENTRY_API_SUFFIX) + "?offset=" + std::to_string(offset) + + "&limit=" + std::to_string(limit) + "&source=" + source + "&extract=true" + + "&fields=id,file_type,language,url,title,readlater,crawler,starred,disabled,saved,unread,extract,created_at"; + LOG(DEBUG) << "current_suffix " << current_suffix + << std::endl; + + bool success = false; + while (!success) { + client.request(methods::GET, U(current_suffix)) + .then([](http_response response) -> pplx::task { + if (response.status_code() == status_codes::OK) { + return response.extract_json(); + } + return pplx::task_from_result(web::json::value()); + }) + .then([&](pplx::task previousTask) { + try { + web::json::value const &v = previousTask.get(); + int code = v.at("code").as_integer(); + std::string message = "null"; + if (v.has_string_field("message")) { + message = v.at("message").as_string(); + } + LOG(DEBUG) << "code " << code << " message " << message << std::endl; + if (code == 0) { + web::json::value data_value = v.at("data"); + web::json::value items = data_value.at("items"); + *count = data_value.at("count").as_integer(); + for (auto iter = items.as_array().cbegin(); iter != items.as_array().cend(); ++iter) { + auto item = convertFromWebJsonValueToEntry(*iter); + if (item == std::nullopt) { + continue; + } + entry_list->push_back(item.value()); + } + } + success = true; + } catch (http_exception const &e) { + LOG(ERROR) << "Error exception " << e.what() << std::endl; + } + }) + .wait(); + + if (!success) { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + LOG(INFO) << "getEntries retrying..." << std::endl; + } else { + LOG(INFO) << "getEntries succeed." << std::endl; + } + } +} + std::optional GetAlgorithmById(const std::string &id) { /** * @brief diff --git a/train-rank/src/knowledgebase_api.h b/train-rank/src/knowledgebase_api.h index df3a36a..27a2a03 100644 --- a/train-rank/src/knowledgebase_api.h +++ b/train-rank/src/knowledgebase_api.h @@ -2,6 +2,9 @@ #include #include #include + +#include + using std::cerr; using std::endl; @@ -15,12 +18,16 @@ using namespace web::http::client; #include "entity/rank_algorithm.h" #include "entity/reco_metadata.h" +DECLARE_string(recommend_source_name); + namespace knowledgebase { class EntryCache { public: static EntryCache& getInstance(); + void init(); std::optional getEntryById(const std::string& id); + void loadAllEntries(); void dumpStatistics(); @@ -63,6 +70,10 @@ void getAllAlgorithmAccordingRanked(std::string source, bool ranked, int* count); std::optional GetAlgorithmById(const std::string& id); std::optional GetEntryById(const std::string& id); +std::unordered_map getEntries(const std::string& source); +std::unordered_map getEntries(const std::string& source); +void getEntries(int limit, int offset, const std::string& source, + std::vector *entry_list, int *count); bool updateKnowledgeConfig(const std::string& source, const std::string& key, const web::json::value& value); bool updateLastRankTime(std::string source, int64_t last_rank_time); diff --git a/train-rank/src/lr/feature_extractor.cpp b/train-rank/src/lr/feature_extractor.cpp index a5ead3c..f1eddc3 100644 --- a/train-rank/src/lr/feature_extractor.cpp +++ b/train-rank/src/lr/feature_extractor.cpp @@ -44,14 +44,14 @@ double EmbeddingDistanceExtractor::extract(const Impression& item) { if (!item.embedding.has_value()) { return 0.0; } - return extract(Item(item.id, item.embedding.value())); + return extract(Item(item.entry_id, item.embedding.value())); } double EmbeddingDistanceExtractor::extract(const Algorithm& item) { if (!item.embedding.has_value()) { return 0.0; } - return extract(Item(item.id, item.embedding.value())); + return extract(Item(item.entry, item.embedding.value())); } double EmbeddingDistanceExtractor::extract(const Item& item) { @@ -64,7 +64,7 @@ double EmbeddingDistanceExtractor::extract(const Item& item) { minNegCos = std::max(minNegCos, embeddingCosine(neg.embedding, item.embedding)); } // TODO(haochengwang): Optimize me - return minPosCos; + return minPosCos + 1.0; } Reason EmbeddingDistanceExtractor::getReason(const Algorithm& item) { diff --git a/train-rank/src/lr/feature_extractor.h b/train-rank/src/lr/feature_extractor.h index f12cc19..8154cf0 100644 --- a/train-rank/src/lr/feature_extractor.h +++ b/train-rank/src/lr/feature_extractor.h @@ -26,7 +26,8 @@ class EmbeddingDistanceExtractor : public FeatureExtractor { struct Item { std::string id; std::vector embedding; - Item(std::string id, const std::vector& embedding): id(id), embedding(embedding) {} + Item(std::string id, const std::vector& embedding): + id(id), embedding(embedding) {} }; EmbeddingDistanceExtractor(); diff --git a/train-rank/src/rssrank.cpp b/train-rank/src/rssrank.cpp index 4cf960c..08170db 100755 --- a/train-rank/src/rssrank.cpp +++ b/train-rank/src/rssrank.cpp @@ -6,12 +6,15 @@ #include namespace fs = std::filesystem; +#include #include #include #include #include #include +#include + #include "common_tool.h" #include "data_process.h" #include "easylogging++.h" @@ -96,6 +99,19 @@ std::unique_ptr loadDefaultModel() { new lr::LogisticRegression(1, vector{1.0, 0.0})); } +double getTimeCoefficient(boost::posix_time::ptime timestamp) { + auto now = boost::posix_time::second_clock::local_time(); + auto diff = (now - timestamp).total_seconds(); + + if (diff <= 0) { + return 1.0; + } + + double result = 1.8 / (1.0 + std::exp(- 86400.0 / diff)) - 0.8; + + return result; +} + } // namespace std::string getRankModelPath(ModelPathType model_path_type) { @@ -222,43 +238,48 @@ getAllEntryToPrerankSourceForCurrentSourceKnowledge() { const char *source_name = std::getenv("TERMINUS_RECOMMEND_SOURCE_NAME"); const int batch_size = 100; - int offset = 0; - while (true) { - std::vector temp_algorithm; - int count = 0; - knowledgebase::getAlgorithmAccordingRanked(batch_size, offset, source_name, - false, &temp_algorithm, &count); - LOG(INFO) << "offset " << offset << " limit " << batch_size << " count " - << count; + auto get_scores = [&](bool ranked) { + int offset = 0; + while (true) { + std::vector temp_algorithm; + int count = 0; + knowledgebase::getAlgorithmAccordingRanked(batch_size, offset, source_name, + ranked, &temp_algorithm, &count); + LOG(INFO) << "offset " << offset << " limit " << batch_size << " count " + << count; + + // algorithm_list.insert(algorithm_list.end(),temp_algorithm.begin(),temp_algorithm.end()); + for (Algorithm current : temp_algorithm) { + if (current.prerank_score != std::nullopt) { + std::optional temp_entry = + knowledgebase::GetEntryById(current.entry); + if (temp_entry == std::nullopt) { + LOG(INFO) << "entry [" << current.entry << "] not exist, algorithm [" + << current.id << "]" << std::endl; + continue; + } - // algorithm_list.insert(algorithm_list.end(),temp_algorithm.begin(),temp_algorithm.end()); - for (Algorithm current : temp_algorithm) { - if (current.prerank_score != std::nullopt) { - std::optional temp_entry = - knowledgebase::GetEntryById(current.entry); - if (temp_entry == std::nullopt) { - LOG(INFO) << "entry [" << current.entry << "] not exist, algorithm [" - << current.id << "]" << std::endl; - continue; - } + if (temp_entry.value().extract == false) { + LOG(INFO) << "entry [" << current.entry + << "] not extract, algorithm [" << current.id << "]" + << std::endl; + continue; + } - if (temp_entry.value().extract == false) { - LOG(INFO) << "entry [" << current.entry - << "] not extract, algorithm [" << current.id << "]" - << std::endl; - continue; + algorithm_entry_id_to_score_with_meta[current.id] = + ScoreWithMetadata(current.prerank_score.value() * getTimeCoefficient(temp_entry.value().timestamp)); + } + } + offset = offset + batch_size; + if (offset >= count) { + break; } - - algorithm_entry_id_to_score_with_meta[current.id] = - ScoreWithMetadata(current.prerank_score.value()); } - } - offset = offset + batch_size; - if (offset >= count) { - break; - } - } + }; + get_scores(false); + get_scores(true); + LOG(INFO) << "algorithm_entry_id_to_score_with_meta " << algorithm_entry_id_to_score_with_meta.size() << std::endl; return algorithm_entry_id_to_score_with_meta; @@ -699,6 +720,8 @@ bool rankLR() { return false; } + knowledgebase::EntryCache::getInstance().init(); + if (!doRank()) { std::unordered_map entry_to_score_with_metadata = rssrank::getAllEntryToPrerankSourceForCurrentSourceKnowledge(); @@ -758,7 +781,7 @@ bool doRank() { std::unordered_map id_to_score_with_meta; for (const auto ¤t_item : not_ranked_algorithm_to_entry) { std::optional temp_entry = - knowledgebase::GetEntryById(current_item.second); + knowledgebase::EntryCache::getInstance().getEntryById(current_item.second); if (temp_entry == std::nullopt) { LOG(ERROR) << "entry [" << current_item.second << "] not exist, algorithm [" << current_item.first << "]" @@ -788,6 +811,7 @@ bool doRank() { } auto score = logistic_regression->predict(features); + score *= getTimeCoefficient(temp_entry.value().timestamp); if (FLAGS_verbose) { LOG(INFO) << "Score: " << score << endl; }