Skip to content

Commit

Permalink
Merge pull request #30 from beclab/feat/rank-no-time-limit
Browse files Browse the repository at this point in the history
Batch load all entries into cache at one time; Add time coefficient…
  • Loading branch information
haochengwang authored Aug 27, 2024
2 parents fa33dd1 + 8e2824f commit 7b50055
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 82 deletions.
4 changes: 4 additions & 0 deletions train-rank/src/entity/entry.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <string>
#include <vector>

#include <boost/date_time.hpp>

static const char ENTRY_ID[] = "id";
static const char ENTRY_FILE_TYPE[] = "file_type";
static const char ENTRY_READ_LATER[] = "readlater";
Expand All @@ -17,6 +19,7 @@ static const char ENTRY_LANGUAGE[] = "language";
static const char ENTRY_URL[] = "url";
static const char ENTRY_PURE_CONTENT[] = "pure_content";
static const char ENTRY_TITLE[] = "title";
static const char ENTRY_CREATED_AT[] = "created_at";

struct Entry {
std::string id;
Expand All @@ -32,4 +35,5 @@ struct Entry {
std::string url;
std::string pure_content;
std::string title;
boost::posix_time::ptime timestamp;
};
161 changes: 116 additions & 45 deletions train-rank/src/knowledgebase_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,37 @@
#include <cpprest/json.h>

#include <optional>
#include <iostream>
#include <chrono>
#include <thread>

#include "easylogging++.h"

#include <boost/date_time.hpp>

using namespace web::json;

using std::string;
using std::unordered_map;
using std::vector;

namespace knowledgebase {

namespace {

boost::posix_time::ptime parse_time(string& timestamp_str) {
timestamp_str[10] = ' ';
timestamp_str.pop_back(); // remove the last 'Z'

try {
return boost::posix_time::time_from_string(timestamp_str);
} catch (boost::bad_lexical_cast& e) {
LOG(ERROR) << "Failed to convert timestamp " << timestamp_str << ", Reason: " << e.what();
return boost::posix_time::ptime(boost::gregorian::date(1970, 1, 1));
}
}

} // namespace
EntryCache::EntryCache():cache_miss(0), cache_hit(0) {
}

Expand All @@ -20,62 +42,31 @@ EntryCache& EntryCache::getInstance() {
return instance;
}

void EntryCache::init() {
loadAllEntries();
}

std::optional<Entry> EntryCache::getEntryById(const std::string& id) {
if (cache.find(id) != cache.end()) {
++cache_hit;
return std::make_optional(cache[id]);
}
auto result = GetEntryById(id);
if (result != std::nullopt) {
++cache_hit;
cache[id] = result.value();
return result;
}
++cache_miss;
return std::nullopt;
}

http_client client(U(std::getenv("KNOWLEDGE_BASE_API_URL")));
std::string current_entry_api_suffix =
std::string(ENTRY_API_SUFFIX) + "/" + id;

LOG(DEBUG) << "current_entry_api_suffix " << current_entry_api_suffix
<< std::endl;

uri_builder builder(U(current_entry_api_suffix));

// Impression current_impression;
std::optional<Entry> option_entry = std::nullopt;

client.request(methods::GET, builder.to_string())
.then([](http_response response) -> pplx::task<web::json::value> {
if (response.status_code() == status_codes::OK) {
return response.extract_json();
}
return pplx::task_from_result(web::json::value());
})
.then([&option_entry](pplx::task<web::json::value> previousTask) {
try {
web::json::value const &v = previousTask.get();
int code = v.at("code").as_integer();
std::string message = "null";
if (v.has_string_field("message")) {
message = v.at("message").as_string();
}
LOG(DEBUG) << "code " << code << " message " << message << std::endl;
if (code == 0) {
web::json::value current_value = v.at("data");
option_entry = convertFromWebJsonValueToEntry(current_value);
}
} catch (http_exception const &e) {
LOG(ERROR) << "Error exception " << e.what() << std::endl;
}
})
.wait();

if (option_entry.has_value()) {
cache[id] = option_entry.value();
}
return option_entry;
void EntryCache::loadAllEntries() {
cache = getEntries(FLAGS_recommend_source_name);
}

void EntryCache::dumpStatistics() {
LOG(INFO) << "Cache hit: " << cache_hit << ", miss: " << cache_miss << endl;
for (auto& pr : cache) {
LOG(INFO) << pr.first << endl;
}
}

bool updateAlgorithmScoreAndMetadata(
Expand Down Expand Up @@ -330,6 +321,15 @@ std::optional<Entry> convertFromWebJsonValueToEntry(
return std::nullopt;
}

if (current_item.has_string_field(ENTRY_CREATED_AT)) {
auto timestamp_str = current_item.at(ENTRY_CREATED_AT).as_string();
temp_entry.timestamp = parse_time(timestamp_str);
} else {
LOG(ERROR) << "current web json value have no " << ENTRY_CREATED_AT
<< std::endl;
return std::nullopt;
}

return std::make_optional(temp_entry);
}

Expand Down Expand Up @@ -825,6 +825,77 @@ std::optional<Entry> GetEntryById(const std::string &id) {
return option_entry;
}

unordered_map<string, Entry> getEntries(const string& source) {
unordered_map<string, Entry> result;
int offset = 0, count = 0, limit = 100;
do {
vector<Entry> entry_list;
getEntries(limit, offset, source, &entry_list, &count);
for (auto& entry : entry_list) {
result.emplace(entry.id, std::move(entry));
}
offset += limit;
} while (result.size() < count);
return result;
}

void getEntries(int limit, int offset, const string& source,
std::vector<Entry> *entry_list, int *count) {
entry_list->clear();
http_client client(U(std::getenv("KNOWLEDGE_BASE_API_URL")));
std::string current_suffix =
std::string(ENTRY_API_SUFFIX) + "?offset=" + std::to_string(offset)
+ "&limit=" + std::to_string(limit) + "&source=" + source + "&extract=true"
+ "&fields=id,file_type,language,url,title,readlater,crawler,starred,disabled,saved,unread,extract,created_at";
LOG(DEBUG) << "current_suffix " << current_suffix
<< std::endl;

bool success = false;
while (!success) {
client.request(methods::GET, U(current_suffix))
.then([](http_response response) -> pplx::task<web::json::value> {
if (response.status_code() == status_codes::OK) {
return response.extract_json();
}
return pplx::task_from_result(web::json::value());
})
.then([&](pplx::task<web::json::value> previousTask) {
try {
web::json::value const &v = previousTask.get();
int code = v.at("code").as_integer();
std::string message = "null";
if (v.has_string_field("message")) {
message = v.at("message").as_string();
}
LOG(DEBUG) << "code " << code << " message " << message << std::endl;
if (code == 0) {
web::json::value data_value = v.at("data");
web::json::value items = data_value.at("items");
*count = data_value.at("count").as_integer();
for (auto iter = items.as_array().cbegin(); iter != items.as_array().cend(); ++iter) {
auto item = convertFromWebJsonValueToEntry(*iter);
if (item == std::nullopt) {
continue;
}
entry_list->push_back(item.value());
}
}
success = true;
} catch (http_exception const &e) {
LOG(ERROR) << "Error exception " << e.what() << std::endl;
}
})
.wait();

if (!success) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
LOG(INFO) << "getEntries retrying..." << std::endl;
} else {
LOG(INFO) << "getEntries succeed." << std::endl;
}
}
}

std::optional<Algorithm> GetAlgorithmById(const std::string &id) {
/**
* @brief
Expand Down
11 changes: 11 additions & 0 deletions train-rank/src/knowledgebase_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include <string>
#include <unordered_map>
#include <vector>

#include <gflags/gflags.h>

using std::cerr;
using std::endl;

Expand All @@ -15,12 +18,16 @@ using namespace web::http::client;
#include "entity/rank_algorithm.h"
#include "entity/reco_metadata.h"

DECLARE_string(recommend_source_name);

namespace knowledgebase {

class EntryCache {
public:
static EntryCache& getInstance();
void init();
std::optional<Entry> getEntryById(const std::string& id);
void loadAllEntries();

void dumpStatistics();

Expand Down Expand Up @@ -63,6 +70,10 @@ void getAllAlgorithmAccordingRanked(std::string source,
bool ranked, int* count);
std::optional<Algorithm> GetAlgorithmById(const std::string& id);
std::optional<Entry> GetEntryById(const std::string& id);
std::unordered_map<std::string, Entry> getEntries(const std::string& source);
std::unordered_map<std::string, Entry> getEntries(const std::string& source);
void getEntries(int limit, int offset, const std::string& source,
std::vector<Entry> *entry_list, int *count);
bool updateKnowledgeConfig(const std::string& source, const std::string& key,
const web::json::value& value);
bool updateLastRankTime(std::string source, int64_t last_rank_time);
Expand Down
6 changes: 3 additions & 3 deletions train-rank/src/lr/feature_extractor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ double EmbeddingDistanceExtractor::extract(const Impression& item) {
if (!item.embedding.has_value()) {
return 0.0;
}
return extract(Item(item.id, item.embedding.value()));
return extract(Item(item.entry_id, item.embedding.value()));
}

double EmbeddingDistanceExtractor::extract(const Algorithm& item) {
if (!item.embedding.has_value()) {
return 0.0;
}
return extract(Item(item.id, item.embedding.value()));
return extract(Item(item.entry, item.embedding.value()));
}

double EmbeddingDistanceExtractor::extract(const Item& item) {
Expand All @@ -64,7 +64,7 @@ double EmbeddingDistanceExtractor::extract(const Item& item) {
minNegCos = std::max(minNegCos, embeddingCosine(neg.embedding, item.embedding));
}
// TODO(haochengwang): Optimize me
return minPosCos;
return minPosCos + 1.0;
}

Reason EmbeddingDistanceExtractor::getReason(const Algorithm& item) {
Expand Down
3 changes: 2 additions & 1 deletion train-rank/src/lr/feature_extractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class EmbeddingDistanceExtractor : public FeatureExtractor {
struct Item {
std::string id;
std::vector<double> embedding;
Item(std::string id, const std::vector<double>& embedding): id(id), embedding(embedding) {}
Item(std::string id, const std::vector<double>& embedding):
id(id), embedding(embedding) {}
};
EmbeddingDistanceExtractor();

Expand Down
Loading

0 comments on commit 7b50055

Please sign in to comment.