Skip to content

Commit

Permalink
Initial changes
Browse files Browse the repository at this point in the history
  • Loading branch information
chenselena committed Oct 9, 2024
1 parent a5c00d9 commit ad12bbd
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,46 +43,39 @@ public void setupSpark() {

@Test
public void testSimpleSetReplicationPolicy() {
String replicationConfigJson =
"{\"cluster\":\"a\", \"schedule\":\"b\"}, {\"cluster\": \"aa\", \"schedule\": \"bb\"}";
String replicationConfigJson = "{\"cluster\":\"a\", \"interval\":\"b\"}";
Dataset<Row> ds =
spark.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = "
+ "({cluster:'a', schedule:'b'}, {cluster: 'aa', schedule: 'bb'}))");
+ "({cluster:'a', interval:'b'}))");
assert isPlanValid(ds, replicationConfigJson);
}

replicationConfigJson = "{\"cluster\":\"a\", \"schedule\":\"b\"}";
ds =
spark.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster:'a', schedule:'b'}))");
@Test
public void testSimpleSetReplicationPolicyOptionalInterval() {
String replicationConfigJson = "{\"cluster\":\"a\"}";
Dataset<Row> ds =
spark.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = " + "({cluster:'a'}))");
assert isPlanValid(ds, replicationConfigJson);
}

@Test
public void testReplicationPolicyWithoutProperSyntax() {
// missing schedule keyword
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'aa'}))")
.show());

// Missing cluster keyword
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({schedule: 'ss'}))")
.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({interval: 'ss'}))")
.show());

// Typo in keyword schedule
// Typo in keyword interval
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'aa', schedul: 'ss'}))")
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'aa', interv: 'ss'}))")
.show());

// Typo in keyword cluster
Expand All @@ -91,7 +84,7 @@ public void testReplicationPolicyWithoutProperSyntax() {
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({clustr: 'aa', schedule: 'ss'}))")
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({clustr: 'aa', interval: 'ss'}))")
.show());

// Missing quote in cluster value
Expand All @@ -100,7 +93,7 @@ public void testReplicationPolicyWithoutProperSyntax() {
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: aa', schedule: 'ss}))")
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: aa', interval: 'ss}))")
.show());

// Type in REPLICATION keyword
Expand All @@ -109,7 +102,7 @@ public void testReplicationPolicyWithoutProperSyntax() {
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICAT = ({cluster: 'aa', schedule: 'ss}))")
"ALTER TABLE openhouse.db.table SET POLICY (REPLICAT = ({cluster: 'aa', interval: 'ss}))")
.show());

// Missing cluster and schedule value
Expand Down Expand Up @@ -150,8 +143,11 @@ private boolean isPlanValid(Dataset<Row> dataframe, String replicationConfigJson
for (JsonElement element : jsonArray) {
JsonObject entry = element.getAsJsonObject();
String cluster = entry.get("cluster").getAsString();
String schedule = entry.get("schedule").getAsString();
isValid = queryStr.contains(cluster) && queryStr.contains(schedule);
isValid = queryStr.contains(cluster);
if (entry.has("interval")) {
String interval = entry.get("interval").getAsString();
isValid = queryStr.contains(cluster) && queryStr.contains(interval);
}
}
return isValid;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,15 @@ replicationPolicy
;

tableReplicationPolicy
: '(' replicationPolicyClause (',' replicationPolicyClause)* ')'
: '(' '{' replicationPolicyClusterClause (',' replicationPolicyIntervalClause)? '}' ')'
;

replicationPolicyClause
: '{' CLUSTER ':' STRING ',' SCHEDULE ':' STRING '}'
replicationPolicyClusterClause
: CLUSTER ':' STRING
;

replicationPolicyIntervalClause
: INTERVAL ':' STRING
;

columnRetentionPolicyPatternClause
Expand Down Expand Up @@ -165,7 +169,7 @@ SHOW: 'SHOW';
GRANTS: 'GRANTS';
PATTERN: 'PATTERN';
CLUSTER: 'CLUSTER';
SCHEDULE: 'SCHEDULE';
INTERVAL: 'INTERVAL';
WHERE: 'WHERE';
COLUMN: 'COLUMN';
PII: 'PII';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends Openh

override def visitSetReplicationPolicy(ctx: SetReplicationPolicyContext): SetReplicationPolicy = {
val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier)
val replicationPolicies = typedVisit[Seq[String]](ctx.replicationPolicy())
SetReplicationPolicy(tableName, replicationPolicies)
val (clusterName, interval) = typedVisit[(String, String)](ctx.replicationPolicy())
SetReplicationPolicy(tableName, clusterName, interval)
}

// override def visitSetReplicationPolicy(ctx: SetReplicationPolicyContext): SetReplicationPolicy = {
// val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier)
// val replicationPolicies = typedVisit[Seq[String]](ctx.replicationPolicy())
// SetReplicationPolicy(tableName, replicationPolicies)
// }

override def visitSetSharingPolicy(ctx: SetSharingPolicyContext): SetSharingPolicy = {
val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier)
val sharing = typedVisit[String](ctx.sharingPolicy())
Expand Down Expand Up @@ -93,19 +99,45 @@ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends Openh
typedVisit[(String, Int)](ctx.duration())
}

