Skip to content

Commit

Permalink
Auto merge of rust-lang#120193 - x17jiri:cold_match_arms, r=<try>
Browse files Browse the repository at this point in the history
#[cold] on match arms

### Edit

This should be in T-lang. I'm not sure how I can change it.

There is discussion: https://rust-lang.zulipchat.com/#narrow/stream/213817-t-lang/topic/Allow.20.23.5Bcold.5D.20on.20match.20and.20if.20arms

### Summary

Adds the possibility to use `#[cold]` attribute on match arms to hint the optimizer that the arm is unlikely to be taken.

### Motivation

These hints are sometimes thought to help branch prediction, but the effect is probably marginal. Modern CPUs don't support hints on conditional branch instructions. They either have the current branch in the BTB (branch prediction buffer), or not, in which case the branch is predicted not to be taken.

These hints are, however, helpful in letting the compiler know what is the fast path, so it can be optimized at the expense of the slow path.

`grep`-ing the LLVM code for BlockFrequencyInfo and BranchProbabilityInfo shows that these hints are used at many places in the optimizer. Such as:
- block placement - improve locality by making the fast path compact and move everything else out of the way
- inlining, loop unrolling - these optimizations can be less aggressive on the cold path therefore reducing code size
- register allocation - preferably keep in registers the data needed on the fast path

### History

RFC 1131 ( rust-lang#26179 ) added `likely` and `unlikely` intrinsics, which get converted to `llvm.expect.i8`. However this LLVM instruction is fragile and may get removed by some optimization passes. The problems with the intrinsics have been reported several times: rust-lang#96276 , rust-lang#96275 , rust-lang#88767

### Other languages

Clang and GCC C++ compilers provide `__builtin_expect`. Since C++20, it is also possible to use `[[likely]]` and `[[unlikely]]` attributes.

Use:
```
if (__builtin_expect(condition, false)) { ... this branch is UNlikely ... }

if (condition) [[likely]] { ... this branch is likely... }
```

Note that while clang provides `__builtin_expect`, it does not convert it to `llvm.expect.i8`. Instead, it looks at the surrounding code and if there is a condition, emits branch weight metadata for conditional branches.

### Design

Implementing `likely`/`unlikely` type of functions properly to emit branch weights would add significant complexity to the compiler. Additionally, these functions are not easy to use with `match` arms.

Replacing the functions with attributes is easier to implement and will also work with `match`.

A question remains whether these attributes should be named `likely`/`unlikely` as in C++, or if we could reuse the already existing `#[cold]` attribute. `#[cold]` has the same meaning as `unlikely`, i.e., marking the slow path, but it can currently only be used on entire functions.

I personally prefer `#[cold]` because it already exists in Rust and is a short word that looks better in code. It has one disadvantage though.
This code:
```
if cond #[likely] { ... }
```
becomes:
```
if cond { ... } #[cold] { ... empty cold branch ... }
```

In this PR, I implemented the possibility to add `#[cold]` attribute on match arms. Use is as follows:
```
match x {
    #[cold] true => { ... } // the true arm is UNlikely
    _ => { ... } // the false arm is likely
}
```

### Limitations
The implementation only works on bool, or integers with single value arm and an otherwise arm. Extending it to other types and to `if` statements should not be too difficult.
  • Loading branch information
bors committed Jan 22, 2024
2 parents 366d112 + 71b694c commit 4d4892d
Show file tree
Hide file tree
Showing 15 changed files with 231 additions and 21 deletions.
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4210,6 +4210,7 @@ dependencies = [
"rustc_target",
"rustc_trait_selection",
"smallvec",
"thin-vec",
"tracing",
]

Expand Down
32 changes: 32 additions & 0 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,38 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
}
}

