diff --git a/.github/deploy/steps/25-Provision-Service-Principal.ps1 b/.github/deploy/steps/25-Provision-Service-Principal.ps1 index 33464e66..1e4df7ba 100644 --- a/.github/deploy/steps/25-Provision-Service-Principal.ps1 +++ b/.github/deploy/steps/25-Provision-Service-Principal.ps1 @@ -29,8 +29,9 @@ $dbSpn = Get-SpnWithSecret -spnName $dbDeploySpnName -keyVaultName $keyVaultName $mountSpn = Get-SpnWithSecret -spnName $mountSpnName -keyVaultName $keyVaultName $secrets.addSecret("Databricks--TenantId", $tenantId) -$secrets.addSecret("Databricks--ClientId", $mountSpn.clientId) -$secrets.addSecret("Databricks--ClientSecret", $mountSpn.secretText) +$secrets.addSecret("DatabricksOauthEndpoint", "https://login.microsoftonline.com/$tenantId/oauth2/token") +$secrets.addSecret("DatabricksClientId", $mountSpn.clientId) +$secrets.addSecret("DatabricksClientSecret", $mountSpn.secretText) # there is a chicken-and-egg problem where we want to save the new SPN secret in the # keyvault, but the keyvault may not exist yet. This doesn't matter since the keyvault diff --git a/.github/deploy/steps/91-Create-SparkConf.ps1 b/.github/deploy/steps/91-Create-SparkConf.ps1 new file mode 100644 index 00000000..a785580e --- /dev/null +++ b/.github/deploy/steps/91-Create-SparkConf.ps1 @@ -0,0 +1,26 @@ + +Write-Host "Write cluster configuration for Direct Access..." -ForegroundColor DarkYellow + +$confDirectAccess = [ordered]@{} + +$confDirectAccess["spark.databricks.cluster.profile"]= "singleNode" +$confDirectAccess["spark.databricks.delta.preview.enabled"] = $true +$confDirectAccess["spark.databricks.io.cache.enabled"] = $true +$confDirectAccess["spark.master"]= "local[*, 4]" + + +$url = az storage account show --name $dataLakeName --resource-group $resourceGroupName --query "primaryEndpoints.dfs" --out tsv + +$storageUrl = ([System.Uri]$url).Host + +Write-Host " Adds Direct Access for $storageUrl..." -ForegroundColor DarkYellow + +$confDirectAccess["fs.azure.account.auth.type.$storageUrl"] = "OAuth" +$confDirectAccess["fs.azure.account.oauth.provider.type.$storageUrl"] = "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider" +$confDirectAccess["fs.azure.account.oauth2.client.id.$storageUrl"] = "{{secrets/secrets/DatabricksClientId}}" +$confDirectAccess["fs.azure.account.oauth2.client.endpoint.$storageUrl"] = "{{secrets/secrets/DatabricksOauthEndpoint}}" +$confDirectAccess["fs.azure.account.oauth2.client.secret.$storageUrl"] = "{{secrets/secrets/DatabricksClientSecret}}" + +$values.addSecret("StorageAccount--Url", $storageUrl) + +Set-Content $repoRoot\.github\submit\sparkconf.json ($confDirectAccess | ConvertTo-Json) diff --git a/.github/deploy/steps/90-Databricks-Secrets.ps1 b/.github/deploy/steps/95-Databricks-Secrets.ps1 similarity index 100% rename from .github/deploy/steps/90-Databricks-Secrets.ps1 rename to .github/deploy/steps/95-Databricks-Secrets.ps1 diff --git a/.github/deploy/steps/99-SetupMounts.ps1 b/.github/deploy/steps/99-SetupMounts.ps1 index 83023ec9..0a3948fe 100644 --- a/.github/deploy/steps/99-SetupMounts.ps1 +++ b/.github/deploy/steps/99-SetupMounts.ps1 @@ -1,28 +1,33 @@ -$srcDir = "$PSScriptRoot/../../.." -Push-Location -Path $srcDir - -pip install dbx - -dbx configure -copy "$srcDir/.github/submit/sparklibs.json" "$srcDir/tests/cluster/mount/" - -$mountsJson = (,@( - @{ - storageAccountName=$resourceName - secretScope="secrets" - clientIdName="Databricks--ClientId" - clientSecretName="Databricks--ClientSecret" - tenantIdName="Databricks--TenantId" - containers = [array]$($dataLakeContainers | ForEach-Object{ $_.name }) - } -)) - -$mountsJson | ConvertTo-Json -Depth 4 | Set-Content "$srcDir/tests/cluster/mount/mounts.json" - -dbx deploy --deployment-file "$srcDir/tests/cluster/mount/setup_job.yml.j2" - -dbx launch --job="Setup Mounts" --trace --kill-on-sigterm - -Pop-Location +### This entire file is deprecated since we now use direct access and do not mount any +### more. Nevertheless, the code is kept until we remove the class EventHubCapture +### whose code requires a mount point if it is to be tested. +# +# $srcDir = "$PSScriptRoot/../../.." +# +# Push-Location -Path $srcDir +# +# pip install dbx +# +# dbx configure +# copy "$srcDir/.github/submit/sparklibs.json" "$srcDir/tests/cluster/mount/" +# +# $mountsJson = (,@( +# @{ +# storageAccountName=$resourceName +# secretScope="secrets" +# clientIdName="Databricks--ClientId" +# clientSecretName="Databricks--ClientSecret" +# tenantIdName="Databricks--TenantId" +# containers = [array]$($dataLakeContainers | ForEach-Object{ $_.name }) +# } +# )) +# +# $mountsJson | ConvertTo-Json -Depth 4 | Set-Content "$srcDir/tests/cluster/mount/mounts.json" +# +# dbx deploy --deployment-file "$srcDir/tests/cluster/mount/setup_job.yml.j2" +# +# dbx launch --job="Setup Mounts" --trace --kill-on-sigterm +# +# Pop-Location diff --git a/.github/submit/sparkconf_deprecated.json b/.github/submit/sparkconf_deprecated.json new file mode 100644 index 00000000..8c463fb8 --- /dev/null +++ b/.github/submit/sparkconf_deprecated.json @@ -0,0 +1,14 @@ +{ + "spark.databricks.cluster.profile": "singleNode", + "spark.master": "local[*, 4]", + "spark.databricks.delta.preview.enabled": true, + "spark.databricks.io.cache.enabled": true, + + "fs.azure.account.auth.type.$resourceName.dfs.core.windows.net": "OAuth", + "fs.azure.account.oauth.provider.type.$resourceName.dfs.core.windows.net": "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider", + "fs.azure.account.oauth2.client.id.$resourceName.dfs.core.windows.net": "{{ secrets/secrets/DatabricksClientId }}", + "fs.azure.account.oauth2.client.secret.$resourceName.dfs.core.windows.net": "{{ secrets/secrets/DatabricksClientSecret }}", + "fs.azure.account.oauth2.client.endpoint.$resourceName.dfs.core.windows.net": "{{ secrets/secrets/DatabricksOauthEndpoint }}" +} +TODO: fix the resource name. Not possible in this json +UPDATE: See 20-Create-SparkConf.ps1 \ No newline at end of file diff --git a/.github/submit/submit_test_job.ps1 b/.github/submit/submit_test_job.ps1 index 38129fee..0f7e22d4 100644 --- a/.github/submit/submit_test_job.ps1 +++ b/.github/submit/submit_test_job.ps1 @@ -82,12 +82,7 @@ $run = @{ # single node cluster is sufficient new_cluster= @{ spark_version=$sparkVersion - spark_conf= @{ - "spark.databricks.cluster.profile"= "singleNode" - "spark.master"= "local[*, 4]" - "spark.databricks.delta.preview.enabled"= $true - "spark.databricks.io.cache.enabled"= $true - } + spark_conf = Get-Content "$PSScriptRoot/sparkconf.json" | ConvertFrom-Json azure_attributes=${ "availability"= "ON_DEMAND_AZURE", "first_on_demand": 1, diff --git a/setup.cfg b/setup.cfg index 535aa429..be3a33b0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,7 +49,6 @@ where = src console_scripts = python3 = atc.alias:python3 atc-dataplatform-git-hooks = atc.formatting.git_hooks:main - atc-dataplatform-mounts = atc.mount.main:main [flake8] exclude = .git,__pycache__,docs,build,dist,venv diff --git a/src/atc/delta/db_handle.py b/src/atc/delta/db_handle.py index aeb2ef99..801f4faf 100644 --- a/src/atc/delta/db_handle.py +++ b/src/atc/delta/db_handle.py @@ -50,5 +50,6 @@ def drop_cascade(self) -> None: def create(self) -> None: sql = f"CREATE DATABASE IF NOT EXISTS {self._name} " if self._location: - sql += f" LOCATION '{self._location}'" + sql += f' LOCATION "{self._location}"' + print("execute sql:", sql) Spark.get().sql(sql) diff --git a/src/atc/delta/delta_handle.py b/src/atc/delta/delta_handle.py index 92a962fb..79db959a 100644 --- a/src/atc/delta/delta_handle.py +++ b/src/atc/delta/delta_handle.py @@ -92,7 +92,13 @@ def append(self, df: DataFrame, mergeSchema: bool = None) -> None: return self.write_or_append(df, "append", mergeSchema=mergeSchema) def truncate(self) -> None: - Spark.get().sql(f"TRUNCATE TABLE {self._name};") + if self._location: + Spark.get().sql(f"TRUNCATE TABLE delta.`{self._location}`;") + else: + Spark.get().sql(f"TRUNCATE TABLE {self._name};") + # if the hive table does not exit, this will give a useful error like + # pyspark.sql.utils.AnalysisException: + # Table not found for 'TRUNCATE TABLE': TestDb.TestTbl; def drop(self) -> None: Spark.get().sql(f"DROP TABLE IF EXISTS {self._name};") @@ -105,7 +111,8 @@ def drop_and_delete(self) -> None: def create_hive_table(self) -> None: sql = f"CREATE TABLE IF NOT EXISTS {self._name} " if self._location: - sql += f" USING DELTA LOCATION '{self._location}'" + sql += f' USING DELTA LOCATION "{self._location}"' + print("execute sql:", sql) Spark.get().sql(sql) def recreate_hive_table(self): diff --git a/src/atc/mount/main.py b/src/atc/mount/main.py index ee5f50dd..bdd2d59f 100644 --- a/src/atc/mount/main.py +++ b/src/atc/mount/main.py @@ -2,10 +2,20 @@ import json from types import SimpleNamespace +import deprecation + +import atc from atc.atc_exceptions import AtcException from atc.functions import init_dbutils +@deprecation.deprecated( + deprecated_in="1.0.48", + removed_in="2.0.0", + current_version=atc.__version__, + details="use direct access instead. " + "See the atc-dataplatform unittests, for example.", +) def main(): parser = argparse.ArgumentParser(description="atc-dataplatform mountpoint setup.") parser.add_argument( diff --git a/tests/cluster/config/__init__.py b/tests/cluster/config/__init__.py new file mode 100644 index 00000000..0de8a295 --- /dev/null +++ b/tests/cluster/config/__init__.py @@ -0,0 +1,18 @@ +from atc import Configurator +from tests.cluster import values +from tests.cluster.values import storageAccountUrl + + +def InitConfigurator(*, clear=False): + """This example function is how you would use the Configurator in a project.""" + tc = Configurator() + if clear: + tc.clear_all_configurations() + + # This is how you would set yourself up for different environments + # tc.register('ENV','dev') + + tc.register("resourceName", values.resourceName()) + tc.register("storageAccount", f"abfss://silver@{storageAccountUrl()}") + + return tc diff --git a/tests/cluster/cosmos/test_cosmos.py b/tests/cluster/cosmos/test_cosmos.py index c72697f5..11d8497a 100644 --- a/tests/cluster/cosmos/test_cosmos.py +++ b/tests/cluster/cosmos/test_cosmos.py @@ -4,24 +4,24 @@ import atc.cosmos from atc import Configurator -from atc.functions import init_dbutils from atc.spark import Spark +from tests.cluster.config import InitConfigurator +from tests.cluster.secrets import cosmosAccountKey +from tests.cluster.values import cosmosEndpoint class TestCosmos(atc.cosmos.CosmosDb): def __init__(self): - dbutils = init_dbutils() super().__init__( - endpoint=dbutils.secrets.get("values", "Cosmos--Endpoint"), - account_key=dbutils.secrets.get("secrets", "Cosmos--AccountKey"), + endpoint=cosmosEndpoint(), + account_key=cosmosAccountKey(), database="AtcCosmosContainer", ) class CosmosTests(unittest.TestCase): def test_01_tables(self): - tc = Configurator() - tc.clear_all_configurations() + tc = InitConfigurator(clear=True) tc.register( "CmsTbl", { diff --git a/tests/cluster/delta/test_delta_class.py b/tests/cluster/delta/test_delta_class.py index 1fb80c84..5d75f0b0 100644 --- a/tests/cluster/delta/test_delta_class.py +++ b/tests/cluster/delta/test_delta_class.py @@ -1,5 +1,7 @@ +import time import unittest +from py4j.protocol import Py4JJavaError from pyspark.sql.utils import AnalysisException from atc import Configurator @@ -7,39 +9,42 @@ from atc.etl import Orchestrator from atc.etl.extractors import SimpleExtractor from atc.etl.loaders import SimpleLoader +from atc.functions import init_dbutils from atc.spark import Spark +from tests.cluster.config import InitConfigurator class DeltaTests(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - Configurator().clear_all_configurations() + InitConfigurator(clear=True) def test_01_configure(self): tc = Configurator() + tc.register( - "MyDb", {"name": "TestDb{ID}", "path": "/mnt/atc/silver/testdb{ID}"} + "MyDb", {"name": "TestDb{ID}", "path": "{storageAccount}/testdb{ID}"} ) tc.register( "MyTbl", { - "name": "TestDb{ID}.TestTbl", - "path": "/mnt/atc/silver/testdb{ID}/testtbl", + "name": "{MyDb}.TestTbl", + "path": "{MyDb_path}/testtbl", }, ) tc.register( "MyTbl2", { - "name": "TestDb{ID}.TestTbl2", + "name": "{MyDb}.TestTbl2", }, ) tc.register( "MyTbl3", { - "path": "/mnt/atc/silver/testdb/testtbl3", + "path": "{storageAccount}/testdb/testtbl3", }, ) @@ -67,12 +72,40 @@ def test_02_write(self): dh.append(df, mergeSchema=True) + # @unittest.skip("Flaky test") def test_03_create(self): + # print(Configurator().get_all_details()) + # print( + # { + # k: v[:-15] + v[-12:] + # for k, v in Spark.get().sparkContext.getConf().getAll() + # if k.startswith("fs.azure.account") + # } + # ) + db = DbHandle.from_tc("MyDb") db.create() dh = DeltaHandle.from_tc("MyTbl") - dh.create_hive_table() + tc = Configurator() + print(init_dbutils().fs.ls(tc.get("MyTbl", "path"))) + print( + init_dbutils().fs.put( + tc.get("MyTbl", "path") + "/some.file.txt", "Hello, ATC!", True + ) + ) + print(init_dbutils().fs.ls(tc.get("MyTbl", "path"))) + for i in range(10, 0, -1): + try: + dh.create_hive_table() + break + except (AnalysisException, Py4JJavaError) as e: + if i > 0: + print(e) + print("trying again in 10 seconds") + time.sleep(10) + else: + raise e # test hive access: df = Spark.get().table("TestDb.TestTbl") diff --git a/tests/cluster/delta/test_sparkexecutor.py b/tests/cluster/delta/test_sparkexecutor.py index 66ec8c7f..60662617 100644 --- a/tests/cluster/delta/test_sparkexecutor.py +++ b/tests/cluster/delta/test_sparkexecutor.py @@ -1,8 +1,8 @@ import unittest -from atc import Configurator from atc.delta import DbHandle, DeltaHandle from atc.spark import Spark +from tests.cluster.config import InitConfigurator from tests.cluster.delta import extras from tests.cluster.delta.SparkExecutor import SparkSqlExecutor @@ -16,7 +16,7 @@ class DeliverySparkExecutorTests(unittest.TestCase): def setUpClass(cls): # Register the delivery table for the table configurator - cls.tc = Configurator() + cls.tc = InitConfigurator() cls.tc.add_resource_path(extras) cls.tc.set_debug() diff --git a/tests/cluster/eh/AtcEh.py b/tests/cluster/eh/AtcEh.py index 85c84258..603bfeaa 100644 --- a/tests/cluster/eh/AtcEh.py +++ b/tests/cluster/eh/AtcEh.py @@ -2,13 +2,13 @@ This file sets up the EventHub that is deployed as part of the atc integration pipeline """ from atc.eh import EventHubStream -from atc.functions import init_dbutils +from tests.cluster.secrets import eventHubConnection class AtcEh(EventHubStream): def __init__(self): super().__init__( - connection_str=init_dbutils().secrets.get("secrets", "EventHubConnection"), + connection_str=eventHubConnection(), entity_path="atceh", consumer_group="$Default", ) diff --git a/tests/cluster/eh/test_eh_json_orchestrator.py b/tests/cluster/eh/test_eh_json_orchestrator.py index c8859cc6..38408ef3 100644 --- a/tests/cluster/eh/test_eh_json_orchestrator.py +++ b/tests/cluster/eh/test_eh_json_orchestrator.py @@ -8,6 +8,7 @@ from atc.eh import EventHubCaptureExtractor from atc.orchestrators.ehjson2delta.EhJsonToDeltaExtractor import EhJsonToDeltaExtractor from atc.spark import Spark +from tests.cluster.config import InitConfigurator class JsonEhOrchestratorUnitTests(unittest.TestCase): @@ -15,8 +16,7 @@ class JsonEhOrchestratorUnitTests(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - cls.tc = TableConfigurator() - cls.tc.clear_all_configurations() + cls.tc = InitConfigurator(clear=True) cls.tc.register("TblYMD", {"name": "TableYMD"}) cls.tc.register("TblYMDH", {"name": "TableYMDH"}) cls.tc.register("TblPdate", {"name": "TablePdate"}) diff --git a/tests/cluster/eh/test_eh_saving.py b/tests/cluster/eh/test_eh_saving.py index 9e235dc6..5e600dbc 100644 --- a/tests/cluster/eh/test_eh_saving.py +++ b/tests/cluster/eh/test_eh_saving.py @@ -14,7 +14,7 @@ from atc.functions import init_dbutils from atc.orchestrators import EhJsonToDeltaOrchestrator from atc.spark import Spark -from tests.cluster.values import resourceName +from tests.cluster.config import InitConfigurator from .AtcEh import AtcEh @@ -22,7 +22,7 @@ class EventHubsTests(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - Configurator().clear_all_configurations() + InitConfigurator(clear=True) def test_01_publish(self): eh = AtcEh() @@ -35,33 +35,40 @@ def test_02_wait_for_capture_files(self): # wait until capture file appears dbutils = init_dbutils() + silverContainer = Configurator().get("storageAccount") + resourceName = Configurator().get("resourceName") + limit = datetime.now() + timedelta(minutes=10) + conts = None while datetime.now() < limit: - conts = { - item.name for item in dbutils.fs.ls(f"/mnt/{resourceName()}/silver") - } - if f"{resourceName()}/" in conts: + conts = {item.name for item in dbutils.fs.ls(silverContainer)} + if f"{resourceName}/" in conts: break else: time.sleep(10) continue else: + print("Only found:", conts) self.assertTrue(False, "The capture file never appeared.") self.assertTrue(True, "The capture file has appeared.") + @unittest.skip(reason="EventHubCapture is deprecated an no longer tested.") + # not testing this class allows us to skip the mount point setup job def test_03_read_eh_capture(self): tc = Configurator() + tc.register( "AtcEh", { "name": "AtcEh", - "path": f"/mnt/{resourceName()}/silver/{resourceName()}/atceh", + "path": "{storageAccount}/{resourceName}/atceh", "format": "avro", "partitioning": "ymd", }, ) + eh = EventHubCapture.from_tc("AtcEh") df = eh.read() @@ -76,7 +83,7 @@ def test_04_read_eh_capture_extractor(self): tc.register( "AtcEh", { - "path": f"/mnt/{resourceName()}/silver/{resourceName()}/atceh", + "path": "{storageAccount}/{resourceName()}/atceh", "format": "avro", "partitioning": "ymd", }, diff --git a/tests/cluster/etl/test_deltaupsert.py b/tests/cluster/etl/test_deltaupsert.py index 818050ca..81439f36 100644 --- a/tests/cluster/etl/test_deltaupsert.py +++ b/tests/cluster/etl/test_deltaupsert.py @@ -2,9 +2,9 @@ from atc_tools.testing import DataframeTestCase -from atc import Configurator from atc.delta import DbHandle, DeltaHandle from atc.utils import DataframeCreator +from tests.cluster.config import InitConfigurator from tests.cluster.delta import extras from tests.cluster.delta.SparkExecutor import SparkSqlExecutor @@ -37,8 +37,9 @@ class DeltaUpsertTests(DataframeTestCase): @classmethod def setUpClass(cls) -> None: - Configurator().add_resource_path(extras) - Configurator().set_debug() + tc = InitConfigurator(clear=True) + tc.add_resource_path(extras) + tc.set_debug() cls.target_dh_dummy = DeltaHandle.from_tc("UpsertLoaderDummy") @@ -68,7 +69,7 @@ def test_02_can_perform_overwrite_over_existing(self): """The target table is already filled from before. This test does not test .upsert() logic, but ensures that test 03 resembles an upsert after a full load. - If one needs to make an full load, use the .overwrite() method""" + If one needs to make a full load, use the .overwrite() method""" self.assertEqual(2, len(self.target_dh_dummy.read().collect())) df_source = DataframeCreator.make_partial( diff --git a/tests/cluster/etl/test_loader.py b/tests/cluster/etl/test_loader.py index dbfd79c7..f6e123d8 100644 --- a/tests/cluster/etl/test_loader.py +++ b/tests/cluster/etl/test_loader.py @@ -8,16 +8,18 @@ class LoaderTests(unittest.TestCase): + loader: Loader + @classmethod - def setUp(self): + def setUp(cls): - self.loader = Loader() - self.loader.save = MagicMock() - self.loader.save_many = MagicMock() + cls.loader = Loader() + cls.loader.save = MagicMock() + cls.loader.save_many = MagicMock() - self.df_1 = create_dataframe() - self.df_2 = create_dataframe() - self.df_3 = create_dataframe() + cls.df_1 = create_dataframe() + cls.df_2 = create_dataframe() + cls.df_3 = create_dataframe() def test_return_inputs(self): # Assert Loader returns ouput diff --git a/tests/cluster/etl/test_upsertloader.py b/tests/cluster/etl/test_upsertloader.py index 4e156b8f..e56775ca 100644 --- a/tests/cluster/etl/test_upsertloader.py +++ b/tests/cluster/etl/test_upsertloader.py @@ -2,10 +2,10 @@ from atc_tools.testing import DataframeTestCase -from atc import Configurator from atc.delta import DbHandle, DeltaHandle from atc.etl.loaders.UpsertLoader import UpsertLoader from atc.utils import DataframeCreator +from tests.cluster.config import InitConfigurator from tests.cluster.delta import extras from tests.cluster.delta.SparkExecutor import SparkSqlExecutor @@ -33,8 +33,9 @@ class UpsertLoaderTests(DataframeTestCase): @classmethod def setUpClass(cls) -> None: - Configurator().add_resource_path(extras) - Configurator().set_debug() + tc = InitConfigurator(clear=True) + tc.add_resource_path(extras) + tc.set_debug() cls.target_dh_dummy = DeltaHandle.from_tc("UpsertLoaderDummy") diff --git a/tests/cluster/mount/README.md b/tests/cluster/mount/README.md new file mode 100644 index 00000000..97cabc17 --- /dev/null +++ b/tests/cluster/mount/README.md @@ -0,0 +1,8 @@ +This entire folder and its functions are deprecated. +Mounting storage accounts is not the recommended way to access them. +The code is kept here to allow it to be reactivated if needed. + +The tests of the class `EventHubCapture` requires mount-points, or so it seems. +The external avro SerDe classes do not seem to support direct access. Those tests +are skipped together with the deprecation of the class itself. If those tests are +removed entirely, this folder should probably disappear, too. diff --git a/tests/cluster/secrets.py b/tests/cluster/secrets.py new file mode 100644 index 00000000..ab3300ca --- /dev/null +++ b/tests/cluster/secrets.py @@ -0,0 +1,32 @@ +from functools import lru_cache + +from atc.functions import init_dbutils + + +@lru_cache +def getSecret(secret_name: str): + return init_dbutils().secrets.get("secrets", secret_name) + + +def cosmosAccountKey(): + return getSecret("Cosmos--AccountKey") + + +def eventHubConnection(): + return getSecret("EventHubConnection") + + +def sqlServerUser(): + return getSecret("SqlServer--DatabricksUser") + + +def sqlServerUserPassword(): + return getSecret("SqlServer--DatabricksUserPassword") + + +def clientSecret(): + return getSecret("DatabricksClientSecret") + + +def clientId(): + return getSecret("DatabricksClientId") diff --git a/tests/cluster/sql/DeliverySqlServer.py b/tests/cluster/sql/DeliverySqlServer.py index f4b84591..6f4e75e7 100644 --- a/tests/cluster/sql/DeliverySqlServer.py +++ b/tests/cluster/sql/DeliverySqlServer.py @@ -1,6 +1,5 @@ -from atc.functions import init_dbutils from atc.sql.SqlServer import SqlServer -from tests.cluster.values import resourceName +from tests.cluster.secrets import sqlServerUser, sqlServerUserPassword class DeliverySqlServer(SqlServer): @@ -14,20 +13,10 @@ def __init__( ): self.hostname = ( - f"{resourceName()}test.database.windows.net" - if hostname is None - else hostname - ) - self.username = ( - init_dbutils().secrets.get("secrets", "SqlServer--DatabricksUser") - if username is None - else username - ) - self.password = ( - init_dbutils().secrets.get("secrets", "SqlServer--DatabricksUserPassword") - if password is None - else password + "{resourceName}test.database.windows.net" if hostname is None else hostname ) + self.username = sqlServerUser() if username is None else username + self.password = sqlServerUserPassword() if password is None else password self.database = database self.port = port super().__init__( diff --git a/tests/cluster/sql/DeliverySqlServerSpn.py b/tests/cluster/sql/DeliverySqlServerSpn.py index 6a20d451..e4d4cc9e 100644 --- a/tests/cluster/sql/DeliverySqlServerSpn.py +++ b/tests/cluster/sql/DeliverySqlServerSpn.py @@ -1,6 +1,5 @@ -from atc.functions import init_dbutils from atc.sql.SqlServer import SqlServer -from tests.cluster.values import resourceName +from tests.cluster.secrets import clientId, clientSecret class DeliverySqlServerSpn(SqlServer): @@ -8,10 +7,8 @@ def __init__( self, ): super().__init__( - hostname=f"{resourceName()}test.database.windows.net", + hostname="{resourceName}test.database.windows.net", database="Delivery", - spnpassword=init_dbutils().secrets.get( - "secrets", "Databricks--ClientSecret" - ), - spnid=init_dbutils().secrets.get("secrets", "Databricks--ClientId"), + spnpassword=clientSecret(), + spnid=clientId(), ) diff --git a/tests/cluster/sql/test_deliveryexecutor.py b/tests/cluster/sql/test_deliveryexecutor.py index da0189c8..ed0935f8 100644 --- a/tests/cluster/sql/test_deliveryexecutor.py +++ b/tests/cluster/sql/test_deliveryexecutor.py @@ -1,6 +1,6 @@ import unittest -from atc import Configurator +from tests.cluster.config import InitConfigurator from tests.cluster.sql.DeliverySqlExecutor import DeliverySqlExecutor from tests.cluster.sql.DeliverySqlServer import DeliverySqlServer @@ -13,13 +13,14 @@ class DeliverySqlExecutorTests(unittest.TestCase): @classmethod def setUpClass(cls): - cls.sql_server = DeliverySqlServer() # Register the delivery table for the table configurator - cls.tc = Configurator() + cls.tc = InitConfigurator(clear=True) cls.tc.add_resource_path(extras) cls.tc.set_debug() + cls.sql_server = DeliverySqlServer() + # Ensure no table is there cls.sql_server.drop_table("SqlTestTable1") cls.sql_server.drop_table("SqlTestTable2") diff --git a/tests/cluster/sql/test_deliverysql.py b/tests/cluster/sql/test_deliverysql.py index 1a06d680..3a83be39 100644 --- a/tests/cluster/sql/test_deliverysql.py +++ b/tests/cluster/sql/test_deliverysql.py @@ -2,9 +2,9 @@ from pyspark.sql import DataFrame from pyspark.sql.types import IntegerType, StringType, StructField, StructType -from atc import Configurator from atc.functions import get_unique_tempview_name from atc.utils import DataframeCreator +from tests.cluster.config import InitConfigurator from tests.cluster.sql.DeliverySqlServer import DeliverySqlServer from . import extras @@ -20,12 +20,12 @@ class DeliverySqlServerTests(DataframeTestCase): @classmethod def setUpClass(cls): - cls.sql_server = DeliverySqlServer() - cls.tc = Configurator() - + cls.tc = InitConfigurator(clear=True) cls.tc.add_resource_path(extras) cls.tc.reset(debug=True) + cls.sql_server = DeliverySqlServer() + @classmethod def tearDownClass(cls) -> None: cls.sql_server.drop_table_by_name(cls.table_name) @@ -106,7 +106,7 @@ def test09_execute_sql_file(self): self.assertTrue(True) def test10_read_w_id(self): - # This might fail if the previous test didnt succeed + # This might fail if the previous test didn't succeed self.sql_server.read_table("SqlTestTable1") self.sql_server.read_table("SqlTestTable2") self.assertTrue(True) diff --git a/tests/cluster/sql/test_simple_sql_etl.py b/tests/cluster/sql/test_simple_sql_etl.py index 10cacca1..c88a5bb8 100644 --- a/tests/cluster/sql/test_simple_sql_etl.py +++ b/tests/cluster/sql/test_simple_sql_etl.py @@ -12,11 +12,11 @@ TimestampType, ) -from atc import Configurator from atc.etl.loaders import SimpleLoader from atc.functions import get_unique_tempview_name from atc.transformers.simple_sql_transformer import SimpleSqlServerTransformer from atc.utils import DataframeCreator +from tests.cluster.config import InitConfigurator from tests.cluster.sql.DeliverySqlServer import DeliverySqlServer @@ -28,8 +28,8 @@ class SimpleSqlServerETLTests(unittest.TestCase): @classmethod def setUpClass(cls): + cls.tc = InitConfigurator() cls.sql_server = DeliverySqlServer() - cls.tc = Configurator() cls.tc.clear_all_configurations() # Register the delivery table for the table configurator diff --git a/tests/cluster/transformations/test_concat_df.py b/tests/cluster/transformations/test_concat_df.py index 4467cdb7..9176c5c9 100644 --- a/tests/cluster/transformations/test_concat_df.py +++ b/tests/cluster/transformations/test_concat_df.py @@ -189,7 +189,7 @@ def create_df3(): def get_number_rows_1(df, id, brand, model, year=None): - # When testing the first transformation, theres only "year" as column + # When testing the first transformation, there's only "year" as column if year is not None: return df.filter( (f.col("id") == id) @@ -201,7 +201,7 @@ def get_number_rows_1(df, id, brand, model, year=None): def get_number_rows_2(df, id, brand, model, year=None, size=None): - # When testing the second transformation, theres only "year" and "size" as column + # When testing the second transformation, there's only "year" and "size" as column if year is not None: return df.filter( (f.col("id") == id) diff --git a/tests/cluster/values.py b/tests/cluster/values.py index aca9bfe1..d066c6c2 100644 --- a/tests/cluster/values.py +++ b/tests/cluster/values.py @@ -10,3 +10,11 @@ def getValue(secret_name: str): def resourceName(): return getValue("resourceName") + + +def cosmosEndpoint(): + return getValue("Cosmos--Endpoint") + + +def storageAccountUrl(): + return getValue("StorageAccount--Url")