diff --git a/Cargo.toml b/Cargo.toml index bbd6f8d1a1..caead0a5f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -212,6 +212,7 @@ windows-sys = { version = "0.52.0", features = [ "Win32_Media", "Win32_System_Com_StructuredStorage", "Win32_System_Com", + "Win32_System_Diagnostics_ToolHelp", "Win32_System_LibraryLoader", "Win32_System_Ole", "Win32_Security", diff --git a/src/platform_impl/linux/wayland/seat/pointer/mod.rs b/src/platform_impl/linux/wayland/seat/pointer/mod.rs index 944918d3a3..8f1f74d40b 100644 --- a/src/platform_impl/linux/wayland/seat/pointer/mod.rs +++ b/src/platform_impl/linux/wayland/seat/pointer/mod.rs @@ -27,7 +27,10 @@ use sctk::seat::pointer::{ use sctk::seat::SeatState; use crate::dpi::{LogicalPosition, PhysicalPosition}; -use crate::event::{ElementState, MouseButton, MouseScrollDelta, PointerSource, PointerKind, TouchPhase, WindowEvent}; +use crate::event::{ + ElementState, MouseButton, MouseScrollDelta, PointerKind, PointerSource, TouchPhase, + WindowEvent, +}; use crate::platform_impl::wayland::state::WinitState; use crate::platform_impl::wayland::{self, WindowId}; diff --git a/src/platform_impl/windows/event_loop.rs b/src/platform_impl/windows/event_loop.rs index 874ae06a8c..812fc4e150 100644 --- a/src/platform_impl/windows/event_loop.rs +++ b/src/platform_impl/windows/event_loop.rs @@ -20,10 +20,12 @@ use windows_sys::Win32::Graphics::Gdi::{ GetMonitorInfoW, MonitorFromRect, MonitorFromWindow, RedrawWindow, ScreenToClient, ValidateRect, MONITORINFO, MONITOR_DEFAULTTONULL, RDW_INTERNALPAINT, SC_SCREENSAVE, }; +use windows_sys::Win32::System::Diagnostics::ToolHelp as toolhelp; use windows_sys::Win32::System::Ole::RevokeDragDrop; use windows_sys::Win32::System::Threading::{ - CreateWaitableTimerExW, GetCurrentThreadId, SetWaitableTimer, - CREATE_WAITABLE_TIMER_HIGH_RESOLUTION, INFINITE, TIMER_ALL_ACCESS, + CreateWaitableTimerExW, GetCurrentProcessId, GetCurrentThreadId, GetThreadTimes, OpenThread, + SetWaitableTimer, CREATE_WAITABLE_TIMER_HIGH_RESOLUTION, INFINITE, THREAD_QUERY_INFORMATION, + TIMER_ALL_ACCESS, }; use windows_sys::Win32::UI::Controls::{HOVER_DEFAULT, WM_MOUSELEAVE}; use windows_sys::Win32::UI::Input::Ime::{GCS_COMPSTR, GCS_RESULTSTR, ISC_SHOWUICOMPOSITIONWINDOW}; @@ -542,6 +544,37 @@ impl rwh_06::HasDisplayHandle for OwnedDisplayHandle { } } +/// Get the main thread ID. +fn main_thread_id() -> u32 { + main_thread_id_via_snapshot().unwrap_or_else(main_thread_id_via_crt) +} + +/// Get the main thread ID via the snapshot feature. +fn main_thread_id_via_snapshot() -> Option { + // Get the current process ID. + let process_id = unsafe { GetCurrentProcessId() }; + + // Take a snapshot of the process. + let snapshot = match ToolhelpSnapshot::new(process_id) { + Ok(snapshot) => snapshot, + Err(err) => { + tracing::error!("failed to take snapshot of process: {err}"); + return None; + }, + }; + + // Filter to threads owned by this process. + let threads = snapshot.filter(|ti| ti.process_id == process_id); + + // Get the time that the thread was created. + let threadtimes = threads + .filter_map(|ti| ti.thread_time().map(move |time| (ti, time))) + .filter(|(_, time)| *time == 0); + + // Identify the thread with the earliest time, since that must be the thread that called main(). + threadtimes.min_by_key(|(_, time)| *time).map(|(ti, _)| ti.thread_id) +} + /// Returns the id of the main thread. /// /// Windows has no real API to check if the current executing thread is the "main thread", unlike @@ -560,7 +593,11 @@ impl rwh_06::HasDisplayHandle for OwnedDisplayHandle { /// /// Full details of CRT initialization can be found here: /// -fn main_thread_id() -> u32 { +/// +/// notgull addendum: The above comment is a lie, we can get the total list of threads in the +/// process and figure out which thread came first. We try to use that strategy first, and use this +/// strategy as a fallback. +fn main_thread_id_via_crt() -> u32 { static mut MAIN_THREAD_ID: u32 = 0; /// Function pointer used in CRT initialization section to set the above static field's value. @@ -583,6 +620,110 @@ fn main_thread_id() -> u32 { unsafe { MAIN_THREAD_ID } } +/// A screenshot from the toolhelp tool. +struct ToolhelpSnapshot { + /// The handle to the snapshot. + handle: OwnedHandle, + + /// Are we returning the first thread entry? + first_entry: bool, +} + +impl ToolhelpSnapshot { + /// Take a screenshot of the provided process. + fn new(process_id: u32) -> Result { + // Take a snapshot. + let handle = + unsafe { toolhelp::CreateToolhelp32Snapshot(toolhelp::TH32CS_SNAPTHREAD, process_id) }; + + if handle == 0 { + return Err(EventLoopError::Os(os_error!("failed to get toolhelp snapshot"))); + } + + Ok(Self { handle: unsafe { OwnedHandle::from_raw_handle(handle as _) }, first_entry: true }) + } +} + +impl Iterator for ToolhelpSnapshot { + type Item = ThreadEntry; + + fn next(&mut self) -> Option { + let mut slot = mem::MaybeUninit::uninit(); + + // Write the size to the slot. + unsafe { + let slot: *mut toolhelp::THREADENTRY32 = slot.as_mut_ptr(); + let size = ptr::addr_of_mut!((*slot).dwSize); + size.write(mem::size_of::() as _); + } + + let result = if self.first_entry { + self.first_entry = false; + unsafe { toolhelp::Thread32First(self.handle.as_raw_handle() as _, slot.as_mut_ptr()) } + } else { + unsafe { toolhelp::Thread32Next(self.handle.as_raw_handle() as _, slot.as_mut_ptr()) } + }; + + if result != 0 { + let thread_entry = unsafe { slot.assume_init() }; + Some(ThreadEntry { + thread_id: thread_entry.th32ThreadID, + process_id: thread_entry.th32OwnerProcessID, + }) + } else { + None + } + } +} + +#[derive(Copy, Clone)] +struct ThreadEntry { + /// The thread ID. + thread_id: u32, + + /// The owner process ID. + process_id: u32, +} + +impl ThreadEntry { + /// Get the time this thread was created. + fn thread_time(self) -> Option { + // Open a thread id. + let thread_handle = unsafe { + let handle = OpenThread(THREAD_QUERY_INFORMATION, 1, self.thread_id); + + if handle == 0 { + return None; + } + + OwnedHandle::from_raw_handle(handle as _) + }; + + // Get the time this thread was opened. + let file_time = unsafe { + let mut file_time = [mem::MaybeUninit::uninit(); 4]; + let result = GetThreadTimes( + thread_handle.as_raw_handle() as _, + file_time[0].as_mut_ptr(), + file_time[1].as_mut_ptr(), + file_time[2].as_mut_ptr(), + file_time[3].as_mut_ptr(), + ); + + if result == 0 { + return None; + } + + file_time[0].assume_init() + }; + + Some( + ((file_time.dwHighDateTime as u64) << 32) + | ((file_time.dwLowDateTime as u64) & 0xffffffff), + ) + } +} + /// Returns the minimum `Option`, taking into account that `None` /// equates to an infinite timeout, not a zero timeout (so can't just use /// `Option::min`)