Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
lars-berger committed Nov 28, 2024
1 parent 8011a6d commit 8fe63ee
Showing 1 changed file with 145 additions and 126 deletions.
271 changes: 145 additions & 126 deletions packages/desktop/src/providers/audio/audio_provider.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
collections::HashMap,
collections::{HashMap, HashSet},
ops::Mul,
sync::{Arc, Mutex, OnceLock},
time::Duration,
Expand Down Expand Up @@ -89,12 +89,19 @@ enum AudioEvent {
VolumeChanged(String, f32),
}

// Main provider implementation
/// Holds the state of an audio device.
#[derive(Clone)]
struct DeviceState {
device: AudioDevice,
volume_callback: IAudioEndpointVolume,
}

pub struct AudioProvider {
common: CommonProviderState,
enumerator: Option<IMMDeviceEnumerator>,
device_volumes: HashMap<String, IAudioEndpointVolume>,
current_state: AudioOutput,
default_playback_id: Option<String>,
default_recording_id: Option<String>,
devices: HashMap<String, DeviceState>,
event_sender: channel::Sender<AudioEvent>,
event_receiver: channel::Receiver<AudioEvent>,
}
Expand All @@ -109,18 +116,45 @@ impl AudioProvider {
Self {
common,
enumerator: None,
device_volumes: HashMap::new(),
current_state: AudioOutput {
playback_devices: Vec::new(),
recording_devices: Vec::new(),
default_playback_device: None,
default_recording_device: None,
},
default_playback_id: None,
default_recording_id: None,
devices: HashMap::new(),
event_sender,
event_receiver,
}
}

fn create_audio_manager(&mut self) -> anyhow::Result<()> {
unsafe {
let _ = CoInitializeEx(None, COINIT_MULTITHREADED);

let enumerator: IMMDeviceEnumerator =
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)?;

// Register device callback
let device_callback = DeviceCallback {
event_sender: self.event_sender.clone(),
};
enumerator.RegisterEndpointNotificationCallback(
&IMMNotificationClient::from(device_callback),
)?;

self.enumerator = Some(enumerator);

// Initial state update
self.update_device_state()?;

// Event loop
while let Ok(event) = self.event_receiver.recv() {
if let Err(e) = self.handle_event(event) {
debug!("Error handling audio event: {}", e);
}
}

Ok(())
}
}

fn get_device_properties(
&self,
device: &IMMDevice,
Expand Down Expand Up @@ -156,152 +190,137 @@ impl AudioProvider {
}
}

fn build_output(&self) -> AudioOutput {
let mut playback_devices = Vec::new();
let mut recording_devices = Vec::new();
let mut default_playback_device = None;
let mut default_recording_device = None;

for (id, state) in &self.devices {
match &state.device.device_type {
DeviceType::Playback => {
if Some(id) == self.default_playback_id.as_ref() {
default_playback_device = Some(state.device.clone());
}
playback_devices.push(state.device.clone());
}
DeviceType::Recording => {
if Some(id) == self.default_recording_id.as_ref() {
default_recording_device = Some(state.device.clone());
}
recording_devices.push(state.device.clone());
}
}
}

// Sort devices by name for consistent ordering.
playback_devices.sort_by(|a, b| a.name.cmp(&b.name));
recording_devices.sort_by(|a, b| a.name.cmp(&b.name));

AudioOutput {
playback_devices,
recording_devices,
default_playback_device,
default_recording_device,
}
}

