diff --git a/icebergShaded/generate_iceberg_jars.py b/icebergShaded/generate_iceberg_jars.py index 551638aca90..8f19bc8d81d 100644 --- a/icebergShaded/generate_iceberg_jars.py +++ b/icebergShaded/generate_iceberg_jars.py @@ -124,7 +124,7 @@ def generate_iceberg_jars(): # Search for all glob results results = glob.glob(compiled_jar_abs_pattern) # Compiled jars will include tests, sources, javadocs; exclude them - results = list(filter(lambda result: all(x not in result for x in ["test", "source", "javadoc"]), results)) + results = list(filter(lambda result: all(x not in result for x in ["test", "sources", "javadoc"]), results)) if len(results) == 0: raise Exception("Could not find the jar: " + compled_jar_rel_glob_pattern) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/client/ExpressionHandler.java b/kernel/kernel-api/src/main/java/io/delta/kernel/client/ExpressionHandler.java index 0b84a482ee3..bb6987f8ea6 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/client/ExpressionHandler.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/client/ExpressionHandler.java @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.client; +import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.expressions.Expression; import io.delta.kernel.expressions.ExpressionEvaluator; +import io.delta.kernel.types.DataType; import io.delta.kernel.types.StructType; /** @@ -28,12 +29,18 @@ public interface ExpressionHandler { /** * Create an {@link ExpressionEvaluator} that can evaluate the given expression on - * {@link io.delta.kernel.data.ColumnarBatch}s with the given batchSchema. + * {@link ColumnarBatch}s with the given batchSchema. The expression is + * expected to be a scalar expression where for each one input row there + * is a one output value. * - * @param batchSchema Schema of the input data. + * @param inputSchema Input data schema * @param expression Expression to evaluate. + * @param outputType Expected result data type. * @return An {@link ExpressionEvaluator} instance bound to the given expression and - * batchSchema. + * inputSchem. */ - ExpressionEvaluator getEvaluator(StructType batchSchema, Expression expression); + ExpressionEvaluator getEvaluator( + StructType inputSchema, + Expression expression, + DataType outputType); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/And.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/And.java deleted file mode 100644 index 4ad3328c933..00000000000 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/And.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.delta.kernel.expressions; - -import java.util.Collection; - -import io.delta.kernel.types.BooleanType; - -/** - * Evaluates logical {@code expr1} AND {@code expr2} for {@code new And(expr1, expr2)}. - *

