diff --git a/discoverx/explorer.py b/discoverx/explorer.py index a83ea48..1a83ffc 100644 --- a/discoverx/explorer.py +++ b/discoverx/explorer.py @@ -41,7 +41,7 @@ def _to_info_list(self, info_rows: list[Row]) -> list[TableInfo]: filtered_tables = [self._to_info_row(row) for row in info_rows] return filtered_tables - def get_tables_info(self, catalogs: str, schemas: str, tables: str, columns: list[str] = []) -> list[TableInfo]: + def get_tables_info(self, catalogs: str, schemas: str, tables: str, columns: list[str] = [], with_tags=False) -> list[TableInfo]: # Filter tables by matching filter table_list_sql = self._get_table_list_sql(catalogs, schemas, tables, columns) @@ -52,7 +52,9 @@ def get_tables_info(self, catalogs: str, schemas: str, tables: str, columns: lis return self._to_info_list(filtered_tables) - def _get_table_list_sql(self, catalogs: str, schemas: str, tables: str, columns: list[str] = []) -> str: + def _get_table_list_sql( + self, catalogs: str, schemas: str, tables: str, columns: list[str] = [], with_tags=False + ) -> str: """ Returns a SQL expression which returns a list of columns matching the specified filters @@ -84,7 +86,7 @@ def _get_table_list_sql(self, catalogs: str, schemas: str, tables: str, columns: match_any_col = "|".join([f'({c.replace("*", ".*")})' for c in columns]) columns_sql = f"""AND regexp_like(column_name, "^{match_any_col}$")""" - sql = f""" + with_column_info_sql = f""" WITH tb_list AS ( SELECT DISTINCT table_catalog, @@ -114,6 +116,44 @@ def _get_table_list_sql(self, catalogs: str, schemas: str, tables: str, columns: GROUP BY info_schema.table_catalog, info_schema.table_schema, info_schema.table_name ), + with_column_info AS ( + SELECT + col_list.* + FROM col_list + INNER JOIN tb_list ON ( + col_list.table_catalog <=> tb_list.table_catalog AND + col_list.table_schema = tb_list.table_schema AND + col_list.table_name = tb_list.table_name) + ) + + """ + + tags_sql = f""" + , + catalog_tags AS ( + SELECT + info_schema.catalog_name AS table_catalog, + collect_list(struct(tag_name, tag_value)) as catalog_tags + FROM {self.information_schema}.catalog_tags info_schema + WHERE + catalog_name != "system" + {catalog_tags_sql if catalogs != "*" else ""} + GROUP BY info_schema.catalog_name + ), + + schema_tags AS ( + SELECT + info_schema.catalog_name AS table_catalog, + info_schema.schema_name AS table_schema, + collect_list(struct(tag_name, tag_value)) as schema_tags + FROM {self.information_schema}.schema_tags info_schema + WHERE + schema_name != "information_schema" + {catalog_tags_sql if catalogs != "*" else ""} + {schema_tags_sql if schemas != "*" else ""} + GROUP BY info_schema.catalog_name, info_schema.schema_name + ), + table_tags AS ( SELECT info_schema.catalog_name AS table_catalog, @@ -129,27 +169,76 @@ def _get_table_list_sql(self, catalogs: str, schemas: str, tables: str, columns: GROUP BY info_schema.catalog_name, info_schema.schema_name, info_schema.table_name ), - with_column_info AS ( + column_tags AS ( SELECT - col_list.* - FROM col_list - INNER JOIN tb_list ON ( - col_list.table_catalog <=> tb_list.table_catalog AND - col_list.table_schema = tb_list.table_schema AND - col_list.table_name = tb_list.table_name) - ) + info_schema.catalog_name AS table_catalog, + info_schema.schema_name AS table_schema, + info_schema.table_name, + collect_list(struct(column_name, tag_name, tag_value)) as column_tags + FROM {self.information_schema}.column_tags info_schema + WHERE + schema_name != "information_schema" + {catalog_tags_sql if catalogs != "*" else ""} + {schema_tags_sql if schemas != "*" else ""} + {table_sql if tables != "*" else ""} + GROUP BY info_schema.catalog_name, info_schema.schema_name, info_schema.table_name + ), - SELECT - with_column_info.*, - table_tags.table_tags - FROM with_column_info + tags AS ( + SELECT + tb_list.table_catalog, + tb_list.table_schema, + tb_list.table_name, + catalog_tags.catalog_tags, + schema_tags.schema_tags, + table_tags.table_tags, + column_tags.column_tags + FROM tb_list LEFT OUTER JOIN table_tags ON ( - with_column_info.table_catalog <=> table_tags.table_catalog AND - with_column_info.table_schema = table_tags.table_schema AND - with_column_info.table_name = table_tags.table_name) + table_tags.table_catalog <=> tb_list.table_catalog AND + table_tags.table_schema = tb_list.table_schema AND + table_tags.table_name = tb_list.table_name + ) + LEFT OUTER JOIN schema_tags + ON tb_list.table_catalog <=> schema_tags.table_catalog AND tb_list.table_schema = schema_tags.table_schema + LEFT OUTER JOIN column_tags + ON tb_list.table_catalog <=> column_tags.table_catalog AND tb_list.table_schema = column_tags.table_schema AND tb_list.table_name = column_tags.table_name + LEFT OUTER JOIN catalog_tags + ON catalog_tags.table_catalog <=> tb_list.table_catalog + ) + """ + if with_tags: + sql = ( + with_column_info_sql + + tags_sql + + f""" + SELECT + with_column_info.*, + tags.table_tags, + tags.catalog_tags, + tags.schema_tags, + tags.table_tags, + tags.column_tags + FROM with_column_info + LEFT OUTER JOIN tags ON ( + with_column_info.table_catalog <=> tags.table_catalog AND + with_column_info.table_schema = tags.table_schema AND + with_column_info.table_name = tags.table_name) + """ + ) + else: + sql = ( + with_column_info_sql + + f""" + SELECT + * + FROM with_column_info + """ + ) + return helper.strip_margin(sql) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 50d1438..6917479 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -158,6 +158,19 @@ def sample_datasets(spark: SparkSession, request): f"CREATE TABLE IF NOT EXISTS default.columns USING delta LOCATION '{warehouse_dir}/columns' AS SELECT * FROM view_columns_mock" ) + # column_tags + test_file_path = module_path.parent / "data/column_tags.csv" + ( + spark.read.option("header", True) + .schema( + "catalog_name string, schema_name string, table_name string, column_name string, tag_name string, tag_value string" + ) + .csv(str(test_file_path.resolve())) + ).createOrReplaceTempView("column_tags_temp_view") + spark.sql( + f"CREATE TABLE IF NOT EXISTS default.column_tags USING delta LOCATION '{warehouse_dir}/column_tags' AS SELECT * FROM column_tags_temp_view" + ) + # table_tags test_file_path = module_path.parent / "data/table_tags.csv" ( @@ -169,6 +182,28 @@ def sample_datasets(spark: SparkSession, request): f"CREATE TABLE IF NOT EXISTS default.table_tags USING delta LOCATION '{warehouse_dir}/table_tags' AS SELECT * FROM table_tags_temp_view" ) + # schema_tags + test_file_path = module_path.parent / "data/schema_tags.csv" + ( + spark.read.option("header", True) + .schema("catalog_name string,schema_name string,tag_name string,tag_value string") + .csv(str(test_file_path.resolve())) + ).createOrReplaceTempView("schema_tags_temp_view") + spark.sql( + f"CREATE TABLE IF NOT EXISTS default.schema_tags USING delta LOCATION '{warehouse_dir}/schema_tags' AS SELECT * FROM schema_tags_temp_view" + ) + + # catalog_tags + test_file_path = module_path.parent / "data/catalog_tags.csv" + ( + spark.read.option("header", True) + .schema("catalog_name string,tag_name string,tag_value string") + .csv(str(test_file_path.resolve())) + ).createOrReplaceTempView("catalog_tags_temp_view") + spark.sql( + f"CREATE TABLE IF NOT EXISTS default.catalog_tags USING delta LOCATION '{warehouse_dir}/catalog_tags' AS SELECT * FROM catalog_tags_temp_view" + ) + logging.info("Sample datasets created") yield None @@ -178,7 +213,10 @@ def sample_datasets(spark: SparkSession, request): spark.sql("DROP TABLE IF EXISTS default.tb_1") spark.sql("DROP TABLE IF EXISTS default.tb_2") spark.sql("DROP TABLE IF EXISTS default.columns") + spark.sql("DROP TABLE IF EXISTS default.column_tags") spark.sql("DROP TABLE IF EXISTS default.table_tags") + spark.sql("DROP TABLE IF EXISTS default.schema_tags") + spark.sql("DROP TABLE IF EXISTS default.catalog_tags") if Path(warehouse_dir).exists(): shutil.rmtree(warehouse_dir) diff --git a/tests/unit/data/catalog_tags.csv b/tests/unit/data/catalog_tags.csv new file mode 100644 index 0000000..d0cb12f --- /dev/null +++ b/tests/unit/data/catalog_tags.csv @@ -0,0 +1,2 @@ +catalog_name,tag_name,tag_value +,catalog-pii,true \ No newline at end of file diff --git a/tests/unit/data/column_tags.csv b/tests/unit/data/column_tags.csv new file mode 100644 index 0000000..d2c7517 --- /dev/null +++ b/tests/unit/data/column_tags.csv @@ -0,0 +1,4 @@ +catalog_name,schema_name,table_name,column_name,tag_name,tag_value +hive_metastore,default,tb_all_types,int_col,my_int_tag, +,default,tb_1,id,pk, +,default,tb_1,ip,pii,true \ No newline at end of file diff --git a/tests/unit/data/schema_tags.csv b/tests/unit/data/schema_tags.csv new file mode 100644 index 0000000..7660bf1 --- /dev/null +++ b/tests/unit/data/schema_tags.csv @@ -0,0 +1,2 @@ +catalog_name,schema_name,tag_name,tag_value +,default,schema-pii,true \ No newline at end of file diff --git a/tests/unit/data/table_tags.csv b/tests/unit/data/table_tags.csv index 501db03..2099042 100644 --- a/tests/unit/data/table_tags.csv +++ b/tests/unit/data/table_tags.csv @@ -1,4 +1,3 @@ catalog_name,schema_name,table_name,tag_name,tag_value hive_metastore,default,tb_all_types,int_col,my_int_tag -,default,tb_1,pk, ,default,tb_1,pii,true \ No newline at end of file