Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify interrupt handler #564

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agb/examples/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ static COUNT: Static<u32> = Static::new(0);
#[agb::entry]
fn main(_gba: agb::Gba) -> ! {
let _a = unsafe {
agb::interrupt::add_interrupt_handler(agb::interrupt::Interrupt::VBlank, |_| {
agb::interrupt::add_interrupt_handler(agb::interrupt::Interrupt::VBlank, || {
let cur_count = COUNT.read();
agb::println!("Hello, world, frame = {}", cur_count);
COUNT.write(cur_count + 1);
Expand Down
14 changes: 8 additions & 6 deletions agb/examples/wave.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use agb::{
fixnum::FixedNum,
interrupt::{free, Interrupt},
};
use bare_metal::{CriticalSection, Mutex};
use bare_metal::Mutex;

struct BackCosines {
cosines: [u16; 32],
Expand All @@ -36,11 +36,13 @@ fn main(mut gba: agb::Gba) -> ! {
example_logo::display_logo(&mut background, &mut vram);

let _a = unsafe {
agb::interrupt::add_interrupt_handler(Interrupt::HBlank, |key: CriticalSection| {
let mut back = BACK.borrow(key).borrow_mut();
let deflection = back.cosines[back.row % 32];
((0x0400_0010) as *mut u16).write_volatile(deflection);
back.row += 1;
agb::interrupt::add_interrupt_handler(Interrupt::HBlank, || {
free(|key| {
let mut back = BACK.borrow(key).borrow_mut();
let deflection = back.cosines[back.row % 32];
((0x0400_0010) as *mut u16).write_volatile(deflection);
back.row += 1;
});
})
};

Expand Down
218 changes: 64 additions & 154 deletions agb/src/interrupt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::{cell::Cell, marker::PhantomPinned, pin::Pin};
use core::cell::Cell;

use alloc::boxed::Box;
use alloc::{rc::Rc, vec::Vec};
use bare_metal::CriticalSection;

use crate::{display::DISPLAY_STATUS, memory_mapped::MemoryMapped, sync::Static};
Expand All @@ -24,16 +24,14 @@ pub enum Interrupt {
}

impl Interrupt {
fn enable(self) {
let _interrupt_token = temporary_interrupt_disable();
fn enable(self, _cs: CriticalSection) {
self.other_things_to_enable_interrupt();
let interrupt = self as usize;
let enabled = ENABLED_INTERRUPTS.get() | (1 << (interrupt as u16));
ENABLED_INTERRUPTS.set(enabled);
}

fn disable(self) {
let _interrupt_token = temporary_interrupt_disable();
fn disable(self, _cs: CriticalSection) {
self.other_things_to_disable_interrupt();
let interrupt = self as usize;
let enabled = ENABLED_INTERRUPTS.get() & !(1 << (interrupt as u16));
Expand Down Expand Up @@ -66,77 +64,39 @@ impl Interrupt {
}

const ENABLED_INTERRUPTS: MemoryMapped<u16> = unsafe { MemoryMapped::new(0x04000200) };
const INTERRUPTS_ENABLED: MemoryMapped<u16> = unsafe { MemoryMapped::new(0x04000208) };
const INTERRUPTS_ENABLED: MemoryMapped<u32> = unsafe { MemoryMapped::new(0x04000208) };

struct Disable {
pre: u16,
}

impl Drop for Disable {
fn drop(&mut self) {
INTERRUPTS_ENABLED.set(self.pre);
}
}

fn temporary_interrupt_disable() -> Disable {
let d = Disable {
pre: INTERRUPTS_ENABLED.get(),
};
disable_interrupts();
d
}

fn disable_interrupts() {
INTERRUPTS_ENABLED.set(0);
extern "C" {
static mut __INTERRUPT_NEST: u32;
}

struct InterruptRoot {
next: Cell<*const InterruptInner>,
count: Cell<i32>,
interrupt: Interrupt,
interrupts: Vec<Rc<dyn Fn() + Send + Sync>>,
}

impl InterruptRoot {
const fn new(interrupt: Interrupt) -> Self {
const fn new() -> Self {
InterruptRoot {
next: Cell::new(core::ptr::null()),
count: Cell::new(0),
interrupt,
interrupts: Vec::new(),
}
}

fn reduce(&self) {
let new_count = self.count.get() - 1;
if new_count == 0 {
self.interrupt.disable();
}
self.count.set(new_count);
}

fn add(&self) {
let count = self.count.get();
if count == 0 {
self.interrupt.enable();
}
self.count.set(count + 1);
}
}

static mut INTERRUPT_TABLE: [InterruptRoot; 14] = [
InterruptRoot::new(Interrupt::VBlank),
InterruptRoot::new(Interrupt::HBlank),
InterruptRoot::new(Interrupt::VCounter),
InterruptRoot::new(Interrupt::Timer0),
InterruptRoot::new(Interrupt::Timer1),
InterruptRoot::new(Interrupt::Timer2),
InterruptRoot::new(Interrupt::Timer3),
InterruptRoot::new(Interrupt::Serial),
InterruptRoot::new(Interrupt::Dma0),
InterruptRoot::new(Interrupt::Dma1),
InterruptRoot::new(Interrupt::Dma2),
InterruptRoot::new(Interrupt::Dma3),
InterruptRoot::new(Interrupt::Keypad),
InterruptRoot::new(Interrupt::Gamepak),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
InterruptRoot::new(),
];

#[no_mangle]
Expand All @@ -150,81 +110,39 @@ extern "C" fn __RUST_INTERRUPT_HANDLER(interrupt: u16) -> u16 {
interrupt
}

struct InterruptInner {
next: Cell<*const InterruptInner>,
root: *const InterruptRoot,
closure: *const dyn Fn(CriticalSection),
_pin: PhantomPinned,
}

unsafe fn create_interrupt_inner(
c: impl Fn(CriticalSection),
root: *const InterruptRoot,
) -> Pin<Box<InterruptInner>> {
let c = Box::new(c);
let c: &dyn Fn(CriticalSection) = Box::leak(c);
let c: &dyn Fn(CriticalSection) = core::mem::transmute(c);
Box::pin(InterruptInner {
next: Cell::new(core::ptr::null()),
root,
closure: c,
_pin: PhantomPinned,
})
pub struct InterruptHandler {
kind: Interrupt,
closure: Rc<dyn Fn() + Send + Sync + 'static>,
}

impl Drop for InterruptInner {
impl Drop for InterruptHandler {
fn drop(&mut self) {
inner_drop(unsafe { Pin::new_unchecked(self) });
#[allow(clippy::needless_pass_by_value)] // needed for safety reasons
fn inner_drop(this: Pin<&mut InterruptInner>) {
// drop the closure allocation safely
let _closure_box =
unsafe { Box::from_raw(this.closure as *mut dyn Fn(&CriticalSection)) };

// perform the rest of the drop sequence
let root = unsafe { &*this.root };
root.reduce();
let mut c = root.next.get();
let own_pointer = &*this as *const _;
if c == own_pointer {
unsafe { &*this.root }.next.set(this.next.get());
return;
}
loop {
let p = unsafe { &*c }.next.get();
if p == own_pointer {
unsafe { &*c }.next.set(this.next.get());
return;
}
c = p;
free(|cs| {
let root = unsafe { interrupt_to_root(self.kind) };
root.interrupts.retain(|x| {
!core::ptr::eq::<dyn Fn() + Send + Sync + 'static>(&**x, &*self.closure)
});
if root.interrupts.is_empty() {
self.kind.disable(cs);
}
}
});
}
}

pub struct InterruptHandler {
_inner: Pin<Box<InterruptInner>>,
}

impl InterruptRoot {
fn trigger_interrupts(&self) {
let mut c = self.next.get();
while !c.is_null() {
let closure_ptr = unsafe { &*c }.closure;
let closure_ref = unsafe { &*closure_ptr };
closure_ref(unsafe { CriticalSection::new() });
c = unsafe { &*c }.next.get();
for interrupt in self.interrupts.iter() {
(interrupt)();
}
}
}

fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot {
unsafe { &INTERRUPT_TABLE[interrupt as usize] }
unsafe fn interrupt_to_root(interrupt: Interrupt) -> &'static mut InterruptRoot {
unsafe { &mut INTERRUPT_TABLE[interrupt as usize] }
}

#[must_use]
/// Adds an interrupt handler as long as the returned value is alive. The
/// closure takes a [`CriticalSection`] which can be used for mutexes.
/// Adds an interrupt handler as long as the returned value is alive.
///
/// # Safety
/// * You *must not* allocate in an interrupt.
Expand All @@ -234,53 +152,45 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot {
/// * The closure must be static because forgetting the interrupt handler would
/// cause a use after free.
///
/// [`CriticalSection`]: bare_metal::CriticalSection
///
/// # Examples
///
/// ```rust,no_run
/// # #![no_std]
/// # #![no_main]
/// # fn foo() {
/// use bare_metal::CriticalSection;
/// use agb::interrupt::{add_interrupt_handler, Interrupt};
/// // Safety: doesn't allocate
/// let _a = unsafe {
/// add_interrupt_handler(Interrupt::VBlank, |_: CriticalSection| {
/// add_interrupt_handler(Interrupt::VBlank, || {
/// agb::println!("Woah there! There's been a vblank!");
/// })
/// };
/// # }
/// ```
pub unsafe fn add_interrupt_handler(
interrupt: Interrupt,
handler: impl Fn(CriticalSection) + Send + Sync + 'static,
handler: impl Fn() + Send + Sync + 'static,
) -> InterruptHandler {
fn do_with_inner(interrupt: Interrupt, inner: Pin<Box<InterruptInner>>) -> InterruptHandler {
free(|_| {
let root = interrupt_to_root(interrupt);
root.add();
let mut c = root.next.get();
if c.is_null() {
root.next.set((&*inner) as *const _);
return;
}
loop {
let p = unsafe { &*c }.next.get();
if p.is_null() {
unsafe { &*c }.next.set((&*inner) as *const _);
return;
}

c = p;
}
});
fn inner(
interrupt: Interrupt,
handle: Rc<dyn Fn() + Send + Sync + 'static>,
cs: CriticalSection,
) -> InterruptHandler {
let interrupts = unsafe { interrupt_to_root(interrupt) };

if interrupts.interrupts.is_empty() {
interrupt.enable(cs);
}

InterruptHandler { _inner: inner }
interrupts.interrupts.push(handle.clone());

InterruptHandler {
kind: interrupt,
closure: handle,
}
}
let root = interrupt_to_root(interrupt) as *const _;
let inner = unsafe { create_interrupt_inner(handler, root) };
do_with_inner(interrupt, inner)

free(|cs| inner(interrupt, Rc::new(handler), cs))
}

/// How you can access mutexes outside of interrupts by being given a
Expand All @@ -293,7 +203,7 @@ where
{
let enabled = INTERRUPTS_ENABLED.get();

disable_interrupts();
INTERRUPTS_ENABLED.set(0);

// prevents the contents of the function from being reordered before IME is disabled.
crate::sync::memory_write_hint(&mut f);
Expand Down Expand Up @@ -323,7 +233,7 @@ impl VBlank {
if !HAS_CREATED_INTERRUPT.read() {
// safety: we don't allocate in the interrupt
let handler = unsafe {
add_interrupt_handler(Interrupt::VBlank, |_| {
add_interrupt_handler(Interrupt::VBlank, || {
NUM_VBLANKS.write(NUM_VBLANKS.read() + 1);
})
};
Expand Down Expand Up @@ -365,7 +275,7 @@ pub fn profiler(timer: &mut crate::timer::Timer, period: u16) -> InterruptHandle
timer.set_enabled(true);

unsafe {
add_interrupt_handler(timer.interrupt(), |_key: CriticalSection| {
add_interrupt_handler(timer.interrupt(), || {
crate::println!("{:#010x}", crate::program_counter_before_interrupt());
})
}
Expand Down
6 changes: 4 additions & 2 deletions agb/src/sound/mixer/sw_mixer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ impl Mixer<'_> {
let buffer_pointer_for_interrupt_handler: &MixerBuffer =
unsafe { core::mem::transmute(buffer_pointer_for_interrupt_handler) };
let interrupt_handler = unsafe {
add_interrupt_handler(interrupt_timer.interrupt(), |cs| {
buffer_pointer_for_interrupt_handler.swap(cs);
add_interrupt_handler(interrupt_timer.interrupt(), || {
free(|cs| {
buffer_pointer_for_interrupt_handler.swap(cs);
});
})
};

Expand Down
3 changes: 1 addition & 2 deletions agb/src/sync/statics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ unsafe impl<T> Sync for Static<T> {}

#[cfg(test)]
mod test {
use crate::interrupt::Interrupt;
use crate::sync::Static;
use crate::timer::Divider;
use crate::Gba;
Expand All @@ -282,7 +281,7 @@ mod test {
timer.set_enabled(true);

let _int = unsafe {
crate::interrupt::add_interrupt_handler(Interrupt::Timer2, |_| {
crate::interrupt::add_interrupt_handler(timer.interrupt(), || {
VALUE.write(SENTINEL);
})
};
Expand Down