fn update_device_state(&mut self) -> anyhow::Result<()> {
let enumerator = self
.enumerator
.as_ref()
.context("Enumerator not initialized")?;
let mut active_devices = HashSet::new();

unsafe {
let collection =
enumerator.EnumAudioEndpoints(eRender, DEVICE_STATE_ACTIVE)?;
let default_device = enumerator
.GetDefaultAudioEndpoint(eRender, eMultimedia)
.ok();
// Process both playback and recording devices
for flow in [eRender, eCapture] {
let devices = self.enumerate_devices(flow)?;
let default_device = self.get_default_device(flow).ok();
let default_id = default_device
.as_ref()
.and_then(|d| d.GetId().ok())
.and_then(|id| id.to_string().ok());

let mut new_devices = Vec::new();
let mut new_volumes = HashMap::new();

for i in 0..collection.GetCount()? {
if let Ok(device) = collection.Item(i) {
let (device_id, name) = self.get_device_properties(&device)?;

// Register/get volume interface
let endpoint_volume =
if let Some(existing) = self.device_volumes.get(&device_id) {
existing.clone()
} else {
self.register_volume_callback(&device, device_id.clone())?
};

let volume =
endpoint_volume.GetMasterVolumeLevelScalar()? * 100.0;
let is_default =
default_id.as_ref().map_or(false, |id| *id == device_id);

new_volumes.insert(device_id.clone(), endpoint_volume);

let device_info = AudioDevice {
name,
device_id,
volume: volume.round() as u32,
is_default,
.and_then(|d| unsafe { d.GetId().ok() })
.and_then(|id| unsafe { id.to_string().ok() });

// Update default device IDs
match flow {
e if e == eRender => self.default_playback_id = default_id.clone(),
e if e == eCapture => self.default_recording_id = default_id,
_ => {}
}

for device in devices {
let (device_id, _) = self.get_device_info(&device, flow)?;
active_devices.insert(device_id.clone());

let endpoint_volume =
if let Some(state) = self.devices.get(&device_id) {
state.volume_callback.clone()
} else {
self.register_volume_callback(&device, device_id.clone())?
};

if is_default {
self.current_state.default_playback_device =
Some(device_info.clone());
let is_default = match flow {
e if e == eRender => {
Some(&device_id) == self.default_playback_id.as_ref()
}
new_devices.push(device_info);
}
e if e == eCapture => {
Some(&device_id) == self.default_recording_id.as_ref()
}
_ => false,
};

let device_info = self.create_audio_device(
&device,
flow,
is_default,
&endpoint_volume,
)?;

self.devices.insert(
device_id,
DeviceState {
device: device_info,
volume_callback: endpoint_volume,
},
);
}

self.current_state.playback_devices = new_devices;
self.device_volumes = new_volumes;

self
.common
.emitter
.emit_output(Ok(self.current_state.clone()));
}

// Remove devices that are no longer active
self.devices.retain(|id, _| active_devices.contains(id));

// Emit updated state
self.common.emitter.emit_output(Ok(self.build_output()));
Ok(())
}

fn handle_event(&mut self, event: AudioEvent) -> anyhow::Result<()> {
match event {
AudioEvent::DeviceAdded(_)
| AudioEvent::DeviceRemoved(_)
| AudioEvent::DeviceStateChanged(_, _)
| AudioEvent::DefaultDeviceChanged(_) => {
AudioEvent::DeviceAdded(_, _)
| AudioEvent::DeviceRemoved(_, _)
| AudioEvent::DeviceStateChanged(_, _, _)
| AudioEvent::DefaultDeviceChanged(_, _) => {
self.update_device_state()?;
}
AudioEvent::VolumeChanged(device_id, new_volume) => {
let volume = (new_volume * 100.0).round() as u32;

// Update volume in current state
for device in &mut self.current_state.playback_devices {
if device.device_id == device_id {
device.volume = volume;
if let Some(default_device) =
&mut self.current_state.default_playback_device
{
if default_device.device_id == device_id {
default_device.volume = volume;
}
}
break;
}
if let Some(state) = self.devices.get_mut(&device_id) {
state.device.volume = (new_volume * 100.0).round() as u32;
self.common.emitter.emit_output(Ok(self.build_output()));
}

self
.common
.emitter
.emit_output(Ok(self.current_state.clone()));
}
}

Ok(())
}

fn create_audio_manager(&mut self) -> anyhow::Result<()> {
unsafe {
let _ = CoInitializeEx(None, COINIT_MULTITHREADED);

let enumerator: IMMDeviceEnumerator =
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)?;

// Register device callback
let device_callback = DeviceCallback {
event_sender: self.event_sender.clone(),
};
enumerator.RegisterEndpointNotificationCallback(
&IMMNotificationClient::from(device_callback),
)?;

self.enumerator = Some(enumerator);

// Initial state update
self.update_device_state()?;

// Event loop
while let Ok(event) = self.event_receiver.recv() {
if let Err(e) = self.handle_event(event) {
debug!("Error handling audio event: {}", e);
}
}

Ok(())
}
}
}

impl Drop for AudioProvider {
fn drop(&mut self) {
// Clean up volume callbacks
for volume in self.device_volumes.values() {
// Deregister volume callbacks.
for state in self.devices.values() {
unsafe {
let _ = volume.UnregisterControlChangeNotify(
&IAudioEndpointVolumeCallback::null(),
let _ = state.volume_callback.UnregisterControlChangeNotify(
&IAudioEndpointVolumeCallback::from(&state.volume_callback),
);
}
}

// Clean up device notification callback
// Deregister device notification callback.
if let Some(enumerator) = &self.enumerator {
unsafe {
let _ = enumerator.UnregisterEndpointNotificationCallback(
Expand Down Expand Up @@ -409,7 +428,7 @@ impl IMMNotificationClient_Impl for DeviceCallback_Impl {

fn OnPropertyValueChanged(
&self,
_pwstrDeviceId: &windows::core::PCWSTR,
_device_id: &windows::core::PCWSTR,
_key: &windows::Win32::UI::Shell::PropertiesSystem::PROPERTYKEY,
) -> windows::core::Result<()> {
Ok(())
Expand Down

0 comments on commit 8fe63ee

Please sign in to comment.