- * - Name:
+
+ * - Name:
element_at
*
- * - SQL semantic:
expr1 + expr2
+ * - Semantic:
element_at(map, key)
. Return the value of given key
+ * from the map type input. Ex: `element_at(map(1, 'a', 2, 'b'), 2)` returns 'b'
* - Since version: 3.0.0
*
*
diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java
index 89302e0c8a4..c7b57dec889 100644
--- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java
+++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java
@@ -29,8 +29,7 @@
import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector;
import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector;
import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;
-import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.compare;
-import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.evalNullability;
+import static io.delta.kernel.defaults.internal.expressions.ExpressionUtils.*;
import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo;
/**
@@ -157,7 +156,7 @@ ExpressionTransformResult visitLiteral(Literal literal) {
ExpressionTransformResult visitColumn(Column column) {
String[] names = column.getNames();
DataType currentType = inputDataSchema;
- for(int level = 0; level < names.length; level++) {
+ for (int level = 0; level < names.length; level++) {
assertColumnExists(currentType instanceof StructType, inputDataSchema, column);
StructType structSchema = ((StructType) currentType);
int ordinal = structSchema.indexOf(names[level]);
@@ -173,6 +172,42 @@ ExpressionTransformResult visitCast(ImplicitCastExpression cast) {
throw new UnsupportedOperationException("CAST expression is not expected.");
}
+ @Override
+ ExpressionTransformResult visitPartitionValue(PartitionValueExpression partitionValue) {
+ ExpressionTransformResult serializedPartValueInput = visit(partitionValue.getInput());
+ checkArgument(
+ serializedPartValueInput.outputType instanceof StringType,
+ "%s: expected string input, but got %s",
+ partitionValue, serializedPartValueInput.outputType);
+ DataType partitionColType = partitionValue.getDataType();
+ if (partitionColType instanceof StructType ||
+ partitionColType instanceof ArrayType ||
+ partitionColType instanceof MapType) {
+ throw new UnsupportedOperationException(
+ "unsupported partition data type: " + partitionColType);
+ }
+ return new ExpressionTransformResult(
+ new PartitionValueExpression(serializedPartValueInput.expression, partitionColType),
+ partitionColType);
+ }
+
+ @Override
+ ExpressionTransformResult visitElementAt(ScalarExpression elementAt) {
+ ExpressionTransformResult transformedMapInput = visit(childAt(elementAt, 0));
+ ExpressionTransformResult transformedLookupKey = visit(childAt(elementAt, 1));
+
+ ScalarExpression transformedExpression = ElementAtEvaluator.validateAndTransform(
+ elementAt,
+ transformedMapInput.expression,
+ transformedMapInput.outputType,
+ transformedLookupKey.expression,
+ transformedLookupKey.outputType);
+
+ return new ExpressionTransformResult(
+ transformedExpression,
+ ((MapType) transformedMapInput.outputType).getValueType());
+ }
+
private Predicate validateIsPredicate(
Expression baseExpression,
ExpressionTransformResult result) {
@@ -187,9 +222,8 @@ private Predicate validateIsPredicate(
}
private Expression transformBinaryComparator(Predicate predicate) {
- checkArgument(predicate.getChildren().size() == 2, "expected two inputs");
- ExpressionTransformResult leftResult = visit(predicate.getChildren().get(0));
- ExpressionTransformResult rightResult = visit(predicate.getChildren().get(1));
+ ExpressionTransformResult leftResult = visit(getLeft(predicate));
+ ExpressionTransformResult rightResult = visit(getRight(predicate));
Expression left = leftResult.expression;
Expression right = rightResult.expression;
if (!leftResult.outputType.equivalent(rightResult.outputType)) {
@@ -324,7 +358,7 @@ ColumnVector visitColumn(Column column) {
String[] names = column.getNames();
DataType currentType = input.getSchema();
ColumnVector columnVector = null;
- for(int level = 0; level < names.length; level++) {
+ for (int level = 0; level < names.length; level++) {
assertColumnExists(currentType instanceof StructType, input.getSchema(), column);
StructType structSchema = ((StructType) currentType);
int ordinal = structSchema.indexOf(names[level]);
@@ -347,6 +381,19 @@ ColumnVector visitCast(ImplicitCastExpression cast) {
return cast.eval(inputResult);
}
+ @Override
+ ColumnVector visitPartitionValue(PartitionValueExpression partitionValue) {
+ ColumnVector input = visit(partitionValue.getInput());
+ return PartitionValueEvaluator.eval(input, partitionValue.getDataType());
+ }
+
+ @Override
+ ColumnVector visitElementAt(ScalarExpression elementAt) {
+ ColumnVector map = visit(childAt(elementAt, 0));
+ ColumnVector lookupKey = visit(childAt(elementAt, 1));
+ return ElementAtEvaluator.eval(map, lookupKey);
+ }
+
/**
* Utility method to evaluate inputs to the binary input expression. Also validates the
* evaluated expression result {@link ColumnVector}s are of the same size.
@@ -355,9 +402,8 @@ ColumnVector visitCast(ImplicitCastExpression cast) {
* @return Triplet of (result vector size, left operand result, left operand result)
*/
private PredicateChildrenEvalResult evalBinaryExpressionChildren(Predicate predicate) {
- checkArgument(predicate.getChildren().size() == 2, "expected two inputs");
- ColumnVector left = visit(predicate.getChildren().get(0));
- ColumnVector right = visit(predicate.getChildren().get(1));
+ ColumnVector left = visit(getLeft(predicate));
+ ColumnVector right = visit(getRight(predicate));
checkArgument(
left.getSize() == right.getSize(),
"Left and right operand returned different results: left=%d, right=d",
diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java
new file mode 100644
index 00000000000..1028e9cb918
--- /dev/null
+++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java
@@ -0,0 +1,135 @@
+/*
+ * Copyright (2023) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.delta.kernel.defaults.internal.expressions;
+
+import java.util.Arrays;
+import static java.lang.String.format;
+
+import io.delta.kernel.data.ColumnVector;
+import io.delta.kernel.expressions.Expression;
+import io.delta.kernel.expressions.ScalarExpression;
+import io.delta.kernel.types.DataType;
+import io.delta.kernel.types.MapType;
+import io.delta.kernel.types.StringType;
+import io.delta.kernel.utils.Utils;
+
+import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;
+import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo;
+
+/**
+ * Utility methods to evaluate {@code element_at} expression.
+ */
+class ElementAtEvaluator {
+ private ElementAtEvaluator() {}
+
+ /**
+ * Validate and transform the {@code element_at} expression with given validated and
+ * transformed inputs.
+ */
+ static ScalarExpression validateAndTransform(
+ ScalarExpression elementAt,
+ Expression mapInput,
+ DataType mapInputType,
+ Expression lookupKey,
+ DataType lookupKeyType) {
+
+ MapType asMapType = validateSupportedMapType(elementAt, mapInputType);
+ DataType keyTypeFromMapInput = asMapType.getKeyType();
+
+ if (!keyTypeFromMapInput.equivalent(lookupKeyType)) {
+ if (canCastTo(lookupKeyType, keyTypeFromMapInput)) {
+ lookupKey = new ImplicitCastExpression(lookupKey, keyTypeFromMapInput);
+ } else {
+ throw new UnsupportedOperationException(format(
+ "%s: lookup key type (%s) is different from the map key type (%s)",
+ elementAt, lookupKeyType, asMapType.getKeyType()));
+ }
+ }
+ return new ScalarExpression(elementAt.getName(), Arrays.asList(mapInput, lookupKey));
+ }
+
+ /**
+ * Utility method to evaluate the {@code element_at} on given map and key vectors.
+ * @param map {@link ColumnVector} of {@code map(string, string)} type.
+ * @param lookupKey {@link ColumnVector} of {@code string} type.
+ * @return result {@link ColumnVector} containing the lookup values.
+ */
+ static ColumnVector eval(ColumnVector map, ColumnVector lookupKey) {
+ return new ColumnVector() {
+ // Store the last lookup value to avoid multiple looks up for same row id.
+ // The general pattern is call `isNullAt(rowId)` followed by `getString`.
+ // So the cache of one value is enough.
+ private int lastLookupRowId = -1;
+ private Object lastLookupValue = null;
+
+ @Override
+ public DataType getDataType() {
+ return ((MapType) map.getDataType()).getValueType();
+ }
+
+ @Override
+ public int getSize() {
+ return map.getSize();
+ }
+
+ @Override
+ public void close() {
+ Utils.closeCloseables(map, lookupKey);
+ }
+
+ @Override
+ public boolean isNullAt(int rowId) {
+ if (rowId == lastLookupRowId) {
+ return lastLookupValue == null;
+ }
+ return map.isNullAt(rowId) || lookupValue(rowId) == null;
+ }
+
+ @Override
+ public String getString(int rowId) {
+ lookupValue(rowId);
+ return lastLookupValue == null ? null : (String) lastLookupValue;
+ }
+
+ private Object lookupValue(int rowId) {
+ if (rowId == lastLookupRowId) {
+ return lastLookupValue;
+ }
+ // TODO: this needs to be updated after the new way of accessing the complex
+ // types is merged.
+ lastLookupRowId = rowId;
+ String keyValue = lookupKey.getString(rowId);
+ lastLookupValue = map.getMap(rowId).get(keyValue);
+ return lastLookupValue;
+ }
+ };
+ }
+
+ private static MapType validateSupportedMapType(Expression elementAt, DataType mapInputType) {
+ checkArgument(
+ mapInputType instanceof MapType,
+ "expected a map type input as first argument: " + elementAt);
+ MapType asMapType = (MapType) mapInputType;
+ // TODO: we may extend type support in future, but currently the need is just a lookup
+ // in map column of type `map(string -> string)`.
+ if (asMapType.getKeyType().equivalent(StringType.INSTANCE) &&
+ asMapType.getValueType().equivalent(StringType.INSTANCE)) {
+ return asMapType;
+ }
+ throw new UnsupportedOperationException(
+ format("%s: Supported only on type map(string, string) input data", elementAt));
+ }
+}
diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java
index 8997afc9fc3..5ff13573b4b 100644
--- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java
+++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionUtils.java
@@ -17,8 +17,11 @@
import java.math.BigDecimal;
import java.util.Comparator;
+import java.util.List;
+import static java.lang.String.format;
import io.delta.kernel.data.ColumnVector;
+import io.delta.kernel.expressions.Expression;
import io.delta.kernel.types.*;
import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;
@@ -46,7 +49,7 @@ static boolean[] evalNullability(ColumnVector left, ColumnVector right) {
* Utility method to compare the left and right according to the natural ordering
* and return an integer array where each row contains the comparison result (-1, 0, 1) for
* corresponding rows in the input vectors compared.
- *
+ *
* Only primitive data types are supported.
*/
static int[] compare(ColumnVector left, ColumnVector right) {
@@ -174,4 +177,30 @@ static void compareBinary(ColumnVector left, ColumnVector right, int[] result) {
}
}
}
+
+ /**
+ * Utility method to return the left child of the binary input expression
+ */
+ static Expression getLeft(Expression expression) {
+ List children = expression.getChildren();
+ checkArgument(
+ children.size() == 2,
+ format("%s: expected two inputs, but got %s", expression, children.size()));
+ return children.get(0);
+ }
+
+ /**
+ * Utility method to return the right child of the binary input expression
+ */
+ static Expression getRight(Expression expression) {
+ List children = expression.getChildren();
+ checkArgument(
+ children.size() == 2,
+ format("%s: expected two inputs, but got %s", expression, children.size()));
+ return children.get(1);
+ }
+
+ static Expression childAt(Expression expression, int index) {
+ return expression.getChildren().get(index);
+ }
}
diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java
index 0e0e5000687..62defe46d50 100644
--- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java
+++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java
@@ -47,8 +47,14 @@ abstract class ExpressionVisitor {
abstract R visitCast(ImplicitCastExpression cast);
+ abstract R visitPartitionValue(PartitionValueExpression partitionValue);
+
+ abstract R visitElementAt(ScalarExpression elementAt);
+
final R visit(Expression expression) {
- if (expression instanceof ScalarExpression) {
+ if (expression instanceof PartitionValueExpression) {
+ return visitPartitionValue((PartitionValueExpression) expression);
+ } else if (expression instanceof ScalarExpression) {
return visitScalarExpression((ScalarExpression) expression);
} else if (expression instanceof Literal) {
return visitLiteral((Literal) expression);
@@ -81,6 +87,8 @@ private R visitScalarExpression(ScalarExpression expression) {
case ">":
case ">=":
return visitComparator(new Predicate(name, children));
+ case "ELEMENT_AT":
+ return visitElementAt(expression);
default:
throw new UnsupportedOperationException(
String.format("Scalar expression `%s` is not supported.", name));
diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/PartitionValueEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/PartitionValueEvaluator.java
new file mode 100644
index 00000000000..57fcf83130d
--- /dev/null
+++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/PartitionValueEvaluator.java
@@ -0,0 +1,119 @@
+/*
+ * Copyright (2023) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.delta.kernel.defaults.internal.expressions;
+
+import java.math.BigDecimal;
+import java.sql.Date;
+
+import io.delta.kernel.data.ColumnVector;
+import io.delta.kernel.types.DataType;
+import io.delta.kernel.types.DateType;
+import io.delta.kernel.types.IntegerType;
+import io.delta.kernel.types.LongType;
+
+import io.delta.kernel.internal.util.InternalUtils;
+
+/**
+ * Utility methods to evaluate {@code partition_value} expression
+ */
+class PartitionValueEvaluator {
+ /**
+ * Evaluate the {@code partition_value} expression for given input column vector and generate
+ * a column vector with decoded values according to the given partition type.
+ */
+ static ColumnVector eval(ColumnVector input, DataType partitionType) {
+ return new ColumnVector() {
+ @Override
+ public DataType getDataType() {
+ return partitionType;
+ }
+
+ @Override
+ public int getSize() {
+ return input.getSize();
+ }
+
+ @Override
+ public void close() {
+ input.close();
+ }
+
+ @Override
+ public boolean isNullAt(int rowId) {
+ return input.isNullAt(rowId);
+ }
+
+ @Override
+ public boolean getBoolean(int rowId) {
+ return Boolean.parseBoolean(input.getString(rowId));
+ }
+
+ @Override
+ public byte getByte(int rowId) {
+ return Byte.parseByte(input.getString(rowId));
+ }
+
+ @Override
+ public short getShort(int rowId) {
+ return Short.parseShort(input.getString(rowId));
+ }
+
+ @Override
+ public int getInt(int rowId) {
+ if (partitionType.equivalent(IntegerType.INSTANCE)) {
+ return Integer.parseInt(input.getString(rowId));
+ } else if (partitionType.equivalent(DateType.INSTANCE)) {
+ return InternalUtils.daysSinceEpoch(Date.valueOf(input.getString(rowId)));
+ }
+ throw new UnsupportedOperationException("Invalid value request for data type");
+ }
+
+ @Override
+ public long getLong(int rowId) {
+ if (partitionType.equivalent(LongType.INSTANCE)) {
+ return Long.parseLong(input.getString(rowId));
+ }
+ // TODO: partition value of timestamp type are not yet supported
+ throw new UnsupportedOperationException("Invalid value request for data type");
+ }
+
+ @Override
+ public float getFloat(int rowId) {
+ return Float.parseFloat(input.getString(rowId));
+ }
+
+ @Override
+ public double getDouble(int rowId) {
+ return Double.parseDouble(input.getString(rowId));
+ }
+
+ @Override
+ public byte[] getBinary(int rowId) {
+ return input.isNullAt(rowId) ? null : input.getString(rowId).getBytes();
+ }
+
+ @Override
+ public String getString(int rowId) {
+ return input.getString(rowId);
+ }
+
+ @Override
+ public BigDecimal getDecimal(int rowId) {
+ return input.isNullAt(rowId) ? null : new BigDecimal(input.getString(rowId));
+ }
+ };
+ }
+}
diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala
index da82422f883..b4766a029c7 100644
--- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala
+++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala
@@ -17,17 +17,20 @@ package io.delta.kernel.defaults.internal.expressions
import java.lang.{Boolean => BooleanJ}
import java.math.{BigDecimal => BigDecimalJ}
+import java.sql.Date
import java.util
import java.util.Optional
import io.delta.kernel.data.{ColumnarBatch, ColumnVector}
import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch
-import io.delta.kernel.defaults.internal.data.vector.{DefaultIntVector, DefaultStructVector}
+import io.delta.kernel.defaults.internal.data.vector.{DefaultIntVector, DefaultMapVector, DefaultStructVector}
import io.delta.kernel.defaults.internal.data.vector.VectorUtils.getValueAsObject
+import io.delta.kernel.defaults.utils.TestUtils
import io.delta.kernel.expressions._
import io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE
import io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE
import io.delta.kernel.expressions.Literal._
+import io.delta.kernel.internal.util.InternalUtils
import io.delta.kernel.types._
import org.scalatest.funsuite.AnyFunSuite
@@ -386,11 +389,156 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa
val right = literals(4)
Seq.range(5, literals.length).foreach { idx =>
comparatorToExpResults.foreach { comparator =>
- testComparator(comparator, right, literals(idx), null)
+ testComparator(comparator, right, literals(idx), null)
}
}
}
+ test("evaluate expression: element_at") {
+ import scala.collection.JavaConverters._
+ val nullStr = null.asInstanceOf[String]
+ val testMapValues = Seq(
+ Map("k0" -> "v00", "k1" -> "v01", "k3" -> nullStr, nullStr -> "v04").asJava,
+ Map("k0" -> "v10", "k1" -> nullStr, "k3" -> "v13", nullStr -> "v14").asJava,
+ Map("k0" -> nullStr, "k1" -> "v21", "k3" -> "v23", nullStr -> "v24").asJava,
+ null
+ )
+ val testMapVector = new ColumnVector {
+ override def getDataType: DataType =
+ new MapType(StringType.INSTANCE, StringType.INSTANCE, true /* valueContainsNull */)
+
+ override def getSize: Int = testMapValues.size
+
+ override def close(): Unit = {}
+
+ override def isNullAt(rowId: Int): Boolean = testMapValues(rowId) == null
+
+ override def getMap[K, V](rowId: Int): util.Map[K, V] =
+ testMapValues(rowId).asInstanceOf[util.Map[K, V]]
+ }
+ val inputBatch = new DefaultColumnarBatch(
+ testMapVector.getSize,
+ new StructType().add("partitionValues", testMapVector.getDataType),
+ Seq(testMapVector).toArray
+ )
+ Seq("k0", "k1", "k2", null).foreach { lookupKey =>
+ val expOutput = testMapValues.map(map => {
+ if (map == null) null
+ else map.get(lookupKey)
+ })
+
+ val lookupKeyExpr = if (lookupKey == null) {
+ Literal.ofNull(StringType.INSTANCE)
+ } else {
+ Literal.ofString(lookupKey)
+ }
+ val elementAtExpr = new ScalarExpression(
+ "element_at",
+ util.Arrays.asList(new Column("partitionValues"), lookupKeyExpr))
+
+ val outputVector = evaluator(inputBatch.getSchema, elementAtExpr, StringType.INSTANCE)
+ .eval(inputBatch)
+ assert(outputVector.getSize === testMapValues.size)
+ assert(outputVector.getDataType === StringType.INSTANCE)
+ Seq.range(0, testMapValues.size).foreach { rowId =>
+ val expNull = expOutput(rowId) == null
+ assert(outputVector.isNullAt(rowId) == expNull)
+ if (!expNull) {
+ assert(outputVector.getString(rowId) === expOutput(rowId))
+ }
+ }
+ }
+ }
+
+ test("evaluate expression: element_at - unsupported map type input") {
+ val inputSchema = new StructType()
+ .add("as_map", new MapType(IntegerType.INSTANCE, BooleanType.INSTANCE, true))
+ val elementAtExpr = new ScalarExpression(
+ "element_at",
+ util.Arrays.asList(new Column("as_map"), Literal.ofString("empty")))
+
+ val ex = intercept[UnsupportedOperationException] {
+ evaluator(inputSchema, elementAtExpr, StringType.INSTANCE)
+ }
+ assert(ex.getMessage.contains(
+ "ELEMENT_AT(column(`as_map`), empty): Supported only on type map(string, string) input data"))
+ }
+
+ test("evaluate expression: element_at - unsupported lookup type input") {
+ val inputSchema = new StructType()
+ .add("as_map", new MapType(StringType.INSTANCE, StringType.INSTANCE, true))
+ val elementAtExpr = new ScalarExpression(
+ "element_at",
+ util.Arrays.asList(new Column("as_map"), Literal.ofShort(24)))
+
+ val ex = intercept[UnsupportedOperationException] {
+ evaluator(inputSchema, elementAtExpr, StringType.INSTANCE)
+ }
+ assert(ex.getMessage.contains("ELEMENT_AT(column(`as_map`), 24): " +
+ "lookup key type (short) is different from the map key type (string)"))
+ }
+
+ test("evaluate expression: partition_value") {
+ // (serialized partition value, partition col type, expected deserialized partition value)
+ val testCases = Seq(
+ ("true", BooleanType.INSTANCE, true),
+ ("false", BooleanType.INSTANCE, false),
+ (null, BooleanType.INSTANCE, null),
+ ("24", ByteType.INSTANCE, 24.toByte),
+ ("null", ByteType.INSTANCE, null),
+ ("876", ShortType.INSTANCE, 876.toShort),
+ ("null", ShortType.INSTANCE, null),
+ ("2342342", IntegerType.INSTANCE, 2342342),
+ ("null", IntegerType.INSTANCE, null),
+ ("234234223", LongType.INSTANCE, 234234223L),
+ ("null", LongType.INSTANCE, null),
+ ("23423.4223", FloatType.INSTANCE, 23423.4223f),
+ ("null", FloatType.INSTANCE, null),
+ ("23423.422233", DoubleType.INSTANCE, 23423.422233d),
+ ("null", DoubleType.INSTANCE, null),
+ ("234.422233", new DecimalType(10, 6), new BigDecimalJ("234.422233")),
+ ("null", DoubleType.INSTANCE, null),
+ ("string_val", StringType.INSTANCE, "string_val"),
+ ("null", StringType.INSTANCE, null),
+ ("binary_val", BinaryType.INSTANCE, "binary_val".getBytes()),
+ ("null", BinaryType.INSTANCE, null),
+ ("2021-11-18", DateType.INSTANCE, InternalUtils.daysSinceEpoch(Date.valueOf("2021-11-18"))),
+ ("null", DateType.INSTANCE, null),
+ ("2021-11-18", DateType.INSTANCE, InternalUtils.daysSinceEpoch(Date.valueOf("2021-11-18"))),
+ ("null", DateType.INSTANCE, null)
+ // TODO: timestamp partition value types are not yet supported in reading
+ )
+
+ val inputBatch = zeroColumnBatch(rowCount = 1)
+ testCases.foreach { testCase =>
+ val (serializedPartVal, partType, deserializedPartVal) = testCase
+ val literalSerializedPartVal = if (serializedPartVal == "null") {
+ Literal.ofNull(StringType.INSTANCE)
+ } else {
+ Literal.ofString(serializedPartVal)
+ }
+ val expr = new PartitionValueExpression(literalSerializedPartVal, partType)
+ val outputVector = evaluator(inputBatch.getSchema, expr, partType).eval(inputBatch)
+ assert(outputVector.getSize === 1)
+ assert(outputVector.getDataType === partType)
+ assert(outputVector.isNullAt(0) === (deserializedPartVal == null))
+ if (deserializedPartVal != null) {
+ assert(getValueAsObject(outputVector, 0) === deserializedPartVal)
+ }
+ }
+ }
+
+ test("evaluate expression: partition_value - invalid serialize value") {
+ val inputBatch = zeroColumnBatch(rowCount = 1)
+ val (serializedPartVal, partType) = ("23423sdfsdf", IntegerType.INSTANCE)
+ val expr = new PartitionValueExpression(Literal.ofString(serializedPartVal), partType)
+ val ex = intercept[IllegalArgumentException] {
+ val outputVector = evaluator(inputBatch.getSchema, expr, partType).eval(inputBatch)
+ outputVector.getInt(0)
+ }
+ assert(ex.getMessage.contains(serializedPartVal))
+ }
+
/**
* Utility method to generate a [[dataType]] column vector of given size.
* The nullability of rows is determined by the [[testIsNullValue(dataType, rowId)]].