Skip to content

Commit

Permalink
Move mutexes into shared data structure
Browse files Browse the repository at this point in the history
  • Loading branch information
maneatingape committed Sep 7, 2024
1 parent 420f718 commit 8a85454
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 71 deletions.
28 changes: 14 additions & 14 deletions src/year2016/day05.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ struct Shared {
prefix: String,
done: AtomicBool,
counter: AtomicU32,
mutex: Mutex<Exclusive>,
}

struct Exclusive {
Expand All @@ -25,24 +26,24 @@ pub fn parse(input: &str) -> Vec<u32> {
prefix: input.trim().to_owned(),
done: AtomicBool::new(false),
counter: AtomicU32::new(1000),
mutex: Mutex::new(Exclusive { found: vec![], mask: 0 }),
};
let mutex = Mutex::new(Exclusive { found: vec![], mask: 0 });

// Handle the first 999 numbers specially as the number of digits varies.
for n in 1..1000 {
let (mut buffer, size) = format_string(&shared.prefix, n);
check_hash(&mut buffer, size, n, &shared, &mutex);
check_hash(&mut buffer, size, n, &shared);
}

// Use as many cores as possible to parallelize the remaining search.
spawn(|| {
#[cfg(not(feature = "simd"))]
worker(&shared, &mutex);
worker(&shared);
#[cfg(feature = "simd")]
simd::worker(&shared, &mutex);
simd::worker(&shared);
});

let mut found = mutex.into_inner().unwrap().found;
let mut found = shared.mutex.into_inner().unwrap().found;
found.sort_unstable();
found.iter().map(|&(_, n)| n).collect()
}
Expand Down Expand Up @@ -79,11 +80,11 @@ fn format_string(prefix: &str, n: u32) -> ([u8; 64], usize) {
(buffer, size)
}

