Skip to content

Commit

Permalink
[Kernel] Add partition_value and element_at expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
vkorukanti committed Sep 24, 2023
1 parent 87f80ce commit 503cde6
Show file tree
Hide file tree
Showing 8 changed files with 581 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (2023) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.delta.kernel.expressions;

import java.util.Collections;
import java.util.List;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

import io.delta.kernel.annotation.Evolving;
import io.delta.kernel.types.DataType;

/**
* Expression to decode the serialized partition value into partition type value according the
* <a href=https://github.com/delta-io/delta/blob/master/PROTOCOL.md#partition-value-serialization>
* Delta Protocol spec</a>.
* <p>
* <ul>
* <li>Name: <code>partition_value</code>
* <li>Semantic: <code>partition_value(string, datatype)</code>. Decode the partition
* value of type <i>datatype</i> from the serialized string format.</li>
* </ul>
*
* @since 3.0.0
*/
@Evolving
public class PartitionValueExpression implements Expression {
private final DataType partitionValueType;
private final Expression serializedPartitionValue;

/**
* Create {@code partition_value} expression.
*
* @param serializedPartitionValue Input expression providing the partition values in
* serialized format.
* @param partitionDataType Partition data type to which string partition value is
* deserialized as according to the Delta Protocol.
*/
public PartitionValueExpression(
Expression serializedPartitionValue, DataType partitionDataType) {
this.serializedPartitionValue = requireNonNull(serializedPartitionValue);
this.partitionValueType = requireNonNull(partitionDataType);
}

/**
* Get the expression reference to the serialized partition value.
*/
public Expression getInput() {
return serializedPartitionValue;
}

/**
* Get the data type of the partition value.
*/
public DataType getDataType() {
return partitionValueType;
}

@Override
public List<Expression> getChildren() {
return Collections.singletonList(serializedPartitionValue);
}

@Override
public String toString() {
return format("partition_value(%s, %d)", serializedPartitionValue, partitionValueType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
* output value. A subclass of these expressions are of type {@link Predicate} whose result type is
* `boolean`. See {@link Predicate} for predicate type scalar expressions. Supported
* non-predicate type scalar expressions are listed below.
* TODO: Currently there aren't any. Will be added in future. An example one looks like this:
* <ol>
* <li>Name: <code>+</code>
* <li>Name: <code>element_at</code>
* <ul>
* <li>SQL semantic: <code>expr1 + expr2</code></li>
* <li>Semantic: <code>element_at(map, key)</code>. Return the value of given <i>key</i>
* from the <i>map</i> type input. Ex: `element_at(map(1, 'a', 2, 'b'), 2)` returns 'b'</li>
* <li>Since version: 3.0.0</li>
* </ul>
* </li>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector;
import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector;
import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;
import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.compare;
import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.evalNullability;
import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.*;
import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo;

/**
Expand Down Expand Up @@ -157,7 +156,7 @@ ExpressionTransformResult visitLiteral(Literal literal) {
ExpressionTransformResult visitColumn(Column column) {
String[] names = column.getNames();
DataType currentType = inputDataSchema;
for(int level = 0; level < names.length; level++) {
for (int level = 0; level < names.length; level++) {
assertColumnExists(currentType instanceof StructType, inputDataSchema, column);
StructType structSchema = ((StructType) currentType);
int ordinal = structSchema.indexOf(names[level]);
Expand All @@ -173,6 +172,42 @@ ExpressionTransformResult visitCast(ImplicitCastExpression cast) {
throw new UnsupportedOperationException("CAST expression is not expected.");
}

@Override
ExpressionTransformResult visitPartitionValue(PartitionValueExpression partitionValue) {
ExpressionTransformResult serializedPartValueInput = visit(partitionValue.getInput());
checkArgument(
serializedPartValueInput.outputType instanceof StringType,
"%s: expected string input, but got %s",
partitionValue, serializedPartValueInput.outputType);
DataType partitionColType = partitionValue.getDataType();
if (partitionColType instanceof StructType ||
partitionColType instanceof ArrayType ||
partitionColType instanceof MapType) {
throw new UnsupportedOperationException(
"unsupported partition data type: " + partitionColType);
}
return new ExpressionTransformResult(
new PartitionValueExpression(serializedPartValueInput.expression, partitionColType),
partitionColType);
}

@Override
ExpressionTransformResult visitElementAt(ScalarExpression elementAt) {
ExpressionTransformResult transformedMapInput = visit(childAt(elementAt, 0));
ExpressionTransformResult transformedLookupKey = visit(childAt(elementAt, 1));

ScalarExpression transformedExpression = ElementAtEvaluator.validateAndTransform(
elementAt,
transformedMapInput.expression,
transformedMapInput.outputType,
transformedLookupKey.expression,
transformedLookupKey.outputType);

return new ExpressionTransformResult(
transformedExpression,
((MapType) transformedMapInput.outputType).getValueType());
}

private Predicate validateIsPredicate(
Expression baseExpression,
ExpressionTransformResult result) {
Expand All @@ -187,9 +222,8 @@ private Predicate validateIsPredicate(
}

private Expression transformBinaryComparator(Predicate predicate) {
checkArgument(predicate.getChildren().size() == 2, "expected two inputs");
ExpressionTransformResult leftResult = visit(predicate.getChildren().get(0));
ExpressionTransformResult rightResult = visit(predicate.getChildren().get(1));
ExpressionTransformResult leftResult = visit(getLeft(predicate));
ExpressionTransformResult rightResult = visit(getRight(predicate));
Expression left = leftResult.expression;
Expression right = rightResult.expression;
if (!leftResult.outputType.equivalent(rightResult.outputType)) {
Expand Down Expand Up @@ -324,7 +358,7 @@ ColumnVector visitColumn(Column column) {
String[] names = column.getNames();
DataType currentType = input.getSchema();
ColumnVector columnVector = null;
for(int level = 0; level < names.length; level++) {
for (int level = 0; level < names.length; level++) {
assertColumnExists(currentType instanceof StructType, input.getSchema(), column);
StructType structSchema = ((StructType) currentType);
int ordinal = structSchema.indexOf(names[level]);
Expand All @@ -347,6 +381,19 @@ ColumnVector visitCast(ImplicitCastExpression cast) {
return cast.eval(inputResult);
}

@Override
ColumnVector visitPartitionValue(PartitionValueExpression partitionValue) {
ColumnVector input = visit(partitionValue.getInput());
return PartitionValueEvaluator.eval(input, partitionValue.getDataType());
}

@Override
ColumnVector visitElementAt(ScalarExpression elementAt) {
ColumnVector map = visit(childAt(elementAt, 0));
ColumnVector lookupKey = visit(childAt(elementAt, 1));
return ElementAtEvaluator.eval(map, lookupKey);
}

/**
* Utility method to evaluate inputs to the binary input expression. Also validates the
* evaluated expression result {@link ColumnVector}s are of the same size.
Expand All @@ -355,9 +402,8 @@ ColumnVector visitCast(ImplicitCastExpression cast) {
* @return Triplet of (result vector size, left operand result, left operand result)
*/
private PredicateChildrenEvalResult evalBinaryExpressionChildren(Predicate predicate) {
checkArgument(predicate.getChildren().size() == 2, "expected two inputs");
ColumnVector left = visit(predicate.getChildren().get(0));
ColumnVector right = visit(predicate.getChildren().get(1));
ColumnVector left = visit(getLeft(predicate));
ColumnVector right = visit(getRight(predicate));
checkArgument(
left.getSize() == right.getSize(),
"Left and right operand returned different results: left=%d, right=d",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright (2023) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.delta.kernel.defaults.internal.expressions;

import java.util.Arrays;
import static java.lang.String.format;

import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.expressions.Expression;
import io.delta.kernel.expressions.ScalarExpression;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.MapType;
import io.delta.kernel.types.StringType;
import io.delta.kernel.utils.Utils;

import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;
import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo;

/**
* Utility methods to evaluate {@code element_at} expression.
*/
class ElementAtEvaluator {
private ElementAtEvaluator() {}

/**
* Validate and transform the {@code element_at} expression with given validated and
* transformed inputs.
*/
static ScalarExpression validateAndTransform(
ScalarExpression elementAt,
Expression mapInput,
DataType mapInputType,
Expression lookupKey,
DataType lookupKeyType) {

MapType asMapType = validateSupportedMapType(elementAt, mapInputType);
DataType keyTypeFromMapInput = asMapType.getKeyType();

if (!keyTypeFromMapInput.equivalent(lookupKeyType)) {
if (canCastTo(lookupKeyType, keyTypeFromMapInput)) {
lookupKey = new ImplicitCastExpression(lookupKey, keyTypeFromMapInput);
} else {
throw new UnsupportedOperationException(format(
"%s: lookup key type (%s) is different from the map key type (%s)",
elementAt, lookupKeyType, asMapType.getKeyType()));
}
}
return new ScalarExpression(elementAt.getName(), Arrays.asList(mapInput, lookupKey));
}

/**
* Utility method to evaluate the {@code element_at} on given map and key vectors.
* @param map {@link ColumnVector} of {@code map(string, string)} type.
* @param lookupKey {@link ColumnVector} of {@code string} type.
* @return
*/
static ColumnVector eval(ColumnVector map, ColumnVector lookupKey) {
return new ColumnVector() {
// Store the last lookup value to avoid multiple looks up for same row id.
private int lastLookupRowId = -1;
private Object lastLookupValue = null;

@Override
public DataType getDataType() {
return ((MapType) map.getDataType()).getValueType();
}

@Override
public int getSize() {
return map.getSize();
}

@Override
public void close() {
Utils.closeCloseables(map, lookupKey);
}

@Override
public boolean isNullAt(int rowId) {
if (rowId == lastLookupRowId) {
return lastLookupValue == null;
}
return map.isNullAt(rowId) || lookupValue(rowId) == null;
}

@Override
public String getString(int rowId) {
lookupValue(rowId);
return lastLookupValue == null ? null : (String) lastLookupValue;
}

private Object lookupValue(int rowId) {
if (rowId == lastLookupRowId) {
return lastLookupValue;
}
// TODO: this needs to be updated after the new way of accessing the complex
// types is merged.
lastLookupRowId = rowId;
String keyValue = lookupKey.getString(rowId);
lastLookupValue = map.getMap(rowId).get(keyValue);
return lastLookupValue;
}
};
}

private static MapType validateSupportedMapType(Expression elementAt, DataType mapInputType) {
checkArgument(
mapInputType instanceof MapType,
"expected a map type input as first argument: " + elementAt);
MapType asMapType = (MapType) mapInputType;
// TODO: we may extend type support in future, but for the need is just a look
// in map(string, string).
if (asMapType.getKeyType().equivalent(StringType.INSTANCE) &&
asMapType.getValueType().equivalent(StringType.INSTANCE)) {
return asMapType;
}
throw new UnsupportedOperationException(
format("%s: Supported only on type map(string, string) input data", elementAt));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

import java.math.BigDecimal;
import java.util.Comparator;
import java.util.List;
import static java.lang.String.format;

import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.expressions.Expression;
import io.delta.kernel.types.*;

import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;
Expand Down Expand Up @@ -46,7 +49,7 @@ static boolean[] evalNullability(ColumnVector left, ColumnVector right) {
* Utility method to compare the left and right according to the natural ordering
* and return an integer array where each row contains the comparison result (-1, 0, 1) for
* corresponding rows in the input vectors compared.
*
* <p>
* Only primitive data types are supported.
*/
static int[] compare(ColumnVector left, ColumnVector right) {
Expand Down Expand Up @@ -174,4 +177,30 @@ static void compareBinary(ColumnVector left, ColumnVector right, int[] result) {
}
}
}

/**
* Utility method to return the left child of the binary input expression
*/
static Expression getLeft(Expression expression) {
List<Expression> children = expression.getChildren();
checkArgument(
children.size() == 2,
format("%s: expected two inputs, but got %s", expression, children.size()));
return children.get(0);
}

/**
* Utility method to return the right child of the binary input expression
*/
static Expression getRight(Expression expression) {
List<Expression> children = expression.getChildren();
checkArgument(
children.size() == 2,
format("%s: expected two inputs, but got %s", expression, children.size()));
return children.get(1);
}

static Expression childAt(Expression expression, int index) {
return expression.getChildren().get(index);
}
}
Loading

0 comments on commit 503cde6

Please sign in to comment.