diff --git a/src/databricks/labs/ucx/source_code/pyspark.py b/src/databricks/labs/ucx/source_code/pyspark.py index f4419bf5b9..9cf51294e9 100644 --- a/src/databricks/labs/ucx/source_code/pyspark.py +++ b/src/databricks/labs/ucx/source_code/pyspark.py @@ -14,6 +14,28 @@ from databricks.labs.ucx.source_code.queries import FromTable +class AstHelper: + @staticmethod + def get_full_function_name(node): + if isinstance(node.func, ast.Attribute): + return AstHelper._get_value(node.func) + + if isinstance(node.func, ast.Name): + return node.func.id + + return None + + @staticmethod + def _get_value(node): + if isinstance(node.value, ast.Name): + return node.value.id + '.' + node.attr + + if isinstance(node.value, ast.Attribute): + return AstHelper._get_value(node.value) + '.' + node.attr + + return None + + @dataclass class Matcher(ABC): method_name: str @@ -21,11 +43,14 @@ class Matcher(ABC): max_args: int table_arg_index: int table_arg_name: str | None = None + call_context: dict[str, set[str]] | None = None def matches(self, node: ast.AST): - if not (isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute)): - return False - return self._get_table_arg(node) is not None + return ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and self._get_table_arg(node) is not None + ) @abstractmethod def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: @@ -39,9 +64,26 @@ def _get_table_arg(self, node: ast.Call): if len(node.args) > 0: return node.args[self.table_arg_index] if self.min_args <= len(node.args) <= self.max_args else None assert self.table_arg_name is not None + if not node.keywords: + return None arg = next(kw for kw in node.keywords if kw.arg == self.table_arg_name) return arg.value if arg is not None else None + def _check_call_context(self, node: ast.Call) -> bool: + assert isinstance(node.func, ast.Attribute) # Avoid linter warning + func_name = node.func.attr + qualified_name = AstHelper.get_full_function_name(node) + + # Check if the call_context is None as that means all calls are checked + if self.call_context is None: + return True + + # Get the qualified names from the call_context dictionary + qualified_names = self.call_context.get(func_name) + + # Check if the qualified name is in the set of qualified names that are allowed + return qualified_name in qualified_names if qualified_names else False + @dataclass class QueryMatcher(Matcher): @@ -78,19 +120,8 @@ class TableNameMatcher(Matcher): def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: table_arg = self._get_table_arg(node) - if isinstance(table_arg, ast.Constant): - dst = self._find_dest(index, table_arg.value, from_table.schema) - if dst is not None: - yield Deprecation( - code='table-migrate', - message=f"Table {table_arg.value} is migrated to {dst.destination()} in Unity Catalog", - # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 - start_line=node.lineno, - start_col=node.col_offset, - end_line=node.end_lineno or 0, - end_col=node.end_col_offset or 0, - ) - else: + + if not isinstance(table_arg, ast.Constant): assert isinstance(node.func, ast.Attribute) # always true, avoids a pylint warning yield Advisory( code='table-migrate', @@ -100,6 +131,21 @@ def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> end_line=node.end_lineno or 0, end_col=node.end_col_offset or 0, ) + return + + dst = self._find_dest(index, table_arg.value, from_table.schema) + if dst is None: + return + + yield Deprecation( + code='table-migrate', + message=f"Table {table_arg.value} is migrated to {dst.destination()} in Unity Catalog", + # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: table_arg = self._get_table_arg(node) @@ -135,7 +181,62 @@ def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> ) def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: - raise NotImplementedError("Should never get there!") + # No transformations to apply + return + + +@dataclass +class DirectFilesystemAccessMatcher(Matcher): + _DIRECT_FS_REFS = { + "s3a://", + "s3n://", + "s3://", + "wasb://", + "wasbs://", + "abfs://", + "abfss://", + "dbfs:/", + "hdfs://", + "file:/", + } + + def matches(self, node: ast.AST): + return ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and self._get_table_arg(node) is not None + ) + + def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: + table_arg = self._get_table_arg(node) + + if not isinstance(table_arg, ast.Constant): + return + + if any(table_arg.value.startswith(prefix) for prefix in self._DIRECT_FS_REFS): + yield Deprecation( + code='direct-filesystem-access', + message=f"The use of direct filesystem references is deprecated: {table_arg.value}", + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) + return + + if table_arg.value.startswith("/") and self._check_call_context(node): + yield Deprecation( + code='direct-filesystem-access', + message=f"The use of default dbfs: references is deprecated: {table_arg.value}", + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) + + def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: + # No transformations to apply + return class SparkMatchers: @@ -193,6 +294,37 @@ def __init__(self): TableNameMatcher("register", 1, 2, 0, "name"), ] + direct_fs_access_matchers = [ + DirectFilesystemAccessMatcher("ls", 1, 1, 0, call_context={"ls": {"dbutils.fs.ls"}}), + DirectFilesystemAccessMatcher("cp", 1, 2, 0, call_context={"cp": {"dbutils.fs.cp"}}), + DirectFilesystemAccessMatcher("rm", 1, 1, 0, call_context={"rm": {"dbutils.fs.rm"}}), + DirectFilesystemAccessMatcher("head", 1, 1, 0, call_context={"head": {"dbutils.fs.head"}}), + DirectFilesystemAccessMatcher("put", 1, 2, 0, call_context={"put": {"dbutils.fs.put"}}), + DirectFilesystemAccessMatcher("mkdirs", 1, 1, 0, call_context={"mkdirs": {"dbutils.fs.mkdirs"}}), + DirectFilesystemAccessMatcher("mv", 1, 2, 0, call_context={"mv": {"dbutils.fs.mv"}}), + DirectFilesystemAccessMatcher("text", 1, 3, 0), + DirectFilesystemAccessMatcher("csv", 1, 1000, 0), + DirectFilesystemAccessMatcher("json", 1, 1000, 0), + DirectFilesystemAccessMatcher("orc", 1, 1000, 0), + DirectFilesystemAccessMatcher("parquet", 1, 1000, 0), + DirectFilesystemAccessMatcher("save", 0, 1000, -1, "path"), + DirectFilesystemAccessMatcher("load", 0, 1000, -1, "path"), + DirectFilesystemAccessMatcher("option", 1, 1000, 1), # Only .option("path", "xxx://bucket/path") will hit + DirectFilesystemAccessMatcher("addFile", 1, 3, 0), + DirectFilesystemAccessMatcher("binaryFiles", 1, 2, 0), + DirectFilesystemAccessMatcher("binaryRecords", 1, 2, 0), + DirectFilesystemAccessMatcher("dump_profiles", 1, 1, 0), + DirectFilesystemAccessMatcher("hadoopFile", 1, 8, 0), + DirectFilesystemAccessMatcher("newAPIHadoopFile", 1, 8, 0), + DirectFilesystemAccessMatcher("pickleFile", 1, 3, 0), + DirectFilesystemAccessMatcher("saveAsHadoopFile", 1, 8, 0), + DirectFilesystemAccessMatcher("saveAsNewAPIHadoopFile", 1, 7, 0), + DirectFilesystemAccessMatcher("saveAsPickleFile", 1, 2, 0), + DirectFilesystemAccessMatcher("saveAsSequenceFile", 1, 2, 0), + DirectFilesystemAccessMatcher("saveAsTextFile", 1, 2, 0), + DirectFilesystemAccessMatcher("load_from_path", 1, 1, 0), + ] + # nothing to migrate in UserDefinedFunction, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.UserDefinedFunction.html # nothing to migrate in UserDefinedTableFunction, see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.UserDefinedTableFunction.html self._matchers = {} @@ -203,6 +335,7 @@ def __init__(self): + spark_dataframereader_matchers + spark_dataframewriter_matchers + spark_udtfregistration_matchers + + direct_fs_access_matchers ): self._matchers[matcher.method_name] = matcher diff --git a/tests/unit/source_code/test_notebook_linter.py b/tests/unit/source_code/test_notebook_linter.py index 03d35c69fb..e95359f2b8 100644 --- a/tests/unit/source_code/test_notebook_linter.py +++ b/tests/unit/source_code/test_notebook_linter.py @@ -39,6 +39,14 @@ end_line=4, end_col=1024, ), + Deprecation( + code='direct-filesystem-access', + message='The use of default dbfs: references is deprecated: ' '/mnt/things/e/f/g', + start_line=14, + start_col=8, + end_line=14, + end_col=43, + ), Deprecation( code='dbfs-usage', message='Deprecated file system path in call to: /mnt/things/e/f/g', @@ -82,6 +90,14 @@ """, [ + Deprecation( + code='direct-filesystem-access', + message='The use of default dbfs: references is deprecated: ' '/mnt/things/e/f/g', + start_line=5, + start_col=8, + end_line=5, + end_col=43, + ), Deprecation( code='dbfs-usage', message='Deprecated file system path in call to: /mnt/things/e/f/g', @@ -154,6 +170,30 @@ MERGE INTO delta.`/dbfs/...` t USING source ON t.key = source.key WHEN MATCHED THEN DELETE """, [ + Deprecation( + code='direct-filesystem-access', + message='The use of default dbfs: references is deprecated: /mnt/foo/bar', + start_line=15, + start_col=0, + end_line=15, + end_col=34, + ), + Deprecation( + code='direct-filesystem-access', + message='The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar', + start_line=16, + start_col=0, + end_line=16, + end_col=39, + ), + Deprecation( + code='direct-filesystem-access', + message='The use of direct filesystem references is deprecated: dbfs://mnt/foo/bar', + start_line=17, + start_col=0, + end_line=17, + end_col=40, + ), Advisory( code='dbfs-usage', message='Possible deprecated file system path: dbfs:/...', diff --git a/tests/unit/source_code/test_pyspark.py b/tests/unit/source_code/test_pyspark.py index a9dd0e95e3..9fba44f018 100644 --- a/tests/unit/source_code/test_pyspark.py +++ b/tests/unit/source_code/test_pyspark.py @@ -1,7 +1,9 @@ +import ast + import pytest from databricks.labs.ucx.source_code.base import Advisory, Deprecation, CurrentSessionState -from databricks.labs.ucx.source_code.pyspark import SparkMatchers, SparkSql +from databricks.labs.ucx.source_code.pyspark import SparkMatchers, SparkSql, AstHelper, TableNameMatcher from databricks.labs.ucx.source_code.queries import FromTable @@ -17,7 +19,6 @@ def test_spark_sql_no_match(empty_index): sqf = SparkSql(ftf, empty_index) old_code = """ -spark.read.csv("s3://bucket/path") for i in range(10): result = spark.sql("SELECT * FROM old.things").collect() print(len(result)) @@ -37,6 +38,14 @@ def test_spark_sql_match(migration_index): print(len(result)) """ assert [ + Deprecation( + code='direct-filesystem-access', + message='The use of direct filesystem references is deprecated: ' 's3://bucket/path', + start_line=2, + start_col=0, + end_line=2, + end_col=34, + ), Deprecation( code='table-migrate', message='Table old.things is migrated to brand.new.stuff in Unity Catalog', @@ -44,7 +53,7 @@ def test_spark_sql_match(migration_index): start_col=13, end_line=4, end_col=50, - ) + ), ] == list(sqf.lint(old_code)) @@ -59,6 +68,14 @@ def test_spark_sql_match_named(migration_index): print(len(result)) """ assert [ + Deprecation( + code='direct-filesystem-access', + message='The use of direct filesystem references is deprecated: ' 's3://bucket/path', + start_line=2, + start_col=0, + end_line=2, + end_col=34, + ), Deprecation( code='table-migrate', message='Table old.things is migrated to brand.new.stuff in Unity Catalog', @@ -66,7 +83,7 @@ def test_spark_sql_match_named(migration_index): start_col=13, end_line=4, end_col=71, - ) + ), ] == list(sqf.lint(old_code)) @@ -104,6 +121,14 @@ def test_spark_table_match(migration_index, method_name): do_stuff_with_df(df) """ assert [ + Deprecation( + code='direct-filesystem-access', + message='The use of direct filesystem references is deprecated: ' 's3://bucket/path', + start_line=2, + start_col=0, + end_line=2, + end_col=34, + ), Deprecation( code='table-migrate', message='Table old.things is migrated to brand.new.stuff in Unity Catalog', @@ -111,7 +136,7 @@ def test_spark_table_match(migration_index, method_name): start_col=9, end_line=4, end_col=17 + len(method_name) + len(args), - ) + ), ] == list(sqf.lint(old_code)) @@ -125,7 +150,6 @@ def test_spark_table_no_match(migration_index, method_name): args_list[matcher.table_arg_index] = '"table.we.know.nothing.about"' args = ",".join(args_list) old_code = f""" -spark.read.csv("s3://bucket/path") for i in range(10): df = spark.{method_name}({args}) do_stuff_with_df(df) @@ -145,7 +169,6 @@ def test_spark_table_too_many_args(migration_index, method_name): args_list[matcher.table_arg_index] = '"table.we.know.nothing.about"' args = ",".join(args_list) old_code = f""" -spark.read.csv("s3://bucket/path") for i in range(10): df = spark.{method_name}({args}) do_stuff_with_df(df) @@ -163,6 +186,14 @@ def test_spark_table_named_args(migration_index): do_stuff_with_df(df) """ assert [ + Deprecation( + code='direct-filesystem-access', + message='The use of direct filesystem references is deprecated: ' 's3://bucket/path', + start_line=2, + start_col=0, + end_line=2, + end_col=34, + ), Deprecation( code='table-migrate', message='Table old.things is migrated to brand.new.stuff in Unity Catalog', @@ -170,7 +201,7 @@ def test_spark_table_named_args(migration_index): start_col=9, end_line=4, end_col=59, - ) + ), ] == list(sqf.lint(old_code)) @@ -184,6 +215,14 @@ def test_spark_table_variable_arg(migration_index): do_stuff_with_df(df) """ assert [ + Deprecation( + code='direct-filesystem-access', + message='The use of direct filesystem references is deprecated: ' 's3://bucket/path', + start_line=2, + start_col=0, + end_line=2, + end_col=34, + ), Advisory( code='table-migrate', message="Can't migrate 'saveAsTable' because its table name argument is not a constant", @@ -191,7 +230,7 @@ def test_spark_table_variable_arg(migration_index): start_col=9, end_line=4, end_col=32, - ) + ), ] == list(sqf.lint(old_code)) @@ -205,6 +244,14 @@ def test_spark_table_fstring_arg(migration_index): do_stuff_with_df(df) """ assert [ + Deprecation( + code='direct-filesystem-access', + message='The use of direct filesystem references is deprecated: ' 's3://bucket/path', + start_line=2, + start_col=0, + end_line=2, + end_col=34, + ), Advisory( code='table-migrate', message="Can't migrate 'saveAsTable' because its table name argument is not a constant", @@ -212,7 +259,7 @@ def test_spark_table_fstring_arg(migration_index): start_col=9, end_line=4, end_col=42, - ) + ), ] == list(sqf.lint(old_code)) @@ -225,6 +272,14 @@ def test_spark_table_return_value(migration_index): do_stuff_with_table(table) """ assert [ + Deprecation( + code='direct-filesystem-access', + message='The use of direct filesystem references is deprecated: ' 's3://bucket/path', + start_line=2, + start_col=0, + end_line=2, + end_col=34, + ), Advisory( code='table-migrate', message="Call to 'listTables' will return a list of .. instead of .
.", @@ -232,10 +287,21 @@ def test_spark_table_return_value(migration_index): start_col=13, end_line=3, end_col=31, - ) + ), ] == list(sqf.lint(old_code)) +def test_spark_table_return_value_apply(migration_index): + ftf = FromTable(migration_index, CurrentSessionState()) + sqf = SparkSql(ftf, migration_index) + old_code = """spark.read.csv('s3://bucket/path') +for table in spark.listTables(): + do_stuff_with_table(table)""" + fixed_code = sqf.apply(old_code) + # no transformations to apply, only lint messages + assert fixed_code == old_code + + def test_spark_sql_fix(migration_index): ftf = FromTable(migration_index, CurrentSessionState()) sqf = SparkSql(ftf, migration_index) @@ -253,3 +319,486 @@ def test_spark_sql_fix(migration_index): result = spark.sql('SELECT * FROM brand.new.stuff').collect() print(len(result))""" ) + + +@pytest.mark.parametrize( + "code, expected", + [ + # Test for 'ls' function + ( + """dbutils.fs.ls("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=34, + ) + ], + ), + # Test for 'cp' function. Note that the current code will stop at the first deprecation found. + ( + """dbutils.fs.cp("s3n://bucket/path", "s3n://another_bucket/another_path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3n://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=71, + ) + ], + ), + # Test for 'rm' function + ( + """dbutils.fs.rm("s3://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=33, + ) + ], + ), + # Test for 'head' function + ( + """dbutils.fs.head("wasb://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: wasb://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=37, + ) + ], + ), + # Test for 'put' function + ( + """dbutils.fs.put("wasbs://bucket/path", "data")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: wasbs://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=45, + ) + ], + ), + # Test for 'mkdirs' function + ( + """dbutils.fs.mkdirs("abfs://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: abfs://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=39, + ) + ], + ), + # Test for 'move' function + ( + """dbutils.fs.mv("wasb://bucket/path", "wasb://another_bucket/another_path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: wasb://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=73, + ) + ], + ), + # Test for 'text' function + ( + """spark.read.text("wasbs://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: wasbs://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=38, + ) + ], + ), + # Test for 'csv' function + ( + """spark.read.csv("abfs://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: abfs://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=36, + ) + ], + ), + # Test for option function + ( + """(df.write + .format("parquet") + .option("path", "s3a://your_bucket_name/your_directory/") + .option("spark.hadoop.fs.s3a.access.key", "your_access_key") + .option("spark.hadoop.fs.s3a.secret.key", "your_secret_key") + .save())""", + [ + Deprecation( + code='direct-filesystem-access', + message='The use of direct filesystem references is deprecated: ' + "s3a://your_bucket_name/your_directory/", + start_line=1, + start_col=1, + end_line=3, + end_col=59, + ) + ], + ), + # Test for 'json' function + ( + """spark.read.json("abfss://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: abfss://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=38, + ) + ], + ), + # Test for 'orc' function + ( + """spark.read.orc("dbfs://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: dbfs://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=36, + ) + ], + ), + # Test for 'parquet' function + ( + """spark.read.parquet("hdfs://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: hdfs://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=40, + ) + ], + ), + # Test for 'save' function + ( + """spark.write.save("file://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: file://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=38, + ) + ], + ), + # Test for 'load' function with default to dbfs + ( + """spark.read.load("/bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of default dbfs: references is deprecated: /bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=31, + ) + ], + ), + # Test for 'addFile' function + ( + """spark.addFile("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=34, + ) + ], + ), + # Test for 'binaryFiles' function + ( + """spark.binaryFiles("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=38, + ) + ], + ), + # Test for 'binaryRecords' function + ( + """spark.binaryRecords("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=40, + ) + ], + ), + # Test for 'dump_profiles' function + ( + """spark.dump_profiles("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=40, + ) + ], + ), + # Test for 'hadoopFile' function + ( + """spark.hadoopFile("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=37, + ) + ], + ), + # Test for 'newAPIHadoopFile' function + ( + """spark.newAPIHadoopFile("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=43, + ) + ], + ), + # Test for 'pickleFile' function + ( + """spark.pickleFile("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=37, + ) + ], + ), + # Test for 'saveAsHadoopFile' function + ( + """spark.saveAsHadoopFile("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=43, + ) + ], + ), + # Test for 'saveAsNewAPIHadoopFile' function + ( + """spark.saveAsNewAPIHadoopFile("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=49, + ) + ], + ), + # Test for 'saveAsPickleFile' function + ( + """spark.saveAsPickleFile("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=43, + ) + ], + ), + # Test for 'saveAsSequenceFile' function + ( + """spark.saveAsSequenceFile("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=45, + ) + ], + ), + # Test for 'saveAsTextFile' function + ( + """spark.saveAsTextFile("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=41, + ) + ], + ), + # Test for 'load_from_path' function + ( + """spark.load_from_path("s3a://bucket/path")""", + [ + Deprecation( + code='direct-filesystem-access', + message="The use of direct filesystem references is deprecated: s3a://bucket/path", + start_line=1, + start_col=0, + end_line=1, + end_col=41, + ) + ], + ), + ], +) +def test_spark_cloud_direct_access(empty_index, code, expected): + ftf = FromTable(empty_index, CurrentSessionState()) + sqf = SparkSql(ftf, empty_index) + advisories = list(sqf.lint(code)) + assert advisories == expected + + +FS_FUNCTIONS = [ + "ls", + "cp", + "rm", + "mv", + "head", + "put", + "mkdirs", +] + + +@pytest.mark.parametrize("fs_function", FS_FUNCTIONS) +def test_direct_cloud_access_reports_nothing(empty_index, fs_function): + ftf = FromTable(empty_index, CurrentSessionState()) + sqf = SparkSql(ftf, empty_index) + # ls function calls have to be from dbutils.fs, or we ignore them + code = f"""spark.{fs_function}("/bucket/path")""" + advisories = list(sqf.lint(code)) + assert not advisories + + +def test_get_full_function_name(): + + # Test when node.func is an instance of ast.Attribute + node = ast.Call(func=ast.Attribute(value=ast.Name(id='value'), attr='attr')) + # noinspection PyProtectedMember + assert AstHelper.get_full_function_name(node) == 'value.attr' + + # Test when node.func is an instance of ast.Name + node = ast.Call(func=ast.Name(id='name')) + # noinspection PyProtectedMember + assert AstHelper.get_full_function_name(node) == 'name' + + # Test when node.func is neither ast.Attribute nor ast.Name + node = ast.Call(func=ast.Constant(value='constant')) + # noinspection PyProtectedMember + assert AstHelper.get_full_function_name(node) is None + + # Test when next_node in _get_value is an instance of ast.Name + node = ast.Call(func=ast.Attribute(value=ast.Name(id='name'), attr='attr')) + # noinspection PyProtectedMember + assert AstHelper.get_full_function_name(node) == 'name.attr' + + # Test when next_node in _get_value is an instance of ast.Attribute + node = ast.Call(func=ast.Attribute(value=ast.Attribute(value=ast.Name(id='value'), attr='attr'), attr='attr')) + # noinspection PyProtectedMember + assert AstHelper.get_full_function_name(node) == 'value.attr.attr' + + # Test when next_node in _get_value is neither ast.Name nor ast.Attribute + node = ast.Call(func=ast.Attribute(value=ast.Constant(value='constant'), attr='attr')) + # noinspection PyProtectedMember + assert AstHelper.get_full_function_name(node) is None + + +def test_apply_table_name_matcher(migration_index): + from_table = FromTable(migration_index, CurrentSessionState('old')) + matcher = TableNameMatcher('things', 1, 1, 0) + + # Test when table_arg is an instance of ast.Constant but the destination does not exist in the index + node = ast.Call(args=[ast.Constant(value='some.things')]) + matcher.apply(from_table, migration_index, node) + table_constant = node.args[0] + assert isinstance(table_constant, ast.Constant) + assert table_constant.value == 'some.things' + + # Test when table_arg is an instance of ast.Constant and the destination exists in the index + node = ast.Call(args=[ast.Constant(value='old.things')]) + matcher.apply(from_table, migration_index, node) + table_constant = node.args[0] + assert isinstance(table_constant, ast.Constant) + assert table_constant.value == 'brand.new.stuff'