diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveConvertletTable.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveConvertletTable.java index e5a8ed35d..85fa0b659 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveConvertletTable.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveConvertletTable.java @@ -1,11 +1,10 @@ /** - * Copyright 2018-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ package com.linkedin.coral.hive.hive2rel; -import java.util.ArrayList; import java.util.List; import com.google.common.base.Preconditions; @@ -17,7 +16,6 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNodeList; import org.apache.calcite.sql.fun.SqlCastFunction; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql2rel.ReflectiveConvertletTable; import org.apache.calcite.sql2rel.SqlRexContext; import org.apache.calcite.sql2rel.SqlRexConvertlet; @@ -26,7 +24,6 @@ import com.linkedin.coral.com.google.common.collect.ImmutableList; import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator; import com.linkedin.coral.hive.hive2rel.functions.HiveInOperator; -import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction; /** @@ -35,17 +32,6 @@ */ public class HiveConvertletTable extends ReflectiveConvertletTable { - @SuppressWarnings("unused") - public RexNode convertNamedStruct(SqlRexContext cx, HiveNamedStructFunction func, SqlCall call) { - List operandExpressions = new ArrayList<>(call.operandCount() / 2); - for (int i = 0; i < call.operandCount(); i += 2) { - operandExpressions.add(cx.convertExpression(call.operand(i + 1))); - } - RelDataType retType = cx.getValidator().getValidatedNodeType(call); - RexNode rowNode = cx.getRexBuilder().makeCall(retType, SqlStdOperatorTable.ROW, operandExpressions); - return cx.getRexBuilder().makeCast(retType, rowNode); - } - @SuppressWarnings("unused") public RexNode convertHiveInOperator(SqlRexContext cx, HiveInOperator operator, SqlCall call) { List operandList = call.getOperandList(); diff --git a/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverterTest.java b/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverterTest.java index 778dbfef4..6ab45e7bc 100644 --- a/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverterTest.java +++ b/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverterTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2017-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2017-2023 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -497,12 +497,11 @@ public void testStructPeekDisallowed() { public void testStructReturnFieldAccess() { final String sql = "select named_struct('field_a', 10, 'field_b', 'abc').field_b"; RelNode rel = toRel(sql); - final String expectedRel = "LogicalProject(EXPR$0=[CAST(ROW(10, 'abc')):" - + "RecordType(INTEGER NOT NULL field_a, CHAR(3) NOT NULL field_b) NOT NULL.field_b])\n" + final String expectedRel = "LogicalProject(EXPR$0=[named_struct('field_a', 10, 'field_b', 'abc').field_b])\n" + " LogicalValues(tuples=[[{ 0 }]])\n"; assertEquals(relToStr(rel), expectedRel); - final String expectedSql = "SELECT CAST(ROW(10, 'abc') AS ROW(field_a INTEGER, field_b CHAR(3))).field_b\n" - + "FROM (VALUES (0)) t (ZERO)"; + final String expectedSql = + "SELECT named_struct('field_a', 10, 'field_b', 'abc').field_b\n" + "FROM (VALUES (0)) t (ZERO)"; assertEquals(relToHql(rel), expectedSql); } diff --git a/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/NamedStructTest.java b/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/NamedStructTest.java index 58c32ff98..12c4410e2 100644 --- a/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/NamedStructTest.java +++ b/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/NamedStructTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -44,8 +44,7 @@ public void testMixedTypes() { final String sql = "SELECT named_struct('abc', 123, 'def', 'xyz')"; RelNode rel = toRel(sql); final String generated = relToStr(rel); - final String expected = "" - + "LogicalProject(EXPR$0=[CAST(ROW(123, 'xyz')):RecordType(INTEGER NOT NULL abc, CHAR(3) NOT NULL def) NOT NULL])\n" + final String expected = "" + "LogicalProject(EXPR$0=[named_struct('abc', 123, 'def', 'xyz')])\n" + " LogicalValues(tuples=[[{ 0 }]])\n"; assertEquals(generated, expected); } @@ -54,9 +53,8 @@ public void testMixedTypes() { public void testNullFieldValue() { final String sql = "SELECT named_struct('abc', cast(NULL as int), 'def', 150)"; final String generated = sqlToRelStr(sql); - final String expected = - "LogicalProject(EXPR$0=[CAST(ROW(CAST(null:NULL):INTEGER, 150)):RecordType(INTEGER abc, INTEGER NOT NULL def) NOT NULL])\n" - + " LogicalValues(tuples=[[{ 0 }]])\n"; + final String expected = "LogicalProject(EXPR$0=[named_struct('abc', CAST(null:NULL):INTEGER, 'def', 150)])\n" + + " LogicalValues(tuples=[[{ 0 }]])\n"; assertEquals(generated, expected); } @@ -65,7 +63,7 @@ public void testAllNullValues() { final String sql = "SELECT named_struct('abc', cast(NULL as int), 'def', cast(NULL as double))"; final String generated = sqlToRelStr(sql); final String expected = - "LogicalProject(EXPR$0=[CAST(ROW(CAST(null:NULL):INTEGER, CAST(null:NULL):DOUBLE)):RecordType(INTEGER abc, DOUBLE def) NOT NULL])\n" + "LogicalProject(EXPR$0=[named_struct('abc', CAST(null:NULL):INTEGER, 'def', CAST(null:NULL):DOUBLE)])\n" + " LogicalValues(tuples=[[{ 0 }]])\n"; assertEquals(generated, expected); } @@ -74,10 +72,9 @@ public void testAllNullValues() { public void testNestedComplexTypes() { final String sql = "SELECT named_struct('arr', array(10, 15), 's', named_struct('f1', 123, 'f2', array(20.5)))"; final String generated = sqlToRelStr(sql); - final String expected = "LogicalProject(EXPR$0=[CAST(ROW(ARRAY(10, 15), CAST(ROW(123, ARRAY(20.5:DECIMAL(3, 1)))):" - + "RecordType(INTEGER NOT NULL f1, DECIMAL(3, 1) NOT NULL ARRAY NOT NULL f2) NOT NULL)):" - + "RecordType(INTEGER NOT NULL ARRAY NOT NULL arr, RecordType(INTEGER NOT NULL f1, DECIMAL(3, 1) NOT NULL ARRAY NOT NULL f2) NOT NULL s) NOT NULL])\n" - + " LogicalValues(tuples=[[{ 0 }]])\n"; + final String expected = + "LogicalProject(EXPR$0=[named_struct('arr', ARRAY(10, 15), 's', named_struct('f1', 123, 'f2', ARRAY(20.5:DECIMAL(3, 1))))])\n" + + " LogicalValues(tuples=[[{ 0 }]])\n"; // verified by human that expected string is correct and retained here to protect from future changes assertEquals(generated, expected); } diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java index 8db5cca72..928e51c9d 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java @@ -15,6 +15,7 @@ import com.linkedin.coral.common.utils.TypeDerivationUtil; import com.linkedin.coral.hive.hive2rel.HiveToRelConverter; import com.linkedin.coral.trino.rel2trino.transformers.FromUtcTimestampOperatorTransformer; +import com.linkedin.coral.trino.rel2trino.transformers.NamedStructOperandTransformer; /** @@ -31,7 +32,8 @@ public class DataTypeDerivedSqlCallConverter extends SqlShuttle { public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode topSqlNode) { SqlValidator sqlValidator = new HiveToRelConverter(mscClient).getSqlValidator(); TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(sqlValidator, topSqlNode); - operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil)); + operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil), + new NamedStructOperandTransformer(typeDerivationUtil)); } @Override diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/NamedStructOperandTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/NamedStructOperandTransformer.java new file mode 100644 index 000000000..89ceeb14e --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/NamedStructOperandTransformer.java @@ -0,0 +1,61 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transformers; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlRowTypeSpec; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeUtil; + +import com.linkedin.coral.common.transformers.SqlCallTransformer; +import com.linkedin.coral.common.utils.TypeDerivationUtil; +import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction; + +import static org.apache.calcite.sql.parser.SqlParserPos.ZERO; + + +/** + * Converts Coral's named_struct function to CAST AS ROW(types) function. + */ +public class NamedStructOperandTransformer extends SqlCallTransformer { + + public NamedStructOperandTransformer(TypeDerivationUtil typeDerivationUtil) { + super(typeDerivationUtil); + } + + @Override + protected boolean condition(SqlCall sqlCall) { + return sqlCall.getOperator().equals(HiveNamedStructFunction.NAMED_STRUCT); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List inputOperands = sqlCall.getOperandList(); + + List rowTypes = new ArrayList<>(); + List fieldNames = new ArrayList<>(); + for (int i = 0; i < inputOperands.size(); i += 2) { + assert inputOperands.get(i) instanceof SqlLiteral; + fieldNames.add(((SqlLiteral) inputOperands.get(i)).getStringValue()); + } + + List rowCallOperands = new ArrayList<>(); + for (int i = 1; i < inputOperands.size(); i += 2) { + rowCallOperands.add(inputOperands.get(i)); + RelDataType type = deriveRelDatatype(inputOperands.get(i)); + rowTypes.add(SqlTypeUtil.convertTypeToSpec(type)); + } + SqlNode rowCall = SqlStdOperatorTable.ROW.createCall(ZERO, rowCallOperands); + return SqlStdOperatorTable.CAST.createCall(ZERO, rowCall, new SqlRowTypeSpec(fieldNames, rowTypes, ZERO)); + } +}