diff --git a/spatial/src/spatial/core/index/rtree/rtree_index_plan_scan.cpp b/spatial/src/spatial/core/index/rtree/rtree_index_plan_scan.cpp index c453956..7a531b5 100644 --- a/spatial/src/spatial/core/index/rtree/rtree_index_plan_scan.cpp +++ b/spatial/src/spatial/core/index/rtree/rtree_index_plan_scan.cpp @@ -1,9 +1,12 @@ #include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" #include "duckdb/optimizer/column_lifetime_analyzer.hpp" +#include "duckdb/optimizer/matcher/expression_matcher.hpp" +#include "duckdb/optimizer/matcher/function_matcher.hpp" #include "duckdb/optimizer/optimizer_extension.hpp" #include "duckdb/optimizer/remove_unused_columns.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator_extension.hpp" @@ -11,14 +14,15 @@ #include "spatial/core/geometry/bbox.hpp" #include "spatial/core/geometry/geometry_type.hpp" #include "spatial/core/index/rtree/rtree_index.hpp" +#include "spatial/core/index/rtree/rtree_index_create_logical.hpp" #include "spatial/core/index/rtree/rtree_index_scan.hpp" #include "spatial/core/index/rtree/rtree_module.hpp" #include "spatial/core/types.hpp" #include "spatial/core/util/math.hpp" -#include "duckdb/optimizer/matcher/expression_matcher.hpp" -#include "duckdb/optimizer/matcher/function_matcher.hpp" -#include "spatial/core/index/rtree/rtree_index_create_logical.hpp" +#include +#include +#include namespace spatial { @@ -95,7 +99,7 @@ class RTreeIndexScanOptimizer : public OptimizerExtension { return true; } - static bool TryOptimize(ClientContext &context, unique_ptr &plan) { + static bool TryOptimize(Binder &binder, ClientContext &context, unique_ptr &plan, unique_ptr &root) { // Look for a FILTER with a spatial predicate followed by a LOGICAL_GET table scan auto &op = *plan; @@ -116,11 +120,17 @@ class RTreeIndexScanOptimizer : public OptimizerExtension { if (filter.children.front()->type != LogicalOperatorType::LOGICAL_GET) { return false; } - auto &get = filter.children.front()->Cast(); + auto &get_ptr = filter.children.front(); + auto &get = get_ptr->Cast(); if (get.function.name != "seq_scan") { return false; } + // We cant optimize if the table already has filters pushed down :( + if(get.dynamic_filters && get.dynamic_filters->HasFilters()) { + return false; + } + // We can replace the scan function with a rtree index scan (if the table has a rtree index) // Get the table auto &table = *get.GetTable(); @@ -182,24 +192,57 @@ class RTreeIndexScanOptimizer : public OptimizerExtension { return false; } - // Replace the scan with our custom index scan function + // If there are no table filters pushed down into the get, we can just replace the get with the index scan + const auto cardinality = get.function.cardinality(context, bind_data.get()); get.function = RTreeIndexScanFunction::GetFunction(); - auto cardinality = get.function.cardinality(context, bind_data.get()); get.has_estimated_cardinality = cardinality->has_estimated_cardinality; get.estimated_cardinality = cardinality->estimated_cardinality; get.bind_data = std::move(bind_data); - + if(get.table_filters.filters.empty()) { + return true; + } + get.projection_ids.clear(); + get.types.clear(); + + // Otherwise, things get more complicated. We need to pullup the filters from the table scan as our index scan + // does not support regular filter pushdown. + auto new_filter = make_uniq(); + auto &column_ids = get.GetColumnIds(); + for(const auto &entry : get.table_filters.filters) { + idx_t column_id = entry.first; + auto &type = get.returned_types[column_id]; + bool found = false; + for(idx_t i = 0; i < column_ids.size(); i++) { + if (column_ids[i] == column_id) { + column_id = i; + found = true; + break; + } + } + if (!found) { + throw InternalException("Could not find column id for filter"); + } + auto column = make_uniq(type, ColumnBinding(get.table_index, column_id)); + new_filter->expressions.push_back(entry.second->ToExpression(*column)); + } + new_filter->children.push_back(std::move(get_ptr)); + new_filter->ResolveOperatorTypes(); + get_ptr = std::move(new_filter); return true; } - static void Optimize(OptimizerExtensionInput &input, unique_ptr &plan) { - if (!TryOptimize(input.context, plan)) { + static void OptimizeRecursive(OptimizerExtensionInput &input, unique_ptr &plan, unique_ptr &root) { + if (!TryOptimize(input.optimizer.binder, input.context, plan, root)) { // No match: continue with the children for (auto &child : plan->children) { - Optimize(input, child); + OptimizeRecursive(input, child, root); } } } + + static void Optimize(OptimizerExtensionInput &input, unique_ptr &plan) { + OptimizeRecursive(input, plan, plan); + } }; //----------------------------------------------------------------------------- diff --git a/test/data/segments.parquet b/test/data/segments.parquet new file mode 100644 index 0000000..64e93d8 Binary files /dev/null and b/test/data/segments.parquet differ diff --git a/test/sql/index/rtree_filter_pullup.test b/test/sql/index/rtree_filter_pullup.test new file mode 100644 index 0000000..2288a65 --- /dev/null +++ b/test/sql/index/rtree_filter_pullup.test @@ -0,0 +1,19 @@ +require spatial + +require parquet + +query I rowsort +SELECT id FROM '__WORKING_DIRECTORY__/test/data/segments.parquet' +WHERE subtype='road' AND ST_Intersects(geometry, ST_Buffer(ST_GeomFromText('POINT (-8476562 4795814)'), 100)); +---- +0862aac667ffffff043df7e4c6756d14 +0862aac667ffffff047de7f2111f86ad +0862aac667ffffff047ffedcbc2db0f8 + + +query III rowsort +SELECT id, subtype, class FROM '__WORKING_DIRECTORY__/test/data/segments.parquet' +WHERE subtype='road' AND class='residential' AND ST_Intersects(geometry, ST_Buffer(ST_GeomFromText('POINT (-8476562 4795814)'), 100)); +---- +0862aac667ffffff047de7f2111f86ad road residential +0862aac667ffffff047ffedcbc2db0f8 road residential