diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index dc6bab3..edab32d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -114,6 +114,7 @@ add_library(yugawara yugawara/analyzer/details/collect_join_keys.cpp yugawara/analyzer/details/rewrite_scan.cpp yugawara/analyzer/details/classify_expression.cpp + yugawara/analyzer/details/inline_variables.cpp # analyzer misc. yugawara/analyzer/details/detect_join_endpoint_style.cpp diff --git a/src/yugawara/analyzer/details/inline_variables.cpp b/src/yugawara/analyzer/details/inline_variables.cpp new file mode 100644 index 0000000..9a75314 --- /dev/null +++ b/src/yugawara/analyzer/details/inline_variables.cpp @@ -0,0 +1,168 @@ +#include "inline_variables.h" + +#include + +#include + +namespace yugawara::analyzer::details { + +namespace { + +class engine { +public: + explicit engine(inline_variables::map_type const& replacements) noexcept : + replacements_ { replacements } + {} + + void process(::takatori::util::ownership_reference<::takatori::scalar::expression> target) { + if (auto expr = target.find()) { + auto replacement = dispatch(*expr); + if (replacement) { + target.set(std::move(replacement)); + } + } + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> dispatch(::takatori::scalar::expression& expression) { + return ::takatori::scalar::dispatch(*this, expression); + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::expression& expression) { + (void) expression; + return {}; + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::immediate& expression) { + (void) expression; + return {}; + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::variable_reference& expression) { + auto found = replacements_.find(expression.variable().reference()); + if (found == replacements_.end()) { + return {}; + } + return ::takatori::util::clone_unique(found->second); + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::unary& expression) { + if (auto replacement = dispatch(expression.operand())) { + expression.operand(std::move(replacement)); + } + return {}; + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::cast& expression) { + if (auto replacement = dispatch(expression.operand())) { + expression.operand(std::move(replacement)); + } + return {}; + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::binary& expression) { + if (auto replacement = dispatch(expression.left())) { + expression.left(std::move(replacement)); + } + if (auto replacement = dispatch(expression.right())) { + expression.right(std::move(replacement)); + } + return {}; + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::compare& expression) { + if (auto replacement = dispatch(expression.left())) { + expression.left(std::move(replacement)); + } + if (auto replacement = dispatch(expression.right())) { + expression.right(std::move(replacement)); + } + return {}; + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::match& expression) { + if (auto replacement = dispatch(expression.input())) { + expression.input(std::move(replacement)); + } + if (auto replacement = dispatch(expression.pattern())) { + expression.pattern(std::move(replacement)); + } + if (auto replacement = dispatch(expression.escape())) { + expression.escape(std::move(replacement)); + } + return {}; + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::conditional& expression) { + for (auto&& alternative : expression.alternatives()) { + if (auto replacement = dispatch(alternative.condition())) { + alternative.condition(std::move(replacement)); + } + if (auto replacement = dispatch(alternative.body())) { + alternative.body(std::move(replacement)); + } + } + if (auto otherwise = expression.default_expression()) { + if (auto replacement = dispatch(*otherwise)) { + expression.default_expression(std::move(replacement)); + } + } + return {}; + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::coalesce& expression) { + apply_list(expression.alternatives()); + return {}; + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::let& expression) { + for (auto&& decl : expression.variables()) { + if (auto replacement = dispatch(decl.value())) { + decl.value(std::move(replacement)); + } + } + if (auto replacement = dispatch(expression.body())) { + expression.body(std::move(replacement)); + } + return {}; + } + + [[nodiscard]] std::unique_ptr<::takatori::scalar::expression> operator()(::takatori::scalar::function_call& expression) { + apply_list(expression.arguments()); + return {}; + } + +private: + inline_variables::map_type const& replacements_; + + void apply_list(::takatori::tree::tree_element_vector<::takatori::scalar::expression>& list) { + for (auto iter = list.begin(); iter != list.end(); ++iter) { + if (auto replacement = dispatch(*iter)) { + (void) list.exchange(iter, std::move(replacement)); + } + } + } +}; + +} // namespace + +void inline_variables::reserve(std::size_t size) { + variables_.reserve(size); + replacements_.reserve(size); +} + +void inline_variables::declare( + ::takatori::descriptor::variable variable, + std::unique_ptr<::takatori::scalar::expression> replacement) { + auto key = variable.reference(); + if (replacements_.find(key) == replacements_.end()) { + variables_.emplace_back(std::move(variable)); + } + replacements_.emplace(key, std::move(replacement)); +} + +void inline_variables::apply(::takatori::util::ownership_reference<::takatori::scalar::expression> target) const { + engine e { replacements_ }; + e.process(std::move(target)); +} + +} // namespace yugawara::analyzer::details diff --git a/src/yugawara/analyzer/details/inline_variables.h b/src/yugawara/analyzer/details/inline_variables.h new file mode 100644 index 0000000..525cb04 --- /dev/null +++ b/src/yugawara/analyzer/details/inline_variables.h @@ -0,0 +1,46 @@ +#pragma once + +#include + +#include + +#include + +#include + +namespace yugawara::analyzer::details { + +class inline_variables { +public: + using map_type = ::tsl::hopscotch_map< + ::takatori::descriptor::variable::reference_type, + std::unique_ptr<::takatori::scalar::expression>>; + + /** + * @brief reserves the space for the variables to be inlined. + * @param size the number of variables to be inlined + */ + void reserve(std::size_t size); + + /** + * @brief declares a variable to be inlined. + * @param variable the target variable + * @param replacement the replacement expression + * @return this + */ + void declare( + ::takatori::descriptor::variable variable, + std::unique_ptr<::takatori::scalar::expression> replacement); + + /** + * @brief applies the inlining of the variables to the target expression. + * @param target the target expression + */ + void apply(::takatori::util::ownership_reference<::takatori::scalar::expression> target) const; + +private: + std::vector<::takatori::descriptor::variable> variables_ {}; + map_type replacements_; +}; + +} // namespace yugawara::analyzer::details diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9424b6d..b900dc6 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -100,6 +100,7 @@ add_test_executable(yugawara/analyzer/details/rewrite_join_test.cpp) add_test_executable(yugawara/analyzer/details/collect_join_keys_test.cpp) add_test_executable(yugawara/analyzer/details/rewrite_scan_test.cpp) add_test_executable(yugawara/analyzer/details/classify_expression_test.cpp) +add_test_executable(yugawara/analyzer/details/inline_variables_test.cpp) add_test_executable(yugawara/analyzer/intermediate_plan_optimizer_test.cpp) # serializer diff --git a/test/yugawara/analyzer/details/classify_expression_test.cpp b/test/yugawara/analyzer/details/classify_expression_test.cpp index 76ab5db..a4f7307 100644 --- a/test/yugawara/analyzer/details/classify_expression_test.cpp +++ b/test/yugawara/analyzer/details/classify_expression_test.cpp @@ -41,7 +41,6 @@ class classify_expression_test : public ::testing::Test { }; } - type::repository types; binding::factory bindings; }; diff --git a/test/yugawara/analyzer/details/inline_variables_test.cpp b/test/yugawara/analyzer/details/inline_variables_test.cpp new file mode 100644 index 0000000..5620f91 --- /dev/null +++ b/test/yugawara/analyzer/details/inline_variables_test.cpp @@ -0,0 +1,431 @@ +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace yugawara::analyzer::details { + +// import test utils +using namespace ::yugawara::testing; + +class inline_variables_test : public ::testing::Test { +protected: + + scalar::variable_reference var(descriptor::variable v) { + return scalar::variable_reference { + std::move(v), + }; + } + + scalar::immediate dummy(int v = 0) { + return scalar::immediate { + v::int4 { v }, + t::int4 {}, + }; + } + + descriptor::variable declare(int v) { + auto r = bindings.stream_variable(); + inliner.declare(r, std::make_unique(v::int4 { v }, t::int4 {})); + return r; + } + + std::unique_ptr apply(scalar::expression&& expr) { + auto target = ::takatori::util::clone_unique(std::move(expr)); + inliner.apply(target); + return target; + } + + inline_variables inliner; + binding::factory bindings; +}; + +TEST_F(inline_variables_test, immediate) { + auto r = apply(scalar::immediate { + v::int4 { 100 }, + t::int4 {}, + }); + EXPECT_EQ(*r, (scalar::immediate { + v::int4 { 100 }, + t::int4 {}, + })); +} + +TEST_F(inline_variables_test, variable_reference_miss) { + auto v = bindings.stream_variable(); + auto r = apply(scalar::variable_reference { + v, + }); + EXPECT_EQ(*r, (scalar::variable_reference { + v, + })); +} + +TEST_F(inline_variables_test, variable_reference_hit) { + auto v = declare(10); + auto r = apply(scalar::variable_reference { + v, + }); + EXPECT_EQ(*r, dummy(10)); +} + +TEST_F(inline_variables_test, unary_miss) { + auto v = bindings.stream_variable(); + auto r = apply(scalar::unary { + scalar::unary_operator::plus, + var(v), + }); + EXPECT_EQ(*r, (scalar::unary { + scalar::unary_operator::plus, + var(v), + })); +} + +TEST_F(inline_variables_test, unary_operand) { + auto v = declare(10); + auto r = apply(scalar::unary { + scalar::unary_operator::plus, + var(v), + }); + EXPECT_EQ(*r, (scalar::unary { + scalar::unary_operator::plus, + dummy(10), + })); +} + +TEST_F(inline_variables_test, cast_miss) { + auto v = bindings.stream_variable(); + auto r = apply(scalar::cast { + t::int8 {}, + scalar::cast_loss_policy::error, + var(v), + }); + EXPECT_EQ(*r, (scalar::cast { + t::int8 {}, + scalar::cast_loss_policy::error, + var(v), + })); +} + +TEST_F(inline_variables_test, cast_operand) { + auto v = declare(10); + auto r = apply(scalar::cast { + t::int8 {}, + scalar::cast_loss_policy::error, + var(v), + }); + EXPECT_EQ(*r, (scalar::cast { + t::int8 {}, + scalar::cast_loss_policy::error, + dummy(10), + })); +} + +TEST_F(inline_variables_test, binary_miss) { + auto v1 = bindings.stream_variable(); + auto v2 = bindings.stream_variable(); + auto r = apply(scalar::binary { + scalar::binary_operator::add, + var(v1), + var(v2), + }); + EXPECT_EQ(*r, (scalar::binary { + scalar::binary_operator::add, + var(v1), + var(v2), + })); +} + +TEST_F(inline_variables_test, binary_hit) { + auto v1 = declare(10); + auto v2 = declare(20); + auto r = apply(scalar::binary { + scalar::binary_operator::add, + var(v1), + var(v2), + }); + EXPECT_EQ(*r, (scalar::binary { + scalar::binary_operator::add, + dummy(10), + dummy(20), + })); +} + +TEST_F(inline_variables_test, compare_miss) { + auto v1 = bindings.stream_variable(); + auto v2 = bindings.stream_variable(); + auto r = apply(scalar::compare { + scalar::comparison_operator::equal, + var(v1), + var(v2), + }); + EXPECT_EQ(*r, (scalar::compare { + scalar::comparison_operator::equal, + var(v1), + var(v2), + })); +} + +TEST_F(inline_variables_test, compare_left) { + auto v1 = declare(10); + auto v2 = declare(20); + auto r = apply(scalar::compare { + scalar::comparison_operator::equal, + var(v1), + var(v2), + }); + EXPECT_EQ(*r, (scalar::compare { + scalar::comparison_operator::equal, + dummy(10), + dummy(20), + })); +} + +TEST_F(inline_variables_test, match_miss) { + auto v1 = bindings.stream_variable(); + auto v2 = bindings.stream_variable(); + auto v3 = bindings.stream_variable(); + auto r = apply(scalar::match { + scalar::match_operator::like, + var(v1), + var(v2), + var(v3), + }); + EXPECT_EQ(*r, (scalar::match { + scalar::match_operator::like, + var(v1), + var(v2), + var(v3), + })); +} + +TEST_F(inline_variables_test, match_hit) { + auto v1 = declare(10); + auto v2 = declare(20); + auto v3 = declare(30); + auto r = apply(scalar::match { + scalar::match_operator::like, + var(v1), + var(v2), + var(v3), + }); + EXPECT_EQ(*r, (scalar::match { + scalar::match_operator::like, + dummy(10), + dummy(20), + dummy(30), + })); +} + +TEST_F(inline_variables_test, conditional_miss) { + auto v1 = bindings.stream_variable(); + auto v2 = bindings.stream_variable(); + auto v3 = bindings.stream_variable(); + auto v4 = bindings.stream_variable(); + auto v5 = bindings.stream_variable(); + auto r = apply(scalar::conditional { + { + scalar::conditional::alternative { var(v1), var(v2), }, + scalar::conditional::alternative { var(v3), var(v4), }, + }, + var(v5), + }); + EXPECT_EQ(*r, (scalar::conditional { + { + scalar::conditional::alternative { var(v1), var(v2), }, + scalar::conditional::alternative { var(v3), var(v4), }, + }, + var(v5), + })); +} + +TEST_F(inline_variables_test, conditional_hit) { + auto v1 = declare(10); + auto v2 = bindings.stream_variable(); + auto v3 = bindings.stream_variable(); + auto v4 = declare(20); + auto v5 = declare(30); + auto r = apply(scalar::conditional { + { + scalar::conditional::alternative { var(v1), var(v2), }, + scalar::conditional::alternative { var(v3), var(v4), }, + }, + var(v5), + }); + EXPECT_EQ(*r, (scalar::conditional { + { + scalar::conditional::alternative {dummy(10), var(v2), }, + scalar::conditional::alternative { var(v3), dummy(20), }, + }, + dummy(30), + })); +} + +TEST_F(inline_variables_test, coalesce_miss) { + auto v1 = bindings.stream_variable(); + auto v2 = bindings.stream_variable(); + auto v3 = bindings.stream_variable(); + auto r = apply(scalar::coalesce { + { + var(v1), + var(v2), + var(v3), + } + }); + EXPECT_EQ(*r, (scalar::coalesce { + { + var(v1), + var(v2), + var(v3), + } + })); +} + +TEST_F(inline_variables_test, coalesce_hit) { + auto v1 = declare(10); + auto v2 = declare(20); + auto v3 = declare(30); + auto r = apply(scalar::coalesce { + { + var(v1), + var(v2), + var(v3), + } + }); + EXPECT_EQ(*r, (scalar::coalesce { + { + dummy(10), + dummy(20), + dummy(30), + } + })); +} + +TEST_F(inline_variables_test, let_miss) { + auto v1 = bindings.stream_variable(); + auto v2 = bindings.stream_variable(); + auto v3 = bindings.stream_variable(); + auto v4 = bindings.stream_variable(); + auto v5 = bindings.stream_variable(); + auto r = apply(scalar::let { + { + scalar::let::variable { v1, var(v2) }, + scalar::let::variable { v3, var(v4) }, + }, + var(v5), + }); + EXPECT_EQ(*r, (scalar::let { + { + scalar::let::variable { v1, var(v2) }, + scalar::let::variable { v3, var(v4) }, + }, + var(v5), + })); +} + +TEST_F(inline_variables_test, let_hit) { + auto v1 = bindings.stream_variable(); + auto v2 = declare(10); + auto v3 = bindings.stream_variable(); + auto v4 = bindings.stream_variable(); + auto v5 = declare(20); + auto r = apply(scalar::let { + { + scalar::let::variable { v1, var(v2) }, + scalar::let::variable { v3, var(v4) }, + }, + var(v5), + }); + EXPECT_EQ(*r, (scalar::let { + { + scalar::let::variable { v1, dummy(10) }, + scalar::let::variable { v3, var(v4) }, + }, + dummy(20), + })); +} + +TEST_F(inline_variables_test, function_call_miss) { + auto f = bindings.function({ + 1, + "f", + t::int4 {}, + { + t::int4 {}, + t::int4 {}, + t::int4 {}, + }, + }); + auto v1 = bindings.stream_variable(); + auto v2 = bindings.stream_variable(); + auto v3 = bindings.stream_variable(); + auto r = apply(scalar::function_call { + f, + { + var(v1), + var(v2), + var(v3), + }, + }); + EXPECT_EQ(*r, (scalar::function_call { + f, + { + var(v1), + var(v2), + var(v3), + }, + })); +} + +TEST_F(inline_variables_test, function_call_hit) { + auto f = bindings.function({ + 1, + "f", + t::int4 {}, + { + t::int4 {}, + t::int4 {}, + t::int4 {}, + }, + }); + auto v1 = declare(10); + auto v2 = declare(20); + auto v3 = declare(30); + auto r = apply(scalar::function_call { + f, + { + var(v1), + var(v2), + var(v3), + }, + }); + EXPECT_EQ(*r, (scalar::function_call { + f, + { + dummy(10), + dummy(20), + dummy(30), + }, + })); +} + +} // namespace yugawara::analyzer::details