Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DECO-2483] Handle Azure authentication when WorkspaceResourceID is provided #145

Merged
merged 3 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,45 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
new ArrayList<>(
Arrays.asList(
"az", "account", "get-access-token", "--resource", resource, "--output", "json"));
return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv);
Optional<String> subscription = getSubscription(config);
if (subscription.isPresent()) {
// This will fail if the user has access to the workspace, but not to the subscription
// itself.
// In such case, we fall back to not using the subscription.
List<String> extendedCmd = new ArrayList<>(cmd);
extendedCmd.addAll(Arrays.asList("--subscription", subscription.get()));
try {
return getToken(config, extendedCmd);
} catch (DatabricksException ex) {
LOG.warn("Failed to get token for subscription. Using resource only token.");
}
} else {
LOG.warn(
"azure_workspace_resource_id field not provided. "
+ "It is recommended to specify this field in the Databricks configuration to avoid authentication errors.");
}

return getToken(config, cmd);
}

protected CliTokenSource getToken(DatabricksConfig config, List<String> cmd) {
CliTokenSource token =
new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv);
token.getToken(); // We need this to check if the CLI is installed and to validate the config.
return token;
}

private Optional<String> getSubscription(DatabricksConfig config) {
String resourceId = config.getAzureWorkspaceResourceId();
if (resourceId == null || resourceId.equals("")) {
return Optional.empty();
}
String[] components = resourceId.split("/");
if (components.length < 3) {
LOG.warn("Invalid azure workspace resource ID");
return Optional.empty();
}
return Optional.of(components[2]);
}

@Override
Expand All @@ -37,11 +75,10 @@ public HeaderFactory configure(DatabricksConfig config) {
ensureHostPresent(config, mapper);
String resource = config.getEffectiveAzureLoginAppId();
CliTokenSource tokenSource = tokenSourceFor(config, resource);
CliTokenSource mgmtTokenSource =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
tokenSource.getToken(); // We need this for checking if Azure CLI is installed.
CliTokenSource mgmtTokenSource;
try {
mgmtTokenSource.getToken();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now get the token inside the "tokenSourceFor" function.

mgmtTokenSource =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
} catch (Exception e) {
LOG.debug("Not including service management token in headers", e);
mgmtTokenSource = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public class DatabricksConfig {
sensitive = true)
private String googleCredentials;

/** Azure Resource Manager ID for Azure Databricks workspace, which is exhanged for a Host */
/** Azure Resource Manager ID for Azure Databricks workspace, which is exchanged for a Host */
@ConfigAttribute(
value = "azure_workspace_resource_id",
env = "DATABRICKS_AZURE_RESOURCE_ID",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package com.databricks.sdk.core;

import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.times;

import com.databricks.sdk.core.oauth.Token;
import com.databricks.sdk.core.oauth.TokenSource;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

class AzureCliCredentialsProviderTest {

private static final String WORKSPACE_RESOURCE_ID =
"/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws";
private static final String SUBSCRIPTION = "2a2345f8";
private static final String TOKEN = "t-123";
private static final String TOKEN_TYPE = "token-type";

private static CliTokenSource mockTokenSource() {
CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class);
Mockito.when(tokenSource.getToken())
.thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now()));
return tokenSource;
}

private static AzureCliCredentialsProvider getAzureCliCredentialsProvider(
TokenSource tokenSource) {

AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider());
Mockito.doReturn(tokenSource).when(provider).getToken(any(), anyList());

return provider;
}

@Test
void testWorkSpaceIDUsage() {
AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource());
DatabricksConfig config =
new DatabricksConfig()
.setHost(".azuredatabricks.")
.setCredentialsProvider(provider)
.setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID);
ArgumentCaptor<List<String>> argument = ArgumentCaptor.forClass(List.class);

HeaderFactory header = provider.configure(config);

String token = header.headers().get("Authorization");
assertEquals(token, TOKEN_TYPE + " " + TOKEN);
Mockito.verify(provider, times(2)).getToken(any(), argument.capture());

List<String> value = argument.getValue();
value = value.subList(value.size() - 2, value.size());
List<String> expected = Arrays.asList("--subscription", SUBSCRIPTION);
assertEquals(expected, value);
}

@Test
void testFallbackWhenTailsToGetTokenForSubscription() {
CliTokenSource tokenSource = mockTokenSource();

AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider());
Mockito.doThrow(new DatabricksException("error")).when(provider).getToken(any(), anyList());
Mockito.doReturn(tokenSource).when(provider).getToken(any(), anyList());

DatabricksConfig config =
new DatabricksConfig()
.setHost(".azuredatabricks.")
.setCredentialsProvider(provider)
.setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID);

HeaderFactory header = provider.configure(config);

String token = header.headers().get("Authorization");
assertEquals(token, TOKEN_TYPE + " " + TOKEN);

Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
}

@Test
void testGetTokenWithoutWorkspaceResourceID() {
AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource());
DatabricksConfig config =
new DatabricksConfig().setHost(".azuredatabricks.").setCredentialsProvider(provider);

ArgumentCaptor<List<String>> argument = ArgumentCaptor.forClass(List.class);

HeaderFactory header = provider.configure(config);

String token = header.headers().get("Authorization");
assertEquals(token, TOKEN_TYPE + " " + TOKEN);
Mockito.verify(provider, times(2)).getToken(any(), argument.capture());

List<String> value = argument.getValue();
assertFalse(value.contains("--subscription"));
assertFalse(value.contains(SUBSCRIPTION));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.databricks.sdk.core.oauth;

import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times;

import com.databricks.sdk.core.*;
import java.time.LocalDateTime;
import java.time.temporal.IsoFields;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

class AzureServicePrincipalCredentialsProviderTest {
private static final String TOKEN = "t-123";
private static final String TOKEN_TYPE = "token-type";
public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/";

private static RefreshableTokenSource mockTokenSource() {
RefreshableTokenSource tokenSource = Mockito.mock(RefreshableTokenSource.class);
Mockito.when(tokenSource.getToken())
.thenReturn(
new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now().plus(1, IsoFields.WEEK_BASED_YEARS)));
return tokenSource;
}

private static AzureServicePrincipalCredentialsProvider
getAzureServicePrincipalCredentialsProvider(RefreshableTokenSource tokenSource) {
AzureServicePrincipalCredentialsProvider provider =
Mockito.spy(new AzureServicePrincipalCredentialsProvider());
Mockito.doReturn(tokenSource)
.when(provider)
.tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
Mockito.doReturn(tokenSource)
.when(provider)
.tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT));
return provider;
}

@Test
void testGetToken() {
AzureServicePrincipalCredentialsProvider provider =
getAzureServicePrincipalCredentialsProvider(mockTokenSource());
DatabricksConfig config =
new DatabricksConfig()
.setHost(".azuredatabricks.")
.setCredentialsProvider(provider)
.setAzureClientId("clientID")
.setAzureClientSecret("clientSecret")
.setAzureTenantId("tenantID");

HeaderFactory header = provider.configure(config);

String token = header.headers().get("Authorization");
assertEquals(token, "Bearer " + TOKEN);
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT));
}
}
Loading