Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Coral-Hive] [Coral-Trino] Make named_struct a Coral IR operator and Migrate GenericProject Function #431

Merged
merged 10 commits into from
Jul 10, 2023
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
/**
* Copyright 2018-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.hive.hive2rel;

import java.util.ArrayList;
import java.util.List;

import com.google.common.base.Preconditions;
Expand All @@ -17,7 +16,6 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql2rel.ReflectiveConvertletTable;
import org.apache.calcite.sql2rel.SqlRexContext;
import org.apache.calcite.sql2rel.SqlRexConvertlet;
Expand All @@ -26,7 +24,6 @@
import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator;
import com.linkedin.coral.hive.hive2rel.functions.HiveInOperator;
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;


/**
Expand All @@ -35,17 +32,6 @@
*/
public class HiveConvertletTable extends ReflectiveConvertletTable {

@SuppressWarnings("unused")
public RexNode convertNamedStruct(SqlRexContext cx, HiveNamedStructFunction func, SqlCall call) {
List<RexNode> operandExpressions = new ArrayList<>(call.operandCount() / 2);
for (int i = 0; i < call.operandCount(); i += 2) {
operandExpressions.add(cx.convertExpression(call.operand(i + 1)));
}
RelDataType retType = cx.getValidator().getValidatedNodeType(call);
RexNode rowNode = cx.getRexBuilder().makeCall(retType, SqlStdOperatorTable.ROW, operandExpressions);
return cx.getRexBuilder().makeCast(retType, rowNode);
}

@SuppressWarnings("unused")
public RexNode convertHiveInOperator(SqlRexContext cx, HiveInOperator operator, SqlCall call) {
List<SqlNode> operandList = call.getOperandList();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2017-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -497,12 +497,11 @@ public void testStructPeekDisallowed() {
public void testStructReturnFieldAccess() {
final String sql = "select named_struct('field_a', 10, 'field_b', 'abc').field_b";
RelNode rel = toRel(sql);
final String expectedRel = "LogicalProject(EXPR$0=[CAST(ROW(10, 'abc')):"
+ "RecordType(INTEGER NOT NULL field_a, CHAR(3) NOT NULL field_b) NOT NULL.field_b])\n"
final String expectedRel = "LogicalProject(EXPR$0=[named_struct('field_a', 10, 'field_b', 'abc').field_b])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(relToStr(rel), expectedRel);
final String expectedSql = "SELECT CAST(ROW(10, 'abc') AS ROW(field_a INTEGER, field_b CHAR(3))).field_b\n"
+ "FROM (VALUES (0)) t (ZERO)";
final String expectedSql =
"SELECT named_struct('field_a', 10, 'field_b', 'abc').field_b\n" + "FROM (VALUES (0)) t (ZERO)";
assertEquals(relToHql(rel), expectedSql);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -44,8 +44,7 @@ public void testMixedTypes() {
final String sql = "SELECT named_struct('abc', 123, 'def', 'xyz')";
RelNode rel = toRel(sql);
final String generated = relToStr(rel);
final String expected = ""
+ "LogicalProject(EXPR$0=[CAST(ROW(123, 'xyz')):RecordType(INTEGER NOT NULL abc, CHAR(3) NOT NULL def) NOT NULL])\n"
final String expected = "" + "LogicalProject(EXPR$0=[named_struct('abc', 123, 'def', 'xyz')])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}
Expand All @@ -54,9 +53,8 @@ public void testMixedTypes() {
public void testNullFieldValue() {
final String sql = "SELECT named_struct('abc', cast(NULL as int), 'def', 150)";
final String generated = sqlToRelStr(sql);
final String expected =
"LogicalProject(EXPR$0=[CAST(ROW(CAST(null:NULL):INTEGER, 150)):RecordType(INTEGER abc, INTEGER NOT NULL def) NOT NULL])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
final String expected = "LogicalProject(EXPR$0=[named_struct('abc', CAST(null:NULL):INTEGER, 'def', 150)])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}

Expand All @@ -65,7 +63,7 @@ public void testAllNullValues() {
final String sql = "SELECT named_struct('abc', cast(NULL as int), 'def', cast(NULL as double))";
final String generated = sqlToRelStr(sql);
final String expected =
"LogicalProject(EXPR$0=[CAST(ROW(CAST(null:NULL):INTEGER, CAST(null:NULL):DOUBLE)):RecordType(INTEGER abc, DOUBLE def) NOT NULL])\n"
"LogicalProject(EXPR$0=[named_struct('abc', CAST(null:NULL):INTEGER, 'def', CAST(null:NULL):DOUBLE)])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}
Expand All @@ -74,10 +72,9 @@ public void testAllNullValues() {
public void testNestedComplexTypes() {
final String sql = "SELECT named_struct('arr', array(10, 15), 's', named_struct('f1', 123, 'f2', array(20.5)))";
final String generated = sqlToRelStr(sql);
final String expected = "LogicalProject(EXPR$0=[CAST(ROW(ARRAY(10, 15), CAST(ROW(123, ARRAY(20.5:DECIMAL(3, 1)))):"
+ "RecordType(INTEGER NOT NULL f1, DECIMAL(3, 1) NOT NULL ARRAY NOT NULL f2) NOT NULL)):"
+ "RecordType(INTEGER NOT NULL ARRAY NOT NULL arr, RecordType(INTEGER NOT NULL f1, DECIMAL(3, 1) NOT NULL ARRAY NOT NULL f2) NOT NULL s) NOT NULL])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
final String expected =
"LogicalProject(EXPR$0=[named_struct('arr', ARRAY(10, 15), 's', named_struct('f1', 123, 'f2', ARRAY(20.5:DECIMAL(3, 1))))])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
// verified by human that expected string is correct and retained here to protect from future changes
assertEquals(generated, expected);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.functions.GenericProjectFunction;
import com.linkedin.coral.trino.rel2trino.functions.GenericProjectToTrinoConverter;

import static com.linkedin.coral.trino.rel2trino.CoralTrinoConfigKeys.*;
import static org.apache.calcite.sql.type.ReturnTypes.explicit;
Expand Down Expand Up @@ -160,14 +158,6 @@ public TrinoRexConverter(RelNode node, Map<String, Boolean> configs) {

@Override
public RexNode visitCall(RexCall call) {
// GenericProject requires a nontrivial function rewrite because of the following:
// - makes use of Trino built-in UDFs transform_values for map objects and transform for array objects
// which has lambda functions as parameters
// - syntax is difficult for Calcite to parse
// - the return type varies based on a desired schema to be projected
if (call.getOperator() instanceof GenericProjectFunction) {
return GenericProjectToTrinoConverter.convertGenericProject(rexBuilder, call, node);
}

final String operatorName = call.getOperator().getName();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import com.linkedin.coral.common.utils.TypeDerivationUtil;
import com.linkedin.coral.hive.hive2rel.HiveToRelConverter;
import com.linkedin.coral.trino.rel2trino.transformers.FromUtcTimestampOperatorTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.GenericProjectTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.NamedStructToCastTransformer;


/**
Expand All @@ -31,7 +33,8 @@ public class DataTypeDerivedSqlCallConverter extends SqlShuttle {
public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode topSqlNode) {
SqlValidator sqlValidator = new HiveToRelConverter(mscClient).getSqlValidator();
TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(sqlValidator, topSqlNode);
operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil));
operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil),
new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* If a column, colA, has a RelDataType, relDataTypeA, with a Trino type string, trinoTypeStringA = buildStructDataTypeString(relDataTypeA),
* then the following operation is syntactically and semantically correct in Trino: CAST(colA as trinoTypeStringA)
*/
class RelDataTypeToTrinoTypeStringConverter {
public class RelDataTypeToTrinoTypeStringConverter {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I know why these three classes including TrinoMapTransformValuesFunction and TrinoStructCastRowFunction are converted to public? I don't see any usage of this class in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These classes are used in GenericProjectTransformer for example here. Previously GenericProjectTransformer was in the same package as RelDataTypeToTrinoTypeStringConverter but now it's moved to another package.

private RelDataTypeToTrinoTypeStringConverter() {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* Instead, we represent the input to this UDF as a string and we set its return type is passed as a parameter
* on creation.
*/
class TrinoMapTransformValuesFunction extends GenericTemplateFunction {
public class TrinoMapTransformValuesFunction extends GenericTemplateFunction {
public TrinoMapTransformValuesFunction(RelDataType transformValuesDataType) {
super(transformValuesDataType, "transform_values");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* Instead, we represent the input to this UDF as a string and we set its return type is passed as a parameter
* on creation.
*/
class TrinoStructCastRowFunction extends GenericTemplateFunction {
public class TrinoStructCastRowFunction extends GenericTemplateFunction {
public TrinoStructCastRowFunction(RelDataType structDataType) {
super(structDataType, "cast");
}
Expand Down
Loading