Skip to content

Commit

Permalink
[Kernel] Add partition pruning related utility methods
Browse files Browse the repository at this point in the history
## Description
Part of #2071 (Partition Pruning in Kernel). This PR adds the following utility methods:
  * Dividing `Predicate` given to the `ScanBuilder.withFilter` into data column and partition column predicates
  * Rewrite the partition column `Predicate` to refer to the columns in the scan file columnar batch with the appropriate partition value deserialization expressions applied.  

## How was this patch tested?
Added UTs
  • Loading branch information
vkorukanti authored Sep 28, 2023
1 parent 879df3c commit ca69895
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,4 @@ public Predicate getLeft() {
public Predicate getRight() {
return (Predicate) getChildren().get(1);
}

@Override
public String toString() {
return "(" + getLeft() + " AND " + getRight() + ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,4 @@ public Predicate getLeft() {
public Predicate getRight() {
return (Predicate) getChildren().get(1);
}

@Override
public String toString() {
return "(" + getLeft() + " OR " + getRight() + ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ public Predicate(String name, List<Expression> children) {

@Override
public String toString() {
if (COMPARATORS.contains(name)) {
if (BINARY_OPERATORS.contains(name)) {
return String.format("(%s %s %s)", children.get(0), name, children.get(1));
}
return super.toString();
}

private static final Set<String> COMPARATORS =
Stream.of("<", "<=", ">", ">=", "=").collect(Collectors.toSet());
private static final Set<String> BINARY_OPERATORS =
Stream.of("<", "<=", ">", ">=", "=", "AND", "OR").collect(Collectors.toSet());
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.delta.kernel.Scan;
import io.delta.kernel.client.TableClient;
import io.delta.kernel.data.Row;
import io.delta.kernel.expressions.Column;
import io.delta.kernel.fs.FileStatus;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.StringType;
Expand All @@ -42,6 +43,11 @@ private InternalScanFileUtils() {}

private static final String TABLE_ROOT_COL_NAME = "tableRoot";
private static final DataType TABLE_ROOT_DATA_TYPE = StringType.INSTANCE;
/**
* {@link Column} expression referring to the `partitionValues` in scan `add` file.
*/
public static final Column ADD_FILE_PARTITION_COL_REF =
new Column(new String[] {"add", "partitionValues"});

public static StructField TABLE_ROOT_STRUCT_FIELD = new StructField(
TABLE_ROOT_COL_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@

import java.math.BigDecimal;
import java.sql.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static java.util.Arrays.asList;

import io.delta.kernel.client.ExpressionHandler;
import io.delta.kernel.client.TableClient;
import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.data.ColumnarBatch;
import io.delta.kernel.expressions.ExpressionEvaluator;
import io.delta.kernel.expressions.Literal;
import io.delta.kernel.expressions.*;
import io.delta.kernel.types.*;
import io.delta.kernel.utils.Tuple2;
import static io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE;
import static io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE;

import io.delta.kernel.internal.InternalScanFileUtils;

public class PartitionUtils {
private PartitionUtils() {}
Expand Down Expand Up @@ -100,6 +103,126 @@ public static ColumnarBatch withPartitionColumns(
return dataBatch;
}

/**
* Split the given predicate into predicate on partition columns and predicate on data columns.
*
* @param predicate
* @param partitionColNames
* @return Tuple of partition column predicate and data column predicate.
*/
public static Tuple2<Predicate, Predicate> splitMetadataAndDataPredicates(
Predicate predicate,
Set<String> partitionColNames) {
String predicateName = predicate.getName();
List<Expression> children = predicate.getChildren();
if ("AND".equalsIgnoreCase(predicateName)) {
Predicate left = (Predicate) children.get(0);
Predicate right = (Predicate) children.get(1);
Tuple2<Predicate, Predicate> leftResult =
splitMetadataAndDataPredicates(left, partitionColNames);
Tuple2<Predicate, Predicate> rightResult =
splitMetadataAndDataPredicates(right, partitionColNames);

return new Tuple2<>(
combineWithAndOp(leftResult._1, rightResult._1),
combineWithAndOp(leftResult._2, rightResult._2));
}
if (hasNonPartitionColumns(children, partitionColNames)) {
return new Tuple2(ALWAYS_TRUE, predicate);
} else {
return new Tuple2<>(predicate, ALWAYS_TRUE);
}
}

/**
* Utility method to rewrite the partition predicate referring to the table schema as predicate
* referring to the {@code partitionValues} in scan files read from Delta log. The scan file
* batch is returned by the {@link io.delta.kernel.Scan#getScanFiles(TableClient)}.
* <p>
* E.g. given predicate on partition columns:
* {@code p1 = 'new york' && p2 >= 26} where p1 is of type string and p2 is of int
* Rewritten expression looks like:
* {@code element_at(Column('add', 'partitionValues'), 'p1') = 'new york'
* &&
* partition_value(element_at(Column('add', 'partitionValues'), 'p2'), 'integer') >= 26}
*
* The column `add.partitionValues` is a {@literal map(string -> string)} type. Each partition
* values is in string serialization format according to the Delta protocol. Expression
* `partition_value` deserializes the string value into the given partition column type value.
* String type partition values don't need any deserialization.
*
* @param predicate Predicate containing filters only on partition columns.
* @param partitionColNameTypes Map of partition columns and their types.
* @return
*/
public static Predicate rewritePartitionPredicateOnScanFileSchema(
Predicate predicate, Map<String, DataType> partitionColNameTypes) {
return new Predicate(
predicate.getName(),
predicate.getChildren().stream()
.map(child -> rewritePartitionColumnRef(child, partitionColNameTypes))
.collect(Collectors.toList()));
}

private static Expression rewritePartitionColumnRef(
Expression expression, Map<String, DataType> partitionColNameTypes) {
Column scanFilePartitionValuesRef = InternalScanFileUtils.ADD_FILE_PARTITION_COL_REF;
if (expression instanceof Column) {
Column column = (Column) expression;
String partColName = column.getNames()[0];
DataType partColType = partitionColNameTypes.get(partColName);

Expression elementAt =
new ScalarExpression(
"element_at",
asList(scanFilePartitionValuesRef, Literal.ofString(partColName)));

if (partColType instanceof StringType) {
return elementAt;
}

// Add expression to decode the partition value based on the partition column type.
return new PartitionValueExpression(elementAt, partColType);
} else if (expression instanceof Predicate) {
return rewritePartitionPredicateOnScanFileSchema(
(Predicate) expression, partitionColNameTypes);
}

return expression;
}

private static boolean hasNonPartitionColumns(
List<Expression> children,
Set<String> partitionColNames) {
for (Expression child : children) {
if (child instanceof Column) {
String[] names = ((Column) child).getNames();
// Partition columns are never of nested types.
if (names.length != 1 || !partitionColNames.contains(names[0])) {
return true;
}
} else {
return hasNonPartitionColumns(child.getChildren(), partitionColNames);
}
}
return false;
}

private static Predicate combineWithAndOp(Predicate left, Predicate right) {
String leftName = left.getName().toUpperCase();
String rightName = right.getName().toUpperCase();
if (leftName.equals("ALWAYS_FALSE") || rightName.equals("ALWAYS_FALSE")) {
return ALWAYS_FALSE;
}
if (leftName.equals("ALWAYS_TRUE")) {
return right;
}
if (rightName.equals("ALWAYS_TRUE")) {
return left;
}
return new And(left, right);
}

private static Literal literalForPartitionValue(DataType dataType, String partitionValue) {
if (partitionValue == null) {
return Literal.ofNull(dataType);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Copyright (2023) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.delta.kernel.internal.util

import java.util

import scala.collection.JavaConverters._

import io.delta.kernel.expressions._
import io.delta.kernel.expressions.Literal._
import io.delta.kernel.internal.util.PartitionUtils.{rewritePartitionPredicateOnScanFileSchema, splitMetadataAndDataPredicates}
import io.delta.kernel.types._
import org.scalatest.funsuite.AnyFunSuite

class PartitionUtilsSuite extends AnyFunSuite {
// Table schema
// Data columns: data1: int, data2: string, date3: struct(data31: boolean, data32: long)
// Partition columns: part1: int, part2: date, part3: string
private val partitionColsToType = new util.HashMap[String, DataType]() {
{
put("part1", IntegerType.INSTANCE)
put("part2", DateType.INSTANCE)
put("part3", StringType.INSTANCE)
}
}

private val partitionCols: java.util.Set[String] = partitionColsToType.keySet()

// Test cases for verifying partition of predicate into data and partition predicates
// Map entry format (predicate -> (partition predicate, data predicate)
val partitionTestCases = Map[Predicate, (String, String)](
// single predicate on a data column
predicate("=", col("data1"), ofInt(12)) ->
("ALWAYS_TRUE()", "(column(`data1`) = 12)"),
// multiple predicates on data columns joined with AND
predicate("AND",
predicate("=", col("data1"), ofInt(12)),
predicate(">=", col("data2"), ofString("sss"))) ->
("ALWAYS_TRUE()", "((column(`data1`) = 12) AND (column(`data2`) >= sss))"),
// multiple predicates on data columns joined with OR
predicate("OR",
predicate("<=", col("data2"), ofString("sss")),
predicate("=", col("data3", "data31"), ofBoolean(true))) ->
("ALWAYS_TRUE()", "((column(`data2`) <= sss) OR (column(`data3`.`data31`) = true))"),
// single predicate on a partition column
predicate("=", col("part1"), ofInt(12)) ->
("(column(`part1`) = 12)", "ALWAYS_TRUE()"),
// multiple predicates on partition columns joined with AND
predicate("AND",
predicate("=", col("part1"), ofInt(12)),
predicate(">=", col("part3"), ofString("sss"))) ->
("((column(`part1`) = 12) AND (column(`part3`) >= sss))", "ALWAYS_TRUE()"),
// multiple predicates on partition columns joined with OR
predicate("OR",
predicate("<=", col("part3"), ofString("sss")),
predicate("=", col("part1"), ofInt(2781))) ->
("((column(`part3`) <= sss) OR (column(`part1`) = 2781))", "ALWAYS_TRUE()"),

// predicates (each on data and partition column) joined with AND
predicate("AND",
predicate("=", col("data1"), ofInt(12)),
predicate(">=", col("part3"), ofString("sss"))) ->
("(column(`part3`) >= sss)", "(column(`data1`) = 12)"),

// predicates (each on data and partition column) joined with OR
predicate("OR",
predicate("=", col("data1"), ofInt(12)),
predicate(">=", col("part3"), ofString("sss"))) ->
("ALWAYS_TRUE()", "((column(`data1`) = 12) OR (column(`part3`) >= sss))"),

// predicates (multiple on data and partition columns) joined with AND
predicate("AND",
predicate("AND",
predicate("=", col("data1"), ofInt(12)),
predicate(">=", col("data2"), ofString("sss"))),
predicate("AND",
predicate("=", col("part1"), ofInt(12)),
predicate(">=", col("part3"), ofString("sss")))) ->
(
"((column(`part1`) = 12) AND (column(`part3`) >= sss))",
"((column(`data1`) = 12) AND (column(`data2`) >= sss))"
),

// predicates (multiple on data and partition columns joined with OR) joined with AND
predicate("AND",
predicate("OR",
predicate("=", col("data1"), ofInt(12)),
predicate(">=", col("data2"), ofString("sss"))),
predicate("OR",
predicate("=", col("part1"), ofInt(12)),
predicate(">=", col("part3"), ofString("sss")))) ->
(
"((column(`part1`) = 12) OR (column(`part3`) >= sss))",
"((column(`data1`) = 12) OR (column(`data2`) >= sss))"
),

// predicates (multiple on data and partition columns joined with OR) joined with OR
predicate("OR",
predicate("OR",
predicate("=", col("data1"), ofInt(12)),
predicate(">=", col("data2"), ofString("sss"))),
predicate("OR",
predicate("=", col("part1"), ofInt(12)),
predicate(">=", col("part3"), ofString("sss")))) ->
(
"ALWAYS_TRUE()",
"(((column(`data1`) = 12) OR (column(`data2`) >= sss)) OR " +
"((column(`part1`) = 12) OR (column(`part3`) >= sss)))"
),

// predicates (data and partitions compared in the same expression)
predicate("AND",
predicate("=", col("data1"), col("part1")),
predicate(">=", col("part3"), ofString("sss"))) ->
(
"(column(`part3`) >= sss)",
"(column(`data1`) = column(`part1`))"
)
)

partitionTestCases.foreach {
case (predicate, (partitionPredicate, dataPredicate)) =>
test(s"split predicate into data and partition predicates: $predicate") {
val metadataAndDataPredicates = splitMetadataAndDataPredicates(predicate, partitionCols)
assert(metadataAndDataPredicates._1.toString === partitionPredicate)
assert(metadataAndDataPredicates._2.toString === dataPredicate)
}
}

// Map entry format: (given predicate -> expected rewritten predicate)
val rewriteTestCases = Map(
// single predicate on a partition column
predicate("=", col("part2"), ofTimestamp(12)) ->
"(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part2), date) = 12)",
// multiple predicates on partition columns joined with AND
predicate("AND",
predicate("=", col("part1"), ofInt(12)),
predicate(">=", col("part3"), ofString("sss"))) ->
"""((partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 12) AND
|(ELEMENT_AT(column(`add`.`partitionValues`), part3) >= sss))"""
.stripMargin.replaceAll("\n", " "),
// multiple predicates on partition columns joined with OR
predicate("OR",
predicate("<=", col("part3"), ofString("sss")),
predicate("=", col("part1"), ofInt(2781))) ->
"""((ELEMENT_AT(column(`add`.`partitionValues`), part3) <= sss) OR
|(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 2781))"""
.stripMargin.replaceAll("\n", " ")
)

rewriteTestCases.foreach {
case (predicate, expRewrittenPredicate) =>
test(s"rewrite partition predicate on scan file schema: $predicate") {
val actRewrittenPredicate =
rewritePartitionPredicateOnScanFileSchema(predicate, partitionColsToType)
assert(actRewrittenPredicate.toString === expRewrittenPredicate)
}
}

private def col(names: String*): Column = {
new Column(names.toArray)
}

private def predicate(name: String, children: Expression*): Predicate = {
new Predicate(name, children.asJava)
}
}

0 comments on commit ca69895

Please sign in to comment.