Skip to content

Commit

Permalink
m: Use better strategy for main thread detection
Browse files Browse the repository at this point in the history
We can get a list of the threads in the process, and determine which
thread came first. This thread will be the one who called the main
function. So use this instead of the current strategy if it's available.

Signed-off-by: John Nunley <[email protected]>
  • Loading branch information
notgull committed Nov 23, 2024
1 parent fc6cf89 commit 9c9dde9
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 4 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ windows-sys = { version = "0.52.0", features = [
"Win32_Media",
"Win32_System_Com_StructuredStorage",
"Win32_System_Com",
"Win32_System_Diagnostics_ToolHelp",
"Win32_System_LibraryLoader",
"Win32_System_Ole",
"Win32_Security",
Expand Down
5 changes: 4 additions & 1 deletion src/platform_impl/linux/wayland/seat/pointer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ use sctk::seat::pointer::{
use sctk::seat::SeatState;

use crate::dpi::{LogicalPosition, PhysicalPosition};
use crate::event::{ElementState, MouseButton, MouseScrollDelta, PointerSource, PointerKind, TouchPhase, WindowEvent};
use crate::event::{
ElementState, MouseButton, MouseScrollDelta, PointerKind, PointerSource, TouchPhase,
WindowEvent,
};

use crate::platform_impl::wayland::state::WinitState;
use crate::platform_impl::wayland::{self, WindowId};
Expand Down
147 changes: 144 additions & 3 deletions src/platform_impl/windows/event_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ use windows_sys::Win32::Graphics::Gdi::{
GetMonitorInfoW, MonitorFromRect, MonitorFromWindow, RedrawWindow, ScreenToClient,
ValidateRect, MONITORINFO, MONITOR_DEFAULTTONULL, RDW_INTERNALPAINT, SC_SCREENSAVE,
};
use windows_sys::Win32::System::Diagnostics::ToolHelp as toolhelp;
use windows_sys::Win32::System::Ole::RevokeDragDrop;
use windows_sys::Win32::System::Threading::{
CreateWaitableTimerExW, GetCurrentThreadId, SetWaitableTimer,
CREATE_WAITABLE_TIMER_HIGH_RESOLUTION, INFINITE, TIMER_ALL_ACCESS,
CreateWaitableTimerExW, GetCurrentProcessId, GetCurrentThreadId, GetThreadTimes, OpenThread,
SetWaitableTimer, CREATE_WAITABLE_TIMER_HIGH_RESOLUTION, INFINITE, THREAD_QUERY_INFORMATION,
TIMER_ALL_ACCESS,
};
use windows_sys::Win32::UI::Controls::{HOVER_DEFAULT, WM_MOUSELEAVE};
use windows_sys::Win32::UI::Input::Ime::{GCS_COMPSTR, GCS_RESULTSTR, ISC_SHOWUICOMPOSITIONWINDOW};
Expand Down Expand Up @@ -542,6 +544,37 @@ impl rwh_06::HasDisplayHandle for OwnedDisplayHandle {
}
}

/// Get the main thread ID.
fn main_thread_id() -> u32 {
main_thread_id_via_snapshot().unwrap_or_else(main_thread_id_via_crt)
}

/// Get the main thread ID via the snapshot feature.
fn main_thread_id_via_snapshot() -> Option<u32> {
// Get the current process ID.
let process_id = unsafe { GetCurrentProcessId() };

// Take a snapshot of the process.
let snapshot = match ToolhelpSnapshot::new(process_id) {
Ok(snapshot) => snapshot,
Err(err) => {
tracing::error!("failed to take snapshot of process: {err}");
return None;
},
};

// Filter to threads owned by this process.
let threads = snapshot.filter(|ti| ti.process_id == process_id);

// Get the time that the thread was created.
let threadtimes = threads
.filter_map(|ti| ti.thread_time().map(move |time| (ti, time)))
.filter(|(_, time)| *time != 0);

// Identify the thread with the earliest time, since that must be the thread that called main().
threadtimes.min_by_key(|(_, time)| *time).map(|(ti, _)| ti.thread_id)
}

/// Returns the id of the main thread.
///
/// Windows has no real API to check if the current executing thread is the "main thread", unlike
Expand All @@ -560,7 +593,11 @@ impl rwh_06::HasDisplayHandle for OwnedDisplayHandle {
///
/// Full details of CRT initialization can be found here:
/// <https://docs.microsoft.com/en-us/cpp/c-runtime-library/crt-initialization?view=msvc-160>
fn main_thread_id() -> u32 {
///
/// notgull addendum: The above comment is a lie, we can get the total list of threads in the
/// process and figure out which thread came first. We try to use that strategy first, and use this
/// strategy as a fallback.
fn main_thread_id_via_crt() -> u32 {
static mut MAIN_THREAD_ID: u32 = 0;

/// Function pointer used in CRT initialization section to set the above static field's value.
Expand All @@ -583,6 +620,110 @@ fn main_thread_id() -> u32 {
unsafe { MAIN_THREAD_ID }
}

/// A screenshot from the toolhelp tool.
struct ToolhelpSnapshot {
/// The handle to the snapshot.
handle: OwnedHandle,

/// Are we returning the first thread entry?
first_entry: bool,
}

impl ToolhelpSnapshot {
/// Take a screenshot of the provided process.
fn new(process_id: u32) -> Result<Self, EventLoopError> {
// Take a snapshot.
let handle =
unsafe { toolhelp::CreateToolhelp32Snapshot(toolhelp::TH32CS_SNAPTHREAD, process_id) };

if handle == 0 {
return Err(EventLoopError::Os(os_error!("failed to get toolhelp snapshot")));
}

Ok(Self { handle: unsafe { OwnedHandle::from_raw_handle(handle as _) }, first_entry: true })
}
}

impl Iterator for ToolhelpSnapshot {
type Item = ThreadEntry;

fn next(&mut self) -> Option<Self::Item> {
let mut slot = mem::MaybeUninit::uninit();

// Write the size to the slot.
unsafe {
let slot: *mut toolhelp::THREADENTRY32 = slot.as_mut_ptr();
let size = ptr::addr_of_mut!((*slot).dwSize);
size.write(mem::size_of::<toolhelp::THREADENTRY32>() as _);
}

let result = if self.first_entry {
self.first_entry = false;
unsafe { toolhelp::Thread32First(self.handle.as_raw_handle() as _, slot.as_mut_ptr()) }
} else {
unsafe { toolhelp::Thread32Next(self.handle.as_raw_handle() as _, slot.as_mut_ptr()) }
};

if result != 0 {
let thread_entry = unsafe { slot.assume_init() };
Some(ThreadEntry {
thread_id: thread_entry.th32ThreadID,
process_id: thread_entry.th32OwnerProcessID,
})
} else {
None
}
}
}

#[derive(Copy, Clone)]
struct ThreadEntry {
/// The thread ID.
thread_id: u32,

/// The owner process ID.
process_id: u32,
}

impl ThreadEntry {
/// Get the time this thread was created.
fn thread_time(self) -> Option<u64> {
// Open a thread id.
let thread_handle = unsafe {
let handle = OpenThread(THREAD_QUERY_INFORMATION, 1, self.thread_id);

if handle == 0 {
return None;
}

OwnedHandle::from_raw_handle(handle as _)
};

// Get the time this thread was opened.
let file_time = unsafe {
let mut file_time = [mem::MaybeUninit::uninit(); 4];
let result = GetThreadTimes(
thread_handle.as_raw_handle() as _,
file_time[0].as_mut_ptr(),
file_time[1].as_mut_ptr(),
file_time[2].as_mut_ptr(),
file_time[3].as_mut_ptr(),
);

if result == 0 {
return None;
}

file_time[0].assume_init()
};

Some(
((file_time.dwHighDateTime as u64) << 32)
| ((file_time.dwLowDateTime as u64) & 0xffffffff),
)
}
}

/// Returns the minimum `Option<Duration>`, taking into account that `None`
/// equates to an infinite timeout, not a zero timeout (so can't just use
/// `Option::min`)
Expand Down

0 comments on commit 9c9dde9

Please sign in to comment.