Skip to content

Commit

Permalink
wip add call_function on provider manager
Browse files Browse the repository at this point in the history
  • Loading branch information
lars-berger committed Nov 16, 2024
1 parent 81ee6ea commit 0821c1b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
14 changes: 3 additions & 11 deletions packages/desktop/src/providers/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@ pub trait Provider: Send + Sync {
/// # Panics
///
/// Panics if wrong runtime type is used.
fn start_sync(
&mut self,
_emit_result_tx: mpsc::Sender<ProviderResult>,
_stop_rx: mpsc::Receiver<()>,
) {
fn start_sync(&mut self) {
match self.runtime_type() {
RuntimeType::Sync => {
unreachable!("Sync providers must implement `start_sync`.")
Expand All @@ -32,11 +28,7 @@ pub trait Provider: Send + Sync {
/// # Panics
///
/// Panics if wrong runtime type is used.
async fn start_async(
&mut self,
_emit_result_tx: mpsc::Sender<ProviderResult>,
_stop_rx: mpsc::Receiver<()>,
) {
async fn start_async(&mut self) {
match self.runtime_type() {
RuntimeType::Async => {
unreachable!("Async providers must implement `start_async`.")
Expand Down Expand Up @@ -88,7 +80,7 @@ pub trait Provider: Send + Sync {
}
}

/// Determines whether `run_sync` or `run_async` is called.`
/// Determines whether `start_sync` or `start_async` is called.
pub enum RuntimeType {
Sync,
Async,
Expand Down
38 changes: 29 additions & 9 deletions packages/desktop/src/providers/provider_manager.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{collections::HashMap, sync::Arc};

use anyhow::Context;
use sysinfo::{Disks, Networks, System};
use tauri::AppHandle;
use tokio::sync::Mutex;
Expand All @@ -11,26 +12,22 @@ use super::{ProviderConfig, ProviderRef};
#[derive(Clone)]
pub struct SharedProviderState {
pub sysinfo: Arc<Mutex<System>>,
pub netinfo: Arc<Mutex<Networks>>,
pub diskinfo: Arc<Mutex<Disks>>,
}

/// Manages the creation and cleanup of providers.
pub struct ProviderManager {
app_handle: AppHandle,
providers: Arc<Mutex<HashMap<String, ProviderRef>>>,
provider_refs: Arc<Mutex<HashMap<String, ProviderRef>>>,
shared_state: SharedProviderState,
}

impl ProviderManager {
pub fn new(app_handle: &AppHandle) -> Self {
Self {
app_handle: app_handle.clone(),
providers: Arc::new(Mutex::new(HashMap::new())),
provider_refs: Arc::new(Mutex::new(HashMap::new())),
shared_state: SharedProviderState {
sysinfo: Arc::new(Mutex::new(System::new_all())),
netinfo: Arc::new(Mutex::new(Networks::new_with_refreshed_list())),
diskinfo: Arc::new(Mutex::new(Disks::new_with_refreshed_list())),
},
}
}
Expand All @@ -42,7 +39,7 @@ impl ProviderManager {
config: ProviderConfig,
) -> anyhow::Result<()> {
{
let mut providers = self.providers.lock().await;
let mut providers = self.provider_refs.lock().await;

// If a provider with the given config already exists, refresh it
// and return early.
Expand All @@ -63,15 +60,38 @@ impl ProviderManager {
)
.await?;

let mut providers = self.providers.lock().await;
let mut providers = self.provider_refs.lock().await;
providers.insert(config_hash, provider_ref);

Ok(())
}

/// Calls the given function on the provider with the given config hash.
async fn call_function(
&self,
config_hash: &str,
function: ProviderFunction,
) -> anyhow::Result<ProviderFunctionResult> {
let mut providers = self.provider_refs.lock().await;
let found_provider = providers
.get_mut(&config_hash)
.context("No provider found with config.")?;

match found_provider.runtime_type() {
RuntimeType::Async => {
found_provider.call_async_function(function).await
}
RuntimeType::Sync => task::spawn_blocking(move || {
found_provider.call_sync_function(function)
})
.await
.map_err(|err| format!("Function execution failed: {}", err))?,
}
}

/// Destroys and cleans up the provider with the given config.
pub async fn destroy(&self, config_hash: String) -> anyhow::Result<()> {
let mut providers = self.providers.lock().await;
let mut providers = self.provider_refs.lock().await;

if let Some(found_provider) = providers.get_mut(&config_hash) {
if let Err(err) = found_provider.stop().await {
Expand Down

0 comments on commit 0821c1b

Please sign in to comment.