Skip to content

Commit

Permalink
initial commit for namedstruct from PR#412
Browse files Browse the repository at this point in the history
  • Loading branch information
wmoustafa authored and aastha25 committed Jun 27, 2023
1 parent b2ba2e9 commit dc5753a
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
/**
* Copyright 2018-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.hive.hive2rel;

import java.util.ArrayList;
import java.util.List;

import com.google.common.base.Preconditions;
Expand All @@ -17,7 +16,6 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql2rel.ReflectiveConvertletTable;
import org.apache.calcite.sql2rel.SqlRexContext;
import org.apache.calcite.sql2rel.SqlRexConvertlet;
Expand All @@ -26,7 +24,6 @@
import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator;
import com.linkedin.coral.hive.hive2rel.functions.HiveInOperator;
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;


/**
Expand All @@ -35,17 +32,6 @@
*/
public class HiveConvertletTable extends ReflectiveConvertletTable {

@SuppressWarnings("unused")
public RexNode convertNamedStruct(SqlRexContext cx, HiveNamedStructFunction func, SqlCall call) {
List<RexNode> operandExpressions = new ArrayList<>(call.operandCount() / 2);
for (int i = 0; i < call.operandCount(); i += 2) {
operandExpressions.add(cx.convertExpression(call.operand(i + 1)));
}
RelDataType retType = cx.getValidator().getValidatedNodeType(call);
RexNode rowNode = cx.getRexBuilder().makeCall(retType, SqlStdOperatorTable.ROW, operandExpressions);
return cx.getRexBuilder().makeCast(retType, rowNode);
}

@SuppressWarnings("unused")
public RexNode convertHiveInOperator(SqlRexContext cx, HiveInOperator operator, SqlCall call) {
List<SqlNode> operandList = call.getOperandList();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2017-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -497,12 +497,11 @@ public void testStructPeekDisallowed() {
public void testStructReturnFieldAccess() {
final String sql = "select named_struct('field_a', 10, 'field_b', 'abc').field_b";
RelNode rel = toRel(sql);
final String expectedRel = "LogicalProject(EXPR$0=[CAST(ROW(10, 'abc')):"
+ "RecordType(INTEGER NOT NULL field_a, CHAR(3) NOT NULL field_b) NOT NULL.field_b])\n"
final String expectedRel = "LogicalProject(EXPR$0=[named_struct('field_a', 10, 'field_b', 'abc').field_b])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(relToStr(rel), expectedRel);
final String expectedSql = "SELECT CAST(ROW(10, 'abc') AS ROW(field_a INTEGER, field_b CHAR(3))).field_b\n"
+ "FROM (VALUES (0)) t (ZERO)";
final String expectedSql =
"SELECT named_struct('field_a', 10, 'field_b', 'abc').field_b\n" + "FROM (VALUES (0)) t (ZERO)";
assertEquals(relToHql(rel), expectedSql);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -44,8 +44,7 @@ public void testMixedTypes() {
final String sql = "SELECT named_struct('abc', 123, 'def', 'xyz')";
RelNode rel = toRel(sql);
final String generated = relToStr(rel);
final String expected = ""
+ "LogicalProject(EXPR$0=[CAST(ROW(123, 'xyz')):RecordType(INTEGER NOT NULL abc, CHAR(3) NOT NULL def) NOT NULL])\n"
final String expected = "" + "LogicalProject(EXPR$0=[named_struct('abc', 123, 'def', 'xyz')])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}
Expand All @@ -54,9 +53,8 @@ public void testMixedTypes() {
public void testNullFieldValue() {
final String sql = "SELECT named_struct('abc', cast(NULL as int), 'def', 150)";
final String generated = sqlToRelStr(sql);
final String expected =
"LogicalProject(EXPR$0=[CAST(ROW(CAST(null:NULL):INTEGER, 150)):RecordType(INTEGER abc, INTEGER NOT NULL def) NOT NULL])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
final String expected = "LogicalProject(EXPR$0=[named_struct('abc', CAST(null:NULL):INTEGER, 'def', 150)])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}

Expand All @@ -65,7 +63,7 @@ public void testAllNullValues() {
final String sql = "SELECT named_struct('abc', cast(NULL as int), 'def', cast(NULL as double))";
final String generated = sqlToRelStr(sql);
final String expected =
"LogicalProject(EXPR$0=[CAST(ROW(CAST(null:NULL):INTEGER, CAST(null:NULL):DOUBLE)):RecordType(INTEGER abc, DOUBLE def) NOT NULL])\n"
"LogicalProject(EXPR$0=[named_struct('abc', CAST(null:NULL):INTEGER, 'def', CAST(null:NULL):DOUBLE)])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}
Expand All @@ -74,10 +72,9 @@ public void testAllNullValues() {
public void testNestedComplexTypes() {
final String sql = "SELECT named_struct('arr', array(10, 15), 's', named_struct('f1', 123, 'f2', array(20.5)))";
final String generated = sqlToRelStr(sql);
final String expected = "LogicalProject(EXPR$0=[CAST(ROW(ARRAY(10, 15), CAST(ROW(123, ARRAY(20.5:DECIMAL(3, 1)))):"
+ "RecordType(INTEGER NOT NULL f1, DECIMAL(3, 1) NOT NULL ARRAY NOT NULL f2) NOT NULL)):"
+ "RecordType(INTEGER NOT NULL ARRAY NOT NULL arr, RecordType(INTEGER NOT NULL f1, DECIMAL(3, 1) NOT NULL ARRAY NOT NULL f2) NOT NULL s) NOT NULL])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
final String expected =
"LogicalProject(EXPR$0=[named_struct('arr', ARRAY(10, 15), 's', named_struct('f1', 123, 'f2', ARRAY(20.5:DECIMAL(3, 1))))])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
// verified by human that expected string is correct and retained here to protect from future changes
assertEquals(generated, expected);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.linkedin.coral.common.utils.TypeDerivationUtil;
import com.linkedin.coral.hive.hive2rel.HiveToRelConverter;
import com.linkedin.coral.trino.rel2trino.transformers.FromUtcTimestampOperatorTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.NamedStructOperandTransformer;


/**
Expand All @@ -31,7 +32,8 @@ public class DataTypeDerivedSqlCallConverter extends SqlShuttle {
public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode topSqlNode) {
SqlValidator sqlValidator = new HiveToRelConverter(mscClient).getSqlValidator();
TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(sqlValidator, topSqlNode);
operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil));
operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil),
new NamedStructOperandTransformer(typeDerivationUtil));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.trino.rel2trino.transformers;

import java.util.ArrayList;
import java.util.List;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlRowTypeSpec;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeUtil;

import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.common.utils.TypeDerivationUtil;
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;

import static org.apache.calcite.sql.parser.SqlParserPos.ZERO;


/**
* Converts Coral's named_struct function to CAST AS ROW(types) function.
*/
public class NamedStructOperandTransformer extends SqlCallTransformer {

public NamedStructOperandTransformer(TypeDerivationUtil typeDerivationUtil) {
super(typeDerivationUtil);
}

@Override
protected boolean condition(SqlCall sqlCall) {
return sqlCall.getOperator().equals(HiveNamedStructFunction.NAMED_STRUCT);
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
List<SqlNode> inputOperands = sqlCall.getOperandList();

List<SqlDataTypeSpec> rowTypes = new ArrayList<>();
List<String> fieldNames = new ArrayList<>();
for (int i = 0; i < inputOperands.size(); i += 2) {
assert inputOperands.get(i) instanceof SqlLiteral;
fieldNames.add(((SqlLiteral) inputOperands.get(i)).getStringValue());
}

List<SqlNode> rowCallOperands = new ArrayList<>();
for (int i = 1; i < inputOperands.size(); i += 2) {
rowCallOperands.add(inputOperands.get(i));
RelDataType type = deriveRelDatatype(inputOperands.get(i));
rowTypes.add(SqlTypeUtil.convertTypeToSpec(type));
}
SqlNode rowCall = SqlStdOperatorTable.ROW.createCall(ZERO, rowCallOperands);
return SqlStdOperatorTable.CAST.createCall(ZERO, rowCall, new SqlRowTypeSpec(fieldNames, rowTypes, ZERO));
}
}

0 comments on commit dc5753a

Please sign in to comment.