From e54410899d980d79264d4d5605e9fe5d45798b15 Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Wed, 20 Mar 2024 21:32:09 +0530 Subject: [PATCH 01/15] Added Threshold Query Builder feature --- .../labs/remorph/reconcile/query_builder.py | 116 ++++++++++++++---- .../labs/remorph/reconcile/recon_config.py | 8 +- tests/unit/reconcile/test_query_builder.py | 65 ++++++++-- 3 files changed, 146 insertions(+), 43 deletions(-) diff --git a/src/databricks/labs/remorph/reconcile/query_builder.py b/src/databricks/labs/remorph/reconcile/query_builder.py index e09177cca..12157231b 100644 --- a/src/databricks/labs/remorph/reconcile/query_builder.py +++ b/src/databricks/labs/remorph/reconcile/query_builder.py @@ -54,19 +54,20 @@ def build_hash_query(self) -> str: return select_query def _get_column_list(self) -> tuple[list[str], list[str]]: - column_mapping = self.table_conf.list_to_dict(ColumnMapping, "source_name") + tgt_column_mapping = self.table_conf.list_to_dict(ColumnMapping, "target_name") if self.table_conf.join_columns is None: join_columns = set() - elif self.layer == "source": - join_columns = {col.source_name for col in self.table_conf.join_columns} else: - join_columns = {col.target_name for col in self.table_conf.join_columns} + join_columns = set(self.table_conf.join_columns) if self.table_conf.select_columns is None: - select_columns = {sch.column_name for sch in self.schema} + columns = {sch.column_name for sch in self.schema} + select_columns = ( + columns if self.layer == "source" else self._get_mapped_columns(tgt_column_mapping, columns) + ) else: - select_columns = self._get_mapped_columns(self.layer, column_mapping, self.table_conf.select_columns) + select_columns = set(self.table_conf.select_columns) if self.table_conf.jdbc_reader_options and self.layer == "source": partition_column = {self.table_conf.jdbc_reader_options.partition_column} @@ -81,7 +82,7 @@ def _get_column_list(self) -> tuple[list[str], list[str]]: if self.table_conf.drop_columns is None: drop_columns = set() else: - drop_columns = self._get_mapped_columns(self.layer, column_mapping, self.table_conf.drop_columns) + drop_columns = set(self.table_conf.drop_columns) columns = sorted(all_columns - threshold_columns - drop_columns) key_columns = sorted(join_columns | partition_column) @@ -90,29 +91,26 @@ def _get_column_list(self) -> tuple[list[str], list[str]]: def _generate_transformation_rule_mapping(self, columns: list[str], schema: dict) -> list[TransformRuleMapping]: transformations_dict = self.table_conf.list_to_dict(Transformation, "column_name") - column_mapping_dict = self.table_conf.list_to_dict(ColumnMapping, "target_name") + column_mapping_dict = self.table_conf.list_to_dict(ColumnMapping, "source_name") transformation_rule_mapping = [] for column in columns: - if column_mapping_dict and self.layer == "target": - transform_column = ( - column_mapping_dict.get(column).source_name if column_mapping_dict.get(column) else column - ) - else: - transform_column = column - if transformations_dict and transform_column in transformations_dict.keys(): - transformation = self._get_layer_transform(transformations_dict, transform_column, self.layer) + if transformations_dict and column in transformations_dict.keys(): + transformation = self._get_layer_transform(transformations_dict, column, self.layer) else: - column_data_type = schema.get(column).data_type - transformation = self._get_default_transformation(self.source, column_data_type).format(column) + column_origin = column if self.layer == "source" else self._get_column_map(column, column_mapping_dict) + column_data_type = schema.get(column_origin).data_type + transformation = self._get_default_transformation(self.source, column_data_type).format(column_origin) - if column_mapping_dict and column in column_mapping_dict.keys(): + if column_mapping_dict and column in column_mapping_dict.keys() and self.layer == "target": column_alias = column_mapping_dict.get(column).source_name + column_origin = column_mapping_dict.get(column).target_name else: column_alias = column + column_origin = column - transformation_rule_mapping.append(TransformRuleMapping(column, transformation, column_alias)) + transformation_rule_mapping.append(TransformRuleMapping(column_origin, transformation, column_alias)) return transformation_rule_mapping @@ -163,10 +161,80 @@ def _construct_hash_query(table_name: str, query_filter: str, hash_expr: str, ke return select_query @staticmethod - def _get_mapped_columns(layer: str, column_mapping: dict, columns: list[str]) -> set[str]: - if layer == "source": - return set(columns) + def _get_mapped_columns(column_mapping: dict, columns: set[str]) -> set[str]: select_columns = set() for column in columns: - select_columns.add(column_mapping.get(column).target_name if column_mapping.get(column) else column) + select_columns.add(column_mapping.get(column).source_name if column_mapping.get(column) else column) return select_columns + + @staticmethod + def _get_column_map(column, column_mapping) -> str: + return column_mapping.get(column).target_name if column_mapping.get(column) else column + + def build_threshold_query(self) -> str: + column_mapping = self.table_conf.list_to_dict(ColumnMapping, "source_name") + transformations_dict = self.table_conf.list_to_dict(Transformation, "column_name") + + threshold_columns = set(threshold.column_name for threshold in self.table_conf.thresholds) + join_columns = set(self.table_conf.join_columns) + + if self.table_conf.jdbc_reader_options and self.layer == "source": + partition_column = {self.table_conf.jdbc_reader_options.partition_column} + else: + partition_column = set() + + all_columns = set(threshold_columns | join_columns | partition_column) + + query_columns = sorted( + all_columns if self.layer == "source" else self._get_mapped_columns(column_mapping, all_columns) + ) + + transformation_rule_mapping = self._get_custom_transformation(query_columns, transformations_dict, + column_mapping) + threshold_columns_expr = self._get_column_expr( + TransformRuleMapping.get_column_expression_with_alias, transformation_rule_mapping + ) + + if self.layer == "source": + table_name = self.table_conf.source_name + query_filter = self.table_conf.filters.source if self.table_conf.filters else " 1 = 1 " + else: + table_name = self.table_conf.target_name + query_filter = self.table_conf.filters.target if self.table_conf.filters else " 1 = 1 " + + # construct threshold select query + select_query = self._construct_threshold_query(table_name, query_filter, threshold_columns_expr) + + return select_query + + def _get_custom_transformation(self, columns, transformation, column_mapping): + transformation_rule_mapping = [] + for column in columns: + if transformation and column in transformation.keys(): + transformation = self._get_layer_transform(transformation, column, self.layer) + else: + transformation = None + + if column_mapping and column in column_mapping.keys() and self.layer == "target": + column_alias = column_mapping.get(column).source_name + column_src = column_mapping.get(column).target_name + else: + column_alias = column + column_src = column + + transformation_rule_mapping.append(TransformRuleMapping(column_src, transformation, column_alias)) + + return transformation_rule_mapping + + @staticmethod + def _construct_threshold_query(table_name, query_filter, threshold_columns_expr): + sql_query = StringIO() + column_expr = ",".join(threshold_columns_expr) + sql_query.write(f"select {column_expr} ") + + sql_query.write(f" from {table_name} where {query_filter}") + + select_query = sql_query.getvalue() + sql_query.close() + return select_query + diff --git a/src/databricks/labs/remorph/reconcile/recon_config.py b/src/databricks/labs/remorph/reconcile/recon_config.py index f1e08c28e..493a8bb9f 100644 --- a/src/databricks/labs/remorph/reconcile/recon_config.py +++ b/src/databricks/labs/remorph/reconcile/recon_config.py @@ -30,12 +30,6 @@ class JdbcReaderOptions: fetch_size: int = 100 -@dataclass -class JoinColumns: - source_name: str - target_name: str | None = None - - @dataclass class ColumnMapping: source_name: str @@ -67,7 +61,7 @@ class Filters: class Tables: source_name: str target_name: str - join_columns: list[JoinColumns] | None = None + join_columns: list[str] | None = None jdbc_reader_options: JdbcReaderOptions | None = None select_columns: list[str] | None = None drop_columns: list[str] | None = None diff --git a/tests/unit/reconcile/test_query_builder.py b/tests/unit/reconcile/test_query_builder.py index 546ac3b3e..55f29e293 100644 --- a/tests/unit/reconcile/test_query_builder.py +++ b/tests/unit/reconcile/test_query_builder.py @@ -2,7 +2,6 @@ from databricks.labs.remorph.reconcile.recon_config import ( ColumnMapping, JdbcReaderOptions, - JoinColumns, Schema, Tables, Thresholds, @@ -10,7 +9,7 @@ ) -def test_query_builder_without_join_column(): +def test_hash_query_builder_without_join_column(): table_conf = Tables( source_name="supplier", target_name="supplier", @@ -66,12 +65,12 @@ def test_query_builder_without_join_column(): assert actual_tgt_query == expected_tgt_query -def test_query_builder_with_defaults(): +def test_hash_query_builder_with_defaults(): table_conf = Tables( source_name="supplier", target_name="supplier", jdbc_reader_options=None, - join_columns=[JoinColumns(source_name="s_suppkey", target_name="s_suppkey")], + join_columns=["s_suppkey"], select_columns=None, drop_columns=None, column_mapping=None, @@ -122,12 +121,12 @@ def test_query_builder_with_defaults(): assert actual_tgt_query == expected_tgt_query -def test_query_builder_with_select(): +def test_hash_query_builder_with_select(): table_conf = Tables( source_name="supplier", target_name="supplier", jdbc_reader_options=None, - join_columns=[JoinColumns(source_name="s_suppkey", target_name="s_suppkey_t")], + join_columns=["s_suppkey"], select_columns=["s_suppkey", "s_name", "s_address"], drop_columns=None, column_mapping=[ @@ -178,12 +177,12 @@ def test_query_builder_with_select(): assert actual_tgt_query == expected_tgt_query -def test_query_builder_with_transformations_with_drop_and_default_select(): +def test_hash_query_builder_with_transformations_with_drop_and_default_select(): table_conf = Tables( source_name="supplier", target_name="supplier", jdbc_reader_options=None, - join_columns=[JoinColumns(source_name="s_suppkey", target_name="s_suppkey_t")], + join_columns=["s_suppkey"], select_columns=None, drop_columns=["s_comment"], column_mapping=[ @@ -249,14 +248,14 @@ def test_query_builder_with_transformations_with_drop_and_default_select(): assert actual_tgt_query == expected_tgt_query -def test_query_builder_with_jdbc_reader_options(): +def test_hash_query_builder_with_jdbc_reader_options(): table_conf = Tables( source_name="supplier", target_name="supplier", jdbc_reader_options=JdbcReaderOptions( number_partitions=100, partition_column="s_nationkey", lower_bound="0", upper_bound="100" ), - join_columns=[JoinColumns(source_name="s_suppkey", target_name="s_suppkey_t")], + join_columns=["s_suppkey"], select_columns=["s_suppkey", "s_name", "s_address"], drop_columns=None, column_mapping=[ @@ -309,14 +308,14 @@ def test_query_builder_with_jdbc_reader_options(): assert actual_tgt_query == expected_tgt_query -def test_query_builder_with_threshold(): +def test_hash_query_builder_with_threshold(): table_conf = Tables( source_name="supplier", target_name="supplier", jdbc_reader_options=JdbcReaderOptions( number_partitions=100, partition_column="s_nationkey", lower_bound="0", upper_bound="100" ), - join_columns=[JoinColumns(source_name="s_suppkey", target_name="s_suppkey_t")], + join_columns=["s_suppkey"], select_columns=None, drop_columns=None, column_mapping=[ @@ -371,3 +370,45 @@ def test_query_builder_with_threshold(): ) assert actual_tgt_query == expected_tgt_query + + +def test_threshold_query_builder_with_defaults(): + table_conf = Tables( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=None, + join_columns=["s_suppkey"], + select_columns=None, + drop_columns=None, + column_mapping=None, + transformations=None, + thresholds=[Thresholds(column_name="s_acctbal", lower_bound="0", upper_bound="100", type="int")], + filters=None, + ) + src_schema = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "varchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "varchar"), + ] + + actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_threshold_query() + expected_src_query = 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from supplier where 1 = 1 ' + assert actual_src_query == expected_src_query + + tgt_schema = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "varchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "varchar"), + ] + + actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_threshold_query() + expected_tgt_query = 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from supplier where 1 = 1 ' + assert actual_tgt_query == expected_tgt_query From 8f1af1074bc695dfb3660d3ea748e204b1732c83 Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Wed, 20 Mar 2024 22:53:59 +0530 Subject: [PATCH 02/15] Refactored the code to reduce redundancy --- .../labs/remorph/reconcile/query_builder.py | 131 ++++++++++-------- tests/unit/reconcile/test_query_builder.py | 37 +++-- 2 files changed, 92 insertions(+), 76 deletions(-) diff --git a/src/databricks/labs/remorph/reconcile/query_builder.py b/src/databricks/labs/remorph/reconcile/query_builder.py index 12157231b..250a1fef8 100644 --- a/src/databricks/labs/remorph/reconcile/query_builder.py +++ b/src/databricks/labs/remorph/reconcile/query_builder.py @@ -31,14 +31,14 @@ def build_hash_query(self) -> str: columns, key_columns = self._get_column_list() col_transformations = self._generate_transformation_rule_mapping(columns, schema_info) - hash_columns_expr = self._get_column_expr( - TransformRuleMapping.get_column_expression_without_alias, col_transformations + hash_columns_expr = sorted( + self._get_column_expr(TransformRuleMapping.get_column_expression_without_alias, col_transformations) ) hash_expr = self._generate_hash_algorithm(self.source, hash_columns_expr) key_column_transformation = self._generate_transformation_rule_mapping(key_columns, schema_info) - key_column_expr = self._get_column_expr( - TransformRuleMapping.get_column_expression_with_alias, key_column_transformation + key_column_expr = sorted( + self._get_column_expr(TransformRuleMapping.get_column_expression_with_alias, key_column_transformation) ) if self.layer == "source": @@ -93,42 +93,23 @@ def _generate_transformation_rule_mapping(self, columns: list[str], schema: dict transformations_dict = self.table_conf.list_to_dict(Transformation, "column_name") column_mapping_dict = self.table_conf.list_to_dict(ColumnMapping, "source_name") - transformation_rule_mapping = [] - for column in columns: - - if transformations_dict and column in transformations_dict.keys(): - transformation = self._get_layer_transform(transformations_dict, column, self.layer) - else: - column_origin = column if self.layer == "source" else self._get_column_map(column, column_mapping_dict) - column_data_type = schema.get(column_origin).data_type - transformation = self._get_default_transformation(self.source, column_data_type).format(column_origin) + if transformations_dict: + columns_with_transformation = [column for column in columns if column in transformations_dict.keys()] + custom_transformation = self._get_custom_transformation( + columns_with_transformation, transformations_dict, column_mapping_dict + ) + else: + custom_transformation = [] - if column_mapping_dict and column in column_mapping_dict.keys() and self.layer == "target": - column_alias = column_mapping_dict.get(column).source_name - column_origin = column_mapping_dict.get(column).target_name - else: - column_alias = column - column_origin = column + columns_without_transformation = [column for column in columns if column not in transformations_dict.keys()] + default_transformation = self._get_default_transformation( + columns_without_transformation, column_mapping_dict, schema + ) - transformation_rule_mapping.append(TransformRuleMapping(column_origin, transformation, column_alias)) + transformation_rule_mapping = custom_transformation + default_transformation return transformation_rule_mapping - @staticmethod - def _get_default_transformation(data_source: str, data_type: str) -> str: - if data_source == "oracle": - return OracleDataSource.oracle_datatype_mapper.get(data_type, ColumnTransformationType.ORACLE_DEFAULT.value) - if data_source == "snowflake": - return SnowflakeDataSource.snowflake_datatype_mapper.get( - data_type, ColumnTransformationType.SNOWFLAKE_DEFAULT.value - ) - if data_source == "databricks": - return DatabricksDataSource.databricks_datatype_mapper.get( - data_type, ColumnTransformationType.DATABRICKS_DEFAULT.value - ) - msg = f"Unsupported source type --> {data_source}" - raise ValueError(msg) - @staticmethod def _get_layer_transform(transform_dict: dict[str, Transformation], column: str, layer: str) -> str: return transform_dict.get(column).source if layer == "source" else transform_dict.get(column).target @@ -171,6 +152,61 @@ def _get_mapped_columns(column_mapping: dict, columns: set[str]) -> set[str]: def _get_column_map(column, column_mapping) -> str: return column_mapping.get(column).target_name if column_mapping.get(column) else column + def _get_custom_transformation(self, columns, transformation_dict, column_mapping): + transformation_rule_mapping = [] + for column in columns: + if column in transformation_dict.keys(): + transformation = self._get_layer_transform(transformation_dict, column, self.layer) + else: + transformation = None + + column_origin, column_alias = self._get_column_alias(self.layer, column, column_mapping) + + transformation_rule_mapping.append(TransformRuleMapping(column_origin, transformation, column_alias)) + + return transformation_rule_mapping + + def _get_default_transformation(self, columns, column_mapping, schema): + transformation_rule_mapping = [] + for column in columns: + column_origin = column if self.layer == "source" else self._get_column_map(column, column_mapping) + column_data_type = schema.get(column_origin).data_type + transformation = self._get_default_transformation_mapping(self.source, column_data_type).format( + column_origin + ) + + column_origin, column_alias = self._get_column_alias(self.layer, column, column_mapping) + + transformation_rule_mapping.append(TransformRuleMapping(column_origin, transformation, column_alias)) + + return transformation_rule_mapping + + @staticmethod + def _get_default_transformation_mapping(data_source: str, data_type: str) -> str: + if data_source == "oracle": + return OracleDataSource.oracle_datatype_mapper.get(data_type, ColumnTransformationType.ORACLE_DEFAULT.value) + if data_source == "snowflake": + return SnowflakeDataSource.snowflake_datatype_mapper.get( + data_type, ColumnTransformationType.SNOWFLAKE_DEFAULT.value + ) + if data_source == "databricks": + return DatabricksDataSource.databricks_datatype_mapper.get( + data_type, ColumnTransformationType.DATABRICKS_DEFAULT.value + ) + msg = f"Unsupported source type --> {data_source}" + raise ValueError(msg) + + @staticmethod + def _get_column_alias(layer, column, column_mapping): + if column_mapping and column in column_mapping.keys() and layer == "target": + column_alias = column_mapping.get(column).source_name + column_origin = column_mapping.get(column).target_name + else: + column_alias = column + column_origin = column + + return column_origin, column_alias + def build_threshold_query(self) -> str: column_mapping = self.table_conf.list_to_dict(ColumnMapping, "source_name") transformations_dict = self.table_conf.list_to_dict(Transformation, "column_name") @@ -189,8 +225,9 @@ def build_threshold_query(self) -> str: all_columns if self.layer == "source" else self._get_mapped_columns(column_mapping, all_columns) ) - transformation_rule_mapping = self._get_custom_transformation(query_columns, transformations_dict, - column_mapping) + transformation_rule_mapping = self._get_custom_transformation( + query_columns, transformations_dict, column_mapping + ) threshold_columns_expr = self._get_column_expr( TransformRuleMapping.get_column_expression_with_alias, transformation_rule_mapping ) @@ -207,25 +244,6 @@ def build_threshold_query(self) -> str: return select_query - def _get_custom_transformation(self, columns, transformation, column_mapping): - transformation_rule_mapping = [] - for column in columns: - if transformation and column in transformation.keys(): - transformation = self._get_layer_transform(transformation, column, self.layer) - else: - transformation = None - - if column_mapping and column in column_mapping.keys() and self.layer == "target": - column_alias = column_mapping.get(column).source_name - column_src = column_mapping.get(column).target_name - else: - column_alias = column - column_src = column - - transformation_rule_mapping.append(TransformRuleMapping(column_src, transformation, column_alias)) - - return transformation_rule_mapping - @staticmethod def _construct_threshold_query(table_name, query_filter, threshold_columns_expr): sql_query = StringIO() @@ -237,4 +255,3 @@ def _construct_threshold_query(table_name, query_filter, threshold_columns_expr) select_query = sql_query.getvalue() sql_query.close() return select_query - diff --git a/tests/unit/reconcile/test_query_builder.py b/tests/unit/reconcile/test_query_builder.py index 55f29e293..86523090f 100644 --- a/tests/unit/reconcile/test_query_builder.py +++ b/tests/unit/reconcile/test_query_builder.py @@ -218,11 +218,11 @@ def test_hash_query_builder_with_transformations_with_drop_and_default_select(): actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() expected_src_query = ( - 'select lower(RAWTOHEX(STANDARD_HASH(trim(to_char(s_acctbal_t, ' - "'9999999999.99')) || trim(s_address) || trim(s_name) || " - "coalesce(trim(s_nationkey),'') || trim(s_phone) || " - "coalesce(trim(s_suppkey),''), 'SHA256'))) as hash_value__recon, " - "coalesce(trim(s_suppkey),'') as s_suppkey from supplier where 1 = 1 " + "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_nationkey),'') || " + "coalesce(trim(s_suppkey),'') || trim(s_address) || trim(s_name) || " + "trim(s_phone) || trim(to_char(s_acctbal_t, '9999999999.99')), 'SHA256'))) as " + "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from supplier " + "where 1 = 1 " ) assert actual_src_query == expected_src_query @@ -238,11 +238,10 @@ def test_hash_query_builder_with_transformations_with_drop_and_default_select(): actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() expected_tgt_query = ( - 'select sha2(concat(cast(s_acctbal_t as decimal(38,2)), trim(s_address_t), ' - "trim(s_name), coalesce(trim(s_nationkey_t),''), " - "trim(s_phone_t), coalesce(trim(s_suppkey_t),'')),256) as " - "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier " - 'where 1 = 1 ' + "select sha2(concat(cast(s_acctbal_t as decimal(38,2)), " + "coalesce(trim(s_nationkey_t),''), coalesce(trim(s_suppkey_t),''), " + 'trim(s_address_t), trim(s_name), trim(s_phone_t)),256) as hash_value__recon, ' + "coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query @@ -342,12 +341,12 @@ def test_hash_query_builder_with_threshold(): actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() expected_src_query = ( - 'select lower(RAWTOHEX(STANDARD_HASH(trim(s_address) || ' - "coalesce(trim(s_comment),'') || trim(s_name) || " - "coalesce(trim(s_nationkey),'') || trim(s_phone) || " - "coalesce(trim(s_suppkey),''), 'SHA256'))) as hash_value__recon, " - "coalesce(trim(s_nationkey),'') as s_nationkey,coalesce(trim(s_suppkey),'') " - 'as s_suppkey from supplier where 1 = 1 ' + "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_comment),'') || " + "coalesce(trim(s_nationkey),'') || coalesce(trim(s_suppkey),'') || " + "trim(s_address) || trim(s_name) || trim(s_phone), 'SHA256'))) as " + "hash_value__recon, coalesce(trim(s_nationkey),'') as " + "s_nationkey,coalesce(trim(s_suppkey),'') as s_suppkey from supplier where 1 " + '= 1 ' ) assert actual_src_query == expected_src_query @@ -363,9 +362,9 @@ def test_hash_query_builder_with_threshold(): actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() expected_tgt_query = ( - "select sha2(concat(trim(s_address_t), coalesce(trim(s_comment),''), " - "trim(s_name), coalesce(trim(s_nationkey),''), trim(s_phone), " - "coalesce(trim(s_suppkey_t),'')),256) as hash_value__recon, " + "select sha2(concat(coalesce(trim(s_comment),''), " + "coalesce(trim(s_nationkey),''), coalesce(trim(s_suppkey_t),''), " + 'trim(s_address_t), trim(s_name), trim(s_phone)),256) as hash_value__recon, ' "coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier where 1 = 1 " ) From 44329b7911d907de8819bfabbcac826063ff99db Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Thu, 21 Mar 2024 09:50:47 +0530 Subject: [PATCH 03/15] Added inline comments and renamed a function --- .../labs/remorph/reconcile/query_builder.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/databricks/labs/remorph/reconcile/query_builder.py b/src/databricks/labs/remorph/reconcile/query_builder.py index 250a1fef8..73ec3388e 100644 --- a/src/databricks/labs/remorph/reconcile/query_builder.py +++ b/src/databricks/labs/remorph/reconcile/query_builder.py @@ -29,18 +29,21 @@ def build_hash_query(self) -> str: schema_info = {v.column_name: v for v in self.schema} columns, key_columns = self._get_column_list() - col_transformations = self._generate_transformation_rule_mapping(columns, schema_info) + # get transformation for columns considered for hashing + col_transformations = self._generate_transformation_rule_mapping(columns, schema_info) hash_columns_expr = sorted( self._get_column_expr(TransformRuleMapping.get_column_expression_without_alias, col_transformations) ) hash_expr = self._generate_hash_algorithm(self.source, hash_columns_expr) + # get transformation for columns considered for joining and partition key key_column_transformation = self._generate_transformation_rule_mapping(key_columns, schema_info) key_column_expr = sorted( self._get_column_expr(TransformRuleMapping.get_column_expression_with_alias, key_column_transformation) ) + # get table_name and query filter if self.layer == "source": table_name = self.table_conf.source_name query_filter = self.table_conf.filters.source if self.table_conf.filters else " 1 = 1 " @@ -48,7 +51,7 @@ def build_hash_query(self) -> str: table_name = self.table_conf.target_name query_filter = self.table_conf.filters.target if self.table_conf.filters else " 1 = 1 " - # construct select query + # construct select hash query select_query = self._construct_hash_query(table_name, query_filter, hash_expr, key_column_expr) return select_query @@ -56,11 +59,13 @@ def build_hash_query(self) -> str: def _get_column_list(self) -> tuple[list[str], list[str]]: tgt_column_mapping = self.table_conf.list_to_dict(ColumnMapping, "target_name") + # get join columns if self.table_conf.join_columns is None: join_columns = set() else: join_columns = set(self.table_conf.join_columns) + # get select columns if self.table_conf.select_columns is None: columns = {sch.column_name for sch in self.schema} select_columns = ( @@ -69,15 +74,16 @@ def _get_column_list(self) -> tuple[list[str], list[str]]: else: select_columns = set(self.table_conf.select_columns) + # get partition key for jdbc reader options if self.table_conf.jdbc_reader_options and self.layer == "source": partition_column = {self.table_conf.jdbc_reader_options.partition_column} else: partition_column = set() - # Combine all column names + # combine all column names for hashing all_columns = join_columns | select_columns - # Remove threshold and drop columns + # remove threshold and drop columns threshold_columns = {thresh.column_name for thresh in self.table_conf.thresholds or []} if self.table_conf.drop_columns is None: drop_columns = set() @@ -93,6 +99,7 @@ def _generate_transformation_rule_mapping(self, columns: list[str], schema: dict transformations_dict = self.table_conf.list_to_dict(Transformation, "column_name") column_mapping_dict = self.table_conf.list_to_dict(ColumnMapping, "source_name") + # compute custom transformation if transformations_dict: columns_with_transformation = [column for column in columns if column in transformations_dict.keys()] custom_transformation = self._get_custom_transformation( @@ -101,6 +108,7 @@ def _generate_transformation_rule_mapping(self, columns: list[str], schema: dict else: custom_transformation = [] + # compute default transformation columns_without_transformation = [column for column in columns if column not in transformations_dict.keys()] default_transformation = self._get_default_transformation( columns_without_transformation, column_mapping_dict, schema @@ -130,6 +138,7 @@ def _generate_hash_algorithm(source: str, column_expr: list[str]) -> str: @staticmethod def _construct_hash_query(table_name: str, query_filter: str, hash_expr: str, key_column_expr: list[str]) -> str: sql_query = StringIO() + # construct hash expr sql_query.write(f"select {hash_expr} as {Constants.hash_column_name}") # add join column @@ -171,7 +180,7 @@ def _get_default_transformation(self, columns, column_mapping, schema): for column in columns: column_origin = column if self.layer == "source" else self._get_column_map(column, column_mapping) column_data_type = schema.get(column_origin).data_type - transformation = self._get_default_transformation_mapping(self.source, column_data_type).format( + transformation = self._get_default_transformation_expr(self.source, column_data_type).format( column_origin ) @@ -182,7 +191,7 @@ def _get_default_transformation(self, columns, column_mapping, schema): return transformation_rule_mapping @staticmethod - def _get_default_transformation_mapping(data_source: str, data_type: str) -> str: + def _get_default_transformation_expr(data_source: str, data_type: str) -> str: if data_source == "oracle": return OracleDataSource.oracle_datatype_mapper.get(data_type, ColumnTransformationType.ORACLE_DEFAULT.value) if data_source == "snowflake": @@ -225,6 +234,7 @@ def build_threshold_query(self) -> str: all_columns if self.layer == "source" else self._get_mapped_columns(column_mapping, all_columns) ) + # get custom transformation transformation_rule_mapping = self._get_custom_transformation( query_columns, transformations_dict, column_mapping ) @@ -247,9 +257,11 @@ def build_threshold_query(self) -> str: @staticmethod def _construct_threshold_query(table_name, query_filter, threshold_columns_expr): sql_query = StringIO() + # construct threshold expr column_expr = ",".join(threshold_columns_expr) sql_query.write(f"select {column_expr} ") + # add query filter sql_query.write(f" from {table_name} where {query_filter}") select_query = sql_query.getvalue() From f8bf844b8a53d93f3a64a2ef757b07ffc2586c40 Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Thu, 21 Mar 2024 10:55:32 +0530 Subject: [PATCH 04/15] Added a test case for threshold query builder --- tests/unit/reconcile/test_query_builder.py | 69 ++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/unit/reconcile/test_query_builder.py b/tests/unit/reconcile/test_query_builder.py index 86523090f..473b3b4b8 100644 --- a/tests/unit/reconcile/test_query_builder.py +++ b/tests/unit/reconcile/test_query_builder.py @@ -411,3 +411,72 @@ def test_threshold_query_builder_with_defaults(): actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_threshold_query() expected_tgt_query = 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from supplier where 1 = 1 ' assert actual_tgt_query == expected_tgt_query + + +def test_threshold_query_builder_with_transformations_and_jdbc(): + table_conf = Tables( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=JdbcReaderOptions( + number_partitions=100, partition_column="s_nationkey", lower_bound="0", upper_bound="100" + ), + join_columns=["s_suppkey"], + select_columns=None, + drop_columns=["s_comment"], + column_mapping=[ + ColumnMapping(source_name="s_suppkey", target_name="s_suppkey_t"), + ColumnMapping(source_name="s_address", target_name="s_address_t"), + ColumnMapping(source_name="s_nationkey", target_name="s_nationkey_t"), + ColumnMapping(source_name="s_phone", target_name="s_phone_t"), + ColumnMapping(source_name="s_acctbal", target_name="s_acctbal_t"), + ColumnMapping(source_name="s_comment", target_name="s_comment_t"), + ColumnMapping(source_name="s_suppdate", target_name="s_suppdate_t"), + ], + transformations=[ + Transformation(column_name="s_suppkey", source="trim(s_suppkey)", target="trim(s_suppkey_t)"), + Transformation(column_name="s_address", source="trim(s_address)", target="trim(s_address_t)"), + Transformation(column_name="s_phone", source="trim(s_phone)", target="trim(s_phone_t)"), + Transformation(column_name="s_name", source="trim(s_name)", target="trim(s_name)"), + Transformation( + column_name="s_acctbal", + source="trim(to_char(s_acctbal, '9999999999.99'))", + target="cast(s_acctbal_t as decimal(38,2))", + ), + ], + thresholds=[Thresholds(column_name="s_acctbal", lower_bound="0", upper_bound="100", type="int"), + Thresholds(column_name="s_suppdate", lower_bound="-86400", upper_bound="86400", type="timestamp")], + filters=None, + ) + src_schema = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "varchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "varchar"), + Schema("s_suppdate", "timestamp") + ] + + actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_threshold_query() + expected_src_query = ("select trim(to_char(s_acctbal, '9999999999.99')) as s_acctbal,s_nationkey " + "as s_nationkey,s_suppdate as s_suppdate,trim(s_suppkey) as s_suppkey from " + "supplier where 1 = 1 ") + assert actual_src_query == expected_src_query + + tgt_schema = [ + Schema("s_suppkey_t", "number"), + Schema("s_name", "varchar"), + Schema("s_address_t", "varchar"), + Schema("s_nationkey_t", "number"), + Schema("s_phone_t", "varchar"), + Schema("s_acctbal_t", "number"), + Schema("s_comment_t", "varchar"), + Schema("s_suppdate_t", "timestamp") + ] + + actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_threshold_query() + expected_tgt_query = ("select cast(s_acctbal_t as decimal(38,2)) as s_acctbal,s_suppdate_t as " + "s_suppdate,trim(s_suppkey_t) as s_suppkey from supplier where 1 = 1 ") + + assert actual_tgt_query == expected_tgt_query From 9011f5ebdbefac5cf7855463a651241a2d3c347a Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Thu, 21 Mar 2024 10:57:59 +0530 Subject: [PATCH 05/15] Fixed the format issue --- .../labs/remorph/reconcile/query_builder.py | 4 +--- tests/unit/reconcile/test_query_builder.py | 24 ++++++++++++------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/databricks/labs/remorph/reconcile/query_builder.py b/src/databricks/labs/remorph/reconcile/query_builder.py index 73ec3388e..5a21c7d04 100644 --- a/src/databricks/labs/remorph/reconcile/query_builder.py +++ b/src/databricks/labs/remorph/reconcile/query_builder.py @@ -180,9 +180,7 @@ def _get_default_transformation(self, columns, column_mapping, schema): for column in columns: column_origin = column if self.layer == "source" else self._get_column_map(column, column_mapping) column_data_type = schema.get(column_origin).data_type - transformation = self._get_default_transformation_expr(self.source, column_data_type).format( - column_origin - ) + transformation = self._get_default_transformation_expr(self.source, column_data_type).format(column_origin) column_origin, column_alias = self._get_column_alias(self.layer, column, column_mapping) diff --git a/tests/unit/reconcile/test_query_builder.py b/tests/unit/reconcile/test_query_builder.py index 473b3b4b8..6bffd1139 100644 --- a/tests/unit/reconcile/test_query_builder.py +++ b/tests/unit/reconcile/test_query_builder.py @@ -443,8 +443,10 @@ def test_threshold_query_builder_with_transformations_and_jdbc(): target="cast(s_acctbal_t as decimal(38,2))", ), ], - thresholds=[Thresholds(column_name="s_acctbal", lower_bound="0", upper_bound="100", type="int"), - Thresholds(column_name="s_suppdate", lower_bound="-86400", upper_bound="86400", type="timestamp")], + thresholds=[ + Thresholds(column_name="s_acctbal", lower_bound="0", upper_bound="100", type="int"), + Thresholds(column_name="s_suppdate", lower_bound="-86400", upper_bound="86400", type="timestamp"), + ], filters=None, ) src_schema = [ @@ -455,13 +457,15 @@ def test_threshold_query_builder_with_transformations_and_jdbc(): Schema("s_phone", "varchar"), Schema("s_acctbal", "number"), Schema("s_comment", "varchar"), - Schema("s_suppdate", "timestamp") + Schema("s_suppdate", "timestamp"), ] actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_threshold_query() - expected_src_query = ("select trim(to_char(s_acctbal, '9999999999.99')) as s_acctbal,s_nationkey " - "as s_nationkey,s_suppdate as s_suppdate,trim(s_suppkey) as s_suppkey from " - "supplier where 1 = 1 ") + expected_src_query = ( + "select trim(to_char(s_acctbal, '9999999999.99')) as s_acctbal,s_nationkey " + "as s_nationkey,s_suppdate as s_suppdate,trim(s_suppkey) as s_suppkey from " + "supplier where 1 = 1 " + ) assert actual_src_query == expected_src_query tgt_schema = [ @@ -472,11 +476,13 @@ def test_threshold_query_builder_with_transformations_and_jdbc(): Schema("s_phone_t", "varchar"), Schema("s_acctbal_t", "number"), Schema("s_comment_t", "varchar"), - Schema("s_suppdate_t", "timestamp") + Schema("s_suppdate_t", "timestamp"), ] actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_threshold_query() - expected_tgt_query = ("select cast(s_acctbal_t as decimal(38,2)) as s_acctbal,s_suppdate_t as " - "s_suppdate,trim(s_suppkey_t) as s_suppkey from supplier where 1 = 1 ") + expected_tgt_query = ( + "select cast(s_acctbal_t as decimal(38,2)) as s_acctbal,s_suppdate_t as " + "s_suppdate,trim(s_suppkey_t) as s_suppkey from supplier where 1 = 1 " + ) assert actual_tgt_query == expected_tgt_query From 1c3f87cd65139886184dc68ad9dadc36beaf3893 Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Thu, 21 Mar 2024 14:56:58 +0530 Subject: [PATCH 06/15] Added test case for snowflake source and exception scenario --- pyproject.toml | 2 +- .../labs/remorph/reconcile/recon_config.py | 4 +- tests/unit/reconcile/test_query_builder.py | 91 ++++++++++++++++++- 3 files changed, 92 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c3bbb373c..c6cf9f283 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,7 +132,7 @@ branch = true parallel = true [tool.coverage.report] -omit = ["src/databricks/labs/remorph/reconcile/*", +omit = [ "src/databricks/labs/remorph/coverage/*", "src/databricks/labs/remorph/helpers/execution_time.py", "__about__.py"] diff --git a/src/databricks/labs/remorph/reconcile/recon_config.py b/src/databricks/labs/remorph/reconcile/recon_config.py index 493a8bb9f..d010006e0 100644 --- a/src/databricks/labs/remorph/reconcile/recon_config.py +++ b/src/databricks/labs/remorph/reconcile/recon_config.py @@ -16,9 +16,7 @@ def get_column_expression_without_alias(self) -> str: return f"{self.column_name}" def get_column_expression_with_alias(self) -> str: - if self.alias_name: - return f"{self.get_column_expression_without_alias()} as {self.alias_name}" - return f"{self.get_column_expression_without_alias()} as {self.column_name}" + return f"{self.get_column_expression_without_alias()} as {self.alias_name}" @dataclass diff --git a/tests/unit/reconcile/test_query_builder.py b/tests/unit/reconcile/test_query_builder.py index 6bffd1139..72b9b1ba6 100644 --- a/tests/unit/reconcile/test_query_builder.py +++ b/tests/unit/reconcile/test_query_builder.py @@ -1,3 +1,5 @@ +import pytest + from databricks.labs.remorph.reconcile.query_builder import QueryBuilder from databricks.labs.remorph.reconcile.recon_config import ( ColumnMapping, @@ -5,7 +7,7 @@ Schema, Tables, Thresholds, - Transformation, + Transformation, Filters, ) @@ -371,6 +373,93 @@ def test_hash_query_builder_with_threshold(): assert actual_tgt_query == expected_tgt_query +def test_hash_query_builder_with_filters(): + table_conf = Tables( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=None, + join_columns=["s_suppkey"], + select_columns=["s_suppkey", "s_name", "s_address"], + drop_columns=None, + column_mapping=[ + ColumnMapping(source_name="s_suppkey", target_name="s_suppkey_t"), + ColumnMapping(source_name="s_address", target_name="s_address_t"), + ], + transformations=None, + thresholds=None, + filters=Filters(source="s_name='t' and s_address='a'", target="s_name='t' and s_address_t='a'"), + ) + src_schema = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "varchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "varchar"), + ] + + actual_src_query = QueryBuilder(table_conf, src_schema, "source", "snowflake").build_hash_query() + expected_src_query = ( + "select sha2(concat(coalesce(trim(s_address),''), coalesce(trim(s_name),''), " + "coalesce(trim(s_suppkey),'')),256) as hash_value__recon, " + "coalesce(trim(s_suppkey),'') as s_suppkey from supplier where s_name='t' and " + "s_address='a'" + ) + assert actual_src_query == expected_src_query + + tgt_schema = [ + Schema("s_suppkey_t", "number"), + Schema("s_name", "varchar"), + Schema("s_address_t", "varchar"), + Schema("s_nationkey_t", "number"), + Schema("s_phone_t", "varchar"), + Schema("s_acctbal_t", "number"), + Schema("s_comment_t", "varchar"), + ] + + actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + expected_tgt_query = ( + "select sha2(concat(coalesce(trim(s_address_t),''), " + "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " + "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier " + "where s_name='t' and s_address_t='a'" + ) + + assert actual_tgt_query == expected_tgt_query + + +def test_hash_query_builder_with_unsupported_source(): + table_conf = Tables( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=None, + join_columns=None, + select_columns=None, + drop_columns=None, + column_mapping=None, + transformations=None, + thresholds=None, + filters=None, + ) + src_schema = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "varchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "varchar"), + ] + + query_builder = QueryBuilder(table_conf, src_schema, "source", "abc") + + with pytest.raises(Exception) as exc_info: + query_builder.build_hash_query() + + assert (str(exc_info.value) == "Unsupported source type --> abc") + + def test_threshold_query_builder_with_defaults(): table_conf = Tables( source_name="supplier", From f7d41bc1dedc049098c98f8835b3f694b444d776 Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Thu, 21 Mar 2024 14:59:05 +0530 Subject: [PATCH 07/15] Fixed the format --- tests/unit/reconcile/test_query_builder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/reconcile/test_query_builder.py b/tests/unit/reconcile/test_query_builder.py index 72b9b1ba6..df15b037f 100644 --- a/tests/unit/reconcile/test_query_builder.py +++ b/tests/unit/reconcile/test_query_builder.py @@ -3,11 +3,12 @@ from databricks.labs.remorph.reconcile.query_builder import QueryBuilder from databricks.labs.remorph.reconcile.recon_config import ( ColumnMapping, + Filters, JdbcReaderOptions, Schema, Tables, Thresholds, - Transformation, Filters, + Transformation, ) @@ -457,7 +458,7 @@ def test_hash_query_builder_with_unsupported_source(): with pytest.raises(Exception) as exc_info: query_builder.build_hash_query() - assert (str(exc_info.value) == "Unsupported source type --> abc") + assert str(exc_info.value) == "Unsupported source type --> abc" def test_threshold_query_builder_with_defaults(): From 822e506a2624408fd1b27b1b639a97c7f8cff0f9 Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Thu, 21 Mar 2024 16:26:05 +0530 Subject: [PATCH 08/15] Added the schema and catalog information to query --- .../reconcile/connectors/data_source.py | 25 +++++++- .../reconcile/connectors/databricks.py | 13 +++- .../remorph/reconcile/connectors/oracle.py | 25 +++++--- .../remorph/reconcile/connectors/snowflake.py | 13 +++- .../labs/remorph/reconcile/execute.py | 6 +- .../labs/remorph/reconcile/query_builder.py | 28 ++++++--- .../labs/remorph/reconcile/recon_config.py | 4 +- tests/unit/reconcile/test_query_builder.py | 62 ++++++++++--------- 8 files changed, 114 insertions(+), 62 deletions(-) diff --git a/src/databricks/labs/remorph/reconcile/connectors/data_source.py b/src/databricks/labs/remorph/reconcile/connectors/data_source.py index df0a796de..8a25003c3 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/data_source.py +++ b/src/databricks/labs/remorph/reconcile/connectors/data_source.py @@ -1,3 +1,4 @@ +import re from abc import ABC, abstractmethod from databricks.sdk import WorkspaceClient # pylint: disable-next=wrong-import-order @@ -6,7 +7,6 @@ from databricks.labs.remorph.reconcile.recon_config import ( # pylint: disable=ungrouped-imports JdbcReaderOptions, Schema, - Tables, ) @@ -20,11 +20,18 @@ def __init__(self, source: str, spark: SparkSession, ws: WorkspaceClient, scope: self.scope = scope @abstractmethod - def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame: + def read_data( + self, catalog_name: str, schema_name: str, query: str, jdbc_reader_options: JdbcReaderOptions + ) -> DataFrame: return NotImplemented @abstractmethod - def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]: + def get_schema( + self, + catalog_name: str, + schema_name: str, + table_name: str, + ) -> list[Schema]: return NotImplemented def _get_jdbc_reader(self, query, jdbc_url, driver): @@ -48,3 +55,15 @@ def _get_jdbc_reader_options(jdbc_reader_options: JdbcReaderOptions): def _get_secrets(self, key_name): key = self.source + '_' + key_name return self.ws.secrets.get_secret(self.scope, key) + + @staticmethod + def _get_table_or_query( + catalog_name: str, + schema_name: str, + query: str, + ): + if re.search('select', query, re.IGNORECASE): + return query.format(catalog_name=catalog_name, schema_name=schema_name) + if catalog_name: + return catalog_name + "." + schema_name + "." + query + return schema_name + "." + query diff --git a/src/databricks/labs/remorph/reconcile/connectors/databricks.py b/src/databricks/labs/remorph/reconcile/connectors/databricks.py index d4be09730..383834998 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/databricks.py +++ b/src/databricks/labs/remorph/reconcile/connectors/databricks.py @@ -1,15 +1,22 @@ from pyspark.sql import DataFrame from databricks.labs.remorph.reconcile.connectors.data_source import DataSource -from databricks.labs.remorph.reconcile.recon_config import Schema, Tables +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema class DatabricksDataSource(DataSource): - def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame: + def read_data( + self, catalog_name: str, schema_name: str, query: str, jdbc_reader_options: JdbcReaderOptions + ) -> DataFrame: # Implement Databricks-specific logic here return NotImplemented - def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]: + def get_schema( + self, + catalog_name: str, + schema_name: str, + table_name: str, + ) -> list[Schema]: # Implement Databricks-specific logic here return NotImplemented diff --git a/src/databricks/labs/remorph/reconcile/connectors/oracle.py b/src/databricks/labs/remorph/reconcile/connectors/oracle.py index c3da2bd1f..076e87dc6 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/oracle.py +++ b/src/databricks/labs/remorph/reconcile/connectors/oracle.py @@ -3,7 +3,7 @@ from databricks.labs.remorph.reconcile.connectors.data_source import DataSource from databricks.labs.remorph.reconcile.constants import SourceDriver -from databricks.labs.remorph.reconcile.recon_config import Schema, Tables +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema class OracleDataSource(DataSource): @@ -16,16 +16,16 @@ def get_jdbc_url(self) -> str: f":{self._get_secrets('port')}/{self._get_secrets('database')}" ) - # TODO need to check schema_name,catalog_name is needed - def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame: + def read_data( + self, catalog_name: str, schema_name: str, query: str, jdbc_reader_options: JdbcReaderOptions + ) -> DataFrame: try: - if table_conf.jdbc_reader_options is None: - return self.reader(query).options(**self._get_timestamp_options()).load() + table_query = self._get_table_or_query(catalog_name, schema_name, query) + if jdbc_reader_options is None: + return self.reader(table_query).options(**self._get_timestamp_options()).load() return ( - self.reader(query) - .options( - **self._get_jdbc_reader_options(table_conf.jdbc_reader_options) | self._get_timestamp_options() - ) + self.reader(table_query) + .options(**self._get_jdbc_reader_options(jdbc_reader_options) | self._get_timestamp_options()) .load() ) except PySparkException as e: @@ -34,7 +34,12 @@ def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: ) raise PySparkException(error_msg) from e - def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]: + def get_schema( + self, + catalog_name: str, + schema_name: str, + table_name: str, + ) -> list[Schema]: try: schema_query = self._get_schema_query(table_name, schema_name) schema_df = self.reader(schema_query).load() diff --git a/src/databricks/labs/remorph/reconcile/connectors/snowflake.py b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py index f36c52381..f0d96150c 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py @@ -1,15 +1,22 @@ from pyspark.sql import DataFrame from databricks.labs.remorph.reconcile.connectors.data_source import DataSource -from databricks.labs.remorph.reconcile.recon_config import Schema, Tables +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema class SnowflakeDataSource(DataSource): - def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame: + def read_data( + self, catalog_name: str, schema_name: str, query: str, jdbc_reader_options: JdbcReaderOptions + ) -> DataFrame: # Implement Snowflake-specific logic here return NotImplemented - def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]: + def get_schema( + self, + catalog_name: str, + schema_name: str, + table_name: str, + ) -> list[Schema]: # Implement Snowflake-specific logic here return NotImplemented diff --git a/src/databricks/labs/remorph/reconcile/execute.py b/src/databricks/labs/remorph/reconcile/execute.py index 7ef53e347..c23a1ed39 100644 --- a/src/databricks/labs/remorph/reconcile/execute.py +++ b/src/databricks/labs/remorph/reconcile/execute.py @@ -4,7 +4,7 @@ from databricks.labs.blueprint.installation import Installation from databricks.labs.remorph.reconcile.connectors.data_source import DataSource -from databricks.labs.remorph.reconcile.recon_config import TableRecon, Tables +from databricks.labs.remorph.reconcile.recon_config import Table, TableRecon logger = logging.getLogger(__name__) @@ -27,10 +27,10 @@ def __init__(self, source: DataSource, target: DataSource): self.source = source self.target = target - def compare_schemas(self, table_conf: Tables, schema_name: str, catalog_name: str) -> bool: + def compare_schemas(self, table_conf: Table, schema_name: str, catalog_name: str) -> bool: raise NotImplementedError - def compare_data(self, table_conf: Tables, schema_name: str, catalog_name: str) -> bool: + def compare_data(self, table_conf: Table, schema_name: str, catalog_name: str) -> bool: raise NotImplementedError diff --git a/src/databricks/labs/remorph/reconcile/query_builder.py b/src/databricks/labs/remorph/reconcile/query_builder.py index 5a21c7d04..fbbdb091a 100644 --- a/src/databricks/labs/remorph/reconcile/query_builder.py +++ b/src/databricks/labs/remorph/reconcile/query_builder.py @@ -11,7 +11,7 @@ from databricks.labs.remorph.reconcile.recon_config import ( ColumnMapping, Schema, - Tables, + Table, Transformation, TransformRuleMapping, ) @@ -19,7 +19,7 @@ class QueryBuilder: - def __init__(self, table_conf: Tables, schema: list[Schema], layer: str, source: str): + def __init__(self, table_conf: Table, schema: list[Schema], layer: str, source: str): self.table_conf = table_conf self.schema = schema self.layer = layer @@ -45,10 +45,10 @@ def build_hash_query(self) -> str: # get table_name and query filter if self.layer == "source": - table_name = self.table_conf.source_name + table_name = self._get_table_name(self.source, self.table_conf.source_name) query_filter = self.table_conf.filters.source if self.table_conf.filters else " 1 = 1 " else: - table_name = self.table_conf.target_name + table_name = self._get_table_name(self.source, self.table_conf.target_name) query_filter = self.table_conf.filters.target if self.table_conf.filters else " 1 = 1 " # construct select hash query @@ -190,13 +190,13 @@ def _get_default_transformation(self, columns, column_mapping, schema): @staticmethod def _get_default_transformation_expr(data_source: str, data_type: str) -> str: - if data_source == "oracle": + if data_source == SourceType.ORACLE.value: return OracleDataSource.oracle_datatype_mapper.get(data_type, ColumnTransformationType.ORACLE_DEFAULT.value) - if data_source == "snowflake": + if data_source == SourceType.SNOWFLAKE.value: return SnowflakeDataSource.snowflake_datatype_mapper.get( data_type, ColumnTransformationType.SNOWFLAKE_DEFAULT.value ) - if data_source == "databricks": + if data_source == SourceType.DATABRICKS.value: return DatabricksDataSource.databricks_datatype_mapper.get( data_type, ColumnTransformationType.DATABRICKS_DEFAULT.value ) @@ -241,10 +241,10 @@ def build_threshold_query(self) -> str: ) if self.layer == "source": - table_name = self.table_conf.source_name + table_name = self._get_table_name(self.source, self.table_conf.source_name) query_filter = self.table_conf.filters.source if self.table_conf.filters else " 1 = 1 " else: - table_name = self.table_conf.target_name + table_name = self._get_table_name(self.source, self.table_conf.target_name) query_filter = self.table_conf.filters.target if self.table_conf.filters else " 1 = 1 " # construct threshold select query @@ -265,3 +265,13 @@ def _construct_threshold_query(table_name, query_filter, threshold_columns_expr) select_query = sql_query.getvalue() sql_query.close() return select_query + + @staticmethod + def _get_table_name(source, table_name): + if source == SourceType.ORACLE.value: + return "{{schema_name}}.{table_name}".format( # pylint: disable=consider-using-f-string + table_name=table_name + ) + return "{{catalog_name}}.{{schema_name}}.{table_name}".format( # pylint: disable=consider-using-f-string + table_name=table_name + ) diff --git a/src/databricks/labs/remorph/reconcile/recon_config.py b/src/databricks/labs/remorph/reconcile/recon_config.py index d010006e0..8300b046e 100644 --- a/src/databricks/labs/remorph/reconcile/recon_config.py +++ b/src/databricks/labs/remorph/reconcile/recon_config.py @@ -56,7 +56,7 @@ class Filters: @dataclass -class Tables: +class Table: source_name: str target_name: str join_columns: list[str] | None = None @@ -83,7 +83,7 @@ class TableRecon: source_schema: str target_catalog: str target_schema: str - tables: list[Tables] + tables: list[Table] source_catalog: str | None = None diff --git a/tests/unit/reconcile/test_query_builder.py b/tests/unit/reconcile/test_query_builder.py index df15b037f..8293124d9 100644 --- a/tests/unit/reconcile/test_query_builder.py +++ b/tests/unit/reconcile/test_query_builder.py @@ -6,14 +6,14 @@ Filters, JdbcReaderOptions, Schema, - Tables, + Table, Thresholds, Transformation, ) def test_hash_query_builder_without_join_column(): - table_conf = Tables( + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=None, @@ -41,7 +41,7 @@ def test_hash_query_builder_without_join_column(): "coalesce(trim(s_address),'') || coalesce(trim(s_comment),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_nationkey),'') || " "coalesce(trim(s_phone),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " - "hash_value__recon from supplier " + "hash_value__recon from {schema_name}.supplier " "where 1 = 1 " ) assert actual_src_query == expected_src_query @@ -62,14 +62,14 @@ def test_hash_query_builder_without_join_column(): "coalesce(trim(s_address),''), coalesce(trim(s_comment),''), " "coalesce(trim(s_name),''), coalesce(trim(s_nationkey),''), " "coalesce(trim(s_phone),''), coalesce(trim(s_suppkey),'')),256) as " - "hash_value__recon from supplier " + "hash_value__recon from {catalog_name}.{schema_name}.supplier " "where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query def test_hash_query_builder_with_defaults(): - table_conf = Tables( + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=None, @@ -97,7 +97,7 @@ def test_hash_query_builder_with_defaults(): "coalesce(trim(s_address),'') || coalesce(trim(s_comment),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_nationkey),'') || " "coalesce(trim(s_phone),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " - "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from {schema_name}.supplier " "where 1 = 1 " ) assert actual_src_query == expected_src_query @@ -118,14 +118,14 @@ def test_hash_query_builder_with_defaults(): "coalesce(trim(s_address),''), coalesce(trim(s_comment),''), " "coalesce(trim(s_name),''), coalesce(trim(s_nationkey),''), " "coalesce(trim(s_phone),''), coalesce(trim(s_suppkey),'')),256) as " - "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from {catalog_name}.{schema_name}.supplier " "where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query def test_hash_query_builder_with_select(): - table_conf = Tables( + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=None, @@ -154,7 +154,7 @@ def test_hash_query_builder_with_select(): expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_address),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " - "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from {schema_name}.supplier " "where 1 = 1 " ) assert actual_src_query == expected_src_query @@ -173,7 +173,7 @@ def test_hash_query_builder_with_select(): expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_address_t),''), " "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " - "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from {catalog_name}.{schema_name}.supplier " "where 1 = 1 " ) @@ -181,7 +181,7 @@ def test_hash_query_builder_with_select(): def test_hash_query_builder_with_transformations_with_drop_and_default_select(): - table_conf = Tables( + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=None, @@ -224,7 +224,7 @@ def test_hash_query_builder_with_transformations_with_drop_and_default_select(): "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_nationkey),'') || " "coalesce(trim(s_suppkey),'') || trim(s_address) || trim(s_name) || " "trim(s_phone) || trim(to_char(s_acctbal_t, '9999999999.99')), 'SHA256'))) as " - "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from {schema_name}.supplier " "where 1 = 1 " ) assert actual_src_query == expected_src_query @@ -244,14 +244,14 @@ def test_hash_query_builder_with_transformations_with_drop_and_default_select(): "select sha2(concat(cast(s_acctbal_t as decimal(38,2)), " "coalesce(trim(s_nationkey_t),''), coalesce(trim(s_suppkey_t),''), " 'trim(s_address_t), trim(s_name), trim(s_phone_t)),256) as hash_value__recon, ' - "coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier where 1 = 1 " + "coalesce(trim(s_suppkey_t),'') as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query def test_hash_query_builder_with_jdbc_reader_options(): - table_conf = Tables( + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=JdbcReaderOptions( @@ -283,7 +283,7 @@ def test_hash_query_builder_with_jdbc_reader_options(): "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_address),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " "hash_value__recon, coalesce(trim(s_nationkey),'') as s_nationkey,coalesce(trim(s_suppkey),'') as s_suppkey " - "from supplier " + "from {schema_name}.supplier " "where 1 = 1 " ) @@ -303,7 +303,7 @@ def test_hash_query_builder_with_jdbc_reader_options(): expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_address_t),''), " "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " - "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from {catalog_name}.{schema_name}.supplier " "where 1 = 1 " ) @@ -311,7 +311,7 @@ def test_hash_query_builder_with_jdbc_reader_options(): def test_hash_query_builder_with_threshold(): - table_conf = Tables( + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=JdbcReaderOptions( @@ -348,7 +348,7 @@ def test_hash_query_builder_with_threshold(): "coalesce(trim(s_nationkey),'') || coalesce(trim(s_suppkey),'') || " "trim(s_address) || trim(s_name) || trim(s_phone), 'SHA256'))) as " "hash_value__recon, coalesce(trim(s_nationkey),'') as " - "s_nationkey,coalesce(trim(s_suppkey),'') as s_suppkey from supplier where 1 " + "s_nationkey,coalesce(trim(s_suppkey),'') as s_suppkey from {schema_name}.supplier where 1 " '= 1 ' ) assert actual_src_query == expected_src_query @@ -368,14 +368,14 @@ def test_hash_query_builder_with_threshold(): "select sha2(concat(coalesce(trim(s_comment),''), " "coalesce(trim(s_nationkey),''), coalesce(trim(s_suppkey_t),''), " 'trim(s_address_t), trim(s_name), trim(s_phone)),256) as hash_value__recon, ' - "coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier where 1 = 1 " + "coalesce(trim(s_suppkey_t),'') as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query def test_hash_query_builder_with_filters(): - table_conf = Tables( + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=None, @@ -404,7 +404,7 @@ def test_hash_query_builder_with_filters(): expected_src_query = ( "select sha2(concat(coalesce(trim(s_address),''), coalesce(trim(s_name),''), " "coalesce(trim(s_suppkey),'')),256) as hash_value__recon, " - "coalesce(trim(s_suppkey),'') as s_suppkey from supplier where s_name='t' and " + "coalesce(trim(s_suppkey),'') as s_suppkey from {catalog_name}.{schema_name}.supplier where s_name='t' and " "s_address='a'" ) assert actual_src_query == expected_src_query @@ -423,7 +423,7 @@ def test_hash_query_builder_with_filters(): expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_address_t),''), " "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " - "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from {catalog_name}.{schema_name}.supplier " "where s_name='t' and s_address_t='a'" ) @@ -431,7 +431,7 @@ def test_hash_query_builder_with_filters(): def test_hash_query_builder_with_unsupported_source(): - table_conf = Tables( + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=None, @@ -462,7 +462,7 @@ def test_hash_query_builder_with_unsupported_source(): def test_threshold_query_builder_with_defaults(): - table_conf = Tables( + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=None, @@ -485,7 +485,9 @@ def test_threshold_query_builder_with_defaults(): ] actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_threshold_query() - expected_src_query = 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from supplier where 1 = 1 ' + expected_src_query = ( + 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {schema_name}.supplier where 1 = 1 ' + ) assert actual_src_query == expected_src_query tgt_schema = [ @@ -499,12 +501,14 @@ def test_threshold_query_builder_with_defaults(): ] actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_threshold_query() - expected_tgt_query = 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from supplier where 1 = 1 ' + expected_tgt_query = ( + 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 ' + ) assert actual_tgt_query == expected_tgt_query def test_threshold_query_builder_with_transformations_and_jdbc(): - table_conf = Tables( + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=JdbcReaderOptions( @@ -554,7 +558,7 @@ def test_threshold_query_builder_with_transformations_and_jdbc(): expected_src_query = ( "select trim(to_char(s_acctbal, '9999999999.99')) as s_acctbal,s_nationkey " "as s_nationkey,s_suppdate as s_suppdate,trim(s_suppkey) as s_suppkey from " - "supplier where 1 = 1 " + "{schema_name}.supplier where 1 = 1 " ) assert actual_src_query == expected_src_query @@ -572,7 +576,7 @@ def test_threshold_query_builder_with_transformations_and_jdbc(): actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_threshold_query() expected_tgt_query = ( "select cast(s_acctbal_t as decimal(38,2)) as s_acctbal,s_suppdate_t as " - "s_suppdate,trim(s_suppkey_t) as s_suppkey from supplier where 1 = 1 " + "s_suppdate,trim(s_suppkey_t) as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query From ea30630468f8c610d45728a126cdde4b9d7f62ee Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Fri, 22 Mar 2024 14:47:45 +0530 Subject: [PATCH 09/15] Refactored the QueryBuilder to Abstract class and added query_config class --- .../labs/remorph/reconcile/query_builder.py | 275 +++++++----------- .../labs/remorph/reconcile/query_config.py | 71 +++++ tests/unit/reconcile/test_query_builder.py | 67 +++-- 3 files changed, 220 insertions(+), 193 deletions(-) create mode 100644 src/databricks/labs/remorph/reconcile/query_config.py diff --git a/src/databricks/labs/remorph/reconcile/query_builder.py b/src/databricks/labs/remorph/reconcile/query_builder.py index fbbdb091a..5c4b39871 100644 --- a/src/databricks/labs/remorph/reconcile/query_builder.py +++ b/src/databricks/labs/remorph/reconcile/query_builder.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from io import StringIO from databricks.labs.remorph.reconcile.connectors.databricks import DatabricksDataSource @@ -8,110 +9,87 @@ Constants, SourceType, ) +from databricks.labs.remorph.reconcile.query_config import QueryConfig from databricks.labs.remorph.reconcile.recon_config import ( - ColumnMapping, - Schema, - Table, Transformation, TransformRuleMapping, ) +# pylint: disable=invalid-name -class QueryBuilder: - def __init__(self, table_conf: Table, schema: list[Schema], layer: str, source: str): - self.table_conf = table_conf - self.schema = schema - self.layer = layer - self.source = source +class QueryBuilder(ABC): - def build_hash_query(self) -> str: - schema_info = {v.column_name: v for v in self.schema} + def __init__(self, qc: QueryConfig): + self.qc = qc - columns, key_columns = self._get_column_list() + @abstractmethod + def build_query(self): + raise NotImplementedError - # get transformation for columns considered for hashing - col_transformations = self._generate_transformation_rule_mapping(columns, schema_info) - hash_columns_expr = sorted( - self._get_column_expr(TransformRuleMapping.get_column_expression_without_alias, col_transformations) - ) - hash_expr = self._generate_hash_algorithm(self.source, hash_columns_expr) - - # get transformation for columns considered for joining and partition key - key_column_transformation = self._generate_transformation_rule_mapping(key_columns, schema_info) - key_column_expr = sorted( - self._get_column_expr(TransformRuleMapping.get_column_expression_with_alias, key_column_transformation) - ) - - # get table_name and query filter - if self.layer == "source": - table_name = self._get_table_name(self.source, self.table_conf.source_name) - query_filter = self.table_conf.filters.source if self.table_conf.filters else " 1 = 1 " - else: - table_name = self._get_table_name(self.source, self.table_conf.target_name) - query_filter = self.table_conf.filters.target if self.table_conf.filters else " 1 = 1 " - - # construct select hash query - select_query = self._construct_hash_query(table_name, query_filter, hash_expr, key_column_expr) + def _get_custom_transformation(self, columns, transformation_dict, column_mapping): + transformation_rule_mapping = [] + for column in columns: + if column in transformation_dict.keys(): + transformation = self._get_layer_transform(transformation_dict, column, self.qc.layer) + else: + transformation = None - return select_query + column_origin, column_alias = self._get_column_alias(self.qc.layer, column, column_mapping) - def _get_column_list(self) -> tuple[list[str], list[str]]: - tgt_column_mapping = self.table_conf.list_to_dict(ColumnMapping, "target_name") + transformation_rule_mapping.append(TransformRuleMapping(column_origin, transformation, column_alias)) - # get join columns - if self.table_conf.join_columns is None: - join_columns = set() - else: - join_columns = set(self.table_conf.join_columns) + return transformation_rule_mapping - # get select columns - if self.table_conf.select_columns is None: - columns = {sch.column_name for sch in self.schema} - select_columns = ( - columns if self.layer == "source" else self._get_mapped_columns(tgt_column_mapping, columns) + def _get_default_transformation(self, columns, column_mapping, schema): + transformation_rule_mapping = [] + for column in columns: + column_origin = column if self.qc.layer == "source" else self._get_column_map(column, column_mapping) + column_data_type = schema.get(column_origin).data_type + transformation = self._get_default_transformation_expr(self.qc.db_type, column_data_type).format( + column_origin ) - else: - select_columns = set(self.table_conf.select_columns) - - # get partition key for jdbc reader options - if self.table_conf.jdbc_reader_options and self.layer == "source": - partition_column = {self.table_conf.jdbc_reader_options.partition_column} - else: - partition_column = set() - # combine all column names for hashing - all_columns = join_columns | select_columns + column_origin, column_alias = self._get_column_alias(self.qc.layer, column, column_mapping) - # remove threshold and drop columns - threshold_columns = {thresh.column_name for thresh in self.table_conf.thresholds or []} - if self.table_conf.drop_columns is None: - drop_columns = set() - else: - drop_columns = set(self.table_conf.drop_columns) + transformation_rule_mapping.append(TransformRuleMapping(column_origin, transformation, column_alias)) - columns = sorted(all_columns - threshold_columns - drop_columns) - key_columns = sorted(join_columns | partition_column) + return transformation_rule_mapping - return columns, key_columns + @staticmethod + def _get_default_transformation_expr(data_source: str, data_type: str) -> str: + if data_source == SourceType.ORACLE.value: + return OracleDataSource.oracle_datatype_mapper.get(data_type, ColumnTransformationType.ORACLE_DEFAULT.value) + if data_source == SourceType.SNOWFLAKE.value: + return SnowflakeDataSource.snowflake_datatype_mapper.get( + data_type, ColumnTransformationType.SNOWFLAKE_DEFAULT.value + ) + if data_source == SourceType.DATABRICKS.value: + return DatabricksDataSource.databricks_datatype_mapper.get( + data_type, ColumnTransformationType.DATABRICKS_DEFAULT.value + ) + msg = f"Unsupported source type --> {data_source}" + raise ValueError(msg) - def _generate_transformation_rule_mapping(self, columns: list[str], schema: dict) -> list[TransformRuleMapping]: - transformations_dict = self.table_conf.list_to_dict(Transformation, "column_name") - column_mapping_dict = self.table_conf.list_to_dict(ColumnMapping, "source_name") + def _generate_transformation_rule_mapping(self, columns: list[str]) -> list[TransformRuleMapping]: # compute custom transformation - if transformations_dict: - columns_with_transformation = [column for column in columns if column in transformations_dict.keys()] + if self.qc.transformations_dict: + columns_with_transformation = [ + column for column in columns if column in self.qc.transformations_dict.keys() + ] custom_transformation = self._get_custom_transformation( - columns_with_transformation, transformations_dict, column_mapping_dict + columns_with_transformation, self.qc.transformations_dict, self.qc.src_column_mapping ) else: custom_transformation = [] # compute default transformation - columns_without_transformation = [column for column in columns if column not in transformations_dict.keys()] + columns_without_transformation = [ + column for column in columns if column not in self.qc.transformations_dict.keys() + ] default_transformation = self._get_default_transformation( - columns_without_transformation, column_mapping_dict, schema + columns_without_transformation, self.qc.src_column_mapping, self.qc.schema_dict ) transformation_rule_mapping = custom_transformation + default_transformation @@ -126,6 +104,48 @@ def _get_layer_transform(transform_dict: dict[str, Transformation], column: str, def _get_column_expr(func, column_transformations: list[TransformRuleMapping]): return [func(transformation) for transformation in column_transformations] + @staticmethod + def _get_column_map(column, column_mapping) -> str: + return column_mapping.get(column).target_name if column_mapping.get(column) else column + + @staticmethod + def _get_column_alias(layer, column, column_mapping): + if column_mapping and column in column_mapping.keys() and layer == "target": + column_alias = column_mapping.get(column).source_name + column_origin = column_mapping.get(column).target_name + else: + column_alias = column + column_origin = column + + return column_origin, column_alias + + +class HashQueryBuilder(QueryBuilder): + + def build_query(self): + columns = sorted( + (self.qc.join_columns | self.qc.select_columns) - self.qc.threshold_columns - self.qc.drop_columns + ) + key_columns = sorted(self.qc.join_columns | self.qc.partition_column) + + # get transformation for columns considered for hashing + col_transformations = self._generate_transformation_rule_mapping(columns) + hash_columns_expr = sorted( + self._get_column_expr(TransformRuleMapping.get_column_expression_without_alias, col_transformations) + ) + hash_expr = self._generate_hash_algorithm(self.qc.db_type, hash_columns_expr) + + # get transformation for columns considered for joining and partition key + key_column_transformation = self._generate_transformation_rule_mapping(key_columns) + key_column_expr = sorted( + self._get_column_expr(TransformRuleMapping.get_column_expression_with_alias, key_column_transformation) + ) + + # construct select hash query + select_query = self._construct_hash_query(self.qc.table_name, self.qc.query_filter, hash_expr, key_column_expr) + + return select_query + @staticmethod def _generate_hash_algorithm(source: str, column_expr: list[str]) -> str: if source in {SourceType.DATABRICKS.value, SourceType.SNOWFLAKE.value}: @@ -150,105 +170,26 @@ def _construct_hash_query(table_name: str, query_filter: str, hash_expr: str, ke sql_query.close() return select_query - @staticmethod - def _get_mapped_columns(column_mapping: dict, columns: set[str]) -> set[str]: - select_columns = set() - for column in columns: - select_columns.add(column_mapping.get(column).source_name if column_mapping.get(column) else column) - return select_columns - - @staticmethod - def _get_column_map(column, column_mapping) -> str: - return column_mapping.get(column).target_name if column_mapping.get(column) else column - - def _get_custom_transformation(self, columns, transformation_dict, column_mapping): - transformation_rule_mapping = [] - for column in columns: - if column in transformation_dict.keys(): - transformation = self._get_layer_transform(transformation_dict, column, self.layer) - else: - transformation = None - - column_origin, column_alias = self._get_column_alias(self.layer, column, column_mapping) - - transformation_rule_mapping.append(TransformRuleMapping(column_origin, transformation, column_alias)) - - return transformation_rule_mapping - - def _get_default_transformation(self, columns, column_mapping, schema): - transformation_rule_mapping = [] - for column in columns: - column_origin = column if self.layer == "source" else self._get_column_map(column, column_mapping) - column_data_type = schema.get(column_origin).data_type - transformation = self._get_default_transformation_expr(self.source, column_data_type).format(column_origin) - column_origin, column_alias = self._get_column_alias(self.layer, column, column_mapping) +class ThresholdQueryBuilder(QueryBuilder): - transformation_rule_mapping.append(TransformRuleMapping(column_origin, transformation, column_alias)) - - return transformation_rule_mapping - - @staticmethod - def _get_default_transformation_expr(data_source: str, data_type: str) -> str: - if data_source == SourceType.ORACLE.value: - return OracleDataSource.oracle_datatype_mapper.get(data_type, ColumnTransformationType.ORACLE_DEFAULT.value) - if data_source == SourceType.SNOWFLAKE.value: - return SnowflakeDataSource.snowflake_datatype_mapper.get( - data_type, ColumnTransformationType.SNOWFLAKE_DEFAULT.value - ) - if data_source == SourceType.DATABRICKS.value: - return DatabricksDataSource.databricks_datatype_mapper.get( - data_type, ColumnTransformationType.DATABRICKS_DEFAULT.value - ) - msg = f"Unsupported source type --> {data_source}" - raise ValueError(msg) - - @staticmethod - def _get_column_alias(layer, column, column_mapping): - if column_mapping and column in column_mapping.keys() and layer == "target": - column_alias = column_mapping.get(column).source_name - column_origin = column_mapping.get(column).target_name - else: - column_alias = column - column_origin = column - - return column_origin, column_alias - - def build_threshold_query(self) -> str: - column_mapping = self.table_conf.list_to_dict(ColumnMapping, "source_name") - transformations_dict = self.table_conf.list_to_dict(Transformation, "column_name") - - threshold_columns = set(threshold.column_name for threshold in self.table_conf.thresholds) - join_columns = set(self.table_conf.join_columns) - - if self.table_conf.jdbc_reader_options and self.layer == "source": - partition_column = {self.table_conf.jdbc_reader_options.partition_column} - else: - partition_column = set() - - all_columns = set(threshold_columns | join_columns | partition_column) + def build_query(self): + all_columns = set(self.qc.threshold_columns | self.qc.join_columns | self.qc.partition_column) query_columns = sorted( - all_columns if self.layer == "source" else self._get_mapped_columns(column_mapping, all_columns) + all_columns + if self.qc.layer == "source" + else self.qc.get_mapped_columns(self.qc.src_column_mapping, all_columns) ) - # get custom transformation transformation_rule_mapping = self._get_custom_transformation( - query_columns, transformations_dict, column_mapping + query_columns, self.qc.transformations_dict, self.qc.src_column_mapping ) threshold_columns_expr = self._get_column_expr( TransformRuleMapping.get_column_expression_with_alias, transformation_rule_mapping ) - if self.layer == "source": - table_name = self._get_table_name(self.source, self.table_conf.source_name) - query_filter = self.table_conf.filters.source if self.table_conf.filters else " 1 = 1 " - else: - table_name = self._get_table_name(self.source, self.table_conf.target_name) - query_filter = self.table_conf.filters.target if self.table_conf.filters else " 1 = 1 " - - # construct threshold select query - select_query = self._construct_threshold_query(table_name, query_filter, threshold_columns_expr) + select_query = self._construct_threshold_query(self.qc.table_name, self.qc.query_filter, threshold_columns_expr) return select_query @@ -265,13 +206,3 @@ def _construct_threshold_query(table_name, query_filter, threshold_columns_expr) select_query = sql_query.getvalue() sql_query.close() return select_query - - @staticmethod - def _get_table_name(source, table_name): - if source == SourceType.ORACLE.value: - return "{{schema_name}}.{table_name}".format( # pylint: disable=consider-using-f-string - table_name=table_name - ) - return "{{catalog_name}}.{{schema_name}}.{table_name}".format( # pylint: disable=consider-using-f-string - table_name=table_name - ) diff --git a/src/databricks/labs/remorph/reconcile/query_config.py b/src/databricks/labs/remorph/reconcile/query_config.py new file mode 100644 index 000000000..fba70cc20 --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/query_config.py @@ -0,0 +1,71 @@ +from databricks.labs.remorph.reconcile.constants import SourceType +from databricks.labs.remorph.reconcile.recon_config import ( + ColumnMapping, + Schema, + Table, + Transformation, +) + + +class QueryConfig: # pylint: disable=too-many-instance-attributes) + def __init__(self, table_conf: Table, schema: list[Schema], layer: str, db_type: str): + self.table_conf = table_conf + self.schema = schema + self.layer = layer + self.db_type = db_type + self.schema_dict = {v.column_name: v for v in schema} + self.tgt_column_mapping = table_conf.list_to_dict(ColumnMapping, "target_name") + self.src_column_mapping = table_conf.list_to_dict(ColumnMapping, "source_name") + self.transformations_dict = table_conf.list_to_dict(Transformation, "column_name") + self.select_columns = self.get_select_columns() + self.drop_columns = self.get_drop_columns() + self.join_columns = self.get_join_columns() + self.partition_column = self.get_partition_column() + self.threshold_columns = {thresh.column_name for thresh in table_conf.thresholds or []} + self.table_name = self._get_table_name() + self.query_filter = self._get_filter() + + def get_join_columns(self): + if self.table_conf.join_columns is None: + return set() + return set(self.table_conf.join_columns) + + def get_select_columns(self): + if self.table_conf.select_columns is None: + columns = {sch.column_name for sch in self.schema} + return columns if self.layer == "source" else self.get_mapped_columns(self.tgt_column_mapping, columns) + return set(self.table_conf.select_columns) + + def get_partition_column(self): + if self.table_conf.jdbc_reader_options and self.layer == "source": + return {self.table_conf.jdbc_reader_options.partition_column} + return set() + + def get_drop_columns(self): + if self.table_conf.drop_columns is None: + return set() + return set(self.table_conf.drop_columns) + + def _get_table_name(self): + table_name = self.table_conf.source_name if self.layer == "source" else self.table_conf.target_name + if self.db_type == SourceType.ORACLE.value: + return "{{schema_name}}.{table_name}".format( # pylint: disable=consider-using-f-string + table_name=table_name + ) + return "{{catalog_name}}.{{schema_name}}.{table_name}".format( # pylint: disable=consider-using-f-string + table_name=table_name + ) + + def _get_filter(self): + if self.table_conf.filters is None: + return " 1 = 1 " + if self.layer == "source": + return self.table_conf.filters.source + return self.table_conf.filters.target + + @staticmethod + def get_mapped_columns(column_mapping: dict, columns: set[str]) -> set[str]: + select_columns = set() + for column in columns: + select_columns.add(column_mapping.get(column).source_name if column_mapping.get(column) else column) + return select_columns diff --git a/tests/unit/reconcile/test_query_builder.py b/tests/unit/reconcile/test_query_builder.py index 8293124d9..401e6bcf5 100644 --- a/tests/unit/reconcile/test_query_builder.py +++ b/tests/unit/reconcile/test_query_builder.py @@ -1,6 +1,10 @@ import pytest -from databricks.labs.remorph.reconcile.query_builder import QueryBuilder +from databricks.labs.remorph.reconcile.query_builder import ( + HashQueryBuilder, + ThresholdQueryBuilder, +) +from databricks.labs.remorph.reconcile.query_config import QueryConfig from databricks.labs.remorph.reconcile.recon_config import ( ColumnMapping, Filters, @@ -11,6 +15,8 @@ Transformation, ) +# pylint: disable=invalid-name + def test_hash_query_builder_without_join_column(): table_conf = Table( @@ -35,7 +41,8 @@ def test_hash_query_builder_without_join_column(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + qc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(qc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_acctbal),'') || " "coalesce(trim(s_address),'') || coalesce(trim(s_comment),'') || " @@ -56,7 +63,8 @@ def test_hash_query_builder_without_join_column(): Schema("s_comment", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(qc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_acctbal),''), " "coalesce(trim(s_address),''), coalesce(trim(s_comment),''), " @@ -91,7 +99,8 @@ def test_hash_query_builder_with_defaults(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + qc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(qc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_acctbal),'') || " "coalesce(trim(s_address),'') || coalesce(trim(s_comment),'') || " @@ -112,7 +121,8 @@ def test_hash_query_builder_with_defaults(): Schema("s_comment", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(qc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_acctbal),''), " "coalesce(trim(s_address),''), coalesce(trim(s_comment),''), " @@ -150,7 +160,8 @@ def test_hash_query_builder_with_select(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + qc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(qc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_address),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " @@ -169,7 +180,8 @@ def test_hash_query_builder_with_select(): Schema("s_comment_t", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(qc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_address_t),''), " "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " @@ -219,7 +231,8 @@ def test_hash_query_builder_with_transformations_with_drop_and_default_select(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + qc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(qc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_nationkey),'') || " "coalesce(trim(s_suppkey),'') || trim(s_address) || trim(s_name) || " @@ -239,7 +252,8 @@ def test_hash_query_builder_with_transformations_with_drop_and_default_select(): Schema("s_comment_t", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(qc).build_query() expected_tgt_query = ( "select sha2(concat(cast(s_acctbal_t as decimal(38,2)), " "coalesce(trim(s_nationkey_t),''), coalesce(trim(s_suppkey_t),''), " @@ -278,7 +292,8 @@ def test_hash_query_builder_with_jdbc_reader_options(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + qc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(qc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_address),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " @@ -299,7 +314,8 @@ def test_hash_query_builder_with_jdbc_reader_options(): Schema("s_comment_t", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(qc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_address_t),''), " "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " @@ -342,7 +358,8 @@ def test_hash_query_builder_with_threshold(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + qc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(qc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_comment),'') || " "coalesce(trim(s_nationkey),'') || coalesce(trim(s_suppkey),'') || " @@ -363,7 +380,8 @@ def test_hash_query_builder_with_threshold(): Schema("s_comment", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(qc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_comment),''), " "coalesce(trim(s_nationkey),''), coalesce(trim(s_suppkey_t),''), " @@ -400,7 +418,8 @@ def test_hash_query_builder_with_filters(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "snowflake").build_hash_query() + qc = QueryConfig(table_conf, src_schema, "source", "snowflake") + actual_src_query = HashQueryBuilder(qc).build_query() expected_src_query = ( "select sha2(concat(coalesce(trim(s_address),''), coalesce(trim(s_name),''), " "coalesce(trim(s_suppkey),'')),256) as hash_value__recon, " @@ -419,7 +438,8 @@ def test_hash_query_builder_with_filters(): Schema("s_comment_t", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(qc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_address_t),''), " "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " @@ -453,10 +473,11 @@ def test_hash_query_builder_with_unsupported_source(): Schema("s_comment", "varchar"), ] - query_builder = QueryBuilder(table_conf, src_schema, "source", "abc") + qc = QueryConfig(table_conf, src_schema, "source", "abc") + query_builder = HashQueryBuilder(qc) with pytest.raises(Exception) as exc_info: - query_builder.build_hash_query() + query_builder.build_query() assert str(exc_info.value) == "Unsupported source type --> abc" @@ -484,7 +505,8 @@ def test_threshold_query_builder_with_defaults(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_threshold_query() + qc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = ThresholdQueryBuilder(qc).build_query() expected_src_query = ( 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {schema_name}.supplier where 1 = 1 ' ) @@ -500,7 +522,8 @@ def test_threshold_query_builder_with_defaults(): Schema("s_comment", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_threshold_query() + qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = ThresholdQueryBuilder(qc).build_query() expected_tgt_query = ( 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 ' ) @@ -554,7 +577,8 @@ def test_threshold_query_builder_with_transformations_and_jdbc(): Schema("s_suppdate", "timestamp"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_threshold_query() + qc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = ThresholdQueryBuilder(qc).build_query() expected_src_query = ( "select trim(to_char(s_acctbal, '9999999999.99')) as s_acctbal,s_nationkey " "as s_nationkey,s_suppdate as s_suppdate,trim(s_suppkey) as s_suppkey from " @@ -573,7 +597,8 @@ def test_threshold_query_builder_with_transformations_and_jdbc(): Schema("s_suppdate_t", "timestamp"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_threshold_query() + qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = ThresholdQueryBuilder(qc).build_query() expected_tgt_query = ( "select cast(s_acctbal_t as decimal(38,2)) as s_acctbal,s_suppdate_t as " "s_suppdate,trim(s_suppkey_t) as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 " From c18445845a6b4db3c2dc155323a823a718538ee2 Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Fri, 22 Mar 2024 16:31:41 +0530 Subject: [PATCH 10/15] Fixed the review feedbacks --- .../reconcile/connectors/data_source.py | 27 ++++++++--------- .../reconcile/connectors/databricks.py | 10 +++---- .../remorph/reconcile/connectors/oracle.py | 16 +++++----- .../remorph/reconcile/connectors/snowflake.py | 10 +++---- .../labs/remorph/reconcile/query_builder.py | 30 +++++++++---------- .../labs/remorph/reconcile/query_config.py | 20 ++++++------- tests/unit/reconcile/test_query_builder.py | 8 ++--- 7 files changed, 54 insertions(+), 67 deletions(-) diff --git a/src/databricks/labs/remorph/reconcile/connectors/data_source.py b/src/databricks/labs/remorph/reconcile/connectors/data_source.py index 8a25003c3..2478ff121 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/data_source.py +++ b/src/databricks/labs/remorph/reconcile/connectors/data_source.py @@ -20,17 +20,15 @@ def __init__(self, source: str, spark: SparkSession, ws: WorkspaceClient, scope: self.scope = scope @abstractmethod - def read_data( - self, catalog_name: str, schema_name: str, query: str, jdbc_reader_options: JdbcReaderOptions - ) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, jdbc_reader_options: JdbcReaderOptions) -> DataFrame: return NotImplemented @abstractmethod def get_schema( self, - catalog_name: str, - schema_name: str, - table_name: str, + catalog: str, + schema: str, + table: str, ) -> list[Schema]: return NotImplemented @@ -52,18 +50,17 @@ def _get_jdbc_reader_options(jdbc_reader_options: JdbcReaderOptions): "fetchsize": jdbc_reader_options.fetch_size, } - def _get_secrets(self, key_name): - key = self.source + '_' + key_name - return self.ws.secrets.get_secret(self.scope, key) + def _get_secrets(self, key): + return self.ws.secrets.get_secret(self.scope, self.source + '_' + key) @staticmethod def _get_table_or_query( - catalog_name: str, - schema_name: str, + catalog: str, + schema: str, query: str, ): if re.search('select', query, re.IGNORECASE): - return query.format(catalog_name=catalog_name, schema_name=schema_name) - if catalog_name: - return catalog_name + "." + schema_name + "." + query - return schema_name + "." + query + return query.format(catalog_name=catalog, schema_name=schema) + if catalog: + return catalog + "." + schema + "." + query + return schema + "." + query diff --git a/src/databricks/labs/remorph/reconcile/connectors/databricks.py b/src/databricks/labs/remorph/reconcile/connectors/databricks.py index 383834998..fd73deff8 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/databricks.py +++ b/src/databricks/labs/remorph/reconcile/connectors/databricks.py @@ -5,17 +5,15 @@ class DatabricksDataSource(DataSource): - def read_data( - self, catalog_name: str, schema_name: str, query: str, jdbc_reader_options: JdbcReaderOptions - ) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, jdbc_reader_options: JdbcReaderOptions) -> DataFrame: # Implement Databricks-specific logic here return NotImplemented def get_schema( self, - catalog_name: str, - schema_name: str, - table_name: str, + catalog: str, + schema: str, + table: str, ) -> list[Schema]: # Implement Databricks-specific logic here return NotImplemented diff --git a/src/databricks/labs/remorph/reconcile/connectors/oracle.py b/src/databricks/labs/remorph/reconcile/connectors/oracle.py index 076e87dc6..ecd2a3a07 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/oracle.py +++ b/src/databricks/labs/remorph/reconcile/connectors/oracle.py @@ -16,11 +16,9 @@ def get_jdbc_url(self) -> str: f":{self._get_secrets('port')}/{self._get_secrets('database')}" ) - def read_data( - self, catalog_name: str, schema_name: str, query: str, jdbc_reader_options: JdbcReaderOptions - ) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, jdbc_reader_options: JdbcReaderOptions) -> DataFrame: try: - table_query = self._get_table_or_query(catalog_name, schema_name, query) + table_query = self._get_table_or_query(catalog, schema, query) if jdbc_reader_options is None: return self.reader(table_query).options(**self._get_timestamp_options()).load() return ( @@ -36,17 +34,17 @@ def read_data( def get_schema( self, - catalog_name: str, - schema_name: str, - table_name: str, + catalog: str, + schema: str, + table: str, ) -> list[Schema]: try: - schema_query = self._get_schema_query(table_name, schema_name) + schema_query = self._get_schema_query(table, schema) schema_df = self.reader(schema_query).load() return [Schema(field.column_name.lower(), field.data_type.lower()) for field in schema_df.collect()] except PySparkException as e: error_msg = ( - f"An error occurred while fetching Oracle Schema using the following {table_name} in " + f"An error occurred while fetching Oracle Schema using the following {table} in " f"OracleDataSource: {e!s}" ) raise PySparkException(error_msg) from e diff --git a/src/databricks/labs/remorph/reconcile/connectors/snowflake.py b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py index f0d96150c..fb2731d34 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py @@ -5,17 +5,15 @@ class SnowflakeDataSource(DataSource): - def read_data( - self, catalog_name: str, schema_name: str, query: str, jdbc_reader_options: JdbcReaderOptions - ) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, jdbc_reader_options: JdbcReaderOptions) -> DataFrame: # Implement Snowflake-specific logic here return NotImplemented def get_schema( self, - catalog_name: str, - schema_name: str, - table_name: str, + catalog: str, + schema: str, + table: str, ) -> list[Schema]: # Implement Snowflake-specific logic here return NotImplemented diff --git a/src/databricks/labs/remorph/reconcile/query_builder.py b/src/databricks/labs/remorph/reconcile/query_builder.py index 5c4b39871..4eabe1f2d 100644 --- a/src/databricks/labs/remorph/reconcile/query_builder.py +++ b/src/databricks/labs/remorph/reconcile/query_builder.py @@ -11,6 +11,7 @@ ) from databricks.labs.remorph.reconcile.query_config import QueryConfig from databricks.labs.remorph.reconcile.recon_config import ( + ColumnMapping, Transformation, TransformRuleMapping, ) @@ -106,7 +107,7 @@ def _get_column_expr(func, column_transformations: list[TransformRuleMapping]): @staticmethod def _get_column_map(column, column_mapping) -> str: - return column_mapping.get(column).target_name if column_mapping.get(column) else column + return column_mapping.get(column, ColumnMapping(source_name='', target_name=column)).target_name @staticmethod def _get_column_alias(layer, column, column_mapping): @@ -124,9 +125,11 @@ class HashQueryBuilder(QueryBuilder): def build_query(self): columns = sorted( - (self.qc.join_columns | self.qc.select_columns) - self.qc.threshold_columns - self.qc.drop_columns + (self.qc.get_join_columns() | self.qc.get_select_columns()) + - self.qc.get_threshold_columns() + - self.qc.get_drop_columns() ) - key_columns = sorted(self.qc.join_columns | self.qc.partition_column) + key_columns = sorted(self.qc.get_join_columns() | self.qc.get_partition_column()) # get transformation for columns considered for hashing col_transformations = self._generate_transformation_rule_mapping(columns) @@ -142,7 +145,9 @@ def build_query(self): ) # construct select hash query - select_query = self._construct_hash_query(self.qc.table_name, self.qc.query_filter, hash_expr, key_column_expr) + select_query = self._construct_hash_query( + self.qc.get_table_name(), self.qc.get_filter(), hash_expr, key_column_expr + ) return select_query @@ -174,7 +179,7 @@ def _construct_hash_query(table_name: str, query_filter: str, hash_expr: str, ke class ThresholdQueryBuilder(QueryBuilder): def build_query(self): - all_columns = set(self.qc.threshold_columns | self.qc.join_columns | self.qc.partition_column) + all_columns = set(self.qc.get_threshold_columns() | self.qc.get_join_columns() | self.qc.get_partition_column()) query_columns = sorted( all_columns @@ -189,20 +194,13 @@ def build_query(self): TransformRuleMapping.get_column_expression_with_alias, transformation_rule_mapping ) - select_query = self._construct_threshold_query(self.qc.table_name, self.qc.query_filter, threshold_columns_expr) + select_query = self._construct_threshold_query( + self.qc.get_table_name(), self.qc.get_filter(), threshold_columns_expr + ) return select_query @staticmethod def _construct_threshold_query(table_name, query_filter, threshold_columns_expr): - sql_query = StringIO() - # construct threshold expr column_expr = ",".join(threshold_columns_expr) - sql_query.write(f"select {column_expr} ") - - # add query filter - sql_query.write(f" from {table_name} where {query_filter}") - - select_query = sql_query.getvalue() - sql_query.close() - return select_query + return f"select {column_expr} from {table_name} where {query_filter}" diff --git a/src/databricks/labs/remorph/reconcile/query_config.py b/src/databricks/labs/remorph/reconcile/query_config.py index fba70cc20..4f70b27ef 100644 --- a/src/databricks/labs/remorph/reconcile/query_config.py +++ b/src/databricks/labs/remorph/reconcile/query_config.py @@ -7,7 +7,7 @@ ) -class QueryConfig: # pylint: disable=too-many-instance-attributes) +class QueryConfig: def __init__(self, table_conf: Table, schema: list[Schema], layer: str, db_type: str): self.table_conf = table_conf self.schema = schema @@ -17,13 +17,9 @@ def __init__(self, table_conf: Table, schema: list[Schema], layer: str, db_type: self.tgt_column_mapping = table_conf.list_to_dict(ColumnMapping, "target_name") self.src_column_mapping = table_conf.list_to_dict(ColumnMapping, "source_name") self.transformations_dict = table_conf.list_to_dict(Transformation, "column_name") - self.select_columns = self.get_select_columns() - self.drop_columns = self.get_drop_columns() - self.join_columns = self.get_join_columns() - self.partition_column = self.get_partition_column() - self.threshold_columns = {thresh.column_name for thresh in table_conf.thresholds or []} - self.table_name = self._get_table_name() - self.query_filter = self._get_filter() + + def get_threshold_columns(self): + return {thresh.column_name for thresh in self.table_conf.thresholds or []} def get_join_columns(self): if self.table_conf.join_columns is None: @@ -46,7 +42,7 @@ def get_drop_columns(self): return set() return set(self.table_conf.drop_columns) - def _get_table_name(self): + def get_table_name(self): table_name = self.table_conf.source_name if self.layer == "source" else self.table_conf.target_name if self.db_type == SourceType.ORACLE.value: return "{{schema_name}}.{table_name}".format( # pylint: disable=consider-using-f-string @@ -56,7 +52,7 @@ def _get_table_name(self): table_name=table_name ) - def _get_filter(self): + def get_filter(self): if self.table_conf.filters is None: return " 1 = 1 " if self.layer == "source": @@ -67,5 +63,7 @@ def _get_filter(self): def get_mapped_columns(column_mapping: dict, columns: set[str]) -> set[str]: select_columns = set() for column in columns: - select_columns.add(column_mapping.get(column).source_name if column_mapping.get(column) else column) + select_columns.add( + column_mapping.get(column, ColumnMapping(source_name=column, target_name='')).source_name + ) return select_columns diff --git a/tests/unit/reconcile/test_query_builder.py b/tests/unit/reconcile/test_query_builder.py index 401e6bcf5..1b9bb9e71 100644 --- a/tests/unit/reconcile/test_query_builder.py +++ b/tests/unit/reconcile/test_query_builder.py @@ -508,7 +508,7 @@ def test_threshold_query_builder_with_defaults(): qc = QueryConfig(table_conf, src_schema, "source", "oracle") actual_src_query = ThresholdQueryBuilder(qc).build_query() expected_src_query = ( - 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {schema_name}.supplier where 1 = 1 ' + 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {schema_name}.supplier where 1 = 1 ' ) assert actual_src_query == expected_src_query @@ -525,7 +525,7 @@ def test_threshold_query_builder_with_defaults(): qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") actual_tgt_query = ThresholdQueryBuilder(qc).build_query() expected_tgt_query = ( - 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 ' + 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 ' ) assert actual_tgt_query == expected_tgt_query @@ -581,7 +581,7 @@ def test_threshold_query_builder_with_transformations_and_jdbc(): actual_src_query = ThresholdQueryBuilder(qc).build_query() expected_src_query = ( "select trim(to_char(s_acctbal, '9999999999.99')) as s_acctbal,s_nationkey " - "as s_nationkey,s_suppdate as s_suppdate,trim(s_suppkey) as s_suppkey from " + "as s_nationkey,s_suppdate as s_suppdate,trim(s_suppkey) as s_suppkey from " "{schema_name}.supplier where 1 = 1 " ) assert actual_src_query == expected_src_query @@ -601,7 +601,7 @@ def test_threshold_query_builder_with_transformations_and_jdbc(): actual_tgt_query = ThresholdQueryBuilder(qc).build_query() expected_tgt_query = ( "select cast(s_acctbal_t as decimal(38,2)) as s_acctbal,s_suppdate_t as " - "s_suppdate,trim(s_suppkey_t) as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 " + "s_suppdate,trim(s_suppkey_t) as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query From 367bf7535c085ab641854063d993116e4f3b4bd9 Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Fri, 22 Mar 2024 18:29:46 +0530 Subject: [PATCH 11/15] Renamed the function parameters --- .../reconcile/connectors/data_source.py | 14 +- .../reconcile/connectors/databricks.py | 2 +- .../remorph/reconcile/connectors/oracle.py | 10 +- .../remorph/reconcile/connectors/snowflake.py | 2 +- .../labs/remorph/reconcile/query_builder.py | 149 +++++++++--------- .../labs/remorph/reconcile/query_config.py | 24 ++- .../labs/remorph/reconcile/recon_config.py | 6 +- 7 files changed, 99 insertions(+), 108 deletions(-) diff --git a/src/databricks/labs/remorph/reconcile/connectors/data_source.py b/src/databricks/labs/remorph/reconcile/connectors/data_source.py index 2478ff121..5435da68e 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/data_source.py +++ b/src/databricks/labs/remorph/reconcile/connectors/data_source.py @@ -20,7 +20,7 @@ def __init__(self, source: str, spark: SparkSession, ws: WorkspaceClient, scope: self.scope = scope @abstractmethod - def read_data(self, catalog: str, schema: str, query: str, jdbc_reader_options: JdbcReaderOptions) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame: return NotImplemented @abstractmethod @@ -41,13 +41,13 @@ def _get_jdbc_reader(self, query, jdbc_url, driver): ) @staticmethod - def _get_jdbc_reader_options(jdbc_reader_options: JdbcReaderOptions): + def _get_jdbc_reader_options(options: JdbcReaderOptions): return { - "numPartitions": jdbc_reader_options.number_partitions, - "partitionColumn": jdbc_reader_options.partition_column, - "lowerBound": jdbc_reader_options.lower_bound, - "upperBound": jdbc_reader_options.upper_bound, - "fetchsize": jdbc_reader_options.fetch_size, + "numPartitions": options.number_partitions, + "partitionColumn": options.partition_column, + "lowerBound": options.lower_bound, + "upperBound": options.upper_bound, + "fetchsize": options.fetch_size, } def _get_secrets(self, key): diff --git a/src/databricks/labs/remorph/reconcile/connectors/databricks.py b/src/databricks/labs/remorph/reconcile/connectors/databricks.py index fd73deff8..137ae208b 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/databricks.py +++ b/src/databricks/labs/remorph/reconcile/connectors/databricks.py @@ -5,7 +5,7 @@ class DatabricksDataSource(DataSource): - def read_data(self, catalog: str, schema: str, query: str, jdbc_reader_options: JdbcReaderOptions) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame: # Implement Databricks-specific logic here return NotImplemented diff --git a/src/databricks/labs/remorph/reconcile/connectors/oracle.py b/src/databricks/labs/remorph/reconcile/connectors/oracle.py index ecd2a3a07..f09578a7d 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/oracle.py +++ b/src/databricks/labs/remorph/reconcile/connectors/oracle.py @@ -16,14 +16,14 @@ def get_jdbc_url(self) -> str: f":{self._get_secrets('port')}/{self._get_secrets('database')}" ) - def read_data(self, catalog: str, schema: str, query: str, jdbc_reader_options: JdbcReaderOptions) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame: try: table_query = self._get_table_or_query(catalog, schema, query) - if jdbc_reader_options is None: + if options is None: return self.reader(table_query).options(**self._get_timestamp_options()).load() return ( self.reader(table_query) - .options(**self._get_jdbc_reader_options(jdbc_reader_options) | self._get_timestamp_options()) + .options(**self._get_jdbc_reader_options(options) | self._get_timestamp_options()) .load() ) except PySparkException as e: @@ -66,7 +66,7 @@ def reader(self, query: str) -> DataFrameReader: return self._get_jdbc_reader(query, self.get_jdbc_url, SourceDriver.ORACLE.value) @staticmethod - def _get_schema_query(table_name: str, owner: str) -> str: + def _get_schema_query(table: str, owner: str) -> str: return f"""select column_name, case when (data_precision is not null and data_scale <> 0) then data_type || '(' || data_precision || ',' || data_scale || ')' @@ -78,4 +78,4 @@ def _get_schema_query(table_name: str, owner: str) -> str: else data_type || '(' || CHAR_LENGTH || ')' end data_type FROM ALL_TAB_COLUMNS - WHERE lower(TABLE_NAME) = '{table_name}' and lower(owner) = '{owner}' """ + WHERE lower(TABLE_NAME) = '{table}' and lower(owner) = '{owner}' """ diff --git a/src/databricks/labs/remorph/reconcile/connectors/snowflake.py b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py index fb2731d34..c85f05303 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py @@ -5,7 +5,7 @@ class SnowflakeDataSource(DataSource): - def read_data(self, catalog: str, schema: str, query: str, jdbc_reader_options: JdbcReaderOptions) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame: # Implement Snowflake-specific logic here return NotImplemented diff --git a/src/databricks/labs/remorph/reconcile/query_builder.py b/src/databricks/labs/remorph/reconcile/query_builder.py index 4eabe1f2d..f869f2033 100644 --- a/src/databricks/labs/remorph/reconcile/query_builder.py +++ b/src/databricks/labs/remorph/reconcile/query_builder.py @@ -12,6 +12,7 @@ from databricks.labs.remorph.reconcile.query_config import QueryConfig from databricks.labs.remorph.reconcile.recon_config import ( ColumnMapping, + Schema, Transformation, TransformRuleMapping, ) @@ -28,34 +29,36 @@ def __init__(self, qc: QueryConfig): def build_query(self): raise NotImplementedError - def _get_custom_transformation(self, columns, transformation_dict, column_mapping): - transformation_rule_mapping = [] - for column in columns: - if column in transformation_dict.keys(): - transformation = self._get_layer_transform(transformation_dict, column, self.qc.layer) + def _get_custom_transformation( + self, cols: list[str], transform_dict: dict[str, Transformation], col_mapping: dict[str, ColumnMapping] + ): + transform_rule_mapping = [] + for col in cols: + if col in transform_dict.keys(): + transform = self._get_layer_transform(transform_dict, col, self.qc.layer) else: - transformation = None + transform = None - column_origin, column_alias = self._get_column_alias(self.qc.layer, column, column_mapping) + col_origin, col_alias = self._get_column_alias(self.qc.layer, col, col_mapping) - transformation_rule_mapping.append(TransformRuleMapping(column_origin, transformation, column_alias)) + transform_rule_mapping.append(TransformRuleMapping(col_origin, transform, col_alias)) - return transformation_rule_mapping + return transform_rule_mapping - def _get_default_transformation(self, columns, column_mapping, schema): - transformation_rule_mapping = [] - for column in columns: - column_origin = column if self.qc.layer == "source" else self._get_column_map(column, column_mapping) - column_data_type = schema.get(column_origin).data_type - transformation = self._get_default_transformation_expr(self.qc.db_type, column_data_type).format( - column_origin - ) + def _get_default_transformation( + self, cols: list[str], col_mapping: dict[str, ColumnMapping], schema: dict[str, Schema] + ): + transform_rule_mapping = [] + for col in cols: + col_origin = col if self.qc.layer == "source" else self._get_column_map(col, col_mapping) + col_data_type = schema.get(col_origin).data_type + transform = self._get_default_transformation_expr(self.qc.source, col_data_type).format(col_origin) - column_origin, column_alias = self._get_column_alias(self.qc.layer, column, column_mapping) + col_origin, col_alias = self._get_column_alias(self.qc.layer, col, col_mapping) - transformation_rule_mapping.append(TransformRuleMapping(column_origin, transformation, column_alias)) + transform_rule_mapping.append(TransformRuleMapping(col_origin, transform, col_alias)) - return transformation_rule_mapping + return transform_rule_mapping @staticmethod def _get_default_transformation_expr(data_source: str, data_type: str) -> str: @@ -72,104 +75,98 @@ def _get_default_transformation_expr(data_source: str, data_type: str) -> str: msg = f"Unsupported source type --> {data_source}" raise ValueError(msg) - def _generate_transformation_rule_mapping(self, columns: list[str]) -> list[TransformRuleMapping]: + def _generate_transform_rule_mapping(self, cols: list[str]) -> list[TransformRuleMapping]: # compute custom transformation - if self.qc.transformations_dict: - columns_with_transformation = [ - column for column in columns if column in self.qc.transformations_dict.keys() - ] - custom_transformation = self._get_custom_transformation( - columns_with_transformation, self.qc.transformations_dict, self.qc.src_column_mapping + if self.qc.transform_dict: + cols_with_transform = [col for col in cols if col in self.qc.transform_dict.keys()] + custom_transform = self._get_custom_transformation( + cols_with_transform, self.qc.transform_dict, self.qc.src_col_mapping ) else: - custom_transformation = [] + custom_transform = [] # compute default transformation - columns_without_transformation = [ - column for column in columns if column not in self.qc.transformations_dict.keys() - ] - default_transformation = self._get_default_transformation( - columns_without_transformation, self.qc.src_column_mapping, self.qc.schema_dict + cols_without_transform = [col for col in cols if col not in self.qc.transform_dict.keys()] + default_transform = self._get_default_transformation( + cols_without_transform, self.qc.src_col_mapping, self.qc.schema_dict ) - transformation_rule_mapping = custom_transformation + default_transformation + transform_rule_mapping = custom_transform + default_transform - return transformation_rule_mapping + return transform_rule_mapping @staticmethod - def _get_layer_transform(transform_dict: dict[str, Transformation], column: str, layer: str) -> str: - return transform_dict.get(column).source if layer == "source" else transform_dict.get(column).target + def _get_layer_transform(transform_dict: dict[str, Transformation], col: str, layer: str) -> str: + return transform_dict.get(col).source if layer == "source" else transform_dict.get(col).target @staticmethod - def _get_column_expr(func, column_transformations: list[TransformRuleMapping]): - return [func(transformation) for transformation in column_transformations] + def _get_column_expr(func, col_transform: list[TransformRuleMapping]): + return [func(transform) for transform in col_transform] @staticmethod - def _get_column_map(column, column_mapping) -> str: - return column_mapping.get(column, ColumnMapping(source_name='', target_name=column)).target_name + def _get_column_map(col, col_mapping: dict[str, ColumnMapping]) -> str: + return col_mapping.get(col, ColumnMapping(source_name='', target_name=col)).target_name @staticmethod - def _get_column_alias(layer, column, column_mapping): - if column_mapping and column in column_mapping.keys() and layer == "target": - column_alias = column_mapping.get(column).source_name - column_origin = column_mapping.get(column).target_name + def _get_column_alias(layer: str, col: str, col_mapping: dict[str, ColumnMapping]): + if col_mapping and col in col_mapping.keys() and layer == "target": + col_alias = col_mapping.get(col).source_name + col_origin = col_mapping.get(col).target_name else: - column_alias = column - column_origin = column + col_alias = col + col_origin = col - return column_origin, column_alias + return col_origin, col_alias class HashQueryBuilder(QueryBuilder): def build_query(self): - columns = sorted( + hash_cols = sorted( (self.qc.get_join_columns() | self.qc.get_select_columns()) - self.qc.get_threshold_columns() - self.qc.get_drop_columns() ) - key_columns = sorted(self.qc.get_join_columns() | self.qc.get_partition_column()) + key_cols = sorted(self.qc.get_join_columns() | self.qc.get_partition_column()) # get transformation for columns considered for hashing - col_transformations = self._generate_transformation_rule_mapping(columns) - hash_columns_expr = sorted( - self._get_column_expr(TransformRuleMapping.get_column_expression_without_alias, col_transformations) + col_transform = self._generate_transform_rule_mapping(hash_cols) + hash_cols_expr = sorted( + self._get_column_expr(TransformRuleMapping.get_column_expr_without_alias, col_transform) ) - hash_expr = self._generate_hash_algorithm(self.qc.db_type, hash_columns_expr) + hash_expr = self._generate_hash_algorithm(self.qc.source, hash_cols_expr) # get transformation for columns considered for joining and partition key - key_column_transformation = self._generate_transformation_rule_mapping(key_columns) - key_column_expr = sorted( - self._get_column_expr(TransformRuleMapping.get_column_expression_with_alias, key_column_transformation) - ) + key_col_transform = self._generate_transform_rule_mapping(key_cols) + key_col_expr = sorted(self._get_column_expr(TransformRuleMapping.get_column_expr_with_alias, key_col_transform)) # construct select hash query select_query = self._construct_hash_query( - self.qc.get_table_name(), self.qc.get_filter(), hash_expr, key_column_expr + self.qc.get_table_name(), self.qc.get_filter(), hash_expr, key_col_expr ) return select_query @staticmethod - def _generate_hash_algorithm(source: str, column_expr: list[str]) -> str: + def _generate_hash_algorithm(source: str, col_expr: list[str]) -> str: if source in {SourceType.DATABRICKS.value, SourceType.SNOWFLAKE.value}: - hash_expr = "concat(" + ", ".join(column_expr) + ")" + hash_expr = "concat(" + ", ".join(col_expr) + ")" else: - hash_expr = " || ".join(column_expr) + hash_expr = " || ".join(col_expr) return (Constants.hash_algorithm_mapping.get(source.lower()).get("source")).format(hash_expr) @staticmethod - def _construct_hash_query(table_name: str, query_filter: str, hash_expr: str, key_column_expr: list[str]) -> str: + def _construct_hash_query(table: str, query_filter: str, hash_expr: str, key_col_expr: list[str]) -> str: sql_query = StringIO() # construct hash expr sql_query.write(f"select {hash_expr} as {Constants.hash_column_name}") # add join column - if key_column_expr: - sql_query.write(", " + ",".join(key_column_expr)) - sql_query.write(f" from {table_name} where {query_filter}") + if key_col_expr: + sql_query.write(", " + ",".join(key_col_expr)) + sql_query.write(f" from {table} where {query_filter}") select_query = sql_query.getvalue() sql_query.close() @@ -184,23 +181,19 @@ def build_query(self): query_columns = sorted( all_columns if self.qc.layer == "source" - else self.qc.get_mapped_columns(self.qc.src_column_mapping, all_columns) + else self.qc.get_mapped_columns(self.qc.src_col_mapping, all_columns) ) - transformation_rule_mapping = self._get_custom_transformation( - query_columns, self.qc.transformations_dict, self.qc.src_column_mapping - ) - threshold_columns_expr = self._get_column_expr( - TransformRuleMapping.get_column_expression_with_alias, transformation_rule_mapping + transform_rule_mapping = self._get_custom_transformation( + query_columns, self.qc.transform_dict, self.qc.src_col_mapping ) + col_expr = self._get_column_expr(TransformRuleMapping.get_column_expr_with_alias, transform_rule_mapping) - select_query = self._construct_threshold_query( - self.qc.get_table_name(), self.qc.get_filter(), threshold_columns_expr - ) + select_query = self._construct_threshold_query(self.qc.get_table_name(), self.qc.get_filter(), col_expr) return select_query @staticmethod - def _construct_threshold_query(table_name, query_filter, threshold_columns_expr): - column_expr = ",".join(threshold_columns_expr) - return f"select {column_expr} from {table_name} where {query_filter}" + def _construct_threshold_query(table, query_filter, col_expr): + expr = ",".join(col_expr) + return f"select {expr} from {table} where {query_filter}" diff --git a/src/databricks/labs/remorph/reconcile/query_config.py b/src/databricks/labs/remorph/reconcile/query_config.py index 4f70b27ef..282fffebc 100644 --- a/src/databricks/labs/remorph/reconcile/query_config.py +++ b/src/databricks/labs/remorph/reconcile/query_config.py @@ -8,15 +8,15 @@ class QueryConfig: - def __init__(self, table_conf: Table, schema: list[Schema], layer: str, db_type: str): + def __init__(self, table_conf: Table, schema: list[Schema], layer: str, source: str): self.table_conf = table_conf self.schema = schema self.layer = layer - self.db_type = db_type + self.source = source self.schema_dict = {v.column_name: v for v in schema} - self.tgt_column_mapping = table_conf.list_to_dict(ColumnMapping, "target_name") - self.src_column_mapping = table_conf.list_to_dict(ColumnMapping, "source_name") - self.transformations_dict = table_conf.list_to_dict(Transformation, "column_name") + self.tgt_col_mapping = table_conf.list_to_dict(ColumnMapping, "target_name") + self.src_col_mapping = table_conf.list_to_dict(ColumnMapping, "source_name") + self.transform_dict = table_conf.list_to_dict(Transformation, "column_name") def get_threshold_columns(self): return {thresh.column_name for thresh in self.table_conf.thresholds or []} @@ -28,8 +28,8 @@ def get_join_columns(self): def get_select_columns(self): if self.table_conf.select_columns is None: - columns = {sch.column_name for sch in self.schema} - return columns if self.layer == "source" else self.get_mapped_columns(self.tgt_column_mapping, columns) + cols = {sch.column_name for sch in self.schema} + return cols if self.layer == "source" else self.get_mapped_columns(self.tgt_col_mapping, cols) return set(self.table_conf.select_columns) def get_partition_column(self): @@ -44,7 +44,7 @@ def get_drop_columns(self): def get_table_name(self): table_name = self.table_conf.source_name if self.layer == "source" else self.table_conf.target_name - if self.db_type == SourceType.ORACLE.value: + if self.source == SourceType.ORACLE.value: return "{{schema_name}}.{table_name}".format( # pylint: disable=consider-using-f-string table_name=table_name ) @@ -60,10 +60,8 @@ def get_filter(self): return self.table_conf.filters.target @staticmethod - def get_mapped_columns(column_mapping: dict, columns: set[str]) -> set[str]: + def get_mapped_columns(col_mapping: dict, cols: set[str]) -> set[str]: select_columns = set() - for column in columns: - select_columns.add( - column_mapping.get(column, ColumnMapping(source_name=column, target_name='')).source_name - ) + for col in cols: + select_columns.add(col_mapping.get(col, ColumnMapping(source_name=col, target_name='')).source_name) return select_columns diff --git a/src/databricks/labs/remorph/reconcile/recon_config.py b/src/databricks/labs/remorph/reconcile/recon_config.py index 8300b046e..2fd3e2eb3 100644 --- a/src/databricks/labs/remorph/reconcile/recon_config.py +++ b/src/databricks/labs/remorph/reconcile/recon_config.py @@ -10,13 +10,13 @@ class TransformRuleMapping: transformation: str alias_name: str - def get_column_expression_without_alias(self) -> str: + def get_column_expr_without_alias(self) -> str: if self.transformation: return f"{self.transformation}" return f"{self.column_name}" - def get_column_expression_with_alias(self) -> str: - return f"{self.get_column_expression_without_alias()} as {self.alias_name}" + def get_column_expr_with_alias(self) -> str: + return f"{self.get_column_expr_without_alias()} as {self.alias_name}" @dataclass From 39dd55845fe62db00bd09199efa1e97e992e7538 Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Fri, 22 Mar 2024 20:47:05 +0530 Subject: [PATCH 12/15] Fixed the review comments --- .../reconcile/connectors/data_source.py | 13 +-- .../reconcile/connectors/databricks.py | 7 +- .../remorph/reconcile/connectors/oracle.py | 14 +-- .../remorph/reconcile/connectors/snowflake.py | 7 +- .../labs/remorph/reconcile/query_builder.py | 62 +++++------ .../labs/remorph/reconcile/query_config.py | 104 +++++++++++------- 6 files changed, 100 insertions(+), 107 deletions(-) diff --git a/src/databricks/labs/remorph/reconcile/connectors/data_source.py b/src/databricks/labs/remorph/reconcile/connectors/data_source.py index 5435da68e..8487b76b2 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/data_source.py +++ b/src/databricks/labs/remorph/reconcile/connectors/data_source.py @@ -24,12 +24,7 @@ def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOp return NotImplemented @abstractmethod - def get_schema( - self, - catalog: str, - schema: str, - table: str, - ) -> list[Schema]: + def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]: return NotImplemented def _get_jdbc_reader(self, query, jdbc_url, driver): @@ -54,11 +49,7 @@ def _get_secrets(self, key): return self.ws.secrets.get_secret(self.scope, self.source + '_' + key) @staticmethod - def _get_table_or_query( - catalog: str, - schema: str, - query: str, - ): + def _get_table_or_query(catalog: str, schema: str, query: str) -> str: if re.search('select', query, re.IGNORECASE): return query.format(catalog_name=catalog, schema_name=schema) if catalog: diff --git a/src/databricks/labs/remorph/reconcile/connectors/databricks.py b/src/databricks/labs/remorph/reconcile/connectors/databricks.py index 137ae208b..991b63a01 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/databricks.py +++ b/src/databricks/labs/remorph/reconcile/connectors/databricks.py @@ -9,12 +9,7 @@ def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOp # Implement Databricks-specific logic here return NotImplemented - def get_schema( - self, - catalog: str, - schema: str, - table: str, - ) -> list[Schema]: + def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]: # Implement Databricks-specific logic here return NotImplemented diff --git a/src/databricks/labs/remorph/reconcile/connectors/oracle.py b/src/databricks/labs/remorph/reconcile/connectors/oracle.py index f09578a7d..49f4a2376 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/oracle.py +++ b/src/databricks/labs/remorph/reconcile/connectors/oracle.py @@ -21,23 +21,15 @@ def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOp table_query = self._get_table_or_query(catalog, schema, query) if options is None: return self.reader(table_query).options(**self._get_timestamp_options()).load() - return ( - self.reader(table_query) - .options(**self._get_jdbc_reader_options(options) | self._get_timestamp_options()) - .load() - ) + options = self._get_jdbc_reader_options(options) | self._get_timestamp_options() + return self.reader(table_query).options(**options).load() except PySparkException as e: error_msg = ( f"An error occurred while fetching Oracle Data using the following {query} in OracleDataSource : {e!s}" ) raise PySparkException(error_msg) from e - def get_schema( - self, - catalog: str, - schema: str, - table: str, - ) -> list[Schema]: + def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]: try: schema_query = self._get_schema_query(table, schema) schema_df = self.reader(schema_query).load() diff --git a/src/databricks/labs/remorph/reconcile/connectors/snowflake.py b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py index c85f05303..05bc88dab 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py @@ -9,12 +9,7 @@ def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOp # Implement Snowflake-specific logic here return NotImplemented - def get_schema( - self, - catalog: str, - schema: str, - table: str, - ) -> list[Schema]: + def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]: # Implement Snowflake-specific logic here return NotImplemented diff --git a/src/databricks/labs/remorph/reconcile/query_builder.py b/src/databricks/labs/remorph/reconcile/query_builder.py index f869f2033..b7d87839a 100644 --- a/src/databricks/labs/remorph/reconcile/query_builder.py +++ b/src/databricks/labs/remorph/reconcile/query_builder.py @@ -17,13 +17,11 @@ TransformRuleMapping, ) -# pylint: disable=invalid-name - class QueryBuilder(ABC): - def __init__(self, qc: QueryConfig): - self.qc = qc + def __init__(self, qrc: QueryConfig): + self.qrc = qrc @abstractmethod def build_query(self): @@ -31,15 +29,15 @@ def build_query(self): def _get_custom_transformation( self, cols: list[str], transform_dict: dict[str, Transformation], col_mapping: dict[str, ColumnMapping] - ): + ) -> list[TransformRuleMapping]: transform_rule_mapping = [] for col in cols: if col in transform_dict.keys(): - transform = self._get_layer_transform(transform_dict, col, self.qc.layer) + transform = self._get_layer_transform(transform_dict, col, self.qrc.get_layer()) else: transform = None - col_origin, col_alias = self._get_column_alias(self.qc.layer, col, col_mapping) + col_origin, col_alias = self._get_column_alias(self.qrc.get_layer(), col, col_mapping) transform_rule_mapping.append(TransformRuleMapping(col_origin, transform, col_alias)) @@ -47,14 +45,14 @@ def _get_custom_transformation( def _get_default_transformation( self, cols: list[str], col_mapping: dict[str, ColumnMapping], schema: dict[str, Schema] - ): + ) -> list[TransformRuleMapping]: transform_rule_mapping = [] for col in cols: - col_origin = col if self.qc.layer == "source" else self._get_column_map(col, col_mapping) + col_origin = col if self.qrc.get_layer() == "source" else self._get_column_map(col, col_mapping) col_data_type = schema.get(col_origin).data_type - transform = self._get_default_transformation_expr(self.qc.source, col_data_type).format(col_origin) + transform = self._get_default_transformation_expr(self.qrc.get_source(), col_data_type).format(col_origin) - col_origin, col_alias = self._get_column_alias(self.qc.layer, col, col_mapping) + col_origin, col_alias = self._get_column_alias(self.qrc.get_layer(), col, col_mapping) transform_rule_mapping.append(TransformRuleMapping(col_origin, transform, col_alias)) @@ -78,18 +76,18 @@ def _get_default_transformation_expr(data_source: str, data_type: str) -> str: def _generate_transform_rule_mapping(self, cols: list[str]) -> list[TransformRuleMapping]: # compute custom transformation - if self.qc.transform_dict: - cols_with_transform = [col for col in cols if col in self.qc.transform_dict.keys()] + if self.qrc.get_transform_dict(): + cols_with_transform = [col for col in cols if col in self.qrc.get_transform_dict().keys()] custom_transform = self._get_custom_transformation( - cols_with_transform, self.qc.transform_dict, self.qc.src_col_mapping + cols_with_transform, self.qrc.get_transform_dict(), self.qrc.get_src_col_mapping() ) else: custom_transform = [] # compute default transformation - cols_without_transform = [col for col in cols if col not in self.qc.transform_dict.keys()] + cols_without_transform = [col for col in cols if col not in self.qrc.get_transform_dict().keys()] default_transform = self._get_default_transformation( - cols_without_transform, self.qc.src_col_mapping, self.qc.schema_dict + cols_without_transform, self.qrc.get_src_col_mapping(), self.qrc.get_schema_dict() ) transform_rule_mapping = custom_transform + default_transform @@ -109,7 +107,7 @@ def _get_column_map(col, col_mapping: dict[str, ColumnMapping]) -> str: return col_mapping.get(col, ColumnMapping(source_name='', target_name=col)).target_name @staticmethod - def _get_column_alias(layer: str, col: str, col_mapping: dict[str, ColumnMapping]): + def _get_column_alias(layer: str, col: str, col_mapping: dict[str, ColumnMapping]) -> tuple[str, str]: if col_mapping and col in col_mapping.keys() and layer == "target": col_alias = col_mapping.get(col).source_name col_origin = col_mapping.get(col).target_name @@ -122,20 +120,20 @@ def _get_column_alias(layer: str, col: str, col_mapping: dict[str, ColumnMapping class HashQueryBuilder(QueryBuilder): - def build_query(self): + def build_query(self) -> str: hash_cols = sorted( - (self.qc.get_join_columns() | self.qc.get_select_columns()) - - self.qc.get_threshold_columns() - - self.qc.get_drop_columns() + (self.qrc.get_join_columns() | self.qrc.get_select_columns()) + - self.qrc.get_threshold_columns() + - self.qrc.get_drop_columns() ) - key_cols = sorted(self.qc.get_join_columns() | self.qc.get_partition_column()) + key_cols = sorted(self.qrc.get_join_columns() | self.qrc.get_partition_column()) # get transformation for columns considered for hashing col_transform = self._generate_transform_rule_mapping(hash_cols) hash_cols_expr = sorted( self._get_column_expr(TransformRuleMapping.get_column_expr_without_alias, col_transform) ) - hash_expr = self._generate_hash_algorithm(self.qc.source, hash_cols_expr) + hash_expr = self._generate_hash_algorithm(self.qrc.get_source(), hash_cols_expr) # get transformation for columns considered for joining and partition key key_col_transform = self._generate_transform_rule_mapping(key_cols) @@ -143,7 +141,7 @@ def build_query(self): # construct select hash query select_query = self._construct_hash_query( - self.qc.get_table_name(), self.qc.get_filter(), hash_expr, key_col_expr + self.qrc.get_table_name(), self.qrc.get_filter(), hash_expr, key_col_expr ) return select_query @@ -175,25 +173,27 @@ def _construct_hash_query(table: str, query_filter: str, hash_expr: str, key_col class ThresholdQueryBuilder(QueryBuilder): - def build_query(self): - all_columns = set(self.qc.get_threshold_columns() | self.qc.get_join_columns() | self.qc.get_partition_column()) + def build_query(self) -> str: + all_columns = set( + self.qrc.get_threshold_columns() | self.qrc.get_join_columns() | self.qrc.get_partition_column() + ) query_columns = sorted( all_columns - if self.qc.layer == "source" - else self.qc.get_mapped_columns(self.qc.src_col_mapping, all_columns) + if self.qrc.get_layer() == "source" + else self.qrc.get_mapped_columns(self.qrc.get_src_col_mapping(), all_columns) ) transform_rule_mapping = self._get_custom_transformation( - query_columns, self.qc.transform_dict, self.qc.src_col_mapping + query_columns, self.qrc.get_transform_dict(), self.qrc.get_src_col_mapping() ) col_expr = self._get_column_expr(TransformRuleMapping.get_column_expr_with_alias, transform_rule_mapping) - select_query = self._construct_threshold_query(self.qc.get_table_name(), self.qc.get_filter(), col_expr) + select_query = self._construct_threshold_query(self.qrc.get_table_name(), self.qrc.get_filter(), col_expr) return select_query @staticmethod - def _construct_threshold_query(table, query_filter, col_expr): + def _construct_threshold_query(table, query_filter, col_expr) -> str: expr = ",".join(col_expr) return f"select {expr} from {table} where {query_filter}" diff --git a/src/databricks/labs/remorph/reconcile/query_config.py b/src/databricks/labs/remorph/reconcile/query_config.py index 282fffebc..a237e1d27 100644 --- a/src/databricks/labs/remorph/reconcile/query_config.py +++ b/src/databricks/labs/remorph/reconcile/query_config.py @@ -9,55 +9,75 @@ class QueryConfig: def __init__(self, table_conf: Table, schema: list[Schema], layer: str, source: str): - self.table_conf = table_conf - self.schema = schema - self.layer = layer - self.source = source - self.schema_dict = {v.column_name: v for v in schema} - self.tgt_col_mapping = table_conf.list_to_dict(ColumnMapping, "target_name") - self.src_col_mapping = table_conf.list_to_dict(ColumnMapping, "source_name") - self.transform_dict = table_conf.list_to_dict(Transformation, "column_name") - - def get_threshold_columns(self): - return {thresh.column_name for thresh in self.table_conf.thresholds or []} - - def get_join_columns(self): - if self.table_conf.join_columns is None: + self._table_conf = table_conf + self._schema = schema + self._layer = layer + self._source = source + self._schema_dict = {v.column_name: v for v in schema} + self._tgt_col_mapping = table_conf.list_to_dict(ColumnMapping, "target_name") + self._src_col_mapping = table_conf.list_to_dict(ColumnMapping, "source_name") + self._transform_dict = table_conf.list_to_dict(Transformation, "column_name") + + def get_table_conf(self): + return self._table_conf + + def get_schema(self): + return self._schema + + def get_source(self): + return self._source + + def get_layer(self): + return self._layer + + def get_schema_dict(self): + return self._schema_dict + + def get_tgt_col_mapping(self): + return self._tgt_col_mapping + + def get_src_col_mapping(self): + return self._src_col_mapping + + def get_transform_dict(self): + return self._transform_dict + + def get_threshold_columns(self) -> set[str]: + return {thresh.column_name for thresh in self._table_conf.thresholds or []} + + def get_join_columns(self) -> set[str]: + if self._table_conf.join_columns is None: return set() - return set(self.table_conf.join_columns) + return set(self._table_conf.join_columns) - def get_select_columns(self): - if self.table_conf.select_columns is None: - cols = {sch.column_name for sch in self.schema} - return cols if self.layer == "source" else self.get_mapped_columns(self.tgt_col_mapping, cols) - return set(self.table_conf.select_columns) + def get_select_columns(self) -> set[str]: + if self._table_conf.select_columns is None: + cols = {sch.column_name for sch in self._schema} + return cols if self._layer == "source" else self.get_mapped_columns(self._tgt_col_mapping, cols) + return set(self._table_conf.select_columns) - def get_partition_column(self): - if self.table_conf.jdbc_reader_options and self.layer == "source": - return {self.table_conf.jdbc_reader_options.partition_column} + def get_partition_column(self) -> set[str]: + if self._table_conf.jdbc_reader_options and self._layer == "source": + return {self._table_conf.jdbc_reader_options.partition_column} return set() - def get_drop_columns(self): - if self.table_conf.drop_columns is None: + def get_drop_columns(self) -> set[str]: + if self._table_conf.drop_columns is None: return set() - return set(self.table_conf.drop_columns) - - def get_table_name(self): - table_name = self.table_conf.source_name if self.layer == "source" else self.table_conf.target_name - if self.source == SourceType.ORACLE.value: - return "{{schema_name}}.{table_name}".format( # pylint: disable=consider-using-f-string - table_name=table_name - ) - return "{{catalog_name}}.{{schema_name}}.{table_name}".format( # pylint: disable=consider-using-f-string - table_name=table_name - ) - - def get_filter(self): - if self.table_conf.filters is None: + return set(self._table_conf.drop_columns) + + def get_table_name(self) -> str: + table_name = self._table_conf.source_name if self._layer == "source" else self._table_conf.target_name + if self._source == SourceType.ORACLE.value: + return f"{{schema_name}}.{table_name}" + return f"{{catalog_name}}.{{schema_name}}.{table_name}" + + def get_filter(self) -> str: + if self._table_conf.filters is None: return " 1 = 1 " - if self.layer == "source": - return self.table_conf.filters.source - return self.table_conf.filters.target + if self._layer == "source": + return self._table_conf.filters.source + return self._table_conf.filters.target @staticmethod def get_mapped_columns(col_mapping: dict, cols: set[str]) -> set[str]: From 8d13c01d9271a02a497ddf8ab120f3962641c213 Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Fri, 22 Mar 2024 21:04:29 +0530 Subject: [PATCH 13/15] Fixed the pylint error --- tests/unit/reconcile/test_query_builder.py | 78 +++++++++++----------- 1 file changed, 38 insertions(+), 40 deletions(-) diff --git a/tests/unit/reconcile/test_query_builder.py b/tests/unit/reconcile/test_query_builder.py index 1b9bb9e71..f3746acbc 100644 --- a/tests/unit/reconcile/test_query_builder.py +++ b/tests/unit/reconcile/test_query_builder.py @@ -15,8 +15,6 @@ Transformation, ) -# pylint: disable=invalid-name - def test_hash_query_builder_without_join_column(): table_conf = Table( @@ -41,8 +39,8 @@ def test_hash_query_builder_without_join_column(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, src_schema, "source", "oracle") - actual_src_query = HashQueryBuilder(qc).build_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_acctbal),'') || " "coalesce(trim(s_address),'') || coalesce(trim(s_comment),'') || " @@ -63,8 +61,8 @@ def test_hash_query_builder_without_join_column(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") - actual_tgt_query = HashQueryBuilder(qc).build_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_acctbal),''), " "coalesce(trim(s_address),''), coalesce(trim(s_comment),''), " @@ -99,8 +97,8 @@ def test_hash_query_builder_with_defaults(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, src_schema, "source", "oracle") - actual_src_query = HashQueryBuilder(qc).build_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_acctbal),'') || " "coalesce(trim(s_address),'') || coalesce(trim(s_comment),'') || " @@ -121,8 +119,8 @@ def test_hash_query_builder_with_defaults(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") - actual_tgt_query = HashQueryBuilder(qc).build_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_acctbal),''), " "coalesce(trim(s_address),''), coalesce(trim(s_comment),''), " @@ -160,8 +158,8 @@ def test_hash_query_builder_with_select(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, src_schema, "source", "oracle") - actual_src_query = HashQueryBuilder(qc).build_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_address),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " @@ -180,8 +178,8 @@ def test_hash_query_builder_with_select(): Schema("s_comment_t", "varchar"), ] - qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") - actual_tgt_query = HashQueryBuilder(qc).build_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_address_t),''), " "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " @@ -231,8 +229,8 @@ def test_hash_query_builder_with_transformations_with_drop_and_default_select(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, src_schema, "source", "oracle") - actual_src_query = HashQueryBuilder(qc).build_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_nationkey),'') || " "coalesce(trim(s_suppkey),'') || trim(s_address) || trim(s_name) || " @@ -252,8 +250,8 @@ def test_hash_query_builder_with_transformations_with_drop_and_default_select(): Schema("s_comment_t", "varchar"), ] - qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") - actual_tgt_query = HashQueryBuilder(qc).build_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select sha2(concat(cast(s_acctbal_t as decimal(38,2)), " "coalesce(trim(s_nationkey_t),''), coalesce(trim(s_suppkey_t),''), " @@ -292,8 +290,8 @@ def test_hash_query_builder_with_jdbc_reader_options(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, src_schema, "source", "oracle") - actual_src_query = HashQueryBuilder(qc).build_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_address),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " @@ -314,8 +312,8 @@ def test_hash_query_builder_with_jdbc_reader_options(): Schema("s_comment_t", "varchar"), ] - qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") - actual_tgt_query = HashQueryBuilder(qc).build_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_address_t),''), " "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " @@ -358,8 +356,8 @@ def test_hash_query_builder_with_threshold(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, src_schema, "source", "oracle") - actual_src_query = HashQueryBuilder(qc).build_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_comment),'') || " "coalesce(trim(s_nationkey),'') || coalesce(trim(s_suppkey),'') || " @@ -380,8 +378,8 @@ def test_hash_query_builder_with_threshold(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") - actual_tgt_query = HashQueryBuilder(qc).build_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_comment),''), " "coalesce(trim(s_nationkey),''), coalesce(trim(s_suppkey_t),''), " @@ -418,8 +416,8 @@ def test_hash_query_builder_with_filters(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, src_schema, "source", "snowflake") - actual_src_query = HashQueryBuilder(qc).build_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "snowflake") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( "select sha2(concat(coalesce(trim(s_address),''), coalesce(trim(s_name),''), " "coalesce(trim(s_suppkey),'')),256) as hash_value__recon, " @@ -438,8 +436,8 @@ def test_hash_query_builder_with_filters(): Schema("s_comment_t", "varchar"), ] - qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") - actual_tgt_query = HashQueryBuilder(qc).build_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_address_t),''), " "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " @@ -473,8 +471,8 @@ def test_hash_query_builder_with_unsupported_source(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, src_schema, "source", "abc") - query_builder = HashQueryBuilder(qc) + src_qrc = QueryConfig(table_conf, src_schema, "source", "abc") + query_builder = HashQueryBuilder(src_qrc) with pytest.raises(Exception) as exc_info: query_builder.build_query() @@ -505,8 +503,8 @@ def test_threshold_query_builder_with_defaults(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, src_schema, "source", "oracle") - actual_src_query = ThresholdQueryBuilder(qc).build_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = ThresholdQueryBuilder(src_qrc).build_query() expected_src_query = ( 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {schema_name}.supplier where 1 = 1 ' ) @@ -522,8 +520,8 @@ def test_threshold_query_builder_with_defaults(): Schema("s_comment", "varchar"), ] - qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") - actual_tgt_query = ThresholdQueryBuilder(qc).build_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = ThresholdQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 ' ) @@ -577,8 +575,8 @@ def test_threshold_query_builder_with_transformations_and_jdbc(): Schema("s_suppdate", "timestamp"), ] - qc = QueryConfig(table_conf, src_schema, "source", "oracle") - actual_src_query = ThresholdQueryBuilder(qc).build_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = ThresholdQueryBuilder(src_qrc).build_query() expected_src_query = ( "select trim(to_char(s_acctbal, '9999999999.99')) as s_acctbal,s_nationkey " "as s_nationkey,s_suppdate as s_suppdate,trim(s_suppkey) as s_suppkey from " @@ -597,8 +595,8 @@ def test_threshold_query_builder_with_transformations_and_jdbc(): Schema("s_suppdate_t", "timestamp"), ] - qc = QueryConfig(table_conf, tgt_schema, "target", "databricks") - actual_tgt_query = ThresholdQueryBuilder(qc).build_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = ThresholdQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select cast(s_acctbal_t as decimal(38,2)) as s_acctbal,s_suppdate_t as " "s_suppdate,trim(s_suppkey_t) as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 " From 4605f55fd9c1b85d5a2d46e3924a7d446e4eeb0d Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Fri, 22 Mar 2024 22:38:53 +0530 Subject: [PATCH 14/15] Fixed the pylint error --- src/databricks/labs/remorph/reconcile/recon_config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/databricks/labs/remorph/reconcile/recon_config.py b/src/databricks/labs/remorph/reconcile/recon_config.py index 2fd3e2eb3..7c9a90b74 100644 --- a/src/databricks/labs/remorph/reconcile/recon_config.py +++ b/src/databricks/labs/remorph/reconcile/recon_config.py @@ -55,6 +55,9 @@ class Filters: target: str = None +T = TypeVar("T") + + @dataclass class Table: source_name: str @@ -68,8 +71,6 @@ class Table: thresholds: list[Thresholds] | None = None filters: Filters | None = None - T = TypeVar("T") # pylint: disable=invalid-name - def list_to_dict(self, cls: type[T], key: str) -> T: for _, value in self.__dict__.items(): if isinstance(value, list): From b02970dad05c564300a79d901319e50a16cd0343 Mon Sep 17 00:00:00 2001 From: Ravikumar Date: Sat, 23 Mar 2024 10:05:17 +0530 Subject: [PATCH 15/15] Refactored the class --- .../labs/remorph/reconcile/query_builder.py | 46 ++++++-------- .../labs/remorph/reconcile/query_config.py | 61 ++++++++++--------- .../labs/remorph/reconcile/recon_config.py | 7 +-- 3 files changed, 55 insertions(+), 59 deletions(-) diff --git a/src/databricks/labs/remorph/reconcile/query_builder.py b/src/databricks/labs/remorph/reconcile/query_builder.py index b7d87839a..01ef2ec89 100644 --- a/src/databricks/labs/remorph/reconcile/query_builder.py +++ b/src/databricks/labs/remorph/reconcile/query_builder.py @@ -33,11 +33,11 @@ def _get_custom_transformation( transform_rule_mapping = [] for col in cols: if col in transform_dict.keys(): - transform = self._get_layer_transform(transform_dict, col, self.qrc.get_layer()) + transform = self._get_layer_transform(transform_dict, col, self.qrc.layer) else: transform = None - col_origin, col_alias = self._get_column_alias(self.qrc.get_layer(), col, col_mapping) + col_origin, col_alias = self._get_column_alias(self.qrc.layer, col, col_mapping) transform_rule_mapping.append(TransformRuleMapping(col_origin, transform, col_alias)) @@ -48,11 +48,11 @@ def _get_default_transformation( ) -> list[TransformRuleMapping]: transform_rule_mapping = [] for col in cols: - col_origin = col if self.qrc.get_layer() == "source" else self._get_column_map(col, col_mapping) + col_origin = col if self.qrc.layer == "source" else self._get_column_map(col, col_mapping) col_data_type = schema.get(col_origin).data_type - transform = self._get_default_transformation_expr(self.qrc.get_source(), col_data_type).format(col_origin) + transform = self._get_default_transformation_expr(self.qrc.source, col_data_type).format(col_origin) - col_origin, col_alias = self._get_column_alias(self.qrc.get_layer(), col, col_mapping) + col_origin, col_alias = self._get_column_alias(self.qrc.layer, col, col_mapping) transform_rule_mapping.append(TransformRuleMapping(col_origin, transform, col_alias)) @@ -76,18 +76,18 @@ def _get_default_transformation_expr(data_source: str, data_type: str) -> str: def _generate_transform_rule_mapping(self, cols: list[str]) -> list[TransformRuleMapping]: # compute custom transformation - if self.qrc.get_transform_dict(): - cols_with_transform = [col for col in cols if col in self.qrc.get_transform_dict().keys()] + if self.qrc.transform_dict: + cols_with_transform = [col for col in cols if col in self.qrc.transform_dict.keys()] custom_transform = self._get_custom_transformation( - cols_with_transform, self.qrc.get_transform_dict(), self.qrc.get_src_col_mapping() + cols_with_transform, self.qrc.transform_dict, self.qrc.src_col_mapping ) else: custom_transform = [] # compute default transformation - cols_without_transform = [col for col in cols if col not in self.qrc.get_transform_dict().keys()] + cols_without_transform = [col for col in cols if col not in self.qrc.transform_dict.keys()] default_transform = self._get_default_transformation( - cols_without_transform, self.qrc.get_src_col_mapping(), self.qrc.get_schema_dict() + cols_without_transform, self.qrc.src_col_mapping, self.qrc.schema_dict ) transform_rule_mapping = custom_transform + default_transform @@ -122,27 +122,23 @@ class HashQueryBuilder(QueryBuilder): def build_query(self) -> str: hash_cols = sorted( - (self.qrc.get_join_columns() | self.qrc.get_select_columns()) - - self.qrc.get_threshold_columns() - - self.qrc.get_drop_columns() + (self.qrc.join_columns | self.qrc.select_columns) - self.qrc.threshold_columns - self.qrc.drop_columns ) - key_cols = sorted(self.qrc.get_join_columns() | self.qrc.get_partition_column()) + key_cols = sorted(self.qrc.join_columns | self.qrc.partition_column) # get transformation for columns considered for hashing col_transform = self._generate_transform_rule_mapping(hash_cols) hash_cols_expr = sorted( self._get_column_expr(TransformRuleMapping.get_column_expr_without_alias, col_transform) ) - hash_expr = self._generate_hash_algorithm(self.qrc.get_source(), hash_cols_expr) + hash_expr = self._generate_hash_algorithm(self.qrc.source, hash_cols_expr) # get transformation for columns considered for joining and partition key key_col_transform = self._generate_transform_rule_mapping(key_cols) key_col_expr = sorted(self._get_column_expr(TransformRuleMapping.get_column_expr_with_alias, key_col_transform)) # construct select hash query - select_query = self._construct_hash_query( - self.qrc.get_table_name(), self.qrc.get_filter(), hash_expr, key_col_expr - ) + select_query = self._construct_hash_query(self.qrc.table_name, self.qrc.filter, hash_expr, key_col_expr) return select_query @@ -153,7 +149,7 @@ def _generate_hash_algorithm(source: str, col_expr: list[str]) -> str: else: hash_expr = " || ".join(col_expr) - return (Constants.hash_algorithm_mapping.get(source.lower()).get("source")).format(hash_expr) + return (Constants.hash_algorithm_mapping.get(source).get("source")).format(hash_expr) @staticmethod def _construct_hash_query(table: str, query_filter: str, hash_expr: str, key_col_expr: list[str]) -> str: @@ -174,22 +170,20 @@ def _construct_hash_query(table: str, query_filter: str, hash_expr: str, key_col class ThresholdQueryBuilder(QueryBuilder): def build_query(self) -> str: - all_columns = set( - self.qrc.get_threshold_columns() | self.qrc.get_join_columns() | self.qrc.get_partition_column() - ) + all_columns = set(self.qrc.threshold_columns | self.qrc.join_columns | self.qrc.partition_column) query_columns = sorted( all_columns - if self.qrc.get_layer() == "source" - else self.qrc.get_mapped_columns(self.qrc.get_src_col_mapping(), all_columns) + if self.qrc.layer == "source" + else self.qrc.get_mapped_columns(self.qrc.src_col_mapping, all_columns) ) transform_rule_mapping = self._get_custom_transformation( - query_columns, self.qrc.get_transform_dict(), self.qrc.get_src_col_mapping() + query_columns, self.qrc.transform_dict, self.qrc.src_col_mapping ) col_expr = self._get_column_expr(TransformRuleMapping.get_column_expr_with_alias, transform_rule_mapping) - select_query = self._construct_threshold_query(self.qrc.get_table_name(), self.qrc.get_filter(), col_expr) + select_query = self._construct_threshold_query(self.qrc.table_name, self.qrc.filter, col_expr) return select_query diff --git a/src/databricks/labs/remorph/reconcile/query_config.py b/src/databricks/labs/remorph/reconcile/query_config.py index a237e1d27..e0e6d2e62 100644 --- a/src/databricks/labs/remorph/reconcile/query_config.py +++ b/src/databricks/labs/remorph/reconcile/query_config.py @@ -13,66 +13,69 @@ def __init__(self, table_conf: Table, schema: list[Schema], layer: str, source: self._schema = schema self._layer = layer self._source = source - self._schema_dict = {v.column_name: v for v in schema} - self._tgt_col_mapping = table_conf.list_to_dict(ColumnMapping, "target_name") - self._src_col_mapping = table_conf.list_to_dict(ColumnMapping, "source_name") - self._transform_dict = table_conf.list_to_dict(Transformation, "column_name") - def get_table_conf(self): - return self._table_conf - - def get_schema(self): - return self._schema - - def get_source(self): + @property + def source(self): return self._source - def get_layer(self): + @property + def layer(self): return self._layer - def get_schema_dict(self): - return self._schema_dict + @property + def schema_dict(self): + return {v.column_name: v for v in self._schema} - def get_tgt_col_mapping(self): - return self._tgt_col_mapping + @property + def tgt_col_mapping(self): + return self._table_conf.list_to_dict(ColumnMapping, "target_name") - def get_src_col_mapping(self): - return self._src_col_mapping + @property + def src_col_mapping(self): + return self._table_conf.list_to_dict(ColumnMapping, "source_name") - def get_transform_dict(self): - return self._transform_dict + @property + def transform_dict(self): + return self._table_conf.list_to_dict(Transformation, "column_name") - def get_threshold_columns(self) -> set[str]: + @property + def threshold_columns(self) -> set[str]: return {thresh.column_name for thresh in self._table_conf.thresholds or []} - def get_join_columns(self) -> set[str]: + @property + def join_columns(self) -> set[str]: if self._table_conf.join_columns is None: return set() return set(self._table_conf.join_columns) - def get_select_columns(self) -> set[str]: + @property + def select_columns(self) -> set[str]: if self._table_conf.select_columns is None: cols = {sch.column_name for sch in self._schema} - return cols if self._layer == "source" else self.get_mapped_columns(self._tgt_col_mapping, cols) + return cols if self._layer == "source" else self.get_mapped_columns(self.tgt_col_mapping, cols) return set(self._table_conf.select_columns) - def get_partition_column(self) -> set[str]: + @property + def partition_column(self) -> set[str]: if self._table_conf.jdbc_reader_options and self._layer == "source": return {self._table_conf.jdbc_reader_options.partition_column} return set() - def get_drop_columns(self) -> set[str]: + @property + def drop_columns(self) -> set[str]: if self._table_conf.drop_columns is None: return set() return set(self._table_conf.drop_columns) - def get_table_name(self) -> str: + @property + def table_name(self) -> str: table_name = self._table_conf.source_name if self._layer == "source" else self._table_conf.target_name if self._source == SourceType.ORACLE.value: return f"{{schema_name}}.{table_name}" return f"{{catalog_name}}.{{schema_name}}.{table_name}" - def get_filter(self) -> str: + @property + def filter(self) -> str: if self._table_conf.filters is None: return " 1 = 1 " if self._layer == "source": @@ -80,7 +83,7 @@ def get_filter(self) -> str: return self._table_conf.filters.target @staticmethod - def get_mapped_columns(col_mapping: dict, cols: set[str]) -> set[str]: + def get_mapped_columns(col_mapping: dict[str, ColumnMapping], cols: set[str]) -> set[str]: select_columns = set() for col in cols: select_columns.add(col_mapping.get(col, ColumnMapping(source_name=col, target_name='')).source_name) diff --git a/src/databricks/labs/remorph/reconcile/recon_config.py b/src/databricks/labs/remorph/reconcile/recon_config.py index 7c9a90b74..059514d09 100644 --- a/src/databricks/labs/remorph/reconcile/recon_config.py +++ b/src/databricks/labs/remorph/reconcile/recon_config.py @@ -55,9 +55,6 @@ class Filters: target: str = None -T = TypeVar("T") - - @dataclass class Table: source_name: str @@ -71,7 +68,9 @@ class Table: thresholds: list[Thresholds] | None = None filters: Filters | None = None - def list_to_dict(self, cls: type[T], key: str) -> T: + Typ = TypeVar("Typ") + + def list_to_dict(self, cls: type[Typ], key: str) -> Typ: for _, value in self.__dict__.items(): if isinstance(value, list): if all(isinstance(x, cls) for x in value):