-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Insert alignment checks for pointer dereferences when debug assertions are enabled #98112
Changes from all commits
8ccf533
5896f86
c9c1346
a71b808
67540ec
9f17ede
f93ef09
7507078
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
use crate::MirPass; | ||
use rustc_hir::def_id::DefId; | ||
use rustc_index::vec::IndexVec; | ||
use rustc_middle::mir::*; | ||
use rustc_middle::mir::{ | ||
interpret::{ConstValue, Scalar}, | ||
visit::{PlaceContext, Visitor}, | ||
}; | ||
use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut}; | ||
use rustc_session::Session; | ||
|
||
pub struct CheckAlignment; | ||
|
||
impl<'tcx> MirPass<'tcx> for CheckAlignment { | ||
fn is_enabled(&self, sess: &Session) -> bool { | ||
sess.opts.debug_assertions | ||
} | ||
|
||
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | ||
let basic_blocks = body.basic_blocks.as_mut(); | ||
let local_decls = &mut body.local_decls; | ||
|
||
for block in (0..basic_blocks.len()).rev() { | ||
let block = block.into(); | ||
for statement_index in (0..basic_blocks[block].statements.len()).rev() { | ||
let location = Location { block, statement_index }; | ||
let statement = &basic_blocks[block].statements[statement_index]; | ||
let source_info = statement.source_info; | ||
|
||
let mut finder = PointerFinder { | ||
local_decls, | ||
tcx, | ||
pointers: Vec::new(), | ||
def_id: body.source.def_id(), | ||
}; | ||
for (pointer, pointee_ty) in finder.find_pointers(statement) { | ||
debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty); | ||
|
||
let new_block = split_block(basic_blocks, location); | ||
insert_alignment_check( | ||
tcx, | ||
local_decls, | ||
&mut basic_blocks[block], | ||
pointer, | ||
pointee_ty, | ||
source_info, | ||
new_block, | ||
); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
impl<'tcx, 'a> PointerFinder<'tcx, 'a> { | ||
fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> { | ||
self.pointers.clear(); | ||
self.visit_statement(statement, Location::START); | ||
core::mem::take(&mut self.pointers) | ||
} | ||
} | ||
|
||
struct PointerFinder<'tcx, 'a> { | ||
local_decls: &'a mut LocalDecls<'tcx>, | ||
tcx: TyCtxt<'tcx>, | ||
def_id: DefId, | ||
pointers: Vec<(Place<'tcx>, Ty<'tcx>)>, | ||
} | ||
|
||
impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> { | ||
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { | ||
if let PlaceContext::NonUse(_) = context { | ||
return; | ||
} | ||
if !place.is_indirect() { | ||
return; | ||
} | ||
|
||
let pointer = Place::from(place.local); | ||
let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty; | ||
|
||
// We only want to check unsafe pointers | ||
if !pointer_ty.is_unsafe_ptr() { | ||
trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty); | ||
return; | ||
} | ||
|
||
let Some(pointee) = pointer_ty.builtin_deref(true) else { | ||
debug!("Indirect but no builtin deref: {:?}", pointer_ty); | ||
return; | ||
}; | ||
let mut pointee_ty = pointee.ty; | ||
if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() { | ||
oli-obk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pointee_ty = pointee_ty.sequence_element_type(self.tcx); | ||
} | ||
|
||
if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) { | ||
debug!("Unsafe pointer, but unsized: {:?}", pointer_ty); | ||
return; | ||
} | ||
|
||
if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_] | ||
.contains(&pointee_ty) | ||
{ | ||
debug!("Trivially aligned pointee type: {:?}", pointer_ty); | ||
return; | ||
} | ||
|
||
self.pointers.push((pointer, pointee_ty)) | ||
} | ||
} | ||
|
||
fn split_block( | ||
basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>, | ||
location: Location, | ||
) -> BasicBlock { | ||
let block_data = &mut basic_blocks[location.block]; | ||
|
||
// Drain every statement after this one and move the current terminator to a new basic block | ||
let new_block = BasicBlockData { | ||
statements: block_data.statements.split_off(location.statement_index), | ||
terminator: block_data.terminator.take(), | ||
is_cleanup: block_data.is_cleanup, | ||
}; | ||
|
||
basic_blocks.push(new_block) | ||
} | ||
|
||
fn insert_alignment_check<'tcx>( | ||
tcx: TyCtxt<'tcx>, | ||
local_decls: &mut LocalDecls<'tcx>, | ||
block_data: &mut BasicBlockData<'tcx>, | ||
pointer: Place<'tcx>, | ||
pointee_ty: Ty<'tcx>, | ||
source_info: SourceInfo, | ||
new_block: BasicBlock, | ||
) { | ||
// Cast the pointer to a *const () | ||
let const_raw_ptr = tcx.mk_ptr(TypeAndMut { ty: tcx.types.unit, mutbl: Mutability::Not }); | ||
let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr); | ||
let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into(); | ||
block_data | ||
.statements | ||
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) }); | ||
|
||
// Transmute the pointer to a usize (equivalent to `ptr.addr()`) | ||
let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize); | ||
let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | ||
block_data | ||
.statements | ||
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) }); | ||
|
||
// Get the alignment of the pointee | ||
let alignment = | ||
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | ||
let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty); | ||
block_data.statements.push(Statement { | ||
source_info, | ||
kind: StatementKind::Assign(Box::new((alignment, rvalue))), | ||
}); | ||
|
||
// Subtract 1 from the alignment to get the alignment mask | ||
let alignment_mask = | ||
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | ||
let one = Operand::Constant(Box::new(Constant { | ||
span: source_info.span, | ||
user_ty: None, | ||
literal: ConstantKind::Val( | ||
ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)), | ||
tcx.types.usize, | ||
), | ||
})); | ||
block_data.statements.push(Statement { | ||
source_info, | ||
kind: StatementKind::Assign(Box::new(( | ||
alignment_mask, | ||
Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(alignment), one))), | ||
))), | ||
}); | ||
|
||
// BitAnd the alignment mask with the pointer | ||
scottmcm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let alignment_bits = | ||
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | ||
block_data.statements.push(Statement { | ||
source_info, | ||
kind: StatementKind::Assign(Box::new(( | ||
alignment_bits, | ||
Rvalue::BinaryOp( | ||
BinOp::BitAnd, | ||
Box::new((Operand::Copy(addr), Operand::Copy(alignment_mask))), | ||
), | ||
))), | ||
}); | ||
|
||
// Check if the alignment bits are all zero | ||
let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); | ||
let zero = Operand::Constant(Box::new(Constant { | ||
span: source_info.span, | ||
user_ty: None, | ||
literal: ConstantKind::Val( | ||
ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), | ||
tcx.types.usize, | ||
), | ||
})); | ||
block_data.statements.push(Statement { | ||
source_info, | ||
kind: StatementKind::Assign(Box::new(( | ||
is_ok, | ||
Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))), | ||
))), | ||
}); | ||
|
||
// Set this block's terminator to our assert, continuing to new_block if we pass | ||
block_data.terminator = Some(Terminator { | ||
source_info, | ||
kind: TerminatorKind::Assert { | ||
cond: Operand::Copy(is_ok), | ||
expected: true, | ||
target: new_block, | ||
msg: AssertKind::MisalignedPointerDereference { | ||
required: Operand::Copy(alignment), | ||
found: Operand::Copy(addr), | ||
}, | ||
cleanup: None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discovered this while rebasing my PR #102906. I don't think I feel that this check should be inserted during MIR build, or, this pass needs to be run before drop elaboration. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The ideal outcome is to produce an aborting panic here, because this check can add surprising new unwinding sites and make code that was unwind-safe not unwind-safe. Unwinding can be a significantly worse outcome than an unaligned read. With your PR, is it possible to make this abort? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. I can just change it to |
||
}, | ||
}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this do the wrong thing for something like
**x
withx: &*const i32
? Herex
is safe but we still deref a raw ptr that needs checking.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at this stage there are no places like
**x
. They have all been converted toThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wow, that is a very subtle non-local invariant. At the very least the code needs a detailed comment explaining that. Also we better make sure some check (MIR validation?) ICEs if that invariant is ever violated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we already have a MIR validation that only the first projection may be a Deref. I have tripped it often working on MIR opts:
rust/compiler/rustc_const_eval/src/transform/validate.rs
Line 482 in a732883