Skip to content

Commit

Permalink
Fix protolite compilation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699936577
  • Loading branch information
achoum authored and copybara-github committed Nov 25, 2024
1 parent ddeea33 commit 4b76ec7
Show file tree
Hide file tree
Showing 20 changed files with 116 additions and 93 deletions.
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/learner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ cc_library_ydf(
"//yggdrasil_decision_forests/utils:fold_generator_cc_proto",
"//yggdrasil_decision_forests/utils:hyper_parameters",
"//yggdrasil_decision_forests/utils:logging",
"//yggdrasil_decision_forests/utils:protobuf",
"//yggdrasil_decision_forests/utils:random",
"//yggdrasil_decision_forests/utils:status_macros",
"//yggdrasil_decision_forests/utils:synchronization_primitives",
Expand Down
23 changes: 5 additions & 18 deletions yggdrasil_decision_forests/learner/abstract_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "yggdrasil_decision_forests/utils/fold_generator.h"
#include "yggdrasil_decision_forests/utils/hyper_parameters.h"
#include "yggdrasil_decision_forests/utils/logging.h"
#include "yggdrasil_decision_forests/utils/protobuf.h"
#include "yggdrasil_decision_forests/utils/random.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"
#include "yggdrasil_decision_forests/utils/synchronization_primitives.h"
Expand Down Expand Up @@ -491,13 +492,12 @@ absl::Status AbstractLearner::CheckConfiguration(
return absl::OkStatus();
}

#ifdef YGG_PROTOBUF_LITE
if (config.has_maximum_model_size_in_memory_in_bytes()) {
if (config.has_maximum_model_size_in_memory_in_bytes() &&
!utils::ProtoSizeInBytesIsAvailable()) {
return absl::InvalidArgumentError(
"YDF has been compiled with YGG_PROTOBUF_LITE. Model size "
"cannot be estimated.");
"Cannot constraint the model size during training as YDF was compiled "
"with protobuf lite");
}
#endif // YGG_PROTOBUF_LITE

const auto& label_col_spec = data_spec.columns(config_link.label());
// Check the type of the label column.
Expand Down Expand Up @@ -663,11 +663,6 @@ absl::Status AbstractLearner::SetHyperParametersImpl(
const auto hparam =
generic_hyper_params->Get(kHParamMaximumModelSizeInMemoryInBytes);
if (hparam.has_value()) {
#ifdef YGG_PROTOBUF_LITE
return absl::InvalidArgumentError(
"YDF has been compiled with YGG_PROTOBUF_LITE. Model size "
"cannot be estimated.");
#endif // YGG_PROTOBUF_LITE
if (hparam.value().value().real() >= 0) {
training_config_.set_maximum_model_size_in_memory_in_bytes(
hparam.value().value().real());
Expand Down Expand Up @@ -929,14 +924,6 @@ absl::Status AbstractLearner::CheckCapabilities() const {
training_config().learner()));
}

#ifdef YGG_PROTOBUF_LITE
if (training_config().has_maximum_model_size_in_memory_in_bytes()) {
return absl::InvalidArgumentError(
"YDF has been compiled with YGG_PROTOBUF_LITE. Model size "
"cannot be estimated.");
}
#endif // YGG_PROTOBUF_LITE

// Monotonic constraints
if (!capabilities.support_monotonic_constraints() &&
training_config().monotonic_constraints_size() > 0) {
Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/learner/multitasker/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ cc_library_ydf(
"//yggdrasil_decision_forests/model/multitasker",
"//yggdrasil_decision_forests/serving:example_set",
"//yggdrasil_decision_forests/utils:concurrency",
"//yggdrasil_decision_forests/utils:protobuf",
"//yggdrasil_decision_forests/utils:regex",
"//yggdrasil_decision_forests/utils:status_macros",
"//yggdrasil_decision_forests/utils:synchronization_primitives",
Expand Down
11 changes: 6 additions & 5 deletions yggdrasil_decision_forests/learner/multitasker/multitasker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "yggdrasil_decision_forests/model/multitasker/multitasker.h"
#include "yggdrasil_decision_forests/serving/example_set.h"
#include "yggdrasil_decision_forests/utils/concurrency.h"
#include "yggdrasil_decision_forests/utils/protobuf.h"
#include "yggdrasil_decision_forests/utils/regex.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"
#include "yggdrasil_decision_forests/utils/synchronization_primitives.h"
Expand Down Expand Up @@ -345,11 +346,11 @@ MultitaskerLearner::BuildSubTrainingConfig(const int learner_idx) const {

if (training_config().has_maximum_model_size_in_memory_in_bytes() &&
!sub_learner_config.has_maximum_model_size_in_memory_in_bytes()) {
#ifdef YGG_PROTOBUF_LITE
return absl::InvalidArgumentError(
"YDF has been compiled with YGG_PROTOBUF_LITE. Model size "
"cannot be estimated.");
#endif // YGG_PROTOBUF_LITE
if (!utils::ProtoSizeInBytesIsAvailable()) {
return absl::InvalidArgumentError(
"YDF has been compiled with YGG_PROTOBUF_LITE. Model size "
"cannot be estimated.");
}
sub_learner_config.set_maximum_model_size_in_memory_in_bytes(
training_config().maximum_model_size_in_memory_in_bytes());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ RandomForestLearner::TrainWithStatusImpl(
concurrent_fields.num_nodes_completed_trees.assign(rf_config.num_trees(),
-1);
concurrent_fields.model_size_in_bytes =
mdl->AbstractAttributesSizeInBytes();
mdl->AbstractAttributesSizeInBytes().value_or(0);
}

// Note: "num_trained_trees" is defined outside of the following brackets so
Expand Down Expand Up @@ -699,7 +699,8 @@ RandomForestLearner::TrainWithStatusImpl(
if (training_config().has_maximum_model_size_in_memory_in_bytes()) {
const auto tree_size_in_bytes =
decision_tree->EstimateModelSizeInBytes();
concurrent_fields.model_size_in_bytes += tree_size_in_bytes;
concurrent_fields.model_size_in_bytes +=
tree_size_in_bytes.value_or(0);
// Note: A model should contain at least one tree.
if (num_trained_trees > 0 &&
concurrent_fields.model_size_in_bytes >
Expand Down
17 changes: 9 additions & 8 deletions yggdrasil_decision_forests/model/abstract_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1443,21 +1443,22 @@ AbstractModel::BuildFastEngine(
return engine_or;
}

size_t AbstractModel::AbstractAttributesSizeInBytes() const {
#ifdef YGG_PROTOBUF_LITE
return 0;
#else
size_t size = sizeof(*this) + name_.size() + data_spec_.SpaceUsedLong();
std::optional<size_t> AbstractModel::AbstractAttributesSizeInBytes() const {
if (!utils::ProtoSizeInBytesIsAvailable()) {
return {};
}
size_t size = sizeof(*this) + name_.size() +
utils::ProtoSizeInBytes(data_spec_).value_or(0);
size +=
input_features_.size() * sizeof(decltype(input_features_)::value_type);
if (weights_.has_value()) {
size += weights_->ByteSizeLong();
size += utils::ProtoSizeInBytes(weights_.value()).value_or(0);
}
for (const auto& v : precomputed_variable_importances_) {
size += sizeof(v) + v.first.size() + v.second.SpaceUsedLong();
size += sizeof(v) + v.first.size() +
utils::ProtoSizeInBytes(v.second).value_or(0);
}
return size;
#endif // YGG_PROTOBUF_LITE
}

absl::Status AbstractModel::ValidateModelIOOptions(
Expand Down
5 changes: 2 additions & 3 deletions yggdrasil_decision_forests/model/abstract_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,8 @@ class AbstractModel {
virtual std::optional<size_t> ModelSizeInBytes() const { return {}; }

// Estimates the memory usage of the attributes defined in the "AbstractModel"
// object.
// Returns 0 if the model is compiled with YGG_PROTOBUF_LITE.
size_t AbstractAttributesSizeInBytes() const;
// object. Returns {} if the model size is not available.
std::optional<size_t> AbstractAttributesSizeInBytes() const;

// List of input features of the model.
const std::vector<int>& input_features() const { return input_features_; }
Expand Down
6 changes: 0 additions & 6 deletions yggdrasil_decision_forests/model/decision_tree/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@ cc_library_ydf(
"decision_tree_io.h",
"structure_analysis.h",
],
defines = select({
"//yggdrasil_decision_forests:ydf_protobuf_lite": [
"YGG_PROTOBUF_LITE",
],
"//conditions:default": [],
}),
deps = [
":decision_tree_cc_proto",
":decision_tree_io_blob_sequence",
Expand Down
37 changes: 17 additions & 20 deletions yggdrasil_decision_forests/model/decision_tree/decision_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,29 +376,27 @@ void AppendConditionDescription(
node.num_pos_training_examples_without_weight(), node.na_value());
}

size_t DecisionTree::EstimateModelSizeInBytes() const {
#ifdef YGG_PROTOBUF_LITE
return 0;
#else
std::optional<size_t> DecisionTree::EstimateModelSizeInBytes() const {
if (!utils::ProtoSizeInBytesIsAvailable()) {
return {};
}
if (root_) {
return root_->EstimateSizeInByte() + sizeof(DecisionTree);
return root_->EstimateSizeInByte().value_or(0) + sizeof(DecisionTree);
} else {
return sizeof(DecisionTree);
}
#endif // YGG_PROTOBUF_LITE
}

size_t NodeWithChildren::EstimateSizeInByte() const {
#ifdef YGG_PROTOBUF_LITE
return 0;
#else
size_t size = node_.SpaceUsedLong();
std::optional<size_t> NodeWithChildren::EstimateSizeInByte() const {
if (!utils::ProtoSizeInBytesIsAvailable()) {
return 0;
}
size_t size = utils::ProtoSizeInBytes(node_).value_or(0);
if (!IsLeaf()) {
size += children_[0]->EstimateSizeInByte();
size += children_[1]->EstimateSizeInByte();
size += children_[0]->EstimateSizeInByte().value_or(0);
size += children_[1]->EstimateSizeInByte().value_or(0);
}
return size;
#endif // YGG_PROTOBUF_LITE
}

int64_t NodeWithChildren::NumNodes() const {
Expand Down Expand Up @@ -1702,17 +1700,16 @@ void SetLeafIndices(DecisionForest* trees) {
}
}

size_t EstimateSizeInByte(
std::optional<size_t> EstimateSizeInByte(
const std::vector<std::unique_ptr<DecisionTree>>& trees) {
#ifdef YGG_PROTOBUF_LITE
return 0;
#else
if (!utils::ProtoSizeInBytesIsAvailable()) {
return {};
}
size_t size = 0;
for (const auto& tree : trees) {
size += tree->EstimateModelSizeInBytes();
size += tree->EstimateModelSizeInBytes().value_or(0);
}
return size;
#endif // YGG_PROTOBUF_LITE
}

// Number of nodes in a list of decision trees.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class NodeWithChildren {
public:
// Approximate size in memory (expressed in bytes) of the node and all its
// children.
size_t EstimateSizeInByte() const;
std::optional<size_t> EstimateSizeInByte() const;

// Exports the node (and its children) to a RecordIO writer. The nodes are
// stored sequentially with a depth-first exploration.
Expand Down Expand Up @@ -237,7 +237,7 @@ class DecisionTree {
public:
// Estimates the memory usage of the model in RAM. The serialized or the
// compiled version of the model can be much smaller.
size_t EstimateModelSizeInBytes() const;
std::optional<size_t> EstimateModelSizeInBytes() const;

// Number of nodes in the tree.
int64_t NumNodes() const;
Expand Down Expand Up @@ -349,7 +349,7 @@ void SetLeafIndices(DecisionForest* trees);

// Estimate the size (in bytes) of a list of decision trees.
// Returns 0 if the size cannot be estimated.
size_t EstimateSizeInByte(const DecisionForest& trees);
std::optional<size_t> EstimateSizeInByte(const DecisionForest& trees);

// Number of nodes in a list of decision trees.
int64_t NumberOfNodes(const DecisionForest& trees);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@ absl::Status SaveTreesToDisk(
ASSIGN_OR_RETURN(const auto format_impl, GetFormatImplementation(format));
// FutureWork(gbm): The current function is fully sequential. If speed
// becomes an issue, make it so that the shards are written in parallel.
*num_shards =
std::max<int>(1, (EstimateSizeInByte(trees) + kMaxShardSizeInByte - 1) /
kMaxShardSizeInByte);

auto tree_size = EstimateSizeInByte(trees);
if (tree_size.has_value()) {
*num_shards = std::max<int>(
1, (tree_size.value() + kMaxShardSizeInByte - 1) / kMaxShardSizeInByte);
} else {
*num_shards = 1;
}

const int64_t num_nodes = NumberOfNodes(trees);
const int num_nodes_per_shard =
std::max<int>(1, (num_nodes + *num_shards - 1) / *num_shards);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,11 @@ absl::Status GradientBoostedTreesModel::Validate() const {
}

std::optional<size_t> GradientBoostedTreesModel::ModelSizeInBytes() const {
#ifdef YGG_PROTOBUF_LITE
return std::nullopt;
#else
return AbstractAttributesSizeInBytes() +
decision_tree::EstimateSizeInByte(decision_trees_);
#endif // YGG_PROTOBUF_LITE
OPTIONAL_ASSIGN_OR_RETURN(const auto abstract_size,
AbstractAttributesSizeInBytes());
OPTIONAL_ASSIGN_OR_RETURN(const auto tree_size,
decision_tree::EstimateSizeInByte(decision_trees_));
return abstract_size + tree_size;
}

int64_t GradientBoostedTreesModel::NumNodes() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,11 @@ absl::Status IsolationForestModel::Validate() const {
}

std::optional<size_t> IsolationForestModel::ModelSizeInBytes() const {
#ifdef YGG_PROTOBUF_LITE
return std::nullopt;
#else
return AbstractAttributesSizeInBytes() +
decision_tree::EstimateSizeInByte(decision_trees_);
#endif // YGG_PROTOBUF_LITE
OPTIONAL_ASSIGN_OR_RETURN(const auto abstract_size,
AbstractAttributesSizeInBytes());
OPTIONAL_ASSIGN_OR_RETURN(const auto tree_size,
decision_tree::EstimateSizeInByte(decision_trees_));
return abstract_size + tree_size;
}

void IsolationForestModel::PredictLambda(
Expand Down
11 changes: 5 additions & 6 deletions yggdrasil_decision_forests/model/random_forest/random_forest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,11 @@ absl::Status RandomForestModel::Validate() const {
}

std::optional<size_t> RandomForestModel::ModelSizeInBytes() const {
#ifdef YGG_PROTOBUF_LITE
return std::nullopt;
#else
return AbstractAttributesSizeInBytes() +
decision_tree::EstimateSizeInByte(decision_trees_);
#endif // YGG_PROTOBUF_LITE
OPTIONAL_ASSIGN_OR_RETURN(const auto abstract_size,
AbstractAttributesSizeInBytes());
OPTIONAL_ASSIGN_OR_RETURN(const auto tree_size,
decision_tree::EstimateSizeInByte(decision_trees_));
return abstract_size + tree_size;
}

int64_t RandomForestModel::NumNodes() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <string>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
Expand Down
10 changes: 7 additions & 3 deletions yggdrasil_decision_forests/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,13 @@ cc_library_ydf(

cc_library_ydf(
name = "protobuf",
hdrs = [
"protobuf.h",
],
hdrs = ["protobuf.h"],
defines = select({
"//yggdrasil_decision_forests:ydf_protobuf_lite": [
"YGG_PROTOBUF_LITE",
],
"//conditions:default": [],
}),
deps = [
":logging",
"@com_google_absl//absl/status",
Expand Down
5 changes: 4 additions & 1 deletion yggdrasil_decision_forests/utils/logging_default.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
#include "absl/log/initialize.h"
#include "absl/log/log.h"

#ifndef YDF_NO_ALSOLOGTOSTDERR_FLAG
ABSL_FLAG(bool, alsologtostderr, false, "Log all messages to stderr");
#endif

namespace yggdrasil_decision_forests::logging {
void InitLoggingLib() { absl::InitializeLog(); }
Expand All @@ -35,9 +37,10 @@ void InitLogging(const char* usage, int* argc, char*** argv,
absl::InitializeLog();
absl::SetProgramUsageMessage(usage);
absl::ParseCommandLine(*argc, *argv);

#ifndef YDF_NO_ALSOLOGTOSTDERR_FLAG
if (absl::GetFlag(FLAGS_alsologtostderr)) {
absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);
absl::SetMinLogLevel(absl::LogSeverityAtLeast::kInfo);
}
#endif
}
Loading

0 comments on commit 4b76ec7

Please sign in to comment.