diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/AbstractASTVisitor.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/AbstractASTVisitor.java index 8f71bf0cb..f922d1cb4 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/AbstractASTVisitor.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/AbstractASTVisitor.java @@ -112,6 +112,18 @@ protected R visit(ASTNode node, C ctx) { case HiveParser.TOK_ORDERBY: return visitOrderBy(node, ctx); + case HiveParser.TOK_GROUPING_SETS: + return visitGroupingSets(node,ctx); + + case HiveParser.TOK_ROLLUP_GROUPBY: + return visitRollUpGroupBy(node,ctx); + + case HiveParser.TOK_CUBE_GROUPBY: + return visitCubeGroupBy(node,ctx); + + case HiveParser.TOK_GROUPING_SETS_EXPRESSION: + return visitGroupingSetsExpression(node,ctx); + case HiveParser.TOK_TABSORTCOLNAMEASC: return visitSortColNameAsc(node, ctx); @@ -434,6 +446,20 @@ protected R visitOrderBy(ASTNode node, C ctx) { return visitChildren(node, ctx).get(0); } + protected R visitGroupingSets(ASTNode node, C ctx){ + return visitChildren(node, ctx).get(0); + } + + protected R visitRollUpGroupBy(ASTNode node , C ctx){ + return visitChildren(node, ctx).get(0); + } + + protected R visitCubeGroupBy(ASTNode node,C ctx){ + return visitChildren(node, ctx).get(0); + } + protected R visitGroupingSetsExpression(ASTNode node, C ctx){ + return visitChildren(node, ctx).get(0); + } protected R visitGroupBy(ASTNode node, C ctx) { return visitChildren(node, ctx).get(0); } diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java index 8bf1ac4a0..7f66310a6 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java @@ -16,6 +16,7 @@ import javax.annotation.Nullable; +import com.linkedin.coral.common.calcite.CalciteUtil; import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.sql.JoinConditionType; import org.apache.calcite.sql.JoinType; @@ -63,6 +64,7 @@ import static com.google.common.base.Preconditions.checkState; import static java.lang.String.format; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.*; import static org.apache.calcite.sql.parser.SqlParserPos.ZERO; @@ -487,6 +489,84 @@ protected SqlNode visitGroupBy(ASTNode node, ParseContext ctx) { return ctx.grpBy; } + + + @Override + protected SqlNode visitGroupingSets(ASTNode node, ParseContext ctx){ + // When Hive recognizes a grouping set, it omits the group by node in the constructed AST. + // However, Calcite requires the grouping set to be within a group by node when building SqlNode. + // Therefore, we need to add a group by node. + if(ctx.grpBy == null){ + ctx.grpBy = new SqlNodeList(new ArrayList(), ZERO); + } + + List groupingSets = visitChildren(node, ctx); + List operands = new ArrayList<>(); + if (groupingSets.isEmpty()) { + operands.add(CalciteUtil.createSqlNodeList(Collections.emptyList())); + } else { + // In Hive, fields used in GROUPING SETS must be declared after the GROUP BY clause. + // However, when constructing Calcite SqlNode, the SqlIdentifier after GROUP BY isn't required; + // only the child nodes of TOK_GROUPING_SETS are needed, so they are filtered out. + List operand = groupingSets.stream() + .filter(f -> !(f instanceof SqlIdentifier)) + .collect(Collectors.toList()); + operands.add(new SqlNodeList(operand, ZERO)); + } + SqlNode groupingSetsNode = GROUPING_SETS.createCall(ZERO, operands); + ctx.grpBy.add(groupingSetsNode); + + return groupingSetsNode; + } + + @Override + protected SqlNode visitGroupingSetsExpression(ASTNode node, ParseContext ctx){ + List identifiers = visitChildren(node, ctx); + if (identifiers == null || identifiers.isEmpty()){ + return CalciteUtil.createSqlNodeList(Collections.emptyList()); + }else{ + return ROW.createCall(ZERO, new SqlNodeList(identifiers, ZERO)); + } + } + + @Override + protected SqlNode visitCubeGroupBy(ASTNode node, ParseContext ctx){ + if(ctx.grpBy == null){ + ctx.grpBy = new SqlNodeList(new ArrayList(), ZERO); + } + + List cubeGroupby = visitChildren(node, ctx); + List operands = new ArrayList<>(); + if (cubeGroupby.isEmpty()) { + operands.add(CalciteUtil.createSqlNodeList(Collections.emptyList())); + } else { + operands.add(new SqlNodeList(cubeGroupby, ZERO)); + } + SqlNode cubeCall = CUBE.createCall(ZERO, operands); + ctx.grpBy.add(cubeCall); + + return cubeCall; + } + + @Override + protected SqlNode visitRollUpGroupBy(ASTNode node, ParseContext ctx){ + if(ctx.grpBy == null){ + ctx.grpBy = new SqlNodeList(new ArrayList(), ZERO); + } + + List rollupGroupby = visitChildren(node, ctx); + List operands = new ArrayList<>(); + if (rollupGroupby.isEmpty()) { + operands.add(CalciteUtil.createSqlNodeList(Collections.emptyList())); + } else { + operands.add(new SqlNodeList(rollupGroupby, ZERO)); + } + SqlNode rollupCall = ROLLUP.createCall(ZERO, operands); + ctx.grpBy.add(rollupCall); + + return rollupCall; + } + @Override protected SqlNode visitOperator(ASTNode node, ParseContext ctx) { ArrayList children = node.getChildren(); diff --git a/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilderTest.java b/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilderTest.java index 6deded3ca..57776d7c5 100644 --- a/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilderTest.java +++ b/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilderTest.java @@ -195,7 +195,16 @@ public Iterator getValidateSql() { "SELECT CASE WHEN `a` THEN 10 WHEN `b` THEN 20 ELSE 30 END FROM `foo`"), ImmutableList.of("SELECT named_struct('abc', 123, 'def', 234.23) FROM foo", "SELECT `named_struct`('abc', 123, 'def', 234.23) FROM `foo`"), - ImmutableList.of("SELECT 0L FROM foo", "SELECT 0 FROM `foo`")); + ImmutableList.of("SELECT 0L FROM foo", "SELECT 0 FROM `foo`"), + //test grouping set + ImmutableList.of("select deptno, job, avg(sal) from emp group by deptno,job grouping sets((deptno,job), (deptno),())", + "SELECT `deptno`, `job`, AVG(`sal`) FROM `emp` GROUP BY GROUPING SETS(ROW(`deptno`, `job`), ROW(`deptno`), ())"), + ImmutableList.of("select deptno, job, avg(sal) from emp group by deptno,job with cube", + "SELECT `deptno`, `job`, AVG(`sal`) FROM `emp` GROUP BY CUBE(`deptno`, `job`)"), + ImmutableList.of("select deptno, job, avg(sal) from emp group by deptno,job with rollup", + "SELECT `deptno`, `job`, AVG(`sal`) FROM `emp` GROUP BY ROLLUP(`deptno`, `job`)") + + ); return convertAndValidateSql.stream().map(x -> new Object[] { x.get(0), x.get(1) }).iterator(); }