From 19b6c9e9f5f148e53d1b2a63907cd73d091fda1c Mon Sep 17 00:00:00 2001 From: Allison Portis Date: Thu, 14 Sep 2023 19:15:31 -0700 Subject: [PATCH] [Kernel] Add test utilities like checkAnswer and checkTable (#2034) #### Which Delta project/connector is this regarding? - [ ] Spark - [ ] Standalone - [ ] Flink - [x] Kernel - [ ] Other (fill in here) ## Description Improves the testing infrastructure for Scala tests in Delta Kernel. For now adds it to `kernel-defaults` but if we have tests with `ColumnarBatch`s in `kernel-api` we can move it there. ## How was this patch tested? Refactors existing tests to use the new infra. --- .../kernel/defaults/DeletionVectorSuite.scala | 66 ++-- .../defaults/DeltaTableReadsSuite.scala | 90 +++--- .../kernel/defaults/LogReplaySuite.scala | 1 + .../defaults/ParquetBatchReaderSuite.scala | 28 +- .../io/delta/kernel/defaults/TestUtils.scala | 150 --------- .../DefaultExpressionEvaluatorSuite.scala | 2 +- .../ImplicitCastExpressionSuite.scala | 2 +- .../delta/kernel/defaults/utils/TestRow.scala | 125 ++++++++ .../kernel/defaults/utils/TestUtils.scala | 284 ++++++++++++++++++ 9 files changed, 487 insertions(+), 261 deletions(-) delete mode 100644 kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/TestUtils.scala create mode 100644 kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala create mode 100644 kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeletionVectorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeletionVectorSuite.scala index 1480ac869cd..4893e44fc8f 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeletionVectorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeletionVectorSuite.scala @@ -17,77 +17,59 @@ package io.delta.kernel.defaults import io.delta.golden.GoldenTableUtils.goldenTablePath -import io.delta.kernel.Table import io.delta.kernel.defaults.client.DefaultTableClient -import io.delta.kernel.defaults.utils.DefaultKernelTestUtils +import io.delta.kernel.defaults.utils.{TestRow, TestUtils} +import io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getTestResourceFilePath import org.apache.hadoop.conf.Configuration import org.scalatest.funsuite.AnyFunSuite class DeletionVectorSuite extends AnyFunSuite with TestUtils { test("end-to-end usage: reading a table with dv") { - val path = DefaultKernelTestUtils.getTestResourceFilePath("basic-dv-no-checkpoint") - val expectedResult = Seq.range(start = 2, end = 10).toSet - - val snapshot = Table.forPath(path).getLatestSnapshot(defaultTableClient) - val result = readSnapshot(snapshot).map { row => - row.getLong(0) - } - - assert(result.toSet === expectedResult) + checkTable( + path = getTestResourceFilePath("basic-dv-no-checkpoint"), + expectedAnswer = (2L until 10L).map(TestRow(_)) + ) } test("end-to-end usage: reading a table with dv with checkpoint") { - val path = DefaultKernelTestUtils.getTestResourceFilePath("basic-dv-with-checkpoint") - val expectedResult = Seq.range(start = 0, end = 500).filter(_ % 11 != 0).toSet - - val snapshot = Table.forPath(path).getLatestSnapshot(defaultTableClient) - val result = readSnapshot(snapshot).map { row => - row.getLong(0) - } - - assert(result.toSet === expectedResult) + checkTable( + path = getTestResourceFilePath("basic-dv-with-checkpoint"), + expectedAnswer = (0L until 500L).filter(_ % 11 != 0).map(TestRow(_)) + ) } test("end-to-end usage: reading partitioned dv table with checkpoint") { - // kernel expects a fully qualified path - val path = "file:" + goldenTablePath("dv-partitioned-with-checkpoint") - val expectedResult = (0 until 50).map(x => (x%10, x, s"foo${x % 5}")) - .filter{ case (_, col1, _) => - !(col1 % 2 == 0 && col1 < 30) - }.toSet - val conf = new Configuration() // Set the batch size small enough so there will be multiple batches conf.setInt("delta.kernel.default.parquet.reader.batch-size", 2) val tableClient = DefaultTableClient.create(conf) - val snapshot = Table.forPath(path).getLatestSnapshot(tableClient) - val result = readSnapshot(snapshot, tableClient = tableClient).map { row => - (row.getInt(0), row.getInt(1), row.getString(2)) - } + val expectedResult = (0 until 50).map(x => (x%10, x, s"foo${x % 5}")) + .filter{ case (_, col1, _) => + !(col1 % 2 == 0 && col1 < 30) + } - assert (result.toSet == expectedResult) + checkTable( + path = "file:" + goldenTablePath("dv-partitioned-with-checkpoint"), + expectedAnswer = expectedResult.map(TestRow.fromTuple(_)), + tableClient = tableClient + ) } // TODO: update to use goldenTables once bug is fixed in delta-spark see issue #1886 test( "end-to-end usage: reading partitioned dv table with checkpoint with columnMappingMode=name") { - val path = DefaultKernelTestUtils.getTestResourceFilePath("dv-with-columnmapping") val expectedResult = (0 until 50).map(x => (x%10, x, s"foo${x % 5}")) .filter{ case (_, col1, _) => !(col1 % 2 == 0 && col1 < 30) - }.toSet - - val snapshot = Table.forPath(path).getLatestSnapshot(defaultTableClient) - val result = readSnapshot(snapshot).map { row => - (row.getInt(0), row.getInt(1), row.getString(2)) - } - - assert (result.toSet == expectedResult) + } + checkTable( + path = getTestResourceFilePath("dv-with-columnmapping"), + expectedAnswer = expectedResult.map(TestRow.fromTuple(_)) + ) } - // TODO detect corrupted DV checksum // TODO detect corrupted dv size // TODO multiple dvs in one file diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala index 3a4328bab97..0494f4e11a7 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala @@ -22,8 +22,8 @@ import org.scalatest.funsuite.AnyFunSuite import io.delta.golden.GoldenTableUtils.goldenTablePath import io.delta.kernel.Table - import io.delta.kernel.defaults.internal.DefaultKernelUtils +import io.delta.kernel.defaults.utils.{TestRow, TestUtils} class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { @@ -54,52 +54,44 @@ class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { 4 | null | null */ - def row0: (Int, Option[Long]) = ( + def row0: TestRow = TestRow( 0, - Some(1580544550000000L) // 2020-02-01 08:09:10 UTC to micros since the epoch + 1580544550000000L // 2020-02-01 08:09:10 UTC to micros since the epoch ) - def row1: (Int, Option[Long]) = ( + def row1: TestRow = TestRow( 1, - Some(915181200000000L) // 1999-01-01 09:00:00 UTC to micros since the epoch + 915181200000000L // 1999-01-01 09:00:00 UTC to micros since the epoch ) - def row2: (Int, Option[Long]) = ( + def row2: TestRow = TestRow( 2, - Some(946717200000000L) // 2000-01-01 09:00:00 UTC to micros since the epoch + 946717200000000L // 2000-01-01 09:00:00 UTC to micros since the epoch ) - def row3: (Int, Option[Long]) = ( + def row3: TestRow = TestRow( 3, - Some(-31536000000000L) // 1969-01-01 00:00:00 UTC to micros since the epoch + -31536000000000L // 1969-01-01 00:00:00 UTC to micros since the epoch ) - def row4: (Int, Option[Long]) = ( + def row4: TestRow = TestRow( 4, - None + null ) - // TODO: refactor this once testing utilities have support for Rows/ColumnarBatches - def utcTableExpectedResult: Set[(Int, Option[Long])] = Set(row0, row1, row2, row3, row4) + def utcTableExpectedResult: Seq[TestRow] = Seq(row0, row1, row2, row3, row4) def testTimestampTable( goldenTableName: String, timeZone: String, - expectedResult: Set[(Int, Option[Long])]): Unit = { + expectedResult: Seq[TestRow]): Unit = { withTimeZone(timeZone) { - // kernel expects a fully qualified path - val path = "file:" + goldenTablePath(goldenTableName) - val snapshot = Table.forPath(path).getLatestSnapshot(defaultTableClient) - - // for now omit "part" column since we don't support reading timestamp partition values - val readSchema = snapshot.getSchema(defaultTableClient) - .withoutField("part") - - val result = readSnapshot(snapshot, readSchema).map { row => - (row.getInt(0), if (row.isNullAt(1)) Option.empty[Long] else Some(row.getLong(1))) - } - - assert(result.toSet == expectedResult) + checkTable( + path = "file:" + goldenTablePath(goldenTableName), + expectedAnswer = expectedResult, + // for now omit "part" column since we don't support reading timestamp partition values + readCols = Seq("id", "time") + ) } } @@ -113,9 +105,16 @@ class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { } // PST table - all the "time" col timestamps are + 8 hours - def pstTableExpectedResult: Set[(Int, Option[Long])] = utcTableExpectedResult.map { - case (id, col) => - (id, col.map(_ + DefaultKernelUtils.DateTimeConstants.MICROS_PER_HOUR * 8)) + def pstTableExpectedResult: Seq[TestRow] = utcTableExpectedResult.map { testRow => + val values = testRow.toSeq + TestRow( + values(0), + if (values(1) == null) { + null + } else { + values(1).asInstanceOf[Long] + DefaultKernelUtils.DateTimeConstants.MICROS_PER_HOUR * 8 + } + ) } for (timeZone <- Seq("UTC", "Iceland", "PST", "America/Los_Angeles")) { @@ -138,30 +137,23 @@ class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { ).map { tup => (new BigDecimal(tup._1), new BigDecimal(tup._2), new BigDecimal(tup._3), new BigDecimal(tup._4)) - }.toSet - - // kernel expects a fully qualified path - val path = "file:" + goldenTablePath(tablePath) - val snapshot = Table.forPath(path).getLatestSnapshot(defaultTableClient) - - val result = readSnapshot(snapshot).map { row => - (row.getDecimal(0), row.getDecimal(1), row.getDecimal(2), row.getDecimal(3)) } - assert(expectedResult == result.toSet) + checkTable( + path = "file:" + goldenTablePath(tablePath), + expectedAnswer = expectedResult.map(TestRow.fromTuple(_)) + ) } } - test("end to end: multi-part checkpoint") { - val expectedResult = Seq(0) ++ (0 until 30) - - // kernel expects a fully qualified path - val path = "file:" + goldenTablePath("multi-part-checkpoint") - val snapshot = Table.forPath(path).getLatestSnapshot(defaultTableClient) - val result = readSnapshot(snapshot).map { row => - row.getLong(0) - } + ////////////////////////////////////////////////////////////////////////////////// + // Misc tests + ////////////////////////////////////////////////////////////////////////////////// - assert(result.toSet == expectedResult.toSet) + test("end to end: multi-part checkpoint") { + checkTable( + path = "file:" + goldenTablePath("multi-part-checkpoint"), + expectedAnswer = (Seq(0L) ++ (0L until 30L)).map(TestRow(_)) + ) } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplaySuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplaySuite.scala index 0264321588d..cc812990aba 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplaySuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplaySuite.scala @@ -28,6 +28,7 @@ import org.scalatest.funsuite.AnyFunSuite // scalastyle:off println class LogReplaySuite extends AnyFunSuite { + // TODO: refactor to use TestUtils private val tableClient = DefaultTableClient.create(new Configuration() {{ // Set the batch sizes to small so that we get to test the multiple batch scenarios. diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ParquetBatchReaderSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ParquetBatchReaderSuite.scala index b377af05715..44312ef0f8f 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ParquetBatchReaderSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ParquetBatchReaderSuite.scala @@ -19,13 +19,12 @@ import java.io.File import java.math.BigDecimal import org.scalatest.funsuite.AnyFunSuite - import org.apache.hadoop.conf.Configuration import io.delta.golden.GoldenTableUtils.goldenTableFile import io.delta.kernel.types.{DecimalType, IntegerType, StructType} - import io.delta.kernel.defaults.internal.parquet.ParquetBatchReader +import io.delta.kernel.defaults.utils.{TestRow, TestUtils} class ParquetBatchReaderSuite extends AnyFunSuite with TestUtils { @@ -47,8 +46,8 @@ class ParquetBatchReaderSuite extends AnyFunSuite with TestUtils { test("decimals encoded using dictionary encoding ") { val expectedResult = (0 until 1000000).map { i => - (i, BigDecimal.valueOf(i%5), BigDecimal.valueOf(i%6), BigDecimal.valueOf(i%2)) - }.toSet + TestRow(i, BigDecimal.valueOf(i%5), BigDecimal.valueOf(i%6), BigDecimal.valueOf(i%2)) + } val readSchema = new StructType() .add("id", IntegerType.INSTANCE) @@ -59,12 +58,8 @@ class ParquetBatchReaderSuite extends AnyFunSuite with TestUtils { val batchReader = new ParquetBatchReader(new Configuration()) for (file <- Seq(DECIMAL_TYPES_DICT_FILE_V1, DECIMAL_TYPES_DICT_FILE_V2)) { val batches = batchReader.read(file, readSchema) - - val result = batches.toSeq.flatMap(_.getRows.toSeq).map { row => - (row.getInt(0), row.getDecimal(1), row.getDecimal(2), row.getDecimal(3)) - } - - assert(expectedResult == result.toSet) + val result = batches.toSeq.flatMap(_.getRows.toSeq) + checkAnswer(result, expectedResult) } } @@ -80,7 +75,7 @@ class ParquetBatchReaderSuite extends AnyFunSuite with TestUtils { val expectedResult = (0 until 99998).map { i => if (i % 85 == 0) { val n = BigDecimal.valueOf(i) - (i, n.movePointLeft(1).setScale(1), n.setScale(5), n.setScale(5)) + TestRow(i, n.movePointLeft(1).setScale(1), n.setScale(5), n.setScale(5)) } else { val negation = if (i % 33 == 0) { -1 @@ -88,14 +83,14 @@ class ParquetBatchReaderSuite extends AnyFunSuite with TestUtils { 1 } val n = BigDecimal.valueOf(i*negation) - ( + TestRow( i, n.movePointLeft(1), expand(n).movePointLeft(5), expand(expand(expand(n))).movePointLeft(5) ) } - }.toSet + } val readSchema = new StructType() .add("id", IntegerType.INSTANCE) @@ -106,11 +101,8 @@ class ParquetBatchReaderSuite extends AnyFunSuite with TestUtils { val batchReader = new ParquetBatchReader(new Configuration()) val batches = batchReader.read(LARGE_SCALE_DECIMAL_TYPES_FILE, readSchema) - val result = batches.toSeq.flatMap(_.getRows.toSeq).map { row => - (row.getInt(0), row.getDecimal(1), row.getDecimal(2), row.getDecimal(3)) - } - - assert(expectedResult == result.toSet) + val result = batches.toSeq.flatMap(_.getRows.toSeq) + checkAnswer(result, expectedResult) } ////////////////////////////////////////////////////////////////////////////////// diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/TestUtils.scala deleted file mode 100644 index cc75a1d121e..00000000000 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/TestUtils.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright (2021) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.delta.kernel.defaults - -import java.util.{Optional, TimeZone} - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - -import io.delta.kernel.{Scan, Snapshot} -import io.delta.kernel.client.TableClient -import io.delta.kernel.data.Row -import io.delta.kernel.defaults.client.DefaultTableClient -import io.delta.kernel.types._ -import io.delta.kernel.utils.CloseableIterator -import org.apache.hadoop.conf.Configuration - -trait TestUtils { - lazy val defaultTableClient = DefaultTableClient.create(new Configuration()) - - implicit class CloseableIteratorOps[T](private val iter: CloseableIterator[T]) { - - def forEach(f: T => Unit): Unit = { - try { - while (iter.hasNext) { - f(iter.next()) - } - } finally { - iter.close() - } - } - - def toSeq: Seq[T] = { - try { - val result = new ArrayBuffer[T] - while (iter.hasNext) { - result.append(iter.next()) - } - result - } finally { - iter.close() - } - } - } - - implicit class StructTypeOps(schema: StructType) { - - def withoutField(name: String): StructType = { - val newFields = schema.fields().asScala - .filter(_.getName != name).asJava - new StructType(newFields) - } - } - - def readSnapshot( - snapshot: Snapshot, - readSchema: StructType = null, - tableClient: TableClient = defaultTableClient): Seq[Row] = { - - val result = ArrayBuffer[Row]() - - var scanBuilder = snapshot.getScanBuilder(tableClient) - - if (readSchema != null) { - scanBuilder = scanBuilder.withReadSchema(tableClient, readSchema) - } - - val scan = scanBuilder.build() - - val scanState = scan.getScanState(tableClient); - val fileIter = scan.getScanFiles(tableClient) - // TODO serialize scan state and scan rows - - fileIter.forEach { fileColumnarBatch => - // TODO deserialize scan state and scan rows - val dataBatches = Scan.readData( - tableClient, - scanState, - fileColumnarBatch.getRows(), - Optional.empty() - ) - - dataBatches.forEach { batch => - val selectionVector = batch.getSelectionVector() - val data = batch.getData() - - var i = 0 - val rowIter = data.getRows() - try { - while (rowIter.hasNext) { - val row = rowIter.next() - if (!selectionVector.isPresent || selectionVector.get.getBoolean(i)) { // row is valid - result.append(row) - } - i += 1 - } - } finally { - rowIter.close() - } - } - } - result - } - - def withTimeZone(zoneId: String)(f: => Unit): Unit = { - val currentDefault = TimeZone.getDefault - try { - TimeZone.setDefault(TimeZone.getTimeZone(zoneId)) - f - } finally { - TimeZone.setDefault(currentDefault) - } - } - - /** All simple data type used in parameterized tests where type is one of the test dimensions. */ - val SIMPLE_TYPES = Seq( - BooleanType.INSTANCE, - ByteType.INSTANCE, - ShortType.INSTANCE, - IntegerType.INSTANCE, - LongType.INSTANCE, - FloatType.INSTANCE, - DoubleType.INSTANCE, - DateType.INSTANCE, - TimestampType.INSTANCE, - StringType.INSTANCE, - BinaryType.INSTANCE, - new DecimalType(10, 5) - ) - - /** All types. Used in parameterized tests where type is one of the test dimensions. */ - val ALL_TYPES = SIMPLE_TYPES ++ Seq( - new ArrayType(BooleanType.INSTANCE, true), - new MapType(IntegerType.INSTANCE, LongType.INSTANCE, true), - new StructType().add("s1", BooleanType.INSTANCE).add("s2", IntegerType.INSTANCE) - ) -} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala index d8c3ccfaa35..075d4851c19 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -21,7 +21,7 @@ import java.util import io.delta.kernel.data.{ColumnarBatch, ColumnVector} import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch -import io.delta.kernel.defaults.TestUtils +import io.delta.kernel.defaults.utils.TestUtils import io.delta.kernel.defaults.internal.data.vector.VectorUtils.getValueAsObject import io.delta.kernel.expressions._ import io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala index 40a3fce6b51..7aad68e5a76 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala @@ -18,7 +18,7 @@ package io.delta.kernel.defaults.internal.expressions import io.delta.kernel.data.ColumnVector import io.delta.kernel.defaults.internal.data.vector.VectorUtils.getValueAsObject import io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo -import io.delta.kernel.defaults.TestUtils +import io.delta.kernel.defaults.utils.TestUtils import io.delta.kernel.expressions.Column import io.delta.kernel.types._ import org.scalatest.funsuite.AnyFunSuite diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala new file mode 100644 index 00000000000..8c81f4ea3ff --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala @@ -0,0 +1,125 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.utils + +import scala.collection.JavaConverters._ + +import io.delta.kernel.data.Row +import io.delta.kernel.types._ + +/** + * Corresponding Scala class for each Kernel data type: + * - BooleanType --> boolean + * - ByteType --> byte + * - ShortType --> short + * - IntegerType --> int + * - LongType --> long + * - FloatType --> float + * - DoubleType --> double + * - StringType --> String + * - DateType --> int (number of days since the epoch) + * - TimestampType --> long (number of microseconds since the unix epoch) + * - DecimalType --> java.math.BigDecimal + * - BinaryType --> Array[Byte] + * + * TODO: complex types + * - StructType? + * - ArrayType? + * - MapType? + */ +class TestRow(val values: Array[Any]) { + // TODO: we could make this extend Row and create a way to generate Seq(Any) from Rows but it + // would complicate a lot of the code for not much benefit + + def length: Int = values.length + + def get(i: Int): Any = values(i) + + def toSeq: Seq[Any] = values.clone() + + def mkString(start: String, sep: String, end: String): String = { + val n = length + val builder = new StringBuilder + builder.append(start) + if (n > 0) { + builder.append(get(0)) + var i = 1 + while (i < n) { + builder.append(sep) + builder.append(get(i)) + i += 1 + } + } + builder.append(end) + builder.toString() + } + + override def toString: String = this.mkString("[", ",", "]") +} + +object TestRow { + + /** + * Construct a [[TestRow]] with the given values. See the docs for [[TestRow]] for + * the scala type corresponding to each Kernel data type. + */ + def apply(values: Any*): TestRow = { + new TestRow(values.toArray) + } + + /** + * Construct a [[TestRow]] with the same values as a Kernel [[Row]]. + */ + def apply(row: Row): TestRow = { + TestRow.fromSeq(row.getSchema.fields().asScala.zipWithIndex.map { case (field, i) => + field.getDataType match { + case _ if row.isNullAt(i) => null + case _: BooleanType => row.getBoolean(i) + case _: ByteType => row.getByte(i) + case _: IntegerType => row.getInt(i) + case _: LongType => row.getLong(i) + case _: ShortType => row.getShort(i) + case _: DateType => row.getInt(i) + case _: TimestampType => row.getLong(i) + case _: FloatType => row.getFloat(i) + case _: DoubleType => row.getDouble(i) + case _: StringType => row.getString(i) + case _: BinaryType => row.getBinary(i) + case _: DecimalType => row.getDecimal(i) + + // TODO complex types + // case _: StructType => row.getStruct(i) + // case _: MapType => row.getMap(i) + // case _: ArrayType => row.getArray(i) + case _ => throw new UnsupportedOperationException("unrecognized data type") + } + }) + } + + /** + * Construct a [[TestRow]] from the given seq of values. See the docs for [[TestRow]] for + * the scala type corresponding to each Kernel data type. + */ + def fromSeq(values: Seq[Any]): TestRow = { + new TestRow(values.toArray) + } + + /** + * Construct a [[TestRow]] with the elements of the given tuple. See the docs for + * [[TestRow]] for the scala type corresponding to each Kernel data type. + */ + def fromTuple(tuple: Product): TestRow = fromSeq(tuple.productIterator.toSeq) +} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala new file mode 100644 index 00000000000..2dce9ee4ad4 --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -0,0 +1,284 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.utils + +import java.util.{Optional, TimeZone} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import io.delta.kernel.{Scan, Snapshot, Table} +import io.delta.kernel.client.TableClient +import io.delta.kernel.data.Row +import io.delta.kernel.defaults.client.DefaultTableClient +import io.delta.kernel.types._ +import io.delta.kernel.utils.CloseableIterator +import org.apache.hadoop.conf.Configuration +import org.scalatest.Assertions + +trait TestUtils extends Assertions { + + lazy val configuration = new Configuration() + lazy val defaultTableClient = DefaultTableClient.create(configuration) + + implicit class CloseableIteratorOps[T](private val iter: CloseableIterator[T]) { + + def forEach(f: T => Unit): Unit = { + try { + while (iter.hasNext) { + f(iter.next()) + } + } finally { + iter.close() + } + } + + def toSeq: Seq[T] = { + try { + val result = new ArrayBuffer[T] + while (iter.hasNext) { + result.append(iter.next()) + } + result + } finally { + iter.close() + } + } + } + + implicit class StructTypeOps(schema: StructType) { + + def withoutField(name: String): StructType = { + val newFields = schema.fields().asScala + .filter(_.getName != name).asJava + new StructType(newFields) + } + } + + def readSnapshot( + snapshot: Snapshot, + readSchema: StructType = null, + tableClient: TableClient = defaultTableClient): Seq[Row] = { + + val result = ArrayBuffer[Row]() + + var scanBuilder = snapshot.getScanBuilder(tableClient) + + if (readSchema != null) { + scanBuilder = scanBuilder.withReadSchema(tableClient, readSchema) + } + + val scan = scanBuilder.build() + + val scanState = scan.getScanState(tableClient); + val fileIter = scan.getScanFiles(tableClient) + // TODO serialize scan state and scan rows + + fileIter.forEach { fileColumnarBatch => + // TODO deserialize scan state and scan rows + val dataBatches = Scan.readData( + tableClient, + scanState, + fileColumnarBatch.getRows(), + Optional.empty() + ) + + dataBatches.forEach { batch => + val selectionVector = batch.getSelectionVector() + val data = batch.getData() + + var i = 0 + val rowIter = data.getRows() + try { + while (rowIter.hasNext) { + val row = rowIter.next() + if (!selectionVector.isPresent || selectionVector.get.getBoolean(i)) { // row is valid + result.append(row) + } + i += 1 + } + } finally { + rowIter.close() + } + } + } + result + } + + /** + * Execute {@code f} with {@code TimeZone.getDefault()} set to the time zone provided. + * + * @param zoneId the ID for a TimeZone, either an abbreviation such as "PST", a full name such as + * "America/Los_Angeles", or a custom ID such as "GMT-8:00". + */ + def withTimeZone(zoneId: String)(f: => Unit): Unit = { + val currentDefault = TimeZone.getDefault + try { + TimeZone.setDefault(TimeZone.getTimeZone(zoneId)) + f + } finally { + TimeZone.setDefault(currentDefault) + } + } + + /** All simple data type used in parameterized tests where type is one of the test dimensions. */ + val SIMPLE_TYPES = Seq( + BooleanType.INSTANCE, + ByteType.INSTANCE, + ShortType.INSTANCE, + IntegerType.INSTANCE, + LongType.INSTANCE, + FloatType.INSTANCE, + DoubleType.INSTANCE, + DateType.INSTANCE, + TimestampType.INSTANCE, + StringType.INSTANCE, + BinaryType.INSTANCE, + new DecimalType(10, 5) + ) + + /** All types. Used in parameterized tests where type is one of the test dimensions. */ + val ALL_TYPES = SIMPLE_TYPES ++ Seq( + new ArrayType(BooleanType.INSTANCE, true), + new MapType(IntegerType.INSTANCE, LongType.INSTANCE, true), + new StructType().add("s1", BooleanType.INSTANCE).add("s2", IntegerType.INSTANCE) + ) + + /** + * Compares the rows in the tables latest snapshot with the expected answer and fails if they + * do not match. The comparison is order independent. If expectedSchema is provided, checks + * that the latest snapshot's schema is equivalent. + * + * @param path fully qualified path of the table to check + * @param expectedAnswer expected rows + * @param readCols subset of columns to read; if null then all columns will be read + * @param tableClient table client to use to read the table + * @param expectedSchema expected schema to check for; if null then no check is performed + */ + def checkTable( + path: String, + expectedAnswer: Seq[TestRow], + readCols: Seq[String] = null, + tableClient: TableClient = defaultTableClient, + expectedSchema: StructType = null + // filter + // version + ): Unit = { + + val snapshot = Table.forPath(path).getLatestSnapshot(tableClient) + + val readSchema = if (readCols == null) { + null + } else { + val schema = snapshot.getSchema(tableClient) + new StructType(readCols.map(schema.get(_)).asJava) + } + + if (expectedSchema != null) { + assert( + expectedSchema == snapshot.getSchema(tableClient), + s""" + |Expected schema does not match actual schema: + |Expected schema: $expectedSchema + |Actual schema: ${snapshot.getSchema(tableClient)} + |""".stripMargin + ) + } + + val result = readSnapshot(snapshot, readSchema = readSchema, tableClient = tableClient) + checkAnswer(result, expectedAnswer) + } + + def checkAnswer(result: => Seq[Row], expectedAnswer: Seq[TestRow]): Unit = { + checkAnswer(result.map(TestRow(_)), expectedAnswer) + } + + def checkAnswer(result: Seq[TestRow], expectedAnswer: Seq[TestRow]): Unit = { + if (!compare(prepareAnswer(result), prepareAnswer(expectedAnswer))) { + fail(genErrorMessage(expectedAnswer, result)) + } + } + + private def prepareAnswer(answer: Seq[TestRow]): Seq[TestRow] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted = answer.map(prepareRow) + converted.sortBy(_.toString()) + } + + // We need to call prepareRow recursively to handle schemas with struct types. + private def prepareRow(row: TestRow): TestRow = { + TestRow.fromSeq(row.toSeq.map { + case null => null + case bd: java.math.BigDecimal => BigDecimal(bd) + // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ + case seq: Seq[_] => seq.map { + case b: java.lang.Byte => b.byteValue + case s: java.lang.Short => s.shortValue + case i: java.lang.Integer => i.intValue + case l: java.lang.Long => l.longValue + case f: java.lang.Float => f.floatValue + case d: java.lang.Double => d.doubleValue + case x => x + } + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: TestRow => prepareRow(r) + case o => o + }) + } + + private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => + a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r) } + case (a: Map[_, _], b: Map[_, _]) => + a.size == b.size && a.keys.forall { aKey => + b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey))) + } + case (a: Iterable[_], b: Iterable[_]) => + a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r) } + case (a: Product, b: Product) => + compare(a.productIterator.toSeq, b.productIterator.toSeq) + case (a: TestRow, b: TestRow) => + compare(a.toSeq, b.toSeq) + // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. + case (a: Double, b: Double) => + java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) + case (a: Float, b: Float) => + java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) + case (a, b) => a.equals(b) // In scala == does not call equals for boxed numeric classes? + } + + private def genErrorMessage(expectedAnswer: Seq[TestRow], result: Seq[TestRow]): String = { + // TODO: improve to include schema or Java type information to help debugging + s""" + |== Results == + | + |== Expected Answer - ${expectedAnswer.size} == + |${prepareAnswer(expectedAnswer).map(_.toString()).mkString("(", ",", ")")} + | + |== Result - ${result.size} == + |${prepareAnswer(result).map(_.toString()).mkString("(", ",", ")")} + | + |""".stripMargin + } +}