override def visitReplicationPolicy(ctx: ReplicationPolicyContext): (Seq[String]) = {
typedVisit[(Seq[String])](ctx.tableReplicationPolicy())
override def visitReplicationPolicy(ctx: ReplicationPolicyContext): (String, String) = {
typedVisit[(String, String)](ctx.tableReplicationPolicy())
}

override def visitTableReplicationPolicy(ctx: TableReplicationPolicyContext): (Seq[String]) = {
ctx.replicationPolicyClause().map(ele => typedVisit[String](ele)).toSeq
override def visitTableReplicationPolicy(ctx: TableReplicationPolicyContext): (String, String) = {
val clusterName = typedVisit[String](ctx.replicationPolicyClusterClause())
val interval = if (ctx.replicationPolicyIntervalClause() != null)
typedVisit[String](ctx.replicationPolicyIntervalClause())
else null
(clusterName, interval)
}

override def visitReplicationPolicyClause(ctx: ReplicationPolicyClauseContext): (String) = {
val replicationPolicy = ctx.STRING().map(_.getText)
replicationPolicy.mkString(":")
// override def visitReplicationPolicyClause(ctx: ReplicationPolicyClauseContext): (String, String) = {
// val clusterName = typedVisit[String](ctx.replicationPolicyClusterClause())
// val interval = typedVisit[String](ctx.replicationPolicyIntervalClause())
// (clusterName, interval)
// }

override def visitReplicationPolicyClusterClause(ctx: ReplicationPolicyClusterClauseContext): (String) = {
ctx.STRING().getText
}

override def visitReplicationPolicyIntervalClause(ctx: ReplicationPolicyIntervalClauseContext): (String) = {
ctx.STRING().getText
}

// override def visitReplicationPolicy(ctx: ReplicationPolicyContext): (Seq[String]) = {
// typedVisit[(Seq[String])](ctx.tableReplicationPolicy())
// }
//
// override def visitTableReplicationPolicy(ctx: TableReplicationPolicyContext): (Seq[String]) = {
// ctx.replicationPolicyClause().map(ele => typedVisit[String](ele)).toSeq
// }
//
// override def visitReplicationPolicyClause(ctx: ReplicationPolicyClauseContext): (String) = {
// val replicationPolicy = ctx.STRING().map(_.getText)
// replicationPolicy.mkString(":")
// }

override def visitColumnRetentionPolicy(ctx: ColumnRetentionPolicyContext): (String, String) = {
if (ctx.columnRetentionPolicyPatternClause() != null) {
(ctx.columnNameClause().identifier().getText(), ctx.columnRetentionPolicyPatternClause().retentionColumnPatternClause().STRING().getText)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package com.linkedin.openhouse.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.plans.logical.Command

case class SetReplicationPolicy(tableName: Seq[String], replicationPolicies: Seq[String]) extends Command {
case class SetReplicationPolicy(tableName: Seq[String], clusterName: String, interval: String) extends Command {
override def simpleString(maxFields: Int): String = {
s"SetReplicationPolicy: ${tableName} ${replicationPolicies}}"
s"SetReplicationPolicy: ${tableName} ${clusterName} ${interval}"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ case class OpenhouseDataSourceV2Strategy(spark: SparkSession) extends Strategy w
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case SetRetentionPolicy(CatalogAndIdentifierExtractor(catalog, ident), granularity, count, colName, colPattern) =>
SetRetentionPolicyExec(catalog, ident, granularity, count, colName, colPattern) :: Nil
case SetReplicationPolicy(CatalogAndIdentifierExtractor(catalog, ident), replicationPolicies) =>
SetReplicationPolicyExec(catalog, ident, replicationPolicies) :: Nil
case SetReplicationPolicy(CatalogAndIdentifierExtractor(catalog, ident), clusterName, interval) =>
SetReplicationPolicyExec(catalog, ident, clusterName, interval) :: Nil
case SetSharingPolicy(CatalogAndIdentifierExtractor(catalog, ident), sharing) =>
SetSharingPolicyExec(catalog, ident, sharing) :: Nil
case SetColumnPolicyTag(CatalogAndIdentifierExtractor(catalog, ident), policyTag, cols) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec

case class SetReplicationPolicyExec(catalog: TableCatalog, ident: Identifier, replicationPolicies: Seq[String]) extends V2CommandExec{
case class SetReplicationPolicyExec(catalog: TableCatalog, ident: Identifier, clusterName: String, interval: String) extends V2CommandExec{
override protected def run(): Seq[InternalRow] = {
catalog.loadTable(ident) match {
case iceberg: SparkTable if iceberg.table().properties().containsKey("openhouse.tableId") =>
val key = "updated.openhouse.policy"
val value = s"""{"replication":{"schedules":[${replicationPolicies.map(replication => s"""{config:{${replication}}}""").mkString(",")}]}}"""
val value = s"""{"replication":{"":[{}]}"""
iceberg.table().updateProperties()
.set(key, value)
.commit()
Expand Down

0 comments on commit ad12bbd

Please sign in to comment.