From 9c9dde9d2edda92125b4bf0fa35585fd5d8f7e23 Mon Sep 17 00:00:00 2001 From: John Nunley Date: Fri, 22 Nov 2024 19:38:46 -0800 Subject: [PATCH] m: Use better strategy for main thread detection We can get a list of the threads in the process, and determine which thread came first. This thread will be the one who called the main function. So use this instead of the current strategy if it's available. Signed-off-by: John Nunley --- Cargo.toml | 1 + .../linux/wayland/seat/pointer/mod.rs | 5 +- src/platform_impl/windows/event_loop.rs | 147 +++++++++++++++++- 3 files changed, 149 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bbd6f8d1a1..caead0a5f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -212,6 +212,7 @@ windows-sys = { version = "0.52.0", features = [ "Win32_Media", "Win32_System_Com_StructuredStorage", "Win32_System_Com", + "Win32_System_Diagnostics_ToolHelp", "Win32_System_LibraryLoader", "Win32_System_Ole", "Win32_Security", diff --git a/src/platform_impl/linux/wayland/seat/pointer/mod.rs b/src/platform_impl/linux/wayland/seat/pointer/mod.rs index 944918d3a3..8f1f74d40b 100644 --- a/src/platform_impl/linux/wayland/seat/pointer/mod.rs +++ b/src/platform_impl/linux/wayland/seat/pointer/mod.rs @@ -27,7 +27,10 @@ use sctk::seat::pointer::{ use sctk::seat::SeatState; use crate::dpi::{LogicalPosition, PhysicalPosition}; -use crate::event::{ElementState, MouseButton, MouseScrollDelta, PointerSource, PointerKind, TouchPhase, WindowEvent}; +use crate::event::{ + ElementState, MouseButton, MouseScrollDelta, PointerKind, PointerSource, TouchPhase, + WindowEvent, +}; use crate::platform_impl::wayland::state::WinitState; use crate::platform_impl::wayland::{self, WindowId}; diff --git a/src/platform_impl/windows/event_loop.rs b/src/platform_impl/windows/event_loop.rs index 874ae06a8c..ff700d83a4 100644 --- a/src/platform_impl/windows/event_loop.rs +++ b/src/platform_impl/windows/event_loop.rs @@ -20,10 +20,12 @@ use windows_sys::Win32::Graphics::Gdi::{ GetMonitorInfoW, MonitorFromRect, MonitorFromWindow, RedrawWindow, ScreenToClient, ValidateRect, MONITORINFO, MONITOR_DEFAULTTONULL, RDW_INTERNALPAINT, SC_SCREENSAVE, }; +use windows_sys::Win32::System::Diagnostics::ToolHelp as toolhelp; use windows_sys::Win32::System::Ole::RevokeDragDrop; use windows_sys::Win32::System::Threading::{ - CreateWaitableTimerExW, GetCurrentThreadId, SetWaitableTimer, - CREATE_WAITABLE_TIMER_HIGH_RESOLUTION, INFINITE, TIMER_ALL_ACCESS, + CreateWaitableTimerExW, GetCurrentProcessId, GetCurrentThreadId, GetThreadTimes, OpenThread, + SetWaitableTimer, CREATE_WAITABLE_TIMER_HIGH_RESOLUTION, INFINITE, THREAD_QUERY_INFORMATION, + TIMER_ALL_ACCESS, }; use windows_sys::Win32::UI::Controls::{HOVER_DEFAULT, WM_MOUSELEAVE}; use windows_sys::Win32::UI::Input::Ime::{GCS_COMPSTR, GCS_RESULTSTR, ISC_SHOWUICOMPOSITIONWINDOW}; @@ -542,6 +544,37 @@ impl rwh_06::HasDisplayHandle for OwnedDisplayHandle { } } +/// Get the main thread ID. +fn main_thread_id() -> u32 { + main_thread_id_via_snapshot().unwrap_or_else(main_thread_id_via_crt) +} + +/// Get the main thread ID via the snapshot feature. +fn main_thread_id_via_snapshot() -> Option { + // Get the current process ID. + let process_id = unsafe { GetCurrentProcessId() }; + + // Take a snapshot of the process. + let snapshot = match ToolhelpSnapshot::new(process_id) { + Ok(snapshot) => snapshot, + Err(err) => { + tracing::error!("failed to take snapshot of process: {err}"); + return None; + }, + }; + + // Filter to threads owned by this process. + let threads = snapshot.filter(|ti| ti.process_id == process_id); + + // Get the time that the thread was created. + let threadtimes = threads + .filter_map(|ti| ti.thread_time().map(move |time| (ti, time))) + .filter(|(_, time)| *time != 0); + + // Identify the thread with the earliest time, since that must be the thread that called main(). + threadtimes.min_by_key(|(_, time)| *time).map(|(ti, _)| ti.thread_id) +} + /// Returns the id of the main thread. /// /// Windows has no real API to check if the current executing thread is the "main thread", unlike @@ -560,7 +593,11 @@ impl rwh_06::HasDisplayHandle for OwnedDisplayHandle { /// /// Full details of CRT initialization can be found here: /// -fn main_thread_id() -> u32 { +/// +/// notgull addendum: The above comment is a lie, we can get the total list of threads in the +/// process and figure out which thread came first. We try to use that strategy first, and use this +/// strategy as a fallback. +fn main_thread_id_via_crt() -> u32 { static mut MAIN_THREAD_ID: u32 = 0; /// Function pointer used in CRT initialization section to set the above static field's value. @@ -583,6 +620,110 @@ fn main_thread_id() -> u32 { unsafe { MAIN_THREAD_ID } } +/// A screenshot from the toolhelp tool. +struct ToolhelpSnapshot { + /// The handle to the snapshot. + handle: OwnedHandle, + + /// Are we returning the first thread entry? + first_entry: bool, +} + +impl ToolhelpSnapshot { + /// Take a screenshot of the provided process. + fn new(process_id: u32) -> Result { + // Take a snapshot. + let handle = + unsafe { toolhelp::CreateToolhelp32Snapshot(toolhelp::TH32CS_SNAPTHREAD, process_id) }; + + if handle == 0 { + return Err(EventLoopError::Os(os_error!("failed to get toolhelp snapshot"))); + } + + Ok(Self { handle: unsafe { OwnedHandle::from_raw_handle(handle as _) }, first_entry: true }) + } +} + +impl Iterator for ToolhelpSnapshot { + type Item = ThreadEntry; + + fn next(&mut self) -> Option { + let mut slot = mem::MaybeUninit::uninit(); + + // Write the size to the slot. + unsafe { + let slot: *mut toolhelp::THREADENTRY32 = slot.as_mut_ptr(); + let size = ptr::addr_of_mut!((*slot).dwSize); + size.write(mem::size_of::() as _); + } + + let result = if self.first_entry { + self.first_entry = false; + unsafe { toolhelp::Thread32First(self.handle.as_raw_handle() as _, slot.as_mut_ptr()) } + } else { + unsafe { toolhelp::Thread32Next(self.handle.as_raw_handle() as _, slot.as_mut_ptr()) } + }; + + if result != 0 { + let thread_entry = unsafe { slot.assume_init() }; + Some(ThreadEntry { + thread_id: thread_entry.th32ThreadID, + process_id: thread_entry.th32OwnerProcessID, + }) + } else { + None + } + } +} + +#[derive(Copy, Clone)] +struct ThreadEntry { + /// The thread ID. + thread_id: u32, + + /// The owner process ID. + process_id: u32, +} + +impl ThreadEntry { + /// Get the time this thread was created. + fn thread_time(self) -> Option { + // Open a thread id. + let thread_handle = unsafe { + let handle = OpenThread(THREAD_QUERY_INFORMATION, 1, self.thread_id); + + if handle == 0 { + return None; + } + + OwnedHandle::from_raw_handle(handle as _) + }; + + // Get the time this thread was opened. + let file_time = unsafe { + let mut file_time = [mem::MaybeUninit::uninit(); 4]; + let result = GetThreadTimes( + thread_handle.as_raw_handle() as _, + file_time[0].as_mut_ptr(), + file_time[1].as_mut_ptr(), + file_time[2].as_mut_ptr(), + file_time[3].as_mut_ptr(), + ); + + if result == 0 { + return None; + } + + file_time[0].assume_init() + }; + + Some( + ((file_time.dwHighDateTime as u64) << 32) + | ((file_time.dwLowDateTime as u64) & 0xffffffff), + ) + } +} + /// Returns the minimum `Option`, taking into account that `None` /// equates to an infinite timeout, not a zero timeout (so can't just use /// `Option::min`)