From ca698951e6e223a07d584445e164678626da986c Mon Sep 17 00:00:00 2001 From: Venki Korukanti Date: Thu, 28 Sep 2023 12:42:57 -0700 Subject: [PATCH] [Kernel] Add partition pruning related utility methods ## Description Part of delta-io/delta#2071 (Partition Pruning in Kernel). This PR adds the following utility methods: * Dividing `Predicate` given to the `ScanBuilder.withFilter` into data column and partition column predicates * Rewrite the partition column `Predicate` to refer to the columns in the scan file columnar batch with the appropriate partition value deserialization expressions applied. ## How was this patch tested? Added UTs --- .../java/io/delta/kernel/expressions/And.java | 5 - .../java/io/delta/kernel/expressions/Or.java | 5 - .../delta/kernel/expressions/Predicate.java | 6 +- .../internal/InternalScanFileUtils.java | 6 + .../kernel/internal/util/PartitionUtils.java | 133 ++++++++++++- .../internal/util/PartitionUtilsSuite.scala | 181 ++++++++++++++++++ 6 files changed, 318 insertions(+), 18 deletions(-) create mode 100644 kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/And.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/And.java index 9ba046466a9..04c17181254 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/And.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/And.java @@ -51,9 +51,4 @@ public Predicate getLeft() { public Predicate getRight() { return (Predicate) getChildren().get(1); } - - @Override - public String toString() { - return "(" + getLeft() + " AND " + getRight() + ")"; - } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Or.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Or.java index e8e2fa37572..3c7a84d9003 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Or.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Or.java @@ -50,9 +50,4 @@ public Predicate getLeft() { public Predicate getRight() { return (Predicate) getChildren().get(1); } - - @Override - public String toString() { - return "(" + getLeft() + " OR " + getRight() + ")"; - } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java index 8f98005a9f5..e3662231b9e 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java @@ -96,12 +96,12 @@ public Predicate(String name, List children) { @Override public String toString() { - if (COMPARATORS.contains(name)) { + if (BINARY_OPERATORS.contains(name)) { return String.format("(%s %s %s)", children.get(0), name, children.get(1)); } return super.toString(); } - private static final Set COMPARATORS = - Stream.of("<", "<=", ">", ">=", "=").collect(Collectors.toSet()); + private static final Set BINARY_OPERATORS = + Stream.of("<", "<=", ">", ">=", "=", "AND", "OR").collect(Collectors.toSet()); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java index f93524da6a3..b7ef2dca092 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java @@ -22,6 +22,7 @@ import io.delta.kernel.Scan; import io.delta.kernel.client.TableClient; import io.delta.kernel.data.Row; +import io.delta.kernel.expressions.Column; import io.delta.kernel.fs.FileStatus; import io.delta.kernel.types.DataType; import io.delta.kernel.types.StringType; @@ -42,6 +43,11 @@ private InternalScanFileUtils() {} private static final String TABLE_ROOT_COL_NAME = "tableRoot"; private static final DataType TABLE_ROOT_DATA_TYPE = StringType.INSTANCE; + /** + * {@link Column} expression referring to the `partitionValues` in scan `add` file. + */ + public static final Column ADD_FILE_PARTITION_COL_REF = + new Column(new String[] {"add", "partitionValues"}); public static StructField TABLE_ROOT_STRUCT_FIELD = new StructField( TABLE_ROOT_COL_NAME, diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java index 008d39a15c4..3aea6bbf8b5 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java @@ -17,19 +17,22 @@ import java.math.BigDecimal; import java.sql.Date; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static java.util.Arrays.asList; import io.delta.kernel.client.ExpressionHandler; +import io.delta.kernel.client.TableClient; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.expressions.ExpressionEvaluator; -import io.delta.kernel.expressions.Literal; +import io.delta.kernel.expressions.*; import io.delta.kernel.types.*; import io.delta.kernel.utils.Tuple2; +import static io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE; +import static io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE; + +import io.delta.kernel.internal.InternalScanFileUtils; public class PartitionUtils { private PartitionUtils() {} @@ -100,6 +103,126 @@ public static ColumnarBatch withPartitionColumns( return dataBatch; } + /** + * Split the given predicate into predicate on partition columns and predicate on data columns. + * + * @param predicate + * @param partitionColNames + * @return Tuple of partition column predicate and data column predicate. + */ + public static Tuple2 splitMetadataAndDataPredicates( + Predicate predicate, + Set partitionColNames) { + String predicateName = predicate.getName(); + List children = predicate.getChildren(); + if ("AND".equalsIgnoreCase(predicateName)) { + Predicate left = (Predicate) children.get(0); + Predicate right = (Predicate) children.get(1); + Tuple2 leftResult = + splitMetadataAndDataPredicates(left, partitionColNames); + Tuple2 rightResult = + splitMetadataAndDataPredicates(right, partitionColNames); + + return new Tuple2<>( + combineWithAndOp(leftResult._1, rightResult._1), + combineWithAndOp(leftResult._2, rightResult._2)); + } + if (hasNonPartitionColumns(children, partitionColNames)) { + return new Tuple2(ALWAYS_TRUE, predicate); + } else { + return new Tuple2<>(predicate, ALWAYS_TRUE); + } + } + + /** + * Utility method to rewrite the partition predicate referring to the table schema as predicate + * referring to the {@code partitionValues} in scan files read from Delta log. The scan file + * batch is returned by the {@link io.delta.kernel.Scan#getScanFiles(TableClient)}. + *

