-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DECO-2483] Handle Azure authentication when WorkspaceResourceID is provided #145
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
package com.databricks.sdk.core; | ||
|
||
import com.databricks.sdk.core.oauth.Token; | ||
import com.databricks.sdk.core.oauth.TokenSource; | ||
import com.databricks.sdk.core.utils.AzureUtils; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import java.util.*; | ||
|
@@ -27,6 +28,15 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) { | |
return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); | ||
} | ||
|
||
@Override | ||
public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource, String subscription) { | ||
List<String> cmd = | ||
new ArrayList<>( | ||
Arrays.asList( | ||
"az", "account", "get-access-token", "--subscription", subscription, "--resource", resource, "--output", "json")); | ||
return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv); | ||
} | ||
|
||
@Override | ||
public HeaderFactory configure(DatabricksConfig config) { | ||
if (!config.isAzure()) { | ||
|
@@ -35,20 +45,44 @@ public HeaderFactory configure(DatabricksConfig config) { | |
|
||
try { | ||
ensureHostPresent(config, mapper); | ||
CliTokenSource tokenSource; | ||
CliTokenSource mgmtTokenSource; | ||
String resource = config.getEffectiveAzureLoginAppId(); | ||
CliTokenSource tokenSource = tokenSourceFor(config, resource); | ||
CliTokenSource mgmtTokenSource = | ||
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); | ||
tokenSource.getToken(); // We need this for checking if Azure CLI is installed. | ||
Optional<String> subscription = getSubscription(config); | ||
|
||
if (subscription.isPresent()) { | ||
try { | ||
// This will fail if the user has access to the workspace, but not to the subscription itself. | ||
// In such case, we fall back to not using the subscription. | ||
tokenSource = tokenSourceFor(config, resource, subscription.get()); | ||
tokenSource.getToken(); | ||
mgmtTokenSource = | ||
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint(), subscription.get()); | ||
} catch (DatabricksException e) { | ||
LOG.warn("Failed to get token for subscription. Using resource only token."); | ||
tokenSource = tokenSourceFor(config, resource); | ||
mgmtTokenSource = | ||
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); | ||
} | ||
} else { | ||
LOG.warn("azure_workspace_resource_id field not provided. " + | ||
"It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); | ||
tokenSource = tokenSourceFor(config, resource); | ||
mgmtTokenSource = | ||
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); | ||
} | ||
|
||
tokenSource.getToken(); // We need this for checking if Azure CLI is installed | ||
try { | ||
mgmtTokenSource.getToken(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We now get the token inside the "tokenSourceFor" function. |
||
} catch (Exception e) { | ||
LOG.debug("Not including service management token in headers", e); | ||
mgmtTokenSource = null; | ||
} | ||
TokenSource finalToken = tokenSource; | ||
CliTokenSource finalMgmtTokenSource = mgmtTokenSource; | ||
return () -> { | ||
Token token = tokenSource.getToken(); | ||
Token token = finalToken.getToken(); | ||
Map<String, String> headers = new HashMap<>(); | ||
headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken()); | ||
if (finalMgmtTokenSource != null) { | ||
|
@@ -67,3 +101,5 @@ public HeaderFactory configure(DatabricksConfig config) { | |
} | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
|
||
/** | ||
* Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, | ||
|
@@ -27,9 +28,31 @@ public HeaderFactory configure(DatabricksConfig config) { | |
return null; | ||
} | ||
ensureHostPresent(config, mapper); | ||
RefreshableTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); | ||
RefreshableTokenSource cloud = | ||
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); | ||
RefreshableTokenSource innerToken; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I was confused when we were talking about this earlier. We don't need to change this provider at all: the tenant ID must be explicitly specified, see line 27. What I meant was that: if a user is logged into the Azure CLI with a service principal, in AzureCliCredentialsProvider, we still will take the same pathway. |
||
RefreshableTokenSource cloudToken; | ||
Optional<String> subscription = getSubscription(config); | ||
if (subscription.isPresent()) { | ||
try { | ||
// This will fail if the service principal has access to the workspace, but not to the subscription itself. | ||
// In such case, we fall back to not using the subscription. | ||
innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId(), subscription.get()); | ||
cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint(), subscription.get()); | ||
innerToken.getToken(); | ||
cloudToken.getToken(); | ||
} catch (DatabricksException e) { | ||
LOG.warn("Failed to get token for subscription. Using resource only token."); | ||
innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); | ||
cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); | ||
} | ||
} else { | ||
LOG.warn("azure_workspace_resource_id field not provided. " + | ||
"It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); | ||
innerToken = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); | ||
cloudToken = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); | ||
} | ||
|
||
RefreshableTokenSource inner = innerToken; | ||
RefreshableTokenSource cloud = cloudToken; | ||
|
||
return () -> { | ||
Map<String, String> headers = new HashMap<>(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package com.databricks.sdk.core; | ||
|
||
import com.databricks.sdk.core.oauth.Token; | ||
import com.databricks.sdk.core.oauth.TokenSource; | ||
import org.junit.jupiter.api.Test; | ||
import org.mockito.Mockito; | ||
|
||
import java.time.LocalDateTime; | ||
|
||
import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID; | ||
import static org.junit.jupiter.api.Assertions.*; | ||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.ArgumentMatchers.eq; | ||
import static org.mockito.Mockito.never; | ||
import static org.mockito.Mockito.times; | ||
|
||
class AzureCliCredentialsProviderTest { | ||
|
||
private static final String WORKSPACE_RESOURCE_ID = "/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws"; | ||
private static final String SUBSCRIPTION = "2a2345f8"; | ||
private static final String TOKEN = "t-123"; | ||
private static final String TOKEN_TYPE = "token-type"; | ||
public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/"; | ||
|
||
|
||
private static CliTokenSource mockTokenSource() { | ||
CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); | ||
Mockito.when(tokenSource.getToken()).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); | ||
return tokenSource; | ||
} | ||
private static AzureCliCredentialsProvider getAzureCliCredentialsProvider(TokenSource tokenSource) { | ||
|
||
AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider()); | ||
Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); | ||
Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); | ||
Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT)); | ||
Mockito.doReturn(tokenSource).when(provider).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT), eq(SUBSCRIPTION)); | ||
|
||
return provider; | ||
} | ||
|
||
|
||
@Test | ||
void testWorkSpaceIDUsage() { | ||
AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource()); | ||
DatabricksConfig config = new DatabricksConfig() | ||
.setHost(".azuredatabricks.") | ||
.setCredentialsProvider(provider) | ||
.setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); | ||
|
||
HeaderFactory header = provider.configure(config); | ||
|
||
String token = header.headers().get("Authorization"); | ||
assertEquals(token, TOKEN_TYPE + " " + TOKEN); | ||
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); | ||
Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); | ||
} | ||
|
||
@Test | ||
void testFallbackWhenTailsToGetTokenForSubscription() { | ||
CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); | ||
Mockito.when(tokenSource.getToken()).thenThrow(new DatabricksException("error")).thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now())); | ||
|
||
AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(tokenSource); | ||
|
||
DatabricksConfig config = new DatabricksConfig() | ||
.setHost(".azuredatabricks.") | ||
.setCredentialsProvider(provider) | ||
.setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID); | ||
|
||
|
||
HeaderFactory header = provider.configure(config); | ||
|
||
|
||
String token = header.headers().get("Authorization"); | ||
assertEquals(token, TOKEN_TYPE + " " + TOKEN); | ||
|
||
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); | ||
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); | ||
} | ||
|
||
@Test | ||
void testGetTokenWithoutWorkspaceResourceID() { | ||
AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource()); | ||
DatabricksConfig config = new DatabricksConfig() | ||
.setHost(".azuredatabricks.") | ||
.setCredentialsProvider(provider); | ||
|
||
HeaderFactory header = provider.configure(config); | ||
|
||
String token = header.headers().get("Authorization"); | ||
assertEquals(token, TOKEN_TYPE + " " + TOKEN); | ||
Mockito.verify(provider, never()).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID), eq(SUBSCRIPTION)); | ||
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID)); | ||
} | ||
|
||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we refactor all of this into the
tokenSourceFor
method? I think that would prevent this configure() method from sprawling, and it seems to belong there in the first place.