diff --git a/examples/connect/.gitignore b/examples/connect/.gitignore index afe36a39..5ac2a70e 100644 --- a/examples/connect/.gitignore +++ b/examples/connect/.gitignore @@ -1 +1 @@ -.posit +**/rsconnect-python/ diff --git a/examples/connect/dash/app.py b/examples/connect/dash/app.py index a1bf7893..0a728b31 100644 --- a/examples/connect/dash/app.py +++ b/examples/connect/dash/app.py @@ -6,9 +6,11 @@ import pandas as pd from dash import Dash, Input, Output, dash_table, html from databricks import sql -from databricks.sdk.core import ApiClient, Config +from databricks.sdk.core import ApiClient, Config, databricks_cli from databricks.sdk.service.iam import CurrentUserAPI -from posit.connect.external.databricks import viewer_credentials_provider +from posit.connect.external.databricks import ( + PositOAuthIntegrationCredentialsStrategy, +) DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}" @@ -38,14 +40,16 @@ def update_page(_): session_token = flask.request.headers.get( "Posit-Connect-User-Session-Token" ) - credentials_provider = viewer_credentials_provider( - user_session_token=session_token - ) + posit_strategy = PositOAuthIntegrationCredentialsStrategy( + local_strategy=databricks_cli, + user_session_token=session_token) + cfg = Config( + host=DATABRICKS_HOST_URL, + # uses Posit's custom credential_strategy if running on Connect, + # otherwise falls back to the strategy defined by local_strategy + credentials_strategy=posit_strategy) def get_greeting(): - cfg = Config( - host=DATABRICKS_HOST_URL, credentials_provider=credentials_provider - ) databricks_user_info = CurrentUserAPI(ApiClient(cfg)).me() return f"Hello, {databricks_user_info.display_name}!" @@ -58,8 +62,8 @@ def get_table(): with sql.connect( server_hostname=DATABRICKS_HOST, http_path=SQL_HTTP_PATH, - auth_type="databricks-oauth", - credentials_provider=credentials_provider, + # https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + credentials_provider=posit_strategy.sql_credentials_provider(cfg) ) as connection: with connection.cursor() as cursor: cursor.execute(query) diff --git a/examples/connect/dash/requirements.txt b/examples/connect/dash/requirements.txt index bbceac3d..090523a4 100644 --- a/examples/connect/dash/requirements.txt +++ b/examples/connect/dash/requirements.txt @@ -1,4 +1,4 @@ -databricks-sql-connector==3.0.1 -databricks-sdk==0.20.0 +databricks-sql-connector==3.3.0 +databricks-sdk==0.29.0 dash==2.15.0 -git+https://github.com/posit-dev/posit-sdk-py.git +posit-sdk>=0.4.0 diff --git a/examples/connect/fastapi/app.py b/examples/connect/fastapi/app.py index 8be3ece2..8b39fdfb 100644 --- a/examples/connect/fastapi/app.py +++ b/examples/connect/fastapi/app.py @@ -4,9 +4,12 @@ from typing import Annotated from databricks import sql +from databricks.sdk.core import Config, databricks_cli from fastapi import FastAPI, Header from fastapi.responses import JSONResponse -from posit.connect.external.databricks import viewer_credentials_provider +from posit.connect.external.databricks import ( + PositOAuthIntegrationCredentialsStrategy, +) DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}" @@ -26,9 +29,14 @@ async def get_fares( """ global rows - credentials_provider = viewer_credentials_provider( - user_session_token=posit_connect_user_session_token - ) + posit_strategy = PositOAuthIntegrationCredentialsStrategy( + local_strategy=databricks_cli, + user_session_token=posit_connect_user_session_token) + cfg = Config( + host=DATABRICKS_HOST_URL, + # uses Posit's custom credential_strategy if running on Connect, + # otherwise falls back to the strategy defined by local_strategy + credentials_strategy=posit_strategy) if rows is None: query = "SELECT * FROM samples.nyctaxi.trips LIMIT 10;" @@ -36,8 +44,8 @@ async def get_fares( with sql.connect( server_hostname=DATABRICKS_HOST, http_path=SQL_HTTP_PATH, - auth_type="databricks-oauth", - credentials_provider=credentials_provider, + # https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + credentials_provider=posit_strategy.sql_credentials_provider(cfg) ) as connection: with connection.cursor() as cursor: cursor.execute(query) diff --git a/examples/connect/fastapi/requirements.txt b/examples/connect/fastapi/requirements.txt index d506dcc6..262e5868 100644 --- a/examples/connect/fastapi/requirements.txt +++ b/examples/connect/fastapi/requirements.txt @@ -1,4 +1,4 @@ -databricks-sql-connector==3.0.1 -databricks-sdk==0.20.0 +databricks-sql-connector==3.3.0 +databricks-sdk==0.29.0 fastapi==0.110.0 -git+https://github.com/posit-dev/posit-sdk-py.git +posit-sdk>=0.4.0 diff --git a/examples/connect/flask/app.py b/examples/connect/flask/app.py index d95fa19f..78b9529a 100644 --- a/examples/connect/flask/app.py +++ b/examples/connect/flask/app.py @@ -3,8 +3,11 @@ import os from databricks import sql +from databricks.sdk.core import Config, databricks_cli from flask import Flask, request -from posit.connect.external.databricks import viewer_credentials_provider +from posit.connect.external.databricks import ( + PositOAuthIntegrationCredentialsStrategy, +) DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}" @@ -28,9 +31,14 @@ def get_fares(): global rows session_token = request.headers.get("Posit-Connect-User-Session-Token") - credentials_provider = viewer_credentials_provider( - user_session_token=session_token - ) + posit_strategy = PositOAuthIntegrationCredentialsStrategy( + local_strategy=databricks_cli, + user_session_token=session_token) + cfg = Config( + host=DATABRICKS_HOST_URL, + # uses Posit's custom credential_strategy if running on Connect, + # otherwise falls back to the strategy defined by local_strategy + credentials_strategy=posit_strategy) if rows is None: query = "SELECT * FROM samples.nyctaxi.trips LIMIT 10;" @@ -38,8 +46,8 @@ def get_fares(): with sql.connect( server_hostname=DATABRICKS_HOST, http_path=SQL_HTTP_PATH, - auth_type="databricks-oauth", - credentials_provider=credentials_provider, + # https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + credentials_provider=posit_strategy.sql_credentials_provider(cfg) ) as connection: with connection.cursor() as cursor: cursor.execute(query) diff --git a/examples/connect/flask/requirements.txt b/examples/connect/flask/requirements.txt index 752130dd..c0b12941 100644 --- a/examples/connect/flask/requirements.txt +++ b/examples/connect/flask/requirements.txt @@ -1,4 +1,4 @@ -databricks-sql-connector==3.0.1 -databricks-sdk==0.20.0 +databricks-sql-connector==3.3.0 +databricks-sdk==0.29.0 flask==3.0.2 -git+https://github.com/posit-dev/posit-sdk-py.git +posit-sdk>=0.4.0 diff --git a/examples/connect/shiny-python/app.py b/examples/connect/shiny-python/app.py index bf97579e..60369a2d 100644 --- a/examples/connect/shiny-python/app.py +++ b/examples/connect/shiny-python/app.py @@ -4,9 +4,11 @@ import pandas as pd from databricks import sql -from databricks.sdk.core import ApiClient, Config +from databricks.sdk.core import ApiClient, Config, databricks_cli from databricks.sdk.service.iam import CurrentUserAPI -from posit.connect.external.databricks import viewer_credentials_provider +from posit.connect.external.databricks import ( + PositOAuthIntegrationCredentialsStrategy, +) from shiny import App, Inputs, Outputs, Session, render, ui DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") @@ -24,9 +26,14 @@ def server(i: Inputs, o: Outputs, session: Session): session_token = session.http_conn.headers.get( "Posit-Connect-User-Session-Token" ) - credentials_provider = viewer_credentials_provider( - user_session_token=session_token - ) + posit_strategy = PositOAuthIntegrationCredentialsStrategy( + local_strategy=databricks_cli, + user_session_token=session_token) + cfg = Config( + host=DATABRICKS_HOST_URL, + # uses Posit's custom credential_strategy if running on Connect, + # otherwise falls back to the strategy defined by local_strategy + credentials_strategy=posit_strategy) @render.data_frame def result(): @@ -35,8 +42,8 @@ def result(): with sql.connect( server_hostname=DATABRICKS_HOST, http_path=SQL_HTTP_PATH, - auth_type="databricks-oauth", - credentials_provider=credentials_provider, + # https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + credentials_provider=posit_strategy.sql_credentials_provider(cfg), ) as connection: with connection.cursor() as cursor: cursor.execute(query) @@ -48,9 +55,6 @@ def result(): @render.text def text(): - cfg = Config( - host=DATABRICKS_HOST_URL, credentials_provider=credentials_provider - ) databricks_user_info = CurrentUserAPI(ApiClient(cfg)).me() return f"Hello, {databricks_user_info.display_name}!" diff --git a/examples/connect/shiny-python/requirements.txt b/examples/connect/shiny-python/requirements.txt index 44691daa..eedba131 100644 --- a/examples/connect/shiny-python/requirements.txt +++ b/examples/connect/shiny-python/requirements.txt @@ -1,4 +1,4 @@ -databricks-sql-connector==3.0.1 -databricks-sdk==0.20.0 +databricks-sql-connector==3.3.0 +databricks-sdk==0.29.0 shiny==0.7.1 -git+https://github.com/posit-dev/posit-sdk-py.git +posit-sdk>=0.4.0 diff --git a/examples/connect/streamlit/app.py b/examples/connect/streamlit/app.py index ffc084c4..f0cda8a1 100644 --- a/examples/connect/streamlit/app.py +++ b/examples/connect/streamlit/app.py @@ -5,34 +5,34 @@ import pandas as pd import streamlit as st from databricks import sql -from databricks.sdk.core import ApiClient, Config +from databricks.sdk.core import ApiClient, Config, databricks_cli from databricks.sdk.service.iam import CurrentUserAPI -from posit.connect.external.databricks import viewer_credentials_provider -from streamlit.web.server.websocket_headers import _get_websocket_headers +from posit.connect.external.databricks import ( + PositOAuthIntegrationCredentialsStrategy, +) DATABRICKS_HOST = os.getenv("DATABRICKS_HOST") DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}" SQL_HTTP_PATH = os.getenv("DATABRICKS_PATH") -session_token = _get_websocket_headers().get( - "Posit-Connect-User-Session-Token" -) - -credentials_provider = viewer_credentials_provider( - user_session_token=session_token -) - +session_token = st.context.headers.get("Posit-Connect-User-Session-Token") +posit_strategy = PositOAuthIntegrationCredentialsStrategy( + local_strategy=databricks_cli, + user_session_token=session_token) cfg = Config( - host=DATABRICKS_HOST_URL, credentials_provider=credentials_provider -) + host=DATABRICKS_HOST_URL, + # uses Posit's custom credential_strategy if running on Connect, + # otherwise falls back to the strategy defined by local_strategy + credentials_strategy=posit_strategy) + databricks_user = CurrentUserAPI(ApiClient(cfg)).me() st.write(f"Hello, {databricks_user.display_name}!") with sql.connect( server_hostname=DATABRICKS_HOST, http_path=SQL_HTTP_PATH, - auth_type="databricks-oauth", - credentials_provider=credentials_provider, + # https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + credentials_provider=posit_strategy.sql_credentials_provider(cfg) ) as connection: with connection.cursor() as cursor: cursor.execute("SELECT * FROM samples.nyctaxi.trips LIMIT 10;") diff --git a/examples/connect/streamlit/requirements.txt b/examples/connect/streamlit/requirements.txt index 63ccd077..8395b2c1 100644 --- a/examples/connect/streamlit/requirements.txt +++ b/examples/connect/streamlit/requirements.txt @@ -1,4 +1,4 @@ -databricks-sql-connector==3.0.1 -databricks-sdk==0.20.0 -streamlit==1.31.1 -git+https://github.com/posit-dev/posit-sdk-py.git +databricks-sql-connector==3.3.0 +databricks-sdk==0.29.0 +streamlit==1.37.0 +posit-sdk>=0.4.0 diff --git a/src/posit/connect/external/databricks.py b/src/posit/connect/external/databricks.py index 2e8f634a..e77cb262 100644 --- a/src/posit/connect/external/databricks.py +++ b/src/posit/connect/external/databricks.py @@ -5,73 +5,88 @@ from ..client import Client from ..oauth import OAuthIntegration -HeaderFactory = Callable[[], Dict[str, str]] - - -# https://github.com/databricks/databricks-sdk-py/blob/v0.20.0/databricks/sdk/credentials_provider.py -# https://github.com/databricks/databricks-sql-python/blob/v3.1.0/src/databricks/sql/auth/authenticators.py -# In order to keep compatibility with the Databricks SDK -class CredentialsProvider(abc.ABC): - """Protocol Databricks authentication. - - A call-side interface for the Databricks Rest API. - """ - +""" +NOTE: These APIs are provided as a convenience and are subject to breaking changes: +https://github.com/databricks/databricks-sdk-py#interface-stability +""" + +# The Databricks SDK CredentialsProvider == Databricks SQL HeaderFactory +CredentialsProvider = Callable[[], Dict[str, str]] + +# In order to keep compatibility with the Databricks SQL/SDK client libraries: +# https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/auth/authenticators.py#L19-L33 +# https://github.com/databricks/databricks-sdk-py/blob/v0.29.0/databricks/sdk/credentials_provider.py#L44-L54 +class CredentialsStrategy(abc.ABC): @abc.abstractmethod def auth_type(self) -> str: raise NotImplementedError @abc.abstractmethod - def __call__(self, *args, **kwargs) -> HeaderFactory: + def __call__(self, *args, **kwargs) -> CredentialsProvider: raise NotImplementedError -class PositOAuthIntegrationCredentialsProvider(CredentialsProvider): +# Use this environment variable to determine if the +# client SDK was initialized from a piece of content running on a Connect server. +def _is_local() -> bool: + return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT" + + +class PositOAuthIntegrationCredentialsProvider: def __init__(self, posit_oauth: OAuthIntegration, user_session_token: str): self.posit_oauth = posit_oauth self.user_session_token = user_session_token - def auth_type(self) -> str: - return "posit-oauth-integration" - - def __call__(self, *args, **kwargs) -> HeaderFactory: - def inner() -> Dict[str, str]: - access_token = self.posit_oauth.get_credentials( - self.user_session_token - )["access_token"] - return {"Authorization": f"Bearer {access_token}"} - - return inner + def __call__(self) -> Dict[str, str]: + access_token = self.posit_oauth.get_credentials( + self.user_session_token + )["access_token"] + return {"Authorization": f"Bearer {access_token}"} -# Use this environment variable to determine if the -# client SDK was initialized from a piece of content running on a Connect server. -def is_local() -> bool: - return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT" +class PositOAuthIntegrationCredentialsStrategy(CredentialsStrategy): + def __init__(self, + local_strategy: CredentialsStrategy, + user_session_token: Optional[str] = None, + client: Optional[Client] = None + ): + self.user_session_token = user_session_token + self.local_strategy = local_strategy + self.client = client -def viewer_credentials_provider( - client: Optional[Client] = None, user_session_token: Optional[str] = None -) -> Optional[CredentialsProvider]: - # If the content is not running on Connect then viewer auth should - # fall back to the locally configured credentials hierarchy - if is_local(): - return None + def sql_credentials_provider(self, *args, **kwargs): + """The sql connector attempts to call the credentials provider w/o any args. - if client is None: - client = Client() + The SQL client's `ExternalAuthProvider` is not compatible w/ the SDK's implementation of `CredentialsProvider`, + so create a no-arg lambda that wraps the args defined by the real caller. + This way we can pass in a databricks `Config` object required by most of the SDK's `CredentialsProvider` impls + from where `sql.connect` is called. + https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365 + """ + return lambda: self.__call__(*args, **kwargs) - # If the user-session-token wasn't provided and we're running on Connect then we raise an exception. - # user_session_token is required to impersonate the viewer. - if user_session_token is None: - raise ValueError( - "The user-session-token is required for viewer authentication." + def auth_type(self) -> str: + if _is_local(): + return self.local_strategy.auth_type() + else: + return "posit-oauth-integration" + + def __call__(self, *args, **kwargs) -> CredentialsProvider: + # If the content is not running on Connect then fall back to local_strategy + if _is_local(): + return self.local_strategy(*args, **kwargs) + + # If the user-session-token wasn't provided and we're running on Connect then we raise an exception. + # user_session_token is required to impersonate the viewer. + if self.user_session_token is None: + raise ValueError( + "The user-session-token is required for viewer authentication." + ) + + if self.client is None: + self.client = Client() + + return PositOAuthIntegrationCredentialsProvider( + self.client.oauth, self.user_session_token ) - - return PositOAuthIntegrationCredentialsProvider( - client.oauth, user_session_token - ) - - -def service_account_credentials_provider(client: Optional[Client] = None): - raise NotImplementedError