From 62155beb9f5d8fbdbc9e5cf834daa2466ec07299 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sat, 2 Nov 2024 14:20:40 +0100 Subject: [PATCH 01/21] don't zero newly accepted memory The TDX-module already zeroes out newly accepted pages. --- tee/supervisor-tdx/src/dynamic.rs | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/tee/supervisor-tdx/src/dynamic.rs b/tee/supervisor-tdx/src/dynamic.rs index 7bf0cdb4..15957d8f 100644 --- a/tee/supervisor-tdx/src/dynamic.rs +++ b/tee/supervisor-tdx/src/dynamic.rs @@ -2,10 +2,7 @@ use bit_field::BitField; use constants::{physical_address::DYNAMIC_2MIB, MEMORY_PORT}; use supervisor_services::allocation_buffer::SlotIndex; use tdx_types::tdcall::GpaAttr; -use x86_64::{ - structures::paging::{Page, PageSize, PhysFrame, Size2MiB, Size4KiB}, - VirtAddr, -}; +use x86_64::structures::paging::{PhysFrame, Size4KiB}; use crate::tdcall::{Tdcall, Vmcall}; @@ -77,17 +74,6 @@ impl HostAllocator { Tdcall::mem_page_accept(gpa); } - // Zero out the memory. - let base = Page::::from_start_address(VirtAddr::new(0x200000000000)).unwrap(); - let page = base + u64::from(slot_idx.get()); - unsafe { - core::ptr::write_bytes( - page.start_address().as_mut_ptr::(), - 0, - Size2MiB::SIZE as usize, - ); - } - // Make the frame accessible to the L2 VM. let start = PhysFrame::::from_start_address(gpa.start_address()).unwrap(); let end = start + 511; From edf1c7836dc19642c8fa9b240ba0fef560fea8c2 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sat, 2 Nov 2024 14:21:27 +0100 Subject: [PATCH 02/21] don't map dynamic memory into supervisor The TDX supervisor never needs to access the dynamic memory, so let's just not map it. --- tee/supervisor-tdx/src/pagetable.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tee/supervisor-tdx/src/pagetable.rs b/tee/supervisor-tdx/src/pagetable.rs index 92579eb2..0943a236 100644 --- a/tee/supervisor-tdx/src/pagetable.rs +++ b/tee/supervisor-tdx/src/pagetable.rs @@ -4,7 +4,7 @@ use constants::{ physical_address::{ self, supervisor::{tdx::*, LOG_BUFFER}, - DYNAMIC, INIT_FILE, INPUT_FILE, + INIT_FILE, INPUT_FILE, }, MAX_APS_COUNT, }; @@ -22,7 +22,6 @@ use crate::reset_vector::STACK_SIZE; static PML4: StaticPml4 = { let mut page_table = StaticPageTable::new(); page_table.set_table(0, &PDP_0, flags!(WRITE)); - page_table.set_table(64, &PDP_64, flags!(WRITE | EXECUTE_DISABLE)); page_table }; @@ -101,13 +100,6 @@ static PD_0_3: StaticPd = { page_table }; -#[link_section = ".pagetables"] -static PDP_64: StaticPdp = { - let mut page_table = StaticPageTable::new(); - page_table.set_page_range(0, DYNAMIC, flags!(WRITE | EXECUTE_DISABLE)); - page_table -}; - /// Create static variables that are shared with the host. #[macro_export] macro_rules! shared { From b2f4f23a1fe7d1b2d7b3f04b8be5d75957937699 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sat, 2 Nov 2024 14:31:28 +0100 Subject: [PATCH 03/21] remove redundant check This check does the same as check_user_address. --- tee/kernel/src/memory/pagetable.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tee/kernel/src/memory/pagetable.rs b/tee/kernel/src/memory/pagetable.rs index 5a6eaf46..ffeb312b 100644 --- a/tee/kernel/src/memory/pagetable.rs +++ b/tee/kernel/src/memory/pagetable.rs @@ -616,12 +616,6 @@ impl Pagetables { return Ok(()); } - // Make sure that even the end is still in the lower half. - let end_inclusive = Step::forward_checked(dest, src.len() - 1).ok_or(())?; - if end_inclusive.as_u64().get_bit(63) { - return Err(()); - } - check_user_address(dest, src.len()).map_err(|_| ())?; let _guard = self.activate(); From 19116eefdab07271f82a1abc11b55a458fde050b Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sat, 2 Nov 2024 14:31:45 +0100 Subject: [PATCH 04/21] remove code comment --- tee/kernel/src/memory/pagetable.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/tee/kernel/src/memory/pagetable.rs b/tee/kernel/src/memory/pagetable.rs index ffeb312b..60e0c5bc 100644 --- a/tee/kernel/src/memory/pagetable.rs +++ b/tee/kernel/src/memory/pagetable.rs @@ -249,8 +249,6 @@ fn try_read_fast(src: VirtAddr, dest: NonNull<[u8]>) -> Result<(), ()> { ); } - // assert_eq!(failed, 0); - if failed == 0 { Ok(()) } else { From 425be2dc846aef396d9352f735722a4cdd90a7d9 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sat, 2 Nov 2024 17:47:04 +0100 Subject: [PATCH 05/21] abort increase_reference_count if table is initializing --- tee/kernel/src/memory/pagetable.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tee/kernel/src/memory/pagetable.rs b/tee/kernel/src/memory/pagetable.rs index 60e0c5bc..7d01f74b 100644 --- a/tee/kernel/src/memory/pagetable.rs +++ b/tee/kernel/src/memory/pagetable.rs @@ -920,11 +920,11 @@ where let mut current_entry = atomic_load(&self.entry); loop { - // If the entry is being initialized right now, spin. + // If the entry is being initialized right now, this means that + // there's no page table yet and the caller likely isn't interested + // in a page table that only exists in a short while. if current_entry.get_bit(INITIALIZING_BIT) { - core::hint::spin_loop(); - current_entry = atomic_load(&self.entry); - continue; + return Err(()); } // Verify that the entry was already initialized. From 1ead6409abefda787948bd7da683a5d376dc5d94 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sat, 2 Nov 2024 17:49:03 +0100 Subject: [PATCH 06/21] remove need to flush when creating page table --- tee/kernel/src/memory/pagetable.rs | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tee/kernel/src/memory/pagetable.rs b/tee/kernel/src/memory/pagetable.rs index 7d01f74b..b2ea1e1e 100644 --- a/tee/kernel/src/memory/pagetable.rs +++ b/tee/kernel/src/memory/pagetable.rs @@ -889,12 +889,6 @@ where // Write the entry back. atomic_store(&self.entry, new_entry); - // Flush the entry for the page table. There's a short window - // where another thread has removed the entry, but hasn't yet - // flushed the entry on this thread yet, which would lead to - // this thread using a stale entry. - self.flush(true); - // Zero out the page table. let table_ptr = self.as_table_ptr().cast_mut(); unsafe { @@ -996,8 +990,18 @@ where // The reference count hit zero. Zero out the entry and free // the frame. + // We remove the page table in three steps: + // 1. Zero out the entry, but set `INITIALIZING_BIT`. This + // prevents other threads from changing anything until step + // 2 is complete. + // 2. Flush the page table from all APs. + // 3. Write zero to the entry. + + let new_entry = 1 << INITIALIZING_BIT; + + // Step 1: // First try to commit the zeroing. - let res = atomic_compare_exchange(&self.entry, current_entry, 0); + let res = atomic_compare_exchange(&self.entry, current_entry, new_entry); match res { Ok(_) => { // Success! @@ -1009,8 +1013,12 @@ where } } + // Step 2: self.flush(true); + // Step 3: + atomic_store(&self.entry, 0); + // Extract the freed frame and return it. let phys_addr = PhysAddr::new_truncate(current_entry); let frame = PhysFrame::containing_address(phys_addr); From 9217950e9b28b455889fcb3475535e7729aaff2b Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sat, 2 Nov 2024 18:08:28 +0100 Subject: [PATCH 07/21] add fast-path for decreasing reference count --- tee/kernel/src/memory/pagetable.rs | 56 ++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/tee/kernel/src/memory/pagetable.rs b/tee/kernel/src/memory/pagetable.rs index b2ea1e1e..1e2e00a3 100644 --- a/tee/kernel/src/memory/pagetable.rs +++ b/tee/kernel/src/memory/pagetable.rs @@ -1027,6 +1027,22 @@ where } } + /// Only decrease the reference count, but don't do any resource + /// management. + /// + /// # Safety + /// + /// The caller must ensure that the entry is that `release` is only called + /// after the `acquire` is no longer needed. Additionally, the caller must + /// ensure that the reference count doesn't hit zero. + unsafe fn release_reference_count_fast(&self) { + if self.is_static_entry() { + return; + } + + fetch_sub(&self.entry, 1 << REFERENCE_COUNT_BITS.start); + } + fn as_table_ptr(&self) -> *const ActivePageTable<::Next> { let addr = VirtAddr::from_ptr(self); let p4_index = addr.p3_index(); @@ -1112,9 +1128,9 @@ impl ActivePageTableEntry { let old_entry = PresentPageTableEntry::try_from(old_entry).unwrap(); self.flush(old_entry.global()); - // FIXME: Free up the frame. - let _maybe_frame = unsafe { self.parent_table_entry().release_reference_count() }; - assert_eq!(_maybe_frame, None); + unsafe { + self.parent_table_entry().release_reference_count_fast(); + } old_entry } @@ -1123,8 +1139,9 @@ impl ActivePageTableEntry { pub unsafe fn try_unmap(&self) { let old_entry = atomic_swap(&self.entry, 0); if PresentPageTableEntry::try_from(old_entry).is_ok() { - let frame = unsafe { self.parent_table_entry().release_reference_count() }; - assert_eq!(frame, None); + unsafe { + self.parent_table_entry().release_reference_count_fast(); + } } } @@ -1151,8 +1168,9 @@ where L: HasParentLevel + TableLevel, { unsafe fn release_parent(&self) { - let frame = unsafe { self.parent_table_entry().release_reference_count() }; - assert_eq!(frame, None); + unsafe { + self.parent_table_entry().release_reference_count_fast(); + } } } @@ -1485,6 +1503,30 @@ fn atomic_fetch_and(entry: &AtomicU64, val: u64) -> u64 { } } +/// Wrapper around `AtomicU64::fetch_add` without address sanitizer checks. +#[inline(always)] +fn fetch_add(entry: &AtomicU64, val: u64) -> u64 { + if cfg!(sanitize = "address") { + let out; + unsafe { + asm! { + "lock xadd [{ptr}], {out}", + out = inout(reg) val => out, + ptr = in(reg) entry.as_ptr(), + } + } + out + } else { + entry.fetch_add(val, Ordering::SeqCst) + } +} + +/// Wrapper around `AtomicU64::fetch_sub` without address sanitizer checks. +#[inline(always)] +fn fetch_sub(entry: &AtomicU64, val: u64) -> u64 { + fetch_add(entry, (-(val as i64)) as u64) +} + /// Wrapper around `core::ptr::read` without address sanitizer checks. /// /// # Safety From 07b0dcf075c3aaee9281aeb71a982ca768b96452 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sun, 3 Nov 2024 08:11:05 +0100 Subject: [PATCH 08/21] set cr0 and cr4 read shadows If we don't setup these values, the guest can't read the correct values. This can lead to the guest failing to detect activate features e.g. PCID. --- common/tdx-types/src/tdcall.rs | 2 ++ tee/supervisor-tdx/src/vcpu.rs | 46 ++++++++++++++++------------------ 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/common/tdx-types/src/tdcall.rs b/common/tdx-types/src/tdcall.rs index 1d4b8765..42501d3c 100644 --- a/common/tdx-types/src/tdcall.rs +++ b/common/tdx-types/src/tdcall.rs @@ -42,6 +42,8 @@ impl MdFieldId { pub const VMX_VM_ENTRY_CONTROL: Self = Self::vmcs1(0x4012); pub const VMX_VM_EXECUTION_CONTROL_SECONDARY_PROC_BASED: Self = Self::vmcs1(0x401E); pub const VMX_GUEST_CS_ARBYTE: Self = Self::vmcs1(0x4816); + pub const VMX_CR0_READ_SHADOW: Self = Self::vmcs1(0x6004); + pub const VMX_CR4_READ_SHADOW: Self = Self::vmcs1(0x6006); pub const VMX_GUEST_CR0: Self = Self::vmcs1(0x6800); pub const VMX_GUEST_CR3: Self = Self::vmcs1(0x6802); pub const VMX_GUEST_CR4: Self = Self::vmcs1(0x6804); diff --git a/tee/supervisor-tdx/src/vcpu.rs b/tee/supervisor-tdx/src/vcpu.rs index 62022778..f6df8b3f 100644 --- a/tee/supervisor-tdx/src/vcpu.rs +++ b/tee/supervisor-tdx/src/vcpu.rs @@ -100,34 +100,30 @@ pub unsafe fn init_vcpu(apic: &mut Apic) { !0, ); - Tdcall::vp_wr( - MdFieldId::VMX_GUEST_CR4, - Cr4Flags::PHYSICAL_ADDRESS_EXTENSION.bits() - | Cr4Flags::MACHINE_CHECK_EXCEPTION.bits() - | Cr4Flags::PAGE_GLOBAL.bits() - | Cr4Flags::OSFXSR.bits() - | Cr4Flags::OSXMMEXCPT_ENABLE.bits() - | Cr4Flags::VIRTUAL_MACHINE_EXTENSIONS.bits() - | Cr4Flags::FSGSBASE.bits() - | Cr4Flags::PCID.bits() - | Cr4Flags::OSXSAVE.bits() - | Cr4Flags::SUPERVISOR_MODE_EXECUTION_PROTECTION.bits() - | Cr4Flags::SUPERVISOR_MODE_ACCESS_PREVENTION.bits(), - !0, - ); + let cr4_flags = Cr4Flags::PHYSICAL_ADDRESS_EXTENSION.bits() + | Cr4Flags::MACHINE_CHECK_EXCEPTION.bits() + | Cr4Flags::PAGE_GLOBAL.bits() + | Cr4Flags::OSFXSR.bits() + | Cr4Flags::OSXMMEXCPT_ENABLE.bits() + | Cr4Flags::VIRTUAL_MACHINE_EXTENSIONS.bits() + | Cr4Flags::FSGSBASE.bits() + | Cr4Flags::PCID.bits() + | Cr4Flags::OSXSAVE.bits() + | Cr4Flags::SUPERVISOR_MODE_EXECUTION_PROTECTION.bits() + | Cr4Flags::SUPERVISOR_MODE_ACCESS_PREVENTION.bits(); + Tdcall::vp_wr(MdFieldId::VMX_GUEST_CR4, cr4_flags, !0); + Tdcall::vp_wr(MdFieldId::VMX_CR4_READ_SHADOW, cr4_flags, !0); Tdcall::vp_wr(MdFieldId::VMX_GUEST_CR3, 0x100_0000_1000, !0); - Tdcall::vp_wr( - MdFieldId::VMX_GUEST_CR0, - Cr0Flags::PROTECTED_MODE_ENABLE.bits() - | Cr0Flags::MONITOR_COPROCESSOR.bits() - | Cr0Flags::EXTENSION_TYPE.bits() - | Cr0Flags::NUMERIC_ERROR.bits() - | Cr0Flags::WRITE_PROTECT.bits() - | Cr0Flags::PAGING.bits(), - !0, - ); + let cr0_flags = Cr0Flags::PROTECTED_MODE_ENABLE.bits() + | Cr0Flags::MONITOR_COPROCESSOR.bits() + | Cr0Flags::EXTENSION_TYPE.bits() + | Cr0Flags::NUMERIC_ERROR.bits() + | Cr0Flags::WRITE_PROTECT.bits() + | Cr0Flags::PAGING.bits(); + Tdcall::vp_wr(MdFieldId::VMX_GUEST_CR0, cr0_flags, !0); + Tdcall::vp_wr(MdFieldId::VMX_CR0_READ_SHADOW, cr0_flags, !0); Tdcall::vp_wr(MdFieldId::STAR_WRITE, 0, MdFieldId::STAR_WRITE_MASK); Tdcall::vp_wr(MdFieldId::LSTAR_WRITE, 0, MdFieldId::LSTAR_WRITE_MASK); From f489072b36aaba69849b62e522512e81dea9f78a Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 11 Nov 2024 16:21:56 +0100 Subject: [PATCH 09/21] introduce AP index and bitmap types --- common/Cargo.lock | 1 + common/constants/Cargo.toml | 3 +- common/constants/src/lib.rs | 251 +++++++++++++++++- common/profiler-types/src/lib.rs | 9 +- common/supervisor-services/Cargo.toml | 1 + .../src/allocation_buffer.rs | 4 +- .../supervisor-services/src/command_buffer.rs | 4 +- .../src/notification_buffer.rs | 31 +-- host/Cargo.lock | 1 + host/mushroom/src/insecure.rs | 2 +- host/mushroom/src/profiler.rs | 15 +- tee/Cargo.lock | 1 + tee/kernel/Cargo.toml | 2 +- tee/kernel/src/main.rs | 4 +- tee/kernel/src/per_cpu.rs | 8 +- tee/kernel/src/profiler.rs | 7 +- tee/kernel/src/supervisor.rs | 31 +-- tee/supervisor-snp/src/ap.rs | 18 +- tee/supervisor-snp/src/services.rs | 2 +- tee/supervisor-tdx/src/main.rs | 2 +- tee/supervisor-tdx/src/per_cpu.rs | 7 +- tee/supervisor-tdx/src/reset_vector.rs | 4 +- tee/supervisor-tdx/src/services.rs | 6 +- tee/supervisor-tdx/src/tlb.rs | 25 +- tee/supervisor-tdx/src/vcpu.rs | 2 +- 25 files changed, 335 insertions(+), 106 deletions(-) diff --git a/common/Cargo.lock b/common/Cargo.lock index b4cdc88e..3ddf6ef5 100644 --- a/common/Cargo.lock +++ b/common/Cargo.lock @@ -523,6 +523,7 @@ version = "0.1.0" dependencies = [ "bit_field", "bytemuck", + "constants", ] [[package]] diff --git a/common/constants/Cargo.toml b/common/constants/Cargo.toml index 5ba199c3..76130392 100644 --- a/common/constants/Cargo.toml +++ b/common/constants/Cargo.toml @@ -3,7 +3,8 @@ name = "constants" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +nightly = [] [dependencies] x86_64 = { version = "0.15.1", default-features = false } diff --git a/common/constants/src/lib.rs b/common/constants/src/lib.rs index dd5f4d86..edd60dfb 100644 --- a/common/constants/src/lib.rs +++ b/common/constants/src/lib.rs @@ -1,8 +1,13 @@ -//! This crate contains constants shared between the kernel, loader and host executable. +//! This crate contains constants and related types shared between the kernel, +//! loader and host executable. #![cfg_attr(not(test), no_std)] -#![forbid(unsafe_code)] -use core::{marker::PhantomData, ops::RangeInclusive}; +use core::{ + fmt::{self, Debug, Display}, + marker::PhantomData, + ops::{BitAnd, BitAndAssign, BitOrAssign, Index, RangeInclusive}, + sync::atomic::{AtomicU32, Ordering}, +}; use x86_64::{ structures::paging::{ @@ -13,6 +18,234 @@ use x86_64::{ pub const MAX_APS_COUNT: u8 = 32; +/// `ApIndex` represents the index of one vCPU thread running the workload +/// kernel. It's maximum value is capped at compile time. +#[derive(Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +pub struct ApIndex(u8); + +impl ApIndex { + /// Create a new `ApIndex` from an integer. + /// + /// # Panics + /// + /// This function panics if `idx` exceeds [`MAX_APS_COUNT`]. + #[must_use] + pub const fn new(idx: u8) -> Self { + assert!(idx < MAX_APS_COUNT); + Self(idx) + } + + /// Create a new `ApIndex` from an integer or return `None` if `idx` + /// exceeds [`MAX_APS_COUNT`]. + #[must_use] + pub const fn try_new(idx: u8) -> Option { + if idx < MAX_APS_COUNT { + Some(Self::new(idx)) + } else { + None + } + } + + /// Returns `true` for the first AP that starts running. + #[must_use] + pub const fn is_first(&self) -> bool { + self.0 == 0 + } + + #[must_use] + pub fn as_u8(&self) -> u8 { + self.0 + } +} + +impl Display for ApIndex { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl Debug for ApIndex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.0, f) + } +} + +impl Index for [T; MAX_APS_COUNT as usize] { + type Output = T; + + fn index(&self, index: ApIndex) -> &Self::Output { + unsafe { self.get_unchecked(usize::from(index.as_u8())) } + } +} + +type BitmapType = u32; +type AtomicBitmapType = AtomicU32; + +// Make sure that both types are of the same size. +const _: () = assert!(size_of::() == size_of::()); + +// Make sure that the bitmap type can fit all bits. +const _: () = assert!((MAX_APS_COUNT as usize).div_ceil(8) <= size_of::()); + +/// A bitmap containing one bit for every vCPU thread running the workload. +/// Its size is capped by [`MAX_APS_COUNT`] at compile-time. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct ApBitmap(BitmapType); + +impl ApBitmap { + /// Create a new bitmap with all bits set to `false`. + #[must_use] + pub const fn empty() -> Self { + Self(0) + } + + /// Create a new bitmap with all bits set to `true`. + #[must_use] + pub const fn all() -> Self { + let mut bits = 0; + let mut i = 0; + while i < MAX_APS_COUNT { + bits |= 1 << i; + i += 1; + } + Self(bits) + } + + /// Returns the bit for the given AP. + #[must_use] + pub const fn get(&self, idx: ApIndex) -> bool { + self.0 & (1 << idx.0) != 0 + } + + /// Sets the bit for the given AP. + #[cfg(feature = "nightly")] // TODO: Remove this when Rust 1.83 is released. + pub const fn set(&mut self, idx: ApIndex, value: bool) { + if value { + self.0 |= 1 << idx.0; + } else { + self.0 &= !(1 << idx.0); + } + } + + /// Sets the bit for the given AP. + #[cfg(not(feature = "nightly"))] // TODO: Remove this when Rust 1.83 is released. + pub fn set(&mut self, idx: ApIndex, value: bool) { + if value { + self.0 |= 1 << idx.0; + } else { + self.0 &= !(1 << idx.0); + } + } + + /// Returns whether all bits are `false`. + #[must_use] + pub const fn is_empty(&self) -> bool { + self.0 == 0 + } + + /// Returns the index of the first AP whose bit is not set. + #[must_use] + pub fn first_unset(&self) -> Option { + let idx = self.0.trailing_ones() as u8; + ApIndex::try_new(idx) + } +} + +impl BitOrAssign for ApBitmap { + fn bitor_assign(&mut self, rhs: Self) { + self.0 |= rhs.0; + } +} + +impl BitAnd for ApBitmap { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + Self(self.0 & rhs.0) + } +} + +impl BitAndAssign for ApBitmap { + fn bitand_assign(&mut self, rhs: Self) { + *self = *self & rhs; + } +} + +impl IntoIterator for ApBitmap { + type Item = ApIndex; + type IntoIter = ApBitmapIter; + + /// Returns the indicies of all APs whose bit is `true`. + fn into_iter(self) -> Self::IntoIter { + ApBitmapIter(self) + } +} + +/// Returns the indicies of all APs whose bit is `true`. +pub struct ApBitmapIter(ApBitmap); + +impl Iterator for ApBitmapIter { + type Item = ApIndex; + + fn next(&mut self) -> Option { + let idx = self.0 .0.trailing_zeros(); + let idx = ApIndex::try_new(idx as u8)?; + self.0.set(idx, false); + Some(idx) + } +} + +impl Debug for ApBitmap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_set().entries(*self).finish() + } +} + +/// The atomic equivalent of [`ApBitmap`]. +#[repr(transparent)] +pub struct AtomicApBitmap(AtomicBitmapType); + +impl AtomicApBitmap { + /// Create a new bitmap with all bits set to `false`. + pub const fn empty() -> Self { + Self(AtomicBitmapType::new(0)) + } + + /// Returns the bit for the given AP. + pub fn get(&self, idx: ApIndex) -> bool { + self.0.load(Ordering::SeqCst) & (1 << idx.0) != 0 + } + + /// Returns a copy of all bits. + pub fn get_all(&self) -> ApBitmap { + ApBitmap(self.0.load(Ordering::SeqCst)) + } + + /// Sets the bit for the given AP to `true`. + pub fn set(&self, idx: ApIndex) -> bool { + let mask = 1 << idx.0; + self.0.fetch_or(mask, Ordering::SeqCst) & mask != 0 + } + + /// Sets the bits for the given APs to `true`. + pub fn set_all(&self, aps: ApBitmap) { + self.0.fetch_or(aps.0, Ordering::SeqCst); + } + + /// Atomically clear the bit for the given AP and return its value. + pub fn take(&self, idx: ApIndex) -> bool { + let mask = 1 << idx.0; + self.0.fetch_and(!mask, Ordering::SeqCst) & mask != 0 + } + + /// Atomically clears the bits for all APs and return their values. + pub fn take_all(&self) -> ApBitmap { + ApBitmap(self.0.swap(0, Ordering::SeqCst)) + } +} + pub const FIRST_AP: u8 = 0x80; pub const EXIT_PORT: u16 = 0xf4; @@ -120,7 +353,7 @@ where mod tests { use x86_64::{structures::paging::Page, VirtAddr}; - use crate::{check_ranges, PageRange}; + use crate::{check_ranges, ApBitmap, PageRange, MAX_APS_COUNT}; #[test] fn test_address_range() { @@ -171,4 +404,14 @@ mod tests { PageRange::::new(0x2000..=0x2fff), ]) } + + #[test] + fn test_bitmap_range() { + let bitmap = ApBitmap::all(); + let iter = bitmap.into_iter(); + for (i, idx) in iter.enumerate() { + assert_eq!(i, usize::from(idx.as_u8())); + } + assert_eq!(bitmap.into_iter().count(), usize::from(MAX_APS_COUNT)); + } } diff --git a/common/profiler-types/src/lib.rs b/common/profiler-types/src/lib.rs index 6cc108c9..64d4cd23 100644 --- a/common/profiler-types/src/lib.rs +++ b/common/profiler-types/src/lib.rs @@ -1,12 +1,9 @@ #![no_std] -use core::{mem::size_of, sync::atomic::AtomicU8}; +use core::mem::size_of; use bytemuck::{NoUninit, Zeroable}; -use constants::MAX_APS_COUNT; - -const NOTIFY_BITS: usize = MAX_APS_COUNT as usize; -pub const NOTIFY_BYTES: usize = NOTIFY_BITS.div_ceil(8); +use constants::{AtomicApBitmap, MAX_APS_COUNT}; #[repr(C)] pub struct ProfilerControl { @@ -14,7 +11,7 @@ pub struct ProfilerControl { /// kernel after it writes to a header. Set to `false` by the host after /// reading and processing the header. This mechanism aims to reduce /// contention. - pub notify_flags: [AtomicU8; NOTIFY_BYTES], + pub notify_flags: AtomicApBitmap, pub headers: [PerCpuHeader; MAX_APS_COUNT as usize], /// The effective frequency in MHz of the guest view of TSC. pub tsc_mhz: u64, diff --git a/common/supervisor-services/Cargo.toml b/common/supervisor-services/Cargo.toml index 29cf375d..65706559 100644 --- a/common/supervisor-services/Cargo.toml +++ b/common/supervisor-services/Cargo.toml @@ -11,3 +11,4 @@ kernel = [] [dependencies] bit_field = "0.10.2" bytemuck = { version = "1.15.0", features = ["derive"] } +constants = { workspace = true } diff --git a/common/supervisor-services/src/allocation_buffer.rs b/common/supervisor-services/src/allocation_buffer.rs index ebdb6bc8..9cfe7748 100644 --- a/common/supervisor-services/src/allocation_buffer.rs +++ b/common/supervisor-services/src/allocation_buffer.rs @@ -6,7 +6,9 @@ //! the buffer by issuing the [`AllocateMemory`](crate::command_buffer::AllocateMemory) //! command. -use core::sync::atomic::{AtomicU16, Ordering}; +use core::sync::atomic::AtomicU16; +#[cfg(any(feature = "kernel", feature = "supervisor"))] +use core::sync::atomic::Ordering; use bytemuck::{Pod, Zeroable}; diff --git a/common/supervisor-services/src/command_buffer.rs b/common/supervisor-services/src/command_buffer.rs index 8306c786..c6b0c360 100644 --- a/common/supervisor-services/src/command_buffer.rs +++ b/common/supervisor-services/src/command_buffer.rs @@ -4,7 +4,9 @@ #[cfg(feature = "kernel")] use core::iter::once; -use core::sync::atomic::{AtomicU16, AtomicU8, Ordering}; +#[cfg(any(feature = "kernel", feature = "supervisor"))] +use core::sync::atomic::Ordering; +use core::sync::atomic::{AtomicU16, AtomicU8}; #[cfg(feature = "kernel")] use bytemuck::bytes_of; diff --git a/common/supervisor-services/src/notification_buffer.rs b/common/supervisor-services/src/notification_buffer.rs index e08cb6bc..4901a638 100644 --- a/common/supervisor-services/src/notification_buffer.rs +++ b/common/supervisor-services/src/notification_buffer.rs @@ -1,39 +1,28 @@ -use core::sync::atomic::{AtomicU64, Ordering}; - #[cfg(feature = "supervisor")] -use bit_field::BitField; +use constants::ApBitmap; +#[cfg(feature = "kernel")] +use constants::ApIndex; +use constants::AtomicApBitmap; #[repr(C, align(64))] -pub struct NotificationBuffer { - bits: [AtomicU64; 2], -} +pub struct NotificationBuffer(AtomicApBitmap); impl NotificationBuffer { #[cfg(feature = "kernel")] pub(crate) const fn new() -> Self { - Self { - bits: [const { AtomicU64::new(0) }; 2], - } + Self(AtomicApBitmap::empty()) } /// Tell the supervisor to wake up the vCPU after it's done processing /// commands. #[cfg(feature = "kernel")] - pub fn arm(&self, vcpu: usize) { - let word_index = vcpu / 64; - let bit_index = vcpu % 64; - self.bits[word_index].fetch_or(1 << bit_index, Ordering::SeqCst); + pub fn arm(&self, vcpu: ApIndex) { + self.0.set(vcpu); } /// Return an iterator yielding all vCPUs that requested to be woken up. #[cfg(feature = "supervisor")] - pub fn reset(&self) -> impl Iterator + '_ { - self.bits - .iter() - .map(|bits| bits.swap(0, Ordering::SeqCst)) - .flat_map(|bits| (0..64).map(move |i| bits.get_bit(i))) - .enumerate() - .filter(|(_, armed)| *armed) - .map(|(i, _)| i) + pub fn reset(&self) -> ApBitmap { + self.0.take_all() } } diff --git a/host/Cargo.lock b/host/Cargo.lock index 65a9aa24..b80e32a1 100644 --- a/host/Cargo.lock +++ b/host/Cargo.lock @@ -1441,6 +1441,7 @@ version = "0.1.0" dependencies = [ "bit_field", "bytemuck", + "constants", ] [[package]] diff --git a/host/mushroom/src/insecure.rs b/host/mushroom/src/insecure.rs index dfe2943f..bac683ed 100644 --- a/host/mushroom/src/insecure.rs +++ b/host/mushroom/src/insecure.rs @@ -195,7 +195,7 @@ pub fn main( while command_buffer_reader.handle(&mut handler) {} for i in supervisor_services.notification_buffer.reset() { - ap_threads[i].thread().unpark(); + ap_threads[usize::from(i.as_u8())].thread().unpark(); } if let Some(finish_status) = handler.finish_status { diff --git a/host/mushroom/src/profiler.rs b/host/mushroom/src/profiler.rs index 3fc9b81d..fd1a5a62 100644 --- a/host/mushroom/src/profiler.rs +++ b/host/mushroom/src/profiler.rs @@ -5,14 +5,13 @@ use std::{ mem::size_of, path::Path, process::{Command, Stdio}, - sync::{atomic::Ordering, Arc, Condvar, Mutex}, + sync::{Arc, Condvar, Mutex}, }; use anyhow::{ensure, Context, Error, Result}; -use bit_field::BitField; use bitflags::bitflags; use bytemuck::{bytes_of, cast_slice, zeroed_box, NoUninit}; -use constants::MAX_APS_COUNT; +use constants::{ApIndex, MAX_APS_COUNT}; use profiler_types::{AllEntries, Entry, PerCpuEntries, ProfilerControl, CALL_STACK_CAPACITY}; use rand::random; use tracing::warn; @@ -108,11 +107,7 @@ fn notification_poll_thread( continue; } - let byte_idx = i / 8; - let bit_idx = i % 8; - let bits = notify_flags_ptr[byte_idx].load(Ordering::SeqCst); - let bit = bits.get_bit(bit_idx); - if bit { + if notify_flags_ptr.get(ApIndex::new(i as u8)) { *state = State::Notified; } } @@ -140,9 +135,7 @@ fn notification_poll_thread( *state = State::Idle; // Unset the notify bit. - let byte_idx = i / 8; - let bit_idx = i % 8; - notify_flags_ptr[byte_idx].fetch_and(!(1 << bit_idx), Ordering::SeqCst); + notify_flags_ptr.take(ApIndex::new(i as u8)); // Re-add the idx to the list of available collector threads. available_collector_thread_controls.push(idx); diff --git a/tee/Cargo.lock b/tee/Cargo.lock index 7f0b9da4..75fb8447 100644 --- a/tee/Cargo.lock +++ b/tee/Cargo.lock @@ -716,6 +716,7 @@ version = "0.1.0" dependencies = [ "bit_field", "bytemuck", + "constants", ] [[package]] diff --git a/tee/kernel/Cargo.toml b/tee/kernel/Cargo.toml index 08e2a2e6..d338dcf2 100644 --- a/tee/kernel/Cargo.toml +++ b/tee/kernel/Cargo.toml @@ -19,7 +19,7 @@ async-trait = "0.1.77" bit_field = "0.10.2" bitflags = { version = "2.4.2", features = ["bytemuck"] } bytemuck = { version = "1.15.0", features = ["derive", "min_const_generics"] } -constants = { workspace = true } +constants = { workspace = true, features = ["nightly"] } crossbeam-queue = { version = "0.3.11", default-features = false, features = ["alloc"] } crossbeam-utils = { version = "0.8.19", default-features = false } futures = { version = "0.3.30", default-features = false, features = ["async-await", "alloc"] } diff --git a/tee/kernel/src/main.rs b/tee/kernel/src/main.rs index 054f8b52..d2ee54bc 100644 --- a/tee/kernel/src/main.rs +++ b/tee/kernel/src/main.rs @@ -70,7 +70,7 @@ unsafe fn main() -> ! { PerCpu::init(); #[cfg(feature = "profiling")] - if PerCpu::get().idx == 0 { + if PerCpu::get().idx.is_first() { unsafe { crate::profiler::init(); } @@ -90,7 +90,7 @@ extern "C" fn init() -> ! { } // The first AP does some extra initialization work. - if PerCpu::get().idx == 0 { + if PerCpu::get().idx.is_first() { user::process::start_init_process(); } diff --git a/tee/kernel/src/per_cpu.rs b/tee/kernel/src/per_cpu.rs index 459f5625..cda8235f 100644 --- a/tee/kernel/src/per_cpu.rs +++ b/tee/kernel/src/per_cpu.rs @@ -6,7 +6,7 @@ use core::{ }; use alloc::sync::Arc; -use constants::MAX_APS_COUNT; +use constants::{ApIndex, MAX_APS_COUNT}; use x86_64::{ registers::segmentation::{Segment64, GS}, structures::{gdt::GlobalDescriptorTable, paging::Page, tss::TaskStateSegment}, @@ -25,7 +25,7 @@ static mut STORAGE: [PerCpu; MAX_APS_COUNT as usize] = #[repr(C)] pub struct PerCpu { this: *mut PerCpu, - pub idx: usize, + pub idx: ApIndex, pub kernel_registers: Cell, pub new_userspace_registers: Cell, pub temporary_mapping: OnceCell>, @@ -44,7 +44,7 @@ impl PerCpu { pub const fn new() -> Self { Self { this: null_mut(), - idx: 0, + idx: ApIndex::new(0), kernel_registers: Cell::new(KernelRegisters::ZERO), new_userspace_registers: Cell::new(Registers::ZERO), temporary_mapping: OnceCell::new(), @@ -78,7 +78,7 @@ impl PerCpu { let idx = COUNT.fetch_add(1, Ordering::SeqCst); let ptr = unsafe { &mut STORAGE[idx] }; ptr.this = ptr; - ptr.idx = idx; + ptr.idx = ApIndex::new(u8::try_from(idx).unwrap()); let addr = VirtAddr::from_ptr(ptr); unsafe { diff --git a/tee/kernel/src/profiler.rs b/tee/kernel/src/profiler.rs index 1c38f2d2..53c0fd33 100644 --- a/tee/kernel/src/profiler.rs +++ b/tee/kernel/src/profiler.rs @@ -1,11 +1,10 @@ use core::arch::{asm, global_asm}; use core::mem::{offset_of, size_of}; -use core::sync::atomic::AtomicU8; -use constants::MAX_APS_COUNT; +use constants::{AtomicApBitmap, MAX_APS_COUNT}; use profiler_types::{ AllEntries, Entry, PerCpuEntries, PerCpuHeader, ProfilerControl, CALL_STACK_CAPACITY, - NOTIFY_BYTES, PROFILER_ENTRIES, + PROFILER_ENTRIES, }; use x86_64::registers::model_specific::Msr; @@ -22,7 +21,7 @@ pub fn flush() { #[link_section = ".profiler_control"] static mut PROFILER_CONTROL: ProfilerControl = ProfilerControl { - notify_flags: [const { AtomicU8::new(0) }; NOTIFY_BYTES], + notify_flags: AtomicApBitmap::empty(), headers: [PerCpuHeader { start_idx: 0, len: 0, diff --git a/tee/kernel/src/supervisor.rs b/tee/kernel/src/supervisor.rs index 3913d08a..cfc7cd1a 100644 --- a/tee/kernel/src/supervisor.rs +++ b/tee/kernel/src/supervisor.rs @@ -3,7 +3,7 @@ use core::arch::asm; use crate::spin::mutex::Mutex; use arrayvec::ArrayVec; use bit_field::BitField; -use constants::{physical_address::DYNAMIC_2MIB, MAX_APS_COUNT}; +use constants::{physical_address::DYNAMIC_2MIB, ApBitmap, ApIndex}; use supervisor_services::{ allocation_buffer::SlotIndex, command_buffer::{ @@ -197,7 +197,7 @@ pub struct Scheduler(Mutex); struct SchedulerState { /// One bit for every vCPU. - bits: u128, + bits: ApBitmap, /// The number of vCPUs that have finished being launched. launched: u8, /// Whether a vCPU is being launched right now. @@ -206,8 +206,11 @@ struct SchedulerState { impl Scheduler { pub const fn new() -> Self { + let mut bits = ApBitmap::empty(); + bits.set(ApIndex::new(0), true); + Self(Mutex::new(SchedulerState { - bits: 1, + bits, launched: 0, is_launching: true, })) @@ -215,18 +218,18 @@ impl Scheduler { fn pick_any(&self) -> Option { let mut state = self.0.lock(); - let idx = state.bits.trailing_ones() as u8; - if idx < state.launched { - state.bits.set_bit(usize::from(idx), true); + let idx = state.bits.first_unset()?; + if idx.as_u8() < state.launched { + state.bits.set(idx, true); Some(ScheduledCpu::Existing(idx)) } else { - if state.is_launching || idx >= MAX_APS_COUNT { + if state.is_launching { return None; } state.is_launching = true; - state.bits.set_bit(usize::from(idx), true); + state.bits.set(idx, true); Some(ScheduledCpu::New) } @@ -235,9 +238,9 @@ impl Scheduler { fn halt(&self) -> Result<(), LastRunningVcpuError> { let mut state = self.0.lock(); let mut new_bits = state.bits; - new_bits.set_bit(PerCpu::get().idx, false); + new_bits.set(PerCpu::get().idx, false); // Ensure that this vCPU isn't the last one running. - if new_bits == 0 { + if new_bits.is_empty() { return Err(LastRunningVcpuError); } state.bits = new_bits; @@ -246,7 +249,7 @@ impl Scheduler { fn resume(&self) { let mut state = self.0.lock(); - state.bits.set_bit(PerCpu::get().idx, true); + state.bits.set(PerCpu::get().idx, true); } pub fn finish_launch(&self) { @@ -258,7 +261,7 @@ impl Scheduler { } enum ScheduledCpu { - Existing(u8), + Existing(ApIndex), New, } @@ -270,9 +273,7 @@ pub fn schedule_vcpu() { match cpu { ScheduledCpu::Existing(cpu) => { - SUPERVISOR_SERVICES - .notification_buffer - .arm(usize::from(cpu)); + SUPERVISOR_SERVICES.notification_buffer.arm(cpu); kick_supervisor(true); } ScheduledCpu::New => start_next_ap(), diff --git a/tee/supervisor-snp/src/ap.rs b/tee/supervisor-snp/src/ap.rs index 2c462c56..7b8637b6 100644 --- a/tee/supervisor-snp/src/ap.rs +++ b/tee/supervisor-snp/src/ap.rs @@ -1,6 +1,6 @@ use core::cell::Cell; -use constants::{FIRST_AP, KICK_AP_PORT, MAX_APS_COUNT}; +use constants::{ApIndex, FIRST_AP, KICK_AP_PORT}; use snp_types::vmsa::SevFeatures; use crate::{ @@ -21,26 +21,26 @@ const SEV_FEATURES: SevFeatures = SevFeatures::from_bits_truncate( pub fn start_next_ap() { static APIC_COUNTER: FakeSync> = FakeSync::new(Cell::new(0)); - let apic_id = APIC_COUNTER.get(); - if apic_id >= MAX_APS_COUNT { + let counter = APIC_COUNTER.get(); + let Some(apic_id) = ApIndex::try_new(counter) else { return; - } - APIC_COUNTER.set(apic_id + 1); + }; + APIC_COUNTER.set(counter + 1); // Initialize the VMSA. - let mut vmsa = InitializedVmsa::new(vmsa_tweak_bitmap(), u32::from(apic_id)); + let mut vmsa = InitializedVmsa::new(vmsa_tweak_bitmap(), u32::from(apic_id.as_u8())); unsafe { vmsa.set_runnable(true); } // Tell the host about the new VMSA. let vmsa_pa = vmsa.phys_addr(); - create_ap(u32::from(FIRST_AP + apic_id), vmsa_pa, SEV_FEATURES); + create_ap(u32::from(FIRST_AP + apic_id.as_u8()), vmsa_pa, SEV_FEATURES); // Start the AP. kick(apic_id); } -pub fn kick(apic_id: u8) { - ioio_write(KICK_AP_PORT, u32::from(FIRST_AP + apic_id)); +pub fn kick(apic_id: ApIndex) { + ioio_write(KICK_AP_PORT, u32::from(FIRST_AP + apic_id.as_u8())); } diff --git a/tee/supervisor-snp/src/services.rs b/tee/supervisor-snp/src/services.rs index 617e3f4e..c21ad83a 100644 --- a/tee/supervisor-snp/src/services.rs +++ b/tee/supervisor-snp/src/services.rs @@ -32,7 +32,7 @@ pub fn run() -> ! { // command. for id in supervisor_services().notification_buffer.reset() { - kick(id as u8); + kick(id); } wait_for_command(); diff --git a/tee/supervisor-tdx/src/main.rs b/tee/supervisor-tdx/src/main.rs index 1c702a5d..a24eda0d 100644 --- a/tee/supervisor-tdx/src/main.rs +++ b/tee/supervisor-tdx/src/main.rs @@ -38,7 +38,7 @@ fn main() -> ! { setup_idt(); - if PerCpu::current_vcpu_index() == 0 { + if PerCpu::current_vcpu_index().is_first() { input::init(); } diff --git a/tee/supervisor-tdx/src/per_cpu.rs b/tee/supervisor-tdx/src/per_cpu.rs index f8a6514d..2447718b 100644 --- a/tee/supervisor-tdx/src/per_cpu.rs +++ b/tee/supervisor-tdx/src/per_cpu.rs @@ -1,16 +1,17 @@ use core::{arch::asm, cell::Cell}; +use constants::ApIndex; use x86_64::instructions::interrupts; #[repr(C)] pub struct PerCpu { this: *mut Self, - pub vcpu_index: usize, + pub vcpu_index: ApIndex, pub pending_flushes: Cell, } impl PerCpu { - pub fn new(this: *mut Self, vcpu_index: usize) -> Self { + pub fn new(this: *mut Self, vcpu_index: ApIndex) -> Self { Self { this, vcpu_index, @@ -29,7 +30,7 @@ impl PerCpu { }) } - pub fn current_vcpu_index() -> usize { + pub fn current_vcpu_index() -> ApIndex { Self::with(|per_cpu| per_cpu.vcpu_index) } } diff --git a/tee/supervisor-tdx/src/reset_vector.rs b/tee/supervisor-tdx/src/reset_vector.rs index 905348fe..e8337375 100644 --- a/tee/supervisor-tdx/src/reset_vector.rs +++ b/tee/supervisor-tdx/src/reset_vector.rs @@ -1,6 +1,6 @@ use core::{arch::global_asm, mem::MaybeUninit}; -use constants::MAX_APS_COUNT; +use constants::{ApIndex, MAX_APS_COUNT}; use x86_64::{registers::model_specific::FsBase, VirtAddr}; use crate::{main, per_cpu::PerCpu}; @@ -14,7 +14,7 @@ global_asm!( ); #[export_name = "_start"] -extern "sysv64" fn premain(vcpu_index: usize) { +extern "sysv64" fn premain(vcpu_index: ApIndex) { // Setup a `PerCpu` instance for the current cpu. let mut per_cpu = MaybeUninit::uninit(); let ptr = per_cpu.as_mut_ptr(); diff --git a/tee/supervisor-tdx/src/services.rs b/tee/supervisor-tdx/src/services.rs index 8b8f29c7..9b023a26 100644 --- a/tee/supervisor-tdx/src/services.rs +++ b/tee/supervisor-tdx/src/services.rs @@ -29,6 +29,8 @@ static HANDLER: Lazy> = Lazy::new(|| Mutex::new(Handler::new())); pub fn handle(resume: bool) { interrupts::disable(); + let idx = PerCpu::current_vcpu_index(); + if let Some(mut handler) = HANDLER.try_lock() { let mut command_buffer_reader = CommandBufferReader::new(&supervisor_services().command_buffer); @@ -37,10 +39,10 @@ pub fn handle(resume: bool) { let mut saw_self = false; for id in supervisor_services().notification_buffer.reset() { - if id == PerCpu::current_vcpu_index() { + if id == idx { saw_self = true; } else { - send_ipi(id as u32, WAKEUP_VECTOR); + send_ipi(u32::from(id.as_u8()), WAKEUP_VECTOR); } } if saw_self { diff --git a/tee/supervisor-tdx/src/tlb.rs b/tee/supervisor-tdx/src/tlb.rs index f4635a19..515f7933 100644 --- a/tee/supervisor-tdx/src/tlb.rs +++ b/tee/supervisor-tdx/src/tlb.rs @@ -1,7 +1,4 @@ -use core::sync::atomic::{AtomicU64, Ordering}; - -use bit_field::BitField; -use constants::MAX_APS_COUNT; +use constants::AtomicApBitmap; use spin::mutex::SpinMutex; use x86_64::structures::idt::InterruptStackFrame; @@ -11,28 +8,26 @@ use crate::{ }; static GUARD: SpinMutex<()> = SpinMutex::new(()); -static COUNTER: AtomicU64 = AtomicU64::new(0); -static RAN: AtomicU64 = AtomicU64::new(0); +static COUNTER: AtomicApBitmap = AtomicApBitmap::empty(); +static RAN: AtomicApBitmap = AtomicApBitmap::empty(); /// This function must be called before entering the vCPU. pub fn pre_enter() { - RAN.fetch_or(1 << PerCpu::current_vcpu_index(), Ordering::SeqCst); + RAN.set(PerCpu::current_vcpu_index()); } /// Flush the entire TLB on all vCPUs. pub fn flush() { let _guard = GUARD.lock(); - let mask = RAN.swap(0, Ordering::Relaxed); - COUNTER.fetch_or(mask, Ordering::Relaxed); + let mask = RAN.take_all(); + COUNTER.set_all(mask); drop(_guard); - for i in 0..MAX_APS_COUNT { - if mask.get_bit(usize::from(i)) { - send_ipi(u32::from(i), FLUSH_VECTOR); - } + for idx in mask { + send_ipi(u32::from(idx.as_u8()), FLUSH_VECTOR); } - while COUNTER.load(Ordering::SeqCst) != 0 {} + while !COUNTER.get_all().is_empty() {} } pub extern "x86-interrupt" fn flush_handler(_frame: InterruptStackFrame) { @@ -40,7 +35,7 @@ pub extern "x86-interrupt" fn flush_handler(_frame: InterruptStackFrame) { per_cpu.pending_flushes.set(true); per_cpu.vcpu_index }); - COUNTER.fetch_and(!(1 << vcpu_index), core::sync::atomic::Ordering::Relaxed); + COUNTER.take(vcpu_index); eoi(); } diff --git a/tee/supervisor-tdx/src/vcpu.rs b/tee/supervisor-tdx/src/vcpu.rs index f6df8b3f..9b9672ce 100644 --- a/tee/supervisor-tdx/src/vcpu.rs +++ b/tee/supervisor-tdx/src/vcpu.rs @@ -51,7 +51,7 @@ pub fn wait_for_vcpu_start() { interrupts::disable(); let ready = READY.load(Ordering::Relaxed); - if ready == PerCpu::current_vcpu_index() { + if ready == usize::from(PerCpu::current_vcpu_index().as_u8()) { break; } From 35a4cf5a00c4ef50ae232c39faa531accd4e8e3f Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Tue, 12 Nov 2024 18:20:19 +0100 Subject: [PATCH 10/21] enable interrupts during userspace execution --- tee/kernel/src/user/process/syscall/cpu_state.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tee/kernel/src/user/process/syscall/cpu_state.rs b/tee/kernel/src/user/process/syscall/cpu_state.rs index 57f003e7..a59a83bc 100644 --- a/tee/kernel/src/user/process/syscall/cpu_state.rs +++ b/tee/kernel/src/user/process/syscall/cpu_state.rs @@ -615,7 +615,7 @@ impl Registers { fs: 0, gs: 0, ss: 0x23, - rflags: 2, + rflags: 2 | 1 << 9, ..Self::ZERO }; } @@ -763,7 +763,7 @@ struct FpxSwBytes { pub fn init() { LStar::write(VirtAddr::new(syscall_entry as usize as u64)); - SFMask::write(RFlags::DIRECTION_FLAG); + SFMask::write(RFlags::DIRECTION_FLAG | RFlags::INTERRUPT_FLAG); } unsafe extern "sysv64" { @@ -801,6 +801,9 @@ global_asm!( "pop rax", "mov gs:[{K_RFLAGS_OFFSET}], rax", + // Disable interrupts. + "cli", + // Restore user state. // Restore segment registers. "xor rax, rax", From 4aa7958cc3a932eb1b2a3299d9110ed9832e3144 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 11 Nov 2024 16:30:20 +0100 Subject: [PATCH 11/21] calling load_gdt directly --- tee/kernel/src/exception.rs | 10 ---------- tee/kernel/src/main.rs | 6 +----- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/tee/kernel/src/exception.rs b/tee/kernel/src/exception.rs index 511801b1..aac63e71 100644 --- a/tee/kernel/src/exception.rs +++ b/tee/kernel/src/exception.rs @@ -31,16 +31,6 @@ use x86_64::{ use crate::{memory::pagetable::entry_for_page, per_cpu::PerCpu}; -/// Initialize exception handling. -/// -/// # Safety -/// -/// This function must only be called once by main. -pub unsafe fn init() { - load_gdt(); - load_idt(); -} - pub fn switch_stack(f: extern "C" fn() -> !) -> ! { let stack = allocate_stack(); diff --git a/tee/kernel/src/main.rs b/tee/kernel/src/main.rs index d2ee54bc..20e58761 100644 --- a/tee/kernel/src/main.rs +++ b/tee/kernel/src/main.rs @@ -83,11 +83,7 @@ unsafe fn main() -> ! { } extern "C" fn init() -> ! { - unsafe { - // SAFETY: We're the only ones calling these functions and we're only - // called once. - exception::init(); - } + exception::load_gdt(); // The first AP does some extra initialization work. if PerCpu::get().idx.is_first() { From 85df06397ac1b4576ce28d91f3df841cfe255fcf Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 11 Nov 2024 17:43:03 +0100 Subject: [PATCH 12/21] flush less --- tee/kernel/src/memory/heap/huge_allocator.rs | 4 ++-- tee/kernel/src/memory/pagetable.rs | 25 +++++++++++++------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/tee/kernel/src/memory/heap/huge_allocator.rs b/tee/kernel/src/memory/heap/huge_allocator.rs index 5a64acfe..3cec5910 100644 --- a/tee/kernel/src/memory/heap/huge_allocator.rs +++ b/tee/kernel/src/memory/heap/huge_allocator.rs @@ -12,7 +12,7 @@ use x86_64::{ use crate::memory::{ frame::{allocate_frame, deallocate_frame}, - pagetable::{map_page, unmap_page, PageTableFlags, PresentPageTableEntry}, + pagetable::{map_page, unmap_page_no_flush, PageTableFlags, PresentPageTableEntry}, }; pub struct HugeAllocator { @@ -91,7 +91,7 @@ unsafe impl Allocator for HugeAllocator { let base = Page::::from_start_address(addr).unwrap(); for page in (base..).take(pages) { - let entry = unsafe { unmap_page(page) }; + let entry = unsafe { unmap_page_no_flush(page) }; let frame = entry.frame(); unsafe { deallocate_frame(frame); diff --git a/tee/kernel/src/memory/pagetable.rs b/tee/kernel/src/memory/pagetable.rs index 1e2e00a3..8aa86c44 100644 --- a/tee/kernel/src/memory/pagetable.rs +++ b/tee/kernel/src/memory/pagetable.rs @@ -188,12 +188,12 @@ pub unsafe fn map_page(page: Page, entry: PresentPageTableEntry) -> Result<()> { Ok(()) } -/// Unmap a page. +/// Unmap a page without flushing it from the TLB. /// /// # Panics /// /// This function panics if the page is not mapped. -pub unsafe fn unmap_page(page: Page) -> PresentPageTableEntry { +pub unsafe fn unmap_page_no_flush(page: Page) -> PresentPageTableEntry { trace!("unmapping page {page:?}"); let level4 = ActivePageTable::get(); @@ -211,7 +211,19 @@ pub unsafe fn unmap_page(page: Page) -> PresentPageTableEntry { let level1 = &*level1_guard; let level1_entry = &level1[page.p1_index()]; - unsafe { level1_entry.unmap() } + unsafe { level1_entry.unmap_no_flush() } +} + +/// Unmap a page. +/// +/// # Panics +/// +/// This function panics if the page is not mapped. +#[cfg(sanitize = "address")] +pub unsafe fn unmap_page(page: Page) -> PresentPageTableEntry { + let entry = unsafe { unmap_page_no_flush(page) }; + GlobalFlushGuard.flush_page(page); + entry } pub fn entry_for_page(page: Page) -> Option { @@ -1104,8 +1116,6 @@ impl ActivePageTableEntry { self.parent_table_entry() .increase_reference_count() .unwrap(); - - self.flush(true); } /// Map a new page or replace an existing page. @@ -1116,17 +1126,14 @@ impl ActivePageTableEntry { .increase_reference_count() .unwrap(); } - - self.flush(true); } /// # Panics /// /// Panics if the page isn't mapped. - pub unsafe fn unmap(&self) -> PresentPageTableEntry { + pub unsafe fn unmap_no_flush(&self) -> PresentPageTableEntry { let old_entry = atomic_swap(&self.entry, 0); let old_entry = PresentPageTableEntry::try_from(old_entry).unwrap(); - self.flush(old_entry.global()); unsafe { self.parent_table_entry().release_reference_count_fast(); From f97b168a3a74cb62ca2ca214a48588bf970396ed Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 11 Nov 2024 17:46:36 +0100 Subject: [PATCH 13/21] decode more vm exit information --- common/tdx-types/src/tdcall.rs | 8 +++--- tee/supervisor-tdx/src/tdcall.rs | 44 ++++++++++++++++++++++++++++++-- tee/supervisor-tdx/src/vcpu.rs | 17 ++++++------ 3 files changed, 54 insertions(+), 15 deletions(-) diff --git a/common/tdx-types/src/tdcall.rs b/common/tdx-types/src/tdcall.rs index 42501d3c..db782313 100644 --- a/common/tdx-types/src/tdcall.rs +++ b/common/tdx-types/src/tdcall.rs @@ -3,10 +3,10 @@ use bitflags::bitflags; use bytemuck::{Pod, Zeroable}; use x86_64::registers::rflags::RFlags; -pub const TDX_SUCCESS: u32 = 0x00000000; -pub const TDX_L2_EXIT_HOST_ROUTED_ASYNC: u32 = 0x00001100; -pub const TDX_L2_EXIT_PENDING_INTERRUPT: u32 = 0x00001102; -pub const TDX_PENDING_INTERRUPT: u32 = 0x00001120; +pub const TDX_SUCCESS: u16 = 0x0000; +pub const TDX_L2_EXIT_HOST_ROUTED_ASYNC: u16 = 0x1100; +pub const TDX_L2_EXIT_PENDING_INTERRUPT: u16 = 0x1102; +pub const TDX_PENDING_INTERRUPT: u16 = 0x1120; #[derive(Debug, Clone, Copy)] #[repr(C, align(256))] diff --git a/tee/supervisor-tdx/src/tdcall.rs b/tee/supervisor-tdx/src/tdcall.rs index c585ae01..8c12b158 100644 --- a/tee/supervisor-tdx/src/tdcall.rs +++ b/tee/supervisor-tdx/src/tdcall.rs @@ -232,7 +232,7 @@ impl Tdcall { invd_translations: InvdTranslations, guest_state: &mut GuestState, with_sti: bool, - ) -> (u64, u32) { + ) -> VmExit { let mut tdcall = Self::new(25); tdcall.rcx.set_bits(0..=1, invd_translations as u64); tdcall.rcx.set_bits(52..=53, index as u64); @@ -242,7 +242,25 @@ impl Tdcall { tdcall.execute(); } - (tdcall.rax, (tdcall.r11 >> 32) as u32) + VmExit { + class: tdcall.rax.get_bits(32..=47) as u16, + exit_reason: tdcall.rax as u32, + exit_qualification: tdcall.rcx, + guest_linear_address: tdcall.rdx, + cs_selector: tdcall.rsi.get_bits(0..=15) as u16, + cs_ar_bit: tdcall.rsi.get_bits(16..=31) as u16, + cs_limit: tdcall.rsi.get_bits(32..) as u32, + cs_base: tdcall.rdi, + guest_physical_address: tdcall.r8, + vm_exit_interruption_information: tdcall.r9.get_bits(..=31) as u32, + vm_exit_interruption_error_code: tdcall.r9.get_bits(32..) as u32, + idt_vectoring_information: tdcall.r10.get_bits(..=31) as u32, + idt_vectoring_error_code: tdcall.r10.get_bits(32..) as u32, + vm_exit_instruction_information: tdcall.r11.get_bits(..=31) as u32, + vm_exit_instruction_length: tdcall.r11.get_bits(32..) as u32, + cpl: tdcall.r12.get_bits(0..=1) as u8, + extended_exit_qualification: tdcall.r13.get_bits(..=3) as u8, + } } pub fn vp_veinfo_get() -> VeInfo { @@ -263,6 +281,28 @@ impl Tdcall { } } +#[derive(Debug)] +#[expect(dead_code)] +pub struct VmExit { + pub class: u16, + pub exit_reason: u32, + pub exit_qualification: u64, + pub guest_linear_address: u64, + pub cs_selector: u16, + pub cs_ar_bit: u16, + pub cs_limit: u32, + pub cs_base: u64, + pub guest_physical_address: u64, + pub vm_exit_interruption_information: u32, + pub vm_exit_interruption_error_code: u32, + pub idt_vectoring_information: u32, + pub idt_vectoring_error_code: u32, + pub vm_exit_instruction_information: u32, + pub vm_exit_instruction_length: u32, + pub cpl: u8, + pub extended_exit_qualification: u8, +} + #[expect(dead_code)] pub struct VeInfo { pub exit_reason: u32, diff --git a/tee/supervisor-tdx/src/vcpu.rs b/tee/supervisor-tdx/src/vcpu.rs index 9b9672ce..4fb766f3 100644 --- a/tee/supervisor-tdx/src/vcpu.rs +++ b/tee/supervisor-tdx/src/vcpu.rs @@ -162,10 +162,9 @@ pub fn run_vcpu() -> ! { } else { InvdTranslations::None }; - let (exit_reason, instruction_length) = - Tdcall::vp_enter(VmIndex::One, flush, &mut guest_state, true); + let vm_exit = Tdcall::vp_enter(VmIndex::One, flush, &mut guest_state, true); - match (exit_reason >> 32) as u32 { + match vm_exit.class { TDX_SUCCESS => {} TDX_L2_EXIT_HOST_ROUTED_ASYNC => continue, TDX_L2_EXIT_PENDING_INTERRUPT => continue, @@ -173,7 +172,7 @@ pub fn run_vcpu() -> ! { reason => unimplemented!("{reason:#010x}"), } - match exit_reason as u32 { + match vm_exit.exit_reason { VMEXIT_REASON_CPUID_INSTRUCTION => { let result = unsafe { __cpuid_count(guest_state.rax as u32, guest_state.rcx as u32) }; @@ -181,11 +180,11 @@ pub fn run_vcpu() -> ! { guest_state.rbx = u64::from(result.ebx); guest_state.rcx = u64::from(result.ecx); guest_state.rdx = u64::from(result.edx); - guest_state.rip += u64::from(instruction_length); + guest_state.rip += u64::from(vm_exit.vm_exit_instruction_length); } VMEXIT_REASON_HLT_INSTRUCTION => { handle(guest_state.rax != 0); - guest_state.rip += u64::from(instruction_length); + guest_state.rip += u64::from(vm_exit.vm_exit_instruction_length); } VMEXIT_REASON_VMCALL_INSTRUCTION => { // The kernel currently only executes vmcalls to flush the TLB. @@ -204,7 +203,7 @@ pub fn run_vcpu() -> ! { tlb::flush(); guest_state.rax = 0; - guest_state.rip += u64::from(instruction_length); + guest_state.rip += u64::from(vm_exit.vm_exit_instruction_length); } VMEXIT_REASON_MSR_WRITE => { match guest_state.rcx { @@ -213,9 +212,9 @@ pub fn run_vcpu() -> ! { } rcx => panic!("{rcx:#x}"), } - guest_state.rip += u64::from(instruction_length); + guest_state.rip += u64::from(vm_exit.vm_exit_instruction_length); } - unknown => panic!("{unknown:#x} {guest_state:x?}"), + unknown => panic!("{unknown:#x} {guest_state:x?} {vm_exit:x?}"), } } } From 871617653ccbef722e4ada7fd999fb9b31737965 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 11 Nov 2024 17:58:41 +0100 Subject: [PATCH 14/21] add support for sending IPIs in supervisor-tdx --- common/tdx-types/src/tdcall.rs | 76 ++++++++++- tee/kernel/src/exception.rs | 2 + tee/kernel/src/memory/pagetable.rs | 2 + tee/kernel/src/memory/pagetable/flush.rs | 78 ++++++++++++ tee/kernel/src/supervisor.rs | 2 +- tee/supervisor-tdx/src/main.rs | 4 +- tee/supervisor-tdx/src/services.rs | 14 +- tee/supervisor-tdx/src/vcpu.rs | 155 ++++++++++++++++------- 8 files changed, 269 insertions(+), 64 deletions(-) create mode 100644 tee/kernel/src/memory/pagetable/flush.rs diff --git a/common/tdx-types/src/tdcall.rs b/common/tdx-types/src/tdcall.rs index db782313..3df87bbf 100644 --- a/common/tdx-types/src/tdcall.rs +++ b/common/tdx-types/src/tdcall.rs @@ -1,3 +1,5 @@ +use core::sync::atomic::{AtomicU32, Ordering}; + use bit_field::BitField; use bitflags::bitflags; use bytemuck::{Pod, Zeroable}; @@ -57,6 +59,9 @@ impl MdFieldId { pub const SFMASK_WRITE: Self = Self::msr_bitmaps1(0xC000_0084, true); pub const SFMASK_WRITE_MASK: u64 = Self::msr_bitmaps_mask(0xC000_0084); + pub const X2APIC_EOI_WRITE: Self = Self::msr_bitmaps1(0x80b, true); + pub const X2APIC_EOI_WRITE_MASK: u64 = Self::msr_bitmaps_mask(0x80b); + pub const TDVPS_L2_CTLS1: Self = Self::new( 81, ElementSizeCode::SixtyFour, @@ -200,10 +205,77 @@ pub enum VmIndex { } #[repr(C, align(4096))] -pub struct Apic([u8; 4096]); +pub struct Apic([AtomicU32; 512]); + +impl Apic { + pub const fn new() -> Self { + Self([const { AtomicU32::new(0) }; 512]) + } + + /// Set the local APIC id. + pub fn set_id(&self, value: u32) { + self.0[0x20 / 4].store(value, Ordering::SeqCst); + } + + /// Returns the highest priority requested interrupt. + pub fn pending_vectora_todo(&self) -> Option { + (0..8).rev().find_map(|i| { + let offset = 0x100 | (i * 16); + let idx = offset / 4; + let irr = self.0[idx].load(Ordering::SeqCst); + (irr != 0).then(|| i as u8 * 32 + 31 - irr.leading_zeros() as u8) + }) + } + + /// Returns the highest priority requested interrupt. + pub fn pending_vector(&self) -> Option { + (0..8).rev().find_map(|i| { + let offset = 0x200 | (i * 16); + let idx = offset / 4; + let irr = self.0[idx].load(Ordering::SeqCst); + (irr != 0).then(|| i as u8 * 32 + 31 - irr.leading_zeros() as u8) + }) + } + + /// Sets the IRR bit for the given vector and returns the previous value. + pub fn set_irr(&self, vector: u8) -> bool { + let offset = 0x200 | ((usize::from(vector) & 0xe0) >> 1); + let idx = offset / 4; + let mask = 1u32.wrapping_shl(u32::from(vector)); + self.0[idx].fetch_or(mask, Ordering::SeqCst) & mask != 0 + } +} impl Default for Apic { fn default() -> Self { - Self([0; 4096]) + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::Apic; + + #[test] + fn rvi() { + // Check that the value is correct if only a single bit is set. + for i in 0..=255 { + let apic = Apic::new(); + assert_eq!(apic.pending_vector(), None); + apic.set_irr(i); + assert_eq!(apic.pending_vector(), Some(i)); + } + + // Check that that higher vectors are prioritized. + for i in 0..=255 { + for j in 0..i { + let apic = Apic::new(); + assert_eq!(apic.pending_vector(), None); + apic.set_irr(j); + assert_eq!(apic.pending_vector(), Some(j)); + apic.set_irr(i); + assert_eq!(apic.pending_vector(), Some(i)); + } + } } } diff --git a/tee/kernel/src/exception.rs b/tee/kernel/src/exception.rs index aac63e71..2c0c6bb2 100644 --- a/tee/kernel/src/exception.rs +++ b/tee/kernel/src/exception.rs @@ -7,6 +7,7 @@ use core::{ ptr::null_mut, }; +use crate::memory::pagetable::flush::{tlb_shootdown_handler, TLB_VECTOR}; use crate::spin::lazy::Lazy; use crate::user::process::syscall::cpu_state::exception_entry; use alloc::alloc::alloc; @@ -146,6 +147,7 @@ pub fn load_idt() { .set_handler_fn(general_protection_fault_handler); idt.page_fault.set_handler_fn(page_fault_handler); idt.vmm_communication_exception.set_handler_fn(vc_handler); + idt[TLB_VECTOR].set_handler_fn(tlb_shootdown_handler); idt[0x80] .set_handler_fn(int0x80_handler) diff --git a/tee/kernel/src/memory/pagetable.rs b/tee/kernel/src/memory/pagetable.rs index 8aa86c44..ca95f575 100644 --- a/tee/kernel/src/memory/pagetable.rs +++ b/tee/kernel/src/memory/pagetable.rs @@ -39,6 +39,8 @@ use super::{ temporary::copy_into_frame, }; +pub mod flush; + const RECURSIVE_INDEX: PageTableIndex = PageTableIndex::new(510); #[used] diff --git a/tee/kernel/src/memory/pagetable/flush.rs b/tee/kernel/src/memory/pagetable/flush.rs new file mode 100644 index 00000000..8717b98f --- /dev/null +++ b/tee/kernel/src/memory/pagetable/flush.rs @@ -0,0 +1,78 @@ +use core::arch::asm; + +use bit_field::BitField; +use constants::{ApBitmap, ApIndex, AtomicApBitmap}; +use x86_64::{ + instructions::tlb, + registers::{ + control::{Cr3, Cr4, Cr4Flags}, + model_specific::Msr, + }, + structures::idt::InterruptStackFrame, +}; + +pub const TLB_VECTOR: u8 = 0x20; + +static PENDING_TLB_SHOOTDOWN: AtomicApBitmap = AtomicApBitmap::empty(); +static PENDING_GLOBAL_TLB_SHOOTDOWN: AtomicApBitmap = AtomicApBitmap::empty(); + +fn process_flushes(idx: ApIndex) { + let need_global_flush = PENDING_GLOBAL_TLB_SHOOTDOWN.take(idx); + let need_non_global_flush = PENDING_TLB_SHOOTDOWN.take(idx); + if need_global_flush { + if Cr4::read().contains(Cr4Flags::PCID) { + unsafe { + tlb::flush_pcid(tlb::InvPicdCommand::All); + } + } else { + tlb::flush_all(); + } + } else if need_non_global_flush { + if Cr4::read().contains(Cr4Flags::PCID) { + // Flush the entire PCID. + // TODO: Flush less. + let (_, pcid) = Cr3::read_pcid(); + unsafe { + tlb::flush_pcid(tlb::InvPicdCommand::Single(pcid)); + } + } else { + tlb::flush_all(); + } + } +} + +pub extern "x86-interrupt" fn tlb_shootdown_handler(_: InterruptStackFrame) { + // This handler is only used for TDX. The value returned by `rdpid` can be + // controlled by the host on SNP, so if we ever need to use this handler on + // SNP, we'll have to use something else. + let ap_id: u64; + unsafe { + asm!( + "rdpid {}", + out(reg) ap_id, + ); + } + let idx = ApIndex::new(ap_id as u8); + + process_flushes(idx); + + // Signal EOI. + unsafe { + Msr::new(0x80b).write(0); + } +} + +fn send_tlb_ipis(aps: ApBitmap) { + for ap in aps { + let mut bits = 0; + bits.set_bits(0..8, u64::from(TLB_VECTOR)); + bits.set_bits(8..11, 0); // Delivery Mode: Fixed + bits.set_bit(11, false); // Destination Mode: Physical + bits.set_bit(14, true); // Level: Assert + bits.set_bits(18..20, 0b00); // Destination Shorthand: Destination + bits.set_bits(32.., u64::from(ap.as_u8())); // Destination + unsafe { + Msr::new(0x830).write(bits); + } + } +} diff --git a/tee/kernel/src/supervisor.rs b/tee/kernel/src/supervisor.rs index cfc7cd1a..96b2881c 100644 --- a/tee/kernel/src/supervisor.rs +++ b/tee/kernel/src/supervisor.rs @@ -76,7 +76,7 @@ where } } -/// Push a command, immediatly tell the supervisor about it, but don't wait for +/// Push a command, immediately tell the supervisor about it, but don't wait for /// it to complete. fn push_async_command(command: C) where diff --git a/tee/supervisor-tdx/src/main.rs b/tee/supervisor-tdx/src/main.rs index a24eda0d..8266a34c 100644 --- a/tee/supervisor-tdx/src/main.rs +++ b/tee/supervisor-tdx/src/main.rs @@ -7,7 +7,6 @@ use exception::setup_idt; use log::{debug, LevelFilter}; use per_cpu::PerCpu; use spin::Once; -use tdx_types::tdcall::Apic; use vcpu::{init_vcpu, run_vcpu, wait_for_vcpu_start}; use crate::logging::SerialLogger; @@ -42,8 +41,7 @@ fn main() -> ! { input::init(); } - let mut apic = Apic::default(); - unsafe { init_vcpu(&mut apic) }; + init_vcpu(); wait_for_vcpu_start(); run_vcpu() } diff --git a/tee/supervisor-tdx/src/services.rs b/tee/supervisor-tdx/src/services.rs index 9b023a26..651e9eba 100644 --- a/tee/supervisor-tdx/src/services.rs +++ b/tee/supervisor-tdx/src/services.rs @@ -26,29 +26,21 @@ fn supervisor_services() -> &'static SupervisorServices { static HANDLER: Lazy> = Lazy::new(|| Mutex::new(Handler::new())); -pub fn handle(resume: bool) { - interrupts::disable(); - - let idx = PerCpu::current_vcpu_index(); - +pub fn handle(mut resume: bool) { if let Some(mut handler) = HANDLER.try_lock() { let mut command_buffer_reader = CommandBufferReader::new(&supervisor_services().command_buffer); while command_buffer_reader.handle(&mut *handler) {} drop(handler); - let mut saw_self = false; + let idx = PerCpu::current_vcpu_index(); for id in supervisor_services().notification_buffer.reset() { if id == idx { - saw_self = true; + resume = true; } else { send_ipi(u32::from(id.as_u8()), WAKEUP_VECTOR); } } - if saw_self { - interrupts::enable(); - return; - } } if resume { diff --git a/tee/supervisor-tdx/src/vcpu.rs b/tee/supervisor-tdx/src/vcpu.rs index 4fb766f3..b6c0c10b 100644 --- a/tee/supervisor-tdx/src/vcpu.rs +++ b/tee/supervisor-tdx/src/vcpu.rs @@ -4,7 +4,8 @@ use core::{ sync::atomic::{AtomicUsize, Ordering}, }; -use constants::MAX_APS_COUNT; +use bit_field::BitField; +use constants::{ApIndex, MAX_APS_COUNT}; use tdx_types::{ tdcall::{ Apic, GuestState, InvdTranslations, MdFieldId, VmIndex, TDX_L2_EXIT_HOST_ROUTED_ASYNC, @@ -62,43 +63,56 @@ pub fn wait_for_vcpu_start() { } /// Initialize the L2 VM. -/// -/// # Safety -/// -/// The caller must ensure the `apic` is valid until the end of time. -pub unsafe fn init_vcpu(apic: &mut Apic) { - let apic = core::ptr::from_mut(apic) as u64; - +pub fn init_vcpu() { // Enable access to the shared EPT. - Tdcall::vp_wr( - MdFieldId::TDVPS_L2_CTLS1, - u64::from(cfg!(not(feature = "harden"))), - 1, - ); + unsafe { + Tdcall::vp_wr( + MdFieldId::TDVPS_L2_CTLS1, + u64::from(cfg!(not(feature = "harden"))), + 1, + ); + } // Enable 64-bit mode. - Tdcall::vp_wr(MdFieldId::VMX_VM_ENTRY_CONTROL, 1 << 9, 1 << 9); + unsafe { + Tdcall::vp_wr(MdFieldId::VMX_VM_ENTRY_CONTROL, 1 << 9, 1 << 9); + } // Enabled mode-based execute control for EPT. - Tdcall::vp_wr( - MdFieldId::VMX_VM_EXECUTION_CONTROL_SECONDARY_PROC_BASED, - 1 << 22, - 1 << 22, - ); + unsafe { + Tdcall::vp_wr( + MdFieldId::VMX_VM_EXECUTION_CONTROL_SECONDARY_PROC_BASED, + 1 << 22, + 1 << 22, + ); + } // Adjust CS segment. - Tdcall::vp_wr(MdFieldId::VMX_GUEST_CS_ARBYTE, 0xa09b, !0); + unsafe { + Tdcall::vp_wr(MdFieldId::VMX_GUEST_CS_ARBYTE, 0xa09b, !0); + } - Tdcall::vp_wr(MdFieldId::VMX_VIRTUAL_APIC_PAGE_ADDRESS, apic, !0); + let idx = PerCpu::current_vcpu_index(); + let apic = &APICS[idx]; + apic.set_id(u32::from(idx.as_u8())); + unsafe { + Tdcall::vp_wr( + MdFieldId::VMX_VIRTUAL_APIC_PAGE_ADDRESS, + apic as *const _ as u64, + !0, + ); + } - Tdcall::vp_wr( - MdFieldId::VMX_GUEST_IA32_EFER, - EferFlags::SYSTEM_CALL_EXTENSIONS.bits() - | EferFlags::LONG_MODE_ENABLE.bits() - | EferFlags::LONG_MODE_ACTIVE.bits() - | EferFlags::NO_EXECUTE_ENABLE.bits(), - !0, - ); + unsafe { + Tdcall::vp_wr( + MdFieldId::VMX_GUEST_IA32_EFER, + EferFlags::SYSTEM_CALL_EXTENSIONS.bits() + | EferFlags::LONG_MODE_ENABLE.bits() + | EferFlags::LONG_MODE_ACTIVE.bits() + | EferFlags::NO_EXECUTE_ENABLE.bits(), + !0, + ); + } let cr4_flags = Cr4Flags::PHYSICAL_ADDRESS_EXTENSION.bits() | Cr4Flags::MACHINE_CHECK_EXCEPTION.bits() @@ -111,10 +125,14 @@ pub unsafe fn init_vcpu(apic: &mut Apic) { | Cr4Flags::OSXSAVE.bits() | Cr4Flags::SUPERVISOR_MODE_EXECUTION_PROTECTION.bits() | Cr4Flags::SUPERVISOR_MODE_ACCESS_PREVENTION.bits(); - Tdcall::vp_wr(MdFieldId::VMX_GUEST_CR4, cr4_flags, !0); - Tdcall::vp_wr(MdFieldId::VMX_CR4_READ_SHADOW, cr4_flags, !0); + unsafe { + Tdcall::vp_wr(MdFieldId::VMX_GUEST_CR4, cr4_flags, !0); + Tdcall::vp_wr(MdFieldId::VMX_CR4_READ_SHADOW, cr4_flags, !0); + } - Tdcall::vp_wr(MdFieldId::VMX_GUEST_CR3, 0x100_0000_1000, !0); + unsafe { + Tdcall::vp_wr(MdFieldId::VMX_GUEST_CR3, 0x100_0000_1000, !0); + } let cr0_flags = Cr0Flags::PROTECTED_MODE_ENABLE.bits() | Cr0Flags::MONITOR_COPROCESSOR.bits() @@ -122,14 +140,25 @@ pub unsafe fn init_vcpu(apic: &mut Apic) { | Cr0Flags::NUMERIC_ERROR.bits() | Cr0Flags::WRITE_PROTECT.bits() | Cr0Flags::PAGING.bits(); - Tdcall::vp_wr(MdFieldId::VMX_GUEST_CR0, cr0_flags, !0); - Tdcall::vp_wr(MdFieldId::VMX_CR0_READ_SHADOW, cr0_flags, !0); + unsafe { + Tdcall::vp_wr(MdFieldId::VMX_GUEST_CR0, cr0_flags, !0); + Tdcall::vp_wr(MdFieldId::VMX_CR0_READ_SHADOW, cr0_flags, !0); + } - Tdcall::vp_wr(MdFieldId::STAR_WRITE, 0, MdFieldId::STAR_WRITE_MASK); - Tdcall::vp_wr(MdFieldId::LSTAR_WRITE, 0, MdFieldId::LSTAR_WRITE_MASK); - Tdcall::vp_wr(MdFieldId::SFMASK_WRITE, 0, MdFieldId::SFMASK_WRITE_MASK); + unsafe { + Tdcall::vp_wr(MdFieldId::STAR_WRITE, 0, MdFieldId::STAR_WRITE_MASK); + Tdcall::vp_wr(MdFieldId::LSTAR_WRITE, 0, MdFieldId::LSTAR_WRITE_MASK); + Tdcall::vp_wr(MdFieldId::SFMASK_WRITE, 0, MdFieldId::SFMASK_WRITE_MASK); + Tdcall::vp_wr( + MdFieldId::X2APIC_EOI_WRITE, + 0, + MdFieldId::X2APIC_EOI_WRITE_MASK, + ); + } } +static APICS: [Apic; MAX_APS_COUNT as usize] = [const { Apic::new() }; MAX_APS_COUNT as usize]; + pub fn run_vcpu() -> ! { let mut guest_state = GuestState { rax: 0, @@ -154,8 +183,17 @@ pub fn run_vcpu() -> ! { guest_interrupt_status: 0, }; + let idx = PerCpu::current_vcpu_index(); + loop { interrupts::disable(); + + // Update the RVI field. + let rvi = APICS[idx].pending_vector().unwrap_or_default(); + guest_state + .guest_interrupt_status + .set_bits(0..8, u16::from(rvi)); + tlb::pre_enter(); let flush = if PerCpu::with(|per_cpu| per_cpu.pending_flushes.take()) { InvdTranslations::All @@ -183,7 +221,39 @@ pub fn run_vcpu() -> ! { guest_state.rip += u64::from(vm_exit.vm_exit_instruction_length); } VMEXIT_REASON_HLT_INSTRUCTION => { - handle(guest_state.rax != 0); + interrupts::disable(); + let resume = guest_state.rax != 0 || APICS[idx].pending_vector().is_some(); + handle(resume); + guest_state.rip += u64::from(vm_exit.vm_exit_instruction_length); + } + VMEXIT_REASON_MSR_WRITE => { + let value = guest_state.rax.get_bits(..32) | (guest_state.rdx << 32); + match guest_state.rcx { + 0x40000000 => { + // Ignore writes to HV_X64_MSR_GUEST_OS_ID. + } + // IA32_X2APIC_ICR + 0x830 => { + // We don't support all options. Check that we support the fields. + assert_eq!(value.get_bits(8..11), 0); // Delivery Mode: Fixed + assert!(!value.get_bit(11)); // Destination Mode: Physical + assert!(value.get_bit(14)); // Level: Assert + assert_eq!(value.get_bits(18..20), 0b00); // Destination Shorthand: Destination + + // Set the IRR bit in the APIC page. + let vector = value.get_bits(..8) as u8; + let destination = value.get_bits(32..) as u32; + let idx = ApIndex::new(destination as u8); + let was_set = APICS[idx].set_irr(vector); + + // If the bit was not already set, send an IPI to the + // supervisor, so that it re-evaluates the RVI. + if !was_set { + send_ipi(destination, WAKEUP_VECTOR); + } + } + rcx => unimplemented!("MSR write: {rcx:#x}"), + } guest_state.rip += u64::from(vm_exit.vm_exit_instruction_length); } VMEXIT_REASON_VMCALL_INSTRUCTION => { @@ -205,15 +275,6 @@ pub fn run_vcpu() -> ! { guest_state.rax = 0; guest_state.rip += u64::from(vm_exit.vm_exit_instruction_length); } - VMEXIT_REASON_MSR_WRITE => { - match guest_state.rcx { - 0x40000000 => { - // Ignore writes to HV_X64_MSR_GUEST_OS_ID. - } - rcx => panic!("{rcx:#x}"), - } - guest_state.rip += u64::from(vm_exit.vm_exit_instruction_length); - } unknown => panic!("{unknown:#x} {guest_state:x?} {vm_exit:x?}"), } } From 055dab34966e7909bd084ce65d4594eb8e1e25fc Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 11 Nov 2024 18:44:29 +0100 Subject: [PATCH 15/21] read notifications buffer before processing commands --- host/mushroom/src/insecure.rs | 7 +++++-- tee/supervisor-snp/src/services.rs | 8 +++++--- tee/supervisor-tdx/src/services.rs | 7 +++++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/host/mushroom/src/insecure.rs b/host/mushroom/src/insecure.rs index bac683ed..f04c6795 100644 --- a/host/mushroom/src/insecure.rs +++ b/host/mushroom/src/insecure.rs @@ -192,9 +192,12 @@ pub fn main( finish_status: None, }; let finish_status = loop { - while command_buffer_reader.handle(&mut handler) {} + let mut pending = supervisor_services.notification_buffer.reset(); + while command_buffer_reader.handle(&mut handler) { + pending |= supervisor_services.notification_buffer.reset(); + } - for i in supervisor_services.notification_buffer.reset() { + for i in pending { ap_threads[usize::from(i.as_u8())].thread().unpark(); } diff --git a/tee/supervisor-snp/src/services.rs b/tee/supervisor-snp/src/services.rs index c21ad83a..b7e4f7b4 100644 --- a/tee/supervisor-snp/src/services.rs +++ b/tee/supervisor-snp/src/services.rs @@ -26,12 +26,14 @@ pub fn run() -> ! { loop { // Handle all pending commands. - while command_buffer_reader.handle(&mut handler) {} + let mut pending = supervisor_services().notification_buffer.reset(); + while command_buffer_reader.handle(&mut handler) { + pending |= supervisor_services().notification_buffer.reset(); + } // Notify the APs that requested a notification and wait for the next // command. - - for id in supervisor_services().notification_buffer.reset() { + for id in pending { kick(id); } diff --git a/tee/supervisor-tdx/src/services.rs b/tee/supervisor-tdx/src/services.rs index 651e9eba..af404b96 100644 --- a/tee/supervisor-tdx/src/services.rs +++ b/tee/supervisor-tdx/src/services.rs @@ -30,11 +30,14 @@ pub fn handle(mut resume: bool) { if let Some(mut handler) = HANDLER.try_lock() { let mut command_buffer_reader = CommandBufferReader::new(&supervisor_services().command_buffer); - while command_buffer_reader.handle(&mut *handler) {} + let mut pending = supervisor_services().notification_buffer.reset(); + while command_buffer_reader.handle(&mut *handler) { + pending |= supervisor_services().notification_buffer.reset(); + } drop(handler); let idx = PerCpu::current_vcpu_index(); - for id in supervisor_services().notification_buffer.reset() { + for id in pending { if id == idx { resume = true; } else { From 1f88919b8e7c1f01ba1de42e2f3378ff2fcb7446 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 11 Nov 2024 19:06:33 +0100 Subject: [PATCH 16/21] bump x86_64 rev This new rev contains some invpcid related patches. --- common/Cargo.lock | 2 +- common/Cargo.toml | 2 +- host/Cargo.lock | 2 +- host/Cargo.toml | 2 +- tee/Cargo.lock | 2 +- tee/Cargo.toml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/common/Cargo.lock b/common/Cargo.lock index 3ddf6ef5..59ffdf31 100644 --- a/common/Cargo.lock +++ b/common/Cargo.lock @@ -648,7 +648,7 @@ dependencies = [ [[package]] name = "x86_64" version = "0.15.1" -source = "git+https://github.com/rust-osdev/x86_64.git?rev=c5bc9fc#c5bc9fcefdeb99286fd00793a0c840b142114349" +source = "git+https://github.com/rust-osdev/x86_64.git?rev=3fc9106#3fc91064a3d0952b78c97eebba4e2e5c440217f2" dependencies = [ "bit_field", "bitflags", diff --git a/common/Cargo.toml b/common/Cargo.toml index b37e1129..33fca31a 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -9,4 +9,4 @@ snp-types = { path = "snp-types" } tdx-types = { path = "tdx-types" } [patch.crates-io] -x86_64 = { git = "https://github.com/rust-osdev/x86_64.git", rev = "c5bc9fc" } +x86_64 = { git = "https://github.com/rust-osdev/x86_64.git", rev = "3fc9106" } diff --git a/host/Cargo.lock b/host/Cargo.lock index b80e32a1..a8310916 100644 --- a/host/Cargo.lock +++ b/host/Cargo.lock @@ -2011,7 +2011,7 @@ dependencies = [ [[package]] name = "x86_64" version = "0.15.1" -source = "git+https://github.com/rust-osdev/x86_64.git?rev=c5bc9fc#c5bc9fcefdeb99286fd00793a0c840b142114349" +source = "git+https://github.com/rust-osdev/x86_64.git?rev=3fc9106#3fc91064a3d0952b78c97eebba4e2e5c440217f2" dependencies = [ "bit_field", "bitflags", diff --git a/host/Cargo.toml b/host/Cargo.toml index 8e203d9e..409b4587 100644 --- a/host/Cargo.toml +++ b/host/Cargo.toml @@ -16,7 +16,7 @@ tdx-types = { path = "../common/tdx-types", features = ["quote"] } vcek-kds = { path = "vcek-kds" } [patch.crates-io] -x86_64 = { git = "https://github.com/rust-osdev/x86_64.git", rev = "c5bc9fc" } +x86_64 = { git = "https://github.com/rust-osdev/x86_64.git", rev = "3fc9106" } [profile.dev.package.sha2] opt-level = 3 diff --git a/tee/Cargo.lock b/tee/Cargo.lock index 75fb8447..4c0bc129 100644 --- a/tee/Cargo.lock +++ b/tee/Cargo.lock @@ -937,7 +937,7 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "x86_64" version = "0.15.1" -source = "git+https://github.com/rust-osdev/x86_64.git?rev=c5bc9fc#c5bc9fcefdeb99286fd00793a0c840b142114349" +source = "git+https://github.com/rust-osdev/x86_64.git?rev=3fc9106#3fc91064a3d0952b78c97eebba4e2e5c440217f2" dependencies = [ "bit_field", "bitflags", diff --git a/tee/Cargo.toml b/tee/Cargo.toml index 9c949ad7..81d51095 100644 --- a/tee/Cargo.toml +++ b/tee/Cargo.toml @@ -42,4 +42,4 @@ debug-assertions = true overflow-checks = true [patch.crates-io] -x86_64 = { git = "https://github.com/rust-osdev/x86_64.git", rev = "c5bc9fc" } +x86_64 = { git = "https://github.com/rust-osdev/x86_64.git", rev = "3fc9106" } From 8c52393831d421110011eb7c1bb78260888b8def Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 11 Nov 2024 18:50:14 +0100 Subject: [PATCH 17/21] rewrite TLB flushing --- tee/kernel/src/main.rs | 4 + tee/kernel/src/memory/pagetable.rs | 179 +++++++++++++------ tee/kernel/src/memory/pagetable/flush.rs | 209 ++++++++++++++++++++++- tee/kernel/src/supervisor.rs | 6 +- tee/supervisor-tdx/src/main.rs | 6 + 5 files changed, 347 insertions(+), 57 deletions(-) diff --git a/tee/kernel/src/main.rs b/tee/kernel/src/main.rs index 20e58761..547073a4 100644 --- a/tee/kernel/src/main.rs +++ b/tee/kernel/src/main.rs @@ -34,7 +34,9 @@ compiler_error!("Hardened kernels can't be profiled."); extern crate alloc; use exception::switch_stack; +use memory::pagetable::flush; use supervisor::SCHEDULER; +use x86_64::instructions::interrupts; use crate::per_cpu::PerCpu; @@ -68,6 +70,7 @@ unsafe fn main() -> ! { } PerCpu::init(); + flush::init(); #[cfg(feature = "profiling")] if PerCpu::get().idx.is_first() { @@ -78,6 +81,7 @@ unsafe fn main() -> ! { exception::load_early_gdt(); exception::load_idt(); + interrupts::enable(); switch_stack(init) } diff --git a/tee/kernel/src/memory/pagetable.rs b/tee/kernel/src/memory/pagetable.rs index ca95f575..f3d14c4e 100644 --- a/tee/kernel/src/memory/pagetable.rs +++ b/tee/kernel/src/memory/pagetable.rs @@ -23,7 +23,11 @@ use crate::spin::lazy::Lazy; use alloc::sync::Arc; use bit_field::BitField; use bitflags::bitflags; -use constants::physical_address::{kernel::*, *}; +use constants::{ + physical_address::{kernel::*, *}, + ApBitmap, +}; +use flush::{FlushGuard, GlobalFlushGuard}; use log::trace; use static_page_tables::{flags, StaticPageTable, StaticPd, StaticPdp, StaticPml4, StaticPt}; use x86_64::{ @@ -35,7 +39,6 @@ use x86_64::{ use super::{ frame::{allocate_frame, deallocate_frame}, - invlpgb::INVLPGB, temporary::copy_into_frame, }; @@ -171,15 +174,15 @@ pub unsafe fn map_page(page: Page, entry: PresentPageTableEntry) -> Result<()> { let level4 = ActivePageTable::get(); let level4_entry = &level4[page.p4_index()]; - let level3_guard = level4_entry.acquire(entry.flags())?; + let level3_guard = level4_entry.acquire(entry.flags(), &GlobalFlushGuard)?; let level3 = &*level3_guard; let level3_entry = &level3[page.p3_index()]; - let level2_guard = level3_entry.acquire(entry.flags())?; + let level2_guard = level3_entry.acquire(entry.flags(), &GlobalFlushGuard)?; let level2 = &*level2_guard; let level2_entry = &level2[page.p2_index()]; - let level1_guard = level2_entry.acquire(entry.flags())?; + let level1_guard = level2_entry.acquire(entry.flags(), &GlobalFlushGuard)?; let level1 = &*level1_guard; let level1_entry = &level1[page.p1_index()]; @@ -201,15 +204,15 @@ pub unsafe fn unmap_page_no_flush(page: Page) -> PresentPageTableEntry { let level4 = ActivePageTable::get(); let level4_entry = &level4[page.p4_index()]; - let level3_guard = level4_entry.acquire_existing().unwrap(); + let level3_guard = level4_entry.acquire_existing(&GlobalFlushGuard).unwrap(); let level3 = &*level3_guard; let level3_entry = &level3[page.p3_index()]; - let level2_guard = level3_entry.acquire_existing().unwrap(); + let level2_guard = level3_entry.acquire_existing(&GlobalFlushGuard).unwrap(); let level2 = &*level2_guard; let level2_entry = &level2[page.p2_index()]; - let level1_guard = level2_entry.acquire_existing().unwrap(); + let level1_guard = level2_entry.acquire_existing(&GlobalFlushGuard).unwrap(); let level1 = &*level1_guard; let level1_entry = &level1[page.p1_index()]; @@ -231,11 +234,11 @@ pub unsafe fn unmap_page(page: Page) -> PresentPageTableEntry { pub fn entry_for_page(page: Page) -> Option { let pml4 = ActivePageTable::get(); let pml4e = &pml4[page.p4_index()]; - let pdp = pml4e.acquire_existing()?; + let pdp = pml4e.acquire_existing(&GlobalFlushGuard)?; let pdpe = &pdp[page.p3_index()]; - let pd = pdpe.acquire_existing()?; + let pd = pdpe.acquire_existing(&GlobalFlushGuard)?; let pde = &pd[page.p2_index()]; - let pt = pde.acquire_existing()?; + let pt = pde.acquire_existing(&GlobalFlushGuard)?; let pte = &pt[page.p1_index()]; pte.entry() } @@ -325,14 +328,38 @@ pub fn check_user_address(addr: VirtAddr, len: usize) -> Result<()> { check_user_page(page) } +struct FlushState { + /// A bitmap containing all APs currently using the page table. + active: ApBitmap, + /// A bitmap containing all APs that activated the page tables in the past + /// or now and have not yet flushed. + used: ApBitmap, + /// A bitmap containing all APs that need to flush the PCID the next time + /// they activate the page tables. + needs_flush: ApBitmap, +} + +impl FlushState { + pub fn new() -> Self { + Self { + active: ApBitmap::empty(), + used: ApBitmap::empty(), + needs_flush: ApBitmap::all(), + } + } +} + pub struct PagetablesAllocations { pml4: PhysFrame, + flush_state: Mutex, /// None if PCID is not supported. pcid_allocation: Option, } impl Drop for PagetablesAllocations { fn drop(&mut self) { + assert!(self.flush_state.get_mut().active.is_empty()); + unsafe { deallocate_frame(self.pml4); } @@ -380,6 +407,7 @@ impl Pagetables { .then(|| ALLOCATIONS.lock().allocate()); let allocations = PagetablesAllocations { pml4: frame, + flush_state: Mutex::new(FlushState::new()), pcid_allocation, }; let allocations = Arc::new(allocations); @@ -406,10 +434,26 @@ impl Pagetables { .as_ref() .is_some_and(|existing| Arc::ptr_eq(existing, allocations)); - if update_required { + let mut flush_state_guard = allocations.flush_state.lock(); + let ap_index = PerCpu::get().idx; + flush_state_guard.active.set(ap_index, true); + flush_state_guard.used.set(ap_index, true); + let needs_flush = flush_state_guard.needs_flush.get(ap_index); + if needs_flush { + flush_state_guard.needs_flush.set(ap_index, false); + } + drop(flush_state_guard); + + if update_required || needs_flush { if let Some(pcid_allocation) = allocations.pcid_allocation.as_ref() { - unsafe { - Cr3::write_pcid_no_flush(allocations.pml4, pcid_allocation.pcid); + if needs_flush { + unsafe { + Cr3::write_pcid(allocations.pml4, pcid_allocation.pcid); + } + } else { + unsafe { + Cr3::write_pcid_no_flush(allocations.pml4, pcid_allocation.pcid); + } } } else { unsafe { @@ -417,12 +461,17 @@ impl Pagetables { } } - *guard = Some(allocations.clone()); + if update_required { + *guard = Some(allocations.clone()); + } } + let guard = RefMut::map(guard, |a| a.as_mut().unwrap()); + ActivePageTableGuard { - _guard: guard, + guard, pml4: ActivePageTable::get(), + _marker: PhantomData, } } @@ -435,15 +484,15 @@ impl Pagetables { let level4 = self.activate(); let level4_entry = &level4[page.p4_index()]; - let level3_guard = level4_entry.acquire(entry.flags())?; + let level3_guard = level4_entry.acquire(entry.flags(), &level4)?; let level3 = &*level3_guard; let level3_entry = &level3[page.p3_index()]; - let level2_guard = level3_entry.acquire(entry.flags())?; + let level2_guard = level3_entry.acquire(entry.flags(), &level4)?; let level2 = &*level2_guard; let level2_entry = &level2[page.p2_index()]; - let level1_guard = level2_entry.acquire(entry.flags())?; + let level1_guard = level2_entry.acquire(entry.flags(), &level4)?; let level1 = &*level1_guard; let level1_entry = &level1[page.p1_index()]; @@ -465,7 +514,7 @@ impl Pagetables { } } - flush_current_pcid(); + level4.flush_all(); } /// Unmap a page if it's mapped. @@ -507,7 +556,7 @@ impl Pagetables { let pml4 = self.activate(); for p4_index in start.p4_index()..=end.p4_index() { let pml4e = &pml4[p4_index]; - let Some(pdp) = pml4e.acquire_existing() else { + let Some(pdp) = pml4e.acquire_existing(&pml4) else { continue; }; @@ -532,7 +581,7 @@ impl Pagetables { for p3_index in start.p3_index()..=end.p3_index() { let pdpe = &pdp[p3_index]; - let Some(pd) = pdpe.acquire_existing() else { + let Some(pd) = pdpe.acquire_existing(&pml4) else { continue; }; @@ -557,7 +606,7 @@ impl Pagetables { for p2_index in start.p2_index()..=end.p2_index() { let pde = &pd[p2_index]; - let Some(pt) = pde.acquire_existing() else { + let Some(pt) = pde.acquire_existing(&pml4) else { continue; }; @@ -590,10 +639,7 @@ impl Pagetables { } } - let (_, pcid) = Cr3::read_pcid(); - unsafe { - INVLPGB.flush_user_pages(pcid, start..=end); - } + pml4.flush_pages(start..=end); } /// Try to copy user memory from `src` into `dest`. @@ -637,8 +683,10 @@ impl Pagetables { } struct ActivePageTableGuard { - _guard: RefMut<'static, Option>>, + guard: RefMut<'static, Arc>, pml4: &'static ActivePageTable, + // Make sure the type is neither `Send` nor `Sync`. + _marker: PhantomData<*const ()>, } impl Deref for ActivePageTableGuard { @@ -649,13 +697,10 @@ impl Deref for ActivePageTableGuard { } } -fn flush_current_pcid() { - let cr4 = Cr4::read(); - if cr4.contains(Cr4Flags::PCID) { - let (_, pcid) = Cr3::read_pcid(); - INVLPGB.flush_pcid(pcid); - } else { - INVLPGB.flush_all(); +impl Drop for ActivePageTableGuard { + fn drop(&mut self) { + let mut guard = self.guard.flush_state.lock(); + guard.active.set(PerCpu::get().idx, false); } } @@ -787,14 +832,27 @@ where } impl ActivePageTableEntry { - pub fn acquire(&self, flags: PageTableFlags) -> Result> { + pub fn acquire<'a, F>( + &'a self, + flags: PageTableFlags, + guard: &'a F, + ) -> Result> + where + F: FlushGuard, + { self.acquire_reference_count(flags).unwrap(); - Ok(ActivePageTableEntryGuard { entry: self }) + Ok(ActivePageTableEntryGuard { entry: self, guard }) } - pub fn acquire_existing(&self) -> Option> { + pub fn acquire_existing<'a, F>( + &'a self, + guard: &'a F, + ) -> Option> + where + F: FlushGuard, + { self.increase_reference_count().ok()?; - Some(ActivePageTableEntryGuard { entry: self }) + Some(ActivePageTableEntryGuard { entry: self, guard }) } } @@ -802,7 +860,14 @@ impl ActivePageTableEntry where L: HasParentLevel + TableLevel, { - pub fn acquire(&self, flags: PageTableFlags) -> Result> { + pub fn acquire<'a, F>( + &'a self, + flags: PageTableFlags, + guard: &'a F, + ) -> Result> + where + F: FlushGuard, + { let initialized = self.acquire_reference_count(flags).unwrap(); if initialized { @@ -810,12 +875,18 @@ where parent_entry.increase_reference_count().unwrap(); } - Ok(ActivePageTableEntryGuard { entry: self }) + Ok(ActivePageTableEntryGuard { entry: self, guard }) } - pub fn acquire_existing(&self) -> Option> { + pub fn acquire_existing<'a, F>( + &'a self, + guard: &'a F, + ) -> Option> + where + F: FlushGuard, + { self.increase_reference_count().ok()?; - Some(ActivePageTableEntryGuard { entry: self }) + Some(ActivePageTableEntryGuard { entry: self, guard }) } } @@ -968,7 +1039,7 @@ where /// /// The caller must ensure that the entry is that `release` is only called /// after the `acquire` is no longer needed. - unsafe fn release_reference_count(&self) -> Option { + unsafe fn release_reference_count(&self, guard: &impl FlushGuard) -> Option { if self.is_static_entry() { return None; } @@ -1028,7 +1099,7 @@ where } // Step 2: - self.flush(true); + self.flush(guard); // Step 3: atomic_store(&self.entry, 0); @@ -1070,8 +1141,8 @@ where } impl ActivePageTableEntry { - fn flush(&self, global: bool) { - INVLPGB.flush_page(self.page(), global); + fn flush(&self, guard: &impl FlushGuard) { + guard.flush_page(self.page()); } pub fn page(&self) -> Page { @@ -1184,18 +1255,21 @@ where } #[must_use] -struct ActivePageTableEntryGuard<'a, L> +struct ActivePageTableEntryGuard<'a, L, F> where L: TableLevel, ActivePageTableEntry: ParentEntry, + F: FlushGuard, { entry: &'a ActivePageTableEntry, + guard: &'a F, } -impl Deref for ActivePageTableEntryGuard<'_, L> +impl Deref for ActivePageTableEntryGuard<'_, L, F> where L: TableLevel, ActivePageTableEntry: ParentEntry, + F: FlushGuard, { type Target = ActivePageTable; @@ -1205,17 +1279,18 @@ where } } -impl Drop for ActivePageTableEntryGuard<'_, L> +impl Drop for ActivePageTableEntryGuard<'_, L, F> where L: TableLevel, ActivePageTableEntry: ParentEntry, + F: FlushGuard, { fn drop(&mut self) { // Release reference count. let frame = unsafe { // SAFETY: We're releasing the reference count acquired in // ActivePageTableEntry::acquire`. - self.entry.release_reference_count() + self.entry.release_reference_count(self.guard) }; // Check if the entry was freed. @@ -1593,8 +1668,6 @@ impl PcidAllocations { } unsafe fn deallocate(&mut self, pcid: Pcid) { - INVLPGB.flush_pcid(pcid); - self.in_use[usize::from(pcid.value())] = false; } } diff --git a/tee/kernel/src/memory/pagetable/flush.rs b/tee/kernel/src/memory/pagetable/flush.rs index 8717b98f..0994483f 100644 --- a/tee/kernel/src/memory/pagetable/flush.rs +++ b/tee/kernel/src/memory/pagetable/flush.rs @@ -1,21 +1,42 @@ -use core::arch::asm; +use core::{arch::asm, ops::RangeInclusive}; use bit_field::BitField; use constants::{ApBitmap, ApIndex, AtomicApBitmap}; use x86_64::{ - instructions::tlb, + instructions::tlb::{self, InvPicdCommand, Invlpgb}, registers::{ control::{Cr3, Cr4, Cr4Flags}, model_specific::Msr, }, - structures::idt::InterruptStackFrame, + structures::{idt::InterruptStackFrame, paging::Page}, }; +use crate::{per_cpu::PerCpu, spin::lazy::Lazy}; + +use super::ActivePageTableGuard; + pub const TLB_VECTOR: u8 = 0x20; +static INVLPGB: Lazy> = Lazy::new(Invlpgb::new); + +static ACTIVE_APS: AtomicApBitmap = AtomicApBitmap::empty(); static PENDING_TLB_SHOOTDOWN: AtomicApBitmap = AtomicApBitmap::empty(); static PENDING_GLOBAL_TLB_SHOOTDOWN: AtomicApBitmap = AtomicApBitmap::empty(); +pub fn init() { + post_halt(); +} + +pub fn pre_halt() { + ACTIVE_APS.take(PerCpu::get().idx); +} + +pub fn post_halt() { + let idx = PerCpu::get().idx; + ACTIVE_APS.set(idx); + process_flushes(idx); +} + fn process_flushes(idx: ApIndex) { let need_global_flush = PENDING_GLOBAL_TLB_SHOOTDOWN.take(idx); let need_non_global_flush = PENDING_TLB_SHOOTDOWN.take(idx); @@ -76,3 +97,185 @@ fn send_tlb_ipis(aps: ApBitmap) { } } } + +impl ActivePageTableGuard { + pub fn flush_all(&self) { + let idx = PerCpu::get().idx; + + let mut guard = self.guard.flush_state.lock(); + let state = &mut *guard; + + state.needs_flush |= state.used; + state.used = state.active; + // Unmark this vCPU from needing to be flushed. We'll flush the current + // AP immediately. + state.needs_flush.set(idx, false); + + // Check if the current AP is the only active AP. + let mut other_active_aps = state.active; + other_active_aps.set(idx, false); + + // If the current AP is the only active AP, we don't need to tell other + // APs to flush immediately. We only need to flush the TLB on the + // current AP. + if other_active_aps.is_empty() { + drop(guard); + return self.flush_all_local(); + } + + // We have to flush the TLB on other APs :( + + // If invlpgb is supported, use it. + if let Some(invlpgb) = INVLPGB.as_ref() { + // Note that we don't drop `guard` until we're done flushing. + + let pcid_allocation = self.guard.pcid_allocation.as_ref().unwrap(); + unsafe { + invlpgb.build().pcid(pcid_allocation.pcid).flush(); + } + invlpgb.tlbsync(); + + state.needs_flush = ApBitmap::empty(); + + return; + } + + // We've run out of optimizations :( + // Flush on the current processor and send IPIs to the other relevant + // APs. + + drop(guard); + + PENDING_TLB_SHOOTDOWN.set_all(other_active_aps); + send_tlb_ipis(other_active_aps); + + self.flush_all_local(); + + let mut remaining_aps = other_active_aps; + while !remaining_aps.is_empty() { + remaining_aps &= PENDING_TLB_SHOOTDOWN.get_all(); + } + } + + pub fn flush_all_local(&self) { + if let Some(pcid_allocation) = self.guard.pcid_allocation.as_ref() { + unsafe { + tlb::flush_pcid(InvPicdCommand::Single(pcid_allocation.pcid)); + } + } else { + tlb::flush_all(); + } + } + + pub fn flush_pages(&self, pages: RangeInclusive) { + let num_pages = pages.clone().count(); + if num_pages > 32 { + return self.flush_all(); + } + + let idx = PerCpu::get().idx; + + let mut guard = self.guard.flush_state.lock(); + let state = &mut *guard; + + // Check if the current AP is the only active AP. + let mut other_active_aps = state.active; + other_active_aps.set(idx, false); + + // If the current AP is the only active AP, we don't need to tell other + // APs to flush immediately. We only need to flush the TLB on the + // current AP. + if other_active_aps.is_empty() { + drop(guard); + return self.flush_pages_local(pages); + } + + // We have to flush the TLB on other APs :( + + // If invlpgb is supported, use it. + if let Some(invlpgb) = INVLPGB.as_ref() { + // Note that we don't drop `guard` until we're done flushing. + + let pcid_allocation = self.guard.pcid_allocation.as_ref().unwrap(); + let mut builder = invlpgb.build(); + unsafe { + builder.pcid(pcid_allocation.pcid); + } + if num_pages < usize::from(invlpgb.invlpgb_count_max()) { + builder = builder.pages(Page::range(*pages.start(), *pages.end() + 1)); + } + builder.flush(); + invlpgb.tlbsync(); + + return; + } + + // We've run out of optimizations :( + // Flush on the current processor and send IPIs to the other relevant + // APs. + state.needs_flush |= state.used; + drop(guard); + + PENDING_TLB_SHOOTDOWN.set_all(other_active_aps); + send_tlb_ipis(other_active_aps); + + self.flush_pages_local(pages); + + let mut remaining_aps = other_active_aps; + while !remaining_aps.is_empty() { + remaining_aps &= PENDING_TLB_SHOOTDOWN.get_all(); + } + } + + pub fn flush_pages_local(&self, pages: RangeInclusive) { + for page in pages { + tlb::flush(page.start_address()); + } + } +} + +pub(super) trait FlushGuard { + fn flush_page(&self, page: Page); +} + +impl FlushGuard for ActivePageTableGuard { + fn flush_page(&self, page: Page) { + // TODO: Check that the page is a userspace page. + // TODO: Check that the pml4 is active. + self.flush_pages(page..=page); + } +} + +pub(super) struct GlobalFlushGuard; + +impl FlushGuard for GlobalFlushGuard { + fn flush_page(&self, page: Page) { + if let Some(invlpgb) = &*INVLPGB { + invlpgb + .build() + .pages(Page::range(page, page + 1)) + .include_global() + .flush(); + invlpgb.tlbsync(); + return; + } + + // Tell all other APs to flush their entire TLBs. + let mut all_other_aps = ApBitmap::all(); + all_other_aps.set(PerCpu::get().idx, false); + PENDING_GLOBAL_TLB_SHOOTDOWN.set_all(all_other_aps); + + // Send IPIs to all other currently active APs. + let other_active_aps = ACTIVE_APS.get_all() & all_other_aps; + send_tlb_ipis(other_active_aps); + + // Flush the local TLB entry. + tlb::flush(page.start_address()); + + // Wait for the APS to acknowledge the IPI. + let mut remaining_aps = other_active_aps; + while !remaining_aps.is_empty() { + remaining_aps &= PENDING_GLOBAL_TLB_SHOOTDOWN.get_all(); + } + } +} diff --git a/tee/kernel/src/supervisor.rs b/tee/kernel/src/supervisor.rs index 96b2881c..deab1b31 100644 --- a/tee/kernel/src/supervisor.rs +++ b/tee/kernel/src/supervisor.rs @@ -1,6 +1,6 @@ use core::arch::asm; -use crate::spin::mutex::Mutex; +use crate::{memory::pagetable::flush, spin::mutex::Mutex}; use arrayvec::ArrayVec; use bit_field::BitField; use constants::{physical_address::DYNAMIC_2MIB, ApBitmap, ApIndex}; @@ -180,8 +180,12 @@ pub fn halt() -> Result<(), LastRunningVcpuError> { #[cfg(feature = "profiling")] crate::profiler::flush(); + flush::pre_halt(); + kick_supervisor(false); + flush::post_halt(); + SCHEDULER.resume(); Ok(()) diff --git a/tee/supervisor-tdx/src/main.rs b/tee/supervisor-tdx/src/main.rs index 8266a34c..2c67a4e2 100644 --- a/tee/supervisor-tdx/src/main.rs +++ b/tee/supervisor-tdx/src/main.rs @@ -8,6 +8,7 @@ use log::{debug, LevelFilter}; use per_cpu::PerCpu; use spin::Once; use vcpu::{init_vcpu, run_vcpu, wait_for_vcpu_start}; +use x86_64::registers::model_specific::Msr; use crate::logging::SerialLogger; @@ -35,6 +36,11 @@ fn main() -> ! { }); } + const IA32_TSC_AUX: u32 = 0xC000_0103; + unsafe { + Msr::new(IA32_TSC_AUX).write(PerCpu::current_vcpu_index().as_u8() as u64); + } + setup_idt(); if PerCpu::current_vcpu_index().is_first() { From 670becda184f04841a97fefff2fffcacee695059 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 11 Nov 2024 18:54:37 +0100 Subject: [PATCH 18/21] add proper support for Hyper-V TLB flushing --- tee/kernel/src/memory.rs | 1 - tee/kernel/src/memory/invlpgb.rs | 156 -------------- tee/kernel/src/memory/pagetable/flush.rs | 247 ++++++++++++++++++++++- tee/supervisor-tdx/src/exception.rs | 7 +- tee/supervisor-tdx/src/main.rs | 1 - tee/supervisor-tdx/src/per_cpu.rs | 9 +- tee/supervisor-tdx/src/tlb.rs | 41 ---- tee/supervisor-tdx/src/vcpu.rs | 33 +-- 8 files changed, 248 insertions(+), 247 deletions(-) delete mode 100644 tee/kernel/src/memory/invlpgb.rs delete mode 100644 tee/supervisor-tdx/src/tlb.rs diff --git a/tee/kernel/src/memory.rs b/tee/kernel/src/memory.rs index 15f76505..5bd5b3db 100644 --- a/tee/kernel/src/memory.rs +++ b/tee/kernel/src/memory.rs @@ -1,6 +1,5 @@ pub mod frame; mod heap; -pub mod invlpgb; pub mod page; pub mod pagetable; pub mod temporary; diff --git a/tee/kernel/src/memory/invlpgb.rs b/tee/kernel/src/memory/invlpgb.rs deleted file mode 100644 index 570d9bca..00000000 --- a/tee/kernel/src/memory/invlpgb.rs +++ /dev/null @@ -1,156 +0,0 @@ -use core::{ - arch::{asm, x86_64::__cpuid}, - iter::Step, - ops::RangeInclusive, -}; - -use bit_field::BitField; -use bitflags::bitflags; -use x86_64::{ - instructions::tlb::{Invlpgb, Pcid}, - registers::model_specific::Msr, - structures::paging::Page, -}; - -use crate::spin::lazy::Lazy; - -pub static INVLPGB: Lazy = Lazy::new(InvlpgbCompat::new); - -pub enum InvlpgbCompat { - /// Use the native `invlpgb` and `tlbsync` instructions. This should always - /// be available on EPYC CPUs supporting SEV-SNP. - Invlpgb(Invlpgb), - /// Fall back to using Hyper-V instructions to emulate `invlpgb` and - /// `tlbsync`. - HyperV, -} - -impl InvlpgbCompat { - fn new() -> Self { - Invlpgb::new().map_or(Self::HyperV, Self::Invlpgb) - } - - pub fn flush_all(&self) { - match self { - InvlpgbCompat::Invlpgb(invlpgb) => { - invlpgb.build().flush(); - invlpgb.tlbsync(); - } - InvlpgbCompat::HyperV => hv_flush_all(), - } - } - - pub fn flush_pcid(&self, pcid: Pcid) { - match self { - InvlpgbCompat::Invlpgb(invlpgb) => { - unsafe { - invlpgb.build().pcid(pcid).flush(); - } - invlpgb.tlbsync(); - } - InvlpgbCompat::HyperV => hv_flush_all(), - } - } - - pub fn flush_page(&self, page: Page, global: bool) { - match self { - InvlpgbCompat::Invlpgb(invlpgb) => { - let flush = invlpgb.build(); - let next_page = Step::forward(page, 1); - let mut flush = flush.pages(Page::range(page, next_page)); - if global { - flush.include_global(); - } - flush.flush(); - invlpgb.tlbsync(); - } - InvlpgbCompat::HyperV => hv_flush_all(), - } - } - - /// Flush a range of pages. - pub unsafe fn flush_user_pages(&self, pcid: Pcid, pages: RangeInclusive) { - match self { - InvlpgbCompat::Invlpgb(invlpgb) => { - let mut flush = invlpgb.build(); - - unsafe { - flush.pcid(pcid); - } - - if pages.clone().count() < 64 { - let exlusive_end = Step::forward(*pages.end(), 1); - let page_range = Page::range(*pages.start(), exlusive_end); - flush = flush.pages(page_range); - } - - flush.flush(); - invlpgb.tlbsync(); - } - InvlpgbCompat::HyperV => hv_flush_all(), - } - } -} - -enum Hypercall { - Vmmcall, - Vmcall, -} - -static HYPERCALL: Lazy = Lazy::new(|| { - let svm = unsafe { __cpuid(0x8000_0001) }.ecx.get_bit(2); - if svm { - Hypercall::Vmmcall - } else { - Hypercall::Vmcall - } -}); - -fn hv_flush_all() { - const HV_X64_MSR_GUEST_OS_ID: u32 = 0x40000000; - unsafe { - Msr::new(HV_X64_MSR_GUEST_OS_ID).write(1); - } - - let mut hypercall_input = 0; - hypercall_input.set_bits(0..=15, 0x0002); // call code: HvCallFlushVirtualAddressSpace - hypercall_input.set_bit(16, true); // fast - - let flags = HvCallFlushVirtualAddressSpaceFlags::HV_FLUSH_ALL_PROCESSORS - | HvCallFlushVirtualAddressSpaceFlags::HV_FLUSH_ALL_VIRTUAL_ADDRESS_SPACES; - - let result: u64; - - match *HYPERCALL { - Hypercall::Vmmcall => unsafe { - asm! { - "vpxor xmm0, xmm0, xmm0", - "vmmcall", - in("rcx") hypercall_input, - in("rdx") flags.bits(), - inout("rax") 0x5a5a5a5a5a5a5a5au64 => result, - }; - }, - Hypercall::Vmcall => unsafe { - asm! { - "vpxor xmm0, xmm0, xmm0", - "vmcall", - in("rcx") hypercall_input, - in("rdx") flags.bits(), - inout("rax") 0x5a5a5a5a5a5a5a5au64 => result, - }; - }, - } - - assert_eq!(result.get_bits(0..=15), 0); -} - -bitflags! { - #[derive(Clone, Copy)] - #[repr(transparent)] - pub struct HvCallFlushVirtualAddressSpaceFlags: u64 { - const HV_FLUSH_ALL_PROCESSORS = 1 << 0; - const HV_FLUSH_ALL_VIRTUAL_ADDRESS_SPACES = 1 << 1; - const HV_FLUSH_NON_GLOBAL_MAPPINGS_ONLY = 1 << 2; - } -} diff --git a/tee/kernel/src/memory/pagetable/flush.rs b/tee/kernel/src/memory/pagetable/flush.rs index 0994483f..4e9f4847 100644 --- a/tee/kernel/src/memory/pagetable/flush.rs +++ b/tee/kernel/src/memory/pagetable/flush.rs @@ -1,7 +1,10 @@ -use core::{arch::asm, ops::RangeInclusive}; +use core::{ + arch::{asm, x86_64::__cpuid}, + ops::RangeInclusive, +}; -use bit_field::BitField; -use constants::{ApBitmap, ApIndex, AtomicApBitmap}; +use bit_field::{BitArray, BitField}; +use constants::{ApBitmap, ApIndex, AtomicApBitmap, MAX_APS_COUNT}; use x86_64::{ instructions::tlb::{self, InvPicdCommand, Invlpgb}, registers::{ @@ -140,6 +143,14 @@ impl ActivePageTableGuard { return; } + // If the hypervisor supports Hyper-V hypercalls, use them. + if let Some(hyper_v) = *HYPER_V { + hyper_v.flush_all(state.needs_flush); + self.flush_all_local(); + state.needs_flush = ApBitmap::empty(); + return; + } + // We've run out of optimizations :( // Flush on the current processor and send IPIs to the other relevant // APs. @@ -210,6 +221,15 @@ impl ActivePageTableGuard { return; } + // If the hypervisor supports Hyper-V hypercalls, use them. + if let Some(hyper_v) = *HYPER_V { + drop(guard); + + hyper_v.flush_address_list(pages.clone(), other_active_aps); + self.flush_pages_local(pages); + return; + } + // We've run out of optimizations :( // Flush on the current processor and send IPIs to the other relevant // APs. @@ -250,6 +270,7 @@ pub(super) struct GlobalFlushGuard; impl FlushGuard for GlobalFlushGuard { fn flush_page(&self, page: Page) { + // If invlpgb is supported, use it. if let Some(invlpgb) = &*INVLPGB { invlpgb .build() @@ -260,11 +281,19 @@ impl FlushGuard for GlobalFlushGuard { return; } - // Tell all other APs to flush their entire TLBs. let mut all_other_aps = ApBitmap::all(); all_other_aps.set(PerCpu::get().idx, false); PENDING_GLOBAL_TLB_SHOOTDOWN.set_all(all_other_aps); + // If the hypervisor supports Hyper-V hypercalls, use them. + if let Some(hyper_v) = *HYPER_V { + hyper_v.flush_address_list(page..=page, all_other_aps); + + // Flush the local TLB entry. + tlb::flush(page.start_address()); + return; + } + // Send IPIs to all other currently active APs. let other_active_aps = ACTIVE_APS.get_all() & all_other_aps; send_tlb_ipis(other_active_aps); @@ -279,3 +308,213 @@ impl FlushGuard for GlobalFlushGuard { } } } + +#[derive(Clone, Copy)] +enum Hypercall { + Vmmcall, + Vmcall, +} + +impl Hypercall { + fn get() -> Self { + let svm = unsafe { __cpuid(0x8000_0001) }.ecx.get_bit(2); + if svm { + Hypercall::Vmmcall + } else { + Hypercall::Vmcall + } + } +} + +static HYPER_V: Lazy> = Lazy::new(HyperV::new); + +#[derive(Clone, Copy)] +struct HyperV(Hypercall); + +impl HyperV { + pub fn new() -> Option { + // Make sure the hypervisor supports the HyperV hypercall ABI. + let cpuid_result = unsafe { + // SAFETY: If `cpuid` isn't available, we have bigger problems. + __cpuid(0x40000001) + }; + // Check the interface id. + if cpuid_result.eax != 0x31237648 { + return None; + } + + // Enable HyperV hypercalls. + const HV_X64_MSR_GUEST_OS_ID: u32 = 0x40000000; + unsafe { + Msr::new(HV_X64_MSR_GUEST_OS_ID).write(1); + } + + Some(Self(Hypercall::get())) + } + + const HV_FLUSH_ALL_VIRTUAL_ADDRESS_SPACES: u64 = 1 << 1; + + pub fn flush_address_list(&self, range: RangeInclusive, aps: ApBitmap) { + let count = range.clone().count(); + let gva_range = match count { + 0 => { + // There's nothing to flush. + return; + } + 1..1024 => range.start().start_address().as_u64() + (count as u64 - 1), + _ => { + // We can't encode the range. Fall back to flushing the entire TLB. + return self.flush_all(aps); + } + }; + + let mut input_value = 0; + input_value.set_bits(0..16, 0x0014); // Call Code: 0x0014 + input_value.set_bit(16, true); // Fast: true + input_value.set_bits(17..27, NUM_BANKS); // Variable header size: NUM_BANKS + input_value.set_bit(31, false); // Is Nested: false + input_value.set_bits(32..44, 1); // Rep Count: 1 + input_value.set_bits(48..60, 0); // Rep Start Index: 0 + + #[repr(C, align(16))] + struct HvFlushVirtualAddressListEx { + // header + address_space: u64, + flags: u64, + processor_set: HvVpSet, + // list + gva_range: u64, + } + + let input = HvFlushVirtualAddressListEx { + address_space: 0, + flags: Self::HV_FLUSH_ALL_VIRTUAL_ADDRESS_SPACES, + processor_set: HvVpSet::from_iter(aps), + gva_range, + }; + + // Assert that we can fit `HvFlushVirtualAddressListEx` into two GPRs + // and 2 XMM registers. + const { + assert!(size_of::().div_ceil(16) - 1 == 2); + } + + let output: u64; + unsafe { + asm!( + "mov rdx, qword ptr [{input} + 0]", + "mov r8, qword ptr [{input} + 8]", + "movdqa xmm0, xmmword ptr [{input} + 16]", + "movdqa xmm1, xmmword ptr [{input} + 32]", + "test {variant}, {VMCALL}", + "jnz 65f", + "vmmcall", + "jmp 66f", + "65:", + "vmcall", + "66:", + inout("rcx") input_value => _, + input = in(reg) &input, + variant = in(reg) self.0 as u64, + VMCALL = const Hypercall::Vmcall as u8, + out("rax") output, + out("rdx") _, + out("r8") _, + out("xmm0") _, + out("xmm1") _, + options(nostack), + ); + } + + assert_eq!(output.get_bits(0..16), 0); // Check result + assert_eq!(output.get_bits(32..44), 1); // Check resps completed + } + + pub fn flush_all(&self, aps: ApBitmap) { + let mut input_value = 0; + input_value.set_bits(0..16, 0x0013); // Call Code: 0x0013 + input_value.set_bit(16, true); // Fast: true + input_value.set_bits(17..27, NUM_BANKS); // Variable header size: NUM_BANKS + input_value.set_bit(31, false); // Is Nested: false + input_value.set_bits(32..44, 0); // Rep Count: 1 + input_value.set_bits(48..60, 0); // Rep Start Index: 0 + + #[repr(C, align(16))] + struct HvCallFlushVirtualAddressSpaceEx { + // header + address_space: u64, + flags: u64, + processor_set: HvVpSet, + } + + let input = HvCallFlushVirtualAddressSpaceEx { + address_space: 0, + flags: Self::HV_FLUSH_ALL_VIRTUAL_ADDRESS_SPACES, + processor_set: HvVpSet::from_iter(aps), + }; + + // Assert that we can fit `HvCallFlushVirtualAddressSpaceEx` into two GPRs + // and 2 XMM registers. + const { + assert!(size_of::().div_ceil(16) - 1 == 2); + } + + let output: u64; + unsafe { + asm!( + "mov rdx, qword ptr [{input} + 0]", + "mov r8, qword ptr [{input} + 8]", + "movdqa xmm0, xmmword ptr [{input} + 16]", + "movdqa xmm1, xmmword ptr [{input} + 32]", + "test {variant}, {VMCALL}", + "jnz 65f", + "vmmcall", + "jmp 66f", + "65:", + "vmcall", + "66:", + inout("rcx") input_value => _, + input = in(reg) &input, + variant = in(reg) self.0 as u64, + VMCALL = const Hypercall::Vmcall as u8, + out("rax") output, + out("rdx") _, + out("r8") _, + out("xmm0") _, + out("xmm1") _, + options(nostack), + ); + } + + assert_eq!(output.get_bits(0..16), 0); // Check result + } +} + +const NUM_BANKS: usize = (MAX_APS_COUNT as usize).div_ceil(64); + +#[repr(C)] +struct HvVpSet { + format: u64, + valid_banks_mask: u64, + bank_contents: [u64; NUM_BANKS], +} + +impl FromIterator for HvVpSet { + fn from_iter>(iter: T) -> Self { + let mut this = Self::default(); + for ap in iter { + this.bank_contents.set_bit(usize::from(ap.as_u8()), true); + } + this + } +} + +impl Default for HvVpSet { + fn default() -> Self { + Self { + format: 0, + valid_banks_mask: (1 << NUM_BANKS) - 1, + bank_contents: [0; NUM_BANKS], + } + } +} diff --git a/tee/supervisor-tdx/src/exception.rs b/tee/supervisor-tdx/src/exception.rs index c93bbe5b..d17f99fe 100644 --- a/tee/supervisor-tdx/src/exception.rs +++ b/tee/supervisor-tdx/src/exception.rs @@ -8,13 +8,9 @@ use x86_64::{ structures::idt::{InterruptDescriptorTable, InterruptStackFrame}, }; -use crate::{ - tdcall::{Tdcall, Vmcall}, - tlb::flush_handler, -}; +use crate::tdcall::{Tdcall, Vmcall}; pub const WAKEUP_VECTOR: u8 = 0x60; -pub const FLUSH_VECTOR: u8 = 0x61; pub fn setup_idt() { IDT.load(); @@ -26,7 +22,6 @@ static IDT: Lazy = Lazy::new(|| { let mut idt = InterruptDescriptorTable::new(); idt.virtualization.set_handler_fn(virtualization_handler); idt[WAKEUP_VECTOR].set_handler_fn(wakeup_handler); - idt[FLUSH_VECTOR].set_handler_fn(flush_handler); idt }); diff --git a/tee/supervisor-tdx/src/main.rs b/tee/supervisor-tdx/src/main.rs index 2c67a4e2..8fa81398 100644 --- a/tee/supervisor-tdx/src/main.rs +++ b/tee/supervisor-tdx/src/main.rs @@ -23,7 +23,6 @@ mod per_cpu; mod reset_vector; mod services; mod tdcall; -mod tlb; mod vcpu; fn main() -> ! { diff --git a/tee/supervisor-tdx/src/per_cpu.rs b/tee/supervisor-tdx/src/per_cpu.rs index 2447718b..452d40db 100644 --- a/tee/supervisor-tdx/src/per_cpu.rs +++ b/tee/supervisor-tdx/src/per_cpu.rs @@ -1,4 +1,4 @@ -use core::{arch::asm, cell::Cell}; +use core::arch::asm; use constants::ApIndex; use x86_64::instructions::interrupts; @@ -7,16 +7,11 @@ use x86_64::instructions::interrupts; pub struct PerCpu { this: *mut Self, pub vcpu_index: ApIndex, - pub pending_flushes: Cell, } impl PerCpu { pub fn new(this: *mut Self, vcpu_index: ApIndex) -> Self { - Self { - this, - vcpu_index, - pending_flushes: Cell::new(false), - } + Self { this, vcpu_index } } pub fn with(f: impl FnOnce(&Self) -> R) -> R { diff --git a/tee/supervisor-tdx/src/tlb.rs b/tee/supervisor-tdx/src/tlb.rs deleted file mode 100644 index 515f7933..00000000 --- a/tee/supervisor-tdx/src/tlb.rs +++ /dev/null @@ -1,41 +0,0 @@ -use constants::AtomicApBitmap; -use spin::mutex::SpinMutex; -use x86_64::structures::idt::InterruptStackFrame; - -use crate::{ - exception::{eoi, send_ipi, FLUSH_VECTOR}, - per_cpu::PerCpu, -}; - -static GUARD: SpinMutex<()> = SpinMutex::new(()); -static COUNTER: AtomicApBitmap = AtomicApBitmap::empty(); -static RAN: AtomicApBitmap = AtomicApBitmap::empty(); - -/// This function must be called before entering the vCPU. -pub fn pre_enter() { - RAN.set(PerCpu::current_vcpu_index()); -} - -/// Flush the entire TLB on all vCPUs. -pub fn flush() { - let _guard = GUARD.lock(); - let mask = RAN.take_all(); - COUNTER.set_all(mask); - drop(_guard); - - for idx in mask { - send_ipi(u32::from(idx.as_u8()), FLUSH_VECTOR); - } - - while !COUNTER.get_all().is_empty() {} -} - -pub extern "x86-interrupt" fn flush_handler(_frame: InterruptStackFrame) { - let vcpu_index = PerCpu::with(|per_cpu| { - per_cpu.pending_flushes.set(true); - per_cpu.vcpu_index - }); - COUNTER.take(vcpu_index); - - eoi(); -} diff --git a/tee/supervisor-tdx/src/vcpu.rs b/tee/supervisor-tdx/src/vcpu.rs index b6c0c10b..c2674314 100644 --- a/tee/supervisor-tdx/src/vcpu.rs +++ b/tee/supervisor-tdx/src/vcpu.rs @@ -13,7 +13,6 @@ use tdx_types::{ }, vmexit::{ VMEXIT_REASON_CPUID_INSTRUCTION, VMEXIT_REASON_HLT_INSTRUCTION, VMEXIT_REASON_MSR_WRITE, - VMEXIT_REASON_VMCALL_INSTRUCTION, }, }; use x86_64::{ @@ -29,7 +28,6 @@ use crate::{ per_cpu::PerCpu, services::handle, tdcall::{Tdcall, Vmcall}, - tlb, }; static READY: AtomicUsize = AtomicUsize::new(0); @@ -194,13 +192,8 @@ pub fn run_vcpu() -> ! { .guest_interrupt_status .set_bits(0..8, u16::from(rvi)); - tlb::pre_enter(); - let flush = if PerCpu::with(|per_cpu| per_cpu.pending_flushes.take()) { - InvdTranslations::All - } else { - InvdTranslations::None - }; - let vm_exit = Tdcall::vp_enter(VmIndex::One, flush, &mut guest_state, true); + let vm_exit = + Tdcall::vp_enter(VmIndex::One, InvdTranslations::None, &mut guest_state, true); match vm_exit.class { TDX_SUCCESS => {} @@ -229,9 +222,6 @@ pub fn run_vcpu() -> ! { VMEXIT_REASON_MSR_WRITE => { let value = guest_state.rax.get_bits(..32) | (guest_state.rdx << 32); match guest_state.rcx { - 0x40000000 => { - // Ignore writes to HV_X64_MSR_GUEST_OS_ID. - } // IA32_X2APIC_ICR 0x830 => { // We don't support all options. Check that we support the fields. @@ -256,25 +246,6 @@ pub fn run_vcpu() -> ! { } guest_state.rip += u64::from(vm_exit.vm_exit_instruction_length); } - VMEXIT_REASON_VMCALL_INSTRUCTION => { - // The kernel currently only executes vmcalls to flush the TLB. - // Double-check this. - assert_eq!( - guest_state.rcx, 0x10002, - "unsupported request: {:#x}", - guest_state.rcx - ); - assert_eq!( - guest_state.rdx, 3, - "unsupported flags: {:#x}", - guest_state.rdx - ); - - tlb::flush(); - - guest_state.rax = 0; - guest_state.rip += u64::from(vm_exit.vm_exit_instruction_length); - } unknown => panic!("{unknown:#x} {guest_state:x?} {vm_exit:x?}"), } } From 3e748c3ae1eb987ca3f1bfecdf4cad036023db2c Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 11 Nov 2024 19:04:35 +0100 Subject: [PATCH 19/21] remove #[repr(C)] from PerCpu --- tee/kernel/src/per_cpu.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tee/kernel/src/per_cpu.rs b/tee/kernel/src/per_cpu.rs index cda8235f..eef5d43f 100644 --- a/tee/kernel/src/per_cpu.rs +++ b/tee/kernel/src/per_cpu.rs @@ -1,6 +1,7 @@ use core::{ arch::asm, cell::{Cell, OnceCell, RefCell}, + mem::offset_of, ptr::null_mut, sync::atomic::{AtomicUsize, Ordering}, }; @@ -22,7 +23,6 @@ static COUNT: AtomicUsize = AtomicUsize::new(0); static mut STORAGE: [PerCpu; MAX_APS_COUNT as usize] = [const { PerCpu::new() }; MAX_APS_COUNT as usize]; -#[repr(C)] pub struct PerCpu { this: *mut PerCpu, pub idx: ApIndex, @@ -65,7 +65,12 @@ impl PerCpu { unsafe { // SAFETY: If the GS segment wasn't programmed yet, this will cause // a page fault, which is a safe thing to do. - asm!("mov {}, gs:[0]", out(reg) addr, options(pure, nomem, preserves_flags, nostack)); + asm!( + "mov {}, gs:[{THIS_OFFSET}]", + out(reg) addr, + THIS_OFFSET = const offset_of!(Self, this), + options(pure, nomem, preserves_flags, nostack), + ); } let ptr = addr as *const Self; unsafe { &*ptr } From 63ecb087f420cf316003115af4f0241dd6c38909 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Mon, 11 Nov 2024 19:05:02 +0100 Subject: [PATCH 20/21] remove unused fields in PerCpu --- tee/kernel/src/per_cpu.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tee/kernel/src/per_cpu.rs b/tee/kernel/src/per_cpu.rs index eef5d43f..88a22b94 100644 --- a/tee/kernel/src/per_cpu.rs +++ b/tee/kernel/src/per_cpu.rs @@ -10,7 +10,7 @@ use alloc::sync::Arc; use constants::{ApIndex, MAX_APS_COUNT}; use x86_64::{ registers::segmentation::{Segment64, GS}, - structures::{gdt::GlobalDescriptorTable, paging::Page, tss::TaskStateSegment}, + structures::{gdt::GlobalDescriptorTable, tss::TaskStateSegment}, VirtAddr, }; @@ -28,10 +28,8 @@ pub struct PerCpu { pub idx: ApIndex, pub kernel_registers: Cell, pub new_userspace_registers: Cell, - pub temporary_mapping: OnceCell>, pub tss: OnceCell, pub gdt: OnceCell, - pub int0x80_handler: Cell, pub exit_with_sysret: Cell, pub exit: Cell, pub vector: Cell, @@ -47,10 +45,8 @@ impl PerCpu { idx: ApIndex::new(0), kernel_registers: Cell::new(KernelRegisters::ZERO), new_userspace_registers: Cell::new(Registers::ZERO), - temporary_mapping: OnceCell::new(), tss: OnceCell::new(), gdt: OnceCell::new(), - int0x80_handler: Cell::new(0), exit_with_sysret: Cell::new(false), exit: Cell::new(RawExit::Syscall), vector: Cell::new(0), From 7aaaeca2b81a5b172e40710a359adeac84fb839b Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Wed, 13 Nov 2024 09:34:45 +0100 Subject: [PATCH 21/21] decrease stack size in supervisor-tdx --- tee/supervisor-tdx/src/reset_vector.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tee/supervisor-tdx/src/reset_vector.rs b/tee/supervisor-tdx/src/reset_vector.rs index e8337375..64a0132c 100644 --- a/tee/supervisor-tdx/src/reset_vector.rs +++ b/tee/supervisor-tdx/src/reset_vector.rs @@ -5,7 +5,7 @@ use x86_64::{registers::model_specific::FsBase, VirtAddr}; use crate::{main, per_cpu::PerCpu}; -pub const STACK_SIZE: usize = 16; +pub const STACK_SIZE: usize = 8; global_asm!( include_str!("reset_vector.s"),