From 05e1e250e383057d9cf09e49dd520b9d206586e6 Mon Sep 17 00:00:00 2001 From: Venki Korukanti Date: Tue, 19 Sep 2023 16:14:03 -0700 Subject: [PATCH 1/2] [Kernel] Add `partition_value` and `element_at` expressions --- .../expressions/PartitionValueExpression.java | 81 ++++++++++ .../kernel/expressions/ScalarExpression.java | 6 +- .../DefaultExpressionEvaluator.java | 66 ++++++-- .../expressions/ElementAtEvaluator.java | 133 +++++++++++++++ .../internal/expressions/ExpressionUtils.java | 31 +++- .../expressions/ExpressionVisitor.java | 10 +- .../expressions/PartitionValueEvaluator.java | 119 ++++++++++++++ .../DefaultExpressionEvaluatorSuite.scala | 152 +++++++++++++++++- 8 files changed, 581 insertions(+), 17 deletions(-) create mode 100644 kernel/kernel-api/src/main/java/io/delta/kernel/expressions/PartitionValueExpression.java create mode 100644 kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java create mode 100644 kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/PartitionValueEvaluator.java diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/PartitionValueExpression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/PartitionValueExpression.java new file mode 100644 index 00000000000..b793c428702 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/PartitionValueExpression.java @@ -0,0 +1,81 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.expressions; + +import java.util.Collections; +import java.util.List; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.annotation.Evolving; +import io.delta.kernel.types.DataType; + +/** + * Expression to decode the serialized partition value into partition type value according the + * + * Delta Protocol spec. + *

+ *

+ * + * @since 3.0.0 + */ +@Evolving +public class PartitionValueExpression implements Expression { + private final DataType partitionValueType; + private final Expression serializedPartitionValue; + + /** + * Create {@code partition_value} expression. + * + * @param serializedPartitionValue Input expression providing the partition values in + * serialized format. + * @param partitionDataType Partition data type to which string partition value is + * deserialized as according to the Delta Protocol. + */ + public PartitionValueExpression( + Expression serializedPartitionValue, DataType partitionDataType) { + this.serializedPartitionValue = requireNonNull(serializedPartitionValue); + this.partitionValueType = requireNonNull(partitionDataType); + } + + /** + * Get the expression reference to the serialized partition value. + */ + public Expression getInput() { + return serializedPartitionValue; + } + + /** + * Get the data type of the partition value. + */ + public DataType getDataType() { + return partitionValueType; + } + + @Override + public List getChildren() { + return Collections.singletonList(serializedPartitionValue); + } + + @Override + public String toString() { + return format("partition_value(%s, %s)", serializedPartitionValue, partitionValueType); + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java index 7b8dd76a1dd..5797e599724 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java @@ -29,11 +29,11 @@ * output value. A subclass of these expressions are of type {@link Predicate} whose result type is * `boolean`. See {@link Predicate} for predicate type scalar expressions. Supported * non-predicate type scalar expressions are listed below. - * TODO: Currently there aren't any. Will be added in future. An example one looks like this: *
    - *
  1. Name: + + *
  2. Name: element_at *
      - *
    • SQL semantic: expr1 + expr2
    • + *
    • Semantic: element_at(map, key). Return the value of given key + * from the map type input. Ex: `element_at(map(1, 'a', 2, 'b'), 2)` returns 'b'
    • *
    • Since version: 3.0.0
    • *
    *
  3. diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java index 89302e0c8a4..c7b57dec889 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -29,8 +29,7 @@ import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector; import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector; import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; -import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.compare; -import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.evalNullability; +import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.*; import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo; /** @@ -157,7 +156,7 @@ ExpressionTransformResult visitLiteral(Literal literal) { ExpressionTransformResult visitColumn(Column column) { String[] names = column.getNames(); DataType currentType = inputDataSchema; - for(int level = 0; level < names.length; level++) { + for (int level = 0; level < names.length; level++) { assertColumnExists(currentType instanceof StructType, inputDataSchema, column); StructType structSchema = ((StructType) currentType); int ordinal = structSchema.indexOf(names[level]); @@ -173,6 +172,42 @@ ExpressionTransformResult visitCast(ImplicitCastExpression cast) { throw new UnsupportedOperationException("CAST expression is not expected."); } + @Override + ExpressionTransformResult visitPartitionValue(PartitionValueExpression partitionValue) { + ExpressionTransformResult serializedPartValueInput = visit(partitionValue.getInput()); + checkArgument( + serializedPartValueInput.outputType instanceof StringType, + "%s: expected string input, but got %s", + partitionValue, serializedPartValueInput.outputType); + DataType partitionColType = partitionValue.getDataType(); + if (partitionColType instanceof StructType || + partitionColType instanceof ArrayType || + partitionColType instanceof MapType) { + throw new UnsupportedOperationException( + "unsupported partition data type: " + partitionColType); + } + return new ExpressionTransformResult( + new PartitionValueExpression(serializedPartValueInput.expression, partitionColType), + partitionColType); + } + + @Override + ExpressionTransformResult visitElementAt(ScalarExpression elementAt) { + ExpressionTransformResult transformedMapInput = visit(childAt(elementAt, 0)); + ExpressionTransformResult transformedLookupKey = visit(childAt(elementAt, 1)); + + ScalarExpression transformedExpression = ElementAtEvaluator.validateAndTransform( + elementAt, + transformedMapInput.expression, + transformedMapInput.outputType, + transformedLookupKey.expression, + transformedLookupKey.outputType); + + return new ExpressionTransformResult( + transformedExpression, + ((MapType) transformedMapInput.outputType).getValueType()); + } + private Predicate validateIsPredicate( Expression baseExpression, ExpressionTransformResult result) { @@ -187,9 +222,8 @@ private Predicate validateIsPredicate( } private Expression transformBinaryComparator(Predicate predicate) { - checkArgument(predicate.getChildren().size() == 2, "expected two inputs"); - ExpressionTransformResult leftResult = visit(predicate.getChildren().get(0)); - ExpressionTransformResult rightResult = visit(predicate.getChildren().get(1)); + ExpressionTransformResult leftResult = visit(getLeft(predicate)); + ExpressionTransformResult rightResult = visit(getRight(predicate)); Expression left = leftResult.expression; Expression right = rightResult.expression; if (!leftResult.outputType.equivalent(rightResult.outputType)) { @@ -324,7 +358,7 @@ ColumnVector visitColumn(Column column) { String[] names = column.getNames(); DataType currentType = input.getSchema(); ColumnVector columnVector = null; - for(int level = 0; level < names.length; level++) { + for (int level = 0; level < names.length; level++) { assertColumnExists(currentType instanceof StructType, input.getSchema(), column); StructType structSchema = ((StructType) currentType); int ordinal = structSchema.indexOf(names[level]); @@ -347,6 +381,19 @@ ColumnVector visitCast(ImplicitCastExpression cast) { return cast.eval(inputResult); } + @Override + ColumnVector visitPartitionValue(PartitionValueExpression partitionValue) { + ColumnVector input = visit(partitionValue.getInput()); + return PartitionValueEvaluator.eval(input, partitionValue.getDataType()); + } + + @Override + ColumnVector visitElementAt(ScalarExpression elementAt) { + ColumnVector map = visit(childAt(elementAt, 0)); + ColumnVector lookupKey = visit(childAt(elementAt, 1)); + return ElementAtEvaluator.eval(map, lookupKey); + } + /** * Utility method to evaluate inputs to the binary input expression. Also validates the * evaluated expression result {@link ColumnVector}s are of the same size. @@ -355,9 +402,8 @@ ColumnVector visitCast(ImplicitCastExpression cast) { * @return Triplet of (result vector size, left operand result, left operand result) */ private PredicateChildrenEvalResult evalBinaryExpressionChildren(Predicate predicate) { - checkArgument(predicate.getChildren().size() == 2, "expected two inputs"); - ColumnVector left = visit(predicate.getChildren().get(0)); - ColumnVector right = visit(predicate.getChildren().get(1)); + ColumnVector left = visit(getLeft(predicate)); + ColumnVector right = visit(getRight(predicate)); checkArgument( left.getSize() == right.getSize(), "Left and right operand returned different results: left=%d, right=d", diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java new file mode 100644 index 00000000000..1436f18e6fc --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java @@ -0,0 +1,133 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions; + +import java.util.Arrays; +import static java.lang.String.format; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.ScalarExpression; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.MapType; +import io.delta.kernel.types.StringType; +import io.delta.kernel.utils.Utils; + +import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; +import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo; + +/** + * Utility methods to evaluate {@code element_at} expression. + */ +class ElementAtEvaluator { + private ElementAtEvaluator() {} + + /** + * Validate and transform the {@code element_at} expression with given validated and + * transformed inputs. + */ + static ScalarExpression validateAndTransform( + ScalarExpression elementAt, + Expression mapInput, + DataType mapInputType, + Expression lookupKey, + DataType lookupKeyType) { + + MapType asMapType = validateSupportedMapType(elementAt, mapInputType); + DataType keyTypeFromMapInput = asMapType.getKeyType(); + + if (!keyTypeFromMapInput.equivalent(lookupKeyType)) { + if (canCastTo(lookupKeyType, keyTypeFromMapInput)) { + lookupKey = new ImplicitCastExpression(lookupKey, keyTypeFromMapInput); + } else { + throw new UnsupportedOperationException(format( + "%s: lookup key type (%s) is different from the map key type (%s)", + elementAt, lookupKeyType, asMapType.getKeyType())); + } + } + return new ScalarExpression(elementAt.getName(), Arrays.asList(mapInput, lookupKey)); + } + + /** + * Utility method to evaluate the {@code element_at} on given map and key vectors. + * @param map {@link ColumnVector} of {@code map(string, string)} type. + * @param lookupKey {@link ColumnVector} of {@code string} type. + * @return + */ + static ColumnVector eval(ColumnVector map, ColumnVector lookupKey) { + return new ColumnVector() { + // Store the last lookup value to avoid multiple looks up for same row id. + private int lastLookupRowId = -1; + private Object lastLookupValue = null; + + @Override + public DataType getDataType() { + return ((MapType) map.getDataType()).getValueType(); + } + + @Override + public int getSize() { + return map.getSize(); + } + + @Override + public void close() { + Utils.closeCloseables(map, lookupKey); + } + + @Override + public boolean isNullAt(int rowId) { + if (rowId == lastLookupRowId) { + return lastLookupValue == null; + } + return map.isNullAt(rowId) || lookupValue(rowId) == null; + } + + @Override + public String getString(int rowId) { + lookupValue(rowId); + return lastLookupValue == null ? null : (String) lastLookupValue; + } + + private Object lookupValue(int rowId) { + if (rowId == lastLookupRowId) { + return lastLookupValue; + } + // TODO: this needs to be updated after the new way of accessing the complex + // types is merged. + lastLookupRowId = rowId; + String keyValue = lookupKey.getString(rowId); + lastLookupValue = map.getMap(rowId).get(keyValue); + return lastLookupValue; + } + }; + } + + private static MapType validateSupportedMapType(Expression elementAt, DataType mapInputType) { + checkArgument( + mapInputType instanceof MapType, + "expected a map type input as first argument: " + elementAt); + MapType asMapType = (MapType) mapInputType; + // TODO: we may extend type support in future, but for the need is just a look + // in map(string, string). + if (asMapType.getKeyType().equivalent(StringType.INSTANCE) && + asMapType.getValueType().equivalent(StringType.INSTANCE)) { + return asMapType; + } + throw new UnsupportedOperationException( + format("%s: Supported only on type map(string, string) input data", elementAt)); + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java index 8997afc9fc3..5ff13573b4b 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java @@ -17,8 +17,11 @@ import java.math.BigDecimal; import java.util.Comparator; +import java.util.List; +import static java.lang.String.format; import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.expressions.Expression; import io.delta.kernel.types.*; import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; @@ -46,7 +49,7 @@ static boolean[] evalNullability(ColumnVector left, ColumnVector right) { * Utility method to compare the left and right according to the natural ordering * and return an integer array where each row contains the comparison result (-1, 0, 1) for * corresponding rows in the input vectors compared. - * + *

    * Only primitive data types are supported. */ static int[] compare(ColumnVector left, ColumnVector right) { @@ -174,4 +177,30 @@ static void compareBinary(ColumnVector left, ColumnVector right, int[] result) { } } } + + /** + * Utility method to return the left child of the binary input expression + */ + static Expression getLeft(Expression expression) { + List children = expression.getChildren(); + checkArgument( + children.size() == 2, + format("%s: expected two inputs, but got %s", expression, children.size())); + return children.get(0); + } + + /** + * Utility method to return the right child of the binary input expression + */ + static Expression getRight(Expression expression) { + List children = expression.getChildren(); + checkArgument( + children.size() == 2, + format("%s: expected two inputs, but got %s", expression, children.size())); + return children.get(1); + } + + static Expression childAt(Expression expression, int index) { + return expression.getChildren().get(index); + } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java index 0e0e5000687..62defe46d50 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -47,8 +47,14 @@ abstract class ExpressionVisitor { abstract R visitCast(ImplicitCastExpression cast); + abstract R visitPartitionValue(PartitionValueExpression partitionValue); + + abstract R visitElementAt(ScalarExpression elementAt); + final R visit(Expression expression) { - if (expression instanceof ScalarExpression) { + if (expression instanceof PartitionValueExpression) { + return visitPartitionValue((PartitionValueExpression) expression); + } else if (expression instanceof ScalarExpression) { return visitScalarExpression((ScalarExpression) expression); } else if (expression instanceof Literal) { return visitLiteral((Literal) expression); @@ -81,6 +87,8 @@ private R visitScalarExpression(ScalarExpression expression) { case ">": case ">=": return visitComparator(new Predicate(name, children)); + case "ELEMENT_AT": + return visitElementAt(expression); default: throw new UnsupportedOperationException( String.format("Scalar expression `%s` is not supported.", name)); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/PartitionValueEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/PartitionValueEvaluator.java new file mode 100644 index 00000000000..57fcf83130d --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/PartitionValueEvaluator.java @@ -0,0 +1,119 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions; + +import java.math.BigDecimal; +import java.sql.Date; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.DateType; +import io.delta.kernel.types.IntegerType; +import io.delta.kernel.types.LongType; + +import io.delta.kernel.internal.util.InternalUtils; + +/** + * Utility methods to evaluate {@code partition_value} expression + */ +class PartitionValueEvaluator { + /** + * Evaluate the {@code partition_value} expression for given input column vector and generate + * a column vector with decoded values according to the given partition type. + */ + static ColumnVector eval(ColumnVector input, DataType partitionType) { + return new ColumnVector() { + @Override + public DataType getDataType() { + return partitionType; + } + + @Override + public int getSize() { + return input.getSize(); + } + + @Override + public void close() { + input.close(); + } + + @Override + public boolean isNullAt(int rowId) { + return input.isNullAt(rowId); + } + + @Override + public boolean getBoolean(int rowId) { + return Boolean.parseBoolean(input.getString(rowId)); + } + + @Override + public byte getByte(int rowId) { + return Byte.parseByte(input.getString(rowId)); + } + + @Override + public short getShort(int rowId) { + return Short.parseShort(input.getString(rowId)); + } + + @Override + public int getInt(int rowId) { + if (partitionType.equivalent(IntegerType.INSTANCE)) { + return Integer.parseInt(input.getString(rowId)); + } else if (partitionType.equivalent(DateType.INSTANCE)) { + return InternalUtils.daysSinceEpoch(Date.valueOf(input.getString(rowId))); + } + throw new UnsupportedOperationException("Invalid value request for data type"); + } + + @Override + public long getLong(int rowId) { + if (partitionType.equivalent(LongType.INSTANCE)) { + return Long.parseLong(input.getString(rowId)); + } + // TODO: partition value of timestamp type are not yet supported + throw new UnsupportedOperationException("Invalid value request for data type"); + } + + @Override + public float getFloat(int rowId) { + return Float.parseFloat(input.getString(rowId)); + } + + @Override + public double getDouble(int rowId) { + return Double.parseDouble(input.getString(rowId)); + } + + @Override + public byte[] getBinary(int rowId) { + return input.isNullAt(rowId) ? null : input.getString(rowId).getBytes(); + } + + @Override + public String getString(int rowId) { + return input.getString(rowId); + } + + @Override + public BigDecimal getDecimal(int rowId) { + return input.isNullAt(rowId) ? null : new BigDecimal(input.getString(rowId)); + } + }; + } +} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala index da82422f883..a8e0ae4cead 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -17,17 +17,20 @@ package io.delta.kernel.defaults.internal.expressions import java.lang.{Boolean => BooleanJ} import java.math.{BigDecimal => BigDecimalJ} +import java.sql.Date import java.util import java.util.Optional import io.delta.kernel.data.{ColumnarBatch, ColumnVector} import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch -import io.delta.kernel.defaults.internal.data.vector.{DefaultIntVector, DefaultStructVector} +import io.delta.kernel.defaults.internal.data.vector.{DefaultIntVector, DefaultMapVector, DefaultStructVector} import io.delta.kernel.defaults.internal.data.vector.VectorUtils.getValueAsObject +import io.delta.kernel.defaults.utils.TestUtils import io.delta.kernel.expressions._ import io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE import io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE import io.delta.kernel.expressions.Literal._ +import io.delta.kernel.internal.util.InternalUtils import io.delta.kernel.types._ import org.scalatest.funsuite.AnyFunSuite @@ -386,11 +389,156 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa val right = literals(4) Seq.range(5, literals.length).foreach { idx => comparatorToExpResults.foreach { comparator => - testComparator(comparator, right, literals(idx), null) + testComparator(comparator, right, literals(idx), null) } } } + test("evaluate expression: element_at") { + import scala.collection.JavaConverters._ + val nullStr = null.asInstanceOf[String] + val testMapValues = Seq( + Map("k0" -> "v00", "k1" -> "v01", "k3" -> nullStr, nullStr -> "v04").asJava, + Map("k0" -> "v10", "k1" -> nullStr, "k3" -> "v13", nullStr -> "v14").asJava, + Map("k0" -> nullStr, "k1" -> "v21", "k3" -> "v23", nullStr -> "v24").asJava, + null + ) + val testMapVector = new ColumnVector { + override def getDataType: DataType = + new MapType(StringType.INSTANCE, StringType.INSTANCE, true) + + override def getSize: Int = testMapValues.size + + override def close(): Unit = {} + + override def isNullAt(rowId: Int): Boolean = testMapValues(rowId) == null + + override def getMap[K, V](rowId: Int): util.Map[K, V] = + testMapValues(rowId).asInstanceOf[util.Map[K, V]] + } + val inputBatch = new DefaultColumnarBatch( + testMapVector.getSize, + new StructType().add("partitionValues", testMapVector.getDataType), + Seq(testMapVector).toArray + ) + Seq("k0", "k1", "k2", null).foreach { lookupKey => + val expOutput = testMapValues.map(map => { + if (map == null) null + else map.get(lookupKey) + }) + + val lookupKeyExpr = if (lookupKey == null) { + Literal.ofNull(StringType.INSTANCE) + } else { + Literal.ofString(lookupKey) + } + val elementAtExpr = new ScalarExpression( + "element_at", + util.Arrays.asList(new Column("partitionValues"), lookupKeyExpr)) + + val outputVector = evaluator(inputBatch.getSchema, elementAtExpr, StringType.INSTANCE) + .eval(inputBatch) + assert(outputVector.getSize === testMapValues.size) + assert(outputVector.getDataType === StringType.INSTANCE) + Seq.range(0, testMapValues.size).foreach { rowId => + val expNull = expOutput(rowId) == null + assert(outputVector.isNullAt(rowId) == expNull) + if (!expNull) { + assert(outputVector.getString(rowId) === expOutput(rowId)) + } + } + } + } + + test("evaluate expression: element_at - unsupported map type input") { + val inputSchema = new StructType() + .add("as_map", new MapType(IntegerType.INSTANCE, BooleanType.INSTANCE, true)) + val elementAtExpr = new ScalarExpression( + "element_at", + util.Arrays.asList(new Column("as_map"), Literal.ofString("empty"))) + + val ex = intercept[UnsupportedOperationException] { + evaluator(inputSchema, elementAtExpr, StringType.INSTANCE) + } + assert(ex.getMessage.contains( + "ELEMENT_AT(column(`as_map`), empty): Supported only on type map(string, string) input data")) + } + + test("evaluate expression: element_at - unsupported lookup type input") { + val inputSchema = new StructType() + .add("as_map", new MapType(StringType.INSTANCE, StringType.INSTANCE, true)) + val elementAtExpr = new ScalarExpression( + "element_at", + util.Arrays.asList(new Column("as_map"), Literal.ofShort(24))) + + val ex = intercept[UnsupportedOperationException] { + evaluator(inputSchema, elementAtExpr, StringType.INSTANCE) + } + assert(ex.getMessage.contains("ELEMENT_AT(column(`as_map`), 24): " + + "lookup key type (short) is different from the map key type (string)")) + } + + test("evaluate expression: partition_value") { + // (serialized partition value, partition col type, expected deserialized partition value) + val testCases = Seq( + ("true", BooleanType.INSTANCE, true), + ("false", BooleanType.INSTANCE, false), + (null, BooleanType.INSTANCE, null), + ("24", ByteType.INSTANCE, 24.toByte), + ("null", ByteType.INSTANCE, null), + ("876", ShortType.INSTANCE, 876.toShort), + ("null", ShortType.INSTANCE, null), + ("2342342", IntegerType.INSTANCE, 2342342), + ("null", IntegerType.INSTANCE, null), + ("234234223", LongType.INSTANCE, 234234223L), + ("null", LongType.INSTANCE, null), + ("23423.4223", FloatType.INSTANCE, 23423.4223f), + ("null", FloatType.INSTANCE, null), + ("23423.422233", DoubleType.INSTANCE, 23423.422233d), + ("null", DoubleType.INSTANCE, null), + ("234.422233", new DecimalType(10, 6), new BigDecimalJ("234.422233")), + ("null", DoubleType.INSTANCE, null), + ("string_val", StringType.INSTANCE, "string_val"), + ("null", StringType.INSTANCE, null), + ("binary_val", BinaryType.INSTANCE, "binary_val".getBytes()), + ("null", BinaryType.INSTANCE, null), + ("2021-11-18", DateType.INSTANCE, InternalUtils.daysSinceEpoch(Date.valueOf("2021-11-18"))), + ("null", DateType.INSTANCE, null), + ("2021-11-18", DateType.INSTANCE, InternalUtils.daysSinceEpoch(Date.valueOf("2021-11-18"))), + ("null", DateType.INSTANCE, null) + // TODO: timestamp partition value types are not yet supported in reading + ) + + val inputBatch = zeroColumnBatch(rowCount = 1) + testCases.foreach { testCase => + val (serializedPartVal, partType, deserializedPartVal) = testCase + val literalSerializedPartVal = if (serializedPartVal == "null") { + Literal.ofNull(StringType.INSTANCE) + } else { + Literal.ofString(serializedPartVal) + } + val expr = new PartitionValueExpression(literalSerializedPartVal, partType) + val outputVector = evaluator(inputBatch.getSchema, expr, partType).eval(inputBatch) + assert(outputVector.getSize === 1) + assert(outputVector.getDataType === partType) + assert(outputVector.isNullAt(0) === (deserializedPartVal == null)) + if (deserializedPartVal != null) { + assert(getValueAsObject(outputVector, 0) === deserializedPartVal) + } + } + } + + test("evaluate expression: partition_value - invalid serialize value") { + val inputBatch = zeroColumnBatch(rowCount = 1) + val (serializedPartVal, partType) = ("23423sdfsdf", IntegerType.INSTANCE) + val expr = new PartitionValueExpression(Literal.ofString(serializedPartVal), partType) + val ex = intercept[IllegalArgumentException] { + val outputVector = evaluator(inputBatch.getSchema, expr, partType).eval(inputBatch) + outputVector.getInt(0) + } + assert(ex.getMessage.contains(serializedPartVal)) + } + /** * Utility method to generate a [[dataType]] column vector of given size. * The nullability of rows is determined by the [[testIsNullValue(dataType, rowId)]]. From f6d235131f2fddf60f696fcb9902979b30a8161e Mon Sep 17 00:00:00 2001 From: Venki Korukanti Date: Wed, 27 Sep 2023 10:14:05 -0700 Subject: [PATCH 2/2] review --- .../kernel/expressions/PartitionValueExpression.java | 3 ++- .../defaults/internal/expressions/ElementAtEvaluator.java | 8 +++++--- .../expressions/DefaultExpressionEvaluatorSuite.scala | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/PartitionValueExpression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/PartitionValueExpression.java index b793c428702..329f01ef00b 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/PartitionValueExpression.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/PartitionValueExpression.java @@ -26,7 +26,8 @@ /** * Expression to decode the serialized partition value into partition type value according the * - * Delta Protocol spec. + * Delta Protocol spec. Currently all valid partition types are supported except the + * `timestamp` and `timestamp without timezone` types. *

    *

      *
    • Name: partition_value diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java index 1436f18e6fc..1028e9cb918 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java @@ -65,11 +65,13 @@ static ScalarExpression validateAndTransform( * Utility method to evaluate the {@code element_at} on given map and key vectors. * @param map {@link ColumnVector} of {@code map(string, string)} type. * @param lookupKey {@link ColumnVector} of {@code string} type. - * @return + * @return result {@link ColumnVector} containing the lookup values. */ static ColumnVector eval(ColumnVector map, ColumnVector lookupKey) { return new ColumnVector() { // Store the last lookup value to avoid multiple looks up for same row id. + // The general pattern is call `isNullAt(rowId)` followed by `getString`. + // So the cache of one value is enough. private int lastLookupRowId = -1; private Object lastLookupValue = null; @@ -121,8 +123,8 @@ private static MapType validateSupportedMapType(Expression elementAt, DataType m mapInputType instanceof MapType, "expected a map type input as first argument: " + elementAt); MapType asMapType = (MapType) mapInputType; - // TODO: we may extend type support in future, but for the need is just a look - // in map(string, string). + // TODO: we may extend type support in future, but currently the need is just a lookup + // in map column of type `map(string -> string)`. if (asMapType.getKeyType().equivalent(StringType.INSTANCE) && asMapType.getValueType().equivalent(StringType.INSTANCE)) { return asMapType; diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala index a8e0ae4cead..b4766a029c7 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -405,7 +405,7 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa ) val testMapVector = new ColumnVector { override def getDataType: DataType = - new MapType(StringType.INSTANCE, StringType.INSTANCE, true) + new MapType(StringType.INSTANCE, StringType.INSTANCE, true /* valueContainsNull */) override def getSize: Int = testMapValues.size