Skip to content

Commit

Permalink
Merge 'Use custom expr equality check in translation and planning' fr…
Browse files Browse the repository at this point in the history
…om Preston Thorpe

Idk how I missed these during the initial PR 👍

Reviewed-by: Jussi Saurio <[email protected]>

Closes #541
  • Loading branch information
jussisaurio committed Dec 23, 2024
2 parents 0a479a9 + fbf4245 commit 3ab7f7a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
7 changes: 2 additions & 5 deletions core/translate/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::ext::{ExtFunc, UuidFunc};
use crate::function::JsonFunc;
use crate::function::{AggFunc, Func, FuncCtx, MathFuncArity, ScalarFunc};
use crate::schema::Type;
use crate::util::normalize_ident;
use crate::util::{exprs_are_equivalent, normalize_ident};
use crate::vdbe::{builder::ProgramBuilder, BranchOffset, Insn};
use crate::Result;

Expand Down Expand Up @@ -556,10 +556,7 @@ pub fn translate_expr(
) -> Result<usize> {
if let Some(precomputed_exprs_to_registers) = precomputed_exprs_to_registers {
for (precomputed_expr, reg) in precomputed_exprs_to_registers.iter() {
// TODO: implement a custom equality check for expressions
// there are lots of examples where this breaks, even simple ones like
// sum(x) != SUM(x)
if expr == *precomputed_expr {
if exprs_are_equivalent(expr, precomputed_expr) {
program.emit_insn(Insn::Copy {
src_reg: *reg,
dst_reg: target_register,
Expand Down
12 changes: 10 additions & 2 deletions core/translate/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ use super::{
Aggregate, BTreeTableReference, Direction, GroupBy, Plan, ResultSetColumn, SourceOperator,
},
};
use crate::{function::Func, schema::Schema, util::normalize_ident, Result};
use crate::{
function::Func,
schema::Schema,
util::{exprs_are_equivalent, normalize_ident},
Result,
};
use sqlite3_parser::ast::{self, FromClause, JoinType, ResultColumn};

pub struct OperatorIdCounter {
Expand All @@ -23,7 +28,10 @@ impl OperatorIdCounter {
}

fn resolve_aggregates(expr: &ast::Expr, aggs: &mut Vec<Aggregate>) -> bool {
if aggs.iter().any(|a| a.original_expr == *expr) {
if aggs
.iter()
.any(|a| exprs_are_equivalent(&a.original_expr, expr))
{
return true;
}
match expr {
Expand Down

0 comments on commit 3ab7f7a

Please sign in to comment.