Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a better strategy on Windows for main thread detection #4006

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ windows-sys = { version = "0.52.0", features = [
"Win32_Media",
"Win32_System_Com_StructuredStorage",
"Win32_System_Com",
"Win32_System_Diagnostics_ToolHelp",
"Win32_System_LibraryLoader",
"Win32_System_Ole",
"Win32_Security",
Expand Down
5 changes: 4 additions & 1 deletion src/platform_impl/linux/wayland/seat/pointer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ use sctk::seat::pointer::{
use sctk::seat::SeatState;

use crate::dpi::{LogicalPosition, PhysicalPosition};
use crate::event::{ElementState, MouseButton, MouseScrollDelta, PointerSource, PointerKind, TouchPhase, WindowEvent};
use crate::event::{
ElementState, MouseButton, MouseScrollDelta, PointerKind, PointerSource, TouchPhase,
WindowEvent,
};

use crate::platform_impl::wayland::state::WinitState;
use crate::platform_impl::wayland::{self, WindowId};
Expand Down
147 changes: 144 additions & 3 deletions src/platform_impl/windows/event_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ use windows_sys::Win32::Graphics::Gdi::{
GetMonitorInfoW, MonitorFromRect, MonitorFromWindow, RedrawWindow, ScreenToClient,
ValidateRect, MONITORINFO, MONITOR_DEFAULTTONULL, RDW_INTERNALPAINT, SC_SCREENSAVE,
};
use windows_sys::Win32::System::Diagnostics::ToolHelp as toolhelp;
use windows_sys::Win32::System::Ole::RevokeDragDrop;
use windows_sys::Win32::System::Threading::{
CreateWaitableTimerExW, GetCurrentThreadId, SetWaitableTimer,
CREATE_WAITABLE_TIMER_HIGH_RESOLUTION, INFINITE, TIMER_ALL_ACCESS,
CreateWaitableTimerExW, GetCurrentProcessId, GetCurrentThreadId, GetThreadTimes, OpenThread,
SetWaitableTimer, CREATE_WAITABLE_TIMER_HIGH_RESOLUTION, INFINITE, THREAD_QUERY_INFORMATION,
TIMER_ALL_ACCESS,
};
use windows_sys::Win32::UI::Controls::{HOVER_DEFAULT, WM_MOUSELEAVE};
use windows_sys::Win32::UI::Input::Ime::{GCS_COMPSTR, GCS_RESULTSTR, ISC_SHOWUICOMPOSITIONWINDOW};
Expand Down Expand Up @@ -542,6 +544,37 @@ impl rwh_06::HasDisplayHandle for OwnedDisplayHandle {
}
}

/// Get the main thread ID.
fn main_thread_id() -> u32 {
main_thread_id_via_snapshot().unwrap_or_else(main_thread_id_via_crt)
}

/// Get the main thread ID via the snapshot feature.
fn main_thread_id_via_snapshot() -> Option<u32> {
// Get the current process ID.
let process_id = unsafe { GetCurrentProcessId() };

// Take a snapshot of the process.
let snapshot = match ToolhelpSnapshot::new(process_id) {
Ok(snapshot) => snapshot,
Err(err) => {
tracing::error!("failed to take snapshot of process: {err}");
return None;
},
};

// Filter to threads owned by this process.
let threads = snapshot.filter(|ti| ti.process_id == process_id);

// Get the time that the thread was created.
let threadtimes = threads
.filter_map(|ti| ti.thread_time().map(move |time| (ti, time)))
.filter(|(_, time)| *time != 0);

// Identify the thread with the earliest time, since that must be the thread that called main().
threadtimes.min_by_key(|(_, time)| *time).map(|(ti, _)| ti.thread_id)
}

/// Returns the id of the main thread.
///
/// Windows has no real API to check if the current executing thread is the "main thread", unlike
Expand All @@ -560,7 +593,11 @@ impl rwh_06::HasDisplayHandle for OwnedDisplayHandle {
///
/// Full details of CRT initialization can be found here:
/// <https://docs.microsoft.com/en-us/cpp/c-runtime-library/crt-initialization?view=msvc-160>
fn main_thread_id() -> u32 {
///
/// notgull addendum: The above comment is a lie, we can get the total list of threads in the
/// process and figure out which thread came first. We try to use that strategy first, and use this
/// strategy as a fallback.
fn main_thread_id_via_crt() -> u32 {
static mut MAIN_THREAD_ID: u32 = 0;

/// Function pointer used in CRT initialization section to set the above static field's value.
Expand All @@ -583,6 +620,110 @@ fn main_thread_id() -> u32 {
unsafe { MAIN_THREAD_ID }
}

/// A screenshot from the toolhelp tool.
struct ToolhelpSnapshot {
/// The handle to the snapshot.
handle: OwnedHandle,

/// Are we returning the first thread entry?
first_entry: bool,
}

impl ToolhelpSnapshot {
/// Take a screenshot of the provided process.
fn new(process_id: u32) -> Result<Self, EventLoopError> {
// Take a snapshot.
let handle =
unsafe { toolhelp::CreateToolhelp32Snapshot(toolhelp::TH32CS_SNAPTHREAD, process_id) };
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can pass 0 to CreateToolhelp32Snapshot, no need to pass the current process ID.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the process ID in other places, so I figure why not.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the docs:

This parameter can be zero to indicate the current process.

So it's not really necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figure it's slightly clearer to whoever is reading who may not be familiar with this API.


if handle == 0 {
return Err(EventLoopError::Os(os_error!("failed to get toolhelp snapshot")));
}

Ok(Self { handle: unsafe { OwnedHandle::from_raw_handle(handle as _) }, first_entry: true })
}
}

impl Iterator for ToolhelpSnapshot {
type Item = ThreadEntry;

fn next(&mut self) -> Option<Self::Item> {
let mut slot = mem::MaybeUninit::uninit();

// Write the size to the slot.
unsafe {
let slot: *mut toolhelp::THREADENTRY32 = slot.as_mut_ptr();
let size = ptr::addr_of_mut!((*slot).dwSize);
size.write(mem::size_of::<toolhelp::THREADENTRY32>() as _);
}

let result = if self.first_entry {
self.first_entry = false;
unsafe { toolhelp::Thread32First(self.handle.as_raw_handle() as _, slot.as_mut_ptr()) }
} else {
unsafe { toolhelp::Thread32Next(self.handle.as_raw_handle() as _, slot.as_mut_ptr()) }
};

if result != 0 {
let thread_entry = unsafe { slot.assume_init() };
Some(ThreadEntry {
thread_id: thread_entry.th32ThreadID,
process_id: thread_entry.th32OwnerProcessID,
})
} else {
None
}
}
}

#[derive(Copy, Clone)]
struct ThreadEntry {
/// The thread ID.
thread_id: u32,

/// The owner process ID.
process_id: u32,
}

impl ThreadEntry {
/// Get the time this thread was created.
fn thread_time(self) -> Option<u64> {
// Open a thread id.
let thread_handle = unsafe {
let handle = OpenThread(THREAD_QUERY_INFORMATION, 1, self.thread_id);

if handle == 0 {
return None;
}

OwnedHandle::from_raw_handle(handle as _)
};

// Get the time this thread was opened.
let file_time = unsafe {
let mut file_time = [mem::MaybeUninit::uninit(); 4];
let result = GetThreadTimes(
thread_handle.as_raw_handle() as _,
file_time[0].as_mut_ptr(),
file_time[1].as_mut_ptr(),
file_time[2].as_mut_ptr(),
file_time[3].as_mut_ptr(),
);

if result == 0 {
return None;
}

file_time[0].assume_init()
};

Some(
((file_time.dwHighDateTime as u64) << 32)
| ((file_time.dwLowDateTime as u64) & 0xffffffff),
)
}
}

/// Returns the minimum `Option<Duration>`, taking into account that `None`
/// equates to an infinite timeout, not a zero timeout (so can't just use
/// `Option::min`)
Expand Down
Loading