Skip to content

Commit

Permalink
Add support for array types
Browse files Browse the repository at this point in the history
Co-authored-by: fabricebaranski <[email protected]>
  • Loading branch information
juhoautio-rovio and fabricebaranski committed Sep 6, 2024
1 parent ca4e3fd commit 5570c74
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 10 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ The Dataset extension performs the following validations:
* The Dataset has no columns with unknown types, unless `excludeColumnsWithUnknownTypes` is set to true

The Dataset extension performs the following transformations:
* Drops all columns of complex datatypes such as `StructType`, `MapType` or `ArrayType` as they
* Drops all columns of complex datatypes such as `StructType` or `MapType` as they
are not supported by `DruidSource`. This is only done if `excludeColumnsWithUnknownTypes` is set to true, otherwise validation has already failed.
* `ArrayType` is supported with `StringType`, `LongType` and `DoubleType`
* Converts `Date`/`Timestamp` type columns to `String`, except for the `time_column`
- See [Druid Docs / Data types](https://druid.apache.org/docs/latest/querying/sql.html#standard-types)
* Adds a new column `__PARTITION_TIME__` whose value is based on `time_column` column and the given [segment granularity](#segment-granularity)
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/com/rovio/ingest/DataSegmentCommitMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.fasterxml.jackson.databind.InjectableValues;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.jsontype.NamedType;
import org.apache.druid.guice.NestedDataModule;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.datasketches.hll.HllSketchModule;
Expand All @@ -29,6 +30,8 @@
import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketchBuildComplexMetricSerde;
import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketchMergeComplexMetricSerde;
import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketchModule;
import org.apache.druid.segment.DefaultColumnFormatConfig;
import org.apache.druid.segment.nested.NestedDataComplexTypeSerde;
import org.apache.druid.segment.serde.ComplexMetrics;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
Expand All @@ -52,6 +55,7 @@ public class DataSegmentCommitMessage implements WriterCommitMessage {
// ExpressionMacroTable is injected in AggregatorFactories.
// However, ExprMacro are not actually required as the DataSource is write-only.
.addValue(ExprMacroTable.class, ExprMacroTable.nil())
.addValue(DefaultColumnFormatConfig.class, new DefaultColumnFormatConfig(null))
// PruneLoadSpecHolder are injected in DataSegment.
.addValue(DataSegment.PruneSpecsHolder.class, DataSegment.PruneSpecsHolder.DEFAULT);

Expand All @@ -61,12 +65,14 @@ public class DataSegmentCommitMessage implements WriterCommitMessage {

MAPPER.setTimeZone(TimeZone.getTimeZone("UTC"));

new NestedDataModule().getJacksonModules().forEach(MAPPER::registerModule);
new SketchModule().getJacksonModules().forEach(MAPPER::registerModule);
new HllSketchModule().getJacksonModules().forEach(MAPPER::registerModule);
new KllSketchModule().getJacksonModules().forEach(MAPPER::registerModule);
new DoublesSketchModule().getJacksonModules().forEach(MAPPER::registerModule);
new ArrayOfDoublesSketchModule().getJacksonModules().forEach(MAPPER::registerModule);

NestedDataModule.registerHandlersAndSerde();
HllSketchModule.registerSerde();
KllSketchModule.registerSerde();
DoublesSketchModule.registerSerde();
Expand All @@ -75,6 +81,7 @@ public class DataSegmentCommitMessage implements WriterCommitMessage {
ComplexMetrics.registerSerde("arrayOfDoublesSketch", new ArrayOfDoublesSketchMergeComplexMetricSerde());
ComplexMetrics.registerSerde("arrayOfDoublesSketchMerge", new ArrayOfDoublesSketchMergeComplexMetricSerde());
ComplexMetrics.registerSerde("arrayOfDoublesSketchBuild", new ArrayOfDoublesSketchBuildComplexMetricSerde());
ComplexMetrics.registerSerde(NestedDataComplexTypeSerde.TYPE_NAME, NestedDataComplexTypeSerde.INSTANCE);
}


Expand Down
39 changes: 39 additions & 0 deletions src/main/java/com/rovio/ingest/TaskDataWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
*/
package com.rovio.ingest;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.rovio.ingest.model.Field;
import com.rovio.ingest.model.FieldType;
import com.rovio.ingest.model.SegmentSpec;
import com.rovio.ingest.util.ReflectionUtils;
import com.rovio.ingest.util.SegmentStorageUpdater;
Expand All @@ -38,6 +40,7 @@
import org.apache.druid.segment.indexing.RealtimeTuningConfig;
import org.apache.druid.segment.loading.DataSegmentKiller;
import org.apache.druid.segment.loading.DataSegmentPusher;
import org.apache.druid.segment.nested.StructuredData;
import org.apache.druid.segment.realtime.FireDepartmentMetrics;
import org.apache.druid.segment.realtime.appenderator.Appenderator;
import org.apache.druid.segment.realtime.appenderator.DefaultOfflineAppenderatorFactory;
Expand All @@ -50,6 +53,7 @@
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.connector.write.DataWriter;
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import org.apache.spark.sql.types.DataType;
Expand All @@ -63,6 +67,7 @@
import java.io.IOException;
import java.time.LocalDate;
import java.time.ZoneOffset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -217,6 +222,40 @@ private Map<String, Object> parse(InternalRow record) {
// Convert to Java String as Spark return UTF8String which is not compatible with Druid sketches.
value = value.toString();
}
if (value != null && segmentSpec.getComplexDimensionColumns().contains(columnName) && sqlType == DataTypes.StringType) {
try {
value = MAPPER.readValue(value.toString(), StructuredData.class);
} catch (JsonProcessingException e) {
value = null;
}
}
if (value != null && field.getFieldType() == FieldType.ARRAY_OF_STRING) {
ArrayData arrayData = record.getArray(field.getOrdinal());
int arraySize = arrayData.numElements();
List<String> valueArrayOfString = new ArrayList<>(arraySize);
for (int i = 0; i < arraySize; i++) {
valueArrayOfString.add(arrayData.get(i, DataTypes.StringType).toString());
}
value = valueArrayOfString;
}
if (value != null && field.getFieldType()== FieldType.ARRAY_OF_DOUBLE) {
ArrayData arrayData = record.getArray(field.getOrdinal());
int arraySize = arrayData.numElements();
Double[] valueArrayOfFloat = new Double[arraySize];
for (int i = 0; i < arraySize; i++) {
valueArrayOfFloat[i] = (Double) arrayData.get(i, DataTypes.DoubleType);
}
value = valueArrayOfFloat;
}
if (value != null && field.getFieldType()== FieldType.ARRAY_OF_LONG) {
ArrayData arrayData = record.getArray(field.getOrdinal());
int arraySize = arrayData.numElements();
Long[] valueArrayOfLong = new Long[arraySize];
for (int i = 0; i < arraySize; i++) {
valueArrayOfLong[i] = (Long) arrayData.get(i, DataTypes.LongType);
}
value = valueArrayOfLong;
}
map.put(columnName, value);
}
}
Expand Down
35 changes: 34 additions & 1 deletion src/main/java/com/rovio/ingest/model/FieldType.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;

import java.util.Objects;

public enum FieldType {
TIMESTAMP,
DOUBLE,
LONG,
STRING;
STRING,
ARRAY_OF_DOUBLE,
ARRAY_OF_LONG,
ARRAY_OF_STRING;

public static FieldType from(DataType dataType) {
if (isNumericType(dataType)) {
Expand All @@ -41,6 +46,18 @@ public static FieldType from(DataType dataType) {
return STRING;
}

if (isArrayOfNumericType(dataType)) {
return ARRAY_OF_LONG;
}

if (isArrayOfDoubleType(dataType)) {
return ARRAY_OF_DOUBLE;
}

if (isArrayOfStringType(dataType)) {
return ARRAY_OF_STRING;
}

throw new IllegalArgumentException("Unsupported Type " + dataType);
}

Expand All @@ -55,4 +72,20 @@ private static boolean isNumericType(DataType dataType) {
|| dataType == DataTypes.ByteType;
}

private static boolean isArrayOfNumericType(DataType dataType) {
return Objects.equals(dataType, DataTypes.createArrayType(DataTypes.LongType))
|| Objects.equals(dataType, DataTypes.createArrayType(DataTypes.IntegerType))
|| Objects.equals(dataType, DataTypes.createArrayType(DataTypes.ShortType))
|| Objects.equals(dataType, DataTypes.createArrayType(DataTypes.ByteType));
}

private static boolean isArrayOfDoubleType(DataType dataType) {
return Objects.equals(dataType, DataTypes.createArrayType(DataTypes.DoubleType))
|| Objects.equals(dataType, DataTypes.createArrayType(DataTypes.FloatType));
}
private static boolean isArrayOfStringType(DataType dataType) {
return Objects.equals(dataType, DataTypes.createArrayType(DataTypes.StringType))
|| Objects.equals(dataType, DataTypes.createArrayType(DataTypes.BooleanType));
}

}
21 changes: 19 additions & 2 deletions src/main/java/com/rovio/ingest/model/SegmentSpec.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.segment.AutoTypeColumnSchema;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.indexing.DataSchema;
import org.apache.druid.segment.indexing.granularity.GranularitySpec;
Expand Down Expand Up @@ -66,8 +68,8 @@ public class SegmentSpec implements Serializable {
private final String dimensionsSpec;
private final String metricsSpec;
private final String transformSpec;

private final Set<String> complexMetricColumns;
private final Set<String> complexDimensionColumns;

private SegmentSpec(String dataSource, String timeColumn, String segmentGranularity, String queryGranularity,
List<Field> fields, Field partitionTime, Field partitionNum, boolean rollup,
Expand All @@ -88,6 +90,11 @@ private SegmentSpec(String dataSource, String timeColumn, String segmentGranular
.filter(aggregatorFactory -> aggregatorFactory.getIntermediateType().is(ValueType.COMPLEX))
.flatMap((AggregatorFactory aggregatorFactory) -> aggregatorFactory.requiredFields().stream())
.collect(Collectors.toSet());
this.complexDimensionColumns = getDimensionsSpec().getDimensions()
.stream()
.filter(dimensionSchema -> dimensionSchema.getColumnType() == ColumnType.NESTED_DATA)
.map(DimensionSchema::getName)
.collect(Collectors.toSet());
}

public static SegmentSpec from(String datasource, String timeColumn, List<String> excludedDimensions,
Expand Down Expand Up @@ -127,7 +134,7 @@ public static SegmentSpec from(String datasource, String timeColumn, List<String
fields.stream().noneMatch(f -> f.getFieldType() == FieldType.TIMESTAMP && !f.getName().equals(timeColumn) && !f.getName().equals(PARTITION_TIME_COLUMN_NAME)),
String.format("Schema has another timestamp field other than \"%s\"", timeColumn));

Preconditions.checkArgument(fields.stream().anyMatch(f -> f.getFieldType() == FieldType.STRING),
Preconditions.checkArgument(fields.stream().anyMatch(f -> f.getFieldType() == FieldType.STRING || f.getFieldType() == FieldType.ARRAY_OF_STRING),
"Schema has no dimensions");

Preconditions.checkArgument(!rollup || fields.stream().anyMatch(f -> f.getFieldType() == FieldType.LONG || f.getFieldType() == FieldType.DOUBLE),
Expand Down Expand Up @@ -217,6 +224,12 @@ private ImmutableList<DimensionSchema> getDimensionSchemas() {
builder.add(new DoubleDimensionSchema(fieldName));
} else if (field.getFieldType() == FieldType.TIMESTAMP) {
builder.add(new LongDimensionSchema(fieldName));
} else if (field.getFieldType() == FieldType.ARRAY_OF_STRING) {
builder.add(new AutoTypeColumnSchema(fieldName, ColumnType.STRING_ARRAY));
} else if (field.getFieldType() == FieldType.ARRAY_OF_DOUBLE) {
builder.add(new AutoTypeColumnSchema(fieldName, ColumnType.DOUBLE_ARRAY));
} else if (field.getFieldType() == FieldType.ARRAY_OF_LONG) {
builder.add(new AutoTypeColumnSchema(fieldName, ColumnType.LONG_ARRAY));
}
}
}
Expand Down Expand Up @@ -269,4 +282,8 @@ private AggregatorFactory[] getAggregators() {
public Set<String> getComplexMetricColumns() {
return complexMetricColumns;
}

public Set<String> getComplexDimensionColumns() {
return complexDimensionColumns;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ object DruidDatasetExtensions {
*/
@SerialVersionUID(1L)
implicit class DruidDataset(dataset: Dataset[Row]) extends Serializable {
private val METRIC_TYPES = Array(FloatType, DoubleType, IntegerType, LongType, ShortType, ByteType)
private val DIMENSION_TYPES = Array(StringType, DateType, TimestampType, BooleanType)
private val METRIC_TYPES = Array(FloatType, DoubleType, IntegerType, LongType, ShortType, ByteType, ArrayType(LongType), ArrayType(DoubleType))
private val DIMENSION_TYPES = Array(StringType, DateType, TimestampType, BooleanType, ArrayType(StringType))
private val log = LoggerFactory.getLogger(classOf[DruidDataset])

/**
Expand All @@ -66,7 +66,7 @@ object DruidDatasetExtensions {
* <p>
* The method performs the following transformations:
* <ul>
* <li>Drops all columns of complex datatypes such as `StructType`, `MapType` or `ArrayType` as they are not
* <li>Drops all columns of complex datatypes such as `StructType` or `MapType` as they are not
* supported by `DruidSource`. This is only done if `excludeColumnsWithUnknownTypes` is set to true,
* otherwise validation has already failed.</li>
* <li>Adds a new column `__PARTITION_TIME__` whose value is based on `time_column` column and the given segment
Expand Down
63 changes: 63 additions & 0 deletions src/test/java/com/rovio/ingest/SegmentSpecTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -344,4 +344,67 @@ public void shouldSupportMetricsSpecAsJson() {
assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getSegmentGranularity());
assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getQueryGranularity());
}

@Test
public void shouldSupportArrayDimensions() {
StructType schema = new StructType()
.add("updateTime", DataTypes.TimestampType)
.add("user_id", DataTypes.StringType)
.add("countries", DataTypes.createArrayType(DataTypes.StringType));
String metricsSpec = "[]";
SegmentSpec spec = SegmentSpec.from("temp", "updateTime", Collections.emptyList(), "DAY", "DAY", schema, false, metricsSpec);

assertEquals("temp", spec.getDataSchema().getDataSource());
assertEquals("updateTime", spec.getTimeColumn());
List<DimensionSchema> dimensions = spec.getDataSchema().getDimensionsSpec().getDimensions();
assertEquals(2, dimensions.size());
List<String> expected = Arrays.asList("user_id", "countries");
assertTrue(dimensions.stream().allMatch(d -> expected.contains(d.getName())));
assertTrue(dimensions.stream().anyMatch(d -> ValueType.STRING == d.getColumnType().getType() && d.getName().equals("user_id")));
assertTrue(dimensions.stream().anyMatch(d -> ValueType.ARRAY == d.getColumnType().getType() && d.getName().equals("countries")));
assertFalse(spec.getDataSchema().getGranularitySpec().isRollup());

assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getSegmentGranularity());
assertEquals(Granularity.fromString("DAY"), spec.getDataSchema().getGranularitySpec().getQueryGranularity());
}

@Test
public void shouldDeserializeDimensionSpec() {
StructType schema = new StructType()
.add("__time", DataTypes.TimestampType)
.add("dim1", DataTypes.StringType)
.add("dim2", DataTypes.LongType)
.add("dim3", DataTypes.createArrayType(DataTypes.LongType))
.add("dim4", DataTypes.StringType)
.add("dim5", DataTypes.createArrayType(DataTypes.StringType))
.add("dim6", DataTypes.createArrayType(DataTypes.DoubleType));
String dimensionsSpec =
"{\"dimensions\": " +
"[{\"type\": \"string\", \"name\": \"dim1\" }," +
"{\"type\": \"long\", \"name\": \"dim2\" }," +
"{\"type\": \"auto\", \"name\": \"dim3\" }," +
"{\"type\": \"json\", \"name\": \"dim4\", \"formatVersion\": 5, \"multiValueHandling\": \"array\", \"createBitmapIndex\": true }," +
"{\"type\": \"string\", \"name\": \"dim5\", \"multiValueHandling\": \"array\", \"createBitmapIndex\": true }," +
"{\"type\": \"double\", \"name\": \"dim6\" }],\n" +
"\"includeAllDimensions\": false,\n" +
"\"useSchemaDiscovery\": false}";
String metricsSpec = "[" +
"{\n" +
" \"type\": \"longSum\",\n" +
" \"name\": \"metric2\",\n" +
" \"fieldName\": \"dim2\",\n" +
" \"expression\": null\n" +
"},\n" +
"{\n" +
" \"type\": \"doubleSum\",\n" +
" \"name\": \"metric6\",\n" +
" \"fieldName\": \"dim6\",\n" +
" \"expression\": null\n" +
"}\n" +
"]";
SegmentSpec spec = SegmentSpec.from("temp", "__time", Collections.singletonList("updateTime"), "DAY", "DAY", schema, false, dimensionsSpec, metricsSpec, null);
List<DimensionSchema> dimensions = spec.getDataSchema().getDimensionsSpec().getDimensions();
assertEquals(6, dimensions.size());
}

}
Loading

0 comments on commit 5570c74

Please sign in to comment.