diff --git a/Cargo.lock b/Cargo.lock index a6b5b0af..3ac7d230 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -895,6 +895,19 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -923,6 +936,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.20" @@ -6791,6 +6813,7 @@ dependencies = [ "async-trait", "clap", "cocoa 0.25.0", + "crossbeam", "komorebi-client", "netdev", "regex", diff --git a/packages/desktop/Cargo.toml b/packages/desktop/Cargo.toml index b545ec88..086fe7ad 100644 --- a/packages/desktop/Cargo.toml +++ b/packages/desktop/Cargo.toml @@ -16,6 +16,7 @@ tauri-build = { version = "2.0", features = [] } anyhow = "1" async-trait = "0.1" clap = { version = "4", features = ["derive"] } +crossbeam = "0.8" reqwest = { version = "0.11", features = ["json"] } tauri = { version = "2.0", features = [ "devtools", diff --git a/packages/desktop/src/commands.rs b/packages/desktop/src/commands.rs index 01ffb595..0b295cbc 100644 --- a/packages/desktop/src/commands.rs +++ b/packages/desktop/src/commands.rs @@ -8,7 +8,10 @@ use crate::common::macos::WindowExtMacOs; use crate::common::windows::WindowExtWindows; use crate::{ config::{Config, WidgetConfig, WidgetPlacement}, - providers::{ProviderConfig, ProviderManager}, + providers::{ + ProviderConfig, ProviderFunction, ProviderFunctionResponse, + ProviderManager, + }, widget_factory::{WidgetFactory, WidgetOpenOptions, WidgetState}, }; @@ -98,7 +101,19 @@ pub async fn unlisten_provider( provider_manager: State<'_, Arc>, ) -> anyhow::Result<(), String> { provider_manager - .destroy(config_hash) + .stop(config_hash) + .await + .map_err(|err| err.to_string()) +} + +#[tauri::command] +pub async fn call_provider_function( + config_hash: String, + function: ProviderFunction, + provider_manager: State<'_, Arc>, +) -> anyhow::Result { + provider_manager + .call_function(config_hash, function) .await .map_err(|err| err.to_string()) } diff --git a/packages/desktop/src/common/interval.rs b/packages/desktop/src/common/interval.rs new file mode 100644 index 00000000..243d9451 --- /dev/null +++ b/packages/desktop/src/common/interval.rs @@ -0,0 +1,67 @@ +use std::time::{Duration, Instant}; + +/// An interval timer for synchronous contexts using crossbeam. +/// +/// For use with crossbeam's `select!` macro. +pub struct SyncInterval { + interval: Duration, + next_tick: Instant, + is_first: bool, +} + +impl SyncInterval { + pub fn new(interval_ms: u64) -> Self { + Self { + interval: Duration::from_millis(interval_ms), + next_tick: Instant::now(), + is_first: true, + } + } + + /// Returns a receiver that will get a message at the next tick time. + pub fn tick(&mut self) -> crossbeam::channel::Receiver { + if self.is_first { + // Emit immediately on the first tick. + self.is_first = false; + crossbeam::channel::after(Duration::from_secs(0)) + } else if let Some(wait_duration) = + self.next_tick.checked_duration_since(Instant::now()) + { + // Wait normally until the next tick. + let timer = crossbeam::channel::after(wait_duration); + self.next_tick += self.interval; + timer + } else { + // We're behind - skip missed ticks to catch up. + while self.next_tick <= Instant::now() { + self.next_tick += self.interval; + } + + crossbeam::channel::after(self.next_tick - Instant::now()) + } + } +} + +/// An interval timer for asynchronous contexts using tokio. +pub struct AsyncInterval { + interval: tokio::time::Interval, +} + +impl AsyncInterval { + pub fn new(interval_ms: u64) -> Self { + let mut interval = + tokio::time::interval(Duration::from_millis(interval_ms)); + + // Skip missed ticks when the interval runs. This prevents a burst + // of backlogged ticks after a delay. + interval + .set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + Self { interval } + } + + /// Returns a future that will complete at the next tick time. + pub async fn tick(&mut self) { + self.interval.tick().await; + } +} diff --git a/packages/desktop/src/common/mod.rs b/packages/desktop/src/common/mod.rs index eb5cf80f..3bf9634a 100644 --- a/packages/desktop/src/common/mod.rs +++ b/packages/desktop/src/common/mod.rs @@ -1,5 +1,6 @@ mod format_bytes; mod fs_util; +mod interval; mod length_value; #[cfg(target_os = "macos")] pub mod macos; @@ -9,5 +10,6 @@ pub mod windows; pub use format_bytes::*; pub use fs_util::*; +pub use interval::*; pub use length_value::*; pub use path_ext::*; diff --git a/packages/desktop/src/main.rs b/packages/desktop/src/main.rs index c5f4be96..9dae60cf 100644 --- a/packages/desktop/src/main.rs +++ b/packages/desktop/src/main.rs @@ -8,10 +8,11 @@ use std::{env, sync::Arc}; use clap::Parser; use cli::MonitorType; use config::{MonitorSelection, WidgetPlacement}; +use providers::ProviderEmission; use tauri::{ async_runtime::block_on, AppHandle, Emitter, Manager, RunEvent, }; -use tokio::task; +use tokio::{sync::mpsc, task}; use tracing::{error, info, level_filters::LevelFilter}; use tracing_subscriber::EnvFilter; use widget_factory::WidgetOpenOptions; @@ -88,6 +89,7 @@ async fn main() -> anyhow::Result<()> { commands::update_widget_config, commands::listen_provider, commands::unlisten_provider, + commands::call_provider_function, commands::set_always_on_top, commands::set_skip_taskbar ]) @@ -173,8 +175,8 @@ async fn start_app(app: &mut tauri::App, cli: Cli) -> anyhow::Result<()> { app.handle().plugin(tauri_plugin_dialog::init())?; // Initialize `ProviderManager` in Tauri state. - let manager = Arc::new(ProviderManager::new(app.handle())); - app.manage(manager); + let (manager, emit_rx) = ProviderManager::new(app.handle()); + app.manage(manager.clone()); // Open widgets based on CLI command. open_widgets_by_cli_command(cli, widget_factory.clone()).await?; @@ -184,7 +186,15 @@ async fn start_app(app: &mut tauri::App, cli: Cli) -> anyhow::Result<()> { SysTray::new(app.handle(), config.clone(), widget_factory.clone()) .await?; - listen_events(app.handle(), config, monitor_state, widget_factory, tray); + listen_events( + app.handle(), + config, + monitor_state, + widget_factory, + tray, + manager, + emit_rx, + ); Ok(()) } @@ -194,7 +204,9 @@ fn listen_events( config: Arc, monitor_state: Arc, widget_factory: Arc, - tray: Arc, + tray: SysTray, + manager: Arc, + mut emit_rx: mpsc::UnboundedReceiver, ) { let app_handle = app_handle.clone(); let mut widget_open_rx = widget_factory.open_tx.subscribe(); @@ -231,6 +243,12 @@ fn listen_events( info!("Widget configs changed."); widget_factory.relaunch_by_paths(&changed_configs.keys().cloned().collect()).await }, + Some(provider_emission) = emit_rx.recv() => { + info!("Provider emission: {:?}", provider_emission); + app_handle.emit("provider-emit", provider_emission.clone()); + manager.update_cache(provider_emission).await; + Ok(()) + }, }; if let Err(err) = res { diff --git a/packages/desktop/src/providers/audio/audio_provider.rs b/packages/desktop/src/providers/audio/audio_provider.rs index 3c71e1c4..d150b529 100644 --- a/packages/desktop/src/providers/audio/audio_provider.rs +++ b/packages/desktop/src/providers/audio/audio_provider.rs @@ -5,12 +5,10 @@ use std::{ time::Duration, }; -use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tokio::{ - sync::mpsc::{self, Sender}, + sync::mpsc::{self}, task, - time::sleep, }; use tracing::debug; use windows::Win32::{ @@ -33,10 +31,11 @@ use windows::Win32::{ }; use windows_core::PCWSTR; -use crate::providers::{Provider, ProviderOutput, ProviderResult}; +use crate::providers::{ + CommonProviderState, Provider, ProviderEmitter, RuntimeType, +}; -static PROVIDER_TX: OnceLock> = - OnceLock::new(); +static PROVIDER_TX: OnceLock = OnceLock::new(); static AUDIO_STATE: OnceLock>> = OnceLock::new(); @@ -279,26 +278,29 @@ impl IMMNotificationClient_Impl for MediaDeviceEventHandler_Impl { } pub struct AudioProvider { - _config: AudioProviderConfig, + common: CommonProviderState, } impl AudioProvider { - pub fn new(config: AudioProviderConfig) -> Self { - Self { _config: config } + pub fn new( + _config: AudioProviderConfig, + common: CommonProviderState, + ) -> Self { + Self { common } } fn emit_volume() { if let Some(tx) = PROVIDER_TX.get() { let output = AUDIO_STATE.get().unwrap().lock().unwrap().clone(); - let _ = tx.try_send(Ok(ProviderOutput::Audio(output)).into()); + tx.emit_output(Ok(output)); } } - async fn handle_volume_updates(mut rx: mpsc::Receiver<(String, u32)>) { + fn handle_volume_updates(mut rx: mpsc::Receiver<(String, u32)>) { const PROCESS_DELAY: Duration = Duration::from_millis(50); let mut latest_updates = HashMap::new(); - while let Some((device_id, volume)) = rx.recv().await { + while let Some((device_id, volume)) = rx.blocking_recv() { latest_updates.insert(device_id, volume); // Collect any additional pending updates without waiting. @@ -307,7 +309,7 @@ impl AudioProvider { } // Brief delay to collect more potential updates. - sleep(PROCESS_DELAY).await; + std::thread::sleep(PROCESS_DELAY); // Process all collected updates. if let Some(state) = AUDIO_STATE.get() { @@ -369,11 +371,14 @@ impl AudioProvider { } } -#[async_trait] impl Provider for AudioProvider { - async fn run(&self, emit_result_tx: Sender) { + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Sync + } + + fn start_sync(&mut self) { PROVIDER_TX - .set(emit_result_tx.clone()) + .set(self.common.emitter.clone()) .expect("Error setting provider tx in audio provider"); AUDIO_STATE @@ -383,22 +388,11 @@ impl Provider for AudioProvider { // Create a channel for volume updates. let (update_tx, update_rx) = mpsc::channel(100); - // Spawn both tasks. - let update_handler = - task::spawn(Self::handle_volume_updates(update_rx)); - - let manager = task::spawn_blocking(move || { - if let Err(err) = Self::create_audio_manager(update_tx) { - emit_result_tx - .blocking_send(Err(err).into()) - .expect("Error with media provider"); - } - }); + // Spawn task for handling volume updates. + task::spawn_blocking(move || Self::handle_volume_updates(update_rx)); - // Wait for either task to complete (though they should run forever). - tokio::select! { - _ = manager => debug!("Audio manager stopped unexpectedly"), - _ = update_handler => debug!("Update handler stopped unexpectedly"), + if let Err(err) = Self::create_audio_manager(update_tx) { + self.common.emitter.emit_output::(Err(err)); } } } diff --git a/packages/desktop/src/providers/battery/battery_provider.rs b/packages/desktop/src/providers/battery/battery_provider.rs index 21316295..83a07e0a 100644 --- a/packages/desktop/src/providers/battery/battery_provider.rs +++ b/packages/desktop/src/providers/battery/battery_provider.rs @@ -8,7 +8,12 @@ use starship_battery::{ Manager, State, }; -use crate::{impl_interval_provider, providers::ProviderOutput}; +use crate::{ + common::SyncInterval, + providers::{ + CommonProviderState, Provider, ProviderInputMsg, RuntimeType, + }, +}; #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] @@ -32,25 +37,25 @@ pub struct BatteryOutput { pub struct BatteryProvider { config: BatteryProviderConfig, + common: CommonProviderState, } impl BatteryProvider { - pub fn new(config: BatteryProviderConfig) -> BatteryProvider { - BatteryProvider { config } + pub fn new( + config: BatteryProviderConfig, + common: CommonProviderState, + ) -> BatteryProvider { + BatteryProvider { config, common } } - fn refresh_interval_ms(&self) -> u64 { - self.config.refresh_interval - } - - async fn run_interval(&self) -> anyhow::Result { + fn run_interval(&self) -> anyhow::Result { let battery = Manager::new()? .batteries() .and_then(|mut batteries| batteries.nth(0).transpose()) .unwrap_or(None) .context("No battery found.")?; - Ok(ProviderOutput::Battery(BatteryOutput { + Ok(BatteryOutput { charge_percent: battery.state_of_charge().get::(), health_percent: battery.state_of_health().get::(), state: battery.state().to_string(), @@ -64,8 +69,30 @@ impl BatteryProvider { power_consumption: battery.energy_rate().get::(), voltage: battery.voltage().get::(), cycle_count: battery.cycle_count(), - })) + }) } } -impl_interval_provider!(BatteryProvider, true); +impl Provider for BatteryProvider { + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Sync + } + + fn start_sync(&mut self) { + let mut interval = SyncInterval::new(self.config.refresh_interval); + + loop { + crossbeam::select! { + recv(interval.tick()) -> _ => { + let output = self.run_interval(); + self.common.emitter.emit_output(output); + } + recv(self.common.input.sync_rx) -> input => { + if let Ok(ProviderInputMsg::Stop) = input { + break; + } + } + } + } + } +} diff --git a/packages/desktop/src/providers/cpu/cpu_provider.rs b/packages/desktop/src/providers/cpu/cpu_provider.rs index 5f9e6bdc..fbdd6c6f 100644 --- a/packages/desktop/src/providers/cpu/cpu_provider.rs +++ b/packages/desktop/src/providers/cpu/cpu_provider.rs @@ -1,10 +1,11 @@ -use std::sync::Arc; - use serde::{Deserialize, Serialize}; -use sysinfo::System; -use tokio::sync::Mutex; -use crate::{impl_interval_provider, providers::ProviderOutput}; +use crate::{ + common::SyncInterval, + providers::{ + CommonProviderState, Provider, ProviderInputMsg, RuntimeType, + }, +}; #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] @@ -24,26 +25,22 @@ pub struct CpuOutput { pub struct CpuProvider { config: CpuProviderConfig, - sysinfo: Arc>, + common: CommonProviderState, } impl CpuProvider { pub fn new( config: CpuProviderConfig, - sysinfo: Arc>, + common: CommonProviderState, ) -> CpuProvider { - CpuProvider { config, sysinfo } - } - - fn refresh_interval_ms(&self) -> u64 { - self.config.refresh_interval + CpuProvider { config, common } } - async fn run_interval(&self) -> anyhow::Result { - let mut sysinfo = self.sysinfo.lock().await; + fn run_interval(&self) -> anyhow::Result { + let mut sysinfo = self.common.sysinfo.blocking_lock(); sysinfo.refresh_cpu(); - Ok(ProviderOutput::Cpu(CpuOutput { + Ok(CpuOutput { usage: sysinfo.global_cpu_info().cpu_usage(), frequency: sysinfo.global_cpu_info().frequency(), logical_core_count: sysinfo.cpus().len(), @@ -51,8 +48,30 @@ impl CpuProvider { .physical_core_count() .unwrap_or(sysinfo.cpus().len()), vendor: sysinfo.global_cpu_info().vendor_id().into(), - })) + }) } } -impl_interval_provider!(CpuProvider, true); +impl Provider for CpuProvider { + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Sync + } + + fn start_sync(&mut self) { + let mut interval = SyncInterval::new(self.config.refresh_interval); + + loop { + crossbeam::select! { + recv(interval.tick()) -> _ => { + let output = self.run_interval(); + self.common.emitter.emit_output(output); + } + recv(self.common.input.sync_rx) -> input => { + if let Ok(ProviderInputMsg::Stop) = input { + break; + } + } + } + } + } +} diff --git a/packages/desktop/src/providers/disk/disk_provider.rs b/packages/desktop/src/providers/disk/disk_provider.rs index 19f86153..acd18112 100644 --- a/packages/desktop/src/providers/disk/disk_provider.rs +++ b/packages/desktop/src/providers/disk/disk_provider.rs @@ -1,13 +1,11 @@ -use std::sync::Arc; - use serde::{Deserialize, Serialize}; use sysinfo::Disks; -use tokio::sync::Mutex; use crate::{ - common::{to_iec_bytes, to_si_bytes}, - impl_interval_provider, - providers::ProviderOutput, + common::{to_iec_bytes, to_si_bytes, SyncInterval}, + providers::{ + CommonProviderState, Provider, ProviderInputMsg, RuntimeType, + }, }; #[derive(Deserialize, Debug)] @@ -36,7 +34,8 @@ pub struct Disk { pub struct DiskProvider { config: DiskProviderConfig, - system: Arc>, + common: CommonProviderState, + disks: Disks, } #[derive(Debug, Clone, PartialEq, Serialize)] @@ -52,20 +51,20 @@ pub struct DiskSizeMeasure { impl DiskProvider { pub fn new( config: DiskProviderConfig, - system: Arc>, + common: CommonProviderState, ) -> DiskProvider { - DiskProvider { config, system } - } - - fn refresh_interval_ms(&self) -> u64 { - self.config.refresh_interval + DiskProvider { + config, + common, + disks: Disks::new_with_refreshed_list(), + } } - async fn run_interval(&self) -> anyhow::Result { - let mut disks = self.system.lock().await; - disks.refresh(); + fn run_interval(&mut self) -> anyhow::Result { + self.disks.refresh(); - let disks = disks + let disks = self + .disks .iter() .map(|disk| -> anyhow::Result { let name = disk.name().to_string_lossy().to_string(); @@ -84,7 +83,7 @@ impl DiskProvider { }) .collect::>>()?; - Ok(ProviderOutput::Disk(DiskOutput { disks })) + Ok(DiskOutput { disks }) } fn to_disk_size_measure(bytes: u64) -> anyhow::Result { @@ -101,4 +100,26 @@ impl DiskProvider { } } -impl_interval_provider!(DiskProvider, true); +impl Provider for DiskProvider { + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Sync + } + + fn start_sync(&mut self) { + let mut interval = SyncInterval::new(self.config.refresh_interval); + + loop { + crossbeam::select! { + recv(interval.tick()) -> _ => { + let output = self.run_interval(); + self.common.emitter.emit_output(output); + } + recv(self.common.input.sync_rx) -> input => { + if let Ok(ProviderInputMsg::Stop) = input { + break; + } + } + } + } + } +} diff --git a/packages/desktop/src/providers/host/host_provider.rs b/packages/desktop/src/providers/host/host_provider.rs index 9b26a4b2..a95070bd 100644 --- a/packages/desktop/src/providers/host/host_provider.rs +++ b/packages/desktop/src/providers/host/host_provider.rs @@ -1,7 +1,12 @@ use serde::{Deserialize, Serialize}; use sysinfo::System; -use crate::{impl_interval_provider, providers::ProviderOutput}; +use crate::{ + common::SyncInterval, + providers::{ + CommonProviderState, Provider, ProviderInputMsg, RuntimeType, + }, +}; #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] @@ -22,27 +27,49 @@ pub struct HostOutput { pub struct HostProvider { config: HostProviderConfig, + common: CommonProviderState, } impl HostProvider { - pub fn new(config: HostProviderConfig) -> HostProvider { - HostProvider { config } + pub fn new( + config: HostProviderConfig, + common: CommonProviderState, + ) -> HostProvider { + HostProvider { config, common } } - fn refresh_interval_ms(&self) -> u64 { - self.config.refresh_interval - } - - async fn run_interval(&self) -> anyhow::Result { - Ok(ProviderOutput::Host(HostOutput { + fn run_interval(&mut self) -> anyhow::Result { + Ok(HostOutput { hostname: System::host_name(), os_name: System::name(), os_version: System::os_version(), friendly_os_version: System::long_os_version(), boot_time: System::boot_time() * 1000, uptime: System::uptime() * 1000, - })) + }) } } -impl_interval_provider!(HostProvider, false); +impl Provider for HostProvider { + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Sync + } + + fn start_sync(&mut self) { + let mut interval = SyncInterval::new(self.config.refresh_interval); + + loop { + crossbeam::select! { + recv(interval.tick()) -> _ => { + let output = self.run_interval(); + self.common.emitter.emit_output(output); + } + recv(self.common.input.sync_rx) -> input => { + if let Ok(ProviderInputMsg::Stop) = input { + break; + } + } + } + } + } +} diff --git a/packages/desktop/src/providers/ip/ip_provider.rs b/packages/desktop/src/providers/ip/ip_provider.rs index 37b1cdac..eececc5a 100644 --- a/packages/desktop/src/providers/ip/ip_provider.rs +++ b/packages/desktop/src/providers/ip/ip_provider.rs @@ -1,9 +1,15 @@ use anyhow::Context; +use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; use super::ipinfo_res::IpinfoRes; -use crate::{impl_interval_provider, providers::ProviderOutput}; +use crate::{ + common::AsyncInterval, + providers::{ + CommonProviderState, Provider, ProviderInputMsg, RuntimeType, + }, +}; #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] @@ -23,24 +29,28 @@ pub struct IpOutput { pub struct IpProvider { config: IpProviderConfig, + common: CommonProviderState, http_client: Client, } impl IpProvider { - pub fn new(config: IpProviderConfig) -> IpProvider { + pub fn new( + config: IpProviderConfig, + common: CommonProviderState, + ) -> IpProvider { IpProvider { config, + common, http_client: Client::new(), } } - fn refresh_interval_ms(&self) -> u64 { - self.config.refresh_interval + async fn run_interval(&mut self) -> anyhow::Result { + Self::query_ip(&self.http_client).await } - pub async fn run_interval(&self) -> anyhow::Result { - let res = self - .http_client + pub async fn query_ip(http_client: &Client) -> anyhow::Result { + let res = http_client .get("https://ipinfo.io/json") .send() .await? @@ -49,7 +59,7 @@ impl IpProvider { let mut loc_parts = res.loc.split(','); - Ok(ProviderOutput::Ip(IpOutput { + Ok(IpOutput { address: res.ip, approx_city: res.city, approx_country: res.country, @@ -61,8 +71,31 @@ impl IpProvider { .next() .and_then(|long| long.parse::().ok()) .context("Failed to parse longitude from IPinfo.")?, - })) + }) } } -impl_interval_provider!(IpProvider, false); +#[async_trait] +impl Provider for IpProvider { + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Async + } + + async fn start_async(&mut self) { + let mut interval = AsyncInterval::new(self.config.refresh_interval); + + loop { + tokio::select! { + _ = interval.tick() => { + let output = self.run_interval().await; + self.common.emitter.emit_output(output); + } + Some(message) = self.common.input.async_rx.recv() => { + if let ProviderInputMsg::Stop = message { + break; + } + } + } + } + } +} diff --git a/packages/desktop/src/providers/keyboard/keyboard_provider.rs b/packages/desktop/src/providers/keyboard/keyboard_provider.rs index 31d2d00b..513f1d31 100644 --- a/packages/desktop/src/providers/keyboard/keyboard_provider.rs +++ b/packages/desktop/src/providers/keyboard/keyboard_provider.rs @@ -9,7 +9,12 @@ use windows::Win32::{ }, }; -use crate::{impl_interval_provider, providers::ProviderOutput}; +use crate::{ + common::SyncInterval, + providers::{ + CommonProviderState, Provider, ProviderInputMsg, RuntimeType, + }, +}; #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] @@ -25,18 +30,18 @@ pub struct KeyboardOutput { pub struct KeyboardProvider { config: KeyboardProviderConfig, + common: CommonProviderState, } impl KeyboardProvider { - pub fn new(config: KeyboardProviderConfig) -> KeyboardProvider { - KeyboardProvider { config } + pub fn new( + config: KeyboardProviderConfig, + common: CommonProviderState, + ) -> KeyboardProvider { + KeyboardProvider { config, common } } - fn refresh_interval_ms(&self) -> u64 { - self.config.refresh_interval - } - - async fn run_interval(&self) -> anyhow::Result { + fn run_interval(&mut self) -> anyhow::Result { let keyboard_layout = unsafe { GetKeyboardLayout(GetWindowThreadProcessId( GetForegroundWindow(), @@ -62,10 +67,32 @@ impl KeyboardProvider { let layout_name = String::from_utf16_lossy(&locale_name[..result as usize]); - Ok(ProviderOutput::Keyboard(KeyboardOutput { + Ok(KeyboardOutput { layout: layout_name, - })) + }) } } -impl_interval_provider!(KeyboardProvider, false); +impl Provider for KeyboardProvider { + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Sync + } + + fn start_sync(&mut self) { + let mut interval = SyncInterval::new(self.config.refresh_interval); + + loop { + crossbeam::select! { + recv(interval.tick()) -> _ => { + let output = self.run_interval(); + self.common.emitter.emit_output(output); + } + recv(self.common.input.sync_rx) -> input => { + if let Ok(ProviderInputMsg::Stop) = input { + break; + } + } + } + } + } +} diff --git a/packages/desktop/src/providers/komorebi/komorebi_provider.rs b/packages/desktop/src/providers/komorebi/komorebi_provider.rs index 900dc7cb..e4bada56 100644 --- a/packages/desktop/src/providers/komorebi/komorebi_provider.rs +++ b/packages/desktop/src/providers/komorebi/komorebi_provider.rs @@ -9,14 +9,13 @@ use komorebi_client::{ Container, Monitor, SocketMessage, Window, Workspace, }; use serde::{Deserialize, Serialize}; -use tokio::{sync::mpsc::Sender, time}; use tracing::debug; use super::{ KomorebiContainer, KomorebiLayout, KomorebiLayoutFlip, KomorebiMonitor, KomorebiWindow, KomorebiWorkspace, }; -use crate::providers::{Provider, ProviderOutput, ProviderResult}; +use crate::providers::{CommonProviderState, Provider, RuntimeType}; const SOCKET_NAME: &str = "zebar.sock"; @@ -32,18 +31,18 @@ pub struct KomorebiOutput { } pub struct KomorebiProvider { - _config: KomorebiProviderConfig, + common: CommonProviderState, } impl KomorebiProvider { - pub fn new(config: KomorebiProviderConfig) -> KomorebiProvider { - KomorebiProvider { _config: config } + pub fn new( + _config: KomorebiProviderConfig, + common: CommonProviderState, + ) -> KomorebiProvider { + KomorebiProvider { common } } - async fn create_socket( - &self, - emit_result_tx: Sender, - ) -> anyhow::Result<()> { + fn create_socket(&mut self) -> anyhow::Result<()> { let socket = komorebi_client::subscribe(SOCKET_NAME) .context("Failed to initialize Komorebi socket.")?; @@ -68,7 +67,7 @@ impl KomorebiProvider { .is_err() { debug!("Attempting to reconnect to Komorebi."); - time::sleep(Duration::from_secs(15)).await; + std::thread::sleep(Duration::from_secs(15)); } } @@ -78,24 +77,14 @@ impl KomorebiProvider { &String::from_utf8(buffer).unwrap(), ) { - emit_result_tx - .send( - Ok(ProviderOutput::Komorebi(Self::transform_response( - notification.state, - ))) - .into(), - ) - .await; + self.common.emitter.emit_output(Ok(Self::transform_response( + notification.state, + ))); } } - Err(_) => { - emit_result_tx - .send( - Err(anyhow::anyhow!("Failed to read Komorebi stream.")) - .into(), - ) - .await; - } + Err(_) => self.common.emitter.emit_output::(Err( + anyhow::anyhow!("Failed to read Komorebi stream."), + )), } } @@ -185,9 +174,13 @@ impl KomorebiProvider { #[async_trait] impl Provider for KomorebiProvider { - async fn run(&self, emit_result_tx: Sender) { - if let Err(err) = self.create_socket(emit_result_tx.clone()).await { - emit_result_tx.send(Err(err).into()).await; + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Sync + } + + fn start_sync(&mut self) { + if let Err(err) = self.create_socket() { + self.common.emitter.emit_output::(Err(err)); } } } diff --git a/packages/desktop/src/providers/media/media_provider.rs b/packages/desktop/src/providers/media/media_provider.rs index e3f80974..008858ed 100644 --- a/packages/desktop/src/providers/media/media_provider.rs +++ b/packages/desktop/src/providers/media/media_provider.rs @@ -5,7 +5,6 @@ use std::{ use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use tokio::{sync::mpsc::Sender, task}; use tracing::{debug, error}; use windows::{ Foundation::{EventRegistrationToken, TypedEventHandler}, @@ -16,7 +15,9 @@ use windows::{ }, }; -use crate::providers::{Provider, ProviderOutput, ProviderResult}; +use crate::providers::{ + CommonProviderState, Provider, ProviderEmitter, RuntimeType, +}; #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] @@ -50,26 +51,23 @@ struct EventTokens { } pub struct MediaProvider { - _config: MediaProviderConfig, + common: CommonProviderState, } impl MediaProvider { - pub fn new(config: MediaProviderConfig) -> MediaProvider { - MediaProvider { _config: config } + pub fn new( + _config: MediaProviderConfig, + common: CommonProviderState, + ) -> MediaProvider { + MediaProvider { common } } fn emit_media_info( session: Option<&GsmtcSession>, - emit_result_tx: Sender, + emitter: &ProviderEmitter, ) { - let _ = match Self::media_output(session) { - Ok(media_output) => emit_result_tx - .blocking_send(Ok(ProviderOutput::Media(media_output)).into()), - Err(err) => { - error!("Error retrieving media output: {:?}", err); - emit_result_tx.blocking_send(Err(err).into()) - } - }; + let media_output = Self::media_output(session); + emitter.emit_output(media_output); } fn media_output( @@ -123,9 +121,7 @@ impl MediaProvider { })) } - fn create_session_manager( - emit_result_tx: Sender, - ) -> anyhow::Result<()> { + fn create_session_manager(&mut self) -> anyhow::Result<()> { debug!("Creating media session manager."); // Find the current GSMTC session & add listeners. @@ -133,25 +129,22 @@ impl MediaProvider { let current_session = session_manager.GetCurrentSession().ok(); let event_tokens = match ¤t_session { - Some(session) => Some(Self::add_session_listeners( - session, - emit_result_tx.clone(), - )?), + Some(session) => { + Some(Self::add_session_listeners(session, &self.common.emitter)?) + } None => None, }; // Emit initial media info. - Self::emit_media_info( - current_session.as_ref(), - emit_result_tx.clone(), - ); + Self::emit_media_info(current_session.as_ref(), &self.common.emitter); let current_session = Arc::new(Mutex::new(current_session)); let event_tokens = Arc::new(Mutex::new(event_tokens)); + let emitter = self.common.emitter.clone(); // Clean up & rebind listeners when session changes. - let session_changed_handler = TypedEventHandler::new( - move |session_manager: &Option, _| { + let session_changed_handler = + TypedEventHandler::new(move |_: &Option, _| { { let mut current_session = current_session.lock().unwrap(); let mut event_tokens = event_tokens.lock().unwrap(); @@ -171,23 +164,17 @@ impl MediaProvider { let new_session = GsmtcManager::RequestAsync()?.get()?.GetCurrentSession()?; - let tokens = Self::add_session_listeners( - &new_session, - emit_result_tx.clone(), - )?; + let tokens = + Self::add_session_listeners(&new_session, &emitter)?; - Self::emit_media_info( - Some(&new_session), - emit_result_tx.clone(), - ); + Self::emit_media_info(Some(&new_session), &emitter); *current_session = Some(new_session); *event_tokens = Some(tokens); } Ok(()) - }, - ); + }); session_manager.CurrentSessionChanged(&session_changed_handler)?; @@ -216,18 +203,18 @@ impl MediaProvider { fn add_session_listeners( session: &GsmtcSession, - emit_result_tx: Sender, + emitter: &ProviderEmitter, ) -> windows::core::Result { debug!("Adding session listeners."); let media_properties_changed_handler = { - let emit_result_tx = emit_result_tx.clone(); + let emitter = emitter.clone(); TypedEventHandler::new(move |session: &Option, _| { debug!("Media properties changed event triggered."); if let Some(session) = session { - Self::emit_media_info(Some(session), emit_result_tx.clone()); + Self::emit_media_info(Some(session), &emitter); } Ok(()) @@ -235,13 +222,13 @@ impl MediaProvider { }; let playback_info_changed_handler = { - let emit_result_tx = emit_result_tx.clone(); + let emitter = emitter.clone(); TypedEventHandler::new(move |session: &Option, _| { debug!("Playback info changed event triggered."); if let Some(session) = session { - Self::emit_media_info(Some(session), emit_result_tx.clone()); + Self::emit_media_info(Some(session), &emitter); } Ok(()) @@ -249,45 +236,39 @@ impl MediaProvider { }; let timeline_properties_changed_handler = { - let emit_result_tx = emit_result_tx.clone(); + let emitter = emitter.clone(); TypedEventHandler::new(move |session: &Option, _| { debug!("Timeline properties changed event triggered."); if let Some(session) = session { - Self::emit_media_info(Some(session), emit_result_tx.clone()); + Self::emit_media_info(Some(session), &emitter); } Ok(()) }) }; - let timeline_token = session - .TimelinePropertiesChanged(&timeline_properties_changed_handler)?; - let playback_token = - session.PlaybackInfoChanged(&playback_info_changed_handler)?; - let media_token = - session.MediaPropertiesChanged(&media_properties_changed_handler)?; - - Ok({ - EventTokens { - playback_info_changed_token: playback_token, - media_properties_changed_token: media_token, - timeline_properties_changed_token: timeline_token, - } + Ok(EventTokens { + playback_info_changed_token: session + .PlaybackInfoChanged(&playback_info_changed_handler)?, + media_properties_changed_token: session + .MediaPropertiesChanged(&media_properties_changed_handler)?, + timeline_properties_changed_token: session + .TimelinePropertiesChanged(&timeline_properties_changed_handler)?, }) } } #[async_trait] impl Provider for MediaProvider { - async fn run(&self, emit_result_tx: Sender) { - task::spawn_blocking(move || { - if let Err(err) = - Self::create_session_manager(emit_result_tx.clone()) - { - let _ = emit_result_tx.blocking_send(Err(err).into()); - } - }); + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Sync + } + + fn start_sync(&mut self) { + if let Err(err) = self.create_session_manager() { + self.common.emitter.emit_output::(Err(err)); + } } } diff --git a/packages/desktop/src/providers/memory/memory_provider.rs b/packages/desktop/src/providers/memory/memory_provider.rs index 077c8f4d..ceaa76d6 100644 --- a/packages/desktop/src/providers/memory/memory_provider.rs +++ b/packages/desktop/src/providers/memory/memory_provider.rs @@ -1,10 +1,11 @@ -use std::sync::Arc; - use serde::{Deserialize, Serialize}; -use sysinfo::System; -use tokio::sync::Mutex; -use crate::{impl_interval_provider, providers::ProviderOutput}; +use crate::{ + common::SyncInterval, + providers::{ + CommonProviderState, Provider, ProviderInputMsg, RuntimeType, + }, +}; #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] @@ -26,30 +27,26 @@ pub struct MemoryOutput { pub struct MemoryProvider { config: MemoryProviderConfig, - sysinfo: Arc>, + common: CommonProviderState, } impl MemoryProvider { pub fn new( config: MemoryProviderConfig, - sysinfo: Arc>, + common: CommonProviderState, ) -> MemoryProvider { - MemoryProvider { config, sysinfo } - } - - fn refresh_interval_ms(&self) -> u64 { - self.config.refresh_interval + MemoryProvider { config, common } } - async fn run_interval(&self) -> anyhow::Result { - let mut sysinfo = self.sysinfo.lock().await; + fn run_interval(&mut self) -> anyhow::Result { + let mut sysinfo = self.common.sysinfo.blocking_lock(); sysinfo.refresh_memory(); let usage = (sysinfo.used_memory() as f32 / sysinfo.total_memory() as f32) * 100.0; - Ok(ProviderOutput::Memory(MemoryOutput { + Ok(MemoryOutput { usage, free_memory: sysinfo.free_memory(), used_memory: sysinfo.used_memory(), @@ -57,8 +54,30 @@ impl MemoryProvider { free_swap: sysinfo.free_swap(), used_swap: sysinfo.used_swap(), total_swap: sysinfo.total_swap(), - })) + }) } } -impl_interval_provider!(MemoryProvider, true); +impl Provider for MemoryProvider { + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Sync + } + + fn start_sync(&mut self) { + let mut interval = SyncInterval::new(self.config.refresh_interval); + + loop { + crossbeam::select! { + recv(interval.tick()) -> _ => { + let output = self.run_interval(); + self.common.emitter.emit_output(output); + } + recv(self.common.input.sync_rx) -> input => { + if let Ok(ProviderInputMsg::Stop) = input { + break; + } + } + } + } + } +} diff --git a/packages/desktop/src/providers/mod.rs b/packages/desktop/src/providers/mod.rs index bad74540..0900643a 100644 --- a/packages/desktop/src/providers/mod.rs +++ b/packages/desktop/src/providers/mod.rs @@ -15,13 +15,13 @@ mod memory; mod network; mod provider; mod provider_config; +mod provider_function; mod provider_manager; mod provider_output; -mod provider_ref; mod weather; pub use provider::*; pub use provider_config::*; +pub use provider_function::*; pub use provider_manager::*; pub use provider_output::*; -pub use provider_ref::*; diff --git a/packages/desktop/src/providers/network/network_provider.rs b/packages/desktop/src/providers/network/network_provider.rs index 48c36176..d08f3927 100644 --- a/packages/desktop/src/providers/network/network_provider.rs +++ b/packages/desktop/src/providers/network/network_provider.rs @@ -1,8 +1,5 @@ -use std::sync::Arc; - use serde::{Deserialize, Serialize}; use sysinfo::Networks; -use tokio::sync::Mutex; use super::{ wifi_hotspot::{default_gateway_wifi, WifiHotstop}, @@ -10,9 +7,10 @@ use super::{ NetworkTrafficMeasure, }; use crate::{ - common::{to_iec_bytes, to_si_bytes}, - impl_interval_provider, - providers::ProviderOutput, + common::{to_iec_bytes, to_si_bytes, SyncInterval}, + providers::{ + CommonProviderState, Provider, ProviderInputMsg, RuntimeType, + }, }; #[derive(Deserialize, Debug)] @@ -32,37 +30,37 @@ pub struct NetworkOutput { pub struct NetworkProvider { config: NetworkProviderConfig, - netinfo: Arc>, + common: CommonProviderState, + netinfo: Networks, } impl NetworkProvider { pub fn new( config: NetworkProviderConfig, - netinfo: Arc>, + common: CommonProviderState, ) -> NetworkProvider { - NetworkProvider { config, netinfo } - } - - fn refresh_interval_ms(&self) -> u64 { - self.config.refresh_interval + NetworkProvider { + config, + common, + netinfo: Networks::new_with_refreshed_list(), + } } - async fn run_interval(&self) -> anyhow::Result { - let mut netinfo = self.netinfo.lock().await; - netinfo.refresh(); + fn run_interval(&mut self) -> anyhow::Result { + self.netinfo.refresh(); let interfaces = netdev::get_interfaces(); let default_interface = netdev::get_default_interface().ok(); - let (received, total_received) = Self::bytes_received(&netinfo); + let (received, total_received) = Self::bytes_received(&self.netinfo); let received_per_sec = received / self.config.refresh_interval * 1000; let (transmitted, total_transmitted) = - Self::bytes_transmitted(&netinfo); + Self::bytes_transmitted(&self.netinfo); let transmitted_per_sec = transmitted / self.config.refresh_interval * 1000; - Ok(ProviderOutput::Network(NetworkOutput { + Ok(NetworkOutput { default_interface: default_interface .as_ref() .map(Self::transform_interface), @@ -87,7 +85,7 @@ impl NetworkProvider { total_transmitted, )?, }, - })) + }) } fn to_network_traffic_measure( @@ -197,4 +195,26 @@ impl NetworkProvider { } } -impl_interval_provider!(NetworkProvider, true); +impl Provider for NetworkProvider { + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Sync + } + + fn start_sync(&mut self) { + let mut interval = SyncInterval::new(self.config.refresh_interval); + + loop { + crossbeam::select! { + recv(interval.tick()) -> _ => { + let output = self.run_interval(); + self.common.emitter.emit_output(output); + } + recv(self.common.input.sync_rx) -> input => { + if let Ok(ProviderInputMsg::Stop) = input { + break; + } + } + } + } + } +} diff --git a/packages/desktop/src/providers/provider.rs b/packages/desktop/src/providers/provider.rs index 0a391ccf..b0f26cd2 100644 --- a/packages/desktop/src/providers/provider.rs +++ b/packages/desktop/src/providers/provider.rs @@ -1,65 +1,88 @@ use async_trait::async_trait; -use tokio::sync::mpsc::Sender; -use super::ProviderResult; +use super::{ProviderFunction, ProviderFunctionResponse}; #[async_trait] pub trait Provider: Send + Sync { - /// Callback for when the provider is started. - async fn run(&self, emit_result_tx: Sender); + fn runtime_type(&self) -> RuntimeType; - /// Callback for when the provider is stopped. - async fn on_stop(&self) { - // No-op by default. + /// Callback for when the provider is started. + /// + /// # Panics + /// + /// Panics if wrong runtime type is used. + fn start_sync(&mut self) { + match self.runtime_type() { + RuntimeType::Sync => { + unreachable!("Sync providers must implement `start_sync`.") + } + RuntimeType::Async => { + panic!("Cannot call sync function on async provider.") + } + } } -} -/// Implements the `Provider` trait for the given struct. -/// -/// Expects that the struct has a `refresh_interval_ms` and `run_interval` -/// method. -#[macro_export] -macro_rules! impl_interval_provider { - ($type:ty, $allow_identical_emits:expr) => { - #[async_trait::async_trait] - impl crate::providers::Provider for $type { - async fn run( - &self, - emit_result_tx: tokio::sync::mpsc::Sender< - crate::providers::ProviderResult, - >, - ) { - let mut interval = tokio::time::interval( - std::time::Duration::from_millis(self.refresh_interval_ms()), - ); - - // Skip missed ticks when the interval runs. This prevents a burst - // of backlogged ticks after a delay. - interval - .set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - - let mut last_interval_res: Option< - crate::providers::ProviderResult, - > = None; - - loop { - interval.tick().await; - - let interval_res = self.run_interval().await.into(); - - if $allow_identical_emits - || last_interval_res.as_ref() != Some(&interval_res) - { - let send_res = emit_result_tx.send(interval_res.clone()).await; + /// Callback for when the provider is started. + /// + /// # Panics + /// + /// Panics if wrong runtime type is used. + async fn start_async(&mut self) { + match self.runtime_type() { + RuntimeType::Async => { + unreachable!("Async providers must implement `start_async`.") + } + RuntimeType::Sync => { + panic!("Cannot call async function on sync provider.") + } + } + } - if let Err(err) = send_res { - tracing::error!("Error sending provider result: {:?}", err); - } + /// Runs the given function. + /// + /// # Panics + /// + /// Panics if wrong runtime type is used. + fn call_function_sync( + &self, + function: ProviderFunction, + ) -> anyhow::Result { + let _function = function; + match self.runtime_type() { + RuntimeType::Sync => { + unreachable!("Sync providers must implement `call_function_sync`.") + } + RuntimeType::Async => { + panic!("Cannot call sync function on async provider.") + } + } + } - last_interval_res = Some(interval_res); - } - } + /// Runs the given function. + /// + /// # Panics + /// + /// Panics if wrong runtime type is used. + async fn call_function_async( + &self, + function: ProviderFunction, + ) -> anyhow::Result { + let _function = function; + match self.runtime_type() { + RuntimeType::Async => { + unreachable!( + "Async providers must implement `call_function_async`." + ) + } + RuntimeType::Sync => { + panic!("Cannot call async function on sync provider.") } } - }; + } +} + +/// Determines whether `start_sync` or `start_async` is called. +pub enum RuntimeType { + Sync, + Async, } diff --git a/packages/desktop/src/providers/provider_function.rs b/packages/desktop/src/providers/provider_function.rs new file mode 100644 index 00000000..8a6378c8 --- /dev/null +++ b/packages/desktop/src/providers/provider_function.rs @@ -0,0 +1,21 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ProviderFunction { + Media(MediaFunction), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MediaFunction { + PlayPause, + Next, + Previous, +} + +pub type ProviderFunctionResult = Result; + +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] +pub enum ProviderFunctionResponse { + Null, +} diff --git a/packages/desktop/src/providers/provider_manager.rs b/packages/desktop/src/providers/provider_manager.rs index 73b61d35..67aa541d 100644 --- a/packages/desktop/src/providers/provider_manager.rs +++ b/packages/desktop/src/providers/provider_manager.rs @@ -1,38 +1,146 @@ use std::{collections::HashMap, sync::Arc}; -use sysinfo::{Disks, Networks, System}; -use tauri::AppHandle; -use tokio::sync::Mutex; -use tracing::warn; - -use super::{ProviderConfig, ProviderRef}; - -/// State shared between providers. -#[derive(Clone)] -pub struct SharedProviderState { - pub sysinfo: Arc>, - pub netinfo: Arc>, - pub diskinfo: Arc>, +use anyhow::{bail, Context}; +use serde::{ser::SerializeStruct, Serialize}; +use tauri::{AppHandle, Emitter}; +use tokio::{ + sync::{mpsc, oneshot, Mutex}, + task, +}; +use tracing::info; + +#[cfg(windows)] +use super::{ + audio::AudioProvider, keyboard::KeyboardProvider, + komorebi::KomorebiProvider, media::MediaProvider, +}; +use super::{ + battery::BatteryProvider, cpu::CpuProvider, disk::DiskProvider, + host::HostProvider, ip::IpProvider, memory::MemoryProvider, + network::NetworkProvider, weather::WeatherProvider, Provider, + ProviderConfig, ProviderFunction, ProviderFunctionResponse, + ProviderFunctionResult, ProviderOutput, RuntimeType, +}; + +/// Common fields for a provider. +pub struct CommonProviderState { + /// Wrapper around the sender channel of provider emissions. + pub emitter: ProviderEmitter, + + /// Wrapper around the receiver channel for incoming inputs to the + /// provider. + pub input: ProviderInput, + + /// Shared `sysinfo` instance. + pub sysinfo: Arc>, +} + +/// Handle for receiving provider inputs. +pub struct ProviderInput { + /// Async receiver channel for incoming inputs to the provider. + pub async_rx: mpsc::Receiver, + + /// Sync receiver channel for incoming inputs to the provider. + pub sync_rx: crossbeam::channel::Receiver, +} + +pub enum ProviderInputMsg { + Function(ProviderFunction, oneshot::Sender), + Stop, +} + +/// Handle for sending provider emissions. +#[derive(Clone, Debug)] +pub struct ProviderEmitter { + /// Sender channel for outgoing provider emissions. + emit_tx: mpsc::UnboundedSender, + + /// Hash of the provider's config. + config_hash: String, +} + +impl ProviderEmitter { + /// Emits an output from a provider. + pub fn emit_output(&self, output: anyhow::Result) + where + T: Into, + { + let send_res = self.emit_tx.send(ProviderEmission { + config_hash: self.config_hash.clone(), + result: output.map(Into::into).map_err(|err| err.to_string()), + }); + + if let Err(err) = send_res { + tracing::error!("Error sending provider result: {:?}", err); + } + } +} + +/// Emission from a provider. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ProviderEmission { + /// Hash of the provider's config. + pub config_hash: String, + + /// A thread-safe `Result` type for provider outputs and errors. + #[serde(serialize_with = "serialize_result")] + pub result: Result, +} + +/// Reference to an active provider. +struct ProviderRef { + /// Sender channel for sending inputs to the provider. + async_input_tx: mpsc::Sender, + + /// Sender channel for sending inputs to the provider. + sync_input_tx: crossbeam::channel::Sender, + + /// Handle to the provider's task. + task_handle: task::JoinHandle<()>, + + /// Runtime type of the provider. + runtime_type: RuntimeType, } /// Manages the creation and cleanup of providers. pub struct ProviderManager { + /// Handle to the Tauri application. app_handle: AppHandle, - providers: Arc>>, - shared_state: SharedProviderState, + + /// Map of active provider refs. + provider_refs: Arc>>, + + /// Cache of provider emissions. + emit_cache: Arc>>, + + /// Sender channel for provider emissions. + emit_tx: mpsc::UnboundedSender, + + /// Shared `sysinfo` instance. + sysinfo: Arc>, } impl ProviderManager { - pub fn new(app_handle: &AppHandle) -> Self { - Self { - app_handle: app_handle.clone(), - providers: Arc::new(Mutex::new(HashMap::new())), - shared_state: SharedProviderState { - sysinfo: Arc::new(Mutex::new(System::new_all())), - netinfo: Arc::new(Mutex::new(Networks::new_with_refreshed_list())), - diskinfo: Arc::new(Mutex::new(Disks::new_with_refreshed_list())), - }, - } + /// Creates a new provider manager. + /// + /// Returns a tuple containing the `ProviderManager` instance and a + /// channel for provider emissions. + pub fn new( + app_handle: &AppHandle, + ) -> (Arc, mpsc::UnboundedReceiver) { + let (emit_tx, emit_rx) = mpsc::unbounded_channel::(); + + ( + Arc::new(Self { + app_handle: app_handle.clone(), + provider_refs: Arc::new(Mutex::new(HashMap::new())), + emit_cache: Arc::new(Mutex::new(HashMap::new())), + sysinfo: Arc::new(Mutex::new(sysinfo::System::new_all())), + emit_tx, + }), + emit_rx, + ) } /// Creates a provider with the given config. @@ -41,46 +149,230 @@ impl ProviderManager { config_hash: String, config: ProviderConfig, ) -> anyhow::Result<()> { + // If a provider with the given config already exists, re-emit its + // latest emission and return early. { - let mut providers = self.providers.lock().await; - - // If a provider with the given config already exists, refresh it - // and return early. - if let Some(found_provider) = providers.get_mut(&config_hash) { - if let Err(err) = found_provider.refresh().await { - warn!("Error refreshing provider: {:?}", err); - } + if let Some(found_emit) = + self.emit_cache.lock().await.get(&config_hash) + { + tracing::info!( + "Emitting cached provider emission for: {}", + config_hash + ); + self.app_handle.emit("provider-emit", found_emit)?; return Ok(()); }; } - let provider_ref = ProviderRef::new( - &self.app_handle, - config, - config_hash.clone(), - self.shared_state.clone(), - ) - .await?; + // Hold the lock for `provider_refs` to prevent duplicate providers + // from potentially being created. + let mut provider_refs = self.provider_refs.lock().await; + + // No-op if the provider has already been created (but has not emitted + // yet). Multiple frontend clients can call `create` for the same + // provider, and all will receive the same output once the provider + // emits. + if provider_refs.contains_key(&config_hash) { + return Ok(()); + } + + tracing::info!("Creating provider: {}", config_hash); + + let (async_input_tx, async_input_rx) = mpsc::channel(1); + let (sync_input_tx, sync_input_rx) = crossbeam::channel::bounded(1); + + let common = CommonProviderState { + input: ProviderInput { + async_rx: async_input_rx, + sync_rx: sync_input_rx, + }, + emitter: ProviderEmitter { + emit_tx: self.emit_tx.clone(), + config_hash: config_hash.clone(), + }, + sysinfo: self.sysinfo.clone(), + }; + + let (task_handle, runtime_type) = + self.create_instance(config, config_hash.clone(), common)?; + + let provider_ref = ProviderRef { + async_input_tx, + sync_input_tx, + task_handle, + runtime_type, + }; - let mut providers = self.providers.lock().await; - providers.insert(config_hash, provider_ref); + provider_refs.insert(config_hash, provider_ref); Ok(()) } + /// Creates a new provider instance. + fn create_instance( + &self, + config: ProviderConfig, + config_hash: String, + common: CommonProviderState, + ) -> anyhow::Result<(task::JoinHandle<()>, RuntimeType)> { + let mut provider: Box = match config { + #[cfg(windows)] + ProviderConfig::Audio(config) => { + Box::new(AudioProvider::new(config, common)) + } + ProviderConfig::Battery(config) => { + Box::new(BatteryProvider::new(config, common)) + } + ProviderConfig::Cpu(config) => { + Box::new(CpuProvider::new(config, common)) + } + ProviderConfig::Host(config) => { + Box::new(HostProvider::new(config, common)) + } + ProviderConfig::Ip(config) => { + Box::new(IpProvider::new(config, common)) + } + #[cfg(windows)] + ProviderConfig::Komorebi(config) => { + Box::new(KomorebiProvider::new(config, common)) + } + #[cfg(windows)] + ProviderConfig::Media(config) => { + Box::new(MediaProvider::new(config, common)) + } + ProviderConfig::Memory(config) => { + Box::new(MemoryProvider::new(config, common)) + } + ProviderConfig::Disk(config) => { + Box::new(DiskProvider::new(config, common)) + } + ProviderConfig::Network(config) => { + Box::new(NetworkProvider::new(config, common)) + } + ProviderConfig::Weather(config) => { + Box::new(WeatherProvider::new(config, common)) + } + #[cfg(windows)] + ProviderConfig::Keyboard(config) => { + Box::new(KeyboardProvider::new(config, common)) + } + #[allow(unreachable_patterns)] + _ => bail!("Provider not supported on this operating system."), + }; + + // Spawn the provider's task based on its runtime type. + let runtime_type = provider.runtime_type(); + let task_handle = match &runtime_type { + RuntimeType::Async => task::spawn(async move { + provider.start_async().await; + info!("Provider stopped: {}", config_hash); + }), + RuntimeType::Sync => task::spawn_blocking(move || { + provider.start_sync(); + info!("Provider stopped: {}", config_hash); + }), + }; + + Ok((task_handle, runtime_type)) + } + + /// Sends a function call through a channel to be executed by the + /// provider. + /// + /// Returns the result of the function execution. + pub async fn call_function( + &self, + config_hash: String, + function: ProviderFunction, + ) -> anyhow::Result { + let provider_refs = self.provider_refs.lock().await; + let provider_ref = provider_refs + .get(&config_hash) + .context("No provider found with config.")?; + + let (tx, rx) = oneshot::channel(); + match provider_ref.runtime_type { + RuntimeType::Async => { + provider_ref + .async_input_tx + .send(ProviderInputMsg::Function(function, tx)) + .await + .context("Failed to send function call to provider.")?; + } + RuntimeType::Sync => { + provider_ref + .sync_input_tx + .send(ProviderInputMsg::Function(function, tx)) + .context("Failed to send function call to provider.")?; + } + } + + rx.await?.map_err(anyhow::Error::msg) + } + /// Destroys and cleans up the provider with the given config. - pub async fn destroy(&self, config_hash: String) -> anyhow::Result<()> { - let mut providers = self.providers.lock().await; + pub async fn stop(&self, config_hash: String) -> anyhow::Result<()> { + let provider_ref = { + let mut provider_refs = self.provider_refs.lock().await; - if let Some(found_provider) = providers.get_mut(&config_hash) { - if let Err(err) = found_provider.stop().await { - warn!("Error stopping provider: {:?}", err); + // Evict the provider's emission from cache. Hold the lock for + // `provider_refs` to avoid a race condition with provider + // creation. + let mut provider_cache = self.emit_cache.lock().await; + let _ = provider_cache.remove(&config_hash); + + provider_refs + .remove(&config_hash) + .context("No provider found with config.")? + }; + + // Send shutdown signal to the provider. + match provider_ref.runtime_type { + RuntimeType::Async => { + provider_ref + .async_input_tx + .send(ProviderInputMsg::Stop) + .await + .context("Failed to send shutdown signal to provider.")?; + } + RuntimeType::Sync => { + provider_ref + .sync_input_tx + .send(ProviderInputMsg::Stop) + .context("Failed to send shutdown signal to provider.")?; } } - providers.remove(&config_hash); + // Wait for the provider to stop. + provider_ref.task_handle.await?; Ok(()) } + + /// Updates the cache with the given provider emission. + pub async fn update_cache(&self, emission: ProviderEmission) { + let mut cache = self.emit_cache.lock().await; + cache.insert(emission.config_hash.clone(), emission); + } +} + +/// Custom serializer for Result that converts: +/// - Ok(output) -> {"output": output} +/// - Err(error) -> {"error": error} +fn serialize_result( + result: &Result, + serializer: S, +) -> Result +where + S: serde::Serializer, +{ + let mut state = serializer.serialize_struct("Result", 1)?; + + match result { + Ok(output) => state.serialize_field("output", output)?, + Err(error) => state.serialize_field("error", error)?, + } + + state.end() } diff --git a/packages/desktop/src/providers/provider_output.rs b/packages/desktop/src/providers/provider_output.rs index c922fdd8..b10d96e1 100644 --- a/packages/desktop/src/providers/provider_output.rs +++ b/packages/desktop/src/providers/provider_output.rs @@ -11,6 +11,19 @@ use super::{ network::NetworkOutput, weather::WeatherOutput, }; +/// Implements `From` for `ProviderOutput` for each given variant. +macro_rules! impl_provider_output { + ($($variant:ident($type:ty)),* $(,)?) => { + $( + impl From<$type> for ProviderOutput { + fn from(value: $type) -> Self { + Self::$variant(value) + } + } + )* + }; +} + #[derive(Debug, Clone, PartialEq, Serialize)] #[serde(untagged)] pub enum ProviderOutput { @@ -31,3 +44,22 @@ pub enum ProviderOutput { #[cfg(windows)] Keyboard(KeyboardOutput), } + +impl_provider_output! { + Battery(BatteryOutput), + Cpu(CpuOutput), + Host(HostOutput), + Ip(IpOutput), + Memory(MemoryOutput), + Disk(DiskOutput), + Network(NetworkOutput), + Weather(WeatherOutput) +} + +#[cfg(windows)] +impl_provider_output! { + Audio(AudioOutput), + Komorebi(KomorebiOutput), + Media(MediaOutput), + Keyboard(KeyboardOutput) +} diff --git a/packages/desktop/src/providers/provider_ref.rs b/packages/desktop/src/providers/provider_ref.rs deleted file mode 100644 index b8b94cae..00000000 --- a/packages/desktop/src/providers/provider_ref.rs +++ /dev/null @@ -1,233 +0,0 @@ -use std::sync::Arc; - -use anyhow::bail; -use serde::Serialize; -use serde_json::json; -use tauri::{AppHandle, Emitter}; -use tokio::{ - sync::{mpsc, Mutex}, - task, -}; -use tracing::{info, warn}; - -#[cfg(windows)] -use super::{ - audio::AudioProvider, keyboard::KeyboardProvider, - komorebi::KomorebiProvider, media::MediaProvider, -}; -use super::{ - battery::BatteryProvider, cpu::CpuProvider, disk::DiskProvider, - host::HostProvider, ip::IpProvider, memory::MemoryProvider, - network::NetworkProvider, weather::WeatherProvider, Provider, - ProviderConfig, ProviderOutput, SharedProviderState, -}; - -/// Reference to an active provider. -pub struct ProviderRef { - /// Cache for provider output. - cache: Arc>>>, - - /// Sender channel for emitting provider output/error to frontend - /// clients. - emit_result_tx: mpsc::Sender, - - /// Sender channel for stopping the provider. - stop_tx: mpsc::Sender<()>, -} - -/// Provider output/error emitted to frontend clients. -/// -/// This is used instead of a normal `Result` type in order to serialize it -/// in a nicer way. -#[derive(Debug, Clone, PartialEq, Serialize)] -#[serde(rename_all = "camelCase")] -pub enum ProviderResult { - Output(ProviderOutput), - Error(String), -} - -/// Implements conversion from `anyhow::Result`. -impl From> for ProviderResult { - fn from(result: anyhow::Result) -> Self { - match result { - Ok(output) => ProviderResult::Output(output), - Err(err) => ProviderResult::Error(err.to_string()), - } - } -} - -impl ProviderRef { - /// Creates a new `ProviderRef` instance. - pub async fn new( - app_handle: &AppHandle, - config: ProviderConfig, - config_hash: String, - shared_state: SharedProviderState, - ) -> anyhow::Result { - let cache = Arc::new(Mutex::new(None)); - - let (stop_tx, stop_rx) = mpsc::channel::<()>(1); - let (emit_result_tx, emit_result_rx) = - mpsc::channel::(1); - - Self::start_output_listener( - app_handle.clone(), - config_hash.clone(), - cache.clone(), - emit_result_rx, - ); - - Self::start_provider( - config, - config_hash, - shared_state, - emit_result_tx.clone(), - stop_rx, - )?; - - Ok(Self { - cache, - emit_result_tx, - stop_tx, - }) - } - - fn start_output_listener( - app_handle: AppHandle, - config_hash: String, - cache: Arc>>>, - mut emit_result_rx: mpsc::Receiver, - ) { - task::spawn(async move { - while let Some(output) = emit_result_rx.recv().await { - info!("Emitting for provider: {}", config_hash); - - let output = Box::new(output); - let payload = json!({ - "configHash": config_hash.clone(), - "result": *output.clone(), - }); - - if let Err(err) = app_handle.emit("provider-emit", payload) { - warn!("Error emitting provider output: {:?}", err); - } - - // Update the provider's output cache. - if let Ok(mut providers) = cache.try_lock() { - *providers = Some(output); - } else { - warn!("Failed to update provider output cache."); - } - } - }); - } - - /// Starts the provider in a separate task. - fn start_provider( - config: ProviderConfig, - config_hash: String, - shared_state: SharedProviderState, - emit_result_tx: mpsc::Sender, - mut stop_rx: mpsc::Receiver<()>, - ) -> anyhow::Result<()> { - let provider = Self::create_provider(config, shared_state)?; - - task::spawn(async move { - // TODO: Add arc `should_stop` to be passed to `run`. - - let run = provider.run(emit_result_tx); - tokio::pin!(run); - - // Ref: https://tokio.rs/tokio/tutorial/select#resuming-an-async-operation - loop { - tokio::select! { - // Default match arm which continuously runs the provider. - _ = run => break, - - // On stop, perform any necessary clean up and exit the loop. - Some(_) = stop_rx.recv() => { - info!("Stopping provider: {}", config_hash); - _ = provider.on_stop().await; - break; - }, - } - } - - info!("Provider stopped: {}", config_hash); - }); - - Ok(()) - } - - fn create_provider( - config: ProviderConfig, - shared_state: SharedProviderState, - ) -> anyhow::Result> { - let provider: Box = match config { - #[cfg(windows)] - ProviderConfig::Audio(config) => { - Box::new(AudioProvider::new(config)) - } - ProviderConfig::Battery(config) => { - Box::new(BatteryProvider::new(config)) - } - ProviderConfig::Cpu(config) => { - Box::new(CpuProvider::new(config, shared_state.sysinfo.clone())) - } - ProviderConfig::Host(config) => Box::new(HostProvider::new(config)), - ProviderConfig::Ip(config) => Box::new(IpProvider::new(config)), - #[cfg(windows)] - ProviderConfig::Komorebi(config) => { - Box::new(KomorebiProvider::new(config)) - } - #[cfg(windows)] - ProviderConfig::Media(config) => { - Box::new(MediaProvider::new(config)) - } - ProviderConfig::Memory(config) => { - Box::new(MemoryProvider::new(config, shared_state.sysinfo.clone())) - } - ProviderConfig::Disk(config) => { - Box::new(DiskProvider::new(config, shared_state.diskinfo.clone())) - } - ProviderConfig::Network(config) => Box::new(NetworkProvider::new( - config, - shared_state.netinfo.clone(), - )), - ProviderConfig::Weather(config) => { - Box::new(WeatherProvider::new(config)) - } - #[cfg(windows)] - ProviderConfig::Keyboard(config) => { - Box::new(KeyboardProvider::new(config)) - } - #[allow(unreachable_patterns)] - _ => bail!("Provider not supported on this operating system."), - }; - - Ok(provider) - } - - /// Re-emits the latest provider output. - /// - /// No-ops if the provider hasn't outputted yet, since the provider will - /// anyways emit its output after initialization. - pub async fn refresh(&self) -> anyhow::Result<()> { - let cache = { self.cache.lock().await.clone() }; - - if let Some(cache) = cache { - self.emit_result_tx.send(*cache).await?; - } - - Ok(()) - } - - /// Stops the given provider. - /// - /// This triggers any necessary cleanup. - pub async fn stop(&self) -> anyhow::Result<()> { - self.stop_tx.send(()).await?; - - Ok(()) - } -} diff --git a/packages/desktop/src/providers/weather/weather_provider.rs b/packages/desktop/src/providers/weather/weather_provider.rs index 8b57d300..22ed5692 100644 --- a/packages/desktop/src/providers/weather/weather_provider.rs +++ b/packages/desktop/src/providers/weather/weather_provider.rs @@ -1,12 +1,13 @@ +use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; use super::open_meteo_res::OpenMeteoRes; use crate::{ - impl_interval_provider, + common::AsyncInterval, providers::{ - ip::{IpProvider, IpProviderConfig}, - ProviderOutput, + ip::IpProvider, CommonProviderState, Provider, ProviderInputMsg, + RuntimeType, }, }; @@ -47,38 +48,29 @@ pub enum WeatherStatus { pub struct WeatherProvider { config: WeatherProviderConfig, + common: CommonProviderState, http_client: Client, } impl WeatherProvider { - pub fn new(config: WeatherProviderConfig) -> WeatherProvider { + pub fn new( + config: WeatherProviderConfig, + common: CommonProviderState, + ) -> WeatherProvider { WeatherProvider { config, + common, http_client: Client::new(), } } - fn refresh_interval_ms(&self) -> u64 { - self.config.refresh_interval - } - - async fn run_interval(&self) -> anyhow::Result { + async fn run_interval(&self) -> anyhow::Result { let (latitude, longitude) = { match (self.config.latitude, self.config.longitude) { (Some(lat), Some(lon)) => (lat, lon), _ => { - let ip_output = IpProvider::new(IpProviderConfig { - refresh_interval: 0, - }) - .run_interval() - .await?; - - match ip_output { - ProviderOutput::Ip(ip_output) => { - (ip_output.approx_latitude, ip_output.approx_longitude) - } - _ => anyhow::bail!("Unexpected output from IP provider."), - } + let ip_output = IpProvider::query_ip(&self.http_client).await?; + (ip_output.approx_latitude, ip_output.approx_longitude) } } }; @@ -102,7 +94,7 @@ impl WeatherProvider { let current_weather = res.current_weather; let is_daytime = current_weather.is_day == 1; - Ok(ProviderOutput::Weather(WeatherOutput { + Ok(WeatherOutput { is_daytime, status: Self::get_weather_status( current_weather.weather_code, @@ -113,7 +105,7 @@ impl WeatherProvider { current_weather.temperature, ), wind_speed: current_weather.wind_speed, - })) + }) } fn celsius_to_fahrenheit(celsius_temp: f32) -> f32 { @@ -159,4 +151,27 @@ impl WeatherProvider { } } -impl_interval_provider!(WeatherProvider, true); +#[async_trait] +impl Provider for WeatherProvider { + fn runtime_type(&self) -> RuntimeType { + RuntimeType::Async + } + + async fn start_async(&mut self) { + let mut interval = AsyncInterval::new(self.config.refresh_interval); + + loop { + tokio::select! { + _ = interval.tick() => { + let output = self.run_interval().await; + self.common.emitter.emit_output(output); + } + Some(message) = self.common.input.async_rx.recv() => { + if let ProviderInputMsg::Stop = message { + break; + } + } + } + } + } +} diff --git a/packages/desktop/src/sys_tray.rs b/packages/desktop/src/sys_tray.rs index 8e2ec538..d83d580b 100644 --- a/packages/desktop/src/sys_tray.rs +++ b/packages/desktop/src/sys_tray.rs @@ -116,7 +116,7 @@ impl SysTray { app_handle: &AppHandle, config: Arc, widget_factory: Arc, - ) -> anyhow::Result> { + ) -> anyhow::Result { let mut sys_tray = Self { app_handle: app_handle.clone(), config, @@ -126,7 +126,7 @@ impl SysTray { sys_tray.tray_icon = Some(sys_tray.create_tray_icon().await?); - Ok(Arc::new(sys_tray)) + Ok(sys_tray) } async fn create_tray_icon(&self) -> anyhow::Result {