Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rustc_codegen_ssa: Better code generation for niche discriminants. #102872

Merged
merged 1 commit into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 145 additions & 49 deletions compiler/rustc_codegen_ssa/src/mir/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
bx: &mut Bx,
cast_to: Ty<'tcx>,
) -> V {
let cast_to = bx.cx().immediate_backend_type(bx.cx().layout_of(cast_to));
let cast_to_layout = bx.cx().layout_of(cast_to);
let cast_to_size = cast_to_layout.layout.size();
let cast_to = bx.cx().immediate_backend_type(cast_to_layout);
if self.layout.abi.is_uninhabited() {
return bx.cx().const_undef(cast_to);
}
Expand All @@ -229,7 +231,8 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {

// Read the tag/niche-encoded discriminant from memory.
let tag = self.project_field(bx, tag_field);
let tag = bx.load_operand(tag);
let tag_op = bx.load_operand(tag);
let tag_imm = tag_op.immediate();

// Decode the discriminant (specifically if it's niche-encoded).
match *tag_encoding {
Expand All @@ -242,68 +245,161 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
Int(_, signed) => !tag_scalar.is_bool() && signed,
_ => false,
};
bx.intcast(tag.immediate(), cast_to, signed)
bx.intcast(tag_imm, cast_to, signed)
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
// Rebase from niche values to discriminants, and check
// whether the result is in range for the niche variants.
let niche_llty = bx.cx().immediate_backend_type(tag.layout);
let tag = tag.immediate();

// We first compute the "relative discriminant" (wrt `niche_variants`),
// that is, if `n = niche_variants.end() - niche_variants.start()`,
// we remap `niche_start..=niche_start + n` (which may wrap around)
// to (non-wrap-around) `0..=n`, to be able to check whether the
// discriminant corresponds to a niche variant with one comparison.
// We also can't go directly to the (variant index) discriminant
// and check that it is in the range `niche_variants`, because
// that might not fit in the same type, on top of needing an extra
// comparison (see also the comment on `let niche_discr`).
let relative_discr = if niche_start == 0 {
// Avoid subtracting `0`, which wouldn't work for pointers.
// FIXME(eddyb) check the actual primitive type here.
tag
// Cast to an integer so we don't have to treat a pointer as a
// special case.
let (tag, tag_llty) = if tag_scalar.primitive().is_ptr() {
let t = bx.type_isize();
let tag = bx.ptrtoint(tag_imm, t);
(tag, t)
} else {
bx.sub(tag, bx.cx().const_uint_big(niche_llty, niche_start))
(tag_imm, bx.cx().immediate_backend_type(tag_op.layout))
};

let tag_size = tag_scalar.size(bx.cx());
let max_unsigned = tag_size.unsigned_int_max();
let max_signed = tag_size.signed_int_max() as u128;
let min_signed = max_signed + 1;
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
let is_niche = if relative_max == 0 {
// Avoid calling `const_uint`, which wouldn't work for pointers.
// Also use canonical == 0 instead of non-canonical u<= 0.
// FIXME(eddyb) check the actual primitive type here.
bx.icmp(IntPredicate::IntEQ, relative_discr, bx.cx().const_null(niche_llty))
let niche_end = niche_start.wrapping_add(relative_max as u128) & max_unsigned;
let range = tag_scalar.valid_range(bx.cx());

let sle = |lhs: u128, rhs: u128| -> bool {
// Signed and unsigned comparisons give the same results,
// except that in signed comparisons an integer with the
// sign bit set is less than one with the sign bit clear.
// Toggle the sign bit to do a signed comparison.
(lhs ^ min_signed) <= (rhs ^ min_signed)
};

// We have a subrange `niche_start..=niche_end` inside `range`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would debug_assert!ing this make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, I think it should probably happen when a niche is created.

// If the value of the tag is inside this subrange, it's a
// "niche value", an increment of the discriminant. Otherwise it
// indicates the untagged variant.
// A general algorithm to extract the discriminant from the tag
// is:
// relative_tag = tag - niche_start
// is_niche = relative_tag <= (ule) relative_max
// discr = if is_niche {
// cast(relative_tag) + niche_variants.start()
// } else {
// untagged_variant
// }
// However, we will likely be able to emit simpler code.

// Find the least and greatest values in `range`, considered
// both as signed and unsigned.
let (low_unsigned, high_unsigned) = if range.start <= range.end {
(range.start, range.end)
} else {
(0, max_unsigned)
};
let (low_signed, high_signed) = if sle(range.start, range.end) {
(range.start, range.end)
} else {
let relative_max = bx.cx().const_uint(niche_llty, relative_max as u64);
bx.icmp(IntPredicate::IntULE, relative_discr, relative_max)
(min_signed, max_signed)
};

let niches_ule = niche_start <= niche_end;
let niches_sle = sle(niche_start, niche_end);
let cast_smaller = cast_to_size <= tag_size;

// In the algorithm above, we can change
// cast(relative_tag) + niche_variants.start()
// into
// cast(tag) + (niche_variants.start() - niche_start)
// if either the casted type is no larger than the original
// type, or if the niche values are contiguous (in either the
// signed or unsigned sense).
let can_incr_after_cast = cast_smaller || niches_ule || niches_sle;

let data_for_boundary_niche = || -> Option<(IntPredicate, u128)> {
if !can_incr_after_cast {
None
} else if niche_start == low_unsigned {
Some((IntPredicate::IntULE, niche_end))
} else if niche_end == high_unsigned {
Some((IntPredicate::IntUGE, niche_start))
} else if niche_start == low_signed {
Some((IntPredicate::IntSLE, niche_end))
} else if niche_end == high_signed {
Some((IntPredicate::IntSGE, niche_start))
} else {
None
}
};

// NOTE(eddyb) this addition needs to be performed on the final
// type, in case the niche itself can't represent all variant
// indices (e.g. `u8` niche with more than `256` variants,
// but enough uninhabited variants so that the remaining variants
// fit in the niche).
// In other words, `niche_variants.end - niche_variants.start`
// is representable in the niche, but `niche_variants.end`
// might not be, in extreme cases.
let niche_discr = {
let relative_discr = if relative_max == 0 {
// HACK(eddyb) since we have only one niche, we know which
// one it is, and we can avoid having a dynamic value here.
bx.cx().const_uint(cast_to, 0)
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
// Best case scenario: only one tagged variant. This will
// likely become just a comparison and a jump.
// The algorithm is:
// is_niche = tag == niche_start
// discr = if is_niche {
// niche_start
// } else {
// untagged_variant
// }
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
let tagged_discr =
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
(is_niche, tagged_discr, 0)
} else if let Some((predicate, constant)) = data_for_boundary_niche() {
// The niche values are either the lowest or the highest in
// `range`. We can avoid the first subtraction in the
// algorithm.
// The algorithm is now this:
// is_niche = tag <= niche_end
// discr = if is_niche {
// cast(tag) + (niche_variants.start() - niche_start)
// } else {
// untagged_variant
// }
// (the first line may instead be tag >= niche_start,
// and may be a signed or unsigned comparison)
let is_niche =
bx.icmp(predicate, tag, bx.cx().const_uint_big(tag_llty, constant));
let cast_tag = if cast_smaller {
bx.intcast(tag, cast_to, false)
} else if niches_ule {
bx.zext(tag, cast_to)
} else {
bx.intcast(relative_discr, cast_to, false)
bx.sext(tag, cast_to)
};
bx.add(

let delta = (niche_variants.start().as_u32() as u128).wrapping_sub(niche_start);
(is_niche, cast_tag, delta)
} else {
// The special cases don't apply, so we'll have to go with
// the general algorithm.
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
let cast_tag = bx.intcast(relative_discr, cast_to, false);
let is_niche = bx.icmp(
IntPredicate::IntULE,
relative_discr,
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64),
)
bx.cx().const_uint(tag_llty, relative_max as u64),
);
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
};

bx.select(
let tagged_discr = if delta == 0 {
tagged_discr
} else {
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
};

let discr = bx.select(
is_niche,
niche_discr,
tagged_discr,
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
)
);

// In principle we could insert assumes on the possible range of `discr`, but
// currently in LLVM this seems to be a pessimization.

discr
}
}
}
Expand Down
112 changes: 112 additions & 0 deletions src/test/codegen/enum-match.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// compile-flags: -Copt-level=1
// only-x86_64

#![crate_type = "lib"]

// Check each of the 3 cases for `codegen_get_discr`.

// Case 0: One tagged variant.
pub enum Enum0 {
A(bool),
B,
}

// CHECK: define i8 @match0{{.*}}
// CHECK-NEXT: start:
// CHECK-NEXT: %1 = icmp eq i8 %0, 2
// CHECK-NEXT: %2 = and i8 %0, 1
// CHECK-NEXT: %.0 = select i1 %1, i8 13, i8 %2
#[no_mangle]
pub fn match0(e: Enum0) -> u8 {
use Enum0::*;
match e {
A(b) => b as u8,
B => 13,
}
}

// Case 1: Niche values are on a boundary for `range`.
pub enum Enum1 {
A(bool),
B,
C,
}

// CHECK: define i8 @match1{{.*}}
// CHECK-NEXT: start:
// CHECK-NEXT: %1 = icmp ugt i8 %0, 1
// CHECK-NEXT: %2 = zext i8 %0 to i64
// CHECK-NEXT: %3 = add nsw i64 %2, -1
// CHECK-NEXT: %_2 = select i1 %1, i64 %3, i64 0
// CHECK-NEXT: switch i64 %_2, label {{.*}} [
#[no_mangle]
pub fn match1(e: Enum1) -> u8 {
use Enum1::*;
match e {
A(b) => b as u8,
B => 13,
C => 100,
}
}

// Case 2: Special cases don't apply.
pub enum X {
_2=2, _3, _4, _5, _6, _7, _8, _9, _10, _11,
_12, _13, _14, _15, _16, _17, _18, _19, _20,
_21, _22, _23, _24, _25, _26, _27, _28, _29,
_30, _31, _32, _33, _34, _35, _36, _37, _38,
_39, _40, _41, _42, _43, _44, _45, _46, _47,
_48, _49, _50, _51, _52, _53, _54, _55, _56,
_57, _58, _59, _60, _61, _62, _63, _64, _65,
_66, _67, _68, _69, _70, _71, _72, _73, _74,
_75, _76, _77, _78, _79, _80, _81, _82, _83,
_84, _85, _86, _87, _88, _89, _90, _91, _92,
_93, _94, _95, _96, _97, _98, _99, _100, _101,
_102, _103, _104, _105, _106, _107, _108, _109,
_110, _111, _112, _113, _114, _115, _116, _117,
_118, _119, _120, _121, _122, _123, _124, _125,
_126, _127, _128, _129, _130, _131, _132, _133,
_134, _135, _136, _137, _138, _139, _140, _141,
_142, _143, _144, _145, _146, _147, _148, _149,
_150, _151, _152, _153, _154, _155, _156, _157,
_158, _159, _160, _161, _162, _163, _164, _165,
_166, _167, _168, _169, _170, _171, _172, _173,
_174, _175, _176, _177, _178, _179, _180, _181,
_182, _183, _184, _185, _186, _187, _188, _189,
_190, _191, _192, _193, _194, _195, _196, _197,
_198, _199, _200, _201, _202, _203, _204, _205,
_206, _207, _208, _209, _210, _211, _212, _213,
_214, _215, _216, _217, _218, _219, _220, _221,
_222, _223, _224, _225, _226, _227, _228, _229,
_230, _231, _232, _233, _234, _235, _236, _237,
_238, _239, _240, _241, _242, _243, _244, _245,
_246, _247, _248, _249, _250, _251, _252, _253,
}

pub enum Enum2 {
A(X),
B,
C,
D,
E,
}

// CHECK: define i8 @match2{{.*}}
// CHECK-NEXT: start:
// CHECK-NEXT: %1 = add i8 %0, 2
// CHECK-NEXT: %2 = zext i8 %1 to i64
// CHECK-NEXT: %3 = icmp ult i8 %1, 4
// CHECK-NEXT: %4 = add nuw nsw i64 %2, 1
// CHECK-NEXT: %_2 = select i1 %3, i64 %4, i64 0
// CHECK-NEXT: switch i64 %_2, label {{.*}} [
#[no_mangle]
pub fn match2(e: Enum2) -> u8 {
use Enum2::*;
match e {
A(b) => b as u8,
B => 13,
C => 100,
D => 200,
E => 250,
}
}
Loading