Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow oauth in databricks provider #233

Open
wants to merge 2 commits into
base: v1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions crates/goose-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,12 @@ fn create_openai_provider(cli: &Cli) -> Result<ProviderType> {
.or_else(|| env::var("OPENAI_API_KEY").ok())
.context("API key must be provided via --api-key or OPENAI_API_KEY environment variable")?;

Ok(ProviderType::OpenAi(OpenAiProvider::new(OpenAiProviderConfig {
api_key,
host: "https://api.openai.com".to_string(),
})?))
Ok(ProviderType::OpenAi(OpenAiProvider::new(
OpenAiProviderConfig {
api_key,
host: "https://api.openai.com".to_string(),
},
)?))
}

fn create_databricks_provider(cli: &Cli) -> Result<ProviderType> {
Expand All @@ -164,16 +166,19 @@ fn create_databricks_provider(cli: &Cli) -> Result<ProviderType> {
.or_else(|| env::var("DATABRICKS_HOST").ok())
.unwrap_or("https://block-lakehouse-production.cloud.databricks.com".to_string());

// databricks_token is optional. if not provided, we will use OAuth
let databricks_token = cli
.databricks_token
.clone()
.or_else(|| env::var("DATABRICKS_TOKEN").ok())
.context("Databricks token must be provided via --databricks-token or DATABRICKS_TOKEN environment variable")?;
.or_else(|| env::var("DATABRICKS_TOKEN").ok());

Ok(ProviderType::Databricks(DatabricksProvider::new(DatabricksProviderConfig {
host: databricks_host,
token: databricks_token,
})?))
Ok(ProviderType::Databricks(DatabricksProvider::new(
DatabricksProviderConfig {
host: databricks_host,
token: databricks_token,
use_oauth: true,
},
)?))
}

impl Provider for ProviderType {
Expand All @@ -188,11 +193,15 @@ impl Provider for ProviderType {
) -> Result<(Message, Usage)> {
match self {
ProviderType::OpenAi(provider) => {
provider.complete(model, system, messages, tools, temperature, max_tokens).await
},
provider
.complete(model, system, messages, tools, temperature, max_tokens)
.await
}
ProviderType::Databricks(provider) => {
provider.complete(model, system, messages, tools, temperature, max_tokens).await
},
provider
.complete(model, system, messages, tools, temperature, max_tokens)
.await
}
}
}

Expand Down
8 changes: 8 additions & 0 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
uuid = { version = "1.0", features = ["v4"] }
regex = "1.11.1"
chrono = { version = "0.4.38", features = ["serde"] }
sha2 = "0.10.8"
nanoid = "0.4.0"
webbrowser = "1.0.2"
warp = "0.3.7"
base64 = "0.22.1"
serde_urlencoded = "0.7.1"
url = "2.5.3"

[dev-dependencies]
wiremock = "0.6.0"
1 change: 1 addition & 0 deletions crates/goose/src/providers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod base;
pub mod configs;
pub mod databricks;
pub mod databricks_oauth;
pub mod openai;
pub mod types;
pub mod utils;
32 changes: 24 additions & 8 deletions crates/goose/src/providers/configs/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,41 @@ use anyhow::Result;

pub struct DatabricksProviderConfig {
pub host: String,
pub token: String,
pub token: Option<String>,
pub use_oauth: bool,
}

impl DatabricksProviderConfig {
pub fn new(host: String, token: String) -> Self {
Self { host, token }
pub fn new(host: String, token: Option<String>, use_oauth: bool) -> Self {
Self {
host,
token,
use_oauth,
}
}
}

impl ProviderConfig for DatabricksProviderConfig {
fn from_env() -> Result<Self> {
// Get required host
let host = Self::get_env("DATABRICKS_HOST", true, None)?
.ok_or_else(|| anyhow::anyhow!("Databricks host should be present"))?;
.ok_or_else(|| anyhow::anyhow!("Databricks host must be set"))?;

// Get required token
let token = Self::get_env("DATABRICKS_TOKEN", true, None)?
.ok_or_else(|| anyhow::anyhow!("Databricks token should be present"))?;
// Get optional token
let token = Self::get_env("DATABRICKS_TOKEN", false, None)?;

Ok(Self::new(host, token))
// Get use_oauth flag
let use_oauth = Self::get_env("DATABRICKS_USE_OAUTH", false, Some("false".to_string()))?
.map(|s| s.to_lowercase() == "true")
.unwrap_or(false);

// Ensure that either token is set or use_oauth is true
if token.is_none() && !use_oauth {
return Err(anyhow::anyhow!(
"Authentication not configured: set DATABRICKS_TOKEN or DATABRICKS_USE_OAUTH=true"
));
}

Ok(Self::new(host, token, use_oauth))
}
}
22 changes: 16 additions & 6 deletions crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,29 @@ use super::{
},
};

use super::databricks_oauth::get_oauth_token;

pub struct DatabricksProvider {
client: Client,
config: DatabricksProviderConfig,
}

impl DatabricksProvider {
pub fn new(config: DatabricksProviderConfig) -> Result<Self> {
// Determine the token to use
let token = if let Some(token) = &config.token {
token.clone()
} else if config.use_oauth {
get_oauth_token(&config.host)?
} else {
return Err(anyhow::anyhow!("No authentication method provided"));
};

let client = Client::builder()
.timeout(Duration::from_secs(600)) // 10 minutes timeout
.timeout(Duration::from_secs(600))
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Authorization", format!("Bearer {}", config.token).parse()?);
headers.insert("Authorization", format!("Bearer {}", token).parse()?);
headers
})
.build()?;
Expand Down Expand Up @@ -195,9 +206,7 @@ mod tests {

// Set up the mock to intercept the request and respond with the mocked response
Mock::given(method("POST"))
.and(path(
"/serving-endpoints/my-databricks-model/invocations",
))
.and(path("/serving-endpoints/my-databricks-model/invocations"))
.and(header("Authorization", "Bearer test_token"))
.and(body_json(expected_request_body.clone()))
.respond_with(ResponseTemplate::new(200).set_body_json(mock_response))
Expand All @@ -208,7 +217,8 @@ mod tests {
// Create the DatabricksProvider with the mock server's URL as the host
let config = DatabricksProviderConfig {
host: mock_server.uri(),
token: "test_token".to_string(),
token: Some("test_token".to_string()),
use_oauth: false,
};

let provider = DatabricksProvider::new(config)?;
Expand Down
Loading