diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java index c82877f2c96..ce6afa2a12b 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java @@ -26,18 +26,14 @@ import io.delta.kernel.client.FileReadContext; import io.delta.kernel.client.ParquetHandler; import io.delta.kernel.client.TableClient; -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.data.DataReadResult; -import io.delta.kernel.data.FileDataReadResult; -import io.delta.kernel.data.Row; -import io.delta.kernel.expressions.Expression; -import io.delta.kernel.expressions.Literal; +import io.delta.kernel.data.*; +import io.delta.kernel.expressions.Predicate; import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.Tuple2; import io.delta.kernel.utils.Utils; +import static io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE; import io.delta.kernel.internal.actions.DeletionVectorDescriptor; import io.delta.kernel.internal.data.AddFileColumnarBatch; @@ -67,9 +63,9 @@ public interface Scan { * Get the remaining filter that is not guaranteed to be satisfied for the data Delta Kernel * returns. This filter is used by Delta Kernel to do data skipping when possible. * - * @return the remaining filter as an {@link Expression}. + * @return the remaining filter as a {@link Predicate}. */ - Optional getRemainingFilter(); + Optional getRemainingFilter(); /** * Get the scan state associated with the current scan. This state is common across all @@ -88,9 +84,8 @@ public interface Scan { * @param scanFileRowIter an iterator of {@link Row}s. Each {@link Row} represents one scan file * from the {@link ColumnarBatch} returned by * {@link Scan#getScanFiles(TableClient)} - * @param filter An optional filter that can be used for data skipping while reading - * the - * scan files. + * @param predicate An optional predicate that can be used for data skipping while reading + * the scan files. * @return Data read from the input scan files as an iterator of {@link DataReadResult}s. Each * {@link DataReadResult} instance contains the data read and an optional selection * vector that indicates data rows as valid or invalid. It is the responsibility of the @@ -101,7 +96,7 @@ static CloseableIterator readData( TableClient tableClient, Row scanState, CloseableIterator scanFileRowIter, - Optional filter) throws IOException { + Optional predicate) throws IOException { StructType physicalSchema = Utils.getPhysicalSchema(tableClient, scanState); StructType logicalSchema = Utils.getLogicalSchema(tableClient, scanState); List partitionColumns = Utils.getPartitionColumns(scanState); @@ -122,7 +117,7 @@ static CloseableIterator readData( CloseableIterator filesReadContextsIter = parquetHandler.contextualizeFileReads( scanFileRowIter, - filter.orElse(Literal.TRUE)); + predicate.orElse(ALWAYS_TRUE)); CloseableIterator data = parquetHandler.readParquetFiles( filesReadContextsIter, diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/ScanBuilder.java b/kernel/kernel-api/src/main/java/io/delta/kernel/ScanBuilder.java index 55b1923a073..deff60fdd16 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/ScanBuilder.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/ScanBuilder.java @@ -18,7 +18,7 @@ import io.delta.kernel.annotation.Evolving; import io.delta.kernel.client.TableClient; -import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Predicate; import io.delta.kernel.types.StructType; /** @@ -34,10 +34,10 @@ public interface ScanBuilder { * the given filter. * * @param tableClient {@link TableClient} instance to use in Delta Kernel. - * @param filter an {@link Expression} which evaluates to boolean. + * @param predicate a {@link Predicate} to prune the metadata or data. * @return A {@link ScanBuilder} with filter applied. */ - ScanBuilder withFilter(TableClient tableClient, Expression filter); + ScanBuilder withFilter(TableClient tableClient, Predicate predicate); /** * Apply the given readSchema. If the builder already has a projection applied, calling diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/client/ExpressionHandler.java b/kernel/kernel-api/src/main/java/io/delta/kernel/client/ExpressionHandler.java index 2c8b3de61de..d5d130012a3 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/client/ExpressionHandler.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/client/ExpressionHandler.java @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.client; import io.delta.kernel.annotation.Evolving; +import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.expressions.Expression; import io.delta.kernel.expressions.ExpressionEvaluator; +import io.delta.kernel.types.DataType; import io.delta.kernel.types.StructType; /** @@ -32,12 +33,16 @@ public interface ExpressionHandler { /** * Create an {@link ExpressionEvaluator} that can evaluate the given expression on - * {@link io.delta.kernel.data.ColumnarBatch}s with the given batchSchema. + * {@link ColumnarBatch}s with the given batchSchema. The expression is + * expected to be a scalar expression where for each one input row there + * is a one output value. * - * @param batchSchema Schema of the input data. + * @param inputSchema Input data schema * @param expression Expression to evaluate. - * @return An {@link ExpressionEvaluator} instance bound to the given expression and - * batchSchema. + * @param outputType Expected result data type. */ - ExpressionEvaluator getEvaluator(StructType batchSchema, Expression expression); + ExpressionEvaluator getEvaluator( + StructType inputSchema, + Expression expression, + DataType outputType); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/client/FileHandler.java b/kernel/kernel-api/src/main/java/io/delta/kernel/client/FileHandler.java index 73fbca53656..71cd92892a7 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/client/FileHandler.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/client/FileHandler.java @@ -18,7 +18,7 @@ import io.delta.kernel.annotation.Evolving; import io.delta.kernel.data.Row; -import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Predicate; import io.delta.kernel.fs.FileStatus; import io.delta.kernel.utils.CloseableIterator; @@ -49,5 +49,5 @@ public interface FileHandler { */ CloseableIterator contextualizeFileReads( CloseableIterator fileIter, - Expression predicate); + Predicate predicate); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/EqualTo.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/AlwaysFalse.java similarity index 62% rename from kernel/kernel-api/src/main/java/io/delta/kernel/expressions/EqualTo.java rename to kernel/kernel-api/src/main/java/io/delta/kernel/expressions/AlwaysFalse.java index 8f2076c56e5..64cf1f1f5cf 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/EqualTo.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/AlwaysFalse.java @@ -13,21 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.expressions; +import java.util.Collections; + +import io.delta.kernel.annotation.Evolving; + /** - * Evaluates {@code expr1} = {@code expr2} for {@code new EqualTo(expr1, expr2)}. + * Predicate which always evaluates to {@code false}. + * + * @since 3.0.0 */ -public final class EqualTo extends BinaryComparison implements Predicate { +@Evolving +public final class AlwaysFalse extends Predicate { + public static final AlwaysFalse ALWAYS_FALSE = new AlwaysFalse(); - public EqualTo(Expression left, Expression right) { - super(left, right, "="); - } - - @Override - protected Object nullSafeEval(Object leftResult, Object rightResult) { - return compare(leftResult, rightResult) == 0; + private AlwaysFalse() { + super("ALWAYS_FALSE", Collections.emptyList()); } } - diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/LeafExpression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/AlwaysTrue.java similarity index 60% rename from kernel/kernel-api/src/main/java/io/delta/kernel/expressions/LeafExpression.java rename to kernel/kernel-api/src/main/java/io/delta/kernel/expressions/AlwaysTrue.java index 05f48b5c5a6..031cac08430 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/LeafExpression.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/AlwaysTrue.java @@ -13,31 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.expressions; import java.util.Collections; -import java.util.List; -import java.util.Set; + +import io.delta.kernel.annotation.Evolving; /** - * An {@link Expression} with no children. + * Predicate which always evaluates to {@code true}. + * + * @since 3.0.0 */ -public abstract class LeafExpression implements Expression { - - protected LeafExpression() {} +@Evolving +public final class AlwaysTrue extends Predicate { + public static final AlwaysTrue ALWAYS_TRUE = new AlwaysTrue(); - @Override - public List children() { - return Collections.emptyList(); + private AlwaysTrue() { + super("ALWAYS_TRUE", Collections.emptyList()); } - - @Override - public Set references() { - return Collections.emptySet(); - } - - public abstract boolean equals(Object o); - - public abstract int hashCode(); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/And.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/And.java index 4ad3328c933..9ba046466a9 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/And.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/And.java @@ -13,50 +13,47 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.expressions; -import java.util.Collection; +import java.util.Arrays; -import io.delta.kernel.types.BooleanType; +import io.delta.kernel.annotation.Evolving; /** - * Evaluates logical {@code expr1} AND {@code expr2} for {@code new And(expr1, expr2)}. + * {@code AND} expression + *

+ * Definition: *

- * Requires both left and right input expressions evaluate to booleans. + *

    + *
  • Logical {@code expr1} AND {@code expr2} on two inputs.
  • + *
  • Requires both left and right input expressions of type {@link Predicate}.
  • + *
  • Result is null at least one of the inputs is null.
  • + *
+ * + * @since 3.0.0 */ -public final class And extends BinaryOperator implements Predicate { - - public static And apply(Collection conjunctions) { - if (conjunctions.size() == 0) { - throw new IllegalArgumentException("And.apply must be called with at least 1 element"); - } - - return (And) conjunctions - .stream() - // we start off with And(true, true) - // then we get the 1st expression: And(And(true, true), expr1) - // then we get the 2nd expression: And(And(true, true), expr1), expr2) etc. - .reduce(new And(Literal.TRUE, Literal.TRUE), And::new); +@Evolving +public final class And extends Predicate { + public And(Predicate left, Predicate right) { + super("AND", Arrays.asList(left, right)); } - public And(Expression left, Expression right) { - super(left, right, "&&"); - if (!(left.dataType() instanceof BooleanType) || - !(right.dataType() instanceof BooleanType)) { + /** + * @return Left side operand. + */ + public Predicate getLeft() { + return (Predicate) getChildren().get(0); + } - throw new IllegalArgumentException( - String.format( - "'And' requires expressions of type boolean. Got %s and %s.", - left.dataType(), - right.dataType() - ) - ); - } + /** + * @return Right side operand. + */ + public Predicate getRight() { + return (Predicate) getChildren().get(1); } @Override - public Object nullSafeEval(Object leftResult, Object rightResult) { - return (boolean) leftResult && (boolean) rightResult; + public String toString() { + return "(" + getLeft() + " AND " + getRight() + ")"; } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryComparison.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryComparison.java deleted file mode 100644 index 54e9648d85b..00000000000 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryComparison.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.delta.kernel.expressions; - -import java.util.Comparator; - -import io.delta.kernel.internal.expressions.CastingComparator; - -/** - * A {@link BinaryOperator} that compares the left and right {@link Expression}s and evaluates to a - * boolean value. - */ -public abstract class BinaryComparison extends BinaryOperator implements Predicate { - private final Comparator comparator; - - protected BinaryComparison(Expression left, Expression right, String symbol) { - super(left, right, symbol); - - // super asserted that left and right DataTypes were the same - - comparator = CastingComparator.forDataType(left.dataType()); - } - - protected int compare(Object leftResult, Object rightResult) { - return comparator.compare(leftResult, rightResult); - } -} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryExpression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryExpression.java deleted file mode 100644 index 82a67b328e1..00000000000 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryExpression.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.delta.kernel.expressions; - -import java.util.Arrays; -import java.util.List; -import java.util.Objects; - -import io.delta.kernel.data.Row; - -/** - * An {@link Expression} with two inputs and one output. The output is by default evaluated to null - * if either input is evaluated to null. - */ -public abstract class BinaryExpression implements Expression { - protected final Expression left; - protected final Expression right; - - protected BinaryExpression(Expression left, Expression right) { - this.left = left; - this.right = right; - } - - public Expression getLeft() { - return left; - } - - public Expression getRight() { - return right; - } - - @Override - public final Object eval(Row row) { - Object leftResult = left.eval(row); - if (null == leftResult) return null; - - Object rightResult = right.eval(row); - if (null == rightResult) return null; - - return nullSafeEval(leftResult, rightResult); - } - - protected abstract Object nullSafeEval(Object leftResult, Object rightResult); - - @Override - public List children() { - return Arrays.asList(left, right); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - BinaryExpression that = (BinaryExpression) o; - return Objects.equals(left, that.left) && - Objects.equals(right, that.right); - } - - @Override - public int hashCode() { - return Objects.hash(left, right); - } -} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryOperator.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryOperator.java deleted file mode 100644 index 4c30e2c79bb..00000000000 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/BinaryOperator.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.delta.kernel.expressions; - -/** - * A {@link BinaryExpression} that is an operator, meaning the string representation is - * {@code x symbol y}, rather than {@code funcName(x, y)}. - *

- * Requires both inputs to be of the same data type. - */ -public abstract class BinaryOperator extends BinaryExpression { - protected final String symbol; - - protected BinaryOperator(Expression left, Expression right, String symbol) { - super(left, right); - this.symbol = symbol; - - if (!left.dataType().equals(right.dataType())) { - throw new IllegalArgumentException( - String.format( - "BinaryOperator left and right DataTypes must be the same. Found %s and %s.", - left.dataType(), - right.dataType() - )); - } - } - - @Override - public String toString() { - return String.format("(%s %s %s)", left.toString(), symbol, right.toString()); - } -} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Column.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Column.java index 5d7fc02188e..3ae2fb1109a 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Column.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Column.java @@ -13,101 +13,40 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.expressions; import java.util.Collections; -import java.util.Objects; -import java.util.Set; +import java.util.List; -import io.delta.kernel.data.Row; -import io.delta.kernel.types.BooleanType; -import io.delta.kernel.types.DataType; -import io.delta.kernel.types.IntegerType; -import io.delta.kernel.types.LongType; -import io.delta.kernel.types.StringType; -import io.delta.kernel.types.StructType; +import io.delta.kernel.annotation.Evolving; /** - * A column whose row-value will be computed based on the data in a {@link Row}. - *

- * It is recommended that you instantiate using an existing table schema - * {@link StructType} with {@link StructType#column(int)}. - *

- * Only supports primitive data types, see - * Delta Transaction Log Protocol: Primitive Types. + * An expression type that refers to a column by name (case-sensitive) in the input. + * + * @since 3.0.0 */ -public final class Column extends LeafExpression { - private final int ordinal; +@Evolving +public final class Column implements Expression { private final String name; - private final DataType dataType; - private final RowEvaluator evaluator; - public Column(int ordinal, String name, DataType dataType) { - this.ordinal = ordinal; + public Column(String name) { this.name = name; - this.dataType = dataType; - - if (dataType instanceof IntegerType) { - evaluator = (row -> row.getInt(ordinal)); - } else if (dataType instanceof BooleanType) { - evaluator = (row -> row.getBoolean(ordinal)); - } else if (dataType instanceof LongType) { - evaluator = (row -> row.getLong(ordinal)); - } else if (dataType instanceof StringType) { - evaluator = (row -> row.getString(ordinal)); - } else { - throw new UnsupportedOperationException( - String.format( - "The data type %s of column %s at ordinal %s is not supported", - dataType, - name, - ordinal) - ); - } } - public String name() { + /** + * @return the column name. + */ + public String getName() { return name; } @Override - public Object eval(Row row) { - return row.isNullAt(ordinal) ? null : evaluator.nullSafeEval(row); - } - - @Override - public DataType dataType() { - return dataType; + public List getChildren() { + return Collections.emptyList(); } @Override public String toString() { - return "Column(" + name + ")"; - } - - @Override - public Set references() { - return Collections.singleton(name); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Column column = (Column) o; - return Objects.equals(ordinal, column.ordinal) && - Objects.equals(name, column.name) && - Objects.equals(dataType, column.dataType); - } - - @Override - public int hashCode() { - return Objects.hash(name, dataType); - } - - @FunctionalInterface - private interface RowEvaluator { - Object nullSafeEval(Row row); + return "column(" + name + ")"; } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Expression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Expression.java index af4bc592ecc..761203ff7e7 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Expression.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Expression.java @@ -13,48 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.expressions; -import java.util.HashSet; import java.util.List; -import java.util.Set; -import io.delta.kernel.data.Row; -import io.delta.kernel.types.DataType; +import io.delta.kernel.annotation.Evolving; /** - * Generic interface for all Expressions + * Base interface for all Kernel expressions. + * + * @since 3.0.0 */ +@Evolving public interface Expression { - - /** - * @param row the input row to evaluate. - * @return the result of evaluating this expression on the given input {@link Row}. - */ - Object eval(Row row); - - /** - * @return the {@link DataType} of the result of evaluating this expression. - */ - DataType dataType(); - - /** - * @return the String representation of this expression. - */ - String toString(); - - /** - * @return a {@link List} of the immediate children of this node - */ - List children(); - /** - * @return the names of columns referenced by this expression. + * @return a list of expressions that are input to this expression. */ - default Set references() { - Set result = new HashSet<>(); - children().forEach(child -> result.addAll(child.references())); - return result; - } + List getChildren(); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Literal.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Literal.java index 2f653e02a02..109aea65dc0 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Literal.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Literal.java @@ -13,139 +13,170 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.expressions; import java.math.BigDecimal; -import java.sql.Date; -import java.sql.Timestamp; -import java.util.Objects; +import java.util.Collections; +import java.util.List; -import io.delta.kernel.data.Row; +import io.delta.kernel.annotation.Evolving; import io.delta.kernel.types.*; + import static io.delta.kernel.internal.util.InternalUtils.checkArgument; /** * A literal value. *

- * Only supports primitive data types, see - * Delta Transaction Log Protocol: Primitive Types. + * Definition: + *

    + *
  • Represents literal of primitive types as defined in the protocol + * + * Delta Transaction Log Protocol: Primitive Types + *
  • + *
  • Use {@link #getValue()} to fetch the literal value. Returned value type + * depends on the type of the literal data type. See the {@link #getValue()} for further + * details.
  • + *
+ * + * @since 3.0.0 */ -public final class Literal extends LeafExpression { - public static final Literal TRUE = Literal.of(true); - public static final Literal FALSE = Literal.of(false); - +@Evolving +public final class Literal implements Expression { /** - * Create a boolean {@link Literal} object + * Create a {@code boolean} type literal expression. * - * @param value boolean value - * @return a {@link Literal} with data type {@link BooleanType} + * @param value literal value + * @return a {@link Literal} of type {@link BooleanType} */ - public static Literal of(boolean value) { + public static Literal ofBoolean(boolean value) { return new Literal(value, BooleanType.INSTANCE); } /** - * @return a {@link Literal} with data type {@link ByteType} + * Create a {@code byte} type literal expression. + * + * @param value literal value + * @return a {@link Literal} of type {@link ByteType} */ - public static Literal of(byte value) { + public static Literal ofByte(byte value) { return new Literal(value, ByteType.INSTANCE); } /** - * @return a {@link Literal} with data type {@link ShortType} + * Create a {@code short} type literal expression. + * + * @param value literal value + * @return a {@link Literal} of type {@link ShortType} */ - public static Literal of(short value) { + public static Literal ofShort(short value) { return new Literal(value, ShortType.INSTANCE); } /** - * Create an integer {@link Literal} object + * Create a {@code integer} type literal expression. * - * @param value integer value - * @return a {@link Literal} with data type {@link IntegerType} + * @param value literal value + * @return a {@link Literal} of type {@link IntegerType} */ - public static Literal of(int value) { + public static Literal ofInt(int value) { return new Literal(value, IntegerType.INSTANCE); } /** - * Create a long {@link Literal} object + * Create a {@code long} type literal expression. * - * @param value long value - * @return a {@link Literal} with data type {@link LongType} + * @param value literal value + * @return a {@link Literal} of type {@link LongType} */ - public static Literal of(long value) { + public static Literal ofLong(long value) { return new Literal(value, LongType.INSTANCE); } /** - * @return a {@link Literal} with data type {@link FloatType} + * Create a {@code float} type literal expression. + * + * @param value literal value + * @return a {@link Literal} of type {@link FloatType} */ - public static Literal of(float value) { + public static Literal ofFloat(float value) { return new Literal(value, FloatType.INSTANCE); } /** - * @return a {@link Literal} with data type {@link DoubleType} + * Create a {@code double} type literal expression. + * + * @param value literal value + * @return a {@link Literal} of type {@link DoubleType} */ - public static Literal of(double value) { + public static Literal ofDouble(double value) { return new Literal(value, DoubleType.INSTANCE); } /** - * Create a string {@link Literal} object + * Create a {@code string} type literal expression. * - * @param value string value - * @return a {@link Literal} with data type {@link StringType} + * @param value literal value + * @return a {@link Literal} of type {@link StringType} */ - public static Literal of(String value) { + public static Literal ofString(String value) { return new Literal(value, StringType.INSTANCE); } /** - * @return a {@link Literal} with data type {@link BinaryType} + * Create a {@code binary} type literal expression. + * + * @param value binary literal value as an array of bytes + * @return a {@link Literal} of type {@link BinaryType} */ - public static Literal of(byte[] value) { + public static Literal ofBinary(byte[] value) { return new Literal(value, BinaryType.INSTANCE); } /** - * @return a {@link Literal} with data type {@link DateType} + * Create a {@code date} type literal expression. + * + * @param daysSinceEpochUTC number of days since the epoch in UTC timezone. + * @return a {@link Literal} of type {@link DateType} */ - public static Literal of(Date value) { - return new Literal(value, DateType.INSTANCE); + public static Literal ofDate(int daysSinceEpochUTC) { + return new Literal(daysSinceEpochUTC, DateType.INSTANCE); } /** + * Create a {@code timestamp} type literal expression. + * + * @param microsSinceEpochUTC microseconds since epoch time in UTC timezone. * @return a {@link Literal} with data type {@link TimestampType} */ - public static Literal of(Timestamp value) { - return new Literal(value, TimestampType.INSTANCE); + public static Literal ofTimestamp(long microsSinceEpochUTC) { + return new Literal(microsSinceEpochUTC, TimestampType.INSTANCE); } /** - * @return a {@link Literal} with data type {@link DecimalType} + * Create a {@code decimal} type literal expression. + * + * @param value decimal literal value + * @param precision precision of the decimal literal + * @param scale scale of the decimal literal + * @return a {@link Literal} with data type {@link DecimalType} with given {@code precision} + * and {@code scale}. */ - public static Literal of(BigDecimal value, int precision, int scale) { + public static Literal ofDecimal(BigDecimal value, int precision, int scale) { // throws an error if rounding is required to set the specified scale BigDecimal valueToStore = value.setScale(scale); checkArgument(valueToStore.precision() <= precision, String.format( - "Decimal precision=%s for decimal %s exceeds max precision %s", - valueToStore.precision(), valueToStore, precision)); + "Decimal precision=%s for decimal %s exceeds max precision %s", + valueToStore.precision(), valueToStore, precision)); return new Literal(valueToStore, new DecimalType(precision, scale)); } /** + * Create {@code null} value literal. + * + * @param dataType {@link DataType} of the null literal. * @return a null {@link Literal} with the given data type */ public static Literal ofNull(DataType dataType) { - if (dataType instanceof ArrayType - || dataType instanceof MapType - || dataType instanceof StructType) { - throw new IllegalArgumentException( - dataType + " is an invalid data type for Literal."); - } return new Literal(null, dataType); } @@ -153,21 +184,46 @@ public static Literal ofNull(DataType dataType) { private final DataType dataType; private Literal(Object value, DataType dataType) { + if (dataType instanceof ArrayType + || dataType instanceof MapType + || dataType instanceof StructType) { + throw new IllegalArgumentException(dataType + " is an invalid data type for Literal."); + } this.value = value; this.dataType = dataType; } - public Object value() { - return value; - } - - @Override - public Object eval(Row record) { + /** + * Get the literal value. If the value is null a {@code null} is returned. For non-null + * literal the returned value is one of the following types based on the literal data type. + * + *
    + *
  • BOOLEAN: {@link Boolean}
  • + *
  • BYTE: {@link Byte}
  • + *
  • SHORT: {@link Short}
  • + *
  • INTEGER: {@link Integer}
  • + *
  • LONG: {@link Long}
  • + *
  • FLOAT: {@link Float}
  • + *
  • DOUBLE: {@link Double}
  • + *
  • DATE: {@link Integer} represents the number of days since epoch in UTC
  • + *
  • TIMESTAMP: {@link Long} represents the microseconds since epoch in UTC
  • + *
  • DECIMAL: {@link BigDecimal}.Use {@link #getDataType()} to find the precision and + * scale
  • + *
+ * + * @return Literal value. + */ + public Object getValue() { return value; } - @Override - public DataType dataType() { + /** + * Get the datatype of the literal object. Datatype lets the caller interpret the value of the + * literal object returned by {@link #getValue()} + * + * @return Datatype of the literal object. + */ + public DataType getDataType() { return dataType; } @@ -177,20 +233,7 @@ public String toString() { } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - Literal literal = (Literal) o; - return Objects.equals(value, literal.value) && - Objects.equals(dataType, literal.dataType); - } - - @Override - public int hashCode() { - return Objects.hash(value, dataType); + public List getChildren() { + return Collections.emptyList(); } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Or.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Or.java new file mode 100644 index 00000000000..e8e2fa37572 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Or.java @@ -0,0 +1,58 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.expressions; + +import java.util.Arrays; + +import io.delta.kernel.annotation.Evolving; + +/** + * {@code OR} expression + *

+ * Definition: + *

    + *
  • Logical {@code expr1} OR {@code expr2} on two inputs.
  • + *
  • Requires both left and right input expressions of type {@link Predicate}.
  • + *
  • Result is null at least one of the inputs is null.
  • + *
+ * + * @since 3.0.0 + */ +@Evolving +public final class Or extends Predicate { + public Or(Predicate left, Predicate right) { + super("OR", Arrays.asList(left, right)); + } + + /** + * @return Left side operand. + */ + public Predicate getLeft() { + return (Predicate) getChildren().get(0); + } + + /** + * @return Right side operand. + */ + public Predicate getRight() { + return (Predicate) getChildren().get(1); + } + + @Override + public String toString() { + return "(" + getLeft() + " OR " + getRight() + ")"; + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java index a26a01273f9..8f98005a9f5 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java @@ -13,18 +13,95 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.expressions; -import io.delta.kernel.types.BooleanType; -import io.delta.kernel.types.DataType; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import io.delta.kernel.annotation.Evolving; +import io.delta.kernel.client.ExpressionHandler; /** - * An {@link Expression} that defines a relation on inputs. Evaluates to true, false, or null. + * Defines predicate scalar expression which is an extension of {@link ScalarExpression} + * that evaluates to true, false, or null for each input row. + *

+ * Currently, implementations of {@link ExpressionHandler} requires support for at least the + * following scalar expressions. + *

    + *
  1. Name: = + *
      + *
    • SQL semantic: expr1 = expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  2. + *
  3. Name: < + *
      + *
    • SQL semantic: expr1 < expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  4. + *
  5. Name: <= + *
      + *
    • SQL semantic: expr1 <= expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  6. + *
  7. Name: > + *
      + *
    • SQL semantic: expr1 > expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  8. + *
  9. Name: >= + *
      + *
    • SQL semantic: expr1 >= expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  10. + *
  11. Name: ALWAYS_TRUE + *
      + *
    • SQL semantic: Constant expression whose value is `true`
    • + *
    • Since version: 3.0.0
    • + *
    + *
  12. + *
  13. Name: ALWAYS_FALSE + *
      + *
    • SQL semantic: Constant expression whose value is `false`
    • + *
    • Since version: 3.0.0
    • + *
    + *
  14. + *
  15. Name: AND + *
      + *
    • SQL semantic: expr1 AND expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  16. + *
  17. Name: OR + *
      + *
    • SQL semantic: expr1 OR expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  18. + *
+ * + * @since 3.0.0 */ -public interface Predicate extends Expression { +@Evolving +public class Predicate extends ScalarExpression { + public Predicate(String name, List children) { + super(name, children); + } + @Override - default DataType dataType() { - return BooleanType.INSTANCE; + public String toString() { + if (COMPARATORS.contains(name)) { + return String.format("(%s %s %s)", children.get(0), name, children.get(1)); + } + return super.toString(); } + + private static final Set COMPARATORS = + Stream.of("<", "<=", ">", ">=", "=").collect(Collectors.toSet()); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java new file mode 100644 index 00000000000..7b8dd76a1dd --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java @@ -0,0 +1,68 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.expressions; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.stream.Collectors; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.annotation.Evolving; + +/** + * Scalar SQL expressions which take zero or more inputs and for each input row generate one + * output value. A subclass of these expressions are of type {@link Predicate} whose result type is + * `boolean`. See {@link Predicate} for predicate type scalar expressions. Supported + * non-predicate type scalar expressions are listed below. + * TODO: Currently there aren't any. Will be added in future. An example one looks like this: + *
    + *
  1. Name: + + *
      + *
    • SQL semantic: expr1 + expr2
    • + *
    • Since version: 3.0.0
    • + *
    + *
  2. + *
+ * + * @since 3.0.0 + */ +@Evolving +public class ScalarExpression implements Expression { + protected final String name; + protected final List children; + + public ScalarExpression(String name, List children) { + this.name = requireNonNull(name, "name is null").toUpperCase(Locale.ENGLISH); + this.children = Collections.unmodifiableList(new ArrayList<>(children)); + } + + @Override + public String toString() { + return String.format("%s(%s)", name, + children.stream().map(Object::toString).collect(Collectors.joining(", "))); + } + + public String getName() { + return name; + } + + @Override + public List getChildren() { + return children; + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java index c8e2e3521c9..aeb2a1574fe 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java @@ -22,7 +22,7 @@ import io.delta.kernel.Scan; import io.delta.kernel.ScanBuilder; import io.delta.kernel.client.TableClient; -import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Predicate; import io.delta.kernel.types.StructType; import io.delta.kernel.types.TimestampType; import io.delta.kernel.utils.CloseableIterator; @@ -47,7 +47,7 @@ public class ScanBuilderImpl private final Path dataPath; private StructType readSchema; - private Optional filter; + private Optional predicate; public ScanBuilderImpl( Path dataPath, @@ -62,15 +62,15 @@ public ScanBuilderImpl( this.tableClient = tableClient; this.readSchema = snapshotSchema; - this.filter = Optional.empty(); + this.predicate = Optional.empty(); } @Override - public ScanBuilder withFilter(TableClient tableClient, Expression filter) { - if (this.filter.isPresent()) { + public ScanBuilder withFilter(TableClient tableClient, Predicate predicate) { + if (this.predicate.isPresent()) { throw new IllegalArgumentException("There already exists a filter in current builder"); } - this.filter = Optional.of(filter); + this.predicate = Optional.of(predicate); return this; } @@ -104,7 +104,7 @@ public Scan build() { readSchema, protocolAndMetadata, filesIter, - filter, + predicate, dataPath); } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java index 6705b16e762..7fa829d7862 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java @@ -26,7 +26,7 @@ import io.delta.kernel.client.TableClient; import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.data.Row; -import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Predicate; import io.delta.kernel.types.StructType; import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.Tuple2; @@ -60,7 +60,7 @@ public class ScanImpl private final StructType readSchema; private final CloseableIterator filesIter; private final Lazy> protocolAndMetadata; - private final Optional filter; + private final Optional filter; private boolean accessedScanFiles; @@ -69,7 +69,7 @@ public ScanImpl( StructType readSchema, Lazy> protocolAndMetadata, CloseableIterator filesIter, - Optional filter, + Optional filter, Path dataPath) { this.snapshotSchema = snapshotSchema; this.readSchema = readSchema; @@ -159,7 +159,7 @@ public Row getScanState(TableClient tableClient) { } @Override - public Optional getRemainingFilter() { + public Optional getRemainingFilter() { return filter; } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/checkpoints/Checkpointer.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/checkpoints/Checkpointer.java index 0824ea3167b..276128067d1 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/checkpoints/Checkpointer.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/checkpoints/Checkpointer.java @@ -26,7 +26,7 @@ import io.delta.kernel.client.TableClient; import io.delta.kernel.data.FileDataReadResult; import io.delta.kernel.data.Row; -import io.delta.kernel.expressions.Literal; +import io.delta.kernel.expressions.AlwaysTrue; import io.delta.kernel.fs.FileStatus; import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.Utils; @@ -106,7 +106,7 @@ private Optional loadMetadataFromFile(TableClient tableClien jsonHandler.contextualizeFileReads( Utils.singletonCloseableIterator( InternalUtils.getScanFileRow(lastCheckpointFile)), - Literal.TRUE + AlwaysTrue.ALWAYS_TRUE ); CloseableIterator jsonIter = tableClient.getJsonHandler().readJsonFiles( diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/ActionsIterator.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/ActionsIterator.java index fcd4ec481ce..92daef0884e 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/ActionsIterator.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/ActionsIterator.java @@ -19,16 +19,22 @@ import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; -import java.util.*; - -import io.delta.kernel.client.*; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Optional; + +import io.delta.kernel.client.FileReadContext; +import io.delta.kernel.client.JsonHandler; +import io.delta.kernel.client.ParquetHandler; +import io.delta.kernel.client.TableClient; import io.delta.kernel.data.FileDataReadResult; -import io.delta.kernel.expressions.Literal; import io.delta.kernel.fs.FileStatus; import io.delta.kernel.types.StructType; import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.Tuple2; import io.delta.kernel.utils.Utils; +import static io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE; import io.delta.kernel.internal.util.InternalUtils; @@ -37,16 +43,15 @@ * iterator of (FileDataReadResult, isFromCheckpoint) tuples, where the schema of the * FileDataReadResult semantically represents actions (or, a subset of action fields) parsed from * the Delta Log. - * + *

* Users must pass in a `readSchema` to select which actions and sub-fields they want to consume. */ class ActionsIterator implements CloseableIterator> { - private final TableClient tableClient; /** * Iterator over the files. - * + *

* Each file will be split (by 1, or more) to yield an iterator of FileDataReadResults. */ private final Iterator filesIter; @@ -56,7 +61,7 @@ class ActionsIterator implements CloseableIterator * If it is ever empty, that means there are no more batches to produce. */ private Optional>> @@ -65,9 +70,9 @@ class ActionsIterator implements CloseableIterator files, - StructType readSchema) { + TableClient tableClient, + List files, + StructType readSchema) { this.tableClient = tableClient; this.filesIter = files.iterator(); this.readSchema = readSchema; @@ -90,7 +95,7 @@ public boolean hasNext() { /** * @return a tuple of (FileDataReadResult, isFromCheckpoint), where FileDataReadResult conforms - * to the instance {@link #readSchema}. + * to the instance {@link #readSchema}. */ @Override public Tuple2 next() { @@ -98,7 +103,9 @@ public Tuple2 next() { throw new IllegalStateException("Can't call `next` on a closed iterator."); } - if (!hasNext()) throw new NoSuchElementException("No next element"); + if (!hasNext()) { + throw new NoSuchElementException("No next element"); + } return actionsIter.get().next(); } @@ -151,7 +158,7 @@ private void tryEnsureNextActionsIterIsReady() { * Get the next file from `filesIter` (.json or .checkpoint.parquet), contextualize it * (allow the connector to split it), and then read it + inject the `isFromCheckpoint` * information. - * + *

* Requires that `filesIter.hasNext` is true. */ private CloseableIterator> getNextActionsIter() { @@ -171,7 +178,7 @@ private CloseableIterator> getNextActionsIte jsonHandler.contextualizeFileReads( Utils.singletonCloseableIterator( InternalUtils.getScanFileRow(nextFile)), - Literal.TRUE + ALWAYS_TRUE ); iteratorsToClose[0] = fileReadContextIter; @@ -197,7 +204,7 @@ private CloseableIterator> getNextActionsIte parquetHandler.contextualizeFileReads( Utils.singletonCloseableIterator( InternalUtils.getScanFileRow(nextFile)), - Literal.TRUE); + ALWAYS_TRUE); iteratorsToClose[0] = fileReadContextIter; @@ -232,8 +239,8 @@ private CloseableIterator> getNextActionsIte * Take input (iterator, boolean) and produce an iterator. */ private CloseableIterator> combine( - CloseableIterator fileReadDataIter, - boolean isFromCheckpoint) { + CloseableIterator fileReadDataIter, + boolean isFromCheckpoint) { return new CloseableIterator>() { @Override public boolean hasNext() { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/InternalUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/InternalUtils.java index 51a97dababb..84b03e74a58 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/InternalUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/InternalUtils.java @@ -16,6 +16,9 @@ package io.delta.kernel.internal.util; import java.io.IOException; +import java.sql.Date; +import java.time.LocalDate; +import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.Optional; @@ -28,6 +31,8 @@ import io.delta.kernel.internal.data.AddFileColumnarBatch; public class InternalUtils { + private static final LocalDate EPOCH = LocalDate.ofEpochDay(0); + private InternalUtils() {} public static Row getScanFileRow(FileStatus fileStatus) { @@ -116,4 +121,28 @@ public static void checkArgument(boolean isValid, String message) throw new IllegalArgumentException(message); } } + + /** + * Precondition-style validation that throws {@link IllegalArgumentException}. + * + * @param isValid {@code true} if valid, {@code false} if an exception should be thrown + * @param message A String message for the exception. + * @param args Objects used to fill in {@code %s} placeholders in the message + * @throws IllegalArgumentException if {@code isValid} is false + */ + public static void checkArgument(boolean isValid, String message, Object... args) + throws IllegalArgumentException { + if (!isValid) { + throw new IllegalArgumentException( + String.format(String.valueOf(message), args)); + } + } + + /** + * Utility method to get the number of days since epoch this given date is. + */ + public static int daysSinceEpoch(Date date) { + LocalDate localDate = date.toLocalDate(); + return (int) ChronoUnit.DAYS.between(EPOCH, localDate); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java index 86bb30834da..008d39a15c4 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java @@ -88,8 +88,8 @@ public static ColumnarBatch withPartitionColumns( dataBatchSchema, literalForPartitionValue( structField.getDataType(), - partitionValues.get(structField.getName()) - ) + partitionValues.get(structField.getName())), + structField.getDataType() ); ColumnVector partitionVector = evaluator.eval(dataBatch); @@ -106,38 +106,38 @@ private static Literal literalForPartitionValue(DataType dataType, String partit } if (dataType instanceof BooleanType) { - return Literal.of(Boolean.parseBoolean(partitionValue)); + return Literal.ofBoolean(Boolean.parseBoolean(partitionValue)); } if (dataType instanceof ByteType) { - return Literal.of(Byte.parseByte(partitionValue)); + return Literal.ofByte(Byte.parseByte(partitionValue)); } if (dataType instanceof ShortType) { - return Literal.of(Short.parseShort(partitionValue)); + return Literal.ofShort(Short.parseShort(partitionValue)); } if (dataType instanceof IntegerType) { - return Literal.of(Integer.parseInt(partitionValue)); + return Literal.ofInt(Integer.parseInt(partitionValue)); } if (dataType instanceof LongType) { - return Literal.of(Long.parseLong(partitionValue)); + return Literal.ofLong(Long.parseLong(partitionValue)); } if (dataType instanceof FloatType) { - return Literal.of(Float.parseFloat(partitionValue)); + return Literal.ofFloat(Float.parseFloat(partitionValue)); } if (dataType instanceof DoubleType) { - return Literal.of(Double.parseDouble(partitionValue)); + return Literal.ofDouble(Double.parseDouble(partitionValue)); } if (dataType instanceof StringType) { - return Literal.of(partitionValue); + return Literal.ofString(partitionValue); } if (dataType instanceof BinaryType) { - return Literal.of(partitionValue.getBytes()); + return Literal.ofBinary(partitionValue.getBytes()); } if (dataType instanceof DateType) { - return Literal.of(Date.valueOf(partitionValue)); + return Literal.ofDate(InternalUtils.daysSinceEpoch(Date.valueOf(partitionValue))); } if (dataType instanceof DecimalType) { DecimalType decimalType = (DecimalType) dataType; - return Literal.of( + return Literal.ofDecimal( new BigDecimal(partitionValue), decimalType.getPrecision(), decimalType.getScale()); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/BooleanType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/BooleanType.java index efbfb45e154..83d41d34805 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/BooleanType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/BooleanType.java @@ -24,6 +24,8 @@ */ @Evolving public class BooleanType extends BasePrimitiveType { + // TODO: Should remove the `INSTANCE` to `BOOLEAN` so that it can be static imported where + // needed and referred without the `BooleanType.` prefix. Same for other types. public static final BooleanType INSTANCE = new BooleanType(); private BooleanType() { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java index 3f5dfb274e5..f10d2c10498 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java @@ -87,6 +87,6 @@ public String toJson() { @Override public String toString() { - return String.format("Map[%s, %s]", keyType, valueType); + return String.format("map[%s, %s]", keyType, valueType); } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java index babd4c1c0b8..0af04948377 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java @@ -114,19 +114,7 @@ public StructField at(int index) { */ public Column column(int ordinal) { final StructField field = at(ordinal); - return new Column(ordinal, field.getName(), field.getDataType()); - } - - /** - * Creates a {@link Column} expression for the field with the given {@code fieldName}. - * - * @param fieldName the name of the {@link StructField} to create a column for - * @return a {@link Column} expression for the {@link StructField} with name {@code fieldName} - */ - public Column column(String fieldName) { - Tuple2 fieldAndOrdinal = nameToFieldAndOrdinal.get(fieldName); - System.out.println("Created column " + fieldName + " with ordinal " + fieldAndOrdinal._2); - return new Column(fieldAndOrdinal._2, fieldName, fieldAndOrdinal._1.getDataType()); + return new Column(field.getName()); } @Override @@ -146,8 +134,7 @@ public boolean equivalent(DataType dataType) { @Override public String toString() { return String.format( - "%s(%s)", - getClass().getSimpleName(), + "struct(%s)", fields.stream().map(StructField::toString).collect(Collectors.joining(", ")) ); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/utils/CloseableIterator.java b/kernel/kernel-api/src/main/java/io/delta/kernel/utils/CloseableIterator.java index 762ae5c6fc9..3e0cf88a16a 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/utils/CloseableIterator.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/utils/CloseableIterator.java @@ -25,7 +25,7 @@ import io.delta.kernel.annotation.Evolving; /** - * Closeable extension of {@link Iterator} + * Closeable extension of {@link Iterator} * * @param the type of elements returned by this iterator * @since 3.0.0 diff --git a/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java b/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java index 6e261afc88e..fa023a8cfe5 100644 --- a/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java +++ b/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java @@ -33,7 +33,7 @@ import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.data.FileDataReadResult; import io.delta.kernel.data.Row; -import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Predicate; import io.delta.kernel.types.ArrayType; import io.delta.kernel.types.BooleanType; import io.delta.kernel.types.DataType; @@ -56,7 +56,7 @@ public class JsonHandlerTestImpl @Override public CloseableIterator contextualizeFileReads( - CloseableIterator fileIter, Expression predicate) { + CloseableIterator fileIter, Predicate predicate) { throw new UnsupportedOperationException("not yet implemented"); } diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/expressions/ExpressionsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/expressions/ExpressionsSuite.scala new file mode 100644 index 00000000000..4e13301ee9e --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/expressions/ExpressionsSuite.scala @@ -0,0 +1,38 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.expressions + +import io.delta.kernel.types._ +import org.scalatest.funsuite.AnyFunSuite + +class ExpressionsSuite extends AnyFunSuite { + test("expressions: unsupported literal data types") { + val ex1 = intercept[IllegalArgumentException] { + Literal.ofNull(new ArrayType(IntegerType.INSTANCE, true)) + } + assert(ex1.getMessage.contains("array[integer] is an invalid data type for Literal.")) + + val ex2 = intercept[IllegalArgumentException] { + Literal.ofNull(new MapType(IntegerType.INSTANCE, IntegerType.INSTANCE, true)) + } + assert(ex2.getMessage.contains("map[integer, integer] is an invalid data type for Literal.")) + + val ex3 = intercept[IllegalArgumentException] { + Literal.ofNull(new StructType().add("s1", BooleanType.INSTANCE)) + } + assert(ex3.getMessage.matches("struct.* is an invalid data type for Literal.")) + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultExpressionHandler.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultExpressionHandler.java index 59fe39b7401..ae3957a13a0 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultExpressionHandler.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultExpressionHandler.java @@ -13,110 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.delta.kernel.defaults.client; -import java.sql.Date; -import java.util.Optional; - import io.delta.kernel.client.ExpressionHandler; -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.data.Row; import io.delta.kernel.expressions.Expression; import io.delta.kernel.expressions.ExpressionEvaluator; -import io.delta.kernel.expressions.Literal; -import io.delta.kernel.types.*; -import io.delta.kernel.utils.CloseableIterator; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.StructType; -import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector; -import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector; -import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; -import static io.delta.kernel.defaults.internal.DefaultKernelUtils.daysSinceEpoch; +import io.delta.kernel.defaults.internal.expressions.DefaultExpressionEvaluator; /** * Default implementation of {@link ExpressionHandler} */ public class DefaultExpressionHandler implements ExpressionHandler { @Override - public ExpressionEvaluator getEvaluator(StructType batchSchema, Expression expression) { - return new DefaultExpressionEvaluator(expression); - } - - private static ColumnVector evalBooleanOutputExpression( - ColumnarBatch input, Expression expression) { - checkArgument(expression.dataType().equals(BooleanType.INSTANCE), - "expression should return a boolean"); - - final int batchSize = input.getSize(); - boolean[] result = new boolean[batchSize]; - boolean[] nullResult = new boolean[batchSize]; - CloseableIterator rows = input.getRows(); - for (int currentIndex = 0; currentIndex < batchSize; currentIndex++) { - Object evalResult = expression.eval(rows.next()); - if (evalResult == null) { - nullResult[currentIndex] = true; - } else { - result[currentIndex] = ((Boolean) evalResult).booleanValue(); - } - } - return new DefaultBooleanVector(batchSize, Optional.of(nullResult), result); - } - - private static ColumnVector evalLiteralExpression(ColumnarBatch input, Literal literal) { - Object result = literal.value(); - DataType dataType = literal.dataType(); - int size = input.getSize(); - - if (result == null) { - return new DefaultConstantVector(dataType, size, null); - } - - if (dataType instanceof BooleanType || - dataType instanceof ByteType || - dataType instanceof ShortType || - dataType instanceof IntegerType || - dataType instanceof LongType || - dataType instanceof FloatType || - dataType instanceof DoubleType || - dataType instanceof StringType || - dataType instanceof BinaryType || - dataType instanceof DecimalType) { - return new DefaultConstantVector(dataType, size, result); - } else if (dataType instanceof DateType) { - int numOfDaysSinceEpoch = daysSinceEpoch((Date) result); - return new DefaultConstantVector(dataType, size, numOfDaysSinceEpoch); - } - // TODO: support timestamptype - - throw new UnsupportedOperationException( - "unsupported expression encountered: " + literal); - } - - private static class DefaultExpressionEvaluator - implements ExpressionEvaluator { - private final Expression expression; - - private DefaultExpressionEvaluator(Expression expression) { - this.expression = expression; - } - - @Override - public ColumnVector eval(ColumnarBatch input) { - if (expression instanceof Literal) { - return evalLiteralExpression(input, (Literal) expression); - } - - if (expression.dataType().equals(BooleanType.INSTANCE)) { - return evalBooleanOutputExpression(input, expression); - } - // TODO: Boolean output type expressions are good enough for first preview release - // which enables partition pruning and file skipping using file stats. - - throw new UnsupportedOperationException("not yet implemented"); - } - - @Override - public void close() { /* nothing to close */ } + public ExpressionEvaluator getEvaluator( + StructType inputSchema, + Expression expression, + DataType outputType) { + return new DefaultExpressionEvaluator(inputSchema, expression, outputType); } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultFileHandler.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultFileHandler.java index 93a60a0414a..1b497f09ebb 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultFileHandler.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/client/DefaultFileHandler.java @@ -20,7 +20,7 @@ import io.delta.kernel.client.FileHandler; import io.delta.kernel.client.FileReadContext; import io.delta.kernel.data.Row; -import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Predicate; import io.delta.kernel.utils.CloseableIterator; /** @@ -29,7 +29,7 @@ public class DefaultFileHandler implements FileHandler { @Override public CloseableIterator contextualizeFileReads( - CloseableIterator fileIter, Expression filter) { + CloseableIterator fileIter, Predicate filter) { requireNonNull(fileIter, "fileIter is null"); requireNonNull(filter, "filter is null"); // TODO: we are not using the filter now, will be used later. diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java index 7f853d8e2f3..793878c02a9 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/DefaultKernelUtils.java @@ -15,9 +15,7 @@ */ package io.delta.kernel.defaults.internal; -import java.sql.Date; import java.time.LocalDate; -import java.time.temporal.ChronoUnit; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -74,6 +72,8 @@ public static Type findSubFieldType(GroupType groupType, StructField field) { return null; } + // TODO: Move these precondition checks into a separate utility class. + /** * Precondition-style validation that throws {@link IllegalArgumentException}. * @@ -117,20 +117,6 @@ public static void checkArgument(boolean isValid, String message, Object... args } } - /** - * Precondition-style validation that throws {@link IllegalStateException}. - * - * @param isValid {@code true} if valid, {@code false} if an exception should be thrown - * @param message A String message for the exception. - * @throws IllegalStateException if {@code isValid} is false - */ - public static void checkState(boolean isValid, String message) - throws IllegalStateException { - if (!isValid) { - throw new IllegalStateException(message); - } - } - private static List pruneFields(GroupType type, StructType deltaDataType) { // prune fields including nested pruning like in pruneSchema return deltaDataType.fields().stream() @@ -156,16 +142,6 @@ private static Type prunedType(Type type, DataType deltaType) { } } - /** - * Utility method to get the number of days since epoch this given date is. - * - * @param date - */ - public static int daysSinceEpoch(Date date) { - LocalDate localDate = date.toLocalDate(); - return (int) ChronoUnit.DAYS.between(EPOCH, localDate); - } - ////////////////////////////////////////////////////////////////////////////////// // Below utils are adapted from org.apache.spark.sql.catalyst.util.DateTimeUtils ////////////////////////////////////////////////////////////////////////////////// diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java new file mode 100644 index 00000000000..b8cd2899026 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -0,0 +1,370 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions; + +import java.util.Arrays; +import java.util.Optional; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.client.ExpressionHandler; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.ColumnarBatch; +import io.delta.kernel.expressions.*; +import io.delta.kernel.types.*; + +import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector; +import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector; +import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; +import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.compare; +import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.evalNullability; +import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo; + +/** + * Implementation of {@link ExpressionEvaluator} for default {@link ExpressionHandler}. + * It takes care of validating, adding necessary implicit casts and evaluating the + * {@link Expression} on given {@link ColumnarBatch}. + */ +public class DefaultExpressionEvaluator implements ExpressionEvaluator { + private final Expression expression; + + /** + * Create a {@link DefaultExpressionEvaluator} instance bound to the given expression and + * inputSchem. + * + * @param inputSchema Input data schema + * @param expression Expression to evaluate. + * @param outputType Expected result data type. + */ + public DefaultExpressionEvaluator( + StructType inputSchema, + Expression expression, + DataType outputType) { + ExpressionTransformResult transformResult = + new ExpressionTransformer(inputSchema).visit(expression); + if (!transformResult.outputType.equivalent(outputType)) { + throw new UnsupportedOperationException(format("Can not create an expression handler " + + "for expression `%s` returns result of type %s", expression, outputType)); + } + this.expression = transformResult.expression; + } + + @Override + public ColumnVector eval(ColumnarBatch input) { + return new ExpressionEvalVisitor(input).visit(expression); + } + + @Override + public void close() { /* nothing to close */ } + + /** + * Encapsulates the result of {@link ExpressionTransformer} + */ + private static class ExpressionTransformResult { + public final Expression expression; // transformed expression + public final DataType outputType; // output type of the expression + + ExpressionTransformResult(Expression expression, DataType outputType) { + this.expression = expression; + this.outputType = outputType; + } + } + + /** + * Implementation of {@link ExpressionVisitor} to validate the given expression as follows. + *

    + *
  • given input column is part of the input data schema
  • + *
  • expression inputs are of supported types. Insert cast according to the rules in + * {@link ImplicitCastExpression} to make the types compatible for evaluation by + * {@link ExpressionEvalVisitor} + *
  • + *
+ *

+ * Return type of each expression visit is a tuple of new rewritten expression and its result + * data type. + */ + private static class ExpressionTransformer + extends ExpressionVisitor { + private StructType inputDataSchema; + + ExpressionTransformer(StructType inputDataSchema) { + this.inputDataSchema = requireNonNull(inputDataSchema, "inputDataSchema is null"); + } + + @Override + ExpressionTransformResult visitAnd(And and) { + Predicate left = validateIsPredicate(and, visit(and.getLeft())); + Predicate right = validateIsPredicate(and, visit(and.getRight())); + return new ExpressionTransformResult(new And(left, right), BooleanType.INSTANCE); + } + + @Override + ExpressionTransformResult visitOr(Or or) { + Predicate left = validateIsPredicate(or, visit(or.getLeft())); + Predicate right = validateIsPredicate(or, visit(or.getRight())); + return new ExpressionTransformResult(new Or(left, right), BooleanType.INSTANCE); + } + + @Override + ExpressionTransformResult visitAlwaysTrue(AlwaysTrue alwaysTrue) { + // nothing to validate or rewrite. + return new ExpressionTransformResult(alwaysTrue, BooleanType.INSTANCE); + } + + @Override + ExpressionTransformResult visitAlwaysFalse(AlwaysFalse alwaysFalse) { + // nothing to validate or rewrite. + return new ExpressionTransformResult(alwaysFalse, BooleanType.INSTANCE); + } + + @Override + ExpressionTransformResult visitComparator(Predicate predicate) { + switch (predicate.getName()) { + case "=": + case ">": + case ">=": + case "<": + case "<=": + return new ExpressionTransformResult( + transformBinaryComparator(predicate), + BooleanType.INSTANCE); + default: + throw new UnsupportedOperationException( + "unsupported expression encountered: " + predicate); + } + } + + @Override + ExpressionTransformResult visitLiteral(Literal literal) { + // nothing to validate or rewrite + return new ExpressionTransformResult(literal, literal.getDataType()); + } + + @Override + ExpressionTransformResult visitColumn(Column column) { + int ordinal = inputDataSchema.indexOf(column.getName()); + if (ordinal == -1) { + throw new IllegalArgumentException( + format("Column `%s` doesn't exist in input data schema: %s", + column.getName(), inputDataSchema)); + } + return new ExpressionTransformResult(column, inputDataSchema.at(ordinal).getDataType()); + } + + @Override + ExpressionTransformResult visitCast(ImplicitCastExpression cast) { + throw new UnsupportedOperationException("CAST expression is not expected."); + } + + private Predicate validateIsPredicate( + Expression baseExpression, + ExpressionTransformResult result) { + checkArgument( + result.outputType instanceof BooleanType && + result.expression instanceof Predicate, + "%s: expected a predicate expression but got %s with output type %s.", + baseExpression, + result.expression, + result.outputType); + return (Predicate) result.expression; + } + + private Expression transformBinaryComparator(Predicate predicate) { + checkArgument(predicate.getChildren().size() == 2, "expected two inputs"); + ExpressionTransformResult leftResult = visit(predicate.getChildren().get(0)); + ExpressionTransformResult rightResult = visit(predicate.getChildren().get(1)); + Expression left = leftResult.expression; + Expression right = rightResult.expression; + if (!leftResult.outputType.equivalent(rightResult.outputType)) { + if (canCastTo(leftResult.outputType, rightResult.outputType)) { + left = new ImplicitCastExpression(left, rightResult.outputType); + } else if (canCastTo(rightResult.outputType, leftResult.outputType)) { + right = new ImplicitCastExpression(right, leftResult.outputType); + } else { + String msg = format("%s: operands are of different types which are not " + + "comparable: left type=%s, right type=%s", + predicate, leftResult.outputType, rightResult.outputType); + throw new UnsupportedOperationException(msg); + } + } + return new Predicate(predicate.getName(), Arrays.asList(left, right)); + } + } + + /** + * Implementation of {@link ExpressionVisitor} to evaluate expression on a + * {@link ColumnarBatch}. + */ + private static class ExpressionEvalVisitor extends ExpressionVisitor { + private final ColumnarBatch input; + + ExpressionEvalVisitor(ColumnarBatch input) { + this.input = input; + } + + @Override + ColumnVector visitAnd(And and) { + PredicateChildrenEvalResult argResults = evalBinaryExpressionChildren(and); + int numRows = argResults.rowCount; + boolean[] result = new boolean[numRows]; + boolean[] nullability = evalNullability(argResults.leftResult, argResults.rightResult); + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = argResults.leftResult.getBoolean(rowId) && + argResults.rightResult.getBoolean(rowId); + } + return new DefaultBooleanVector(numRows, Optional.of(nullability), result); + } + + @Override + ColumnVector visitOr(Or or) { + PredicateChildrenEvalResult argResults = evalBinaryExpressionChildren(or); + int numRows = argResults.rowCount; + boolean[] result = new boolean[numRows]; + boolean[] nullability = evalNullability(argResults.leftResult, argResults.rightResult); + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = argResults.leftResult.getBoolean(rowId) || + argResults.rightResult.getBoolean(rowId); + } + return new DefaultBooleanVector(numRows, Optional.of(nullability), result); + } + + @Override + ColumnVector visitAlwaysTrue(AlwaysTrue alwaysTrue) { + return new DefaultConstantVector(BooleanType.INSTANCE, input.getSize(), true); + } + + @Override + ColumnVector visitAlwaysFalse(AlwaysFalse alwaysFalse) { + return new DefaultConstantVector(BooleanType.INSTANCE, input.getSize(), false); + } + + @Override + ColumnVector visitComparator(Predicate predicate) { + PredicateChildrenEvalResult argResults = evalBinaryExpressionChildren(predicate); + + int numRows = argResults.rowCount; + boolean[] result = new boolean[numRows]; + boolean[] nullability = evalNullability(argResults.leftResult, argResults.rightResult); + int[] compareResult = compare(argResults.leftResult, argResults.rightResult); + switch (predicate.getName()) { + case "=": + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = compareResult[rowId] == 0; + } + break; + case ">": + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = compareResult[rowId] > 0; + } + break; + case ">=": + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = compareResult[rowId] >= 0; + } + break; + case "<": + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = compareResult[rowId] < 0; + } + break; + case "<=": + for (int rowId = 0; rowId < numRows; rowId++) { + result[rowId] = compareResult[rowId] <= 0; + } + break; + default: + throw new UnsupportedOperationException( + "unsupported expression encountered: " + predicate); + } + + return new DefaultBooleanVector(numRows, Optional.of(nullability), result); + } + + @Override + ColumnVector visitLiteral(Literal literal) { + DataType dataType = literal.getDataType(); + if (dataType instanceof BooleanType || + dataType instanceof ByteType || + dataType instanceof ShortType || + dataType instanceof IntegerType || + dataType instanceof LongType || + dataType instanceof FloatType || + dataType instanceof DoubleType || + dataType instanceof StringType || + dataType instanceof BinaryType || + dataType instanceof DecimalType || + dataType instanceof DateType || + dataType instanceof TimestampType) { + return new DefaultConstantVector(dataType, input.getSize(), literal.getValue()); + } + + throw new UnsupportedOperationException( + "unsupported expression encountered: " + literal); + } + + @Override + ColumnVector visitColumn(Column column) { + int ordinal = input.getSchema().indexOf(column.getName()); + if (ordinal == -1) { + throw new IllegalArgumentException( + format("Column `%s` doesn't exist in input data schema: %s", + column.getName(), input.getSchema())); + } + return input.getColumnVector(ordinal); + } + + @Override + ColumnVector visitCast(ImplicitCastExpression cast) { + ColumnVector inputResult = visit(cast.getInput()); + return cast.eval(inputResult); + } + + /** + * Utility method to evaluate inputs to the binary input expression. Also validates the + * evaluated expression result {@link ColumnVector}s are of the same size. + * + * @param predicate + * @return Triplet of (result vector size, left operand result, left operand result) + */ + private PredicateChildrenEvalResult evalBinaryExpressionChildren(Predicate predicate) { + checkArgument(predicate.getChildren().size() == 2, "expected two inputs"); + ColumnVector left = visit(predicate.getChildren().get(0)); + ColumnVector right = visit(predicate.getChildren().get(1)); + checkArgument( + left.getSize() == right.getSize(), + "Left and right operand returned different results: left=%d, right=d", + left.getSize(), + right.getSize()); + return new PredicateChildrenEvalResult(left.getSize(), left, right); + } + } + + /** + * Encapsulates children expression result of binary input predicate + */ + private static class PredicateChildrenEvalResult { + public final int rowCount; + public final ColumnVector leftResult; + public final ColumnVector rightResult; + + PredicateChildrenEvalResult( + int rowCount, ColumnVector leftResult, ColumnVector rightResult) { + this.rowCount = rowCount; + this.leftResult = leftResult; + this.rightResult = rightResult; + } + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java new file mode 100644 index 00000000000..8997afc9fc3 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java @@ -0,0 +1,177 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions; + +import java.math.BigDecimal; +import java.util.Comparator; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.types.*; + +import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; + +/** + * Utility methods used by the default expression evaluator. + */ +class ExpressionUtils { + private ExpressionUtils() {} + + /** + * Utility method that calculates the nullability result from given two vectors. Result is + * null if at least one side is a null. + */ + static boolean[] evalNullability(ColumnVector left, ColumnVector right) { + int numRows = left.getSize(); + boolean[] nullability = new boolean[numRows]; + for (int rowId = 0; rowId < numRows; rowId++) { + nullability[rowId] = left.isNullAt(rowId) || right.isNullAt(rowId); + } + return nullability; + } + + /** + * Utility method to compare the left and right according to the natural ordering + * and return an integer array where each row contains the comparison result (-1, 0, 1) for + * corresponding rows in the input vectors compared. + * + * Only primitive data types are supported. + */ + static int[] compare(ColumnVector left, ColumnVector right) { + checkArgument( + left.getSize() == right.getSize(), + "Left and right operand have different vector sizes."); + DataType dataType = left.getDataType(); + + int numRows = left.getSize(); + int[] result = new int[numRows]; + if (dataType instanceof BooleanType) { + compareBoolean(left, right, result); + } else if (dataType instanceof ByteType) { + compareByte(left, right, result); + } else if (dataType instanceof ShortType) { + compareShort(left, right, result); + } else if (dataType instanceof IntegerType || dataType instanceof DateType) { + compareInt(left, right, result); + } else if (dataType instanceof LongType || dataType instanceof TimestampType) { + compareLong(left, right, result); + } else if (dataType instanceof FloatType) { + compareFloat(left, right, result); + } else if (dataType instanceof DoubleType) { + compareDouble(left, right, result); + } else if (dataType instanceof DecimalType) { + compareDecimal(left, right, result); + } else if (dataType instanceof StringType) { + compareString(left, right, result); + } else if (dataType instanceof BinaryType) { + compareBinary(left, right, result); + } else { + throw new UnsupportedOperationException(dataType + " can not be compared."); + } + return result; + } + + static void compareBoolean(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = Boolean.compare(left.getBoolean(rowId), right.getBoolean(rowId)); + } + } + } + + static void compareByte(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = Byte.compare(left.getByte(rowId), right.getByte(rowId)); + } + } + } + + static void compareShort(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = Short.compare(left.getShort(rowId), right.getShort(rowId)); + } + } + } + + static void compareInt(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = Integer.compare(left.getInt(rowId), right.getInt(rowId)); + } + } + } + + static void compareLong(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = Long.compare(left.getLong(rowId), right.getLong(rowId)); + } + } + } + + static void compareFloat(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = Float.compare(left.getFloat(rowId), right.getFloat(rowId)); + } + } + } + + static void compareDouble(ColumnVector left, ColumnVector right, int[] result) { + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = Double.compare(left.getDouble(rowId), right.getDouble(rowId)); + } + } + } + + static void compareString(ColumnVector left, ColumnVector right, int[] result) { + Comparator comparator = Comparator.naturalOrder(); + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = comparator.compare(left.getString(rowId), right.getString(rowId)); + } + } + } + + static void compareDecimal(ColumnVector left, ColumnVector right, int[] result) { + Comparator comparator = Comparator.naturalOrder(); + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = comparator.compare(left.getDecimal(rowId), right.getDecimal(rowId)); + } + } + } + + static void compareBinary(ColumnVector left, ColumnVector right, int[] result) { + Comparator comparator = (leftOp, rightOp) -> { + int i = 0; + while (i < leftOp.length && i < rightOp.length) { + if (leftOp[i] != rightOp[i]) { + return Byte.compare(leftOp[i], rightOp[i]); + } + i++; + } + return Integer.compare(leftOp.length, rightOp.length); + }; + for (int rowId = 0; rowId < left.getSize(); rowId++) { + if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { + result[rowId] = comparator.compare(left.getBinary(rowId), right.getBinary(rowId)); + } + } + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java new file mode 100644 index 00000000000..0e0e5000687 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -0,0 +1,102 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions; + +import java.util.List; +import java.util.Locale; +import static java.util.stream.Collectors.joining; + +import io.delta.kernel.expressions.*; +import static io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE; +import static io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE; + +/** + * Interface to allow visiting an expression tree and implementing handling for each + * specific expression type. + * + * @param Return type of result of visit expression methods. + */ +abstract class ExpressionVisitor { + + abstract R visitAnd(And and); + + abstract R visitOr(Or or); + + abstract R visitAlwaysTrue(AlwaysTrue alwaysTrue); + + abstract R visitAlwaysFalse(AlwaysFalse alwaysFalse); + + abstract R visitComparator(Predicate predicate); + + abstract R visitLiteral(Literal literal); + + abstract R visitColumn(Column column); + + abstract R visitCast(ImplicitCastExpression cast); + + final R visit(Expression expression) { + if (expression instanceof ScalarExpression) { + return visitScalarExpression((ScalarExpression) expression); + } else if (expression instanceof Literal) { + return visitLiteral((Literal) expression); + } else if (expression instanceof Column) { + return visitColumn((Column) expression); + } else if (expression instanceof ImplicitCastExpression) { + return visitCast((ImplicitCastExpression) expression); + } + + throw new UnsupportedOperationException( + String.format("Expression %s is not supported.", expression)); + } + + private R visitScalarExpression(ScalarExpression expression) { + List children = expression.getChildren(); + String name = expression.getName().toUpperCase(Locale.ENGLISH); + switch (name) { + case "ALWAYS_TRUE": + return visitAlwaysTrue(ALWAYS_TRUE); + case "ALWAYS_FALSE": + return visitAlwaysFalse(ALWAYS_FALSE); + case "AND": + return visitAnd( + new And(elemAsPredicate(children, 0), elemAsPredicate(children, 1))); + case "OR": + return visitOr(new Or(elemAsPredicate(children, 0), elemAsPredicate(children, 1))); + case "=": + case "<": + case "<=": + case ">": + case ">=": + return visitComparator(new Predicate(name, children)); + default: + throw new UnsupportedOperationException( + String.format("Scalar expression `%s` is not supported.", name)); + } + } + + private static Predicate elemAsPredicate(List expressions, int index) { + if (expressions.size() <= index) { + throw new RuntimeException( + String.format("Trying to access invalid entry (%d) in list %s", index, + expressions.stream().map(Object::toString).collect(joining(",")))); + } + Expression elemExpression = expressions.get(index); + if (!(elemExpression instanceof Predicate)) { + throw new RuntimeException("Expected a predicate, but got " + elemExpression); + } + return (Predicate) expressions.get(index); + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpression.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpression.java new file mode 100644 index 00000000000..63dca50a748 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpression.java @@ -0,0 +1,265 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions; + +import java.util.*; +import static java.lang.String.format; +import static java.util.Collections.unmodifiableMap; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.expressions.Expression; +import io.delta.kernel.types.DataType; + +import io.delta.kernel.defaults.client.DefaultExpressionHandler; + +/** + * An implicit cast expression to convert the input type to another given type. Here is the valid + * list of casts + *

+ *

    + *
  • {@code byte} to {@code short, int, long, float, double}
  • + *
  • {@code short} to {@code int, long, float, double}
  • + *
  • {@code int} to {@code long, float, double}
  • + *
  • {@code long} to {@code float, double}
  • + *
  • {@code float} to {@code double}
  • + *
+ * + *

+ * The above list is not exhaustive. Based on the need, we can add more casts. + *

+ * In {@link DefaultExpressionHandler} this is used when the operands of an expression are not of + * the same type, but the evaluator expects same type inputs. There could be more use cases, but + * for now this is the only use case. + */ +final class ImplicitCastExpression implements Expression { + private final Expression input; + private final DataType outputType; + + /** + * Create a cast around the given input expression to specified output data + * type. It is the responsibility of the caller to validate the input expression can be cast + * to the new type using {@link #canCastTo(DataType, DataType)} + */ + ImplicitCastExpression(Expression input, DataType outputType) { + this.input = requireNonNull(input, "input is null"); + this.outputType = requireNonNull(outputType, "outputType is null"); + } + + public Expression getInput() { + return input; + } + + public DataType getOutputType() { + return outputType; + } + + @Override + public List getChildren() { + return Collections.singletonList(input); + } + + /** + * Evaluate the given column expression on the input {@link ColumnVector}. + * + * @param input {@link ColumnVector} data of the input to the cast expression. + * @return {@link ColumnVector} result applying target type casting on every element in the + * input {@link ColumnVector}. + */ + ColumnVector eval(ColumnVector input) { + String fromTypeStr = input.getDataType().toString(); + switch (fromTypeStr) { + case "byte": + return new ByteUpConverter(outputType, input); + case "short": + return new ShortUpConverter(outputType, input); + case "integer": + return new IntUpConverter(outputType, input); + case "long": + return new LongUpConverter(outputType, input); + case "float": + return new FloatUpConverter(outputType, input); + default: + throw new UnsupportedOperationException( + format("Cast from %s is not supported", fromTypeStr)); + } + } + + /** + * Map containing for each type what are the target cast types can be. + */ + private static final Map> UP_CASTABLE_TYPE_TABLE = unmodifiableMap( + new HashMap>() { + { + this.put("byte", Arrays.asList("short", "integer", "long", "float", "double")); + this.put("short", Arrays.asList("integer", "long", "float", "double")); + this.put("integer", Arrays.asList("long", "float", "double")); + this.put("long", Arrays.asList("float", "double")); + this.put("float", Arrays.asList("double")); + } + }); + + /** + * Utility method which returns whether the given {@code from} type can be cast to {@code to} + * type. + */ + static boolean canCastTo(DataType from, DataType to) { + // TODO: The type name should be a first class method on `DataType` instead of getting it + // using the `toString`. + String fromStr = from.toString(); + String toStr = to.toString(); + return UP_CASTABLE_TYPE_TABLE.containsKey(fromStr) && + UP_CASTABLE_TYPE_TABLE.get(fromStr).contains(toStr); + } + + /** + * Base class for up casting {@link ColumnVector} data. + */ + private abstract static class UpConverter implements ColumnVector { + protected final DataType targetType; + protected final ColumnVector inputVector; + + UpConverter(DataType targetType, ColumnVector inputVector) { + this.targetType = targetType; + this.inputVector = inputVector; + } + + @Override + public DataType getDataType() { + return targetType; + } + + @Override + public boolean isNullAt(int rowId) { + return inputVector.isNullAt(rowId); + } + + @Override + public int getSize() { + return inputVector.getSize(); + } + + @Override + public void close() { + inputVector.close(); + } + } + + private static class ByteUpConverter extends UpConverter { + ByteUpConverter(DataType targetType, ColumnVector inputVector) { + super(targetType, inputVector); + } + + @Override + public short getShort(int rowId) { + return inputVector.getByte(rowId); + } + + @Override + public int getInt(int rowId) { + return inputVector.getByte(rowId); + } + + @Override + public long getLong(int rowId) { + return inputVector.getByte(rowId); + } + + @Override + public float getFloat(int rowId) { + return inputVector.getByte(rowId); + } + + @Override + public double getDouble(int rowId) { + return inputVector.getByte(rowId); + } + } + + private static class ShortUpConverter extends UpConverter { + ShortUpConverter(DataType targetType, ColumnVector inputVector) { + super(targetType, inputVector); + } + + @Override + public int getInt(int rowId) { + return inputVector.getShort(rowId); + } + + @Override + public long getLong(int rowId) { + return inputVector.getShort(rowId); + } + + @Override + public float getFloat(int rowId) { + return inputVector.getShort(rowId); + } + + @Override + public double getDouble(int rowId) { + return inputVector.getShort(rowId); + } + } + + private static class IntUpConverter extends UpConverter { + IntUpConverter(DataType targetType, ColumnVector inputVector) { + super(targetType, inputVector); + } + + @Override + public long getLong(int rowId) { + return inputVector.getInt(rowId); + } + + @Override + public float getFloat(int rowId) { + return inputVector.getInt(rowId); + } + + @Override + public double getDouble(int rowId) { + return inputVector.getInt(rowId); + } + } + + private static class LongUpConverter extends UpConverter { + LongUpConverter(DataType targetType, ColumnVector inputVector) { + super(targetType, inputVector); + } + + @Override + public float getFloat(int rowId) { + return inputVector.getLong(rowId); + } + + @Override + public double getDouble(int rowId) { + return inputVector.getLong(rowId); + } + } + + private static class FloatUpConverter extends UpConverter { + FloatUpConverter(DataType targetType, ColumnVector inputVector) { + super(targetType, inputVector); + } + + @Override + public double getDouble(int rowId) { + return inputVector.getFloat(rowId); + } + } +} diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultExpressionHandler.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultExpressionHandler.java deleted file mode 100644 index 6a10c1cb9b9..00000000000 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultExpressionHandler.java +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.delta.kernel.defaults.client; - -import java.sql.Date; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Optional; - -import org.junit.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.expressions.And; -import io.delta.kernel.expressions.Column; -import io.delta.kernel.expressions.EqualTo; -import io.delta.kernel.expressions.Expression; -import io.delta.kernel.expressions.Literal; -import io.delta.kernel.types.*; - -import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch; -import io.delta.kernel.defaults.internal.data.vector.DefaultIntVector; -import io.delta.kernel.defaults.internal.data.vector.DefaultLongVector; -import static io.delta.kernel.defaults.internal.DefaultKernelUtils.daysSinceEpoch; - -public class TestDefaultExpressionHandler { - /** - * Evaluate literal expressions. This is used to populate the partition column vectors. - */ - @Test - public void evalLiterals() { - StructType inputSchema = new StructType(); - ColumnVector[] data = new ColumnVector[0]; - - List testCases = new ArrayList<>(); - testCases.add(Literal.of(true)); - testCases.add(Literal.of(false)); - testCases.add(Literal.ofNull(BooleanType.INSTANCE)); - testCases.add(Literal.of((byte) 24)); - testCases.add(Literal.ofNull(ByteType.INSTANCE)); - testCases.add(Literal.of((short) 876)); - testCases.add(Literal.ofNull(ShortType.INSTANCE)); - testCases.add(Literal.of(2342342)); - testCases.add(Literal.ofNull(IntegerType.INSTANCE)); - testCases.add(Literal.of(234234223L)); - testCases.add(Literal.ofNull(LongType.INSTANCE)); - testCases.add(Literal.of(23423.4223f)); - testCases.add(Literal.ofNull(FloatType.INSTANCE)); - testCases.add(Literal.of(23423.422233d)); - testCases.add(Literal.ofNull(DoubleType.INSTANCE)); - testCases.add(Literal.of("string_val")); - testCases.add(Literal.ofNull(StringType.INSTANCE)); - testCases.add(Literal.of("binary_val".getBytes())); - testCases.add(Literal.ofNull(BinaryType.INSTANCE)); - testCases.add(Literal.of(new Date(234234234))); - testCases.add(Literal.ofNull(DateType.INSTANCE)); - // testCases.add(Literal.of(new Timestamp(2342342342232L))); - // testCases.add(Literal.ofNull(TimestampType.INSTANCE)); - - ColumnarBatch[] inputBatches = new ColumnarBatch[] { - new DefaultColumnarBatch(0, inputSchema, data), - new DefaultColumnarBatch(25, inputSchema, data), - new DefaultColumnarBatch(128, inputSchema, data) - }; - - for (Literal expression : testCases) { - DataType outputDataType = expression.dataType(); - - for (ColumnarBatch inputBatch : inputBatches) { - ColumnVector outputVector = eval(inputSchema, inputBatch, expression); - assertEquals(inputBatch.getSize(), outputVector.getSize()); - assertEquals(outputDataType, outputVector.getDataType()); - for (int rowId = 0; rowId < outputVector.getSize(); rowId++) { - if (expression.value() == null) { - assertTrue(outputVector.isNullAt(rowId)); - continue; - } - Object expRowValue = expression.value(); - if (outputDataType instanceof BooleanType) { - assertEquals(expRowValue, outputVector.getBoolean(rowId)); - } else if (outputDataType instanceof ByteType) { - assertEquals(expRowValue, outputVector.getByte(rowId)); - } else if (outputDataType instanceof ShortType) { - assertEquals(expRowValue, outputVector.getShort(rowId)); - } else if (outputDataType instanceof IntegerType) { - assertEquals(expRowValue, outputVector.getInt(rowId)); - } else if (outputDataType instanceof LongType) { - assertEquals(expRowValue, outputVector.getLong(rowId)); - } else if (outputDataType instanceof FloatType) { - assertEquals(expRowValue, outputVector.getFloat(rowId)); - } else if (outputDataType instanceof DoubleType) { - assertEquals(expRowValue, outputVector.getDouble(rowId)); - } else if (outputDataType instanceof StringType) { - assertEquals(expRowValue, outputVector.getString(rowId)); - } else if (outputDataType instanceof BinaryType) { - assertEquals(expRowValue, outputVector.getBinary(rowId)); - } else if (outputDataType instanceof DateType) { - assertEquals( - daysSinceEpoch((Date) expRowValue), outputVector.getInt(rowId)); - } else { - throw new UnsupportedOperationException( - "unsupported output type encountered: " + outputDataType); - } - } - } - } - } - - @Test - public void evalBooleanExpressionSimple() { - Expression expression = new EqualTo( - new Column(0, "intType", IntegerType.INSTANCE), - Literal.of(3)); - - for (int size : Arrays.asList(26, 234, 567)) { - StructType inputSchema = new StructType() - .add("intType", IntegerType.INSTANCE); - ColumnVector[] data = new ColumnVector[] { - intVector(size) - }; - - ColumnarBatch inputBatch = new DefaultColumnarBatch(size, inputSchema, data); - - ColumnVector output = eval(inputSchema, inputBatch, expression); - for (int rowId = 0; rowId < size; rowId++) { - if (data[0].isNullAt(rowId)) { - // expect the output to be null as well - assertTrue(output.isNullAt(rowId)); - } else { - assertFalse(output.isNullAt(rowId)); - boolean expValue = rowId % 7 == 3; - assertEquals(expValue, output.getBoolean(rowId)); - } - } - } - } - - @Test - public void evalBooleanExpressionComplex() { - Expression expression = new And( - new EqualTo(new Column(0, "intType", IntegerType.INSTANCE), Literal.of(3)), - new EqualTo(new Column(1, "longType", LongType.INSTANCE), Literal.of(4L)) - ); - - for (int size : Arrays.asList(26, 234, 567)) { - StructType inputSchema = new StructType() - .add("intType", IntegerType.INSTANCE) - .add("longType", LongType.INSTANCE); - ColumnVector[] data = new ColumnVector[] { - intVector(size), - longVector(size), - }; - - ColumnarBatch inputBatch = new DefaultColumnarBatch(size, inputSchema, data); - - ColumnVector output = eval(inputSchema, inputBatch, expression); - for (int rowId = 0; rowId < size; rowId++) { - if (data[0].isNullAt(rowId) || data[1].isNullAt(rowId)) { - // expect the output to be null as well - assertTrue(output.isNullAt(rowId)); - } else { - assertFalse(output.isNullAt(rowId)); - boolean expValue = (rowId % 7 == 3) && (rowId * 200L / 87 == 4); - assertEquals(expValue, output.getBoolean(rowId)); - } - } - } - } - - private static ColumnVector eval( - StructType inputSchema, ColumnarBatch input, Expression expression) { - return new DefaultExpressionHandler() - .getEvaluator(inputSchema, expression) - .eval(input); - } - - private static ColumnVector intVector(int size) { - int[] values = new int[size]; - boolean[] nullability = new boolean[size]; - - for (int rowId = 0; rowId < size; rowId++) { - if (rowId % 5 == 0) { - nullability[rowId] = true; - } else { - values[rowId] = rowId % 7; - } - } - - return new DefaultIntVector( - IntegerType.INSTANCE, size, Optional.of(nullability), values); - } - - private static ColumnVector longVector(int size) { - long[] values = new long[size]; - boolean[] nullability = new boolean[size]; - - for (int rowId = 0; rowId < size; rowId++) { - if (rowId % 5 == 0) { - nullability[rowId] = true; - } else { - values[rowId] = rowId * 200L % 87; - } - } - - return new DefaultLongVector(LongType.INSTANCE, size, Optional.of(nullability), values); - } -} diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java index 75ad7b775f2..00e7f8f4e19 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java @@ -15,11 +15,7 @@ */ package io.delta.kernel.defaults.client; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; @@ -33,17 +29,14 @@ import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.data.FileDataReadResult; import io.delta.kernel.data.Row; -import io.delta.kernel.expressions.Literal; import io.delta.kernel.fs.FileStatus; -import io.delta.kernel.types.BooleanType; -import io.delta.kernel.types.LongType; -import io.delta.kernel.types.MapType; -import io.delta.kernel.types.StringType; -import io.delta.kernel.types.StructType; +import io.delta.kernel.types.*; import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.Utils; +import static io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE; import io.delta.kernel.defaults.utils.DefaultKernelTestUtils; + import io.delta.kernel.defaults.internal.data.DefaultJsonRow; public class TestDefaultJsonHandler { @@ -60,7 +53,7 @@ public void contextualizeFiles() throws Exception { try (CloseableIterator inputScanFiles = testFiles(); CloseableIterator fileReadContexts = - JSON_HANDLER.contextualizeFileReads(testFiles(), Literal.TRUE)) { + JSON_HANDLER.contextualizeFileReads(testFiles(), ALWAYS_TRUE)) { while (inputScanFiles.hasNext() || fileReadContexts.hasNext()) { assertEquals(inputScanFiles.hasNext(), fileReadContexts.hasNext()); Row inputScanFile = inputScanFiles.next(); @@ -76,7 +69,7 @@ public void readJsonFiles() try ( CloseableIterator data = JSON_HANDLER.readJsonFiles( - JSON_HANDLER.contextualizeFileReads(testFiles(), Literal.TRUE), + JSON_HANDLER.contextualizeFileReads(testFiles(), ALWAYS_TRUE), new StructType() .add("path", StringType.INSTANCE) .add("size", LongType.INSTANCE) diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java index 8b8591f96d1..5fd42a6756c 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java @@ -30,12 +30,12 @@ import io.delta.kernel.client.TableClient; import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.types.*; +import static io.delta.kernel.internal.util.InternalUtils.daysSinceEpoch; import io.delta.kernel.defaults.client.DefaultTableClient; import io.delta.kernel.defaults.integration.DataBuilderUtils.TestColumnBatchBuilder; import static io.delta.kernel.defaults.integration.DataBuilderUtils.row; import static io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getTestResourceFilePath; -import static io.delta.kernel.defaults.internal.DefaultKernelUtils.daysSinceEpoch; /** * Test reading Delta lake tables end to end using the Kernel APIs and default {@link TableClient} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/TestUtils.scala index bb7be1413c9..cc75a1d121e 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/TestUtils.scala @@ -17,21 +17,18 @@ package io.delta.kernel.defaults import java.util.{Optional, TimeZone} -import collection.JavaConverters._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.conf.Configuration - import io.delta.kernel.{Scan, Snapshot} import io.delta.kernel.client.TableClient import io.delta.kernel.data.Row -import io.delta.kernel.types.StructType -import io.delta.kernel.utils.CloseableIterator - import io.delta.kernel.defaults.client.DefaultTableClient +import io.delta.kernel.types._ +import io.delta.kernel.utils.CloseableIterator +import org.apache.hadoop.conf.Configuration trait TestUtils { - lazy val defaultTableClient = DefaultTableClient.create(new Configuration()) implicit class CloseableIteratorOps[T](private val iter: CloseableIterator[T]) { @@ -127,4 +124,27 @@ trait TestUtils { TimeZone.setDefault(currentDefault) } } + + /** All simple data type used in parameterized tests where type is one of the test dimensions. */ + val SIMPLE_TYPES = Seq( + BooleanType.INSTANCE, + ByteType.INSTANCE, + ShortType.INSTANCE, + IntegerType.INSTANCE, + LongType.INSTANCE, + FloatType.INSTANCE, + DoubleType.INSTANCE, + DateType.INSTANCE, + TimestampType.INSTANCE, + StringType.INSTANCE, + BinaryType.INSTANCE, + new DecimalType(10, 5) + ) + + /** All types. Used in parameterized tests where type is one of the test dimensions. */ + val ALL_TYPES = SIMPLE_TYPES ++ Seq( + new ArrayType(BooleanType.INSTANCE, true), + new MapType(IntegerType.INSTANCE, LongType.INSTANCE, true), + new StructType().add("s1", BooleanType.INSTANCE).add("s2", IntegerType.INSTANCE) + ) } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala new file mode 100644 index 00000000000..d8c3ccfaa35 --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -0,0 +1,491 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions + +import java.lang.{Boolean => BooleanJ} +import java.math.{BigDecimal => BigDecimalJ} +import java.util + +import io.delta.kernel.data.{ColumnarBatch, ColumnVector} +import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch +import io.delta.kernel.defaults.TestUtils +import io.delta.kernel.defaults.internal.data.vector.VectorUtils.getValueAsObject +import io.delta.kernel.expressions._ +import io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE +import io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE +import io.delta.kernel.expressions.Literal._ +import io.delta.kernel.types._ +import org.scalatest.funsuite.AnyFunSuite + +class DefaultExpressionEvaluatorSuite extends AnyFunSuite with TestUtils { + test("evaluate expression: literal") { + val testLiterals = Seq( + Literal.ofBoolean(true), + Literal.ofBoolean(false), + Literal.ofNull(BooleanType.INSTANCE), + ofByte(24.toByte), + Literal.ofNull(ByteType.INSTANCE), + Literal.ofShort(876.toShort), + Literal.ofNull(ShortType.INSTANCE), + Literal.ofInt(2342342), + Literal.ofNull(IntegerType.INSTANCE), + Literal.ofLong(234234223L), + Literal.ofNull(LongType.INSTANCE), + Literal.ofFloat(23423.4223f), + Literal.ofNull(FloatType.INSTANCE), + Literal.ofDouble(23423.422233d), + Literal.ofNull(DoubleType.INSTANCE), + Literal.ofString("string_val"), + Literal.ofNull(StringType.INSTANCE), + Literal.ofBinary("binary_val".getBytes), + Literal.ofNull(BinaryType.INSTANCE), + Literal.ofDate(4234), + Literal.ofNull(DateType.INSTANCE), + Literal.ofTimestamp(2342342342232L), + Literal.ofNull(TimestampType.INSTANCE)) + + val inputBatches: Seq[ColumnarBatch] = Seq[ColumnarBatch]( + zeroColumnBatch(rowCount = 0), + zeroColumnBatch(rowCount = 25), + zeroColumnBatch(rowCount = 128)) + + for (literal <- testLiterals) { + val outputDataType = literal.getDataType + for (inputBatch <- inputBatches) { + val outputVector: ColumnVector = + evaluator(inputBatch.getSchema, literal, literal.getDataType) + .eval(inputBatch) + + assert(inputBatch.getSize === outputVector.getSize) + assert(outputDataType === outputVector.getDataType) + + for (rowId <- 0 until outputVector.getSize) { + if (literal.getValue == null) { + assert( + outputVector.isNullAt(rowId), + s"expected a null at $rowId for $literal expression") + } else { + assert( + literal.getValue === getValueAsObject(outputVector, rowId), + s"invalid value at $rowId for $literal expression" + ) + } + } + } + } + } + + SIMPLE_TYPES.foreach { dataType => + test(s"evaluate expression: column of type $dataType") { + val batchSize = 78; + val batchSchema = new StructType().add("col1", dataType) + val batch = new DefaultColumnarBatch( + batchSize, + batchSchema, + Array[ColumnVector](testColumnVector(batchSize, dataType))) + + val outputVector = evaluator(batchSchema, new Column("col1"), dataType) + .eval(batch) + + assert(batchSize === outputVector.getSize) + assert(dataType === outputVector.getDataType) + Seq.range(0, outputVector.getSize).foreach { rowId => + assert( + testIsNullValue(dataType, rowId) === outputVector.isNullAt(rowId), + s"unexpected nullability at $rowId for $dataType type vector") + if (!outputVector.isNullAt(rowId)) { + assert( + testColumnValue(dataType, rowId) === getValueAsObject(outputVector, rowId), + s"unexpected value at $rowId for $dataType type vector") + } + } + } + } + + test("evaluate expression: always true, always false") { + Seq(ALWAYS_TRUE, ALWAYS_FALSE).foreach { expr => + val batch = zeroColumnBatch(rowCount = 87) + val outputVector = evaluator(batch.getSchema, expr, BooleanType.INSTANCE).eval(batch) + assert(outputVector.getSize === 87) + assert(outputVector.getDataType === BooleanType.INSTANCE) + Seq.range(0, 87).foreach { rowId => + assert(!outputVector.isNullAt(rowId)) + assert(outputVector.getBoolean(rowId) == (expr == ALWAYS_TRUE)) + } + } + } + + test("evaluate expression: and, or") { + val leftColumn = booleanVector( + Seq[BooleanJ](true, true, false, false, null, true, null, false, null)) + val rightColumn = booleanVector( + Seq[BooleanJ](true, false, false, true, true, null, false, null, null)) + val expAndOutputVector = booleanVector( + Seq[BooleanJ](true, false, false, false, null, null, null, null, null)) + val expOrOutputVector = booleanVector( + Seq[BooleanJ](true, true, false, true, null, null, null, null, null)) + + val schema = new StructType() + .add("left", BooleanType.INSTANCE) + .add("right", BooleanType.INSTANCE) + val batch = new DefaultColumnarBatch(leftColumn.getSize, schema, Array(leftColumn, rightColumn)) + + val left = comparator("=", new Column("left"), Literal.ofBoolean(true)) + val right = comparator("=", new Column("right"), Literal.ofBoolean(true)) + + // And + val andExpression = and(left, right) + val actAndOutputVector = evaluator(schema, andExpression, BooleanType.INSTANCE).eval(batch) + checkBooleanVectors(actAndOutputVector, expAndOutputVector) + + // Or + val orExpression = or(left, right) + val actOrOutputVector = evaluator(schema, orExpression, BooleanType.INSTANCE).eval(batch) + checkBooleanVectors(actOrOutputVector, expOrOutputVector) + } + + + test("evaluate expression: comparators (=, <, <=, >, >=)") { + // Literals for each data type from the data type value range, used as inputs to comparator + // (small, big, small, null) + val literals = Seq( + (ofByte(1.toByte), ofByte(2.toByte), ofByte(1.toByte), ofNull(ByteType.INSTANCE)), + (ofShort(1.toShort), ofShort(2.toShort), ofShort(1.toShort), ofNull(ShortType.INSTANCE)), + (ofInt(1), ofInt(2), ofInt(1), ofNull(IntegerType.INSTANCE)), + (ofLong(1L), ofLong(2L), ofLong(1L), ofNull(LongType.INSTANCE)), + (ofFloat(1.0F), ofFloat(2.0F), ofFloat(1.0F), ofNull(FloatType.INSTANCE)), + (ofDouble(1.0), ofDouble(2.0), ofDouble(1.0), ofNull(DoubleType.INSTANCE)), + (ofBoolean(false), ofBoolean(true), ofBoolean(false), ofNull(BooleanType.INSTANCE)), + ( + ofTimestamp(343L), + ofTimestamp(123212312L), + ofTimestamp(343L), + ofNull(TimestampType.INSTANCE) + ), + (ofDate(-12123), ofDate(123123), ofDate(-12123), ofNull(DateType.INSTANCE)), + (ofString("apples"), ofString("oranges"), ofString("apples"), ofNull(StringType.INSTANCE)), + ( + ofBinary("apples".getBytes()), + ofBinary("oranges".getBytes()), + ofBinary("apples".getBytes()), + ofNull(BinaryType.INSTANCE) + ), + ( + ofDecimal(BigDecimalJ.valueOf(1.12), 7, 3), + ofDecimal(BigDecimalJ.valueOf(5233.232), 7, 3), + ofDecimal(BigDecimalJ.valueOf(1.12), 7, 3), + ofNull(new DecimalType(7, 3)) + ) + ) + + // Mapping of comparator to expected results for: + // comparator(small, big) + // comparator(big, small) + // comparator(small, small) + // comparator(small, null) + // comparator(big, null) + // comparator(null, null) + val comparatorToExpResults = Map[String, Seq[BooleanJ]]( + "<" -> Seq(true, false, false, null, null, null), + "<=" -> Seq(true, false, true, null, null, null), + ">" -> Seq(false, true, false, null, null, null), + ">=" -> Seq(false, true, true, null, null, null), + "=" -> Seq(false, false, true, null, null, null) + ) + + literals.foreach { + case (small1, big, small2, nullLit) => + comparatorToExpResults.foreach { + case (comparator, expectedResults) => + testComparator(comparator, small1, big, expectedResults(0)) + testComparator(comparator, big, small1, expectedResults(1)) + testComparator(comparator, small1, small2, expectedResults(2)) + testComparator(comparator, small1, nullLit, expectedResults(3)) + testComparator(comparator, nullLit, big, expectedResults(4)) + testComparator(comparator, nullLit, nullLit, expectedResults(5)) + } + } + } + + // Literals for each data type from the data type value range, used as inputs to comparator + // (byte, short, int, float, double) + val literals = Seq( + ofByte(1.toByte), + ofShort(223), + ofInt(-234), + ofLong(223L), + ofFloat(-2423423.9f), + ofNull(DoubleType.INSTANCE) + ) + + test("evaluate expression: comparators `byte` with other implicit types") { + // Mapping of comparator to expected results for: + // (byte, short), (byte, int), (byte, long), (byte, float), (byte, double) + val comparatorToExpResults = Map[String, Seq[BooleanJ]]( + "<" -> Seq(true, false, true, false, null), + "<=" -> Seq(true, false, true, false, null), + ">" -> Seq(false, true, false, true, null), + ">=" -> Seq(false, true, false, true, null), + "=" -> Seq(false, false, false, false, null) + ) + + // Left operand is first literal in [[literal]] which a byte type + // Right operands are the remaining literals to the left side of it in [[literal]] + val right = literals(0) + Seq.range(1, literals.length).foreach { idx => + comparatorToExpResults.foreach { + case (comparator, expectedResults) => + testComparator(comparator, right, literals(idx), expectedResults(idx - 1)) + } + } + } + + test("evaluate expression: comparators `short` with other implicit types") { + // Mapping of comparator to expected results for: + // (short, int), (short, long), (short, float), (short, double) + val comparatorToExpResults = Map[String, Seq[BooleanJ]]( + "<" -> Seq(false, false, false, null), + "<=" -> Seq(false, true, false, null), + ">" -> Seq(true, false, true, null), + ">=" -> Seq(true, true, true, null), + "=" -> Seq(false, true, false, null) + ) + + // Left operand is first literal in [[literal]] which a short type + // Right operands are the remaining literals to the left side of it in [[literal]] + val right = literals(1) + Seq.range(2, literals.length).foreach { idx => + comparatorToExpResults.foreach { + case (comparator, expectedResults) => + testComparator(comparator, right, literals(idx), expectedResults(idx - 2)) + } + } + } + + test("evaluate expression: comparators `int` with other implicit types") { + // Mapping of comparator to expected results for: (int, long), (int, float), (int, double) + val comparatorToExpResults = Map[String, Seq[BooleanJ]]( + "<" -> Seq(true, false, null), + "<=" -> Seq(true, false, null), + ">" -> Seq(false, true, null), + ">=" -> Seq(false, true, null), + "=" -> Seq(false, false, null) + ) + + // Left operand is first literal in [[literal]] which a int type + // Right operands are the remaining literals to the left side of it in [[literal]] + val right = literals(2) + Seq.range(3, literals.length).foreach { idx => + comparatorToExpResults.foreach { + case (comparator, expectedResults) => + testComparator(comparator, right, literals(idx), expectedResults(idx - 3)) + } + } + } + + test("evaluate expression: comparators `long` with other implicit types") { + // Mapping of comparator to expected results for: (long, float), (long, double) + val comparatorToExpResults = Map[String, Seq[BooleanJ]]( + "<" -> Seq(false, null), + "<=" -> Seq(false, null), + ">" -> Seq(true, null), + ">=" -> Seq(true, null), + "=" -> Seq(false, null) + ) + + // Left operand is fourth literal in [[literal]] which a long type + // Right operands are the remaining literals to the left side of it in [[literal]] + val right = literals(3) + Seq.range(4, literals.length).foreach { idx => + comparatorToExpResults.foreach { + case (comparator, expectedResults) => + testComparator(comparator, right, literals(idx), expectedResults(idx - 4)) + } + } + } + + test("evaluate expression: unsupported implicit casts") { + intercept[UnsupportedOperationException] { + testComparator("<", ofInt(21), ofDate(123), null) + } + } + + test("evaluate expression: comparators `float` with other implicit types") { + // Comparator results for: (float, double) is always null as one of the operands is null + val comparatorToExpResults = Seq("<", "<=", ">", ">=", "=") + + // Left operand is fifth literal in [[literal]] which is a float type + // Right operands are the remaining literals to the left side of it in [[literal]] + val right = literals(4) + Seq.range(5, literals.length).foreach { idx => + comparatorToExpResults.foreach { comparator => + testComparator(comparator, right, literals(idx), null) + } + } + } + + /** + * Utility method to generate a [[dataType]] column vector of given size. + * The nullability of rows is determined by the [[testIsNullValue(dataType, rowId)]]. + * The row values are determined by [[testColumnValue(dataType, rowId)]]. + */ + private def testColumnVector(size: Int, dataType: DataType): ColumnVector = { + new ColumnVector { + override def getDataType: DataType = dataType + + override def getSize: Int = size + + override def close(): Unit = {} + + override def isNullAt(rowId: Int): Boolean = testIsNullValue(dataType, rowId) + + override def getBoolean(rowId: Int): Boolean = + testColumnValue(dataType, rowId).asInstanceOf[Boolean] + + override def getByte(rowId: Int): Byte = testColumnValue(dataType, rowId).asInstanceOf[Byte] + + override def getShort(rowId: Int): Short = + testColumnValue(dataType, rowId).asInstanceOf[Short] + + override def getInt(rowId: Int): Int = testColumnValue(dataType, rowId).asInstanceOf[Int] + + override def getLong(rowId: Int): Long = testColumnValue(dataType, rowId).asInstanceOf[Long] + + override def getFloat(rowId: Int): Float = + testColumnValue(dataType, rowId).asInstanceOf[Float] + + override def getDouble(rowId: Int): Double = + testColumnValue(dataType, rowId).asInstanceOf[Double] + + override def getBinary(rowId: Int): Array[Byte] = + testColumnValue(dataType, rowId).asInstanceOf[Array[Byte]] + + override def getString(rowId: Int): String = + testColumnValue(dataType, rowId).asInstanceOf[String] + + override def getDecimal(rowId: Int): BigDecimalJ = + testColumnValue(dataType, rowId).asInstanceOf[BigDecimalJ] + } + } + + /** Utility method to generate a consistent `isNull` value for given column type and row id */ + private def testIsNullValue(dataType: DataType, rowId: Int): Boolean = { + dataType match { + case BooleanType.INSTANCE => rowId % 4 == 0 + case ByteType.INSTANCE => rowId % 8 == 0 + case ShortType.INSTANCE => rowId % 12 == 0 + case IntegerType.INSTANCE => rowId % 20 == 0 + case LongType.INSTANCE => rowId % 25 == 0 + case FloatType.INSTANCE => rowId % 5 == 0 + case DoubleType.INSTANCE => rowId % 10 == 0 + case StringType.INSTANCE => rowId % 2 == 0 + case BinaryType.INSTANCE => rowId % 3 == 0 + case DateType.INSTANCE => rowId % 5 == 0 + case TimestampType.INSTANCE => rowId % 3 == 0 + case _ => + if (dataType.isInstanceOf[DecimalType]) rowId % 6 == 0 + else throw new UnsupportedOperationException(s"$dataType is not supported") + } + } + + /** Utility method to generate a consistent column value for given column type and row id */ + private def testColumnValue(dataType: DataType, rowId: Int): Any = { + dataType match { + case BooleanType.INSTANCE => rowId % 7 == 0 + case ByteType.INSTANCE => (rowId * 7 / 17).toByte + case ShortType.INSTANCE => (rowId * 9 / 87).toShort + case IntegerType.INSTANCE => rowId * 2876 / 176 + case LongType.INSTANCE => rowId * 287623L / 91 + case FloatType.INSTANCE => rowId * 7651.2323f / 91 + case DoubleType.INSTANCE => rowId * 23423.23d / 17 + case StringType.INSTANCE => (rowId % 19).toString + case BinaryType.INSTANCE => Array[Byte]((rowId % 21).toByte, (rowId % 7 - 1).toByte) + case DateType.INSTANCE => (rowId * 28234) % 2876 + case TimestampType.INSTANCE => (rowId * 2342342L) % 23 + case _ => + if (dataType.isInstanceOf[DecimalType]) new BigDecimalJ(rowId * 22342.23) + else throw new UnsupportedOperationException(s"$dataType is not supported") + } + } + + private def booleanVector(values: Seq[BooleanJ]): ColumnVector = { + new ColumnVector { + override def getDataType: DataType = BooleanType.INSTANCE + + override def getSize: Int = values.length + + override def close(): Unit = {} + + override def isNullAt(rowId: Int): Boolean = values(rowId) == null + + override def getBoolean(rowId: Int): Boolean = values(rowId) + } + } + + private def evaluator(inputSchema: StructType, expression: Expression, outputType: DataType) + : DefaultExpressionEvaluator = { + new DefaultExpressionEvaluator(inputSchema, expression, outputType) + } + + /** create a columnar batch of given `size` with zero columns in it. */ + private def zeroColumnBatch(rowCount: Int): ColumnarBatch = { + new DefaultColumnarBatch(rowCount, new StructType(), new Array[ColumnVector](0)) + } + + private def and(left: Predicate, right: Predicate): And = { + new And(left, right) + } + + private def or(left: Predicate, right: Predicate): Or = { + new Or(left, right) + } + + private def comparator(symbol: String, left: Expression, right: Expression): Predicate = { + new Predicate(symbol, util.Arrays.asList(left, right)) + } + + private def checkBooleanVectors(actual: ColumnVector, expected: ColumnVector): Unit = { + assert(actual.getDataType === expected.getDataType) + assert(actual.getSize === expected.getSize) + Seq.range(0, actual.getSize).foreach { rowId => + assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) + if (!actual.isNullAt(rowId)) { + assert( + actual.getBoolean(rowId) === expected.getBoolean(rowId), + s"unexpected value at $rowId" + ) + } + } + } + + private def testComparator( + comparator: String, left: Expression, right: Expression, expResult: BooleanJ): Unit = { + val expression = new Predicate(comparator, util.Arrays.asList(left, right)) + val batch = zeroColumnBatch(rowCount = 1) + val outputVector = evaluator(batch.getSchema, expression, BooleanType.INSTANCE).eval(batch) + + assert(outputVector.getSize === 1) + assert(outputVector.getDataType === BooleanType.INSTANCE) + assert( + outputVector.isNullAt(0) === (expResult == null), + s"Unexpected null value: $comparator($left, $right)") + if (expResult != null) { + assert( + outputVector.getBoolean(0) === expResult, + s"Unexpected value: $comparator($left, $right)") + } + } +} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala new file mode 100644 index 00000000000..40a3fce6b51 --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala @@ -0,0 +1,120 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions + +import io.delta.kernel.data.ColumnVector +import io.delta.kernel.defaults.internal.data.vector.VectorUtils.getValueAsObject +import io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo +import io.delta.kernel.defaults.TestUtils +import io.delta.kernel.expressions.Column +import io.delta.kernel.types._ +import org.scalatest.funsuite.AnyFunSuite + +class ImplicitCastExpressionSuite extends AnyFunSuite with TestUtils { + private val allowedCasts: Set[(DataType, DataType)] = Set( + (ByteType.INSTANCE, ShortType.INSTANCE), + (ByteType.INSTANCE, IntegerType.INSTANCE), + (ByteType.INSTANCE, LongType.INSTANCE), + (ByteType.INSTANCE, FloatType.INSTANCE), + (ByteType.INSTANCE, DoubleType.INSTANCE), + + (ShortType.INSTANCE, IntegerType.INSTANCE), + (ShortType.INSTANCE, LongType.INSTANCE), + (ShortType.INSTANCE, FloatType.INSTANCE), + (ShortType.INSTANCE, DoubleType.INSTANCE), + + (IntegerType.INSTANCE, LongType.INSTANCE), + (IntegerType.INSTANCE, FloatType.INSTANCE), + (IntegerType.INSTANCE, DoubleType.INSTANCE), + + (LongType.INSTANCE, FloatType.INSTANCE), + (LongType.INSTANCE, DoubleType.INSTANCE), + (FloatType.INSTANCE, DoubleType.INSTANCE)) + + test("can cast to") { + Seq.range(0, ALL_TYPES.length).foreach { fromTypeIdx => + val fromType: DataType = ALL_TYPES(fromTypeIdx) + Seq.range(0, ALL_TYPES.length).foreach { toTypeIdx => + val toType: DataType = ALL_TYPES(toTypeIdx) + assert(canCastTo(fromType, toType) === + allowedCasts.contains((fromType, toType))) + } + } + } + + allowedCasts.foreach { castPair => + test(s"eval cast expression: ${castPair._1} -> ${castPair._2}") { + val fromType = castPair._1 + val toType = castPair._2 + val inputVector = testData(87, fromType, (rowId) => rowId % 7 == 0) + val outputVector = new ImplicitCastExpression(new Column("id"), toType) + .eval(inputVector) + checkCastOutput(inputVector, toType, outputVector) + } + } + + def testData(size: Int, dataType: DataType, nullability: (Int) => Boolean): ColumnVector = { + new ColumnVector { + override def getDataType: DataType = dataType + override def getSize: Int = size + override def close(): Unit = {} + override def isNullAt(rowId: Int): Boolean = nullability(rowId) + + override def getByte(rowId: Int): Byte = { + assert(dataType === ByteType.INSTANCE) + generateValue(rowId).toByte + } + + override def getShort(rowId: Int): Short = { + assert(dataType === ShortType.INSTANCE) + generateValue(rowId).toShort + } + + override def getInt(rowId: Int): Int = { + assert(dataType === IntegerType.INSTANCE) + generateValue(rowId).toInt + } + + override def getLong(rowId: Int): Long = { + assert(dataType === LongType.INSTANCE) + generateValue(rowId).toLong + } + + override def getFloat(rowId: Int): Float = { + assert(dataType === FloatType.INSTANCE) + generateValue(rowId).toFloat + } + + override def getDouble(rowId: Int): Double = { + assert(dataType === DoubleType.INSTANCE) + generateValue(rowId) + } + } + } + + // Utility method to generate a value based on the rowId. Returned value is a double + // which the callers can cast to appropriate numerical type. + private def generateValue(rowId: Int): Double = rowId * 2.76 + 7623 + + private def checkCastOutput(input: ColumnVector, toType: DataType, output: ColumnVector): Unit = { + assert(input.getSize === output.getSize) + assert(toType === output.getDataType) + Seq.range(0, input.getSize).foreach { rowId => + assert(input.isNullAt(rowId) === output.isNullAt(rowId)) + assert(getValueAsObject(input, rowId) === getValueAsObject(output, rowId)) + } + } +}