fn cond_br_with_cold_br(
&mut self,
cond: &'ll Value,
then_llbb: Self::BasicBlock,
else_llbb: Self::BasicBlock,
cold_br: Option<bool>,
) {
// emit the branch instruction
let n = unsafe { llvm::LLVMBuildCondBr(self.llbuilder, cond, then_llbb, else_llbb) };

// if one of the branches is cold, emit metadata with branch weights
if let Some(cold_br) = cold_br {
unsafe {
let s = "branch_weights";
let v = [
llvm::LLVMMDStringInContext(
self.cx.llcx,
s.as_ptr() as *const c_char,
s.len() as c_uint,
),
self.cx.const_u32(if cold_br { 1 } else { 2000 }), // 'then' branch weight
self.cx.const_u32(if cold_br { 2000 } else { 1 }), // 'else' branch weight
];
llvm::LLVMSetMetadata(
n,
llvm::MD_prof as c_uint,
llvm::LLVMMDNodeInContext(self.cx.llcx, v.as_ptr(), v.len() as c_uint),
);
}
}
}

fn switch(
&mut self,
v: &'ll Value,
Expand Down
12 changes: 9 additions & 3 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,18 +327,24 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let (test_value, target) = target_iter.next().unwrap();
let lltrue = helper.llbb_with_cleanup(self, target);
let llfalse = helper.llbb_with_cleanup(self, targets.otherwise());
let cold_br =
targets.cold_target().and_then(|t| if t == 0 { Some(true) } else { Some(false) });

if switch_ty == bx.tcx().types.bool {
// Don't generate trivial icmps when switching on bool.
match test_value {
0 => bx.cond_br(discr.immediate(), llfalse, lltrue),
1 => bx.cond_br(discr.immediate(), lltrue, llfalse),
0 => {
let cold_br = cold_br.and_then(|t| Some(!t));
bx.cond_br_with_cold_br(discr.immediate(), llfalse, lltrue, cold_br);
}
1 => bx.cond_br_with_cold_br(discr.immediate(), lltrue, llfalse, cold_br),
_ => bug!(),
}
} else {
let switch_llty = bx.immediate_backend_type(bx.layout_of(switch_ty));
let llval = bx.const_uint_big(switch_llty, test_value);
let cmp = bx.icmp(IntPredicate::IntEQ, discr.immediate(), llval);
bx.cond_br(cmp, lltrue, llfalse);
bx.cond_br_with_cold_br(cmp, lltrue, llfalse, cold_br);
}
} else if self.cx.sess().opts.optimize == OptLevel::No
&& target_iter.len() == 2
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_codegen_ssa/src/traits/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ pub trait BuilderMethods<'a, 'tcx>:
then_llbb: Self::BasicBlock,
else_llbb: Self::BasicBlock,
);
fn cond_br_with_cold_br(
&mut self,
cond: Self::Value,
then_llbb: Self::BasicBlock,
else_llbb: Self::BasicBlock,
cold_br: Option<bool>,
);
fn switch(
&mut self,
v: Self::Value,
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_data_structures/src/stable_hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::fmt;
use std::hash::{BuildHasher, Hash, Hasher};
use std::marker::PhantomData;
use std::mem;
use thin_vec::ThinVec;

#[cfg(test)]
mod tests;
Expand Down Expand Up @@ -499,6 +500,16 @@ where
}
}

impl<T, CTX> HashStable<CTX> for ThinVec<T>
where
T: HashStable<CTX>,
{
#[inline]
fn hash_stable(&self, ctx: &mut CTX, hasher: &mut StableHasher) {
self[..].hash_stable(ctx, hasher);
}
}