- * Requires both left and right input expressions evaluate to booleans. - */ -public final class And extends BinaryOperator implements Predicate { - - public static And apply(Collection conjunctions) { - if (conjunctions.size() == 0) { - throw new IllegalArgumentException("And.apply must be called with at least 1 element"); - } - - return (And) conjunctions - .stream() - // we start off with And(true, true) - // then we get the 1st expression: And(And(true, true), expr1) - // then we get the 2nd expression: And(And(true, true), expr1), expr2) etc. - .reduce(new And(Literal.TRUE, Literal.TRUE), And::new); - } - - public And(Expression left, Expression right) { - super(left, right, "&&"); - if (!(left.dataType() instanceof BooleanType) || - !(right.dataType() instanceof BooleanType)) { - - throw new IllegalArgumentException( - String.format( - "'And' requires expressions of type boolean. Got %s and %s.", - left.dataType(), - right.dataType() - ) - ); - } - } - - @Override - public Object nullSafeEval(Object leftResult, Object rightResult) { - return (boolean) leftResult && (boolean) rightResult; - } -} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryComparison.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryComparison.java deleted file mode 100644 index 54e9648d85b..00000000000 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryComparison.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.delta.kernel.expressions; - -import java.util.Comparator; - -import io.delta.kernel.internal.expressions.CastingComparator; - -/** - * A {@link BinaryOperator} that compares the left and right {@link Expression}s and evaluates to a - * boolean value. - */ -public abstract class BinaryComparison extends BinaryOperator implements Predicate { - private final Comparator comparator; - - protected BinaryComparison(Expression left, Expression right, String symbol) { - super(left, right, symbol); - - // super asserted that left and right DataTypes were the same - - comparator = CastingComparator.forDataType(left.dataType()); - } - - protected int compare(Object leftResult, Object rightResult) { - return comparator.compare(leftResult, rightResult); - } -} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryExpression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryExpression.java deleted file mode 100644 index 82a67b328e1..00000000000 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryExpression.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.delta.kernel.expressions; - -import java.util.Arrays; -import java.util.List; -import java.util.Objects; - -import io.delta.kernel.data.Row; - -/** - * An {@link Expression} with two inputs and one output. The output is by default evaluated to null - * if either input is evaluated to null. - */ -public abstract class BinaryExpression implements Expression { - protected final Expression left; - protected final Expression right; - - protected BinaryExpression(Expression left, Expression right) { - this.left = left; - this.right = right; - } - - public Expression getLeft() { - return left; - } - - public Expression getRight() { - return right; - } - - @Override - public final Object eval(Row row) { - Object leftResult = left.eval(row); - if (null == leftResult) return null; - - Object rightResult = right.eval(row); - if (null == rightResult) return null; - - return nullSafeEval(leftResult, rightResult); - } - - protected abstract Object nullSafeEval(Object leftResult, Object rightResult); - - @Override - public List children() { - return Arrays.asList(left, right); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - BinaryExpression that = (BinaryExpression) o; - return Objects.equals(left, that.left) && - Objects.equals(right, that.right); - } - - @Override - public int hashCode() { - return Objects.hash(left, right); - } -} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryOperator.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryOperator.java deleted file mode 100644 index 4c30e2c79bb..00000000000 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryOperator.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.delta.kernel.expressions; - -/** - * A {@link BinaryExpression} that is an operator, meaning the string representation is - * {@code x symbol y}, rather than {@code funcName(x, y)}. - *

- * Requires both inputs to be of the same data type. - */ -public abstract class BinaryOperator extends BinaryExpression { - protected final String symbol; - - protected BinaryOperator(Expression left, Expression right, String symbol) { - super(left, right); - this.symbol = symbol; - - if (!left.dataType().equals(right.dataType())) { - throw new IllegalArgumentException( - String.format( - "BinaryOperator left and right DataTypes must be the same. Found %s and %s.", - left.dataType(), - right.dataType() - )); - } - } - - @Override - public String toString() { - return String.format("(%s %s %s)", left.toString(), symbol, right.toString()); - } -} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Column.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Column.java index 5d7fc02188e..d99707873ac 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Column.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Column.java @@ -13,101 +13,35 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.expressions; import java.util.Collections; -import java.util.Objects; -import java.util.Set; - -import io.delta.kernel.data.Row; -import io.delta.kernel.types.BooleanType; -import io.delta.kernel.types.DataType; -import io.delta.kernel.types.IntegerType; -import io.delta.kernel.types.LongType; -import io.delta.kernel.types.StringType; -import io.delta.kernel.types.StructType; +import java.util.List; /** - * A column whose row-value will be computed based on the data in a {@link Row}. - *

- * It is recommended that you instantiate using an existing table schema - * {@link StructType} with {@link StructType#column(int)}. - *

- * Only supports primitive data types, see - * Delta Transaction Log Protocol: Primitive Types. + * An expression type that refers to a column by name (case-sensitive) in the input. */ -public final class Column extends LeafExpression { - private final int ordinal; +public final class Column implements Expression { private final String name; - private final DataType dataType; - private final RowEvaluator evaluator; - public Column(int ordinal, String name, DataType dataType) { - this.ordinal = ordinal; + public Column(String name) { this.name = name; - this.dataType = dataType; - - if (dataType instanceof IntegerType) { - evaluator = (row -> row.getInt(ordinal)); - } else if (dataType instanceof BooleanType) { - evaluator = (row -> row.getBoolean(ordinal)); - } else if (dataType instanceof LongType) { - evaluator = (row -> row.getLong(ordinal)); - } else if (dataType instanceof StringType) { - evaluator = (row -> row.getString(ordinal)); - } else { - throw new UnsupportedOperationException( - String.format( - "The data type %s of column %s at ordinal %s is not supported", - dataType, - name, - ordinal) - ); - } } - public String name() { + /** + * @return the column name. + */ + public String getName() { return name; } @Override - public Object eval(Row row) { - return row.isNullAt(ordinal) ? null : evaluator.nullSafeEval(row); - } - - @Override - public DataType dataType() { - return dataType; + public List getChildren() { + return Collections.emptyList(); } @Override public String toString() { return "Column(" + name + ")"; } - - @Override - public Set references() { - return Collections.singleton(name); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Column column = (Column) o; - return Objects.equals(ordinal, column.ordinal) && - Objects.equals(name, column.name) && - Objects.equals(dataType, column.dataType); - } - - @Override - public int hashCode() { - return Objects.hash(name, dataType); - } - - @FunctionalInterface - private interface RowEvaluator { - Object nullSafeEval(Row row); - } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Expression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Expression.java index af4bc592ecc..e946a403431 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Expression.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Expression.java @@ -13,48 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.expressions; -import java.util.HashSet; import java.util.List; -import java.util.Set; - -import io.delta.kernel.data.Row; -import io.delta.kernel.types.DataType; /** - * Generic interface for all Expressions + * Base interface for all Kernel expressions. */ public interface Expression { - - /** - * @param row the input row to evaluate. - * @return the result of evaluating this expression on the given input {@link Row}. - */ - Object eval(Row row); - - /** - * @return the {@link DataType} of the result of evaluating this expression. - */ - DataType dataType(); - - /** - * @return the String representation of this expression. - */ - String toString(); - - /** - * @return a {@link List} of the immediate children of this node - */ - List children(); - /** - * @return the names of columns referenced by this expression. + * @return a list of the input expressions. */ - default Set references() { - Set result = new HashSet<>(); - children().forEach(child -> result.addAll(child.references())); - return result; - } + List getChildren(); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Literal.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Literal.java index 2f653e02a02..4630ced68b1 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Literal.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Literal.java @@ -17,26 +17,35 @@ package io.delta.kernel.expressions; import java.math.BigDecimal; -import java.sql.Date; -import java.sql.Timestamp; -import java.util.Objects; +import java.util.Collections; +import java.util.List; -import io.delta.kernel.data.Row; import io.delta.kernel.types.*; + import static io.delta.kernel.internal.util.InternalUtils.checkArgument; /** * A literal value. *

- * Only supports primitive data types, see - * Delta Transaction Log Protocol: Primitive Types. + * Definition: + *

+ *

    + *
  • Represents only primitive data types as defined in the protocol + * Delta Transaction Log Protocol: Primitive Types + *
  • + *
  • Use {@link Literal#value} to fetch the literal value. Returned object type + * depends on the type of the literal data type. See the {@link Literal#value for further + * details}.
  • + *
*/ -public final class Literal extends LeafExpression { +public final class Literal + implements Expression { + // TODO: Remove these and use `AlwaysTrue` and `AlwaysFalse` predicate expressions instead. public static final Literal TRUE = Literal.of(true); public static final Literal FALSE = Literal.of(false); /** - * Create a boolean {@link Literal} object + * Create a {@code boolean} type {@link Literal} object * * @param value boolean value * @return a {@link Literal} with data type {@link BooleanType} @@ -111,17 +120,21 @@ public static Literal of(byte[] value) { } /** - * @return a {@link Literal} with data type {@link DateType} + * Create a {@code date} type {@link Literal}. + * + * @param daysSinceEpochUTC Number of days since the epoch in UTC timezone. + * @return a {@link Literal} with data type {@link BooleanType} */ - public static Literal of(Date value) { - return new Literal(value, DateType.INSTANCE); + public static Literal ofDate(int daysSinceEpochUTC) { + return new Literal(daysSinceEpochUTC, DateType.INSTANCE); } /** + * @param microsSinceEpochUTC value in microseconds since epoch time in UTC timezone. * @return a {@link Literal} with data type {@link TimestampType} */ - public static Literal of(Timestamp value) { - return new Literal(value, TimestampType.INSTANCE); + public static Literal ofTimestamp(long microsSinceEpochUTC) { + return new Literal(microsSinceEpochUTC, TimestampType.INSTANCE); } /** @@ -131,20 +144,22 @@ public static Literal of(BigDecimal value, int precision, int scale) { // throws an error if rounding is required to set the specified scale BigDecimal valueToStore = value.setScale(scale); checkArgument(valueToStore.precision() <= precision, String.format( - "Decimal precision=%s for decimal %s exceeds max precision %s", - valueToStore.precision(), valueToStore, precision)); + "Decimal precision=%s for decimal %s exceeds max precision %s", + valueToStore.precision(), valueToStore, precision)); return new Literal(valueToStore, new DecimalType(precision, scale)); } /** + * Create {@code null} value literal. + * + * @param dataType {@link DataType} of the null literal. * @return a null {@link Literal} with the given data type */ public static Literal ofNull(DataType dataType) { if (dataType instanceof ArrayType || dataType instanceof MapType || dataType instanceof StructType) { - throw new IllegalArgumentException( - dataType + " is an invalid data type for Literal."); + throw new IllegalArgumentException(dataType + " is an invalid data type for Literal."); } return new Literal(null, dataType); } @@ -157,17 +172,36 @@ private Literal(Object value, DataType dataType) { this.dataType = dataType; } - public Object value() { - return value; - } - - @Override - public Object eval(Row record) { + /** + * Get the literal value. If the value is null a {@code null} is returned. Otherwise one + * of the following types based on the literal data type. + * + *
    + *
  • BOOLEAN: {@link Boolean}
  • + *
  • BYTE: {@link Byte}
  • + *
  • SHORT: {@link Short}
  • + *
  • INTEGER: {@link Integer}
  • + *
  • LONG: {@link Long}
  • + *
  • FLOAT: {@link Float}
  • + *
  • DOUBLE: {@link Double}
  • + *
  • DATE: {@link Integer} represents the number of days since epoch in UTC
  • + *
  • TIMESTAMP: {@link Long} represents the microseconds since epoch in UTC
  • + *
  • DECIMAL: {@link BigDecimal}. Use {@link #dataType} to find the precision and scale
  • + *
+ * + * @return Literal value. + */ + public Object getValue() { return value; } - @Override - public DataType dataType() { + /** + * Get the datatype of the literal object. Datatype lets the caller interpret the value of the + * literal object returned by {@link #value} + * + * @return Datatype of the literal object. + */ + public DataType getDataType() { return dataType; } @@ -177,20 +211,7 @@ public String toString() { } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - Literal literal = (Literal) o; - return Objects.equals(value, literal.value) && - Objects.equals(dataType, literal.dataType); - } - - @Override - public int hashCode() { - return Objects.hash(value, dataType); + public List getChildren() { + return Collections.emptyList(); } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java new file mode 100644 index 00000000000..2db91ae78c0 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java @@ -0,0 +1,58 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.expressions; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.expressions.predicates.Predicate; + +/** + * Scalar SQL expressions which take zero or more inputs and for each input row generate one + * output value. A subclass of these expressions are of type {@link Predicate} whose result type is + * `boolean`. See {@link Predicate} for predicate type scalar expressions. Supported + * non-predicate type scalar expressions are listed below. + * TODO: Currently there aren't any. Will be added in future. An example one looks like this: + *
    + *
  1. Name: + + *
      + *
    • SQL semantic: expr1 + expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  2. + *
+ */ +public class ScalarExpression implements Expression { + protected final String name; + protected final List children; + + public ScalarExpression(String name, List children) { + this.name = requireNonNull(name, "name is null").toUpperCase(Locale.ENGLISH); + this.children = Collections.unmodifiableList(new ArrayList<>(children)); + } + + public String getName() { + return name; + } + + @Override + public List getChildren() { + return children; + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/AlwaysFalse.java similarity index 66% rename from kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java rename to kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/AlwaysFalse.java index a26a01273f9..1758def8c14 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/AlwaysFalse.java @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package io.delta.kernel.expressions.predicates; -package io.delta.kernel.expressions; - -import io.delta.kernel.types.BooleanType; -import io.delta.kernel.types.DataType; +import java.util.Collections; /** - * An {@link Expression} that defines a relation on inputs. Evaluates to true, false, or null. + * Predicate which always evaluates to {@code false}. */ -public interface Predicate extends Expression { +public final class AlwaysFalse extends Predicate { + public AlwaysFalse() { + super("ALWAYS_FALSE", Collections.emptyList()); + } + @Override - default DataType dataType() { - return BooleanType.INSTANCE; + public String toString() { + return "false"; } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/EqualTo.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/AlwaysTrue.java similarity index 61% rename from kernel/kernel-api/src/main/java/io/delta/kernel/expressions/EqualTo.java rename to kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/AlwaysTrue.java index 8f2076c56e5..918ab91dc8b 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/EqualTo.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/AlwaysTrue.java @@ -13,21 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package io.delta.kernel.expressions.predicates; -package io.delta.kernel.expressions; +import java.util.Collections; /** - * Evaluates {@code expr1} = {@code expr2} for {@code new EqualTo(expr1, expr2)}. + * Predicate which always evaluates to {@code true}. */ -public final class EqualTo extends BinaryComparison implements Predicate { - - public EqualTo(Expression left, Expression right) { - super(left, right, "="); +public final class AlwaysTrue extends Predicate { + public AlwaysTrue() { + super("ALWAYS_TRUE", Collections.emptyList()); } @Override - protected Object nullSafeEval(Object leftResult, Object rightResult) { - return compare(leftResult, rightResult) == 0; + public String toString() { + return "true"; } } - diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/And.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/And.java new file mode 100644 index 00000000000..000a8bc4873 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/And.java @@ -0,0 +1,54 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.expressions.predicates; + +import java.util.Arrays; + +/** + * {@code AND} expression + *

+ * Definition: + *

+ *

    + *
  • Logical {@code expr1} AND {@code expr2} on two inputs.
  • + *
  • Requires both left and right input expressions evaluate to be {@link Predicate}.
  • + *
  • Result is null if one or both of the inputs are null.
  • + *
+ */ +public final class And extends Predicate { + public And(Predicate left, Predicate right) { + super("AND", Arrays.asList(left, right)); + } + + /** + * @return Left side operand. + */ + public Predicate getLeft() { + return (Predicate) getChildren().get(0); + } + + /** + * @return Right side operand. + */ + public Predicate getRight() { + return (Predicate) getChildren().get(1); + } + + @Override + public String toString() { + return "(" + getLeft() + " AND " + getRight() + ")"; + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/Or.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/Or.java new file mode 100644 index 00000000000..e7ed0716436 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/Or.java @@ -0,0 +1,53 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.expressions.predicates; + +import java.util.Arrays; + +/** + * {@code OR} expression + *

+ * Definition: + *

    + *
  • Logical {@code expr1} OR {@code expr2} on two inputs.
  • + *
  • Requires both left and right input expressions evaluate to be {@link Predicate}.
  • + *
  • Result is null if one or both of the inputs are null.
  • + *
+ */ +public final class Or extends Predicate { + public Or(Predicate left, Predicate right) { + super("OR", Arrays.asList(left, right)); + } + + /** + * @return Left side operand. + */ + public Predicate getLeft() { + return (Predicate) getChildren().get(0); + } + + /** + * @return Right side operand. + */ + public Predicate getRight() { + return (Predicate) getChildren().get(1); + } + + @Override + public String toString() { + return "(" + getLeft() + " OR " + getRight() + ")"; + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/Predicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/Predicate.java new file mode 100644 index 00000000000..90c1bb0f59e --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/Predicate.java @@ -0,0 +1,95 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.expressions.predicates; + +import java.util.List; + +import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.ScalarExpression; + +/** + * Defines predicate scalar expression which is an extension of {@link Expression} + * that evaluates to true, false, or null for each input row. + *

+ * Currently, Kernel allows following predicate scalar expressions. + *

    + *
  1. Name: = + *
      + *
    • SQL semantic: expr1 = expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  2. + *
  3. Name: <> + *
      + *
    • SQL semantic: expr1 <> expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  4. + *
  5. Name: < + *
      + *
    • SQL semantic: expr1 < expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  6. + *
  7. Name: <= + *
      + *
    • SQL semantic: expr1 <= expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  8. + *
  9. Name: > + *
      + *
    • SQL semantic: expr1 > expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  10. + *
  11. Name: >= + *
      + *
    • SQL semantic: expr1 >= expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  12. + *
  13. Name: ALWAYS_TRUE + *
      + *
    • SQL semantic: Constant expression whose value is `true`
    • + *
    • Since version: 3.0.0
    • + *
    + *
  14. + *
  15. Name: ALWAYS_FALSE + *
      + *
    • SQL semantic: Constant expression whose value is `false`
    • + *
    • Since version: 3.0.0
    • + *
    + *
  16. + *
  17. Name: AND + *
      + *
    • SQL semantic: expr1 AND expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  18. + *
  19. Name: OR + *
      + *
    • SQL semantic: expr1 OR expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  20. + *
+ */ +public class Predicate extends ScalarExpression { + public Predicate(String name, List children) { + super(name, children); + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/LeafExpression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/package-info.java similarity index 53% rename from kernel/kernel-api/src/main/java/io/delta/kernel/expressions/LeafExpression.java rename to kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/package-info.java index 05f48b5c5a6..c920cf7bef1 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/LeafExpression.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/predicates/package-info.java @@ -14,30 +14,8 @@ * limitations under the License. */ -package io.delta.kernel.expressions; - -import java.util.Collections; -import java.util.List; -import java.util.Set; - /** - * An {@link Expression} with no children. + * Predicate type expressions that defines the most common expressions which the connectors + * can use to pass predicates to Delta Kernel. */ -public abstract class LeafExpression implements Expression { - - protected LeafExpression() {} - - @Override - public List children() { - return Collections.emptyList(); - } - - @Override - public Set references() { - return Collections.emptySet(); - } - - public abstract boolean equals(Object o); - - public abstract int hashCode(); -} +package io.delta.kernel.expressions.predicates; diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java index 86bb30834da..b6468ada71d 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java @@ -30,6 +30,7 @@ import io.delta.kernel.expressions.Literal; import io.delta.kernel.types.*; import io.delta.kernel.utils.Tuple2; +import io.delta.kernel.utils.Utils; public class PartitionUtils { private PartitionUtils() {} @@ -88,8 +89,8 @@ public static ColumnarBatch withPartitionColumns( dataBatchSchema, literalForPartitionValue( structField.getDataType(), - partitionValues.get(structField.getName()) - ) + partitionValues.get(structField.getName())), + structField.getDataType() ); ColumnVector partitionVector = evaluator.eval(dataBatch); @@ -133,7 +134,7 @@ private static Literal literalForPartitionValue(DataType dataType, String partit return Literal.of(partitionValue.getBytes()); } if (dataType instanceof DateType) { - return Literal.of(Date.valueOf(partitionValue)); + return Literal.ofDate(Utils.daysSinceEpoch(Date.valueOf(partitionValue))); } if (dataType instanceof DecimalType) { DecimalType decimalType = (DecimalType) dataType; diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/BooleanType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/BooleanType.java index 0904fc7dbe2..819fabacf89 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/BooleanType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/BooleanType.java @@ -19,6 +19,8 @@ * Data type representing {@code boolean} type values. */ public class BooleanType extends BasePrimitiveType { + // TODO: Should remove the `INSTANCE` to `BOOLEAN` so that it can be static imported where + // needed and referred without the `BooleanType.` prefix. Same for other types. public static final BooleanType INSTANCE = new BooleanType(); private BooleanType() { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java index d3d9e6a309d..33586b54df3 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java @@ -110,19 +110,7 @@ public StructField at(int index) { */ public Column column(int ordinal) { final StructField field = at(ordinal); - return new Column(ordinal, field.getName(), field.getDataType()); - } - - /** - * Creates a {@link Column} expression for the field with the given {@code fieldName}. - * - * @param fieldName the name of the {@link StructField} to create a column for - * @return a {@link Column} expression for the {@link StructField} with name {@code fieldName} - */ - public Column column(String fieldName) { - Tuple2 fieldAndOrdinal = nameToFieldAndOrdinal.get(fieldName); - System.out.println("Created column " + fieldName + " with ordinal " + fieldAndOrdinal._2); - return new Column(fieldAndOrdinal._2, fieldName, fieldAndOrdinal._1.getDataType()); + return new Column(field.getName()); } @Override diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/utils/Utils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/utils/Utils.java index 49b74fd245f..ed37de0f535 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/utils/Utils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/utils/Utils.java @@ -18,6 +18,9 @@ import java.io.Closeable; import java.io.IOException; +import java.sql.Date; +import java.time.LocalDate; +import java.time.temporal.ChronoUnit; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -35,6 +38,8 @@ import io.delta.kernel.internal.types.TableSchemaSerDe; public class Utils { + private static final LocalDate EPOCH = LocalDate.ofEpochDay(0); + /** * Utility method to create a singleton {@link CloseableIterator}. * @@ -252,4 +257,57 @@ public static Row requireNonNull(Row row, int ordinal, String columnName) { } return row; } + + /** + * Precondition-style validation that throws {@link IllegalArgumentException}. + * + * @param isValid {@code true} if valid, {@code false} if an exception should be thrown + * @throws IllegalArgumentException if {@code isValid} is false + */ + public static void checkArgument(boolean isValid) + throws IllegalArgumentException { + if (!isValid) { + throw new IllegalArgumentException(); + } + } + + /** + * Precondition-style validation that throws {@link IllegalArgumentException}. + * + * @param isValid {@code true} if valid, {@code false} if an exception should be thrown + * @param message A String message for the exception. + * @throws IllegalArgumentException if {@code isValid} is false + */ + public static void checkArgument(boolean isValid, String message) + throws IllegalArgumentException { + if (!isValid) { + throw new IllegalArgumentException(message); + } + } + + /** + * Precondition-style validation that throws {@link IllegalArgumentException}. + * + * @param isValid {@code true} if valid, {@code false} if an exception should be thrown + * @param message A String message for the exception. + * @param args Objects used to fill in {@code %s} placeholders in the message + * @throws IllegalArgumentException if {@code isValid} is false + */ + public static void checkArgument(boolean isValid, String message, Object... args) + throws IllegalArgumentException { + if (!isValid) { + throw new IllegalArgumentException( + String.format(String.valueOf(message), args)); + } + } + + /** + * Utility method to get the number of days since epoch this given date is. + * + * @param date + */ + public static int daysSinceEpoch(Date date) { + LocalDate localDate = date.toLocalDate(); + return (int) ChronoUnit.DAYS.between(EPOCH, localDate); + } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultExpressionHandler.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultExpressionHandler.java index e7c629a9464..78d0d902940 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultExpressionHandler.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultExpressionHandler.java @@ -13,108 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.defaults.client; -import java.sql.Date; -import java.util.Optional; - import io.delta.kernel.client.ExpressionHandler; -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.data.Row; import io.delta.kernel.expressions.Expression; import io.delta.kernel.expressions.ExpressionEvaluator; -import io.delta.kernel.expressions.Literal; -import io.delta.kernel.types.*; -import io.delta.kernel.utils.CloseableIterator; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.StructType; -import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector; -import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector; -import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; -import static io.delta.kernel.defaults.internal.DefaultKernelUtils.daysSinceEpoch; +import io.delta.kernel.defaults.internal.expressions.DefaultExpressionEvaluator; public class DefaultExpressionHandler implements ExpressionHandler { @Override - public ExpressionEvaluator getEvaluator(StructType batchSchema, Expression expression) { - return new DefaultExpressionEvaluator(expression); - } - - private static ColumnVector evalBooleanOutputExpression( - ColumnarBatch input, Expression expression) { - checkArgument(expression.dataType().equals(BooleanType.INSTANCE), - "expression should return a boolean"); - - final int batchSize = input.getSize(); - boolean[] result = new boolean[batchSize]; - boolean[] nullResult = new boolean[batchSize]; - CloseableIterator rows = input.getRows(); - for (int currentIndex = 0; currentIndex < batchSize; currentIndex++) { - Object evalResult = expression.eval(rows.next()); - if (evalResult == null) { - nullResult[currentIndex] = true; - } else { - result[currentIndex] = ((Boolean) evalResult).booleanValue(); - } - } - return new DefaultBooleanVector(batchSize, Optional.of(nullResult), result); - } - - private static ColumnVector evalLiteralExpression(ColumnarBatch input, Literal literal) { - Object result = literal.value(); - DataType dataType = literal.dataType(); - int size = input.getSize(); - - if (result == null) { - return new DefaultConstantVector(dataType, size, null); - } - - if (dataType instanceof BooleanType || - dataType instanceof ByteType || - dataType instanceof ShortType || - dataType instanceof IntegerType || - dataType instanceof LongType || - dataType instanceof FloatType || - dataType instanceof DoubleType || - dataType instanceof StringType || - dataType instanceof BinaryType || - dataType instanceof DecimalType) { - return new DefaultConstantVector(dataType, size, result); - } else if (dataType instanceof DateType) { - int numOfDaysSinceEpoch = daysSinceEpoch((Date) result); - return new DefaultConstantVector(dataType, size, numOfDaysSinceEpoch); - } - // TODO: support timestamptype - - throw new UnsupportedOperationException( - "unsupported expression encountered: " + literal); - } - - private static class DefaultExpressionEvaluator - implements ExpressionEvaluator { - private final Expression expression; - - private DefaultExpressionEvaluator(Expression expression) { - this.expression = expression; - } - - @Override - public ColumnVector eval(ColumnarBatch input) { - if (expression instanceof Literal) { - return evalLiteralExpression(input, (Literal) expression); - } - - if (expression.dataType().equals(BooleanType.INSTANCE)) { - return evalBooleanOutputExpression(input, expression); - } - // TODO: Boolean output type expressions are good enough for first preview release - // which enables partition pruning and file skipping using file stats. - - throw new UnsupportedOperationException("not yet implemented"); - } - - @Override - public void close() { /* nothing to close */ } + public ExpressionEvaluator getEvaluator( + StructType inputSchema, + Expression expression, + DataType outputType) { + return new DefaultExpressionEvaluator(inputSchema, expression, outputType); } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java index 7f853d8e2f3..f52b371afe1 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java @@ -15,9 +15,7 @@ */ package io.delta.kernel.defaults.internal; -import java.sql.Date; import java.time.LocalDate; -import java.time.temporal.ChronoUnit; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -54,7 +52,7 @@ public static final MessageType pruneSchema( * given {@code field}. * * @param groupType Parquet group type coming from the file schema. - * @param field Sub field given as Delta Kernel's {@link StructField} + * @param field Sub field given as Delta Kernel's {@link StructField} * @return {@link Type} of the Parquet field. Returns {@code null}, if not found. */ public static Type findSubFieldType(GroupType groupType, StructField field) { @@ -74,6 +72,8 @@ public static Type findSubFieldType(GroupType groupType, StructField field) { return null; } + // TODO: Move these precondition checks into a separate utility class. + /** * Precondition-style validation that throws {@link IllegalArgumentException}. * @@ -106,7 +106,7 @@ public static void checkArgument(boolean isValid, String message) * * @param isValid {@code true} if valid, {@code false} if an exception should be thrown * @param message A String message for the exception. - * @param args Objects used to fill in {@code %s} placeholders in the message + * @param args Objects used to fill in {@code %s} placeholders in the message * @throws IllegalArgumentException if {@code isValid} is false */ public static void checkArgument(boolean isValid, String message, Object... args) @@ -156,16 +156,6 @@ private static Type prunedType(Type type, DataType deltaType) { } } - /** - * Utility method to get the number of days since epoch this given date is. - * - * @param date - */ - public static int daysSinceEpoch(Date date) { - LocalDate localDate = date.toLocalDate(); - return (int) ChronoUnit.DAYS.between(EPOCH, localDate); - } - ////////////////////////////////////////////////////////////////////////////////// // Below utils are adapted from org.apache.spark.sql.catalyst.util.DateTimeUtils ////////////////////////////////////////////////////////////////////////////////// diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java new file mode 100644 index 00000000000..773e0350ad8 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -0,0 +1,362 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions; + +import java.util.Arrays; +import java.util.Optional; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.ColumnarBatch; +import io.delta.kernel.expressions.Column; +import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.ExpressionEvaluator; +import io.delta.kernel.expressions.Literal; +import io.delta.kernel.expressions.predicates.*; +import io.delta.kernel.types.*; + +import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector; +import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector; +import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; +import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.compareVectors; +import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.nullabilityEval; +import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo; +import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.createCastExpression; + +public class DefaultExpressionEvaluator implements ExpressionEvaluator { + private final Expression expression; + + public DefaultExpressionEvaluator( + StructType inputSchema, + Expression expression, + DataType outputType) { + ExpressionTransformResult transformResult = + new ExpressionTransformer(inputSchema).visit(expression); + if (!transformResult.outputType.equivalent(outputType)) { + throw new UnsupportedOperationException(format("Can not create an expression handler " + + "for expression `%s` returns result of type %s", expression, outputType)); + } + this.expression = transformResult.expression; + } + + @Override + public ColumnVector eval(ColumnarBatch input) { + return new ExpressionEvalVisitor(input).visit(expression); + } + + @Override + public void close() { /* nothing to close */ } + + /** + * Encapsulates the result of {@link ExpressionTransformer} + */ + private static class ExpressionTransformResult { + public final Expression expression; // transformed expression + public final DataType outputType; // output type of the expression + + ExpressionTransformResult(Expression expression, DataType outputType) { + this.expression = expression; + this.outputType = outputType; + } + } + + /** + * Implementation of {@link ExpressionVisitor} to validate the given expression as follows. + *
    + *
  • given input column is part of the input data schema
  • + *
  • expression inputs are of supported types. Insert cast according to the rules in + * {@link ImplicitCastExpression} to make the types compatible for evaluation by + * {@link ExpressionEvalVisitor} + *
  • + *
+ *

+ * Return type of each expression visit is a tuple of new rewritten expression and its result + * data type. + */ + private static class ExpressionTransformer + extends ExpressionVisitor { + private StructType inputDataSchema; + + ExpressionTransformer(StructType inputDataSchema) { + this.inputDataSchema = requireNonNull(inputDataSchema, "inputDataSchema is null"); + } + + @Override + ExpressionTransformResult visitAnd(And and) { + Predicate left = validateIsPredicate(and, visit(and.getLeft())); + Predicate right = validateIsPredicate(and, visit(and.getRight())); + return new ExpressionTransformResult(new And(left, right), BooleanType.INSTANCE); + } + + @Override + ExpressionTransformResult visitOr(Or or) { + Predicate left = validateIsPredicate(or, visit(or.getLeft())); + Predicate right = validateIsPredicate(or, visit(or.getRight())); + return new ExpressionTransformResult(new Or(left, right), BooleanType.INSTANCE); + } + + @Override + ExpressionTransformResult visitAlwaysTrue(AlwaysTrue alwaysTrue) { + // nothing to validate or rewrite. + return new ExpressionTransformResult(alwaysTrue, BooleanType.INSTANCE); + } + + @Override + ExpressionTransformResult visitAlwaysFalse(AlwaysFalse alwaysFalse) { + // nothing to validate or rewrite. + return new ExpressionTransformResult(alwaysFalse, BooleanType.INSTANCE); + } + + @Override + ExpressionTransformResult visitPredicate(Predicate predicate) { + switch (predicate.getName()) { + case "=": + case ">": + case ">=": + case "<": + case "<=": + return new ExpressionTransformResult( + validateAndRewriteBinaryComparator(predicate), + BooleanType.INSTANCE); + default: + throw new UnsupportedOperationException( + "unsupported expression encountered: " + predicate); + } + } + + @Override + ExpressionTransformResult visitLiteral(Literal literal) { + // nothing to validate or rewrite + return new ExpressionTransformResult(literal, literal.getDataType()); + } + + @Override + ExpressionTransformResult visitColumn(Column column) { + int ordinal = inputDataSchema.indexOf(column.getName()); + if (ordinal == -1) { + throw new IllegalArgumentException( + format("Column `%s` doesn't exist in input data schema: %s", + column.getName(), inputDataSchema)); + } + return new ExpressionTransformResult(column, inputDataSchema.at(ordinal).getDataType()); + } + + @Override + ExpressionTransformResult visitCast(ImplicitCastExpression cast) { + throw new UnsupportedOperationException("CAST expression is not expected."); + } + + private Predicate validateIsPredicate( + Expression baseExpression, + ExpressionTransformResult result) { + checkArgument( + result.outputType instanceof BooleanType && + result.expression instanceof Predicate, + "%s: expected a predicate expression but got %s with output type %s.", + baseExpression, + result.expression, + result.outputType); + return (Predicate) result.expression; + } + + private Expression validateAndRewriteBinaryComparator(Predicate predicate) { + checkArgument(predicate.getChildren().size() == 2, "expected two inputs"); + ExpressionTransformResult leftResult = visit(predicate.getChildren().get(0)); + ExpressionTransformResult rightResult = visit(predicate.getChildren().get(1)); + Expression left = leftResult.expression; + Expression right = rightResult.expression; + if (!leftResult.outputType.equivalent(rightResult.outputType)) { + if (canCastTo(leftResult.outputType, rightResult.outputType)) { + left = createCastExpression(left, rightResult.outputType); + } else if (canCastTo(rightResult.outputType, leftResult.outputType)) { + right = createCastExpression(right, leftResult.outputType); + } else { + String msg = format("%s: operands are of different types which are not " + + "comparable: left type=%s, right type=%", + predicate, leftResult.outputType, rightResult.outputType); + throw new UnsupportedOperationException(msg); + } + } + return new Predicate(predicate.getName(), Arrays.asList(left, right)); + } + } + + /** + * Implementation of {@link ExpressionVisitor} to evaluate expression on a + * {@link ColumnarBatch}. + */ + private static class ExpressionEvalVisitor extends ExpressionVisitor { + private final ColumnarBatch input; + + ExpressionEvalVisitor(ColumnarBatch input) { + this.input = input; + } + + @Override + ColumnVector visitAnd(And and) { + PredicateChildrenEvalResult argResults = evalBinaryExpressionChildren(and); + int numRows = argResults.rowCount; + boolean[] result = new boolean[numRows]; + boolean[] nullability = nullabilityEval(argResults.leftResult, argResults.rightResult); + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = argResults.leftResult.getBoolean(rowId) && + argResults.rightResult.getBoolean(rowId); + } + return new DefaultBooleanVector(numRows, Optional.of(nullability), result); + } + + @Override + ColumnVector visitOr(Or or) { + PredicateChildrenEvalResult argResults = evalBinaryExpressionChildren(or); + int numRows = argResults.rowCount; + boolean[] result = new boolean[numRows]; + boolean[] nullability = nullabilityEval(argResults.leftResult, argResults.rightResult); + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = argResults.leftResult.getBoolean(rowId) || + argResults.rightResult.getBoolean(rowId); + } + return new DefaultBooleanVector(numRows, Optional.of(nullability), result); + } + + @Override + ColumnVector visitAlwaysTrue(AlwaysTrue alwaysTrue) { + return new DefaultConstantVector(BooleanType.INSTANCE, input.getSize(), true); + } + + @Override + ColumnVector visitAlwaysFalse(AlwaysFalse alwaysFalse) { + return new DefaultConstantVector(BooleanType.INSTANCE, input.getSize(), false); + } + + @Override + ColumnVector visitPredicate(Predicate predicate) { + PredicateChildrenEvalResult argResults = evalBinaryExpressionChildren(predicate); + + int numRows = argResults.rowCount; + boolean[] result = new boolean[numRows]; + boolean[] nullability = nullabilityEval(argResults.leftResult, argResults.rightResult); + int[] compareResult = compareVectors(argResults.leftResult, argResults.rightResult); + switch (predicate.getName()) { + case "=": + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = compareResult[rowId] == 0; + } + break; + case ">": + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = compareResult[rowId] > 0; + } + break; + case ">=": + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = compareResult[rowId] >= 0; + } + break; + case "<": + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = compareResult[rowId] < 0; + } + break; + case "<=": + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = compareResult[rowId] <= 0; + } + break; + default: + throw new UnsupportedOperationException( + "unsupported expression encountered: " + predicate); + } + + return new DefaultBooleanVector(numRows, Optional.of(nullability), result); + } + + @Override + ColumnVector visitLiteral(Literal literal) { + DataType dataType = literal.getDataType(); + if (dataType instanceof BooleanType || + dataType instanceof ByteType || + dataType instanceof ShortType || + dataType instanceof IntegerType || + dataType instanceof LongType || + dataType instanceof FloatType || + dataType instanceof DoubleType || + dataType instanceof StringType || + dataType instanceof BinaryType || + dataType instanceof DecimalType || + dataType instanceof DateType || + dataType instanceof TimestampType) { + return new DefaultConstantVector(dataType, input.getSize(), literal.getValue()); + } + + throw new UnsupportedOperationException( + "unsupported expression encountered: " + literal); + } + + @Override + ColumnVector visitColumn(Column column) { + int ordinal = input.getSchema().indexOf(column.getName()); + if (ordinal == -1) { + throw new IllegalArgumentException( + format("Column `%s` doesn't exist in input data schema: %s", + column.getName(), input.getSchema())); + } + return input.getColumnVector(ordinal); + } + + @Override + ColumnVector visitCast(ImplicitCastExpression cast) { + ColumnVector inputResult = visit(cast.getInput()); + return ImplicitCastExpression.evalCastExpression(cast, inputResult); + } + + /** + * Utility method to evaluate inputs to the binary input expression. Also validates the + * evaluated expression result {@link ColumnVector}s are of the same size. + * + * @param predicate + * @return Triplet of (result vector size, left operand result, left operand result) + */ + private PredicateChildrenEvalResult evalBinaryExpressionChildren( + Predicate predicate) { + checkArgument(predicate.getChildren().size() == 2, "expected two inputs"); + ColumnVector left = visit(predicate.getChildren().get(0)); + ColumnVector right = visit(predicate.getChildren().get(1)); + checkArgument( + left.getSize() == right.getSize(), + "Left and right operand returned different results: left=%d, right=d", + left.getSize(), + right.getSize()); + return new PredicateChildrenEvalResult(left.getSize(), left, right); + } + } + + /** + * Encapsulates children expression result of binary input predicate + */ + private static class PredicateChildrenEvalResult { + public final int rowCount; + public final ColumnVector leftResult; + public final ColumnVector rightResult; + + PredicateChildrenEvalResult( + int rowCount, ColumnVector leftResult, ColumnVector rightResult) { + this.rowCount = rowCount; + this.leftResult = leftResult; + this.rightResult = rightResult; + } + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java new file mode 100644 index 00000000000..a4497e13449 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java @@ -0,0 +1,157 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions; + +import java.math.BigDecimal; +import java.util.Comparator; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.types.*; + +import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; + +class ExpressionUtils { + private ExpressionUtils() {} + + /** + * Utility method that calculates the nullability result from given two vectors. Result is + * null if at least one side is a null. + */ + static boolean[] nullabilityEval(ColumnVector left, ColumnVector right) { + int numRows = left.getSize(); + boolean[] nullability = new boolean[numRows]; + for (int rowId = 0; rowId < numRows; rowId++) { + nullability[rowId] = left.isNullAt(rowId) || right.isNullAt(rowId); + } + return nullability; + } + + /** + * Utility method to compare the left and right according to the natural ordering + * and return the comparison result (-1, 0, 1) for each row as an integer array. + */ + static int[] compareVectors(ColumnVector left, ColumnVector right) { + checkArgument( + left.getSize() == right.getSize(), + "Left and right operand have different vector sizes."); + DataType dataType = left.getDataType(); + + int numRows = left.getSize(); + int[] result = new int[numRows]; + if (dataType instanceof BooleanType) { + compareBooleanVectors(left, right, result); + } else if (dataType instanceof ByteType) { + compareByteVectors(left, right, result); + } else if (dataType instanceof ShortType) { + compareShortVectors(left, right, result); + } else if (dataType instanceof IntegerType || dataType instanceof DateType) { + compareIntVectors(left, right, result); + } else if (dataType instanceof LongType || dataType instanceof TimestampType) { + compareLongVectors(left, right, result); + } else if (dataType instanceof FloatType) { + compareFloatVectors(left, right, result); + } else if (dataType instanceof DoubleType) { + compareDoubleVectors(left, right, result); + } else if (dataType instanceof DecimalType) { + compareDecimalVectors(left, right, result); + } else if (dataType instanceof StringType) { + compareStringVectors(left, right, result); + } else if (dataType instanceof BinaryType) { + compareBinaryVectors(left, right, result); + } else { + throw new UnsupportedOperationException(dataType + " can not be compared."); + } + return result; + } + + static void compareBooleanVectors(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + result[rowId] = Boolean.compare(left.getBoolean(rowId), right.getBoolean(rowId)); + } + } + + static void compareByteVectors(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + result[rowId] = Byte.compare(left.getByte(rowId), right.getByte(rowId)); + } + } + + static void compareShortVectors(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + result[rowId] = Short.compare(left.getShort(rowId), right.getShort(rowId)); + } + } + + static void compareIntVectors(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + result[rowId] = Integer.compare(left.getInt(rowId), right.getInt(rowId)); + } + } + + static void compareLongVectors(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + result[rowId] = Long.compare(left.getLong(rowId), right.getLong(rowId)); + } + } + + static void compareFloatVectors(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + result[rowId] = Float.compare(left.getFloat(rowId), right.getFloat(rowId)); + } + } + + static void compareDoubleVectors(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + result[rowId] = Double.compare(left.getDouble(rowId), right.getDouble(rowId)); + } + } + + static void compareStringVectors(ColumnVector left, ColumnVector right, int[] result) { + Comparator comparator = Comparator.naturalOrder(); + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = comparator.compare(left.getString(rowId), right.getString(rowId)); + } + // When either side is null, it doesn't matter as the result is always null + } + } + + static void compareDecimalVectors(ColumnVector left, ColumnVector right, int[] result) { + Comparator comparator = Comparator.naturalOrder(); + for (int rowId = 0; rowId < left.getSize(); rowId++) { + result[rowId] = comparator.compare(left.getDecimal(rowId), right.getDecimal(rowId)); + } + } + + static void compareBinaryVectors(ColumnVector left, ColumnVector right, int[] result) { + Comparator comparator = (leftOp, rightOp) -> { + int i = 0; + while (i < leftOp.length && i < rightOp.length) { + if (leftOp[i] != rightOp[i]) { + return Byte.compare(leftOp[i], rightOp[i]); + } + i++; + } + return Integer.compare(leftOp.length, rightOp.length); + }; + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = comparator.compare(left.getBinary(rowId), right.getBinary(rowId)); + } + // When either side is null, it doesn't matter as the result is always null + } + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java new file mode 100644 index 00000000000..433ea2b2461 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -0,0 +1,73 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions; + +import io.delta.kernel.expressions.Column; +import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Literal; +import io.delta.kernel.expressions.predicates.AlwaysFalse; +import io.delta.kernel.expressions.predicates.AlwaysTrue; +import io.delta.kernel.expressions.predicates.And; +import io.delta.kernel.expressions.predicates.Or; +import io.delta.kernel.expressions.predicates.Predicate; + +/** + * Interface to allow visiting an expression tree and implementing handling for each + * specific expression type. + * + * @param Once an expression is process what type of value is returned/ + */ +abstract class ExpressionVisitor { + + abstract R visitAnd(And and); + + abstract R visitOr(Or or); + + abstract R visitAlwaysTrue(AlwaysTrue alwaysTrue); + + abstract R visitAlwaysFalse(AlwaysFalse alwaysFalse); + + abstract R visitPredicate(Predicate predicate); + + abstract R visitLiteral(Literal literal); + + abstract R visitColumn(Column column); + + abstract R visitCast(ImplicitCastExpression cast); + + final R visit(Expression expression) { + if (expression instanceof And) { + return visitAnd((And) expression); + } else if (expression instanceof Or) { + return visitOr((Or) expression); + } else if (expression instanceof AlwaysTrue) { + return visitAlwaysTrue((AlwaysTrue) expression); + } else if (expression instanceof AlwaysFalse) { + return visitAlwaysFalse((AlwaysFalse) expression); + } else if (expression instanceof Predicate) { + return visitPredicate((Predicate) expression); + } else if (expression instanceof Literal) { + return visitLiteral((Literal) expression); + } else if (expression instanceof Column) { + return visitColumn((Column) expression); + } else if (expression instanceof ImplicitCastExpression) { + return visitCast((ImplicitCastExpression) expression); + } + + throw new UnsupportedOperationException( + String.format("Expression %s is not supported.", expression)); + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpression.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpression.java new file mode 100644 index 00000000000..0d56d21a7fe --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpression.java @@ -0,0 +1,274 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions; + +import java.util.*; +import static java.lang.String.format; +import static java.util.Collections.unmodifiableMap; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.expressions.Expression; +import io.delta.kernel.types.DataType; + +import io.delta.kernel.defaults.client.DefaultExpressionHandler; + +/** + * An implicit cast expression to convert the input type to another given type. Here is the valid + * list of casts + *

+ *

    + *
  • {@code byte} to {@code short, int, long, float, double}
  • + *
  • {@code short} to {@code int, long, float, double}
  • + *
  • {@code int} to {@code long, float, double}
  • + *
  • {@code long} to {@code float, double}
  • + *
  • {@code float} to {@code double}
  • + *
+ * + *

+ * The above list is not exhaustive. Based on the need, we can add more casts. + *

+ * In {@link DefaultExpressionHandler} this is used when the operands of an expression are not of + * the same type, but the evaluator expects same type inputs. There could be more use cases, but + * for now this is the only use case. + */ +class ImplicitCastExpression implements Expression { + private final Expression input; + private final DataType outputType; + + ImplicitCastExpression(Expression input, DataType outputType) { + this.input = requireNonNull(input, "input is null"); + this.outputType = requireNonNull(outputType, "outputType is null"); + } + + public Expression getInput() { + return input; + } + + public DataType getOutputType() { + return outputType; + } + + @Override + public List getChildren() { + return Collections.singletonList(input); + } + + /** + * Map containing for each type what are the target cast types can be. + */ + private static final Map> UP_CASTABLE_TYPE_TABLE = unmodifiableMap( + new HashMap>() { + { + this.put("byte", Arrays.asList("short", "integer", "long", "float", "double")); + this.put("short", Arrays.asList("integer", "long", "float", "double")); + this.put("integer", Arrays.asList("long", "float", "double")); + this.put("long", Arrays.asList("float", "double")); + this.put("float", Arrays.asList("double")); + } + }); + + /** + * Utility method which returns whether the given {@code from} type can be cast to {@code to} + * type. + */ + static boolean canCastTo(DataType from, DataType to) { + // TODO: The type name should be a first class method on `DataType` instead of getting it + // using the `toString`. + String fromStr = from.toString(); + String toStr = to.toString(); + return UP_CASTABLE_TYPE_TABLE.containsKey(fromStr) && + UP_CASTABLE_TYPE_TABLE.get(fromStr).contains(toStr); + } + + /** + * Utility method to create a cast around the given input expression to specified output data + * type. It is the responsibility of the caller to validate the input expression can be cast + * to the new type using {@link #canCastTo(DataType, DataType)} + * + * @return + */ + static ImplicitCastExpression createCastExpression(Expression inputExpression, + DataType outputType) { + return new ImplicitCastExpression(inputExpression, outputType); + } + + /** + * Utility method to evaluate the given column expression on the input {@link ColumnVector}. + * + * @param cast Cast expression. + * @param input {@link ColumnVector} data of the input to the cast expression. + * @return {@link ColumnVector} result applying target type casting on every element in the + * input {@link ColumnVector}. + */ + static ColumnVector evalCastExpression(ImplicitCastExpression cast, ColumnVector input) { + String fromTypeStr = input.getDataType().toString(); + DataType toType = cast.getOutputType(); + switch (fromTypeStr) { + case "byte": + return new ByteUpConverter(toType, input); + case "short": + return new ShortUpConverter(toType, input); + case "integer": + return new IntUpConverter(toType, input); + case "long": + return new LongUpConverter(toType, input); + case "float": + return new FloatUpConverter(toType, input); + default: + throw new UnsupportedOperationException( + format("Cast from %s is not supported", fromTypeStr)); + } + } + + /** + * Base class for up casting {@link ColumnVector} data. + */ + private abstract static class UpConverter implements ColumnVector { + protected final DataType targetType; + protected final ColumnVector inputVector; + + UpConverter(DataType targetType, ColumnVector inputVector) { + this.targetType = targetType; + this.inputVector = inputVector; + } + + @Override + public DataType getDataType() { + return targetType; + } + + @Override + public boolean isNullAt(int rowId) { + return inputVector.isNullAt(rowId); + } + + @Override + public int getSize() { + return inputVector.getSize(); + } + + @Override + public void close() { + inputVector.close(); + } + } + + private static class ByteUpConverter extends UpConverter { + ByteUpConverter(DataType targetType, ColumnVector inputVector) { + super(targetType, inputVector); + } + + @Override + public short getShort(int rowId) { + return getByte(rowId); + } + + @Override + public int getInt(int rowId) { + return getByte(rowId); + } + + @Override + public long getLong(int rowId) { + return getByte(rowId); + } + + @Override + public float getFloat(int rowId) { + return getByte(rowId); + } + + @Override + public double getDouble(int rowId) { + return getByte(rowId); + } + } + + private static class ShortUpConverter extends UpConverter { + ShortUpConverter(DataType targetType, ColumnVector inputVector) { + super(targetType, inputVector); + } + + @Override + public int getInt(int rowId) { + return getShort(rowId); + } + + @Override + public long getLong(int rowId) { + return getShort(rowId); + } + + @Override + public float getFloat(int rowId) { + return getShort(rowId); + } + + @Override + public double getDouble(int rowId) { + return getShort(rowId); + } + } + + private static class IntUpConverter extends UpConverter { + IntUpConverter(DataType targetType, ColumnVector inputVector) { + super(targetType, inputVector); + } + + @Override + public long getLong(int rowId) { + return getInt(rowId); + } + + @Override + public float getFloat(int rowId) { + return getInt(rowId); + } + + @Override + public double getDouble(int rowId) { + return getInt(rowId); + } + } + + private static class LongUpConverter extends UpConverter { + LongUpConverter(DataType targetType, ColumnVector inputVector) { + super(targetType, inputVector); + } + + @Override + public float getFloat(int rowId) { + return getLong(rowId); + } + + @Override + public double getDouble(int rowId) { + return getLong(rowId); + } + } + + private static class FloatUpConverter extends UpConverter { + FloatUpConverter(DataType targetType, ColumnVector inputVector) { + super(targetType, inputVector); + } + + @Override + public double getDouble(int rowId) { + return getFloat(rowId); + } + } +} diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultExpressionHandler.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultExpressionHandler.java index 6a10c1cb9b9..db5cc4fbe8c 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultExpressionHandler.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultExpressionHandler.java @@ -28,17 +28,17 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.expressions.And; import io.delta.kernel.expressions.Column; -import io.delta.kernel.expressions.EqualTo; import io.delta.kernel.expressions.Expression; import io.delta.kernel.expressions.Literal; +import io.delta.kernel.expressions.predicates.And; +import io.delta.kernel.expressions.predicates.Predicate; import io.delta.kernel.types.*; +import static io.delta.kernel.utils.Utils.daysSinceEpoch; import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch; import io.delta.kernel.defaults.internal.data.vector.DefaultIntVector; import io.delta.kernel.defaults.internal.data.vector.DefaultLongVector; -import static io.delta.kernel.defaults.internal.DefaultKernelUtils.daysSinceEpoch; public class TestDefaultExpressionHandler { /** @@ -69,10 +69,11 @@ public void evalLiterals() { testCases.add(Literal.ofNull(StringType.INSTANCE)); testCases.add(Literal.of("binary_val".getBytes())); testCases.add(Literal.ofNull(BinaryType.INSTANCE)); - testCases.add(Literal.of(new Date(234234234))); + testCases.add(Literal.ofDate(daysSinceEpoch(new Date(234234234)))); testCases.add(Literal.ofNull(DateType.INSTANCE)); - // testCases.add(Literal.of(new Timestamp(2342342342232L))); - // testCases.add(Literal.ofNull(TimestampType.INSTANCE)); + + testCases.add(Literal.ofTimestamp(2342342342232L)); + testCases.add(Literal.ofNull(TimestampType.INSTANCE)); ColumnarBatch[] inputBatches = new ColumnarBatch[] { new DefaultColumnarBatch(0, inputSchema, data), @@ -81,18 +82,19 @@ public void evalLiterals() { }; for (Literal expression : testCases) { - DataType outputDataType = expression.dataType(); + DataType outputDataType = expression.getDataType(); for (ColumnarBatch inputBatch : inputBatches) { - ColumnVector outputVector = eval(inputSchema, inputBatch, expression); + ColumnVector outputVector = + eval(inputSchema, inputBatch, expression, outputDataType); assertEquals(inputBatch.getSize(), outputVector.getSize()); assertEquals(outputDataType, outputVector.getDataType()); for (int rowId = 0; rowId < outputVector.getSize(); rowId++) { - if (expression.value() == null) { + if (expression.getValue() == null) { assertTrue(outputVector.isNullAt(rowId)); continue; } - Object expRowValue = expression.value(); + Object expRowValue = expression.getValue(); if (outputDataType instanceof BooleanType) { assertEquals(expRowValue, outputVector.getBoolean(rowId)); } else if (outputDataType instanceof ByteType) { @@ -112,8 +114,10 @@ public void evalLiterals() { } else if (outputDataType instanceof BinaryType) { assertEquals(expRowValue, outputVector.getBinary(rowId)); } else if (outputDataType instanceof DateType) { - assertEquals( - daysSinceEpoch((Date) expRowValue), outputVector.getInt(rowId)); + assertEquals(expRowValue, outputVector.getInt(rowId)); + } else if (outputDataType instanceof TimestampType) { + long micros = (Long) expRowValue; + assertEquals(micros, outputVector.getLong(rowId)); } else { throw new UnsupportedOperationException( "unsupported output type encountered: " + outputDataType); @@ -125,20 +129,18 @@ public void evalLiterals() { @Test public void evalBooleanExpressionSimple() { - Expression expression = new EqualTo( - new Column(0, "intType", IntegerType.INSTANCE), - Literal.of(3)); + Expression expression = new Predicate( + "=", + Arrays.asList(new Column("intType"), Literal.of(3))); for (int size : Arrays.asList(26, 234, 567)) { StructType inputSchema = new StructType() .add("intType", IntegerType.INSTANCE); - ColumnVector[] data = new ColumnVector[] { - intVector(size) - }; + ColumnVector[] data = new ColumnVector[] {intVector(size)}; ColumnarBatch inputBatch = new DefaultColumnarBatch(size, inputSchema, data); - ColumnVector output = eval(inputSchema, inputBatch, expression); + ColumnVector output = eval(inputSchema, inputBatch, expression, BooleanType.INSTANCE); for (int rowId = 0; rowId < size; rowId++) { if (data[0].isNullAt(rowId)) { // expect the output to be null as well @@ -155,8 +157,12 @@ public void evalBooleanExpressionSimple() { @Test public void evalBooleanExpressionComplex() { Expression expression = new And( - new EqualTo(new Column(0, "intType", IntegerType.INSTANCE), Literal.of(3)), - new EqualTo(new Column(1, "longType", LongType.INSTANCE), Literal.of(4L)) + new Predicate( + "=", + Arrays.asList(new Column("intType"), Literal.of(3))), + new Predicate( + "=", + Arrays.asList(new Column("longType"), Literal.of(4L))) ); for (int size : Arrays.asList(26, 234, 567)) { @@ -170,7 +176,7 @@ public void evalBooleanExpressionComplex() { ColumnarBatch inputBatch = new DefaultColumnarBatch(size, inputSchema, data); - ColumnVector output = eval(inputSchema, inputBatch, expression); + ColumnVector output = eval(inputSchema, inputBatch, expression, BooleanType.INSTANCE); for (int rowId = 0; rowId < size; rowId++) { if (data[0].isNullAt(rowId) || data[1].isNullAt(rowId)) { // expect the output to be null as well @@ -185,9 +191,9 @@ public void evalBooleanExpressionComplex() { } private static ColumnVector eval( - StructType inputSchema, ColumnarBatch input, Expression expression) { + StructType inputSchema, ColumnarBatch input, Expression expression, DataType outputType) { return new DefaultExpressionHandler() - .getEvaluator(inputSchema, expression) + .getEvaluator(inputSchema, expression, outputType) .eval(input); } diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java index 8b8591f96d1..1cef07ce3e8 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java @@ -30,12 +30,12 @@ import io.delta.kernel.client.TableClient; import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.types.*; +import static io.delta.kernel.utils.Utils.daysSinceEpoch; import io.delta.kernel.defaults.client.DefaultTableClient; import io.delta.kernel.defaults.integration.DataBuilderUtils.TestColumnBatchBuilder; import static io.delta.kernel.defaults.integration.DataBuilderUtils.row; import static io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getTestResourceFilePath; -import static io.delta.kernel.defaults.internal.DefaultKernelUtils.daysSinceEpoch; /** * Test reading Delta lake tables end to end using the Kernel APIs and default {@link TableClient} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala new file mode 100644 index 00000000000..a06c10ac135 --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -0,0 +1,20 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions + +class DefaultExpressionEvaluatorSuite { + +} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala new file mode 100644 index 00000000000..c4d3d180cf4 --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala @@ -0,0 +1,60 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions + +import io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo +import io.delta.kernel.types._ +import org.scalatest.funsuite.AnyFunSuite + +class ImplicitCastExpressionSuite extends AnyFunSuite { + test("can cast to") { + val allowedCasts: Set[(DataType, DataType)] = Set( + (ByteType.INSTANCE, ShortType.INSTANCE), + (ByteType.INSTANCE, IntegerType.INSTANCE), + (ByteType.INSTANCE, LongType.INSTANCE), + (ByteType.INSTANCE, FloatType.INSTANCE), + (ByteType.INSTANCE, DoubleType.INSTANCE), + + (ShortType.INSTANCE, IntegerType.INSTANCE), + (ShortType.INSTANCE, LongType.INSTANCE), + (ShortType.INSTANCE, FloatType.INSTANCE), + (ShortType.INSTANCE, DoubleType.INSTANCE), + + (IntegerType.INSTANCE, LongType.INSTANCE), + (IntegerType.INSTANCE, FloatType.INSTANCE), + (IntegerType.INSTANCE, DoubleType.INSTANCE), + + (LongType.INSTANCE, FloatType.INSTANCE), + (LongType.INSTANCE, DoubleType.INSTANCE), + (FloatType.INSTANCE, DoubleType.INSTANCE)) + + val types = Seq(ByteType.INSTANCE, ShortType.INSTANCE, IntegerType.INSTANCE, + LongType.INSTANCE, FloatType.INSTANCE, DoubleType.INSTANCE, DateType.INSTANCE, + TimestampType.INSTANCE, BooleanType.INSTANCE, StringType.INSTANCE, BinaryType.INSTANCE, + new DecimalType(10, 5), new ArrayType(BooleanType.INSTANCE, true), + new MapType(IntegerType.INSTANCE, LongType.INSTANCE, true) + ) + + Seq.range(0, types.length).foreach { fromTypeIdx => + val fromType: DataType = types(fromTypeIdx) + Seq.range(0, types.length).foreach { toTypeIdx => + val toType: DataType = types(toTypeIdx) + assert(canCastTo(fromType, toType) === + allowedCasts.contains((fromType, toType))) + } + } + } +}