From 37bc1dc762978521d03f201a24a49600bdd0d71e Mon Sep 17 00:00:00 2001 From: David Date: Wed, 7 Aug 2024 12:35:00 -0400 Subject: [PATCH] feat!: improve compatibility with databricks sql client (#252) * feat!: add working databricks-cli fallback credentials strategy * test: add tests for external databricks helpers --- examples/connect/.gitignore | 2 +- examples/connect/dash/app.py | 22 +-- examples/connect/dash/requirements.txt | 6 +- examples/connect/fastapi/app.py | 18 ++- examples/connect/fastapi/requirements.txt | 6 +- examples/connect/flask/app.py | 18 ++- examples/connect/flask/requirements.txt | 6 +- examples/connect/shiny-python/app.py | 22 +-- .../connect/shiny-python/requirements.txt | 6 +- examples/connect/streamlit/app.py | 28 ++-- examples/connect/streamlit/requirements.txt | 8 +- src/posit/connect/external/databricks.py | 126 +++++++++++------- .../posit/connect/external/test_databricks.py | 72 ++++++++++ 13 files changed, 231 insertions(+), 109 deletions(-) diff --git a/examples/connect/.gitignore b/examples/connect/.gitignore index afe36a39..5ac2a70e 100644 --- a/examples/connect/.gitignore +++ b/examples/connect/.gitignore @@ -1 +1 @@ -.posit +**/rsconnect-python/ diff --git a/examples/connect/dash/app.py b/examples/connect/dash/app.py index a1bf7893..72ff2a65 100644 --- a/examples/connect/dash/app.py +++ b/examples/connect/dash/app.py @@ -6,9 +6,9 @@ import pandas as pd from dash import Dash, Input, Output, dash_table, html from databricks import sql -from databricks.sdk.core import ApiClient, Config +from databricks.sdk.core import ApiClient, Config, databricks_cli from databricks.sdk.service.iam import CurrentUserAPI -from posit.connect.external.databricks import viewer_credentials_provider +from posit.connect.external.databricks import PositCredentialsStrategy DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}" @@ -38,14 +38,16 @@ def update_page(_): session_token = flask.request.headers.get( "Posit-Connect-User-Session-Token" ) - credentials_provider = viewer_credentials_provider( - user_session_token=session_token - ) + posit_strategy = PositCredentialsStrategy( + local_strategy=databricks_cli, + user_session_token=session_token) + cfg = Config( + host=DATABRICKS_HOST_URL, + # uses Posit's custom credential_strategy if running on Connect, + # otherwise falls back to the strategy defined by local_strategy + credentials_strategy=posit_strategy) def get_greeting(): - cfg = Config( - host=DATABRICKS_HOST_URL, credentials_provider=credentials_provider - ) databricks_user_info = CurrentUserAPI(ApiClient(cfg)).me() return f"Hello, {databricks_user_info.display_name}!" @@ -58,8 +60,8 @@ def get_table(): with sql.connect( server_hostname=DATABRICKS_HOST, http_path=SQL_HTTP_PATH, - auth_type="databricks-oauth", - credentials_provider=credentials_provider, + # https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + credentials_provider=posit_strategy.sql_credentials_provider(cfg) ) as connection: with connection.cursor() as cursor: cursor.execute(query) diff --git a/examples/connect/dash/requirements.txt b/examples/connect/dash/requirements.txt index bbceac3d..090523a4 100644 --- a/examples/connect/dash/requirements.txt +++ b/examples/connect/dash/requirements.txt @@ -1,4 +1,4 @@ -databricks-sql-connector==3.0.1 -databricks-sdk==0.20.0 +databricks-sql-connector==3.3.0 +databricks-sdk==0.29.0 dash==2.15.0 -git+https://github.com/posit-dev/posit-sdk-py.git +posit-sdk>=0.4.0 diff --git a/examples/connect/fastapi/app.py b/examples/connect/fastapi/app.py index 8be3ece2..c9eeabb7 100644 --- a/examples/connect/fastapi/app.py +++ b/examples/connect/fastapi/app.py @@ -4,9 +4,10 @@ from typing import Annotated from databricks import sql +from databricks.sdk.core import Config, databricks_cli from fastapi import FastAPI, Header from fastapi.responses import JSONResponse -from posit.connect.external.databricks import viewer_credentials_provider +from posit.connect.external.databricks import PositCredentialsStrategy DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}" @@ -26,9 +27,14 @@ async def get_fares( """ global rows - credentials_provider = viewer_credentials_provider( - user_session_token=posit_connect_user_session_token - ) + posit_strategy = PositCredentialsStrategy( + local_strategy=databricks_cli, + user_session_token=posit_connect_user_session_token) + cfg = Config( + host=DATABRICKS_HOST_URL, + # uses Posit's custom credential_strategy if running on Connect, + # otherwise falls back to the strategy defined by local_strategy + credentials_strategy=posit_strategy) if rows is None: query = "SELECT * FROM samples.nyctaxi.trips LIMIT 10;" @@ -36,8 +42,8 @@ async def get_fares( with sql.connect( server_hostname=DATABRICKS_HOST, http_path=SQL_HTTP_PATH, - auth_type="databricks-oauth", - credentials_provider=credentials_provider, + # https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + credentials_provider=posit_strategy.sql_credentials_provider(cfg) ) as connection: with connection.cursor() as cursor: cursor.execute(query) diff --git a/examples/connect/fastapi/requirements.txt b/examples/connect/fastapi/requirements.txt index d506dcc6..262e5868 100644 --- a/examples/connect/fastapi/requirements.txt +++ b/examples/connect/fastapi/requirements.txt @@ -1,4 +1,4 @@ -databricks-sql-connector==3.0.1 -databricks-sdk==0.20.0 +databricks-sql-connector==3.3.0 +databricks-sdk==0.29.0 fastapi==0.110.0 -git+https://github.com/posit-dev/posit-sdk-py.git +posit-sdk>=0.4.0 diff --git a/examples/connect/flask/app.py b/examples/connect/flask/app.py index d95fa19f..f41fd8d1 100644 --- a/examples/connect/flask/app.py +++ b/examples/connect/flask/app.py @@ -3,8 +3,9 @@ import os from databricks import sql +from databricks.sdk.core import Config, databricks_cli from flask import Flask, request -from posit.connect.external.databricks import viewer_credentials_provider +from posit.connect.external.databricks import PositCredentialsStrategy DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}" @@ -28,9 +29,14 @@ def get_fares(): global rows session_token = request.headers.get("Posit-Connect-User-Session-Token") - credentials_provider = viewer_credentials_provider( - user_session_token=session_token - ) + posit_strategy = PositCredentialsStrategy( + local_strategy=databricks_cli, + user_session_token=session_token) + cfg = Config( + host=DATABRICKS_HOST_URL, + # uses Posit's custom credential_strategy if running on Connect, + # otherwise falls back to the strategy defined by local_strategy + credentials_strategy=posit_strategy) if rows is None: query = "SELECT * FROM samples.nyctaxi.trips LIMIT 10;" @@ -38,8 +44,8 @@ def get_fares(): with sql.connect( server_hostname=DATABRICKS_HOST, http_path=SQL_HTTP_PATH, - auth_type="databricks-oauth", - credentials_provider=credentials_provider, + # https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + credentials_provider=posit_strategy.sql_credentials_provider(cfg) ) as connection: with connection.cursor() as cursor: cursor.execute(query) diff --git a/examples/connect/flask/requirements.txt b/examples/connect/flask/requirements.txt index 752130dd..c0b12941 100644 --- a/examples/connect/flask/requirements.txt +++ b/examples/connect/flask/requirements.txt @@ -1,4 +1,4 @@ -databricks-sql-connector==3.0.1 -databricks-sdk==0.20.0 +databricks-sql-connector==3.3.0 +databricks-sdk==0.29.0 flask==3.0.2 -git+https://github.com/posit-dev/posit-sdk-py.git +posit-sdk>=0.4.0 diff --git a/examples/connect/shiny-python/app.py b/examples/connect/shiny-python/app.py index bf97579e..fb903383 100644 --- a/examples/connect/shiny-python/app.py +++ b/examples/connect/shiny-python/app.py @@ -4,9 +4,9 @@ import pandas as pd from databricks import sql -from databricks.sdk.core import ApiClient, Config +from databricks.sdk.core import ApiClient, Config, databricks_cli from databricks.sdk.service.iam import CurrentUserAPI -from posit.connect.external.databricks import viewer_credentials_provider +from posit.connect.external.databricks import PositCredentialsStrategy from shiny import App, Inputs, Outputs, Session, render, ui DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") @@ -24,9 +24,14 @@ def server(i: Inputs, o: Outputs, session: Session): session_token = session.http_conn.headers.get( "Posit-Connect-User-Session-Token" ) - credentials_provider = viewer_credentials_provider( - user_session_token=session_token - ) + posit_strategy = PositCredentialsStrategy( + local_strategy=databricks_cli, + user_session_token=session_token) + cfg = Config( + host=DATABRICKS_HOST_URL, + # uses Posit's custom credential_strategy if running on Connect, + # otherwise falls back to the strategy defined by local_strategy + credentials_strategy=posit_strategy) @render.data_frame def result(): @@ -35,8 +40,8 @@ def result(): with sql.connect( server_hostname=DATABRICKS_HOST, http_path=SQL_HTTP_PATH, - auth_type="databricks-oauth", - credentials_provider=credentials_provider, + # https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + credentials_provider=posit_strategy.sql_credentials_provider(cfg), ) as connection: with connection.cursor() as cursor: cursor.execute(query) @@ -48,9 +53,6 @@ def result(): @render.text def text(): - cfg = Config( - host=DATABRICKS_HOST_URL, credentials_provider=credentials_provider - ) databricks_user_info = CurrentUserAPI(ApiClient(cfg)).me() return f"Hello, {databricks_user_info.display_name}!" diff --git a/examples/connect/shiny-python/requirements.txt b/examples/connect/shiny-python/requirements.txt index 44691daa..eedba131 100644 --- a/examples/connect/shiny-python/requirements.txt +++ b/examples/connect/shiny-python/requirements.txt @@ -1,4 +1,4 @@ -databricks-sql-connector==3.0.1 -databricks-sdk==0.20.0 +databricks-sql-connector==3.3.0 +databricks-sdk==0.29.0 shiny==0.7.1 -git+https://github.com/posit-dev/posit-sdk-py.git +posit-sdk>=0.4.0 diff --git a/examples/connect/streamlit/app.py b/examples/connect/streamlit/app.py index ffc084c4..0be0b773 100644 --- a/examples/connect/streamlit/app.py +++ b/examples/connect/streamlit/app.py @@ -5,34 +5,32 @@ import pandas as pd import streamlit as st from databricks import sql -from databricks.sdk.core import ApiClient, Config +from databricks.sdk.core import ApiClient, Config, databricks_cli from databricks.sdk.service.iam import CurrentUserAPI -from posit.connect.external.databricks import viewer_credentials_provider -from streamlit.web.server.websocket_headers import _get_websocket_headers +from posit.connect.external.databricks import PositCredentialsStrategy DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}" SQL_HTTP_PATH = os.getenv("DATABRICKS_PATH") -session_token = _get_websocket_headers().get( - "Posit-Connect-User-Session-Token" -) - -credentials_provider = viewer_credentials_provider( - user_session_token=session_token -) - +session_token = st.context.headers.get("Posit-Connect-User-Session-Token") +posit_strategy = PositCredentialsStrategy( + local_strategy=databricks_cli, + user_session_token=session_token) cfg = Config( - host=DATABRICKS_HOST_URL, credentials_provider=credentials_provider -) + host=DATABRICKS_HOST_URL, + # uses Posit's custom credential_strategy if running on Connect, + # otherwise falls back to the strategy defined by local_strategy + credentials_strategy=posit_strategy) + databricks_user = CurrentUserAPI(ApiClient(cfg)).me() st.write(f"Hello, {databricks_user.display_name}!") with sql.connect( server_hostname=DATABRICKS_HOST, http_path=SQL_HTTP_PATH, - auth_type="databricks-oauth", - credentials_provider=credentials_provider, + # https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + credentials_provider=posit_strategy.sql_credentials_provider(cfg) ) as connection: with connection.cursor() as cursor: cursor.execute("SELECT * FROM samples.nyctaxi.trips LIMIT 10;") diff --git a/examples/connect/streamlit/requirements.txt b/examples/connect/streamlit/requirements.txt index 63ccd077..8395b2c1 100644 --- a/examples/connect/streamlit/requirements.txt +++ b/examples/connect/streamlit/requirements.txt @@ -1,4 +1,4 @@ -databricks-sql-connector==3.0.1 -databricks-sdk==0.20.0 -streamlit==1.31.1 -git+https://github.com/posit-dev/posit-sdk-py.git +databricks-sql-connector==3.3.0 +databricks-sdk==0.29.0 +streamlit==1.37.0 +posit-sdk>=0.4.0 diff --git a/src/posit/connect/external/databricks.py b/src/posit/connect/external/databricks.py index 2e8f634a..2324258e 100644 --- a/src/posit/connect/external/databricks.py +++ b/src/posit/connect/external/databricks.py @@ -5,16 +5,19 @@ from ..client import Client from ..oauth import OAuthIntegration -HeaderFactory = Callable[[], Dict[str, str]] +""" +NOTE: These APIs are provided as a convenience and are subject to breaking changes: +https://github.com/databricks/databricks-sdk-py#interface-stability +""" +# The Databricks SDK CredentialsProvider == Databricks SQL HeaderFactory +CredentialsProvider = Callable[[], Dict[str, str]] -# https://github.com/databricks/databricks-sdk-py/blob/v0.20.0/databricks/sdk/credentials_provider.py -# https://github.com/databricks/databricks-sql-python/blob/v3.1.0/src/databricks/sql/auth/authenticators.py -# In order to keep compatibility with the Databricks SDK -class CredentialsProvider(abc.ABC): - """Protocol Databricks authentication. +class CredentialsStrategy(abc.ABC): + """Maintain compatibility with the Databricks SQL/SDK client libraries. - A call-side interface for the Databricks Rest API. + https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/auth/authenticators.py#L19-L33 + https://github.com/databricks/databricks-sdk-py/blob/v0.29.0/databricks/sdk/credentials_provider.py#L44-L54 """ @abc.abstractmethod @@ -22,56 +25,89 @@ def auth_type(self) -> str: raise NotImplementedError @abc.abstractmethod - def __call__(self, *args, **kwargs) -> HeaderFactory: + def __call__(self, *args, **kwargs) -> CredentialsProvider: raise NotImplementedError -class PositOAuthIntegrationCredentialsProvider(CredentialsProvider): +def _is_local() -> bool: + """Returns true if called from a piece of content running on a Connect server. + + The connect server will always set the environment variable `RSTUDIO_PRODUCT=CONNECT`. + We can use this environment variable to determine if the content is running locally + or on a Connect server. + """ + return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT" + + +class PositCredentialsProvider: def __init__(self, posit_oauth: OAuthIntegration, user_session_token: str): self.posit_oauth = posit_oauth self.user_session_token = user_session_token - def auth_type(self) -> str: - return "posit-oauth-integration" - - def __call__(self, *args, **kwargs) -> HeaderFactory: - def inner() -> Dict[str, str]: - access_token = self.posit_oauth.get_credentials( - self.user_session_token - )["access_token"] - return {"Authorization": f"Bearer {access_token}"} + def __call__(self) -> Dict[str, str]: + access_token = self.posit_oauth.get_credentials( + self.user_session_token + )["access_token"] + return {"Authorization": f"Bearer {access_token}"} - return inner +class PositCredentialsStrategy(CredentialsStrategy): -# Use this environment variable to determine if the -# client SDK was initialized from a piece of content running on a Connect server. -def is_local() -> bool: - return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT" + def __init__(self, + local_strategy: CredentialsStrategy, + user_session_token: Optional[str] = None, + client: Optional[Client] = None + ): + self.user_session_token = user_session_token + self.local_strategy = local_strategy + self.client = client + def sql_credentials_provider(self, *args, **kwargs): + """The sql connector attempts to call the credentials provider w/o any args. -def viewer_credentials_provider( - client: Optional[Client] = None, user_session_token: Optional[str] = None -) -> Optional[CredentialsProvider]: - # If the content is not running on Connect then viewer auth should - # fall back to the locally configured credentials hierarchy - if is_local(): - return None + The SQL client's `ExternalAuthProvider` is not compatible w/ the SDK's implementation of + `CredentialsProvider`, so create a no-arg lambda that wraps the args defined by the real caller. + This way we can pass in a databricks `Config` object required by most of the SDK's `CredentialsProvider` + implementations from where `sql.connect` is called. - if client is None: - client = Client() + https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + """ + return lambda: self.__call__(*args, **kwargs) - # If the user-session-token wasn't provided and we're running on Connect then we raise an exception. - # user_session_token is required to impersonate the viewer. - if user_session_token is None: - raise ValueError( - "The user-session-token is required for viewer authentication." + def auth_type(self) -> str: + """Returns the auth type currently in use. + + The databricks-sdk client uses the configurated auth_type to create + a user-agent string which is used for attribution. We should only + overwrite the auth_type if we are using the PositCredentialsStrategy (non-local), + otherwise, we should return the auth_type of the configured local_strategy instead + to avoid breaking someone elses attribution. + + https://github.com/databricks/databricks-sdk-py/blob/v0.29.0/databricks/sdk/config.py#L261-L269 + + NOTE: The databricks-sql client does not use auth_type to set the user-agent. + https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/client.py#L214-L219 + """ + if _is_local(): + return self.local_strategy.auth_type() + else: + return "posit-oauth-integration" + + def __call__(self, *args, **kwargs) -> CredentialsProvider: + # If the content is not running on Connect then fall back to local_strategy + if _is_local(): + return self.local_strategy(*args, **kwargs) + + # If the user-session-token wasn't provided and we're running on Connect then we raise an exception. + # user_session_token is required to impersonate the viewer. + if self.user_session_token is None: + raise ValueError( + "The user-session-token is required for viewer authentication." + ) + + if self.client is None: + self.client = Client() + + return PositCredentialsProvider( + self.client.oauth, self.user_session_token ) - - return PositOAuthIntegrationCredentialsProvider( - client.oauth, user_session_token - ) - - -def service_account_credentials_provider(client: Optional[Client] = None): - raise NotImplementedError diff --git a/tests/posit/connect/external/test_databricks.py b/tests/posit/connect/external/test_databricks.py index e69de29b..ad58ff0d 100644 --- a/tests/posit/connect/external/test_databricks.py +++ b/tests/posit/connect/external/test_databricks.py @@ -0,0 +1,72 @@ +from typing import Dict +from unittest.mock import patch + +import responses +from posit.connect import Client +from posit.connect.external.databricks import ( + CredentialsProvider, + PositCredentialsProvider, + PositCredentialsStrategy, +) + + +class mock_strategy: + def auth_type(self) -> str: + return "local" + def __call__(self) -> CredentialsProvider: + def inner() -> Dict[str,str]: + return {"Authorization": "Bearer static-pat-token"} + return inner + + +def register_mocks(): + responses.post( + "https://connect.example/__api__/v1/oauth/integrations/credentials", + match=[ + responses.matchers.urlencoded_params_matcher( + { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token_type": "urn:posit:connect:user-session-token", + "subject_token": "cit", + } + ) + ], + json={ + "access_token": "dynamic-viewer-access-token", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + }, + ) + + +class TestPositCredentialsHelpers: + @responses.activate + def test_posit_credentials_provider(self): + register_mocks() + + client = Client(api_key="12345", url="https://connect.example/") + cp = PositCredentialsProvider(posit_oauth=client.oauth, user_session_token="cit") + assert cp() == {"Authorization": f"Bearer dynamic-viewer-access-token"} + + @responses.activate + @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) + def test_posit_credentials_strategy(self): + register_mocks() + + client = Client(api_key="12345", url="https://connect.example/") + cs = PositCredentialsStrategy(local_strategy=mock_strategy(), + user_session_token="cit", + client=client) + cp = cs() + assert cs.auth_type() == "posit-oauth-integration" + assert cp() == {"Authorization": "Bearer dynamic-viewer-access-token"} + + def test_posit_credentials_strategy_fallback(self): + # local_strategy is used when the content is running locally + client = Client(api_key="12345", url="https://connect.example/") + cs = PositCredentialsStrategy(local_strategy=mock_strategy(), + user_session_token="cit", + client=client) + cp = cs() + assert cs.auth_type() == "local" + assert cp() == {"Authorization": "Bearer static-pat-token"}