+ * E.g. given predicate on partition columns: + * {@code p1 = 'new york' && p2 >= 26} where p1 is of type string and p2 is of int + * Rewritten expression looks like: + * {@code element_at(Column('add', 'partitionValues'), 'p1') = 'new york' + * && + * partition_value(element_at(Column('add', 'partitionValues'), 'p2'), 'integer') >= 26} + * + * The column `add.partitionValues` is a {@literal map(string -> string)} type. Each partition + * values is in string serialization format according to the Delta protocol. Expression + * `partition_value` deserializes the string value into the given partition column type value. + * String type partition values don't need any deserialization. + * + * @param predicate Predicate containing filters only on partition columns. + * @param partitionColNameTypes Map of partition columns and their types. + * @return + */ + public static Predicate rewritePartitionPredicateOnScanFileSchema( + Predicate predicate, Map partitionColNameTypes) { + return new Predicate( + predicate.getName(), + predicate.getChildren().stream() + .map(child -> rewritePartitionColumnRef(child, partitionColNameTypes)) + .collect(Collectors.toList())); + } + + private static Expression rewritePartitionColumnRef( + Expression expression, Map partitionColNameTypes) { + Column scanFilePartitionValuesRef = InternalScanFileUtils.ADD_FILE_PARTITION_COL_REF; + if (expression instanceof Column) { + Column column = (Column) expression; + String partColName = column.getNames()[0]; + DataType partColType = partitionColNameTypes.get(partColName); + + Expression elementAt = + new ScalarExpression( + "element_at", + asList(scanFilePartitionValuesRef, Literal.ofString(partColName))); + + if (partColType instanceof StringType) { + return elementAt; + } + + // Add expression to decode the partition value based on the partition column type. + return new PartitionValueExpression(elementAt, partColType); + } else if (expression instanceof Predicate) { + return rewritePartitionPredicateOnScanFileSchema( + (Predicate) expression, partitionColNameTypes); + } + + return expression; + } + + private static boolean hasNonPartitionColumns( + List children, + Set partitionColNames) { + for (Expression child : children) { + if (child instanceof Column) { + String[] names = ((Column) child).getNames(); + // Partition columns are never of nested types. + if (names.length != 1 || !partitionColNames.contains(names[0])) { + return true; + } + } else { + return hasNonPartitionColumns(child.getChildren(), partitionColNames); + } + } + return false; + } + + private static Predicate combineWithAndOp(Predicate left, Predicate right) { + String leftName = left.getName().toUpperCase(); + String rightName = right.getName().toUpperCase(); + if (leftName.equals("ALWAYS_FALSE") || rightName.equals("ALWAYS_FALSE")) { + return ALWAYS_FALSE; + } + if (leftName.equals("ALWAYS_TRUE")) { + return right; + } + if (rightName.equals("ALWAYS_TRUE")) { + return left; + } + return new And(left, right); + } + private static Literal literalForPartitionValue(DataType dataType, String partitionValue) { if (partitionValue == null) { return Literal.ofNull(dataType); diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala new file mode 100644 index 00000000000..a5c84fb3f0b --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala @@ -0,0 +1,181 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.internal.util + +import java.util + +import scala.collection.JavaConverters._ + +import io.delta.kernel.expressions._ +import io.delta.kernel.expressions.Literal._ +import io.delta.kernel.internal.util.PartitionUtils.{rewritePartitionPredicateOnScanFileSchema, splitMetadataAndDataPredicates} +import io.delta.kernel.types._ +import org.scalatest.funsuite.AnyFunSuite + +class PartitionUtilsSuite extends AnyFunSuite { + // Table schema + // Data columns: data1: int, data2: string, date3: struct(data31: boolean, data32: long) + // Partition columns: part1: int, part2: date, part3: string + private val partitionColsToType = new util.HashMap[String, DataType]() { + { + put("part1", IntegerType.INSTANCE) + put("part2", DateType.INSTANCE) + put("part3", StringType.INSTANCE) + } + } + + private val partitionCols: java.util.Set[String] = partitionColsToType.keySet() + + // Test cases for verifying partition of predicate into data and partition predicates + // Map entry format (predicate -> (partition predicate, data predicate) + val partitionTestCases = Map[Predicate, (String, String)]( + // single predicate on a data column + predicate("=", col("data1"), ofInt(12)) -> + ("ALWAYS_TRUE()", "(column(`data1`) = 12)"), + // multiple predicates on data columns joined with AND + predicate("AND", + predicate("=", col("data1"), ofInt(12)), + predicate(">=", col("data2"), ofString("sss"))) -> + ("ALWAYS_TRUE()", "((column(`data1`) = 12) AND (column(`data2`) >= sss))"), + // multiple predicates on data columns joined with OR + predicate("OR", + predicate("<=", col("data2"), ofString("sss")), + predicate("=", col("data3", "data31"), ofBoolean(true))) -> + ("ALWAYS_TRUE()", "((column(`data2`) <= sss) OR (column(`data3`.`data31`) = true))"), + // single predicate on a partition column + predicate("=", col("part1"), ofInt(12)) -> + ("(column(`part1`) = 12)", "ALWAYS_TRUE()"), + // multiple predicates on partition columns joined with AND + predicate("AND", + predicate("=", col("part1"), ofInt(12)), + predicate(">=", col("part3"), ofString("sss"))) -> + ("((column(`part1`) = 12) AND (column(`part3`) >= sss))", "ALWAYS_TRUE()"), + // multiple predicates on partition columns joined with OR + predicate("OR", + predicate("<=", col("part3"), ofString("sss")), + predicate("=", col("part1"), ofInt(2781))) -> + ("((column(`part3`) <= sss) OR (column(`part1`) = 2781))", "ALWAYS_TRUE()"), + + // predicates (each on data and partition column) joined with AND + predicate("AND", + predicate("=", col("data1"), ofInt(12)), + predicate(">=", col("part3"), ofString("sss"))) -> + ("(column(`part3`) >= sss)", "(column(`data1`) = 12)"), + + // predicates (each on data and partition column) joined with OR + predicate("OR", + predicate("=", col("data1"), ofInt(12)), + predicate(">=", col("part3"), ofString("sss"))) -> + ("ALWAYS_TRUE()", "((column(`data1`) = 12) OR (column(`part3`) >= sss))"), + + // predicates (multiple on data and partition columns) joined with AND + predicate("AND", + predicate("AND", + predicate("=", col("data1"), ofInt(12)), + predicate(">=", col("data2"), ofString("sss"))), + predicate("AND", + predicate("=", col("part1"), ofInt(12)), + predicate(">=", col("part3"), ofString("sss")))) -> + ( + "((column(`part1`) = 12) AND (column(`part3`) >= sss))", + "((column(`data1`) = 12) AND (column(`data2`) >= sss))" + ), + + // predicates (multiple on data and partition columns joined with OR) joined with AND + predicate("AND", + predicate("OR", + predicate("=", col("data1"), ofInt(12)), + predicate(">=", col("data2"), ofString("sss"))), + predicate("OR", + predicate("=", col("part1"), ofInt(12)), + predicate(">=", col("part3"), ofString("sss")))) -> + ( + "((column(`part1`) = 12) OR (column(`part3`) >= sss))", + "((column(`data1`) = 12) OR (column(`data2`) >= sss))" + ), + + // predicates (multiple on data and partition columns joined with OR) joined with OR + predicate("OR", + predicate("OR", + predicate("=", col("data1"), ofInt(12)), + predicate(">=", col("data2"), ofString("sss"))), + predicate("OR", + predicate("=", col("part1"), ofInt(12)), + predicate(">=", col("part3"), ofString("sss")))) -> + ( + "ALWAYS_TRUE()", + "(((column(`data1`) = 12) OR (column(`data2`) >= sss)) OR " + + "((column(`part1`) = 12) OR (column(`part3`) >= sss)))" + ), + + // predicates (data and partitions compared in the same expression) + predicate("AND", + predicate("=", col("data1"), col("part1")), + predicate(">=", col("part3"), ofString("sss"))) -> + ( + "(column(`part3`) >= sss)", + "(column(`data1`) = column(`part1`))" + ) + ) + + partitionTestCases.foreach { + case (predicate, (partitionPredicate, dataPredicate)) => + test(s"split predicate into data and partition predicates: $predicate") { + val metadataAndDataPredicates = splitMetadataAndDataPredicates(predicate, partitionCols) + assert(metadataAndDataPredicates._1.toString === partitionPredicate) + assert(metadataAndDataPredicates._2.toString === dataPredicate) + } + } + + // Map entry format: (given predicate -> expected rewritten predicate) + val rewriteTestCases = Map( + // single predicate on a partition column + predicate("=", col("part2"), ofTimestamp(12)) -> + "(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part2), date) = 12)", + // multiple predicates on partition columns joined with AND + predicate("AND", + predicate("=", col("part1"), ofInt(12)), + predicate(">=", col("part3"), ofString("sss"))) -> + """((partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 12) AND + |(ELEMENT_AT(column(`add`.`partitionValues`), part3) >= sss))""" + .stripMargin.replaceAll("\n", " "), + // multiple predicates on partition columns joined with OR + predicate("OR", + predicate("<=", col("part3"), ofString("sss")), + predicate("=", col("part1"), ofInt(2781))) -> + """((ELEMENT_AT(column(`add`.`partitionValues`), part3) <= sss) OR + |(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 2781))""" + .stripMargin.replaceAll("\n", " ") + ) + + rewriteTestCases.foreach { + case (predicate, expRewrittenPredicate) => + test(s"rewrite partition predicate on scan file schema: $predicate") { + val actRewrittenPredicate = + rewritePartitionPredicateOnScanFileSchema(predicate, partitionColsToType) + assert(actRewrittenPredicate.toString === expRewrittenPredicate) + } + } + + private def col(names: String*): Column = { + new Column(names.toArray) + } + + private def predicate(name: String, children: Expression*): Predicate = { + new Predicate(name, children.asJava) + } +} +