impl<T: ?Sized + HashStable<CTX>, CTX> HashStable<CTX> for Box<T> {
#[inline]
fn hash_stable(&self, ctx: &mut CTX, hasher: &mut StableHasher) {
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_middle/src/mir/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use rustc_span::symbol::Symbol;
use rustc_span::Span;
use rustc_target::asm::InlineAsmRegOrRegClass;
use smallvec::SmallVec;
use thin_vec::ThinVec;

/// Represents the "flavors" of MIR.
///
Expand Down Expand Up @@ -844,6 +845,12 @@ pub struct SwitchTargets {
// However we’ve decided to keep this as-is until we figure a case
// where some other approach seems to be strictly better than other.
pub(super) targets: SmallVec<[BasicBlock; 2]>,

// Targets that are marked 'cold', if any.
// This vector contains indices into `targets`.
// It can also contain 'targets.len()' to indicate that the otherwise
// branch is cold.
pub(super) cold_targets: ThinVec<usize>,
}

/// Action to be taken when a stack unwind happens.
Expand Down
51 changes: 49 additions & 2 deletions compiler/rustc_middle/src/mir/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use smallvec::SmallVec;
use super::TerminatorKind;
use rustc_macros::HashStable;
use std::slice;
use thin_vec::{thin_vec, ThinVec};

use super::*;

Expand All @@ -16,13 +17,36 @@ impl SwitchTargets {
pub fn new(targets: impl Iterator<Item = (u128, BasicBlock)>, otherwise: BasicBlock) -> Self {
let (values, mut targets): (SmallVec<_>, SmallVec<_>) = targets.unzip();
targets.push(otherwise);
Self { values, targets }
Self { values, targets, cold_targets: ThinVec::new() }
}

/// Builds a switch targets definition that jumps to `then` if the tested value equals `value`,
/// and to `else_` if not.
pub fn static_if(value: u128, then: BasicBlock, else_: BasicBlock) -> Self {
Self { values: smallvec![value], targets: smallvec![then, else_] }
Self {
values: smallvec![value],
targets: smallvec![then, else_],
cold_targets: ThinVec::new(),
}
}

/// Builds a switch targets definition that jumps to `then` if the tested value equals `value`,
/// and to `else_` if not.
/// If cold_br is some bool value, the given outcome is considered cold (i.e., unlikely).
pub fn static_if_with_cold_br(
value: u128,
then: BasicBlock,
else_: BasicBlock,
cold_br: Option<bool>,
) -> Self {
Self {
values: smallvec![value],
targets: smallvec![then, else_],
cold_targets: match cold_br {
Some(br) => thin_vec![if br { 0 } else { 1 }],
None => ThinVec::new(),
},
}
}

/// Inverse of `SwitchTargets::static_if`.
Expand All @@ -37,6 +61,11 @@ impl SwitchTargets {
}
}

// If this switch has exactly one target, returns it.
pub fn cold_target(&self) -> Option<usize> {
if self.cold_targets.len() == 1 { Some(self.cold_targets[0]) } else { None }
}

/// Returns the fallback target that is jumped to when none of the values match the operand.
#[inline]
pub fn otherwise(&self) -> BasicBlock {
Expand Down Expand Up @@ -365,6 +394,24 @@ impl<'tcx> TerminatorKind<'tcx> {
TerminatorKind::SwitchInt { discr: cond, targets: SwitchTargets::static_if(0, f, t) }
}

pub fn if_with_cold_br(
cond: Operand<'tcx>,
t: BasicBlock,
f: BasicBlock,
cold_branch: Option<bool>,
) -> TerminatorKind<'tcx> {
TerminatorKind::SwitchInt {
discr: cond,
targets: SwitchTargets::static_if_with_cold_br(
0,
f,
t,
// we compare to zero, so have to invert the branch
cold_branch.and_then(|b| Some(!b)),
),
}
}

#[inline]
pub fn successors(&self) -> Successors<'_> {
use self::TerminatorKind::*;
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ pub struct Arm<'tcx> {
pub lint_level: LintLevel,
pub scope: region::Scope,
pub span: Span,
pub is_cold: bool,
}

#[derive(Copy, Clone, Debug, HashStable)]
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_mir_build/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ rustc_span = { path = "../rustc_span" }
rustc_target = { path = "../rustc_target" }
rustc_trait_selection = { path = "../rustc_trait_selection" }
smallvec = { version = "1.8.1", features = ["union", "may_dangle"] }
thin-vec = "0.2.12"
tracing = "0.1"
# tidy-alphabetical-end
2 changes: 2 additions & 0 deletions compiler/rustc_mir_build/src/build/expr/into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
condition_scope,
source_info,
true,
None,
));

this.expr_into_dest(destination, then_blk, then)
Expand Down Expand Up @@ -176,6 +177,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
condition_scope,
source_info,
true,
None,
)
});
let (short_circuit, continuation, constant) = match op {
Expand Down
Loading

0 comments on commit 4d4892d

Please sign in to comment.