diff --git a/java/src/main/java/ai/rapids/cudf/Schema.java b/java/src/main/java/ai/rapids/cudf/Schema.java index c8571dd841c..43603386649 100644 --- a/java/src/main/java/ai/rapids/cudf/Schema.java +++ b/java/src/main/java/ai/rapids/cudf/Schema.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; /** * The schema of data to be read in. @@ -221,6 +222,13 @@ public DType[] getChildTypes() { return ret; } + public int getNumChildren() { + if (childSchemas == null) { + return 0; + } + return childSchemas.size(); + } + int[] getFlattenedNumChildren() { flattenIfNeeded(); return flattenedCounts; @@ -243,7 +251,25 @@ public boolean isStructOrHasStructDescendant() { return false; } - public static class Builder { + public HostColumnVector.DataType asHostDataType() { + if (topLevelType == DType.LIST) { + assert(childSchemas != null && childSchemas.size() == 1); + HostColumnVector.DataType element = childSchemas.get(0).asHostDataType(); + return new HostColumnVector.ListType(true, element); + } else if (topLevelType == DType.STRUCT) { + if (childSchemas == null) { + return new HostColumnVector.StructType(true); + } else { + List childTypes = + childSchemas.stream().map(Schema::asHostDataType).collect(Collectors.toList()); + return new HostColumnVector.StructType(true, childTypes); + } + } else { + return new HostColumnVector.BasicType(true, topLevelType); + } + } + + public static class Builder { private final DType topLevelType; private final List names; private final List types; diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 4038b3a40b8..4e737451ed6 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -1220,8 +1220,26 @@ private static Table gatherJSONColumns(Schema schema, TableWithMeta twm, int emp columns[i] = tbl.getColumn(index).incRefCount(); } } else { - try (Scalar s = Scalar.fromNull(types[i])) { - columns[i] = ColumnVector.fromScalar(s, rowCount); + if (types[i] == DType.LIST) { + Schema listSchema = schema.getChild(i); + Schema elementSchema = listSchema.getChild(0); + try (Scalar s = Scalar.listFromNull(elementSchema.asHostDataType())) { + columns[i] = ColumnVector.fromScalar(s, rowCount); + } + } else if (types[i] == DType.STRUCT) { + Schema structSchema = schema.getChild(i); + int numStructChildren = structSchema.getNumChildren(); + DataType[] structChildrenTypes = new DataType[numStructChildren]; + for (int j = 0; j < numStructChildren; j++) { + structChildrenTypes[j] = structSchema.getChild(j).asHostDataType(); + } + try (Scalar s = Scalar.structFromNull(structChildrenTypes)) { + columns[i] = ColumnVector.fromScalar(s, rowCount); + } + } else { + try (Scalar s = Scalar.fromNull(types[i])) { + columns[i] = ColumnVector.fromScalar(s, rowCount); + } } } }