From 64ff529e6de271c8d7f6a5d19353af558783f369 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Mon, 11 Nov 2024 10:33:58 -0500 Subject: [PATCH 1/2] feat: allow oauth in databricks provider --- crates/goose-cli/src/main.rs | 37 ++- crates/goose/Cargo.toml | 8 + crates/goose/src/providers.rs | 1 + .../goose/src/providers/configs/databricks.rs | 32 +- crates/goose/src/providers/databricks.rs | 19 +- .../goose/src/providers/databricks_oauth.rs | 306 ++++++++++++++++++ 6 files changed, 378 insertions(+), 25 deletions(-) create mode 100644 crates/goose/src/providers/databricks_oauth.rs diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index 389c6554..f5ff1b03 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -151,10 +151,12 @@ fn create_openai_provider(cli: &Cli) -> Result { .or_else(|| env::var("OPENAI_API_KEY").ok()) .context("API key must be provided via --api-key or OPENAI_API_KEY environment variable")?; - Ok(ProviderType::OpenAi(OpenAiProvider::new(OpenAiProviderConfig { - api_key, - host: "https://api.openai.com".to_string(), - })?)) + Ok(ProviderType::OpenAi(OpenAiProvider::new( + OpenAiProviderConfig { + api_key, + host: "https://api.openai.com".to_string(), + }, + )?)) } fn create_databricks_provider(cli: &Cli) -> Result { @@ -164,16 +166,19 @@ fn create_databricks_provider(cli: &Cli) -> Result { .or_else(|| env::var("DATABRICKS_HOST").ok()) .unwrap_or("https://block-lakehouse-production.cloud.databricks.com".to_string()); + // databricks_token is optional. if not provided, we will use OAuth let databricks_token = cli .databricks_token .clone() - .or_else(|| env::var("DATABRICKS_TOKEN").ok()) - .context("Databricks token must be provided via --databricks-token or DATABRICKS_TOKEN environment variable")?; + .or_else(|| env::var("DATABRICKS_TOKEN").ok()); - Ok(ProviderType::Databricks(DatabricksProvider::new(DatabricksProviderConfig { - host: databricks_host, - token: databricks_token, - })?)) + Ok(ProviderType::Databricks(DatabricksProvider::new( + DatabricksProviderConfig { + host: databricks_host, + token: databricks_token, + use_oauth: true, + }, + )?)) } impl Provider for ProviderType { @@ -188,11 +193,15 @@ impl Provider for ProviderType { ) -> Result<(Message, Usage)> { match self { ProviderType::OpenAi(provider) => { - provider.complete(model, system, messages, tools, temperature, max_tokens).await - }, + provider + .complete(model, system, messages, tools, temperature, max_tokens) + .await + } ProviderType::Databricks(provider) => { - provider.complete(model, system, messages, tools, temperature, max_tokens).await - }, + provider + .complete(model, system, messages, tools, temperature, max_tokens) + .await + } } } diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 17de2bf9..443c94fc 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -16,6 +16,14 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" uuid = { version = "1.0", features = ["v4"] } regex = "1.11.1" +chrono = { version = "0.4.38", features = ["serde"] } +sha2 = "0.10.8" +nanoid = "0.4.0" +webbrowser = "1.0.2" +warp = "0.3.7" +base64 = "0.22.1" +serde_urlencoded = "0.7.1" +url = "2.5.3" [dev-dependencies] wiremock = "0.6.0" diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index 72e2ad0c..2631ec70 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -1,6 +1,7 @@ pub mod base; pub mod configs; pub mod databricks; +pub mod databricks_oauth; pub mod openai; pub mod types; pub mod utils; diff --git a/crates/goose/src/providers/configs/databricks.rs b/crates/goose/src/providers/configs/databricks.rs index 93ca96c2..c2803fe2 100644 --- a/crates/goose/src/providers/configs/databricks.rs +++ b/crates/goose/src/providers/configs/databricks.rs @@ -3,12 +3,17 @@ use anyhow::Result; pub struct DatabricksProviderConfig { pub host: String, - pub token: String, + pub token: Option, + pub use_oauth: bool, } impl DatabricksProviderConfig { - pub fn new(host: String, token: String) -> Self { - Self { host, token } + pub fn new(host: String, token: Option, use_oauth: bool) -> Self { + Self { + host, + token, + use_oauth, + } } } @@ -16,12 +21,23 @@ impl ProviderConfig for DatabricksProviderConfig { fn from_env() -> Result { // Get required host let host = Self::get_env("DATABRICKS_HOST", true, None)? - .ok_or_else(|| anyhow::anyhow!("Databricks host should be present"))?; + .ok_or_else(|| anyhow::anyhow!("Databricks host must be set"))?; - // Get required token - let token = Self::get_env("DATABRICKS_TOKEN", true, None)? - .ok_or_else(|| anyhow::anyhow!("Databricks token should be present"))?; + // Get optional token + let token = Self::get_env("DATABRICKS_TOKEN", false, None)?; - Ok(Self::new(host, token)) + // Get use_oauth flag + let use_oauth = Self::get_env("DATABRICKS_USE_OAUTH", false, Some("false".to_string()))? + .map(|s| s.to_lowercase() == "true") + .unwrap_or(false); + + // Ensure that either token is set or use_oauth is true + if token.is_none() && !use_oauth { + return Err(anyhow::anyhow!( + "Authentication not configured: set DATABRICKS_TOKEN or DATABRICKS_USE_OAUTH=true" + )); + } + + Ok(Self::new(host, token, use_oauth)) } } diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index d77ac86d..429767fd 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -14,18 +14,30 @@ use super::{ }, }; +use super::databricks_oauth::get_oauth_token; + pub struct DatabricksProvider { client: Client, config: DatabricksProviderConfig, } impl DatabricksProvider { + pub fn new(config: DatabricksProviderConfig) -> Result { + // Determine the token to use + let token = if let Some(token) = &config.token { + token.clone() + } else if config.use_oauth { + get_oauth_token(&config.host)? + } else { + return Err(anyhow::anyhow!("No authentication method provided")); + }; + let client = Client::builder() - .timeout(Duration::from_secs(600)) // 10 minutes timeout + .timeout(Duration::from_secs(600)) .default_headers({ let mut headers = reqwest::header::HeaderMap::new(); - headers.insert("Authorization", format!("Bearer {}", config.token).parse()?); + headers.insert("Authorization", format!("Bearer {}", token).parse()?); headers }) .build()?; @@ -208,7 +220,8 @@ mod tests { // Create the DatabricksProvider with the mock server's URL as the host let config = DatabricksProviderConfig { host: mock_server.uri(), - token: "test_token".to_string(), + token: Some("test_token".to_string()), + use_oauth: false, }; let provider = DatabricksProvider::new(config)?; diff --git a/crates/goose/src/providers/databricks_oauth.rs b/crates/goose/src/providers/databricks_oauth.rs new file mode 100644 index 00000000..e7d5a827 --- /dev/null +++ b/crates/goose/src/providers/databricks_oauth.rs @@ -0,0 +1,306 @@ +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use std::fs; +use anyhow::Result; +use serde_json::Value; +use sha2::Digest; +use base64::Engine; +use url::Url; + +#[derive(Debug, Clone)] +struct OidcEndpoints { + authorization_endpoint: String, + token_endpoint: String, +} + +#[derive(Serialize, Deserialize)] +struct TokenData { + access_token: String, + expires_at: Option>, +} + +struct TokenCache { + cache_path: PathBuf, +} + +const BASE_PATH: &str = concat!(env!("HOME"), "/.config/goose/databricks/oauth"); + +impl TokenCache { + fn new(host: &str, client_id: &str, scopes: &[String]) -> Self { + let mut hasher = sha2::Sha256::new(); + hasher.update(host.as_bytes()); + hasher.update(client_id.as_bytes()); + hasher.update(scopes.join(",").as_bytes()); + let hash = format!("{:x}", hasher.finalize()); + + fs::create_dir_all(BASE_PATH).unwrap(); + let cache_path = PathBuf::from(BASE_PATH).join(format!("{}.json", hash)); + + Self { cache_path } + } + + fn load_token(&self) -> Option { + if let Ok(contents) = fs::read_to_string(&self.cache_path) { + if let Ok(token_data) = serde_json::from_str::(&contents) { + if let Some(expires_at) = token_data.expires_at { + if expires_at > chrono::Utc::now() { + return Some(token_data); + } + } else { + return Some(token_data); + } + } + } + None + } + + fn save_token(&self, token_data: &TokenData) -> Result<(), anyhow::Error> { + if let Some(parent) = self.cache_path.parent() { + fs::create_dir_all(parent)?; + } + let contents = serde_json::to_string(token_data)?; + fs::write(&self.cache_path, contents)?; + Ok(()) + } +} + + + +async fn get_workspace_endpoints(host: &str) -> Result { + let host = host.trim_end_matches('/'); + let oidc_url = format!("{}/oidc/.well-known/oauth-authorization-server", host); + + let client = reqwest::Client::new(); + let resp = client.get(&oidc_url).send().await?; + + if !resp.status().is_success() { + return Err(anyhow::anyhow!("Failed to get OIDC configuration from {}", oidc_url)); + } + + let oidc_config: Value = resp.json().await?; + + let authorization_endpoint = oidc_config + .get("authorization_endpoint") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("authorization_endpoint not found in OIDC configuration"))? + .to_string(); + + let token_endpoint = oidc_config + .get("token_endpoint") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("token_endpoint not found in OIDC configuration"))? + .to_string(); + + Ok(OidcEndpoints { + authorization_endpoint, + token_endpoint, + }) +} + + +struct OAuthClient { + oidc_endpoints: OidcEndpoints, + redirect_url: String, + client_id: String, + scopes: Vec, +} + +impl OAuthClient { + fn new( + oidc_endpoints: OidcEndpoints, + redirect_url: String, + client_id: String, + scopes: Vec, + ) -> Self { + Self { + oidc_endpoints, + redirect_url, + client_id, + scopes, + } + } + + fn initiate_consent(&self) -> Consent { + // Generate state and PKCE verifier/challenge + let state = nanoid::nanoid!(16); + let verifier = nanoid::nanoid!(64); + let challenge = { + let digest = sha2::Sha256::digest(verifier.as_bytes()); + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest) + }; + + // Build authorization URL + let params = [ + ("response_type", "code"), + ("client_id", &self.client_id), + ("redirect_uri", &self.redirect_url), + ("scope", &self.scopes.join(" ")), + ("state", &state), + ("code_challenge", &challenge), + ("code_challenge_method", "S256"), + ]; + let authorization_url = format!( + "{}?{}", + self.oidc_endpoints.authorization_endpoint, + serde_urlencoded::to_string(¶ms).unwrap() + ); + + Consent { + state, + verifier, + authorization_url, + redirect_url: self.redirect_url.clone(), + token_endpoint: self.oidc_endpoints.token_endpoint.clone(), + client_id: self.client_id.clone(), + } + } +} + +struct Consent { + state: String, + verifier: String, + authorization_url: String, + redirect_url: String, + token_endpoint: String, + client_id: String, +} + +impl Consent { + async fn launch_external_browser(&self) -> Result { + // Open the authorization URL in the user's browser + if webbrowser::open(&self.authorization_url).is_err() { + println!("Open this URL in your browser:\n{}", self.authorization_url); + } + + // Start a local server to receive the redirect + use warp::Filter; + use std::sync::{Arc, Mutex}; + use tokio::sync::oneshot; + + let (tx, rx) = oneshot::channel(); + + let state = self.state.clone(); + let tx = Arc::new(Mutex::new(Some(tx))); + + let routes = warp::get() + .and(warp::path::end()) + .and(warp::query::query::>()) + .map(move |params: std::collections::HashMap| { + let code = params.get("code").cloned(); + let received_state = params.get("state").cloned(); + if let (Some(code), Some(received_state)) = (code, received_state) { + if received_state == state { + if let Some(tx) = tx.lock().unwrap().take() { + let _ = tx.send(code); + } + "Authentication successful! You can close this window." + } else { + "State mismatch." + } + } else { + "Authentication failed." + } + }); + + let redirect_url = Url::parse(&self.redirect_url)?; + let port = redirect_url.port().unwrap_or(80); + + let (_addr, server) = warp::serve(routes).bind_ephemeral(([127, 0, 0, 1], port)); + let server_handle = tokio::task::spawn(server); + + // Wait for the authorization code + let code = rx.await?; + + // Stop the server + server_handle.abort(); + + // Exchange the code for a token + self.exchange_code_for_token(&code).await + } + + async fn exchange_code_for_token(&self, code: &str) -> Result { + let params = [ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", &self.redirect_url), + ("code_verifier", &self.verifier), + ("client_id", &self.client_id), + ]; + + let client = reqwest::Client::new(); + let resp = client + .post(&self.token_endpoint) + .header("Content-Type", "application/x-www-form-urlencoded") + .form(¶ms) + .send() + .await?; + + if !resp.status().is_success() { + let err_text = resp.text().await?; + return Err(anyhow::anyhow!("Failed to exchange code for token: {}", err_text)); + } + + let token_response: serde_json::Value = resp.json().await?; + let access_token = token_response + .get("access_token") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))? + .to_string(); + + let expires_in = token_response + .get("expires_in") + .and_then(|v| v.as_u64()) + .unwrap_or(3600); + + let expires_at = chrono::Utc::now() + chrono::Duration::seconds(expires_in as i64); + + Ok(TokenData { + access_token, + expires_at: Some(expires_at), + }) + } +} + +pub fn get_oauth_token(host: &str) -> Result { + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + get_oauth_token_async(host).await + }) + }) +} + +pub async fn get_oauth_token_async(host: &str) -> Result { + + let client_id = "databricks-cli"; + let redirect_url = "http://localhost:8020"; + + let oidc_endpoints = get_workspace_endpoints(host).await?; + let scopes = vec!["all-apis".to_string()]; + + let token_cache = TokenCache::new(host, client_id, &scopes); + + // Attempt to load token from cache + if let Some(token_data) = token_cache.load_token() { + return Ok(token_data.access_token); + } + + // Create OAuthClient + let oauth_client = OAuthClient::new( + oidc_endpoints, + redirect_url.to_string(), + client_id.to_string(), + scopes, + ); + + // Initiate consent + let consent = oauth_client.initiate_consent(); + + // Launch external browser and get token + let token_data = consent.launch_external_browser().await?; + + // Save token to cache + token_cache.save_token(&token_data)?; + + Ok(token_data.access_token) + +} From 05f43e9ea01e237ca86617544fb2ccb9d367120c Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Mon, 11 Nov 2024 10:34:50 -0500 Subject: [PATCH 2/2] run cargo fmt --- crates/goose/src/providers/databricks.rs | 5 +-- .../goose/src/providers/databricks_oauth.rs | 33 ++++++++++--------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 429767fd..4293c19d 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -22,7 +22,6 @@ pub struct DatabricksProvider { } impl DatabricksProvider { - pub fn new(config: DatabricksProviderConfig) -> Result { // Determine the token to use let token = if let Some(token) = &config.token { @@ -207,9 +206,7 @@ mod tests { // Set up the mock to intercept the request and respond with the mocked response Mock::given(method("POST")) - .and(path( - "/serving-endpoints/my-databricks-model/invocations", - )) + .and(path("/serving-endpoints/my-databricks-model/invocations")) .and(header("Authorization", "Bearer test_token")) .and(body_json(expected_request_body.clone())) .respond_with(ResponseTemplate::new(200).set_body_json(mock_response)) diff --git a/crates/goose/src/providers/databricks_oauth.rs b/crates/goose/src/providers/databricks_oauth.rs index e7d5a827..ca37ad66 100644 --- a/crates/goose/src/providers/databricks_oauth.rs +++ b/crates/goose/src/providers/databricks_oauth.rs @@ -1,10 +1,10 @@ -use serde::{Deserialize, Serialize}; -use std::path::PathBuf; -use std::fs; use anyhow::Result; +use base64::Engine; +use serde::{Deserialize, Serialize}; use serde_json::Value; use sha2::Digest; -use base64::Engine; +use std::fs; +use std::path::PathBuf; use url::Url; #[derive(Debug, Clone)] @@ -64,8 +64,6 @@ impl TokenCache { } } - - async fn get_workspace_endpoints(host: &str) -> Result { let host = host.trim_end_matches('/'); let oidc_url = format!("{}/oidc/.well-known/oauth-authorization-server", host); @@ -74,7 +72,10 @@ async fn get_workspace_endpoints(host: &str) -> Result { let resp = client.get(&oidc_url).send().await?; if !resp.status().is_success() { - return Err(anyhow::anyhow!("Failed to get OIDC configuration from {}", oidc_url)); + return Err(anyhow::anyhow!( + "Failed to get OIDC configuration from {}", + oidc_url + )); } let oidc_config: Value = resp.json().await?; @@ -97,7 +98,6 @@ async fn get_workspace_endpoints(host: &str) -> Result { }) } - struct OAuthClient { oidc_endpoints: OidcEndpoints, redirect_url: String, @@ -173,9 +173,9 @@ impl Consent { } // Start a local server to receive the redirect - use warp::Filter; use std::sync::{Arc, Mutex}; use tokio::sync::oneshot; + use warp::Filter; let (tx, rx) = oneshot::channel(); @@ -184,7 +184,9 @@ impl Consent { let routes = warp::get() .and(warp::path::end()) - .and(warp::query::query::>()) + .and(warp::query::query::< + std::collections::HashMap, + >()) .map(move |params: std::collections::HashMap| { let code = params.get("code").cloned(); let received_state = params.get("state").cloned(); @@ -237,7 +239,10 @@ impl Consent { if !resp.status().is_success() { let err_text = resp.text().await?; - return Err(anyhow::anyhow!("Failed to exchange code for token: {}", err_text)); + return Err(anyhow::anyhow!( + "Failed to exchange code for token: {}", + err_text + )); } let token_response: serde_json::Value = resp.json().await?; @@ -263,14 +268,11 @@ impl Consent { pub fn get_oauth_token(host: &str) -> Result { tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(async { - get_oauth_token_async(host).await - }) + tokio::runtime::Handle::current().block_on(async { get_oauth_token_async(host).await }) }) } pub async fn get_oauth_token_async(host: &str) -> Result { - let client_id = "databricks-cli"; let redirect_url = "http://localhost:8020"; @@ -302,5 +304,4 @@ pub async fn get_oauth_token_async(host: &str) -> Result { token_cache.save_token(&token_data)?; Ok(token_data.access_token) - }