Skip to content

Commit

Permalink
Use RETENTION_HOUR for further SQL restriction
Browse files Browse the repository at this point in the history
  • Loading branch information
chenselena committed Oct 15, 2024
1 parent 0a1b0c6 commit 35b59a6
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,30 +43,33 @@ public void setupSpark() {

@Test
public void testSimpleSetReplicationPolicy() {
String replicationConfigJson = "{\"cluster\":\"a\", \"interval\":\"b\"}";
String replicationConfigJson = "[{\"cluster\":\"a\", \"interval\":\"24H\"}]";
Dataset<Row> ds =
spark.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = "
+ "({cluster:'a', interval:'b'}))");
+ "({cluster:'a', interval:24H}))");
assert isPlanValid(ds, replicationConfigJson);

// Test support with multiple clusters
replicationConfigJson =
"{\"cluster\":\"a\", \"interval\":\"b\"}, {\"cluster\":\"aa\", \"interval\":\"bb\"}";
"[{\"cluster\":\"a\", \"interval\":\"12H\"}, {\"cluster\":\"aa\", \"interval\":\"12H\"}]";
ds =
spark.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = "
+ "({cluster:'a', interval:'b'}, {cluster:'aa', interval:'bb'}))");
+ "({cluster:'a', interval:12h}, {cluster:'aa', interval:12H}))");
assert isPlanValid(ds, replicationConfigJson);
}

@Test
public void testSimpleSetReplicationPolicyOptionalInterval() {
// Test with optional interval
replicationConfigJson = "{\"cluster\":\"a\"}";
ds =
String replicationConfigJson = "[{\"cluster\":\"a\"}]";
Dataset<Row> ds =
spark.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = " + "({cluster:'a'}))");
assert isPlanValid(ds, replicationConfigJson);

// Test with optional interval for multiple clusters
replicationConfigJson = "{\"cluster\":\"a\"}, {\"cluster\":\"b\"}";
replicationConfigJson = "[{\"cluster\":\"a\"}, {\"cluster\":\"b\"}]";
ds =
spark.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = "
Expand Down Expand Up @@ -108,7 +111,7 @@ public void testReplicationPolicyWithoutProperSyntax() {
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster:, interval: 'bb'}))")
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster:, interval: '12h'}))")
.show());

// Missing interval value but keyword present
Expand All @@ -126,7 +129,7 @@ public void testReplicationPolicyWithoutProperSyntax() {
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster:, interval:'a'}, {cluster:, interval: 'b'}))")
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster:, interval:'12H'}, {cluster:, interval: '12H'}))")
.show());

// Missing cluster keyword for multiple clusters
Expand All @@ -135,15 +138,16 @@ public void testReplicationPolicyWithoutProperSyntax() {
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({interval:'a'}, {interval: 'b'}))")
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({interval:'a'}, {interval: '12h'}))")
.show());

// Missing cluster keyword
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({interval: 'ss'}))")
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({interval: '12h'}))")
.show());

// Typo in keyword interval
Expand All @@ -152,7 +156,7 @@ public void testReplicationPolicyWithoutProperSyntax() {
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'aa', interv: 'ss'}))")
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'aa', interv: '12h'}))")
.show());

// Typo in keyword cluster
Expand All @@ -161,7 +165,7 @@ public void testReplicationPolicyWithoutProperSyntax() {
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({clustr: 'aa', interval: 'ss'}))")
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({clustr: 'aa', interval: '12h'}))")
.show());

// Missing quote in cluster value
Expand All @@ -170,16 +174,25 @@ public void testReplicationPolicyWithoutProperSyntax() {
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: aa', interval: 'ss}))")
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: aa', interval: '12h}))")
.show());

// Typo in REPLICATION keyword
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICAT = ({cluster: 'aa', interval: '12h'}))")
.show());

// Type in REPLICATION keyword
// Interval input does not follow 'h/H' format
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICAT = ({cluster: 'aa', interval: 'ss}))")
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'aa', interval: '12'}))")
.show());

// Missing cluster and interval values
Expand Down Expand Up @@ -213,7 +226,6 @@ public void tearDownSpark() {

@SneakyThrows
private boolean isPlanValid(Dataset<Row> dataframe, String replicationConfigJson) {
replicationConfigJson = "[" + replicationConfigJson + "]";
String queryStr = dataframe.queryExecution().explainString(ExplainMode.fromString("simple"));
JsonArray jsonArray = new Gson().fromJson(replicationConfigJson, JsonArray.class);
boolean isValid = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ replicationPolicyClusterClause
;

replicationPolicyIntervalClause
: INTERVAL ':' STRING
: INTERVAL ':' RETENTION_HOUR
;

columnRetentionPolicyPatternClause
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends Openh
val cluster = typedVisit[String](ctx.replicationPolicyClusterClause())
val interval = if (ctx.replicationPolicyIntervalClause() != null)
typedVisit[String](ctx.replicationPolicyIntervalClause())
else null
else
null
(cluster, Option(interval))
}

Expand All @@ -114,7 +115,7 @@ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends Openh
}

override def visitReplicationPolicyIntervalClause(ctx: ReplicationPolicyIntervalClauseContext): (String) = {
ctx.STRING().getText
ctx.RETENTION_HOUR().getText.toUpperCase
}

override def visitColumnRetentionPolicy(ctx: ColumnRetentionPolicyContext): (String, String) = {
Expand Down Expand Up @@ -146,7 +147,7 @@ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends Openh
}

override def visitDuration(ctx: DurationContext): (String, Int) = {
val granularity = if (ctx.RETENTION_DAY != null) {
val granularity = if (ctx.RETENTION_DAY() != null) {
TimePartitionSpec.GranularityEnum.DAY.getValue()
} else if (ctx.RETENTION_YEAR() != null) {
TimePartitionSpec.GranularityEnum.YEAR.getValue()
Expand Down

0 comments on commit 35b59a6

Please sign in to comment.