From 758886102fec6364ef9a39f8946a3b118ecfaa4b Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Thu, 31 Aug 2023 11:40:24 +0200 Subject: [PATCH 1/3] [DECO-2483] Handle Azure authentication when WorkspaceResourceID is provided --- .../sdk/core/AzureCliCredentialsProvider.java | 46 +++++++- .../databricks/sdk/core/DatabricksConfig.java | 2 +- ...reServicePrincipalCredentialsProvider.java | 29 ++++- .../databricks/sdk/core/utils/AzureUtils.java | 54 +++++++-- .../core/AzureCliCredentialsProviderTest.java | 98 ++++++++++++++++ ...rvicePrincipalCredentialsProviderTest.java | 108 ++++++++++++++++++ 6 files changed, 316 insertions(+), 21 deletions(-) create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java index ea2dcbd2e..b32a3124c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java @@ -1,6 +1,7 @@ package com.databricks.sdk.core; import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.oauth.TokenSource; import com.databricks.sdk.core.utils.AzureUtils; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.*; @@ -27,6 +28,15 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) { return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); } + @Override + public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource, String subscription) { + List cmd = + new ArrayList<>( + Arrays.asList( + "az", "account", "get-access-token", "--subscription", subscription, "--resource", resource, "--output", "json")); + return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); + } + @Override public HeaderFactory configure(DatabricksConfig config) { if (!config.isAzure()) { @@ -35,20 +45,44 @@ public HeaderFactory configure(DatabricksConfig config) { try { ensureHostPresent(config, mapper); + CliTokenSource tokenSource; + CliTokenSource mgmtTokenSource; String resource = config.getEffectiveAzureLoginAppId(); - CliTokenSource tokenSource = tokenSourceFor(config, resource); - CliTokenSource mgmtTokenSource = - tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); - tokenSource.getToken(); // We need this for checking if Azure CLI is installed. + Optional subscription = getSubscription(config); + + if (subscription.isPresent()) { + try { + // This will fail if the user has access to the workspace, but not to the subscription itself. + // In such case, we fall back to not using the subscription. + tokenSource = tokenSourceFor(config, resource, subscription.get()); + tokenSource.getToken(); + mgmtTokenSource = + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint(), subscription.get()); + } catch (DatabricksException e) { + LOG.warn("Failed to get token for subscription. Using resource only token."); + tokenSource = tokenSourceFor(config, resource); + mgmtTokenSource = + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + } + } else { + LOG.warn("azure_workspace_resource_id field not provided. " + + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); + tokenSource = tokenSourceFor(config, resource); + mgmtTokenSource = + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + } + + tokenSource.getToken(); // We need this for checking if Azure CLI is installed try { mgmtTokenSource.getToken(); } catch (Exception e) { LOG.debug("Not including service management token in headers", e); mgmtTokenSource = null; } + TokenSource finalToken = tokenSource; CliTokenSource finalMgmtTokenSource = mgmtTokenSource; return () -> { - Token token = tokenSource.getToken(); + Token token = finalToken.getToken(); Map headers = new HashMap<>(); headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken()); if (finalMgmtTokenSource != null) { @@ -67,3 +101,5 @@ public HeaderFactory configure(DatabricksConfig config) { } } } + + diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index 2fedbcf82..60feac6d5 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -74,7 +74,7 @@ public class DatabricksConfig { sensitive = true) private String googleCredentials; - /** Azure Resource Manager ID for Azure Databricks workspace, which is exhanged for a Host */ + /** Azure Resource Manager ID for Azure Databricks workspace, which is exchanged for a Host */ @ConfigAttribute( value = "azure_workspace_resource_id", env = "DATABRICKS_AZURE_RESOURCE_ID", diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index ea5a5dee8..e867409fa 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import java.util.HashMap; import java.util.Map; +import java.util.Optional; /** * Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, @@ -27,9 +28,31 @@ public HeaderFactory configure(DatabricksConfig config) { return null; } ensureHostPresent(config, mapper); - RefreshableTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); - RefreshableTokenSource cloud = - tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + RefreshableTokenSource innerToken; + RefreshableTokenSource cloudToken; + Optional subscription = getSubscription(config); + if (subscription.isPresent()) { + try { + // This will fail if the service principal has access to the workspace, but not to the subscription itself. + // In such case, we fall back to not using the subscription. + innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId(), subscription.get()); + cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint(), subscription.get()); + innerToken.getToken(); + cloudToken.getToken(); + } catch (DatabricksException e) { + LOG.warn("Failed to get token for subscription. Using resource only token."); + innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); + cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + } + } else { + LOG.warn("azure_workspace_resource_id field not provided. " + + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); + innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); + cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + } + + RefreshableTokenSource inner = innerToken; + RefreshableTokenSource cloud = cloudToken; return () -> { Map headers = new HashMap<>(); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java index 1a73ea630..42d3cd332 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java @@ -1,5 +1,6 @@ package com.databricks.sdk.core.utils; +import com.databricks.sdk.core.AzureCliCredentialsProvider; import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.http.Request; @@ -11,11 +12,14 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.io.IOException; -import java.util.HashMap; -import java.util.Map; +import java.util.*; public interface AzureUtils { + static final Logger LOG = LoggerFactory.getLogger(AzureUtils.class); /** * Creates a RefreshableTokenSource for the specified Azure resource. @@ -30,20 +34,46 @@ public interface AzureUtils { * Azure resource. */ default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource) { - String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint(); - String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token"; Map endpointParams = new HashMap<>(); endpointParams.put("resource", resource); + return tokenSourceFor(config, endpointParams); + } + + default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource, String subscription) { + Map endpointParams = new HashMap<>(); + endpointParams.put("resource", resource); + endpointParams.put("subscription", subscription); + return tokenSourceFor(config, endpointParams); + } + + default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, Map endpointParams) { + String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint(); + String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token"; return new ClientCredentials.Builder() - .withHttpClient(config.getHttpClient()) - .withClientId(config.getAzureClientId()) - .withClientSecret(config.getAzureClientSecret()) - .withTokenUrl(tokenUrl) - .withEndpointParameters(endpointParams) - .withAuthParameterPosition(AuthParameterPosition.BODY) - .build(); + .withHttpClient(config.getHttpClient()) + .withClientId(config.getAzureClientId()) + .withClientSecret(config.getAzureClientSecret()) + .withTokenUrl(tokenUrl) + .withEndpointParameters(endpointParams) + .withAuthParameterPosition(AuthParameterPosition.BODY) + .build(); + } + + default Optional getSubscription(DatabricksConfig config) { + String resourceId = config.getAzureWorkspaceResourceId(); + if (resourceId == null || resourceId.equals("")) { + return Optional.empty(); + } + String[] components = resourceId.split("/"); + if (components.length < 3) { + LOG.warn("Invalid azure workspace resource ID"); + return Optional.empty(); + } + return Optional.of(components[2]); + } + default String getWorkspaceFromJsonResponse(ObjectNode jsonResponse) throws IOException { JsonNode properties = jsonResponse.get("properties"); if (properties == null) { @@ -69,7 +99,7 @@ default void ensureHostPresent(DatabricksConfig config, ObjectMapper mapper) { } String armEndpoint = config.getAzureEnvironment().getResourceManagerEndpoint(); - Token token = tokenSourceFor(config, armEndpoint).getToken(); + Token token = tokenSourceFor(config, "resource", armEndpoint).getToken(); String requestUrl = armEndpoint + config.getAzureWorkspaceResourceId() + "?api-version=2018-04-01"; Request req = new Request("GET", requestUrl); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java new file mode 100644 index 000000000..82dedb68d --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java @@ -0,0 +1,98 @@ +package com.databricks.sdk.core; + +import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.oauth.TokenSource; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.time.LocalDateTime; + +import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; + +class AzureCliCredentialsProviderTest { + + private static final String WORKSPACE_RESOURCE_ID = "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws"; + private static final String SUBSCRIPTION = "2a2345f8"; + private static final String TOKEN = "t-123"; + private static final String TOKEN_TYPE = "token-type"; + public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; + + + private static CliTokenSource mockTokenSource() { + CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); + Mockito.when(tokenSource.getToken()).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + return tokenSource; + } + private static AzureCliCredentialsProvider getAzureCliCredentialsProvider(TokenSource tokenSource) { + + AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider()); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + + return provider; + } + + + @Test + void testWorkSpaceIDUsage() { + AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource()); + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + } + + @Test + void testFallbackWhenTailsToGetTokenForSubscription() { + CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); + Mockito.when(tokenSource.getToken()).thenThrow(new DatabricksException("error")).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + + AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(tokenSource); + + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + + HeaderFactory header = provider.configure(config); + + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + } + + @Test + void testGetTokenWithoutWorkspaceResourceID() { + AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource()); + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + } + + +} \ No newline at end of file diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java new file mode 100644 index 000000000..3619bd7cc --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java @@ -0,0 +1,108 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.*; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.time.LocalDateTime; +import java.time.temporal.IsoFields; + +import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; + +class AzureServicePrincipalCredentialsProviderTest { + + private static final String WORKSPACE_RESOURCE_ID = "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws"; + private static final String SUBSCRIPTION = "2a2345f8"; + private static final String TOKEN = "t-123"; + private static final String TOKEN_TYPE = "token-type"; + public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; + + private static RefreshableTokenSource mockTokenSource() { + RefreshableTokenSource tokenSource = Mockito.mock(RefreshableTokenSource.class); + Mockito.when(tokenSource.getToken()).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now().plus(1, IsoFields.WEEK_BASED_YEARS))); + return tokenSource; + } + private static AzureServicePrincipalCredentialsProvider getAzureServicePrincipalCredentialsProvider(RefreshableTokenSource tokenSource) { + AzureServicePrincipalCredentialsProvider provider = Mockito.spy(new AzureServicePrincipalCredentialsProvider()); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + return provider; + } + + + @Test + void testWorkSpaceIDUsage() { + AzureServicePrincipalCredentialsProvider provider = getAzureServicePrincipalCredentialsProvider(mockTokenSource()); + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureClientId("clientID") + .setAzureClientSecret("clientSecret") + .setAzureTenantId("tenantID") + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, "Bearer " + TOKEN); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + } + + @Test + void testFallbackWhenTailsToGetTokenForSubscription() { + CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); + Mockito.when(tokenSource.getToken()).thenThrow(new DatabricksException("error")).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + + AzureServicePrincipalCredentialsProvider provider = getAzureServicePrincipalCredentialsProvider(tokenSource); + + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureClientId("clientID") + .setAzureClientSecret("clientSecret") + .setAzureTenantId("tenantID") + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, "Bearer " + TOKEN); + + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + } + + @Test + void testGetTokenWithoutWorkspaceResourceID() { + AzureServicePrincipalCredentialsProvider provider = getAzureServicePrincipalCredentialsProvider(mockTokenSource()); + DatabricksConfig config = new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureClientId("clientID") + .setAzureClientSecret("clientSecret") + .setAzureTenantId("tenantID"); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, "Bearer " + TOKEN); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + } + + +} \ No newline at end of file From db10a4359a95c4fb2aacbcecd9f41886f0d30c1e Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Thu, 31 Aug 2023 13:28:04 +0200 Subject: [PATCH 2/3] Fix formatting --- .../sdk/core/AzureCliCredentialsProvider.java | 36 ++- ...reServicePrincipalCredentialsProvider.java | 23 +- .../databricks/sdk/core/utils/AzureUtils.java | 30 ++- .../core/AzureCliCredentialsProviderTest.java | 183 ++++++++------- ...rvicePrincipalCredentialsProviderTest.java | 218 ++++++++++-------- 5 files changed, 272 insertions(+), 218 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java index b32a3124c..dce514288 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java @@ -29,11 +29,20 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) { } @Override - public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource, String subscription) { + public CliTokenSource tokenSourceFor( + DatabricksConfig config, String resource, String subscription) { List cmd = - new ArrayList<>( - Arrays.asList( - "az", "account", "get-access-token", "--subscription", subscription, "--resource", resource, "--output", "json")); + new ArrayList<>( + Arrays.asList( + "az", + "account", + "get-access-token", + "--subscription", + subscription, + "--resource", + resource, + "--output", + "json")); return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); } @@ -52,24 +61,29 @@ public HeaderFactory configure(DatabricksConfig config) { if (subscription.isPresent()) { try { - // This will fail if the user has access to the workspace, but not to the subscription itself. + // This will fail if the user has access to the workspace, but not to the subscription + // itself. // In such case, we fall back to not using the subscription. tokenSource = tokenSourceFor(config, resource, subscription.get()); tokenSource.getToken(); mgmtTokenSource = - tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint(), subscription.get()); + tokenSourceFor( + config, + config.getAzureEnvironment().getServiceManagementEndpoint(), + subscription.get()); } catch (DatabricksException e) { LOG.warn("Failed to get token for subscription. Using resource only token."); tokenSource = tokenSourceFor(config, resource); mgmtTokenSource = - tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); } } else { - LOG.warn("azure_workspace_resource_id field not provided. " + - "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); + LOG.warn( + "azure_workspace_resource_id field not provided. " + + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); tokenSource = tokenSourceFor(config, resource); mgmtTokenSource = - tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); } tokenSource.getToken(); // We need this for checking if Azure CLI is installed @@ -101,5 +115,3 @@ public HeaderFactory configure(DatabricksConfig config) { } } } - - diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index e867409fa..2a8345570 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -33,22 +33,31 @@ public HeaderFactory configure(DatabricksConfig config) { Optional subscription = getSubscription(config); if (subscription.isPresent()) { try { - // This will fail if the service principal has access to the workspace, but not to the subscription itself. + // This will fail if the service principal has access to the workspace, but not to the + // subscription itself. // In such case, we fall back to not using the subscription. - innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId(), subscription.get()); - cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint(), subscription.get()); + innerToken = + tokenSourceFor(config, config.getEffectiveAzureLoginAppId(), subscription.get()); + cloudToken = + tokenSourceFor( + config, + config.getAzureEnvironment().getServiceManagementEndpoint(), + subscription.get()); innerToken.getToken(); cloudToken.getToken(); } catch (DatabricksException e) { LOG.warn("Failed to get token for subscription. Using resource only token."); innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); - cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + cloudToken = + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); } } else { - LOG.warn("azure_workspace_resource_id field not provided. " + - "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); + LOG.warn( + "azure_workspace_resource_id field not provided. " + + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); - cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); + cloudToken = + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); } RefreshableTokenSource inner = innerToken; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java index 42d3cd332..c36622d50 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java @@ -1,6 +1,5 @@ package com.databricks.sdk.core.utils; -import com.databricks.sdk.core.AzureCliCredentialsProvider; import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.http.Request; @@ -12,14 +11,13 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.IOException; import java.util.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public interface AzureUtils { - static final Logger LOG = LoggerFactory.getLogger(AzureUtils.class); + Logger LOG = LoggerFactory.getLogger(AzureUtils.class); /** * Creates a RefreshableTokenSource for the specified Azure resource. @@ -39,24 +37,26 @@ default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String re return tokenSourceFor(config, endpointParams); } - default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource, String subscription) { + default RefreshableTokenSource tokenSourceFor( + DatabricksConfig config, String resource, String subscription) { Map endpointParams = new HashMap<>(); endpointParams.put("resource", resource); endpointParams.put("subscription", subscription); return tokenSourceFor(config, endpointParams); } - default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, Map endpointParams) { + default RefreshableTokenSource tokenSourceFor( + DatabricksConfig config, Map endpointParams) { String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint(); String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token"; return new ClientCredentials.Builder() - .withHttpClient(config.getHttpClient()) - .withClientId(config.getAzureClientId()) - .withClientSecret(config.getAzureClientSecret()) - .withTokenUrl(tokenUrl) - .withEndpointParameters(endpointParams) - .withAuthParameterPosition(AuthParameterPosition.BODY) - .build(); + .withHttpClient(config.getHttpClient()) + .withClientId(config.getAzureClientId()) + .withClientSecret(config.getAzureClientSecret()) + .withTokenUrl(tokenUrl) + .withEndpointParameters(endpointParams) + .withAuthParameterPosition(AuthParameterPosition.BODY) + .build(); } default Optional getSubscription(DatabricksConfig config) { @@ -70,10 +70,8 @@ default Optional getSubscription(DatabricksConfig config) { return Optional.empty(); } return Optional.of(components[2]); - } - default String getWorkspaceFromJsonResponse(ObjectNode jsonResponse) throws IOException { JsonNode properties = jsonResponse.get("properties"); if (properties == null) { diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java index 82dedb68d..e153ca7e2 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java @@ -1,12 +1,5 @@ package com.databricks.sdk.core; -import com.databricks.sdk.core.oauth.Token; -import com.databricks.sdk.core.oauth.TokenSource; -import org.junit.jupiter.api.Test; -import org.mockito.Mockito; - -import java.time.LocalDateTime; - import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; @@ -14,85 +7,103 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; -class AzureCliCredentialsProviderTest { - - private static final String WORKSPACE_RESOURCE_ID = "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws"; - private static final String SUBSCRIPTION = "2a2345f8"; - private static final String TOKEN = "t-123"; - private static final String TOKEN_TYPE = "token-type"; - public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; - - - private static CliTokenSource mockTokenSource() { - CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); - Mockito.when(tokenSource.getToken()).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); - return tokenSource; - } - private static AzureCliCredentialsProvider getAzureCliCredentialsProvider(TokenSource tokenSource) { - - AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider()); - Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); - Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); - Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); - - return provider; - } - - - @Test - void testWorkSpaceIDUsage() { - AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource()); - DatabricksConfig config = new DatabricksConfig() - .setHost(".azuredatabricks.") - .setCredentialsProvider(provider) - .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); - - HeaderFactory header = provider.configure(config); - - String token = header.headers().get("Authorization"); - assertEquals(token, TOKEN_TYPE + " " + TOKEN); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); - } - - @Test - void testFallbackWhenTailsToGetTokenForSubscription() { - CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); - Mockito.when(tokenSource.getToken()).thenThrow(new DatabricksException("error")).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); - - AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(tokenSource); - - DatabricksConfig config = new DatabricksConfig() - .setHost(".azuredatabricks.") - .setCredentialsProvider(provider) - .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); - - - HeaderFactory header = provider.configure(config); - - - String token = header.headers().get("Authorization"); - assertEquals(token, TOKEN_TYPE + " " + TOKEN); +import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.oauth.TokenSource; +import java.time.LocalDateTime; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); - } +class AzureCliCredentialsProviderTest { - @Test - void testGetTokenWithoutWorkspaceResourceID() { + private static final String WORKSPACE_RESOURCE_ID = + "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws"; + private static final String SUBSCRIPTION = "2a2345f8"; + private static final String TOKEN = "t-123"; + private static final String TOKEN_TYPE = "token-type"; + public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; + + private static CliTokenSource mockTokenSource() { + CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); + Mockito.when(tokenSource.getToken()) + .thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + return tokenSource; + } + + private static AzureCliCredentialsProvider getAzureCliCredentialsProvider( + TokenSource tokenSource) { + + AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider()); + Mockito.doReturn(tokenSource) + .when(provider) + .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.doReturn(tokenSource) + .when(provider) + .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.doReturn(tokenSource) + .when(provider) + .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + Mockito.doReturn(tokenSource) + .when(provider) + .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + + return provider; + } + + @Test + void testWorkSpaceIDUsage() { AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource()); - DatabricksConfig config = new DatabricksConfig() - .setHost(".azuredatabricks.") - .setCredentialsProvider(provider); - - HeaderFactory header = provider.configure(config); - - String token = header.headers().get("Authorization"); - assertEquals(token, TOKEN_TYPE + " " + TOKEN); - Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); - } - - -} \ No newline at end of file + DatabricksConfig config = + new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + Mockito.verify(provider, times(1)) + .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + } + + @Test + void testFallbackWhenTailsToGetTokenForSubscription() { + CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); + Mockito.when(tokenSource.getToken()) + .thenThrow(new DatabricksException("error")) + .thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + + AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(tokenSource); + + DatabricksConfig config = + new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + + Mockito.verify(provider, times(1)) + .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + } + + @Test + void testGetTokenWithoutWorkspaceResourceID() { + AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource()); + DatabricksConfig config = + new DatabricksConfig().setHost(".azuredatabricks.").setCredentialsProvider(provider); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, TOKEN_TYPE + " " + TOKEN); + Mockito.verify(provider, never()) + .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java index 3619bd7cc..41b463dd8 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java @@ -1,12 +1,5 @@ package com.databricks.sdk.core.oauth; -import com.databricks.sdk.core.*; -import org.junit.jupiter.api.Test; -import org.mockito.Mockito; - -import java.time.LocalDateTime; -import java.time.temporal.IsoFields; - import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; @@ -14,95 +7,126 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; +import com.databricks.sdk.core.*; +import java.time.LocalDateTime; +import java.time.temporal.IsoFields; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + class AzureServicePrincipalCredentialsProviderTest { - private static final String WORKSPACE_RESOURCE_ID = "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws"; - private static final String SUBSCRIPTION = "2a2345f8"; - private static final String TOKEN = "t-123"; - private static final String TOKEN_TYPE = "token-type"; - public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; - - private static RefreshableTokenSource mockTokenSource() { - RefreshableTokenSource tokenSource = Mockito.mock(RefreshableTokenSource.class); - Mockito.when(tokenSource.getToken()).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now().plus(1, IsoFields.WEEK_BASED_YEARS))); - return tokenSource; - } - private static AzureServicePrincipalCredentialsProvider getAzureServicePrincipalCredentialsProvider(RefreshableTokenSource tokenSource) { - AzureServicePrincipalCredentialsProvider provider = Mockito.spy(new AzureServicePrincipalCredentialsProvider()); - Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); - Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); - Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); - return provider; - } - - - @Test - void testWorkSpaceIDUsage() { - AzureServicePrincipalCredentialsProvider provider = getAzureServicePrincipalCredentialsProvider(mockTokenSource()); - DatabricksConfig config = new DatabricksConfig() - .setHost(".azuredatabricks.") - .setCredentialsProvider(provider) - .setAzureClientId("clientID") - .setAzureClientSecret("clientSecret") - .setAzureTenantId("tenantID") - .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); - - HeaderFactory header = provider.configure(config); - - String token = header.headers().get("Authorization"); - assertEquals(token, "Bearer " + TOKEN); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); - Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); - Mockito.verify(provider, never()).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); - } - - @Test - void testFallbackWhenTailsToGetTokenForSubscription() { - CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); - Mockito.when(tokenSource.getToken()).thenThrow(new DatabricksException("error")).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); - - AzureServicePrincipalCredentialsProvider provider = getAzureServicePrincipalCredentialsProvider(tokenSource); - - DatabricksConfig config = new DatabricksConfig() - .setHost(".azuredatabricks.") - .setCredentialsProvider(provider) - .setAzureClientId("clientID") - .setAzureClientSecret("clientSecret") - .setAzureTenantId("tenantID") - .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); - - HeaderFactory header = provider.configure(config); - - String token = header.headers().get("Authorization"); - assertEquals(token, "Bearer " + TOKEN); - - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); - } - - @Test - void testGetTokenWithoutWorkspaceResourceID() { - AzureServicePrincipalCredentialsProvider provider = getAzureServicePrincipalCredentialsProvider(mockTokenSource()); - DatabricksConfig config = new DatabricksConfig() - .setHost(".azuredatabricks.") - .setCredentialsProvider(provider) - .setAzureClientId("clientID") - .setAzureClientSecret("clientSecret") - .setAzureTenantId("tenantID"); - - HeaderFactory header = provider.configure(config); - - String token = header.headers().get("Authorization"); - assertEquals(token, "Bearer " + TOKEN); - Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.verify(provider, never()).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); - } - - -} \ No newline at end of file + private static final String WORKSPACE_RESOURCE_ID = + "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws"; + private static final String SUBSCRIPTION = "2a2345f8"; + private static final String TOKEN = "t-123"; + private static final String TOKEN_TYPE = "token-type"; + public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; + + private static RefreshableTokenSource mockTokenSource() { + RefreshableTokenSource tokenSource = Mockito.mock(RefreshableTokenSource.class); + Mockito.when(tokenSource.getToken()) + .thenReturn( + new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now().plus(1, IsoFields.WEEK_BASED_YEARS))); + return tokenSource; + } + + private static AzureServicePrincipalCredentialsProvider + getAzureServicePrincipalCredentialsProvider(RefreshableTokenSource tokenSource) { + AzureServicePrincipalCredentialsProvider provider = + Mockito.spy(new AzureServicePrincipalCredentialsProvider()); + Mockito.doReturn(tokenSource) + .when(provider) + .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.doReturn(tokenSource) + .when(provider) + .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + Mockito.doReturn(tokenSource) + .when(provider) + .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.doReturn(tokenSource) + .when(provider) + .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + return provider; + } + + @Test + void testWorkSpaceIDUsage() { + AzureServicePrincipalCredentialsProvider provider = + getAzureServicePrincipalCredentialsProvider(mockTokenSource()); + DatabricksConfig config = + new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureClientId("clientID") + .setAzureClientSecret("clientSecret") + .setAzureTenantId("tenantID") + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, "Bearer " + TOKEN); + Mockito.verify(provider, times(1)) + .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)) + .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, never()).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + } + + @Test + void testFallbackWhenTailsToGetTokenForSubscription() { + CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); + Mockito.when(tokenSource.getToken()) + .thenThrow(new DatabricksException("error")) + .thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + + AzureServicePrincipalCredentialsProvider provider = + getAzureServicePrincipalCredentialsProvider(tokenSource); + + DatabricksConfig config = + new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureClientId("clientID") + .setAzureClientSecret("clientSecret") + .setAzureTenantId("tenantID") + .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, "Bearer " + TOKEN); + + Mockito.verify(provider, times(1)) + .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)) + .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + } + + @Test + void testGetTokenWithoutWorkspaceResourceID() { + AzureServicePrincipalCredentialsProvider provider = + getAzureServicePrincipalCredentialsProvider(mockTokenSource()); + DatabricksConfig config = + new DatabricksConfig() + .setHost(".azuredatabricks.") + .setCredentialsProvider(provider) + .setAzureClientId("clientID") + .setAzureClientSecret("clientSecret") + .setAzureTenantId("tenantID"); + + HeaderFactory header = provider.configure(config); + + String token = header.headers().get("Authorization"); + assertEquals(token, "Bearer " + TOKEN); + Mockito.verify(provider, never()) + .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); + Mockito.verify(provider, never()) + .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); + } +} From 8fa4430bb6f7dc70a0b4531b140a8edfb531a459 Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Thu, 31 Aug 2023 14:16:17 +0200 Subject: [PATCH 3/3] Address comments in PR --- .../sdk/core/AzureCliCredentialsProvider.java | 93 ++++++++----------- ...reServicePrincipalCredentialsProvider.java | 38 +------- .../databricks/sdk/core/utils/AzureUtils.java | 38 +------- .../core/AzureCliCredentialsProviderTest.java | 52 +++++------ ...rvicePrincipalCredentialsProviderTest.java | 74 +-------------- 5 files changed, 73 insertions(+), 222 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java index dce514288..5a2a7f709 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java @@ -1,7 +1,6 @@ package com.databricks.sdk.core; import com.databricks.sdk.core.oauth.Token; -import com.databricks.sdk.core.oauth.TokenSource; import com.databricks.sdk.core.utils.AzureUtils; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.*; @@ -25,25 +24,45 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) { new ArrayList<>( Arrays.asList( "az", "account", "get-access-token", "--resource", resource, "--output", "json")); - return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); + Optional subscription = getSubscription(config); + if (subscription.isPresent()) { + // This will fail if the user has access to the workspace, but not to the subscription + // itself. + // In such case, we fall back to not using the subscription. + List extendedCmd = new ArrayList<>(cmd); + extendedCmd.addAll(Arrays.asList("--subscription", subscription.get())); + try { + return getToken(config, extendedCmd); + } catch (DatabricksException ex) { + LOG.warn("Failed to get token for subscription. Using resource only token."); + } + } else { + LOG.warn( + "azure_workspace_resource_id field not provided. " + + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); + } + + return getToken(config, cmd); } - @Override - public CliTokenSource tokenSourceFor( - DatabricksConfig config, String resource, String subscription) { - List cmd = - new ArrayList<>( - Arrays.asList( - "az", - "account", - "get-access-token", - "--subscription", - subscription, - "--resource", - resource, - "--output", - "json")); - return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); + protected CliTokenSource getToken(DatabricksConfig config, List cmd) { + CliTokenSource token = + new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); + token.getToken(); // We need this to check if the CLI is installed and to validate the config. + return token; + } + + private Optional getSubscription(DatabricksConfig config) { + String resourceId = config.getAzureWorkspaceResourceId(); + if (resourceId == null || resourceId.equals("")) { + return Optional.empty(); + } + String[] components = resourceId.split("/"); + if (components.length < 3) { + LOG.warn("Invalid azure workspace resource ID"); + return Optional.empty(); + } + return Optional.of(components[2]); } @Override @@ -54,49 +73,19 @@ public HeaderFactory configure(DatabricksConfig config) { try { ensureHostPresent(config, mapper); - CliTokenSource tokenSource; - CliTokenSource mgmtTokenSource; String resource = config.getEffectiveAzureLoginAppId(); - Optional subscription = getSubscription(config); - - if (subscription.isPresent()) { - try { - // This will fail if the user has access to the workspace, but not to the subscription - // itself. - // In such case, we fall back to not using the subscription. - tokenSource = tokenSourceFor(config, resource, subscription.get()); - tokenSource.getToken(); - mgmtTokenSource = - tokenSourceFor( - config, - config.getAzureEnvironment().getServiceManagementEndpoint(), - subscription.get()); - } catch (DatabricksException e) { - LOG.warn("Failed to get token for subscription. Using resource only token."); - tokenSource = tokenSourceFor(config, resource); - mgmtTokenSource = - tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); - } - } else { - LOG.warn( - "azure_workspace_resource_id field not provided. " - + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); - tokenSource = tokenSourceFor(config, resource); + CliTokenSource tokenSource = tokenSourceFor(config, resource); + CliTokenSource mgmtTokenSource; + try { mgmtTokenSource = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); - } - - tokenSource.getToken(); // We need this for checking if Azure CLI is installed - try { - mgmtTokenSource.getToken(); } catch (Exception e) { LOG.debug("Not including service management token in headers", e); mgmtTokenSource = null; } - TokenSource finalToken = tokenSource; CliTokenSource finalMgmtTokenSource = mgmtTokenSource; return () -> { - Token token = finalToken.getToken(); + Token token = tokenSource.getToken(); Map headers = new HashMap<>(); headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken()); if (finalMgmtTokenSource != null) { diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index 2a8345570..ea5a5dee8 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -5,7 +5,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import java.util.HashMap; import java.util.Map; -import java.util.Optional; /** * Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, @@ -28,40 +27,9 @@ public HeaderFactory configure(DatabricksConfig config) { return null; } ensureHostPresent(config, mapper); - RefreshableTokenSource innerToken; - RefreshableTokenSource cloudToken; - Optional subscription = getSubscription(config); - if (subscription.isPresent()) { - try { - // This will fail if the service principal has access to the workspace, but not to the - // subscription itself. - // In such case, we fall back to not using the subscription. - innerToken = - tokenSourceFor(config, config.getEffectiveAzureLoginAppId(), subscription.get()); - cloudToken = - tokenSourceFor( - config, - config.getAzureEnvironment().getServiceManagementEndpoint(), - subscription.get()); - innerToken.getToken(); - cloudToken.getToken(); - } catch (DatabricksException e) { - LOG.warn("Failed to get token for subscription. Using resource only token."); - innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); - cloudToken = - tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); - } - } else { - LOG.warn( - "azure_workspace_resource_id field not provided. " - + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); - innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); - cloudToken = - tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); - } - - RefreshableTokenSource inner = innerToken; - RefreshableTokenSource cloud = cloudToken; + RefreshableTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); + RefreshableTokenSource cloud = + tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); return () -> { Map headers = new HashMap<>(); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java index c36622d50..1a73ea630 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java @@ -12,12 +12,10 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import java.io.IOException; -import java.util.*; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import java.util.HashMap; +import java.util.Map; public interface AzureUtils { - Logger LOG = LoggerFactory.getLogger(AzureUtils.class); /** * Creates a RefreshableTokenSource for the specified Azure resource. @@ -32,23 +30,10 @@ public interface AzureUtils { * Azure resource. */ default RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource) { - Map endpointParams = new HashMap<>(); - endpointParams.put("resource", resource); - return tokenSourceFor(config, endpointParams); - } - - default RefreshableTokenSource tokenSourceFor( - DatabricksConfig config, String resource, String subscription) { - Map endpointParams = new HashMap<>(); - endpointParams.put("resource", resource); - endpointParams.put("subscription", subscription); - return tokenSourceFor(config, endpointParams); - } - - default RefreshableTokenSource tokenSourceFor( - DatabricksConfig config, Map endpointParams) { String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint(); String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token"; + Map endpointParams = new HashMap<>(); + endpointParams.put("resource", resource); return new ClientCredentials.Builder() .withHttpClient(config.getHttpClient()) .withClientId(config.getAzureClientId()) @@ -59,19 +44,6 @@ default RefreshableTokenSource tokenSourceFor( .build(); } - default Optional getSubscription(DatabricksConfig config) { - String resourceId = config.getAzureWorkspaceResourceId(); - if (resourceId == null || resourceId.equals("")) { - return Optional.empty(); - } - String[] components = resourceId.split("/"); - if (components.length < 3) { - LOG.warn("Invalid azure workspace resource ID"); - return Optional.empty(); - } - return Optional.of(components[2]); - } - default String getWorkspaceFromJsonResponse(ObjectNode jsonResponse) throws IOException { JsonNode properties = jsonResponse.get("properties"); if (properties == null) { @@ -97,7 +69,7 @@ default void ensureHostPresent(DatabricksConfig config, ObjectMapper mapper) { } String armEndpoint = config.getAzureEnvironment().getResourceManagerEndpoint(); - Token token = tokenSourceFor(config, "resource", armEndpoint).getToken(); + Token token = tokenSourceFor(config, armEndpoint).getToken(); String requestUrl = armEndpoint + config.getAzureWorkspaceResourceId() + "?api-version=2018-04-01"; Request req = new Request("GET", requestUrl); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java index e153ca7e2..6b617f643 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java @@ -2,15 +2,16 @@ import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID; import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.never; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.times; import com.databricks.sdk.core.oauth.Token; import com.databricks.sdk.core.oauth.TokenSource; import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mockito; class AzureCliCredentialsProviderTest { @@ -20,7 +21,6 @@ class AzureCliCredentialsProviderTest { private static final String SUBSCRIPTION = "2a2345f8"; private static final String TOKEN = "t-123"; private static final String TOKEN_TYPE = "token-type"; - public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; private static CliTokenSource mockTokenSource() { CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); @@ -33,18 +33,7 @@ private static AzureCliCredentialsProvider getAzureCliCredentialsProvider( TokenSource tokenSource) { AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider()); - Mockito.doReturn(tokenSource) - .when(provider) - .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); - Mockito.doReturn(tokenSource) - .when(provider) - .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.doReturn(tokenSource) - .when(provider) - .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); - Mockito.doReturn(tokenSource) - .when(provider) - .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); + Mockito.doReturn(tokenSource).when(provider).getToken(any(), anyList()); return provider; } @@ -57,24 +46,27 @@ void testWorkSpaceIDUsage() { .setHost(".azuredatabricks.") .setCredentialsProvider(provider) .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); + ArgumentCaptor> argument = ArgumentCaptor.forClass(List.class); HeaderFactory header = provider.configure(config); String token = header.headers().get("Authorization"); assertEquals(token, TOKEN_TYPE + " " + TOKEN); - Mockito.verify(provider, times(1)) - .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, times(2)).getToken(any(), argument.capture()); + + List value = argument.getValue(); + value = value.subList(value.size() - 2, value.size()); + List expected = Arrays.asList("--subscription", SUBSCRIPTION); + assertEquals(expected, value); } @Test void testFallbackWhenTailsToGetTokenForSubscription() { - CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); - Mockito.when(tokenSource.getToken()) - .thenThrow(new DatabricksException("error")) - .thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); + CliTokenSource tokenSource = mockTokenSource(); - AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(tokenSource); + AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider()); + Mockito.doThrow(new DatabricksException("error")).when(provider).getToken(any(), anyList()); + Mockito.doReturn(tokenSource).when(provider).getToken(any(), anyList()); DatabricksConfig config = new DatabricksConfig() @@ -87,8 +79,6 @@ void testFallbackWhenTailsToGetTokenForSubscription() { String token = header.headers().get("Authorization"); assertEquals(token, TOKEN_TYPE + " " + TOKEN); - Mockito.verify(provider, times(1)) - .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); } @@ -98,12 +88,16 @@ void testGetTokenWithoutWorkspaceResourceID() { DatabricksConfig config = new DatabricksConfig().setHost(".azuredatabricks.").setCredentialsProvider(provider); + ArgumentCaptor> argument = ArgumentCaptor.forClass(List.class); + HeaderFactory header = provider.configure(config); String token = header.headers().get("Authorization"); assertEquals(token, TOKEN_TYPE + " " + TOKEN); - Mockito.verify(provider, never()) - .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); + Mockito.verify(provider, times(2)).getToken(any(), argument.capture()); + + List value = argument.getValue(); + assertFalse(value.contains("--subscription")); + assertFalse(value.contains(SUBSCRIPTION)); } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java index 41b463dd8..e06683308 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProviderTest.java @@ -4,7 +4,6 @@ import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import com.databricks.sdk.core.*; @@ -14,10 +13,6 @@ import org.mockito.Mockito; class AzureServicePrincipalCredentialsProviderTest { - - private static final String WORKSPACE_RESOURCE_ID = - "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws"; - private static final String SUBSCRIPTION = "2a2345f8"; private static final String TOKEN = "t-123"; private static final String TOKEN_TYPE = "token-type"; public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; @@ -40,74 +35,11 @@ private static RefreshableTokenSource mockTokenSource() { Mockito.doReturn(tokenSource) .when(provider) .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); - Mockito.doReturn(tokenSource) - .when(provider) - .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.doReturn(tokenSource) - .when(provider) - .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); return provider; } @Test - void testWorkSpaceIDUsage() { - AzureServicePrincipalCredentialsProvider provider = - getAzureServicePrincipalCredentialsProvider(mockTokenSource()); - DatabricksConfig config = - new DatabricksConfig() - .setHost(".azuredatabricks.") - .setCredentialsProvider(provider) - .setAzureClientId("clientID") - .setAzureClientSecret("clientSecret") - .setAzureTenantId("tenantID") - .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); - - HeaderFactory header = provider.configure(config); - - String token = header.headers().get("Authorization"); - assertEquals(token, "Bearer " + TOKEN); - Mockito.verify(provider, times(1)) - .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.verify(provider, times(1)) - .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); - Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); - Mockito.verify(provider, never()).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); - } - - @Test - void testFallbackWhenTailsToGetTokenForSubscription() { - CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); - Mockito.when(tokenSource.getToken()) - .thenThrow(new DatabricksException("error")) - .thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); - - AzureServicePrincipalCredentialsProvider provider = - getAzureServicePrincipalCredentialsProvider(tokenSource); - - DatabricksConfig config = - new DatabricksConfig() - .setHost(".azuredatabricks.") - .setCredentialsProvider(provider) - .setAzureClientId("clientID") - .setAzureClientSecret("clientSecret") - .setAzureTenantId("tenantID") - .setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); - - HeaderFactory header = provider.configure(config); - - String token = header.headers().get("Authorization"); - assertEquals(token, "Bearer " + TOKEN); - - Mockito.verify(provider, times(1)) - .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.verify(provider, times(1)) - .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); - Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); - } - - @Test - void testGetTokenWithoutWorkspaceResourceID() { + void testGetToken() { AzureServicePrincipalCredentialsProvider provider = getAzureServicePrincipalCredentialsProvider(mockTokenSource()); DatabricksConfig config = @@ -122,10 +54,6 @@ void testGetTokenWithoutWorkspaceResourceID() { String token = header.headers().get("Authorization"); assertEquals(token, "Bearer " + TOKEN); - Mockito.verify(provider, never()) - .tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); - Mockito.verify(provider, never()) - .tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); }