Skip to content

Commit

Permalink
Add a method to check if type is a CStr
Browse files Browse the repository at this point in the history
  • Loading branch information
celinval committed Dec 15, 2023
1 parent 3f39cae commit 86451ba
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 0 deletions.
6 changes: 6 additions & 0 deletions compiler/rustc_smir/src/rustc_smir/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ impl<'tcx> Context for TablesWrapper<'tcx> {
def.internal(&mut *tables).repr().simd()
}

fn adt_is_cstr(&self, def: AdtDef) -> bool {
let mut tables = self.0.borrow_mut();
let def_id = def.0.internal(&mut *tables);
tables.tcx.lang_items().c_str() == Some(def_id)
}

fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig {
let mut tables = self.0.borrow_mut();
let def_id = def.0.internal(&mut *tables);
Expand Down
3 changes: 3 additions & 0 deletions compiler/stable_mir/src/compiler_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ pub trait Context {
/// Returns whether this ADT is simd.
fn adt_is_simd(&self, def: AdtDef) -> bool;

/// Returns whether this definition is a C string.
fn adt_is_cstr(&self, def: AdtDef) -> bool;

/// Retrieve the function signature for the given generic arguments.
fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig;

Expand Down
6 changes: 6 additions & 0 deletions compiler/stable_mir/src/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,12 @@ impl TyKind {
*self == TyKind::RigidTy(RigidTy::Str)
}

#[inline]
pub fn is_cstr(&self) -> bool {
let TyKind::RigidTy(RigidTy::Adt(def, _)) = self else { return false };
with(|cx| cx.adt_is_cstr(*def))
}

#[inline]
pub fn is_slice(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Slice(_)))
Expand Down
22 changes: 22 additions & 0 deletions tests/ui-fulldeps/stable-mir/check_allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use std::ascii::Char;
use std::assert_matches::assert_matches;
use std::cmp::{max, min};
use std::collections::HashMap;
use std::ffi::CStr;
use std::io::Write;
use std::ops::ControlFlow;

Expand All @@ -45,6 +46,7 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
check_foo(*get_item(&items, (ItemKind::Static, "FOO")).unwrap());
check_bar(*get_item(&items, (ItemKind::Static, "BAR")).unwrap());
check_len(*get_item(&items, (ItemKind::Static, "LEN")).unwrap());
check_cstr(*get_item(&items, (ItemKind::Static, "C_STR")).unwrap());
check_other_consts(*get_item(&items, (ItemKind::Fn, "other_consts")).unwrap());
check_type_id(*get_item(&items, (ItemKind::Fn, "check_type_id")).unwrap());
ControlFlow::Continue(())
Expand Down Expand Up @@ -86,6 +88,24 @@ fn check_bar(item: CrateItem) {
assert_eq!(std::str::from_utf8(&allocation.raw_bytes().unwrap()), Ok("Bar"));
}

/// Check the allocation data for static `C_STR`.
///
/// ```no_run
/// static C_STR: &core::ffi::cstr = c"cstr";
/// ```
fn check_cstr(item: CrateItem) {
let def = StaticDef::try_from(item).unwrap();
let alloc = def.eval_initializer().unwrap();
assert_eq!(alloc.provenance.ptrs.len(), 1);
let deref = item.ty().kind().builtin_deref(true).unwrap();
assert!(deref.ty.kind().is_cstr(), "Expected CStr, but got: {:?}", item.ty());

let alloc_id_0 = alloc.provenance.ptrs[0].1.0;
let GlobalAlloc::Memory(allocation) = GlobalAlloc::from(alloc_id_0) else { unreachable!() };
assert_eq!(allocation.bytes.len(), 5);
assert_eq!(CStr::from_bytes_until_nul(&allocation.raw_bytes().unwrap()), Ok(c"cstr"));
}

/// Check the allocation data for constants used in `other_consts` function.
fn check_other_consts(item: CrateItem) {
// Instance body will force constant evaluation.
Expand Down Expand Up @@ -206,6 +226,7 @@ fn main() {
generate_input(&path).unwrap();
let args = vec![
"rustc".to_string(),
"--edition=2021".to_string(),
"--crate-name".to_string(),
CRATE_NAME.to_string(),
path.to_string(),
Expand All @@ -224,6 +245,7 @@ fn generate_input(path: &str) -> std::io::Result<()> {
static LEN: usize = 2;
static FOO: [&str; 2] = ["hi", "there"];
static BAR: &str = "Bar";
static C_STR: &std::ffi::CStr = c"cstr";
const NULL: *const u8 = std::ptr::null();
const TUPLE: (u32, u32) = (10, u32::MAX);
Expand Down

0 comments on commit 86451ba

Please sign in to comment.