diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d517b54c..684c2607 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -9,6 +9,24 @@ on: - v1.0 jobs: + format: + runs-on: ubuntu-latest + + steps: + - name: Checkout Code + uses: actions/checkout@v3 + + - name: Set up Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + profile: minimal + override: true + + - name: check format + run: | + cargo fmt --check + build-and-test: runs-on: ubuntu-latest @@ -79,3 +97,7 @@ jobs: run: cargo test --verbose env: OLLAMA_MODEL: "qwen2.5" + + - name: check lint + run: | + cargo clippy diff --git a/crates/goose-cli/src/agents/mock_agent.rs b/crates/goose-cli/src/agents/mock_agent.rs index c844ff94..4ddcece2 100644 --- a/crates/goose-cli/src/agents/mock_agent.rs +++ b/crates/goose-cli/src/agents/mock_agent.rs @@ -10,7 +10,7 @@ pub struct MockAgent; #[async_trait] impl Agent for MockAgent { fn add_system(&mut self, _system: Box) { - (); + () } async fn reply(&self, _messages: &[Message]) -> Result>> { diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index fa889d6c..da14a2c1 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -1,9 +1,8 @@ use crate::commands::expected_config::{get_recommended_models, RecommendedModels}; -use crate::inputs::inputs::get_user_input; -use crate::profile::profile::Profile; -use crate::profile::profile_handler::{find_existing_profile, profile_path, save_profile}; -use crate::profile::provider_helper::{ - select_provider_lists, set_provider_config, PROVIDER_OPEN_AI, +use crate::inputs::get_user_input; +use crate::profile::{ + find_existing_profile, profile_path, save_profile, select_provider_lists, set_provider_config, + Profile, PROVIDER_OPEN_AI, }; use cliclack::spinner; use console::style; @@ -58,8 +57,8 @@ async fn check_configuration(provider_config: ProviderConfig) -> Result<(), Box< Ok(()) } -fn get_existing_profile(profile_name: &String) -> Option { - let existing_profile_result = find_existing_profile(profile_name.as_str()); +fn get_existing_profile(profile_name: &str) -> Option { + let existing_profile_result = find_existing_profile(profile_name); if existing_profile_result.is_some() { println!("Profile already exists. We are going to overwriting the existing profile..."); } else { diff --git a/crates/goose-cli/src/commands/expected_config.rs b/crates/goose-cli/src/commands/expected_config.rs index abdb42cc..44454230 100644 --- a/crates/goose-cli/src/commands/expected_config.rs +++ b/crates/goose-cli/src/commands/expected_config.rs @@ -1,6 +1,6 @@ // This is a temporary file to simulate some configuration data from the backend -use crate::profile::provider_helper::{PROVIDER_DATABRICKS, PROVIDER_OLLAMA, PROVIDER_OPEN_AI}; +use crate::profile::{PROVIDER_DATABRICKS, PROVIDER_OLLAMA, PROVIDER_OPEN_AI}; use goose::providers::ollama::OLLAMA_MODEL; pub struct RecommendedModels { @@ -11,9 +11,13 @@ pub fn get_recommended_models(provider_name: &str) -> RecommendedModels { if provider_name == PROVIDER_OPEN_AI { RecommendedModels { model: "gpt-4o" } } else if provider_name == PROVIDER_DATABRICKS { - RecommendedModels { model: "claude-3-5-sonnet-2" } + RecommendedModels { + model: "claude-3-5-sonnet-2", + } } else if provider_name == PROVIDER_OLLAMA { - RecommendedModels { model: OLLAMA_MODEL } + RecommendedModels { + model: OLLAMA_MODEL, + } } else { panic!("Invalid provider name"); } diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index 928a8507..e89fb5ac 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -6,16 +6,14 @@ use goose::models::message::Message; use goose::providers::factory; use crate::commands::expected_config::get_recommended_models; -use crate::profile::profile::Profile; -use crate::profile::profile_handler::{load_profiles, PROFILE_DEFAULT_NAME}; -use crate::profile::provider_helper::set_provider_config; -use crate::profile::provider_helper::PROVIDER_OPEN_AI; +use crate::profile::{ + load_profiles, set_provider_config, Profile, PROFILE_DEFAULT_NAME, PROVIDER_OPEN_AI, +}; use crate::prompt::cliclack::CliclackPrompt; -use crate::prompt::prompt::Prompt; use crate::prompt::rustyline::RustylinePrompt; use crate::prompt::thinking::get_random_goose_action; -use crate::session::session::Session; -use crate::session::session_file::ensure_session_dir; +use crate::prompt::Prompt; +use crate::session::{ensure_session_dir, Session}; pub fn build_session<'a>( session: Option, diff --git a/crates/goose-cli/src/inputs/inputs.rs b/crates/goose-cli/src/inputs.rs similarity index 100% rename from crates/goose-cli/src/inputs/inputs.rs rename to crates/goose-cli/src/inputs.rs diff --git a/crates/goose-cli/src/inputs/mod.rs b/crates/goose-cli/src/inputs/mod.rs deleted file mode 100644 index cefb36d6..00000000 --- a/crates/goose-cli/src/inputs/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod inputs; diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index ed1f0845..ef7fa1ac 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -107,7 +107,7 @@ enum SystemCommands { enum CliProviderVariant { OpenAi, Databricks, - Ollama + Ollama, } #[tokio::main] diff --git a/crates/goose-cli/src/profile.rs b/crates/goose-cli/src/profile.rs new file mode 100644 index 00000000..d94b4891 --- /dev/null +++ b/crates/goose-cli/src/profile.rs @@ -0,0 +1,131 @@ +use std::collections::HashMap; +use std::error::Error; +use std::fs; +use std::path::PathBuf; + +use crate::inputs::get_env_value_or_input; +use goose::providers::configs::{ + DatabricksAuth, DatabricksProviderConfig, OllamaProviderConfig, OpenAiProviderConfig, + ProviderConfig, +}; +use goose::providers::factory::ProviderType; +use goose::providers::ollama::OLLAMA_HOST; +use serde::{Deserialize, Serialize}; +use strum::IntoEnumIterator; + +// Profile types and structures +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct Profile { + pub provider: String, + pub model: String, + #[serde(default)] + pub additional_systems: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct Profiles { + pub profile_items: HashMap, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct AdditionalSystem { + pub name: String, + pub location: String, +} + +// Provider helper constants and functions +pub const PROVIDER_OPEN_AI: &str = "openai"; +pub const PROVIDER_DATABRICKS: &str = "databricks"; +pub const PROVIDER_OLLAMA: &str = "ollama"; +pub const PROFILE_DEFAULT_NAME: &str = "default"; + +pub fn select_provider_lists() -> Vec<(&'static str, String, &'static str)> { + ProviderType::iter() + .map(|provider| match provider { + ProviderType::OpenAi => ( + PROVIDER_OPEN_AI, + PROVIDER_OPEN_AI.to_string(), + "Recommended", + ), + ProviderType::Databricks => (PROVIDER_DATABRICKS, PROVIDER_DATABRICKS.to_string(), ""), + ProviderType::Ollama => (PROVIDER_OLLAMA, PROVIDER_OLLAMA.to_string(), ""), + }) + .collect() +} + +pub fn profile_path() -> Result> { + let home_dir = dirs::home_dir().ok_or(anyhow::anyhow!("Could not determine home directory"))?; + let config_dir = home_dir.join(".config").join("goose"); + if !config_dir.exists() { + fs::create_dir_all(&config_dir)?; + } + Ok(config_dir.join("profiles.json")) +} + +pub fn load_profiles() -> Result, Box> { + let path = profile_path()?; + if !path.exists() { + return Ok(HashMap::new()); + } + let content = fs::read_to_string(path)?; + let profiles: Profiles = serde_json::from_str(&content)?; + Ok(profiles.profile_items) +} + +pub fn save_profile(name: &str, profile: Profile) -> Result<(), Box> { + let path = profile_path()?; + let mut profiles = load_profiles()?; + profiles.insert(name.to_string(), profile); + let profiles = Profiles { + profile_items: profiles, + }; + let content = serde_json::to_string_pretty(&profiles)?; + fs::write(path, content)?; + Ok(()) +} + +pub fn find_existing_profile(name: &str) -> Option { + match load_profiles() { + Ok(profiles) => profiles.get(name).cloned(), + Err(_) => None, + } +} + +pub fn set_provider_config(provider_name: &str, model: String) -> ProviderConfig { + match provider_name.to_lowercase().as_str() { + PROVIDER_OPEN_AI => ProviderConfig::OpenAi(OpenAiProviderConfig { + host: "https://api.openai.com".to_string(), + api_key: get_env_value_or_input( + "OPENAI_API_KEY", + "Please enter your OpenAI API key:", + true, + ), + model, + temperature: None, + max_tokens: None, + }), + PROVIDER_DATABRICKS => { + let host = get_env_value_or_input( + "DATABRICKS_HOST", + "Please enter your Databricks host:", + false, + ); + ProviderConfig::Databricks(DatabricksProviderConfig { + host: host.clone(), + // TODO revisit configuration + auth: DatabricksAuth::oauth(host), + model, + temperature: None, + max_tokens: None, + image_format: goose::providers::utils::ImageFormat::Anthropic, + }) + } + PROVIDER_OLLAMA => ProviderConfig::Ollama(OllamaProviderConfig { + host: std::env::var("OLLAMA_HOST").unwrap_or_else(|_| String::from(OLLAMA_HOST)), + model, + temperature: None, + max_tokens: None, + }), + _ => panic!("Invalid provider name"), + } +} diff --git a/crates/goose-cli/src/profile/mod.rs b/crates/goose-cli/src/profile/mod.rs deleted file mode 100644 index 9a7637ca..00000000 --- a/crates/goose-cli/src/profile/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod profile; -pub mod profile_handler; -pub mod provider_helper; diff --git a/crates/goose-cli/src/profile/profile.rs b/crates/goose-cli/src/profile/profile.rs deleted file mode 100644 index 501743aa..00000000 --- a/crates/goose-cli/src/profile/profile.rs +++ /dev/null @@ -1,20 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct Profile { - pub provider: String, - pub model: String, - #[serde(default)] - pub additional_systems: Vec, -} - -#[derive(Serialize, Deserialize)] -pub struct Profiles { - pub profile_items: std::collections::HashMap, -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct AdditionalSystem { - pub name: String, - pub location: String, -} diff --git a/crates/goose-cli/src/profile/profile_handler.rs b/crates/goose-cli/src/profile/profile_handler.rs deleted file mode 100644 index 8542fd9c..00000000 --- a/crates/goose-cli/src/profile/profile_handler.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::profile::profile::Profile; -use std::collections::HashMap; -use std::error::Error; -use std::fs::{create_dir_all, File}; -use std::io::Write; -use std::path::PathBuf; - -// TODO: set to profile-1.0.yaml temporarily to avoid overriting the existing config -pub const PROFILE_CONFIG_PATH: &str = ".config/goose/profile-1.0.yaml"; -pub const PROFILE_DEFAULT_NAME: &str = "default"; - -fn save_profiles_to_file(profiles: &HashMap) -> Result<(), Box> { - let path = profile_path()?; - - if let Some(parent) = path.parent() { - create_dir_all(parent)?; - } - - let yaml_string = serde_yaml::to_string(profiles)?; - let mut file = File::create(&path)?; - file.write_all(yaml_string.as_bytes())?; - Ok(()) -} - -pub fn profile_path() -> Result> { - let mut path = dirs::home_dir().ok_or("Failed to find home directory")?; - path.push(PROFILE_CONFIG_PATH); - Ok(path) -} - -pub fn save_profile(profile_name: &str, new_profile: Profile) -> Result<(), Box> { - let mut profiles = load_profiles().unwrap(); - profiles.insert(profile_name.to_string(), new_profile); - let _ = save_profiles_to_file(&profiles); - Ok(()) -} - -fn profile_file_exists() -> bool { - profile_path().unwrap().exists() -} -pub fn load_profiles() -> Result, Box> { - let path = profile_path()?; - if !path.exists() { - return Ok(HashMap::new()); - } - let file = File::open(&path)?; - match serde_yaml::from_reader(file) { - Ok(profiles) => Ok(profiles), - Err(e) => { - eprintln!("\x1b[31mFailed to parse profile file: {}\n\nPlease delete {} and recreate it.\n\x1b[0m", e, path.display()); - Err(Box::new(e)) - } - } -} - -pub fn find_existing_profile(profile_name: &str) -> Option { - if profile_file_exists() { - let profiles = load_profiles().unwrap(); - profiles.get(profile_name).cloned() - } else { - None - } -} diff --git a/crates/goose-cli/src/profile/provider_helper.rs b/crates/goose-cli/src/profile/provider_helper.rs deleted file mode 100644 index 1e9a6c53..00000000 --- a/crates/goose-cli/src/profile/provider_helper.rs +++ /dev/null @@ -1,65 +0,0 @@ -use crate::inputs::inputs::get_env_value_or_input; -use goose::providers::configs::{ - DatabricksAuth, DatabricksProviderConfig, OpenAiProviderConfig, OllamaProviderConfig, ProviderConfig -}; -use goose::providers::factory::ProviderType; -use goose::providers::ollama::OLLAMA_HOST; -use strum::IntoEnumIterator; - -pub const PROVIDER_OPEN_AI: &str = "openai"; -pub const PROVIDER_DATABRICKS: &str = "databricks"; -pub const PROVIDER_OLLAMA: &str = "ollama"; - -pub fn select_provider_lists() -> Vec<(&'static str, String, &'static str)> { - ProviderType::iter() - .map(|provider| match provider { - ProviderType::OpenAi => ( - PROVIDER_OPEN_AI, - PROVIDER_OPEN_AI.to_string(), - "Recommended", - ), - ProviderType::Databricks => (PROVIDER_DATABRICKS, PROVIDER_DATABRICKS.to_string(), ""), - ProviderType::Ollama => (PROVIDER_OLLAMA, PROVIDER_OLLAMA.to_string(), "") - }) - .collect() -} - -pub fn set_provider_config(provider_name: &str, model: String) -> ProviderConfig { - match provider_name.to_lowercase().as_str() { - PROVIDER_OPEN_AI => ProviderConfig::OpenAi(OpenAiProviderConfig { - host: "https://api.openai.com".to_string(), - api_key: get_env_value_or_input( - "OPENAI_API_KEY", - "Please enter your OpenAI API key:", - true, - ), - model, - temperature: None, - max_tokens: None, - }), - PROVIDER_DATABRICKS => { - let host = get_env_value_or_input( - "DATABRICKS_HOST", - "Please enter your Databricks host:", - false, - ); - ProviderConfig::Databricks(DatabricksProviderConfig { - host: host.clone(), - // TODO revisit configuration - auth: DatabricksAuth::oauth(host), - model, - temperature: None, - max_tokens: None, - image_format: goose::providers::utils::ImageFormat::Anthropic, - }) - } - PROVIDER_OLLAMA => ProviderConfig::Ollama(OllamaProviderConfig { - host: std::env::var("OLLAMA_HOST") - .unwrap_or_else(|_| String::from(OLLAMA_HOST)), - model, - temperature: None, - max_tokens: None, - }), - _ => panic!("Invalid provider name"), - } -} diff --git a/crates/goose-cli/src/prompt/prompt.rs b/crates/goose-cli/src/prompt.rs similarity index 94% rename from crates/goose-cli/src/prompt/prompt.rs rename to crates/goose-cli/src/prompt.rs index ed84c4c0..4ab52eb6 100644 --- a/crates/goose-cli/src/prompt/prompt.rs +++ b/crates/goose-cli/src/prompt.rs @@ -1,6 +1,10 @@ use anyhow::Result; use goose::models::message::Message; +pub mod cliclack; +pub mod rustyline; +pub mod thinking; + pub trait Prompt { fn render(&mut self, message: Box); fn get_input(&mut self) -> Result; diff --git a/crates/goose-cli/src/prompt/cliclack.rs b/crates/goose-cli/src/prompt/cliclack.rs index 092f579a..b5206c72 100644 --- a/crates/goose-cli/src/prompt/cliclack.rs +++ b/crates/goose-cli/src/prompt/cliclack.rs @@ -8,10 +8,7 @@ use bat::WrappingMode; use cliclack::{input, set_theme, spinner, Theme as CliclackTheme, ThemeState}; use goose::models::message::{Message, MessageContent, ToolRequest, ToolResponse}; -use super::{ - prompt::{Input, InputType, Prompt, Theme}, - thinking::get_random_thinking_message, -}; +use super::{thinking::get_random_thinking_message, Input, InputType, Prompt, Theme}; pub struct CliclackPrompt { spinner: cliclack::ProgressBar, diff --git a/crates/goose-cli/src/prompt/mod.rs b/crates/goose-cli/src/prompt/mod.rs deleted file mode 100644 index 27ebb571..00000000 --- a/crates/goose-cli/src/prompt/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod cliclack; -pub mod prompt; -pub mod rustyline; -pub mod thinking; diff --git a/crates/goose-cli/src/prompt/rustyline.rs b/crates/goose-cli/src/prompt/rustyline.rs index 4ab4f51f..f33d908b 100644 --- a/crates/goose-cli/src/prompt/rustyline.rs +++ b/crates/goose-cli/src/prompt/rustyline.rs @@ -12,10 +12,7 @@ use goose::models::role::Role; use goose::models::{content::Content, tool::ToolCall}; use serde_json::Value; -use super::{ - prompt::{Input, InputType, Prompt, Theme}, - thinking::get_random_thinking_message, -}; +use super::{thinking::get_random_thinking_message, Input, InputType, Prompt, Theme}; const PROMPT: &str = "\x1b[1m\x1b[38;5;30m( O)> \x1b[0m"; const MAX_STRING_LENGTH: usize = 40; diff --git a/crates/goose-cli/src/session/session.rs b/crates/goose-cli/src/session.rs similarity index 84% rename from crates/goose-cli/src/session/session.rs rename to crates/goose-cli/src/session.rs index 0e0b61f2..3a053c85 100644 --- a/crates/goose-cli/src/session/session.rs +++ b/crates/goose-cli/src/session.rs @@ -1,17 +1,71 @@ use anyhow::Result; use futures::StreamExt; +use serde_json; +use std::fs::{self, File}; +use std::io::{self, BufRead, Write}; use std::path::PathBuf; use crate::agents::agent::Agent; -use crate::prompt::prompt::{InputType, Prompt}; -use crate::session::session_file::{persist_messages, readable_session_file}; +use crate::prompt::{InputType, Prompt}; use crate::systems::goose_hints::GooseHintsSystem; use goose::developer::DeveloperSystem; use goose::models::message::{Message, MessageContent}; use goose::models::role::Role; -use super::session_file::deserialize_messages; +// File management functions +pub fn ensure_session_dir() -> Result { + let home_dir = dirs::home_dir().ok_or(anyhow::anyhow!("Could not determine home directory"))?; + let config_dir = home_dir.join(".config").join("goose").join("sessions"); + if !config_dir.exists() { + fs::create_dir_all(&config_dir)?; + } + + Ok(config_dir) +} + +pub fn readable_session_file(session_file: &PathBuf) -> Result { + match fs::OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(session_file) + { + Ok(file) => Ok(file), + Err(e) => Err(anyhow::anyhow!("Failed to open session file: {}", e)), + } +} + +pub fn persist_messages(session_file: &PathBuf, messages: &[Message]) -> Result<()> { + let file = fs::File::create(session_file)?; // Create or truncate the file + persist_messages_internal(file, messages) +} + +fn persist_messages_internal(session_file: File, messages: &[Message]) -> Result<()> { + let mut writer = std::io::BufWriter::new(session_file); + + for message in messages { + serde_json::to_writer(&mut writer, &message)?; + writeln!(writer)?; + } + + writer.flush()?; + Ok(()) +} + +pub fn deserialize_messages(file: File) -> Result> { + let reader = io::BufReader::new(file); + let mut messages = Vec::new(); + + for line in reader.lines() { + messages.push(serde_json::from_str::(&line?)?); + } + + Ok(messages) +} + +// Session management pub struct Session<'a> { agent: Box, prompt: Box, @@ -176,7 +230,7 @@ fn raw_message(content: &str) -> Box { #[cfg(test)] mod tests { use crate::agents::mock_agent::MockAgent; - use crate::prompt::prompt::{self, Input}; + use crate::prompt::{self, Input}; use super::*; use goose::{errors::AgentResult, models::tool::ToolCall}; diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs deleted file mode 100644 index c8cb955f..00000000 --- a/crates/goose-cli/src/session/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod session; -pub mod session_file; diff --git a/crates/goose-cli/src/session/session_file.rs b/crates/goose-cli/src/session/session_file.rs deleted file mode 100644 index 54a9e904..00000000 --- a/crates/goose-cli/src/session/session_file.rs +++ /dev/null @@ -1,356 +0,0 @@ -use anyhow::Result; -use serde_json; -use std::fs::{self, File}; -use std::io::{self, BufRead, Write}; - -use std::path::PathBuf; - -use goose::models::message::Message; - -pub fn ensure_session_dir() -> Result { - let home_dir = - dirs::home_dir().ok_or_else(|| anyhow::anyhow!("Could not determine home directory"))?; - let config_dir = home_dir.join(".config").join("goose").join("sessions"); - - if !config_dir.exists() { - fs::create_dir_all(&config_dir)?; - } - - Ok(config_dir) -} - -pub fn readable_session_file(session_file: &PathBuf) -> Result { - match fs::OpenOptions::new() - .read(true) - .write(true) - .create(true) - .truncate(false) - .open(session_file) - { - Ok(file) => Ok(file), - Err(e) => Err(anyhow::anyhow!("Failed to open session file: {}", e)), - } -} - -pub fn persist_messages(session_file: &PathBuf, messages: &[Message]) -> Result<()> { - let file = fs::File::create(session_file)?; // Create or truncate the file - persist_messages_internal(file, messages) -} - -fn persist_messages_internal(session_file: File, messages: &[Message]) -> Result<()> { - let mut writer = std::io::BufWriter::new(session_file); - - for message in messages { - serde_json::to_writer(&mut writer, &message)?; - writeln!(writer)?; - } - - writer.flush()?; - Ok(()) -} - -pub fn deserialize_messages(file: File) -> Result> { - let reader = io::BufReader::new(file); - let mut messages = Vec::new(); - - for line in reader.lines() { - messages.push(serde_json::from_str::(&line?)?); - } - - Ok(messages) -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - use tempfile::NamedTempFile; - - use crate::session::session_file::{deserialize_messages, persist_messages_internal}; - use goose::models::content::{Content, ImageContent, TextContent}; - use goose::models::message::{Message, MessageContent}; - use goose::models::message::{ToolRequest, ToolResponse}; - use goose::models::role::Role; - use goose::models::tool::ToolCall; - - #[test] - fn test_persist_text_message() -> Result<()> { - let temp_file = NamedTempFile::new()?; - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - - let messages = vec![Message { - role: Role::User, - created: now, - content: vec![MessageContent::Text(TextContent { - text: "Hello, world!".to_string(), - audience: Some(vec![Role::User]), - priority: Some(1.0), - })], - }]; - - persist_messages_internal(temp_file.reopen()?, &messages)?; - let deserialized = deserialize_messages(temp_file.reopen()?)?; - - assert_eq!(messages.len(), deserialized.len()); - if let MessageContent::Text(text) = &messages[0].content[0] { - if let MessageContent::Text(deserialized_text) = &deserialized[0].content[0] { - assert_eq!(text.text, deserialized_text.text); - assert_eq!(text.audience, deserialized_text.audience); - assert_eq!(text.priority, deserialized_text.priority); - } else { - panic!("Deserialized content is not text"); - } - } - Ok(()) - } - - #[test] - fn test_persist_tool_request() -> Result<()> { - let temp_file = NamedTempFile::new()?; - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - - let messages = vec![Message { - role: Role::Assistant, - created: now, - content: vec![MessageContent::ToolRequest(ToolRequest { - id: "magic".to_string(), - tool_call: Ok(ToolCall { - name: "test_tool".to_string(), - arguments: json!({"arg": "value"}), - }), - })], - }]; - - persist_messages_internal(temp_file.reopen()?, &messages)?; - let deserialized = deserialize_messages(temp_file.reopen()?)?; - - assert_eq!(messages.len(), deserialized.len()); - if let MessageContent::ToolRequest(req) = &messages[0].content[0] { - if let MessageContent::ToolRequest(deserialized_req) = &deserialized[0].content[0] { - if let (Ok(call), Ok(deserialized_call)) = - (&req.tool_call, &deserialized_req.tool_call) - { - assert_eq!(req.id, deserialized_req.id); - assert_eq!(call.name, deserialized_call.name); - assert_eq!(call.arguments, deserialized_call.arguments); - } else { - panic!("Tool call results don't match"); - } - } else { - panic!("Deserialized content is not a tool request"); - } - } - Ok(()) - } - - #[test] - fn test_persist_tool_response() -> Result<()> { - let temp_file = NamedTempFile::new()?; - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - - let messages = vec![Message { - role: Role::Assistant, - created: now, - content: vec![MessageContent::ToolResponse(ToolResponse { - id: "test_id".to_string(), - tool_result: Ok(vec![Content::Text(TextContent { - text: "success".to_string(), - audience: None, - priority: None, - })]), - })], - }]; - - persist_messages_internal(temp_file.reopen()?, &messages)?; - let deserialized = deserialize_messages(temp_file.reopen()?)?; - - assert_eq!(messages.len(), deserialized.len()); - if let MessageContent::ToolResponse(resp) = &messages[0].content[0] { - if let MessageContent::ToolResponse(deserialized_resp) = &deserialized[0].content[0] { - assert_eq!(resp.id, deserialized_resp.id); - assert_eq!(resp.tool_result, deserialized_resp.tool_result); - assert!(deserialized_resp.tool_result.is_ok()); - } else { - panic!("Deserialized content is not a tool response"); - } - } - Ok(()) - } - - #[test] - fn test_persist_tool_response_multiple_content() -> Result<()> { - let temp_file = NamedTempFile::new()?; - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - - let messages = vec![Message { - role: Role::Assistant, - created: now, - content: vec![MessageContent::ToolResponse(ToolResponse { - id: "test_id".to_string(), - tool_result: Ok(vec![ - Content::Text(TextContent { - text: "first result".to_string(), - audience: Some(vec![Role::User]), - priority: Some(1.0), - }), - Content::Text(TextContent { - text: "second result".to_string(), - audience: None, - priority: None, - }), - ]), - })], - }]; - - persist_messages_internal(temp_file.reopen()?, &messages)?; - let deserialized = deserialize_messages(temp_file.reopen()?)?; - - assert_eq!(messages.len(), deserialized.len()); - if let MessageContent::ToolResponse(resp) = &messages[0].content[0] { - if let MessageContent::ToolResponse(deserialized_resp) = &deserialized[0].content[0] { - assert_eq!(resp.id, deserialized_resp.id); - if let (Ok(original_results), Ok(deserialized_results)) = - (&resp.tool_result, &deserialized_resp.tool_result) - { - assert_eq!(original_results.len(), deserialized_results.len()); - - // Check first result with audience and priority - if let (Content::Text(original_text), Content::Text(deserialized_text)) = - (&original_results[0], &deserialized_results[0]) - { - assert_eq!(original_text.text, deserialized_text.text); - assert_eq!(original_text.audience, deserialized_text.audience); - assert_eq!(original_text.priority, deserialized_text.priority); - } - - // Check second result without audience and priority - if let (Content::Text(original_text), Content::Text(deserialized_text)) = - (&original_results[1], &deserialized_results[1]) - { - assert_eq!(original_text.text, deserialized_text.text); - assert_eq!(original_text.audience, deserialized_text.audience); - assert_eq!(original_text.priority, deserialized_text.priority); - } - } else { - panic!("Tool result is not Ok"); - } - } else { - panic!("Deserialized content is not a tool response"); - } - } - Ok(()) - } - - #[test] - fn test_persist_tool_response_with_image() -> Result<()> { - let temp_file = NamedTempFile::new()?; - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - - let messages = vec![Message { - role: Role::Assistant, - created: now, - content: vec![MessageContent::ToolResponse(ToolResponse { - id: "test_id".to_string(), - tool_result: Ok(vec![ - Content::Text(TextContent { - text: "text result".to_string(), - audience: None, - priority: None, - }), - Content::Image(ImageContent { - mime_type: "image/png".to_string(), - data: "base64data".to_string(), - audience: Some(vec![Role::User]), - priority: Some(1.0), - }), - ]), - })], - }]; - - persist_messages_internal(temp_file.reopen()?, &messages)?; - let deserialized = deserialize_messages(temp_file.reopen()?)?; - - assert_eq!(messages.len(), deserialized.len()); - if let MessageContent::ToolResponse(resp) = &messages[0].content[0] { - if let MessageContent::ToolResponse(deserialized_resp) = &deserialized[0].content[0] { - assert_eq!(resp.id, deserialized_resp.id); - if let (Ok(original_results), Ok(deserialized_results)) = - (&resp.tool_result, &deserialized_resp.tool_result) - { - assert_eq!(original_results.len(), deserialized_results.len()); - - // Check text content - if let (Content::Text(original_text), Content::Text(deserialized_text)) = - (&original_results[0], &deserialized_results[0]) - { - assert_eq!(original_text.text, deserialized_text.text); - } else { - panic!("First result is not text content"); - } - - // Check image content - if let (Content::Image(original_img), Content::Image(deserialized_img)) = - (&original_results[1], &deserialized_results[1]) - { - assert_eq!(original_img.mime_type, deserialized_img.mime_type); - assert_eq!(original_img.data, deserialized_img.data); - assert_eq!(original_img.audience, deserialized_img.audience); - assert_eq!(original_img.priority, deserialized_img.priority); - } else { - panic!("Second result is not image content"); - } - } - } - } - Ok(()) - } - - #[test] - fn test_persist_image() -> Result<()> { - let temp_file = NamedTempFile::new()?; - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - - let messages = vec![Message { - role: Role::User, - created: now, - content: vec![MessageContent::Image(ImageContent { - mime_type: "image/png".to_string(), - data: "base64data".to_string(), - audience: None, - priority: None, - })], - }]; - - persist_messages_internal(temp_file.reopen()?, &messages)?; - let deserialized = deserialize_messages(temp_file.reopen()?)?; - - assert_eq!(messages.len(), deserialized.len()); - if let MessageContent::Image(img) = &messages[0].content[0] { - if let MessageContent::Image(deserialized_img) = &deserialized[0].content[0] { - assert_eq!(img.mime_type, deserialized_img.mime_type); - assert_eq!(img.data, deserialized_img.data); - } else { - panic!("Deserialized content is not an image"); - } - } - Ok(()) - } -} diff --git a/crates/goose-cli/src/systems/system_handler.rs b/crates/goose-cli/src/systems/system_handler.rs index 247009d1..3fd9eac5 100644 --- a/crates/goose-cli/src/systems/system_handler.rs +++ b/crates/goose-cli/src/systems/system_handler.rs @@ -1,5 +1,4 @@ -use crate::profile::profile::AdditionalSystem; -use crate::profile::profile_handler::{load_profiles, save_profile}; +use crate::profile::{load_profiles, save_profile, AdditionalSystem}; use serde_json::Value; use std::error::Error; diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 82baa233..b95c7c1d 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -235,7 +235,8 @@ async fn stream_message( } MessageContent::Text(text) => { for line in text.text.lines() { - tx.send(ProtocolFormatter::format_text(&format!("{}\\n", line))).await?; + tx.send(ProtocolFormatter::format_text(&format!("{}\\n", line))) + .await?; } } MessageContent::Image(_) => { @@ -314,8 +315,6 @@ async fn handler( Ok(SseResponse::new(stream)) } - - #[derive(Debug, Deserialize)] struct AskRequest { prompt: String, @@ -326,13 +325,11 @@ struct AskResponse { response: String, } - // simple ask an AI for a response, non streaming async fn ask_handler( State(state): State, Json(request): Json, ) -> Result, StatusCode> { - let provider = factory::get_provider(state.provider_config) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; diff --git a/crates/goose/src/agent.rs b/crates/goose/src/agent.rs index 6d569c23..ee5bbcb4 100644 --- a/crates/goose/src/agent.rs +++ b/crates/goose/src/agent.rs @@ -81,13 +81,16 @@ impl Agent { } /// Find the appropriate system for a tool call based on the prefixed name - fn get_system_for_tool(&self, prefixed_name: &str) -> Option<&Box> { + fn get_system_for_tool(&self, prefixed_name: &str) -> Option<&dyn System> { let parts: Vec<&str> = prefixed_name.split("__").collect(); if parts.len() != 2 { return None; } let system_name = parts[0]; - self.systems.iter().find(|sys| sys.name() == system_name) + self.systems + .iter() + .find(|sys| sys.name() == system_name) + .map(|v| &**v) } /// Dispatch a single tool call to the appropriate system diff --git a/crates/goose/src/developer.rs b/crates/goose/src/developer.rs index 9324a4ac..1fbc91ed 100644 --- a/crates/goose/src/developer.rs +++ b/crates/goose/src/developer.rs @@ -197,10 +197,11 @@ impl DeveloperSystem { .output() .map_err(|e| AgentError::ExecutionError(e.to_string()))?; - let output_str = String::from_utf8_lossy(&output.stdout).to_string(); - if !output.status.success() { - return Err(AgentError::ExecutionError(output_str)); - } + let output_str = format!( + "Finished with Status Code: {}\nOutput:\n{}", + output.status, + String::from_utf8_lossy(&output.stdout) + ); Ok(vec![ Content::text(output_str).with_audience(vec![Role::Assistant]) ]) @@ -456,13 +457,10 @@ impl DeveloperSystem { let mut lines: Vec = content.lines().map(|s| s.to_string()).collect(); if insert_line > lines.len() { - return Err(AgentError::InvalidParameters( - format!( - "The insert line is greater than the length of the file ({} lines)", - lines.len() - ) - .into(), - )); + return Err(AgentError::InvalidParameters(format!( + "The insert line is greater than the length of the file ({} lines)", + lines.len() + ))); } // Save history for undo @@ -537,14 +535,13 @@ impl DeveloperSystem { // Capture the screenshot using xcap let monitors = Monitor::all() .map_err(|_| AgentError::ExecutionError("Failed to access monitors".into()))?; - let monitor = monitors.get(display).ok_or(AgentError::ExecutionError( - format!( + let monitor = monitors + .get(display) + .ok_or(AgentError::ExecutionError(format!( "{} was not an available monitor, {} found.", display, monitors.len() - ) - .into(), - ))?; + )))?; let mut image = monitor.capture_image().map_err(|e| { AgentError::ExecutionError(format!("Failed to capture display {}: {}", display, e))