fn check_hash(buffer: &mut [u8], size: usize, n: u32, shared: &Shared, mutex: &Mutex<Exclusive>) {
fn check_hash(buffer: &mut [u8], size: usize, n: u32, shared: &Shared) {
let (result, ..) = hash(buffer, size);

if result & 0xfffff000 == 0 {
let mut exclusive = mutex.lock().unwrap();
let mut exclusive = shared.mutex.lock().unwrap();

exclusive.found.push((n, result));
exclusive.mask |= 1 << (result >> 8);
Expand All @@ -95,7 +96,7 @@ fn check_hash(buffer: &mut [u8], size: usize, n: u32, shared: &Shared, mutex: &M
}

#[cfg(not(feature = "simd"))]
fn worker(shared: &Shared, mutex: &Mutex<Exclusive>) {
fn worker(shared: &Shared) {
while !shared.done.load(Ordering::Relaxed) {
let offset = shared.counter.fetch_add(1000, Ordering::Relaxed);
let (mut buffer, size) = format_string(&shared.prefix, offset);
Expand All @@ -106,7 +107,7 @@ fn worker(shared: &Shared, mutex: &Mutex<Exclusive>) {
buffer[size - 2] = b'0' + ((n / 10) % 10) as u8;
buffer[size - 1] = b'0' + (n % 10) as u8;

check_hash(&mut buffer, size, offset + n, shared, mutex);
check_hash(&mut buffer, size, offset + n, shared);
}
}
}
Expand All @@ -124,7 +125,6 @@ mod simd {
start: u32,
offset: u32,
shared: &Shared,
mutex: &Mutex<Exclusive>,
) where
LaneCount<N>: SupportedLaneCount,
{
Expand All @@ -140,7 +140,7 @@ mod simd {

for i in 0..N {
if result[i] & 0xfffff000 == 0 {
let mut exclusive = mutex.lock().unwrap();
let mut exclusive = shared.mutex.lock().unwrap();

exclusive.found.push((start + offset + i as u32, result[i]));
exclusive.mask |= 1 << (result[i] >> 8);
Expand All @@ -152,17 +152,17 @@ mod simd {
}
}

pub(super) fn worker(shared: &Shared, mutex: &Mutex<Exclusive>) {
pub(super) fn worker(shared: &Shared) {
while !shared.done.load(Ordering::Relaxed) {
let start = shared.counter.fetch_add(1000, Ordering::Relaxed);
let (prefix, size) = format_string(&shared.prefix, start);
let mut buffers = [prefix; 32];

for offset in (0..992).step_by(32) {
check_hash_simd::<32>(&mut buffers, size, start, offset, shared, mutex);
check_hash_simd::<32>(&mut buffers, size, start, offset, shared);
}

check_hash_simd::<8>(&mut buffers, size, start, 992, shared, mutex);
check_hash_simd::<8>(&mut buffers, size, start, 992, shared);
}
}
}
31 changes: 19 additions & 12 deletions src/year2016/day14.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ use std::sync::Mutex;
/// Atomics can be safely shared between threads.
struct Shared<'a> {
input: &'a str,
part_two: bool,
done: AtomicBool,
counter: AtomicI32,
mutex: Mutex<Exclusive>,
}

/// Regular data structures need to be protected by a mutex.
Expand All @@ -38,20 +40,25 @@ pub fn part2(input: &str) -> i32 {

/// Find the first 64 keys that sastify the rules.
fn generate_pad(input: &str, part_two: bool) -> i32 {
let shared = Shared { input, done: AtomicBool::new(false), counter: AtomicI32::new(0) };
let exclusive =
Exclusive { threes: BTreeMap::new(), fives: BTreeMap::new(), found: BTreeSet::new() };
let mutex = Mutex::new(exclusive);
let shared = Shared {
input,
part_two,
done: AtomicBool::new(false),
counter: AtomicI32::new(0),
mutex: Mutex::new(exclusive),
};

// Use as many cores as possible to parallelize the search.
spawn(|| worker(&shared, &mutex, part_two));
spawn(|| worker(&shared));

let exclusive = mutex.into_inner().unwrap();
let exclusive = shared.mutex.into_inner().unwrap();
*exclusive.found.iter().nth(63).unwrap()
}

#[cfg(not(feature = "simd"))]
fn worker(shared: &Shared<'_>, mutex: &Mutex<Exclusive>, part_two: bool) {
fn worker(shared: &Shared<'_>) {
while !shared.done.load(Ordering::Relaxed) {
// Get the next key to check.
let n = shared.counter.fetch_add(1, Ordering::Relaxed);
Expand All @@ -60,7 +67,7 @@ fn worker(shared: &Shared<'_>, mutex: &Mutex<Exclusive>, part_two: bool) {
let (mut buffer, size) = format_string(shared.input, n);
let mut result = hash(&mut buffer, size);

if part_two {
if shared.part_two {
for _ in 0..2016 {
buffer[0..8].copy_from_slice(&to_ascii(result.0));
buffer[8..16].copy_from_slice(&to_ascii(result.1));
Expand All @@ -70,14 +77,14 @@ fn worker(shared: &Shared<'_>, mutex: &Mutex<Exclusive>, part_two: bool) {
}
}

check(shared, mutex, n, result);
check(shared, n, result);
}
}

/// Use SIMD to compute hashes in parallel in blocks of 32.
#[cfg(feature = "simd")]
#[allow(clippy::needless_range_loop)]
fn worker(shared: &Shared<'_>, mutex: &Mutex<Exclusive>, part_two: bool) {
fn worker(shared: &Shared<'_>) {
let mut result = ([0; 32], [0; 32], [0; 32], [0; 32]);
let mut buffers = [[0; 64]; 32];

Expand All @@ -96,7 +103,7 @@ fn worker(shared: &Shared<'_>, mutex: &Mutex<Exclusive>, part_two: bool) {
result.3[i] = d;
}

if part_two {
if shared.part_two {
for _ in 0..2016 {
for i in 0..32 {
buffers[i][0..8].copy_from_slice(&to_ascii(result.0[i]));
Expand All @@ -110,13 +117,13 @@ fn worker(shared: &Shared<'_>, mutex: &Mutex<Exclusive>, part_two: bool) {

for i in 0..32 {
let hash = (result.0[i], result.1[i], result.2[i], result.3[i]);
check(shared, mutex, start + i as i32, hash);
check(shared, start + i as i32, hash);
}
}
}

/// Check for sequences of 3 or 5 consecutive matching digits.
fn check(shared: &Shared<'_>, mutex: &Mutex<Exclusive>, n: i32, hash: (u32, u32, u32, u32)) {
fn check(shared: &Shared<'_>, n: i32, hash: (u32, u32, u32, u32)) {
let (a, b, c, d) = hash;

let mut prev = u32::MAX;
Expand Down Expand Up @@ -147,7 +154,7 @@ fn check(shared: &Shared<'_>, mutex: &Mutex<Exclusive>, n: i32, hash: (u32, u32,
}

if three != 0 || five != 0 {
let mut exclusive = mutex.lock().unwrap();
let mut exclusive = shared.mutex.lock().unwrap();
let mut candidates = Vec::new();

// Compare against all 5 digit sequences.
Expand Down
17 changes: 10 additions & 7 deletions src/year2017/day14.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::sync::Mutex;
pub struct Shared {
prefix: String,
counter: AtomicUsize,
mutex: Mutex<Exclusive>,
}

/// Regular data structures need to be protected by a mutex.
Expand All @@ -22,14 +23,16 @@ struct Exclusive {

/// Parallelize the hashing as each row is independent.
pub fn parse(input: &str) -> Vec<u8> {
let shared = Shared { prefix: input.trim().to_owned(), counter: AtomicUsize::new(0) };
let exclusive = Exclusive { grid: vec![0; 0x4000] };
let mutex = Mutex::new(exclusive);
let shared = Shared {
prefix: input.trim().to_owned(),
counter: AtomicUsize::new(0),
mutex: Mutex::new(Exclusive { grid: vec![0; 0x4000] }),
};

// Use as many cores as possible to parallelize the hashing.
spawn(|| worker(&shared, &mutex));
spawn(|| worker(&shared));

mutex.into_inner().unwrap().grid
shared.mutex.into_inner().unwrap().grid
}

pub fn part1(input: &[u8]) -> u32 {
Expand All @@ -53,7 +56,7 @@ pub fn part2(input: &[u8]) -> u32 {

/// Each worker thread chooses the next available index then computes the hash and patches the
/// final vec with the result.
fn worker(shared: &Shared, mutex: &Mutex<Exclusive>) {
fn worker(shared: &Shared) {
loop {
let index = shared.counter.fetch_add(1, Ordering::Relaxed);
if index >= 128 {
Expand All @@ -64,7 +67,7 @@ fn worker(shared: &Shared, mutex: &Mutex<Exclusive>) {
let start = index * 128;
let end = start + 128;

let mut exclusive = mutex.lock().unwrap();
let mut exclusive = shared.mutex.lock().unwrap();
exclusive.grid[start..end].copy_from_slice(&row);
}
}
Expand Down
17 changes: 11 additions & 6 deletions src/year2018/day11.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ pub struct Result {
power: i32,
}

struct Shared {
sat: Vec<i32>,
mutex: Mutex<Vec<Result>>,
}

pub fn parse(input: &str) -> Vec<Result> {
let grid_serial_number: i32 = input.signed();

Expand Down Expand Up @@ -45,9 +50,9 @@ pub fn parse(input: &str) -> Vec<Result> {
// * 2, 6, 10, ..
// * 3, 7, 11, ..
// * 4, 8, 12, ..
let mutex = Mutex::new(Vec::new());
spawn_batches((1..301).collect(), |batch| worker(batch, &sat, &mutex));
mutex.into_inner().unwrap()
let shared = Shared { sat, mutex: Mutex::new(Vec::new()) };
spawn_batches((1..301).collect(), |batch| worker(&shared, batch));
shared.mutex.into_inner().unwrap()
}

pub fn part1(input: &[Result]) -> String {
Expand All @@ -60,16 +65,16 @@ pub fn part2(input: &[Result]) -> String {
format!("{x},{y},{size}")
}

fn worker(batch: Vec<usize>, sat: &[i32], mutex: &Mutex<Vec<Result>>) {
fn worker(shared: &Shared, batch: Vec<usize>) {
let result: Vec<_> = batch
.into_iter()
.map(|size| {
let (power, x, y) = square(sat, size);
let (power, x, y) = square(&shared.sat, size);
Result { x, y, size, power }
})
.collect();

mutex.lock().unwrap().extend(result);
shared.mutex.lock().unwrap().extend(result);
}

/// Find the (x,y) coordinates and max power for a square of the specified size.
Expand Down
13 changes: 6 additions & 7 deletions src/year2021/day18.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
//! `2i + 1`, right child at index `2i + 2` and parent at index `i / 2`. As leaf nodes are
//! always greater than or equal to zero, `-1` is used as a special sentinel value for non-leaf nodes.
use crate::util::thread::*;
use std::sync::Mutex;
use std::sync::atomic::{AtomicI32, Ordering};

type Snailfish = [i32; 63];

Expand Down Expand Up @@ -85,21 +85,20 @@ pub fn part2(input: &[Snailfish]) -> i32 {

// Use as many cores as possible to parallelize the calculation,
// breaking the work into roughly equally size batches.
let mutex = Mutex::new(0);
spawn_batches(pairs, |batch| worker(&batch, &mutex));
mutex.into_inner().unwrap()
let shared = AtomicI32::new(0);
spawn_batches(pairs, |batch| worker(&shared, &batch));
shared.load(Ordering::Relaxed)
}

/// Pair addition is independent so we can parallelize across multiple threads.
fn worker(batch: &[(&Snailfish, &Snailfish)], mutex: &Mutex<i32>) {
fn worker(shared: &AtomicI32, batch: &[(&Snailfish, &Snailfish)]) {
let mut partial = 0;

for (a, b) in batch {
partial = partial.max(magnitude(&mut add(a, b)));
}

let mut result = mutex.lock().unwrap();
*result = result.max(partial);
shared.fetch_max(partial, Ordering::Relaxed);
}

/// Add two snailfish numbers.
Expand Down
23 changes: 13 additions & 10 deletions src/year2022/day11.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ pub enum Operation {
type Pair = (usize, u64);
type Business = [u64; 8];

struct Shared<'a> {
monkeys: &'a [Monkey],
mutex: Mutex<Exclusive>,
}

struct Exclusive {
pairs: Vec<Pair>,
business: Business,
Expand Down Expand Up @@ -125,30 +130,28 @@ fn sequential(monkeys: &[Monkey], pairs: Vec<Pair>) -> Business {

/// Play 10,000 rounds adjusting the worry level modulo the product of all the monkey's test values.
fn parallel(monkeys: &[Monkey], pairs: Vec<Pair>) -> Business {
let business = [0; 8];
let exclusive = Exclusive { pairs, business };
let mutex = Mutex::new(exclusive);
let shared = Shared { monkeys, mutex: Mutex::new(Exclusive { pairs, business: [0; 8] }) };

// Use as many cores as possible to parallelize the calculation.
spawn(|| worker(monkeys, &mutex));
spawn(|| worker(&shared));

mutex.into_inner().unwrap().business
shared.mutex.into_inner().unwrap().business
}

/// Multiple worker functions are executed in parallel, one per thread.
fn worker(monkeys: &[Monkey], mutex: &Mutex<Exclusive>) {
let product: u64 = monkeys.iter().map(|m| m.test).product();
fn worker(shared: &Shared<'_>) {
let product: u64 = shared.monkeys.iter().map(|m| m.test).product();

loop {
// Take an item from the queue until empty, using the mutex to allow access
// to a single thread at a time.
let Some(pair) = mutex.lock().unwrap().pairs.pop() else {
let Some(pair) = shared.mutex.lock().unwrap().pairs.pop() else {
break;
};

let extra = play(monkeys, 10000, |x| x % product, pair);
let extra = play(shared.monkeys, 10000, |x| x % product, pair);

let mut exclusive = mutex.lock().unwrap();
let mut exclusive = shared.mutex.lock().unwrap();
exclusive.business.iter_mut().enumerate().for_each(|(i, b)| *b += extra[i]);
}
}
Expand Down
Loading

0 comments on commit 8a85454

Please sign in to comment.