diff --git a/lakesoul-common/src/main/java/com/dmetasoul/lakesoul/meta/dao/TableInfoDao.java b/lakesoul-common/src/main/java/com/dmetasoul/lakesoul/meta/dao/TableInfoDao.java index bea8b0a03..199f84059 100644 --- a/lakesoul-common/src/main/java/com/dmetasoul/lakesoul/meta/dao/TableInfoDao.java +++ b/lakesoul-common/src/main/java/com/dmetasoul/lakesoul/meta/dao/TableInfoDao.java @@ -34,7 +34,8 @@ public TableInfo selectByTableId(String tableId) { String sql = String.format("select * from table_info where table_id = '%s'", tableId); return getTableInfo(sql); } - public List selectByNamespace(String namespace){ + + public List selectByNamespace(String namespace) { String sql = String.format("select * from table_info where table_namespace='%s'", namespace); return getTableInfos(sql); } @@ -105,6 +106,7 @@ private TableInfo getTableInfo(String sql) { } return tableInfo; } + private List getTableInfos(String sql) { Connection conn = null; PreparedStatement pstmt = null; @@ -284,4 +286,12 @@ public static TableInfo tableInfoFromResultSet(ResultSet rs) throws SQLException .setDomain(rs.getString("domain")) .build(); } + + public static boolean isArrowKindSchema(String schema) { + return schema.charAt(schema.indexOf('"') + 1) == 'f'; + } + + public static boolean isSparkKindSchema(String schema) { + return schema.charAt(schema.indexOf('"') + 1) == 't'; + } } diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/metadata/LakeSoulCatalog.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/metadata/LakeSoulCatalog.java index e4c1200fd..402dfc7cb 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/metadata/LakeSoulCatalog.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/metadata/LakeSoulCatalog.java @@ -124,7 +124,7 @@ public void dropDatabase(String databaseName, boolean ignoreIfNotExists, boolean List tables = listTables(databaseName); if (!tables.isEmpty()) { if (cascade) { - for (String table: tables) { + for (String table : tables) { try { dropTable(new ObjectPath(databaseName, table), true); } catch (TableNotExistException e) { @@ -207,7 +207,7 @@ public void dropTable(ObjectPath tablePath, boolean ignoreIfNotExists) dbManager.deleteShortTableName(tableInfo.getTableName(), tableName, tablePath.getDatabaseName()); dbManager.deleteDataCommitInfo(tableId); dbManager.deletePartitionInfoByTableId(tableId); - if(FlinkUtil.isTable(tableInfo)){ + if (FlinkUtil.isTable(tableInfo)) { Path path = new Path(tableInfo.getTablePath()); try { path.getFileSystem().delete(path, true); @@ -274,7 +274,7 @@ public void createTable(ObjectPath tablePath, CatalogBaseTable table, boolean ig } String tableId = TABLE_ID_PREFIX + UUID.randomUUID(); String qualifiedPath = ""; - String sparkSchema = FlinkUtil.toSparkSchema(schema, cdcColumn).json(); + String sparkSchema = FlinkUtil.toArrowSchema(schema, cdcColumn).toJson(); List partitionKeys = Collections.emptyList(); if (table instanceof ResolvedCatalogTable) { partitionKeys = ((ResolvedCatalogTable) table).getPartitionKeys(); @@ -284,7 +284,7 @@ public void createTable(ObjectPath tablePath, CatalogBaseTable table, boolean ig } else { String flinkWarehouseDir = GlobalConfiguration.loadConfiguration().get(FLINK_WAREHOUSE_DIR); if (null != flinkWarehouseDir) { - path = String.join("/", flinkWarehouseDir, tablePath.getDatabaseName(), tablePath.getObjectName()); + path = String.join("/", flinkWarehouseDir, tablePath.getDatabaseName(), tablePath.getObjectName()); } } try { @@ -298,9 +298,9 @@ public void createTable(ObjectPath tablePath, CatalogBaseTable table, boolean ig } if (table instanceof ResolvedCatalogView) { tableOptions.put(LAKESOUL_VIEW.key(), "true"); - tableOptions.put(LAKESOUL_VIEW_TYPE.key(),LAKESOUL_VIEW_TYPE.defaultValue()); - tableOptions.put(VIEW_ORIGINAL_QUERY,((ResolvedCatalogView) table).getOriginalQuery()); - tableOptions.put(VIEW_EXPANDED_QUERY,((ResolvedCatalogView) table).getExpandedQuery()); + tableOptions.put(LAKESOUL_VIEW_TYPE.key(), LAKESOUL_VIEW_TYPE.defaultValue()); + tableOptions.put(VIEW_ORIGINAL_QUERY, ((ResolvedCatalogView) table).getOriginalQuery()); + tableOptions.put(VIEW_EXPANDED_QUERY, ((ResolvedCatalogView) table).getExpandedQuery()); } String json = JSON.toJSONString(tableOptions); JSONObject properties = JSON.parseObject(json); diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/sink/committer/LakeSoulSinkGlobalCommitter.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/sink/committer/LakeSoulSinkGlobalCommitter.java index 875d3fde4..c096a6bb6 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/sink/committer/LakeSoulSinkGlobalCommitter.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/sink/committer/LakeSoulSinkGlobalCommitter.java @@ -9,7 +9,9 @@ import com.dmetasoul.lakesoul.meta.DBManager; import com.dmetasoul.lakesoul.meta.DBConfig; import com.dmetasoul.lakesoul.meta.DBUtil; +import com.dmetasoul.lakesoul.meta.dao.TableInfoDao; import com.dmetasoul.lakesoul.meta.entity.TableInfo; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.flink.api.connector.sink.GlobalCommitter; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; @@ -19,6 +21,7 @@ import org.apache.flink.lakesoul.sink.writer.AbstractLakeSoulMultiTableSinkWriter; import org.apache.flink.lakesoul.tool.FlinkUtil; import org.apache.flink.lakesoul.types.TableSchemaIdentity; +import org.apache.spark.sql.arrow.ArrowUtils; import org.apache.spark.sql.arrow.DataTypeCastUtils; import org.apache.spark.sql.types.StructType; import org.slf4j.Logger; @@ -116,9 +119,11 @@ public List commit( String tableName = identity.tableId.table(); String tableNamespace = identity.tableId.schema(); boolean isCdc = identity.useCDC; - StructType msgSchema = FlinkUtil.toSparkSchema(identity.rowType, isCdc ? Optional.of( + Schema msgSchema = FlinkUtil.toArrowSchema(identity.rowType, isCdc ? Optional.of( identity.cdcColumn) : Optional.empty()); + StructType sparkSchema = ArrowUtils.fromArrowSchema(msgSchema); + TableInfo tableInfo = dbManager.getTableInfoByNameAndNamespace(tableName, tableNamespace); LOG.info("Committing: {}, {}, {}, {} {}", tableNamespace, tableName, isCdc, msgSchema, tableInfo); if (tableInfo == null) { @@ -137,7 +142,7 @@ public List commit( properties.put(CDC_CHANGE_COLUMN, CDC_CHANGE_COLUMN_DEFAULT); } } - dbManager.createNewTable(tableId, tableNamespace, tableName, identity.tableLocation, msgSchema.json(), + dbManager.createNewTable(tableId, tableNamespace, tableName, identity.tableLocation, msgSchema.toJson(), properties, partition); } else { DBUtil.TablePartitionKeys partitionKeys = DBUtil.parseTableInfoPartitions(tableInfo.getPartitions()); @@ -149,11 +154,17 @@ public List commit( !new HashSet<>(partitionKeys.rangeKeys).containsAll(identity.partitionKeyList)) { throw new IOException("Change of partition key column of table " + tableName + " is forbidden"); } - StructType origSchema = (StructType) StructType.fromJson(tableInfo.getTableSchema()); + StructType origSchema = null; + if (TableInfoDao.isArrowKindSchema(tableInfo.getTableSchema())) { + Schema arrowSchema = Schema.fromJSON(tableInfo.getTableSchema()); + origSchema = ArrowUtils.fromArrowSchema(arrowSchema); + } else { + origSchema = (StructType) StructType.fromJson(tableInfo.getTableSchema()); + } scala.Tuple3 equalOrCanCastTuple3 = DataTypeCastUtils.checkSchemaEqualOrCanCast(origSchema, - msgSchema, + ArrowUtils.fromArrowSchema(msgSchema), identity.partitionKeyList, identity.primaryKeys); String equalOrCanCast = equalOrCanCastTuple3._1(); @@ -162,9 +173,9 @@ public List commit( if (equalOrCanCast.equals(DataTypeCastUtils.CAN_CAST())) { LOG.warn("Schema change found, origin schema = {}, changed schema = {}", origSchema.json(), - msgSchema.json()); + msgSchema.toJson()); if (logicallyDropColumn) { - List droppedColumn = DataTypeCastUtils.getDroppedColumn(origSchema, msgSchema); + List droppedColumn = DataTypeCastUtils.getDroppedColumn(origSchema, sparkSchema); if (droppedColumn.size() > 0) { LOG.warn("Dropping Column {} Logically", droppedColumn); dbManager.logicallyDropColumn(tableInfo.getTableId(), droppedColumn); @@ -172,7 +183,7 @@ public List commit( dbManager.updateTableSchema(tableInfo.getTableId(), mergeStructType.json()); } } else { - dbManager.updateTableSchema(tableInfo.getTableId(), msgSchema.json()); + dbManager.updateTableSchema(tableInfo.getTableId(), msgSchema.toJson()); } } else { LOG.info("Changing table schema: {}, {}, {}, {}, {}, {}", @@ -182,7 +193,7 @@ public List commit( msgSchema, identity.useCDC, identity.cdcColumn); - dbManager.updateTableSchema(tableInfo.getTableId(), msgSchema.json()); + dbManager.updateTableSchema(tableInfo.getTableId(), msgSchema.toJson()); if (JSONObject.parseObject(tableInfo.getProperties()).containsKey(DBConfig.TableInfoProperty.DROPPED_COLUMN)) { dbManager.removeLogicallyDropColumn(tableInfo.getTableId()); } diff --git a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/tool/FlinkUtil.java b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/tool/FlinkUtil.java index 392ae92d6..97f6222c1 100644 --- a/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/tool/FlinkUtil.java +++ b/lakesoul-flink/src/main/java/org/apache/flink/lakesoul/tool/FlinkUtil.java @@ -8,7 +8,10 @@ import com.alibaba.fastjson.JSONObject; import com.dmetasoul.lakesoul.lakesoul.io.NativeIOBase; import com.dmetasoul.lakesoul.meta.*; +import com.dmetasoul.lakesoul.meta.dao.TableInfoDao; import com.dmetasoul.lakesoul.meta.entity.TableInfo; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.GlobalConfiguration; import org.apache.flink.core.fs.FileSystem; @@ -30,6 +33,7 @@ import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeRoot; import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; import org.apache.flink.types.RowKind; import org.apache.hadoop.fs.permission.FsAction; @@ -72,6 +76,75 @@ public static String getRangeValue(CatalogPartitionSpec cps) { return "Null"; } + public static org.apache.arrow.vector.types.pojo.Schema toArrowSchema(RowType rowType, Optional cdcColumn) throws CatalogException { + List fields = new ArrayList<>(); + String cdcColName = null; + if (cdcColumn.isPresent()) { + cdcColName = cdcColumn.get(); + Field cdcField = ArrowUtils.toArrowField(cdcColName, new VarCharType(false, 16)); + fields.add(cdcField); + } + + for (RowType.RowField field : rowType.getFields()) { + String name = field.getName(); + if (name.equals(SORT_FIELD)) continue; + + LogicalType logicalType = field.getType(); + Field arrowField = ArrowUtils.toArrowField(name, logicalType); + if (name.equals(cdcColName)) { + if (!arrowField.toString().equals(fields.get(0).toString())) { + throw new CatalogException(CDC_CHANGE_COLUMN + + "=" + + cdcColName + + "has an invalid field of" + + field + + "," + + CDC_CHANGE_COLUMN + + " require field of " + + fields.get(0).toString()); + } + } else { + fields.add(arrowField); + } + } + return new org.apache.arrow.vector.types.pojo.Schema(fields); + } + + public static org.apache.arrow.vector.types.pojo.Schema toArrowSchema(TableSchema tsc, Optional cdcColumn) throws CatalogException { + List fields = new ArrayList<>(); + String cdcColName = null; + if (cdcColumn.isPresent()) { + cdcColName = cdcColumn.get(); + Field cdcField = ArrowUtils.toArrowField(cdcColName, new VarCharType(false, 16)); + fields.add(cdcField); + } + + for (int i = 0; i < tsc.getFieldCount(); i++) { + String name = tsc.getFieldName(i).get(); + DataType dt = tsc.getFieldDataType(i).get(); + if (name.equals(SORT_FIELD)) continue; + + LogicalType logicalType = dt.getLogicalType(); + Field arrowField = ArrowUtils.toArrowField(name, logicalType); + if (name.equals(cdcColName)) { + if (!arrowField.toString().equals(fields.get(0).toString())) { + throw new CatalogException(CDC_CHANGE_COLUMN + + "=" + + cdcColName + + "has an invalid field of" + + arrowField + + "," + + CDC_CHANGE_COLUMN + + " require field of " + + fields.get(0).toString()); + } + } else { + fields.add(arrowField); + } + } + return new org.apache.arrow.vector.types.pojo.Schema(fields); + } + public static StructType toSparkSchema(RowType rowType, Optional cdcColumn) throws CatalogException { StructType stNew = new StructType(); @@ -220,10 +293,18 @@ public static CatalogBaseTable toFlinkCatalog(TableInfo tableInfo) { String tableSchema = tableInfo.getTableSchema(); JSONObject properties = JSON.parseObject(tableInfo.getProperties()); - StructType struct = (StructType) org.apache.spark.sql.types.DataType.fromJson(tableSchema); - org.apache.arrow.vector.types.pojo.Schema - arrowSchema = - org.apache.spark.sql.arrow.ArrowUtils.toArrowSchema(struct, ZoneId.of("UTC").toString()); + org.apache.arrow.vector.types.pojo.Schema arrowSchema = null; + System.out.println(tableSchema); + if (TableInfoDao.isArrowKindSchema(tableSchema)) { + try { + arrowSchema = org.apache.arrow.vector.types.pojo.Schema.fromJSON(tableSchema); + } catch (IOException e) { + throw new CatalogException(e); + } + } else { + StructType struct = (StructType) org.apache.spark.sql.types.DataType.fromJson(tableSchema); + arrowSchema = org.apache.spark.sql.arrow.ArrowUtils.toArrowSchema(struct, ZoneId.of("UTC").toString()); + } RowType rowType = ArrowUtils.fromArrowSchema(arrowSchema); Builder bd = Schema.newBuilder(); diff --git a/lakesoul-flink/src/test/java/org/apache/flink/lakesoul/test/LakeSoulCatalogTest.java b/lakesoul-flink/src/test/java/org/apache/flink/lakesoul/test/LakeSoulCatalogTest.java index c4adad210..b105130d7 100644 --- a/lakesoul-flink/src/test/java/org/apache/flink/lakesoul/test/LakeSoulCatalogTest.java +++ b/lakesoul-flink/src/test/java/org/apache/flink/lakesoul/test/LakeSoulCatalogTest.java @@ -16,6 +16,7 @@ import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.table.catalog.Catalog; import org.apache.flink.table.catalog.exceptions.DatabaseAlreadyExistException; +import org.apache.spark.sql.arrow.ArrowUtils; import org.apache.spark.sql.types.StructType; import org.assertj.core.api.Assertions; import org.junit.Before; @@ -82,8 +83,7 @@ public void createTable() { tEnvs.executeSql("show tables").print(); TableInfo info = DbManage.getTableInfoByNameAndNamespace("user_behaviorgg", "test_lakesoul_meta"); assertEquals(info.getTableSchema(), - new StructType().add("user_id", LongType, false).add("dt", StringType).add("name", StringType) - .json()); + ArrowUtils.toArrowSchema(new StructType().add("user_id", LongType, false).add("dt", StringType).add("name", StringType), "UTC").toJson()); tEnvs.executeSql("DROP TABLE user_behaviorgg"); } @@ -94,11 +94,11 @@ public void createTableWithLike() { "'lakesoul_meta_host_port'='9043', 'path'='/tmp/user_behaviorgg', 'use_cdc'='true')"); TableInfo info = DbManage.getTableInfoByNameAndNamespace("user_behaviorgg", "test_lakesoul_meta"); - Assertions.assertThat(info.getTableSchema()).isEqualTo(new StructType().add("user_id", LongType, false).add("dt", StringType).add("name", StringType, false).json()); + Assertions.assertThat(info.getTableSchema()).isEqualTo(ArrowUtils.toArrowSchema(new StructType().add("name", StringType, false).add("user_id", LongType, false).add("dt", StringType), "UTC").toJson()); tEnvs.executeSql("CREATE TABLE if not exists like_table with ('path'='/tmp/like_table') like user_behaviorgg"); TableInfo info2 = DbManage.getTableInfoByNameAndNamespace("like_table", "test_lakesoul_meta"); - Assertions.assertThat(info2.getTableSchema()).isEqualTo(new StructType().add("user_id", LongType, false).add("dt", StringType).add("name", StringType, false).json()); + Assertions.assertThat(info2.getTableSchema()).isEqualTo(ArrowUtils.toArrowSchema(new StructType().add("name", StringType, false).add("user_id", LongType, false).add("dt", StringType), "UTC").toJson()); Assertions.assertThat(JSON.parseObject(info.getProperties()).get("lakesoul_cdc_change_column")).isEqualTo(JSON.parseObject(info2.getProperties()).get("lakesoul_cdc_change_column")); Assertions.assertThat(JSON.parseObject(info.getProperties()).get("path")).isEqualTo("/tmp/user_behaviorgg"); Assertions.assertThat(JSON.parseObject(info2.getProperties()).get("path")).isEqualTo("/tmp/like_table"); diff --git a/lakesoul-presto/src/main/java/com/facebook/presto/lakesoul/LakeSoulMetadata.java b/lakesoul-presto/src/main/java/com/facebook/presto/lakesoul/LakeSoulMetadata.java index 19d60b174..05b1e8d14 100644 --- a/lakesoul-presto/src/main/java/com/facebook/presto/lakesoul/LakeSoulMetadata.java +++ b/lakesoul-presto/src/main/java/com/facebook/presto/lakesoul/LakeSoulMetadata.java @@ -8,6 +8,7 @@ import com.alibaba.fastjson.JSONObject; import com.dmetasoul.lakesoul.meta.DBManager; import com.dmetasoul.lakesoul.meta.DBUtil; +import com.dmetasoul.lakesoul.meta.dao.TableInfoDao; import com.dmetasoul.lakesoul.meta.entity.TableInfo; import com.facebook.presto.lakesoul.handle.LakeSoulTableColumnHandle; import com.facebook.presto.lakesoul.handle.LakeSoulTableHandle; @@ -16,8 +17,12 @@ import com.facebook.presto.spi.*; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.google.common.collect.ImmutableList; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.arrow.ArrowUtils; import org.apache.spark.sql.types.StructType; +import java.io.IOException; import java.time.ZoneId; import java.util.*; import java.util.stream.Collectors; @@ -72,9 +77,17 @@ public List getTableLayouts( TableInfo tableInfo = dbManager.getTableInfoByTableId(((LakeSoulTableHandle) table).getId()); DBUtil.TablePartitionKeys partitionKeys = DBUtil.parseTableInfoPartitions(tableInfo.getPartitions()); JSONObject properties = JSON.parseObject(tableInfo.getProperties()); - StructType struct = (StructType) org.apache.spark.sql.types.DataType.fromJson(tableInfo.getTableSchema()); - org.apache.arrow.vector.types.pojo.Schema arrowSchema = - org.apache.spark.sql.arrow.ArrowUtils.toArrowSchema(struct, ZoneId.of("UTC").toString()); + org.apache.arrow.vector.types.pojo.Schema arrowSchema = null; + if (TableInfoDao.isArrowKindSchema(tableInfo.getTableSchema())) { + try { + arrowSchema = Schema.fromJSON(tableInfo.getTableSchema()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } else { + StructType struct = (StructType) StructType.fromJson(tableInfo.getTableSchema()); + arrowSchema = org.apache.spark.sql.arrow.ArrowUtils.toArrowSchema(struct, ZoneId.of("UTC").toString()); + } HashMap allColumns = new HashMap<>(); String cdcChangeColumn = properties.getString(CDC_CHANGE_COLUMN); for (org.apache.arrow.vector.types.pojo.Field field : arrowSchema.getFields()) { @@ -119,9 +132,17 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect throw new RuntimeException("no such table: " + handle.getNames()); } JSONObject properties = JSON.parseObject(tableInfo.getProperties()); - StructType struct = (StructType) org.apache.spark.sql.types.DataType.fromJson(tableInfo.getTableSchema()); - org.apache.arrow.vector.types.pojo.Schema arrowSchema = - org.apache.spark.sql.arrow.ArrowUtils.toArrowSchema(struct, ZoneId.of("UTC").toString()); + org.apache.arrow.vector.types.pojo.Schema arrowSchema = null; + if (TableInfoDao.isArrowKindSchema(tableInfo.getTableSchema())) { + try { + arrowSchema = Schema.fromJSON(tableInfo.getTableSchema()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } else { + StructType struct = (StructType) StructType.fromJson(tableInfo.getTableSchema()); + arrowSchema = org.apache.spark.sql.arrow.ArrowUtils.toArrowSchema(struct, ZoneId.of("UTC").toString()); + } List columns = new LinkedList<>(); String cdcChangeColumn = properties.getString(CDC_CHANGE_COLUMN); @@ -162,9 +183,19 @@ public Map getColumnHandles(ConnectorSession session, Conn throw new RuntimeException("no such table: " + table.getNames()); } JSONObject properties = JSON.parseObject(tableInfo.getProperties()); - StructType struct = (StructType) org.apache.spark.sql.types.DataType.fromJson(tableInfo.getTableSchema()); - org.apache.arrow.vector.types.pojo.Schema arrowSchema = - org.apache.spark.sql.arrow.ArrowUtils.toArrowSchema(struct, ZoneId.of("UTC").toString()); + + org.apache.arrow.vector.types.pojo.Schema arrowSchema = null; + if (TableInfoDao.isArrowKindSchema(tableInfo.getTableSchema())) { + try { + arrowSchema = Schema.fromJSON(tableInfo.getTableSchema()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } else { + StructType struct = (StructType) StructType.fromJson(tableInfo.getTableSchema()); + arrowSchema = org.apache.spark.sql.arrow.ArrowUtils.toArrowSchema(struct, ZoneId.of("UTC").toString()); + } + HashMap map = new HashMap<>(); String cdcChangeColumn = properties.getString(CDC_CHANGE_COLUMN); for (org.apache.arrow.vector.types.pojo.Field field : arrowSchema.getFields()) { @@ -191,9 +222,17 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, throw new RuntimeException("no such table: " + handle.getTableHandle().getNames()); } - StructType struct = (StructType) org.apache.spark.sql.types.DataType.fromJson(tableInfo.getTableSchema()); - org.apache.arrow.vector.types.pojo.Schema arrowSchema = - org.apache.spark.sql.arrow.ArrowUtils.toArrowSchema(struct, ZoneId.of("UTC").toString()); + org.apache.arrow.vector.types.pojo.Schema arrowSchema = null; + if (TableInfoDao.isArrowKindSchema(tableInfo.getTableSchema())) { + try { + arrowSchema = Schema.fromJSON(tableInfo.getTableSchema()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } else { + StructType struct = (StructType) StructType.fromJson(tableInfo.getTableSchema()); + arrowSchema = org.apache.spark.sql.arrow.ArrowUtils.toArrowSchema(struct, ZoneId.of("UTC").toString()); + } for (org.apache.arrow.vector.types.pojo.Field field : arrowSchema.getFields()) { Map properties = new HashMap<>(); for (Map.Entry entry : field.getMetadata().entrySet()) { diff --git a/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/CreateTableCommand.scala b/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/CreateTableCommand.scala index 97a209025..7b41c6f83 100644 --- a/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/CreateTableCommand.scala +++ b/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/CreateTableCommand.scala @@ -9,6 +9,7 @@ import com.dmetasoul.lakesoul.meta.{DataFileInfo, SparkMetaVersion} import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.arrow.ArrowUtils import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.LeafRunnableCommand @@ -167,7 +168,7 @@ case class CreateTableCommand(var table: CatalogTable, assertPathEmpty(sparkSession, tableWithLocation) // This is a user provided schema. // Doesn't come from a query, Follow nullability invariants. - val newTableInfo = getProvidedTableInfo(tc, table, table.schema.json) + val newTableInfo = getProvidedTableInfo(tc, table, ArrowUtils.toArrowSchema(table.schema).toJson) tc.commit(Seq.empty[DataFileInfo], Seq.empty[DataFileInfo], newTableInfo) } else { diff --git a/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/alterTableCommands.scala b/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/alterTableCommands.scala index 61cc82bba..3ab31466f 100644 --- a/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/alterTableCommands.scala +++ b/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/alterTableCommands.scala @@ -5,6 +5,7 @@ package org.apache.spark.sql.lakesoul.commands import com.dmetasoul.lakesoul.meta.DataFileInfo +import org.apache.spark.sql.arrow.ArrowUtils import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, First} @@ -154,7 +155,7 @@ case class AlterTableAddColumnsCommand( SchemaUtils.checkColumnNameDuplication(newSchema, "in adding columns") DataSourceUtils.checkFieldNames(new ParquetFileFormat, newSchema) - val newTableInfo = tableInfo.copy(table_schema = newSchema.json) + val newTableInfo = tableInfo.copy(table_schema = ArrowUtils.toArrowSchema(newSchema).toJson) tc.commit(Seq.empty[DataFileInfo], Seq.empty[DataFileInfo], newTableInfo) Seq.empty[Row] @@ -228,7 +229,7 @@ case class AlterTableChangeColumnCommand( case (_, _@StructType(fields), _) => fields } - val newTableInfo = tableInfo.copy(table_schema = newSchema.json) + val newTableInfo = tableInfo.copy(table_schema = ArrowUtils.toArrowSchema(newSchema).toJson) tc.commit(Seq.empty[DataFileInfo], Seq.empty[DataFileInfo], newTableInfo) Seq.empty[Row] @@ -356,7 +357,7 @@ case class AlterTableReplaceColumnsCommand( SchemaUtils.checkColumnNameDuplication(newSchema, "in replacing columns") DataSourceUtils.checkFieldNames(new ParquetFileFormat(), newSchema) - val newTableInfo = tableInfo.copy(table_schema = newSchema.json) + val newTableInfo = tableInfo.copy(table_schema = ArrowUtils.toArrowSchema(newSchema).toJson) tc.commit(Seq.empty[DataFileInfo], Seq.empty[DataFileInfo], newTableInfo) Seq.empty[Row] diff --git a/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/schema/ImplicitMetadataOperation.scala b/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/schema/ImplicitMetadataOperation.scala index f8aa164ee..03af8e178 100644 --- a/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/schema/ImplicitMetadataOperation.scala +++ b/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/schema/ImplicitMetadataOperation.scala @@ -12,6 +12,7 @@ import org.apache.spark.sql.lakesoul.utils.{PartitionUtils, SparkUtil, TableInfo import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.hadoop.fs.Path +import org.apache.spark.sql.arrow.ArrowUtils /** @@ -144,7 +145,7 @@ trait ImplicitMetadataOperation extends Logging { namespace = table_info.namespace, table_path_s = Option(SparkUtil.makeQualifiedTablePath(new Path(table_info.table_path_s.get)).toString), table_id = table_info.table_id, - table_schema = dataSchema.json, + table_schema = ArrowUtils.toArrowSchema(dataSchema).toJson, range_column = normalizedRangePartitionCols.mkString(LAKESOUL_RANGE_PARTITION_SPLITTER), hash_column = normalizedHashPartitionCols.mkString(LAKESOUL_HASH_PARTITION_SPLITTER), bucket_num = realHashBucketNum, @@ -153,14 +154,14 @@ trait ImplicitMetadataOperation extends Logging { } else if (isOverwriteMode && canOverwriteSchema && isNewSchema) { val newTableInfo = tc.tableInfo.copy( - table_schema = dataSchema.json + table_schema = ArrowUtils.toArrowSchema(dataSchema).toJson ) tc.updateTableInfo(newTableInfo) } else if (isNewSchema && canMergeSchema) { logInfo(s"New merged schema: ${mergedSchema.treeString}") - tc.updateTableInfo(tc.tableInfo.copy(table_schema = mergedSchema.json)) + tc.updateTableInfo(tc.tableInfo.copy(table_schema = ArrowUtils.toArrowSchema(mergedSchema).toJson)) } else if (isNewSchema) { val errorBuilder = new MetadataMismatchErrorBuilder if (isNewSchema) { diff --git a/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/utils/MetaData.scala b/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/utils/MetaData.scala index 778b841b5..6ab671518 100644 --- a/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/utils/MetaData.scala +++ b/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/utils/MetaData.scala @@ -5,9 +5,12 @@ package org.apache.spark.sql.lakesoul.utils import com.dmetasoul.lakesoul.meta.DBConfig.{LAKESOUL_HASH_PARTITION_SPLITTER, LAKESOUL_RANGE_PARTITION_SPLITTER} +import com.dmetasoul.lakesoul.meta.dao.TableInfoDao import com.dmetasoul.lakesoul.meta.{CommitState, CommitType, DataFileInfo, PartitionInfoScala} import com.fasterxml.jackson.annotation.JsonIgnore +import org.apache.arrow.vector.types.pojo.Schema import org.apache.hadoop.fs.Path +import org.apache.spark.sql.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructType} import java.util.UUID @@ -46,8 +49,14 @@ case class TableInfo(namespace: String, //full table schema which contains partition columns @JsonIgnore lazy val schema: StructType = - Option(table_schema).map { s => - DataType.fromJson(s).asInstanceOf[StructType] + Option(table_schema).map { s => { + // latest version: from arrow schema json + if (TableInfoDao.isArrowKindSchema(s)) + ArrowUtils.fromArrowSchema(Schema.fromJSON(s)) + else + // old version: from spark struct datatype + DataType.fromJson(s).asInstanceOf[StructType] + } }.getOrElse(StructType.apply(Nil)) //range partition columns diff --git a/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala b/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala index 02fb8bf95..719ec4cfe 100644 --- a/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala +++ b/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala @@ -5,16 +5,16 @@ package org.apache.spark.sql.arrow import scala.collection.JavaConverters._ - import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.complex.MapVector import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} - import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import java.util + private[sql] object ArrowUtils { val rootAllocator = new RootAllocator(Long.MaxValue) @@ -70,21 +70,27 @@ private[sql] object ArrowUtils { } /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ - def toArrowField( - name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = { + def toArrowField(name: String, dt: DataType, nullable: Boolean, timeZoneId: String, metadata: util.Map[String, String] = null): Field = { + dt match { case ArrayType(elementType, containsNull) => - val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) + val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null, metadata) new Field(name, fieldType, - Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava) + Seq(toArrowField("element", elementType, containsNull, timeZoneId, metadata)).asJava) case StructType(fields) => - val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null, metadata) new Field(name, fieldType, fields.map { field => - toArrowField(field.name, field.dataType, field.nullable, timeZoneId) + val comment = field.getComment + val child_metadata = if (comment.isDefined) { + val map = new util.HashMap[String, String] + map.put("spark_comment", comment.get) + map + } else null + toArrowField(field.name, field.dataType, field.nullable, timeZoneId, child_metadata) }.toSeq.asJava) case MapType(keyType, valueType, valueContainsNull) => - val mapType = new FieldType(nullable, new ArrowType.Map(false), null) + val mapType = new FieldType(nullable, new ArrowType.Map(false), null, metadata) // Note: Map Type struct can not be null, Struct Type key field can not be null new Field(name, mapType, Seq(toArrowField(MapVector.DATA_VECTOR_NAME, @@ -94,7 +100,7 @@ private[sql] object ArrowUtils { nullable = false, timeZoneId)).asJava) case dataType => - val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null) + val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null, metadata) new Field(name, fieldType, Seq.empty[Field].asJava) } } @@ -113,7 +119,11 @@ private[sql] object ArrowUtils { case ArrowType.Struct.INSTANCE => val fields = field.getChildren().asScala.map { child => val dt = fromArrowField(child) - StructField(child.getName, dt, child.isNullable) + val comment = child.getMetadata.get("spark_comment") + if (comment == null) + StructField(child.getName, dt, child.isNullable) + else + StructField(child.getName, dt, child.isNullable).withComment(comment) } StructType(fields.toSeq) case arrowType => fromArrowType(arrowType) @@ -121,16 +131,26 @@ private[sql] object ArrowUtils { } /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ - def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { + def toArrowSchema(schema: StructType, timeZoneId: String = "UTC"): Schema = { new Schema(schema.map { field => - toArrowField(field.name, field.dataType, field.nullable, timeZoneId) + val comment = field.getComment + val metadata = if (comment.isDefined) { + val map = new util.HashMap[String, String] + map.put("spark_comment", comment.get) + map + } else null + toArrowField(field.name, field.dataType, field.nullable, timeZoneId, metadata) }.asJava) } def fromArrowSchema(schema: Schema): StructType = { StructType(schema.getFields.asScala.map { field => val dt = fromArrowField(field) - StructField(field.getName, dt, field.isNullable) + val comment = field.getMetadata.get("spark_comment") + if (comment == null) + StructField(field.getName, dt, field.isNullable) + else + StructField(field.getName, dt, field.isNullable).withComment(comment) }.toSeq) } diff --git a/rust/lakesoul-io/src/datasource/parquet_source.rs b/rust/lakesoul-io/src/datasource/parquet_source.rs index 96af1c273..30acbef2a 100644 --- a/rust/lakesoul-io/src/datasource/parquet_source.rs +++ b/rust/lakesoul-io/src/datasource/parquet_source.rs @@ -223,7 +223,7 @@ impl TableProvider for LakeSoulParquetProvider { let df = DataFrame::new(_state.clone(), self.plans[i].clone()); let df = _filters .iter() - .fold(df, |df, f| df.clone().filter(f.clone()).unwrap_or(df.clone())); + .fold(df, |df, f| df.clone().filter(f.clone()).unwrap_or(df)); let df_schema = Arc::new(df.schema().clone()); let projected_cols = schema_intersection(df_schema, projected_schema.clone(), &self.config.primary_keys); let df = if projected_cols.is_empty() { @@ -290,7 +290,7 @@ impl LakeSoulParquetScanExec { Self { projections: projections.unwrap().clone(), origin_schema: schema.clone(), - projected_schema: project_schema(&schema.clone(), projections).unwrap(), + projected_schema: project_schema(&schema, projections).unwrap(), inputs, default_column_value, merge_operators, @@ -357,7 +357,7 @@ impl ExecutionPlan for LakeSoulParquetScanExec { .map(|&idx| { datafusion::physical_expr::expressions::col( self.origin_schema().field(idx).name(), - &self.schema().clone(), + &self.schema(), ) .unwrap() }) @@ -385,8 +385,8 @@ pub fn merge_stream( let merge_stream = if primary_keys.is_empty() { Box::pin(DefaultColumnStream::new_from_streams_with_default( streams, - schema.clone(), - default_column_value.clone(), + schema, + default_column_value, )) } else { let merge_schema: SchemaRef = Arc::new(Schema::new( @@ -416,7 +416,7 @@ pub fn merge_stream( .collect(); let merge_stream = SortedStreamMerger::new_from_streams( streams, - merge_schema.clone(), + merge_schema, primary_keys.iter().map(String::clone).collect(), batch_size, merge_ops, @@ -424,8 +424,8 @@ pub fn merge_stream( .unwrap(); Box::pin(DefaultColumnStream::new_from_streams_with_default( vec![Box::pin(merge_stream)], - schema.clone(), - default_column_value.clone(), + schema, + default_column_value, )) }; Ok(merge_stream)