Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
  • Loading branch information
vkorukanti committed Sep 12, 2023
1 parent 83b923e commit d161973
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 89 deletions.
2 changes: 1 addition & 1 deletion icebergShaded/generate_iceberg_jars.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def generate_iceberg_jars():
# Search for all glob results
results = glob.glob(compiled_jar_abs_pattern)
# Compiled jars will include tests, sources, javadocs; exclude them
results = list(filter(lambda result: all(x not in result for x in ["test", "source", "javadoc"]), results))
results = list(filter(lambda result: all(x not in result for x in ["test", "sources", "javadoc"]), results))

if len(results) == 0:
raise Exception("Could not find the jar: " + compled_jar_rel_glob_pattern)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,18 @@ public static Literal ofDecimal(BigDecimal value, int precision, int scale) {
* @return a null {@link Literal} with the given data type
*/
public static Literal ofNull(DataType dataType) {
if (dataType instanceof ArrayType
|| dataType instanceof MapType
|| dataType instanceof StructType) {
throw new IllegalArgumentException(dataType + " is an invalid data type for Literal.");
}
return new Literal(null, dataType);
}

private final Object value;
private final DataType dataType;

private Literal(Object value, DataType dataType) {
if (dataType instanceof ArrayType
|| dataType instanceof MapType
|| dataType instanceof StructType) {
throw new IllegalArgumentException(dataType + " is an invalid data type for Literal.");
}
this.value = value;
this.dataType = dataType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,26 @@
*/
package io.delta.kernel.expressions;

import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import io.delta.kernel.client.ExpressionHandler;

/**
* Defines predicate scalar expression which is an extension of {@link ScalarExpression}
* that evaluates to true, false, or null for each input row.
* <p>
* Currently, Kernel allows following predicate scalar expressions.
* Currently, implementations of {@link ExpressionHandler} requires support for at least the
* following scalar expressions.
* <ol>
* <li>Name: <code>=</code>
* <ul>
* <li>SQL semantic: <code>expr1 = expr2</code></li>
* <li>Since version: 3.0.0</li>
* </ul>
* </li>
* <li>Name: <code>&lt;&gt;</code>
* <ul>
* <li>SQL semantic: <code>expr1 &lt;&gt; expr2</code></li>
* <li>Since version: 3.0.0</li>
* </ul>
* </li>
* <li>Name: <code>&lt;</code>
* <ul>
* <li>SQL semantic: <code>expr1 &lt; expr2</code></li>
Expand Down Expand Up @@ -93,9 +92,12 @@ public Predicate(String name, List<Expression> children) {

@Override
public String toString() {
if (Arrays.asList("<", "<=", ">", ">=", "=").contains(name)) {
if (COMPARATORS.contains(name)) {
return String.format("(%s %s %s)", children.get(0), name, children.get(1));
}
return super.toString();
}

private static final Set<String> COMPARATORS =
Stream.of("<", "<=", ">", ">=", "=").collect(Collectors.toSet());
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (2023) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.delta.kernel.expressions

import io.delta.kernel.types._
import org.scalatest.funsuite.AnyFunSuite

class ExpressionsSuite extends AnyFunSuite {
test("expressions: unsupported literal data types") {
val ex1 = intercept[IllegalArgumentException] {
Literal.ofNull(new ArrayType(IntegerType.INSTANCE, true))
}
assert(ex1.getMessage.contains("array[integer] is an invalid data type for Literal."))

val ex2 = intercept[IllegalArgumentException] {
Literal.ofNull(new MapType(IntegerType.INSTANCE, IntegerType.INSTANCE, true))
}
assert(ex2.getMessage.contains("map[integer, integer] is an invalid data type for Literal."))

val ex3 = intercept[IllegalArgumentException] {
Literal.ofNull(new StructType().add("s1", BooleanType.INSTANCE))
}
assert(ex3.getMessage.matches("struct.* is an invalid data type for Literal."))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.compare;
import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.evalNullability;
import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo;
import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.createCastExpression;

/**
* Implementation of {@link ExpressionEvaluator} for default {@link ExpressionHandler}.
Expand Down Expand Up @@ -192,9 +191,9 @@ private Expression transformBinaryComparator(Predicate predicate) {
Expression right = rightResult.expression;
if (!leftResult.outputType.equivalent(rightResult.outputType)) {
if (canCastTo(leftResult.outputType, rightResult.outputType)) {
left = createCastExpression(left, rightResult.outputType);
left = new ImplicitCastExpression(left, rightResult.outputType);
} else if (canCastTo(rightResult.outputType, leftResult.outputType)) {
right = createCastExpression(right, leftResult.outputType);
right = new ImplicitCastExpression(right, leftResult.outputType);
} else {
String msg = format("%s: operands are of different types which are not " +
"comparable: left type=%s, right type=%s",
Expand Down Expand Up @@ -331,7 +330,7 @@ ColumnVector visitColumn(Column column) {
@Override
ColumnVector visitCast(ImplicitCastExpression cast) {
ColumnVector inputResult = visit(cast.getInput());
return ImplicitCastExpression.evalCastExpression(cast, inputResult);
return cast.eval(inputResult);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.util.List;
import java.util.Locale;
import static java.util.stream.Collectors.joining;

import io.delta.kernel.expressions.*;
import static io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE;
Expand Down Expand Up @@ -73,8 +74,7 @@ private R visitScalarExpression(ScalarExpression expression) {
return visitAnd(
new And(elemAsPredicate(children, 0), elemAsPredicate(children, 1)));
case "OR":
return visitOr(
new Or(elemAsPredicate(children, 0), elemAsPredicate(children, 1)));
return visitOr(new Or(elemAsPredicate(children, 0), elemAsPredicate(children, 1)));
case "=":
case "<":
case "<=":
Expand All @@ -88,8 +88,15 @@ private R visitScalarExpression(ScalarExpression expression) {
}

private static Predicate elemAsPredicate(List<Expression> expressions, int index) {
// Exception may happen due to connector impl issues than user issues, so no need to be
// very specific about the exception messages.
if (expressions.size() <= index) {
throw new RuntimeException(
String.format("Trying to access invalid entry (%d) in list %s", index,
expressions.stream().map(Object::toString).collect(joining(","))));
}
Expression elemExpression = expressions.get(index);
if (!(elemExpression instanceof Predicate)) {
throw new RuntimeException("Expected a predicate, but got " + elemExpression);
}
return (Predicate) expressions.get(index);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ final class ImplicitCastExpression implements Expression {
private final Expression input;
private final DataType outputType;

/**
* Create a cast around the given input expression to specified output data
* type. It is the responsibility of the caller to validate the input expression can be cast
* to the new type using {@link #canCastTo(DataType, DataType)}
*/
ImplicitCastExpression(Expression input, DataType outputType) {
this.input = requireNonNull(input, "input is null");
this.outputType = requireNonNull(outputType, "outputType is null");
Expand All @@ -67,6 +72,32 @@ public List<Expression> getChildren() {
return Collections.singletonList(input);
}

/**
* Evaluate the given column expression on the input {@link ColumnVector}.
*
* @param input {@link ColumnVector} data of the input to the cast expression.
* @return {@link ColumnVector} result applying target type casting on every element in the
* input {@link ColumnVector}.
*/
ColumnVector eval(ColumnVector input) {
String fromTypeStr = input.getDataType().toString();
switch (fromTypeStr) {
case "byte":
return new ByteUpConverter(outputType, input);
case "short":
return new ShortUpConverter(outputType, input);
case "integer":
return new IntUpConverter(outputType, input);
case "long":
return new LongUpConverter(outputType, input);
case "float":
return new FloatUpConverter(outputType, input);
default:
throw new UnsupportedOperationException(
format("Cast from %s is not supported", fromTypeStr));
}
}

/**
* Map containing for each type what are the target cast types can be.
*/
Expand Down Expand Up @@ -94,46 +125,6 @@ static boolean canCastTo(DataType from, DataType to) {
UP_CASTABLE_TYPE_TABLE.get(fromStr).contains(toStr);
}

/**
* Utility method to create a cast around the given input expression to specified output data
* type. It is the responsibility of the caller to validate the input expression can be cast
* to the new type using {@link #canCastTo(DataType, DataType)}
*
* @return
*/
static ImplicitCastExpression createCastExpression(Expression inputExpression,
DataType outputType) {
return new ImplicitCastExpression(inputExpression, outputType);
}

/**
* Utility method to evaluate the given column expression on the input {@link ColumnVector}.
*
* @param cast Cast expression.
* @param input {@link ColumnVector} data of the input to the cast expression.
* @return {@link ColumnVector} result applying target type casting on every element in the
* input {@link ColumnVector}.
*/
static ColumnVector evalCastExpression(ImplicitCastExpression cast, ColumnVector input) {
String fromTypeStr = input.getDataType().toString();
DataType toType = cast.getOutputType();
switch (fromTypeStr) {
case "byte":
return new ByteUpConverter(toType, input);
case "short":
return new ShortUpConverter(toType, input);
case "integer":
return new IntUpConverter(toType, input);
case "long":
return new LongUpConverter(toType, input);
case "float":
return new FloatUpConverter(toType, input);
default:
throw new UnsupportedOperationException(
format("Cast from %s is not supported", fromTypeStr));
}
}

/**
* Base class for up casting {@link ColumnVector} data.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,6 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with TestUtils {
}
}

// Literals of complex types are not supported.
// TODO: should this be moved to kernel-api module?
test("expression: unsupported literal data types") {
val ex1 = intercept[IllegalArgumentException] {
Literal.ofNull(new ArrayType(IntegerType.INSTANCE, true))
}
assert(ex1.getMessage.contains("array[integer] is an invalid data type for Literal."))

val ex2 = intercept[IllegalArgumentException] {
Literal.ofNull(new MapType(IntegerType.INSTANCE, IntegerType.INSTANCE, true))
}
assert(ex2.getMessage.contains("map[integer, integer] is an invalid data type for Literal."))

val ex3 = intercept[IllegalArgumentException] {
Literal.ofNull(new StructType().add("s1", BooleanType.INSTANCE))
}
assert(ex3.getMessage.matches("struct.* is an invalid data type for Literal."))
}

SIMPLE_TYPES.foreach { dataType =>
test(s"evaluate expression: column of type $dataType") {
val batchSize = 78;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package io.delta.kernel.defaults.internal.expressions

import io.delta.kernel.data.ColumnVector
import io.delta.kernel.defaults.internal.data.vector.VectorUtils.getValueAsObject
import io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.{canCastTo, evalCastExpression}
import io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo
import io.delta.kernel.defaults.TestUtils
import io.delta.kernel.expressions.Column
import io.delta.kernel.types._
Expand Down Expand Up @@ -60,10 +60,9 @@ class ImplicitCastExpressionSuite extends AnyFunSuite with TestUtils {
val fromType = castPair._1
val toType = castPair._2
val inputVector = testData(87, fromType, (rowId) => rowId % 7 == 0)
val outputVector = evalCastExpression(
new ImplicitCastExpression(new Column("id"), toType),
inputVector)
checkCastOutput(inputVector, outputVector)
val outputVector = new ImplicitCastExpression(new Column("id"), toType)
.eval(inputVector)
checkCastOutput(inputVector, toType, outputVector)
}
}

Expand Down Expand Up @@ -110,9 +109,10 @@ class ImplicitCastExpressionSuite extends AnyFunSuite with TestUtils {
// which the callers can cast to appropriate numerical type.
private def generateValue(rowId: Int): Double = rowId * 2.76 + 7623

private def checkCastOutput(input: ColumnVector, output: ColumnVector): Unit = {
private def checkCastOutput(input: ColumnVector, toType: DataType, output: ColumnVector): Unit = {
assert(input.getSize === output.getSize)
Seq(0, input.getSize).foreach { rowId =>
assert(toType === output.getDataType)
Seq.range(0, input.getSize).foreach { rowId =>
assert(input.isNullAt(rowId) === output.isNullAt(rowId))
assert(getValueAsObject(input, rowId) === getValueAsObject(output, rowId))
}
Expand Down

0 comments on commit d161973

Please sign in to comment.