From 5570c74dca5b19dab7e91d0630606d565f342bd1 Mon Sep 17 00:00:00 2001 From: juhoautio-rovio Date: Sat, 7 Sep 2024 00:57:34 +0300 Subject: [PATCH] Add support for array types Co-authored-by: fabricebaranski --- README.md | 3 +- .../ingest/DataSegmentCommitMessage.java | 7 +++ .../java/com/rovio/ingest/TaskDataWriter.java | 39 ++++++++++++ .../com/rovio/ingest/model/FieldType.java | 35 ++++++++++- .../com/rovio/ingest/model/SegmentSpec.java | 21 ++++++- .../extensions/DruidDatasetExtensions.scala | 6 +- .../com/rovio/ingest/SegmentSpecTest.java | 63 +++++++++++++++++++ .../ingest/DruidDatasetExtensionsSpec.scala | 6 +- 8 files changed, 170 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 6f3ce8d..af6290b 100644 --- a/README.md +++ b/README.md @@ -64,8 +64,9 @@ The Dataset extension performs the following validations: * The Dataset has no columns with unknown types, unless `excludeColumnsWithUnknownTypes` is set to true The Dataset extension performs the following transformations: -* Drops all columns of complex datatypes such as `StructType`, `MapType` or `ArrayType` as they +* Drops all columns of complex datatypes such as `StructType` or `MapType` as they are not supported by `DruidSource`. This is only done if `excludeColumnsWithUnknownTypes` is set to true, otherwise validation has already failed. +* `ArrayType` is supported with `StringType`, `LongType` and `DoubleType` * Converts `Date`/`Timestamp` type columns to `String`, except for the `time_column` - See [Druid Docs / Data types](https://druid.apache.org/docs/latest/querying/sql.html#standard-types) * Adds a new column `__PARTITION_TIME__` whose value is based on `time_column` column and the given [segment granularity](#segment-granularity) diff --git a/src/main/java/com/rovio/ingest/DataSegmentCommitMessage.java b/src/main/java/com/rovio/ingest/DataSegmentCommitMessage.java index 13c5771..3d1fb77 100644 --- a/src/main/java/com/rovio/ingest/DataSegmentCommitMessage.java +++ b/src/main/java/com/rovio/ingest/DataSegmentCommitMessage.java @@ -20,6 +20,7 @@ import com.fasterxml.jackson.databind.InjectableValues; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.jsontype.NamedType; +import org.apache.druid.guice.NestedDataModule; import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.aggregation.datasketches.hll.HllSketchModule; @@ -29,6 +30,8 @@ import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketchBuildComplexMetricSerde; import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketchMergeComplexMetricSerde; import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketchModule; +import org.apache.druid.segment.DefaultColumnFormatConfig; +import org.apache.druid.segment.nested.NestedDataComplexTypeSerde; import org.apache.druid.segment.serde.ComplexMetrics; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.partition.LinearShardSpec; @@ -52,6 +55,7 @@ public class DataSegmentCommitMessage implements WriterCommitMessage { // ExpressionMacroTable is injected in AggregatorFactories. // However, ExprMacro are not actually required as the DataSource is write-only. .addValue(ExprMacroTable.class, ExprMacroTable.nil()) + .addValue(DefaultColumnFormatConfig.class, new DefaultColumnFormatConfig(null)) // PruneLoadSpecHolder are injected in DataSegment. .addValue(DataSegment.PruneSpecsHolder.class, DataSegment.PruneSpecsHolder.DEFAULT); @@ -61,12 +65,14 @@ public class DataSegmentCommitMessage implements WriterCommitMessage { MAPPER.setTimeZone(TimeZone.getTimeZone("UTC")); + new NestedDataModule().getJacksonModules().forEach(MAPPER::registerModule); new SketchModule().getJacksonModules().forEach(MAPPER::registerModule); new HllSketchModule().getJacksonModules().forEach(MAPPER::registerModule); new KllSketchModule().getJacksonModules().forEach(MAPPER::registerModule); new DoublesSketchModule().getJacksonModules().forEach(MAPPER::registerModule); new ArrayOfDoublesSketchModule().getJacksonModules().forEach(MAPPER::registerModule); + NestedDataModule.registerHandlersAndSerde(); HllSketchModule.registerSerde(); KllSketchModule.registerSerde(); DoublesSketchModule.registerSerde(); @@ -75,6 +81,7 @@ public class DataSegmentCommitMessage implements WriterCommitMessage { ComplexMetrics.registerSerde("arrayOfDoublesSketch", new ArrayOfDoublesSketchMergeComplexMetricSerde()); ComplexMetrics.registerSerde("arrayOfDoublesSketchMerge", new ArrayOfDoublesSketchMergeComplexMetricSerde()); ComplexMetrics.registerSerde("arrayOfDoublesSketchBuild", new ArrayOfDoublesSketchBuildComplexMetricSerde()); + ComplexMetrics.registerSerde(NestedDataComplexTypeSerde.TYPE_NAME, NestedDataComplexTypeSerde.INSTANCE); } diff --git a/src/main/java/com/rovio/ingest/TaskDataWriter.java b/src/main/java/com/rovio/ingest/TaskDataWriter.java index c6515cf..8c5ab19 100644 --- a/src/main/java/com/rovio/ingest/TaskDataWriter.java +++ b/src/main/java/com/rovio/ingest/TaskDataWriter.java @@ -15,7 +15,9 @@ */ package com.rovio.ingest; +import com.fasterxml.jackson.core.JsonProcessingException; import com.rovio.ingest.model.Field; +import com.rovio.ingest.model.FieldType; import com.rovio.ingest.model.SegmentSpec; import com.rovio.ingest.util.ReflectionUtils; import com.rovio.ingest.util.SegmentStorageUpdater; @@ -38,6 +40,7 @@ import org.apache.druid.segment.indexing.RealtimeTuningConfig; import org.apache.druid.segment.loading.DataSegmentKiller; import org.apache.druid.segment.loading.DataSegmentPusher; +import org.apache.druid.segment.nested.StructuredData; import org.apache.druid.segment.realtime.FireDepartmentMetrics; import org.apache.druid.segment.realtime.appenderator.Appenderator; import org.apache.druid.segment.realtime.appenderator.DefaultOfflineAppenderatorFactory; @@ -50,6 +53,7 @@ import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.partition.LinearShardSpec; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.connector.write.DataWriter; import org.apache.spark.sql.connector.write.WriterCommitMessage; import org.apache.spark.sql.types.DataType; @@ -63,6 +67,7 @@ import java.io.IOException; import java.time.LocalDate; import java.time.ZoneOffset; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -217,6 +222,40 @@ private Map parse(InternalRow record) { // Convert to Java String as Spark return UTF8String which is not compatible with Druid sketches. value = value.toString(); } + if (value != null && segmentSpec.getComplexDimensionColumns().contains(columnName) && sqlType == DataTypes.StringType) { + try { + value = MAPPER.readValue(value.toString(), StructuredData.class); + } catch (JsonProcessingException e) { + value = null; + } + } + if (value != null && field.getFieldType() == FieldType.ARRAY_OF_STRING) { + ArrayData arrayData = record.getArray(field.getOrdinal()); + int arraySize = arrayData.numElements(); + List valueArrayOfString = new ArrayList<>(arraySize); + for (int i = 0; i < arraySize; i++) { + valueArrayOfString.add(arrayData.get(i, DataTypes.StringType).toString()); + } + value = valueArrayOfString; + } + if (value != null && field.getFieldType()== FieldType.ARRAY_OF_DOUBLE) { + ArrayData arrayData = record.getArray(field.getOrdinal()); + int arraySize = arrayData.numElements(); + Double[] valueArrayOfFloat = new Double[arraySize]; + for (int i = 0; i < arraySize; i++) { + valueArrayOfFloat[i] = (Double) arrayData.get(i, DataTypes.DoubleType); + } + value = valueArrayOfFloat; + } + if (value != null && field.getFieldType()== FieldType.ARRAY_OF_LONG) { + ArrayData arrayData = record.getArray(field.getOrdinal()); + int arraySize = arrayData.numElements(); + Long[] valueArrayOfLong = new Long[arraySize]; + for (int i = 0; i < arraySize; i++) { + valueArrayOfLong[i] = (Long) arrayData.get(i, DataTypes.LongType); + } + value = valueArrayOfLong; + } map.put(columnName, value); } } diff --git a/src/main/java/com/rovio/ingest/model/FieldType.java b/src/main/java/com/rovio/ingest/model/FieldType.java index e875721..53ffd00 100644 --- a/src/main/java/com/rovio/ingest/model/FieldType.java +++ b/src/main/java/com/rovio/ingest/model/FieldType.java @@ -18,11 +18,16 @@ import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; +import java.util.Objects; + public enum FieldType { TIMESTAMP, DOUBLE, LONG, - STRING; + STRING, + ARRAY_OF_DOUBLE, + ARRAY_OF_LONG, + ARRAY_OF_STRING; public static FieldType from(DataType dataType) { if (isNumericType(dataType)) { @@ -41,6 +46,18 @@ public static FieldType from(DataType dataType) { return STRING; } + if (isArrayOfNumericType(dataType)) { + return ARRAY_OF_LONG; + } + + if (isArrayOfDoubleType(dataType)) { + return ARRAY_OF_DOUBLE; + } + + if (isArrayOfStringType(dataType)) { + return ARRAY_OF_STRING; + } + throw new IllegalArgumentException("Unsupported Type " + dataType); } @@ -55,4 +72,20 @@ private static boolean isNumericType(DataType dataType) { || dataType == DataTypes.ByteType; } + private static boolean isArrayOfNumericType(DataType dataType) { + return Objects.equals(dataType, DataTypes.createArrayType(DataTypes.LongType)) + || Objects.equals(dataType, DataTypes.createArrayType(DataTypes.IntegerType)) + || Objects.equals(dataType, DataTypes.createArrayType(DataTypes.ShortType)) + || Objects.equals(dataType, DataTypes.createArrayType(DataTypes.ByteType)); + } + + private static boolean isArrayOfDoubleType(DataType dataType) { + return Objects.equals(dataType, DataTypes.createArrayType(DataTypes.DoubleType)) + || Objects.equals(dataType, DataTypes.createArrayType(DataTypes.FloatType)); + } + private static boolean isArrayOfStringType(DataType dataType) { + return Objects.equals(dataType, DataTypes.createArrayType(DataTypes.StringType)) + || Objects.equals(dataType, DataTypes.createArrayType(DataTypes.BooleanType)); + } + } diff --git a/src/main/java/com/rovio/ingest/model/SegmentSpec.java b/src/main/java/com/rovio/ingest/model/SegmentSpec.java index 342bb2a..7b78d8a 100644 --- a/src/main/java/com/rovio/ingest/model/SegmentSpec.java +++ b/src/main/java/com/rovio/ingest/model/SegmentSpec.java @@ -30,6 +30,8 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; +import org.apache.druid.segment.AutoTypeColumnSchema; +import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.indexing.DataSchema; import org.apache.druid.segment.indexing.granularity.GranularitySpec; @@ -66,8 +68,8 @@ public class SegmentSpec implements Serializable { private final String dimensionsSpec; private final String metricsSpec; private final String transformSpec; - private final Set complexMetricColumns; + private final Set complexDimensionColumns; private SegmentSpec(String dataSource, String timeColumn, String segmentGranularity, String queryGranularity, List fields, Field partitionTime, Field partitionNum, boolean rollup, @@ -88,6 +90,11 @@ private SegmentSpec(String dataSource, String timeColumn, String segmentGranular .filter(aggregatorFactory -> aggregatorFactory.getIntermediateType().is(ValueType.COMPLEX)) .flatMap((AggregatorFactory aggregatorFactory) -> aggregatorFactory.requiredFields().stream()) .collect(Collectors.toSet()); + this.complexDimensionColumns = getDimensionsSpec().getDimensions() + .stream() + .filter(dimensionSchema -> dimensionSchema.getColumnType() == ColumnType.NESTED_DATA) + .map(DimensionSchema::getName) + .collect(Collectors.toSet()); } public static SegmentSpec from(String datasource, String timeColumn, List excludedDimensions, @@ -127,7 +134,7 @@ public static SegmentSpec from(String datasource, String timeColumn, List f.getFieldType() == FieldType.TIMESTAMP && !f.getName().equals(timeColumn) && !f.getName().equals(PARTITION_TIME_COLUMN_NAME)), String.format("Schema has another timestamp field other than \"%s\"", timeColumn)); - Preconditions.checkArgument(fields.stream().anyMatch(f -> f.getFieldType() == FieldType.STRING), + Preconditions.checkArgument(fields.stream().anyMatch(f -> f.getFieldType() == FieldType.STRING || f.getFieldType() == FieldType.ARRAY_OF_STRING), "Schema has no dimensions"); Preconditions.checkArgument(!rollup || fields.stream().anyMatch(f -> f.getFieldType() == FieldType.LONG || f.getFieldType() == FieldType.DOUBLE), @@ -217,6 +224,12 @@ private ImmutableList getDimensionSchemas() { builder.add(new DoubleDimensionSchema(fieldName)); } else if (field.getFieldType() == FieldType.TIMESTAMP) { builder.add(new LongDimensionSchema(fieldName)); + } else if (field.getFieldType() == FieldType.ARRAY_OF_STRING) { + builder.add(new AutoTypeColumnSchema(fieldName, ColumnType.STRING_ARRAY)); + } else if (field.getFieldType() == FieldType.ARRAY_OF_DOUBLE) { + builder.add(new AutoTypeColumnSchema(fieldName, ColumnType.DOUBLE_ARRAY)); + } else if (field.getFieldType() == FieldType.ARRAY_OF_LONG) { + builder.add(new AutoTypeColumnSchema(fieldName, ColumnType.LONG_ARRAY)); } } } @@ -269,4 +282,8 @@ private AggregatorFactory[] getAggregators() { public Set getComplexMetricColumns() { return complexMetricColumns; } + + public Set getComplexDimensionColumns() { + return complexDimensionColumns; + } } diff --git a/src/main/scala/com/rovio/ingest/extensions/DruidDatasetExtensions.scala b/src/main/scala/com/rovio/ingest/extensions/DruidDatasetExtensions.scala index a1b1a50..edd5918 100644 --- a/src/main/scala/com/rovio/ingest/extensions/DruidDatasetExtensions.scala +++ b/src/main/scala/com/rovio/ingest/extensions/DruidDatasetExtensions.scala @@ -51,8 +51,8 @@ object DruidDatasetExtensions { */ @SerialVersionUID(1L) implicit class DruidDataset(dataset: Dataset[Row]) extends Serializable { - private val METRIC_TYPES = Array(FloatType, DoubleType, IntegerType, LongType, ShortType, ByteType) - private val DIMENSION_TYPES = Array(StringType, DateType, TimestampType, BooleanType) + private val METRIC_TYPES = Array(FloatType, DoubleType, IntegerType, LongType, ShortType, ByteType, ArrayType(LongType), ArrayType(DoubleType)) + private val DIMENSION_TYPES = Array(StringType, DateType, TimestampType, BooleanType, ArrayType(StringType)) private val log = LoggerFactory.getLogger(classOf[DruidDataset]) /** @@ -66,7 +66,7 @@ object DruidDatasetExtensions { *

* The method performs the following transformations: *

    - *
  • Drops all columns of complex datatypes such as `StructType`, `MapType` or `ArrayType` as they are not + *
  • Drops all columns of complex datatypes such as `StructType` or `MapType` as they are not * supported by `DruidSource`. This is only done if `excludeColumnsWithUnknownTypes` is set to true, * otherwise validation has already failed.
  • *
  • Adds a new column `__PARTITION_TIME__` whose value is based on `time_column` column and the given segment diff --git a/src/test/java/com/rovio/ingest/SegmentSpecTest.java b/src/test/java/com/rovio/ingest/SegmentSpecTest.java index 8a61312..406566c 100644 --- a/src/test/java/com/rovio/ingest/SegmentSpecTest.java +++ b/src/test/java/com/rovio/ingest/SegmentSpecTest.java @@ -344,4 +344,67 @@ public void shouldSupportMetricsSpecAsJson() { assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getSegmentGranularity()); assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getQueryGranularity()); } + + @Test + public void shouldSupportArrayDimensions() { + StructType schema = new StructType() + .add("updateTime", DataTypes.TimestampType) + .add("user_id", DataTypes.StringType) + .add("countries", DataTypes.createArrayType(DataTypes.StringType)); + String metricsSpec = "[]"; + SegmentSpec spec = SegmentSpec.from("temp", "updateTime", Collections.emptyList(), "DAY", "DAY", schema, false, metricsSpec); + + assertEquals("temp", spec.getDataSchema().getDataSource()); + assertEquals("updateTime", spec.getTimeColumn()); + List dimensions = spec.getDataSchema().getDimensionsSpec().getDimensions(); + assertEquals(2, dimensions.size()); + List expected = Arrays.asList("user_id", "countries"); + assertTrue(dimensions.stream().allMatch(d -> expected.contains(d.getName()))); + assertTrue(dimensions.stream().anyMatch(d -> ValueType.STRING == d.getColumnType().getType() && d.getName().equals("user_id"))); + assertTrue(dimensions.stream().anyMatch(d -> ValueType.ARRAY == d.getColumnType().getType() && d.getName().equals("countries"))); + assertFalse(spec.getDataSchema().getGranularitySpec().isRollup()); + + assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getSegmentGranularity()); + assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getQueryGranularity()); + } + + @Test + public void shouldDeserializeDimensionSpec() { + StructType schema = new StructType() + .add("__time", DataTypes.TimestampType) + .add("dim1", DataTypes.StringType) + .add("dim2", DataTypes.LongType) + .add("dim3", DataTypes.createArrayType(DataTypes.LongType)) + .add("dim4", DataTypes.StringType) + .add("dim5", DataTypes.createArrayType(DataTypes.StringType)) + .add("dim6", DataTypes.createArrayType(DataTypes.DoubleType)); + String dimensionsSpec = + "{\"dimensions\": " + + "[{\"type\": \"string\", \"name\": \"dim1\" }," + + "{\"type\": \"long\", \"name\": \"dim2\" }," + + "{\"type\": \"auto\", \"name\": \"dim3\" }," + + "{\"type\": \"json\", \"name\": \"dim4\", \"formatVersion\": 5, \"multiValueHandling\": \"array\", \"createBitmapIndex\": true }," + + "{\"type\": \"string\", \"name\": \"dim5\", \"multiValueHandling\": \"array\", \"createBitmapIndex\": true }," + + "{\"type\": \"double\", \"name\": \"dim6\" }],\n" + + "\"includeAllDimensions\": false,\n" + + "\"useSchemaDiscovery\": false}"; + String metricsSpec = "[" + + "{\n" + + " \"type\": \"longSum\",\n" + + " \"name\": \"metric2\",\n" + + " \"fieldName\": \"dim2\",\n" + + " \"expression\": null\n" + + "},\n" + + "{\n" + + " \"type\": \"doubleSum\",\n" + + " \"name\": \"metric6\",\n" + + " \"fieldName\": \"dim6\",\n" + + " \"expression\": null\n" + + "}\n" + + "]"; + SegmentSpec spec = SegmentSpec.from("temp", "__time", Collections.singletonList("updateTime"), "DAY", "DAY", schema, false, dimensionsSpec, metricsSpec, null); + List dimensions = spec.getDataSchema().getDimensionsSpec().getDimensions(); + assertEquals(6, dimensions.size()); + } + } diff --git a/src/test/scala/com/rovio/ingest/DruidDatasetExtensionsSpec.scala b/src/test/scala/com/rovio/ingest/DruidDatasetExtensionsSpec.scala index bbbf0d2..fbf3f30 100644 --- a/src/test/scala/com/rovio/ingest/DruidDatasetExtensionsSpec.scala +++ b/src/test/scala/com/rovio/ingest/DruidDatasetExtensionsSpec.scala @@ -27,7 +27,7 @@ import org.scalatest.flatspec.AnyFlatSpec // must define classes outside of the actual test methods, otherwise spark can't find them case class KpiRow(date: String, country: String, dau: Integer, revenue: Double, is_segmented: Boolean) -case class RowWithUnsupportedType(date: String, country: String, dau: Integer, labels: Array[String]) +case class RowWithUnsupportedType(date: String, country: String, dau: Integer, values: Array[Long]) case class PartitionedRow(date: String, country: String, dau: Integer, revenue: Double, `__PARTITION_NUM__`: Integer) case class ExpectedRow(`__PARTITION_TIME__`: String, `__PARTITION_NUM__`: Integer, count: Integer) @@ -221,7 +221,7 @@ class DruidDatasetExtensionsSpec extends AnyFlatSpec with Matchers with BeforeAn it should "exclude columns with unsupported types" in { val ds = Seq(RowWithUnsupportedType( - date="2019-10-17", country="US", dau=50, labels=Array("A", "B"))).toDS + date="2019-10-17", country="US", dau=50, values=Array(1L, 2L))).toDS .withColumn("date", 'date.cast(DataTypes.TimestampType)) val result = ds.repartitionByDruidSegmentSize("date", "DAY", 2, excludeColumnsWithUnknownTypes = true) @@ -237,7 +237,7 @@ class DruidDatasetExtensionsSpec extends AnyFlatSpec with Matchers with BeforeAn it should "throw exception from unsupported types by default" in { val ds = Seq(RowWithUnsupportedType( - date="2019-10-17", country="US", dau=50, labels=Array("A", "B"))).toDS + date="2019-10-17", country="US", dau=50, values=Array(1L, 2L))).toDS .withColumn("date", 'date.cast(DataTypes.TimestampType)) assertThrows[IllegalArgumentException] { ds.repartitionByDruidSegmentSize("date", "DAY", 2)