Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace fastrand with getrandom and base64 #188

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ repository = "https://github.com/Stebalien/tempfile"
description = "A library for managing temporary files and directories."

[dependencies]
base64 = "0.13.0"
cfg-if = "1"
fastrand = "1.6.0"
getrandom = "0.2.7"
remove_dir_all = "0.5"

[target.'cfg(any(unix, target_os = "wasi"))'.dependencies]
Expand Down
35 changes: 31 additions & 4 deletions src/util.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
use std::ffi::{OsStr, OsString};
use std::path::{Path, PathBuf};
use std::{io, iter::repeat_with};
use std::io;

use crate::error::IoResultExt;

fn calculate_rand_buf_len(alphanumeric_len: usize) -> usize {
let expected_non_alphanumeric_chars = alphanumeric_len / 32;
(alphanumeric_len + expected_non_alphanumeric_chars) * 3 / 4 + 3
}

fn calculate_base64_len(binary_len: usize) -> usize {
binary_len * 4 / 3 + 4
}

fn fill_with_random_base64(rand_buf: &mut [u8], char_buf: &mut Vec<u8>) {
getrandom::getrandom(rand_buf).expect("calling getrandom failed");
char_buf.resize(calculate_base64_len(rand_buf.len()), 0);
base64::encode_config_slice(rand_buf, base64::STANDARD_NO_PAD, char_buf);
}

fn tmpname(prefix: &OsStr, suffix: &OsStr, rand_len: usize) -> OsString {
let mut buf = OsString::with_capacity(prefix.len() + suffix.len() + rand_len);
buf.push(prefix);
let mut char_buf = [0u8; 4];
for c in repeat_with(fastrand::alphanumeric).take(rand_len) {
buf.push(c.encode_utf8(&mut char_buf));

let mut rand_buf = vec![0; calculate_rand_buf_len(rand_len)];
let mut char_buf = vec![0; calculate_base64_len(rand_buf.len())];
let mut remaining_chars = rand_len;
loop {
fill_with_random_base64(&mut rand_buf, &mut char_buf);
char_buf.retain(|&c| (c != b'+') & (c != b'/') & (c != 0));
if char_buf.len() >= remaining_chars {
buf.push(std::str::from_utf8(&char_buf[..remaining_chars]).unwrap());
break;
} else {
buf.push(std::str::from_utf8(&char_buf).unwrap());
remaining_chars -= char_buf.len();
}
}

buf.push(suffix);
buf
}
Expand Down
52 changes: 0 additions & 52 deletions tests/namedtempfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,55 +387,3 @@ fn test_make_uds() {

assert!(temp_sock.path().exists());
}

#[cfg(unix)]
#[test]
fn test_make_uds_conflict() {
Copy link

@5225225 5225225 Aug 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test can presumably still exist, but less deterministically, by setting the rand_bytes to 1, then trying to generate 62 files. (Which will require that some of the threads retry unless they all get incredibly lucky, which won't happen).

I'd also use a https://doc.rust-lang.org/std/sync/struct.Barrier.html inside the thread but before the generation, to ensure they all wait until everyone's ready, so that if thread spawning is slower than file generation, it's not just sequential.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is possible, but the number of retries will no longer be deterministic.

use std::os::unix::net::UnixListener;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

// Check that retries happen correctly by racing N different threads.

const NTHREADS: usize = 20;

// The number of times our callback was called.
let tries = Arc::new(AtomicUsize::new(0));

let mut threads = Vec::with_capacity(NTHREADS);

for _ in 0..NTHREADS {
let tries = tries.clone();
threads.push(std::thread::spawn(move || {
// Ensure that every thread uses the same seed so we are guaranteed
// to retry. Note that fastrand seeds are thread-local.
fastrand::seed(42);

Builder::new()
.prefix("tmp")
.suffix(".sock")
.rand_bytes(12)
.make(|path| {
tries.fetch_add(1, Ordering::Relaxed);
UnixListener::bind(path)
})
}));
}

// Join all threads, but don't drop the temp file yet. Otherwise, we won't
// get a deterministic number of `tries`.
let sockets: Vec<_> = threads
.into_iter()
.map(|thread| thread.join().unwrap().unwrap())
.collect();

// Number of tries is exactly equal to (n*(n+1))/2.
assert_eq!(
tries.load(Ordering::Relaxed),
(NTHREADS * (NTHREADS + 1)) / 2
);

for socket in sockets {
assert!(socket.path().exists());
}
}