Skip to content

Commit 1d601d6

Browse files
committed
Auto merge of #74695 - alexcrichton:more-wasm-float-cast-fixes, r=nagisa
rustc: Improving safe wasm float->int casts This commit improves code generation for WebAssembly targets when translating floating to integer casts. This improvement is only relevant when the `nontrapping-fptoint` feature is not enabled, but the feature is not enabled by default right now. Additionally this improvement only affects safe casts since unchecked casts were improved in #74659. Some more background for this issue is present on #73591, but the general gist of the issue is that in LLVM the `fptosi` and `fptoui` instructions are defined to return an `undef` value if they execute on out-of-bounds values; they notably do not trap. To implement these instructions for WebAssembly the LLVM backend must therefore generate quite a few instructions before executing `i32.trunc_f32_s` (for example) because this WebAssembly instruction traps on out-of-bounds values. This codegen into wasm instructions happens very late in the code generator, so what ends up happening is that rustc inserts its own codegen to implement Rust's saturating semantics, and then LLVM also inserts its own codegen to make sure that the `fptosi` instruction doesn't trap. Overall this means that a function like this: #[no_mangle] pub unsafe extern "C" fn cast(x: f64) -> u32 { x as u32 } will generate this WebAssembly today: (func $cast (type 0) (param f64) (result i32) (local i32 i32) local.get 0 f64.const 0x1.fffffffep+31 (;=4.29497e+09;) f64.gt local.set 1 block ;; label = @1 block ;; label = @2 local.get 0 f64.const 0x0p+0 (;=0;) local.get 0 f64.const 0x0p+0 (;=0;) f64.gt select local.tee 0 f64.const 0x1p+32 (;=4.29497e+09;) f64.lt local.get 0 f64.const 0x0p+0 (;=0;) f64.ge i32.and i32.eqz br_if 0 (;@2;) local.get 0 i32.trunc_f64_u local.set 2 br 1 (;@1;) end i32.const 0 local.set 2 end i32.const -1 local.get 2 local.get 1 select) This PR improves the situation by updating the code generation for float-to-int conversions in rustc, specifically only for WebAssembly targets and only for some situations (float-to-u8 still has not great codegen). The fix here is to use basic blocks and control flow to avoid speculatively executing `fptosi`, and instead LLVM's raw intrinsic for the WebAssembly instruction is used instead. This effectively extends the support added in #74659 to checked casts. After this commit the codegen for the above Rust function looks like: (func $cast (type 0) (param f64) (result i32) (local i32) block ;; label = @1 local.get 0 f64.const 0x0p+0 (;=0;) f64.ge local.tee 1 i32.const 1 i32.xor br_if 0 (;@1;) local.get 0 f64.const 0x1.fffffffep+31 (;=4.29497e+09;) f64.le i32.eqz br_if 0 (;@1;) local.get 0 i32.trunc_f64_u return end i32.const -1 i32.const 0 local.get 1 select) For reference, in Rust 1.44, which did not have saturating float-to-integer casts, the codegen LLVM would emit is: (func $cast (type 0) (param f64) (result i32) block ;; label = @1 local.get 0 f64.const 0x1p+32 (;=4.29497e+09;) f64.lt local.get 0 f64.const 0x0p+0 (;=0;) f64.ge i32.and i32.eqz br_if 0 (;@1;) local.get 0 i32.trunc_f64_u return end i32.const 0) So we're relatively close to the original codegen, although it's slightly different because the semantics of the function changed where we're emulating the `i32.trunc_sat_f32_s` instruction rather than always replacing out-of-bounds values with zero. There is still work that could be done to improve casts such as `f32` to `u8`. That form of cast still uses the `fptosi` instruction which generates lots of branch-y code. This seems less important to tackle now though. In the meantime this should take care of most use cases of floating-point conversion and as a result I'm going to speculate that this... Closes #73591
2 parents d8cbd9c + 2c1b046 commit 1d601d6

File tree

5 files changed

+235
-112
lines changed

5 files changed

+235
-112
lines changed

src/librustc_codegen_llvm/builder.rs

+61
Original file line numberDiff line numberDiff line change
@@ -703,11 +703,67 @@ impl BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
703703
None
704704
}
705705

706+
fn fptosui_may_trap(&self, val: &'ll Value, dest_ty: &'ll Type) -> bool {
707+
// Most of the time we'll be generating the `fptosi` or `fptoui`
708+
// instruction for floating-point-to-integer conversions. These
709+
// instructions by definition in LLVM do not trap. For the WebAssembly
710+
// target, however, we'll lower in some cases to intrinsic calls instead
711+
// which may trap. If we detect that this is a situation where we'll be
712+
// using the intrinsics then we report that the call map trap, which
713+
// callers might need to handle.
714+
if !self.wasm_and_missing_nontrapping_fptoint() {
715+
return false;
716+
}
717+
let src_ty = self.cx.val_ty(val);
718+
let float_width = self.cx.float_width(src_ty);
719+
let int_width = self.cx.int_width(dest_ty);
720+
match (int_width, float_width) {
721+
(32, 32) | (32, 64) | (64, 32) | (64, 64) => true,
722+
_ => false,
723+
}
724+
}
725+
706726
fn fptoui(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
727+
// When we can, use the native wasm intrinsics which have tighter
728+
// codegen. Note that this has a semantic difference in that the
729+
// intrinsic can trap whereas `fptoui` never traps. That difference,
730+
// however, is handled by `fptosui_may_trap` above.
731+
if self.wasm_and_missing_nontrapping_fptoint() {
732+
let src_ty = self.cx.val_ty(val);
733+
let float_width = self.cx.float_width(src_ty);
734+
let int_width = self.cx.int_width(dest_ty);
735+
let name = match (int_width, float_width) {
736+
(32, 32) => Some("llvm.wasm.trunc.unsigned.i32.f32"),
737+
(32, 64) => Some("llvm.wasm.trunc.unsigned.i32.f64"),
738+
(64, 32) => Some("llvm.wasm.trunc.unsigned.i64.f32"),
739+
(64, 64) => Some("llvm.wasm.trunc.unsigned.i64.f64"),
740+
_ => None,
741+
};
742+
if let Some(name) = name {
743+
let intrinsic = self.get_intrinsic(name);
744+
return self.call(intrinsic, &[val], None);
745+
}
746+
}
707747
unsafe { llvm::LLVMBuildFPToUI(self.llbuilder, val, dest_ty, UNNAMED) }
708748
}
709749

710750
fn fptosi(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
751+
if self.wasm_and_missing_nontrapping_fptoint() {
752+
let src_ty = self.cx.val_ty(val);
753+
let float_width = self.cx.float_width(src_ty);
754+
let int_width = self.cx.int_width(dest_ty);
755+
let name = match (int_width, float_width) {
756+
(32, 32) => Some("llvm.wasm.trunc.signed.i32.f32"),
757+
(32, 64) => Some("llvm.wasm.trunc.signed.i32.f64"),
758+
(64, 32) => Some("llvm.wasm.trunc.signed.i64.f32"),
759+
(64, 64) => Some("llvm.wasm.trunc.signed.i64.f64"),
760+
_ => None,
761+
};
762+
if let Some(name) = name {
763+
let intrinsic = self.get_intrinsic(name);
764+
return self.call(intrinsic, &[val], None);
765+
}
766+
}
711767
unsafe { llvm::LLVMBuildFPToSI(self.llbuilder, val, dest_ty, UNNAMED) }
712768
}
713769

@@ -1349,4 +1405,9 @@ impl Builder<'a, 'll, 'tcx> {
13491405
llvm::LLVMAddIncoming(phi, &val, &bb, 1 as c_uint);
13501406
}
13511407
}
1408+
1409+
fn wasm_and_missing_nontrapping_fptoint(&self) -> bool {
1410+
self.sess().target.target.arch == "wasm32"
1411+
&& !self.sess().target_features.contains(&sym::nontrapping_dash_fptoint)
1412+
}
13521413
}

src/librustc_codegen_llvm/intrinsic.rs

+15-55
Original file line numberDiff line numberDiff line change
@@ -634,22 +634,19 @@ impl IntrinsicCallMethods<'tcx> for Builder<'a, 'll, 'tcx> {
634634
}
635635

636636
sym::float_to_int_unchecked => {
637-
let float_width = match float_type_width(arg_tys[0]) {
638-
Some(width) => width,
639-
None => {
640-
span_invalid_monomorphization_error(
641-
tcx.sess,
642-
span,
643-
&format!(
644-
"invalid monomorphization of `float_to_int_unchecked` \
637+
if float_type_width(arg_tys[0]).is_none() {
638+
span_invalid_monomorphization_error(
639+
tcx.sess,
640+
span,
641+
&format!(
642+
"invalid monomorphization of `float_to_int_unchecked` \
645643
intrinsic: expected basic float type, \
646644
found `{}`",
647-
arg_tys[0]
648-
),
649-
);
650-
return;
651-
}
652-
};
645+
arg_tys[0]
646+
),
647+
);
648+
return;
649+
}
653650
let (width, signed) = match int_type_width_signed(ret_ty, self.cx) {
654651
Some(pair) => pair,
655652
None => {
@@ -666,48 +663,11 @@ impl IntrinsicCallMethods<'tcx> for Builder<'a, 'll, 'tcx> {
666663
return;
667664
}
668665
};
669-
670-
// The LLVM backend can reorder and speculate `fptosi` and
671-
// `fptoui`, so on WebAssembly the codegen for this instruction
672-
// is quite heavyweight. To avoid this heavyweight codegen we
673-
// instead use the raw wasm intrinsics which will lower to one
674-
// instruction in WebAssembly (`iNN.trunc_fMM_{s,u}`). This one
675-
// instruction will trap if the operand is out of bounds, but
676-
// that's ok since this intrinsic is UB if the operands are out
677-
// of bounds, so the behavior can be different on WebAssembly
678-
// than other targets.
679-
//
680-
// Note, however, that when the `nontrapping-fptoint` feature is
681-
// enabled in LLVM then LLVM will lower `fptosi` to
682-
// `iNN.trunc_sat_fMM_{s,u}`, so if that's the case we don't
683-
// bother with intrinsics.
684-
let mut result = None;
685-
if self.sess().target.target.arch == "wasm32"
686-
&& !self.sess().target_features.contains(&sym::nontrapping_dash_fptoint)
687-
{
688-
let name = match (width, float_width, signed) {
689-
(32, 32, true) => Some("llvm.wasm.trunc.signed.i32.f32"),
690-
(32, 64, true) => Some("llvm.wasm.trunc.signed.i32.f64"),
691-
(64, 32, true) => Some("llvm.wasm.trunc.signed.i64.f32"),
692-
(64, 64, true) => Some("llvm.wasm.trunc.signed.i64.f64"),
693-
(32, 32, false) => Some("llvm.wasm.trunc.unsigned.i32.f32"),
694-
(32, 64, false) => Some("llvm.wasm.trunc.unsigned.i32.f64"),
695-
(64, 32, false) => Some("llvm.wasm.trunc.unsigned.i64.f32"),
696-
(64, 64, false) => Some("llvm.wasm.trunc.unsigned.i64.f64"),
697-
_ => None,
698-
};
699-
if let Some(name) = name {
700-
let intrinsic = self.get_intrinsic(name);
701-
result = Some(self.call(intrinsic, &[args[0].immediate()], None));
702-
}
666+
if signed {
667+
self.fptosi(args[0].immediate(), self.cx.type_ix(width))
668+
} else {
669+
self.fptoui(args[0].immediate(), self.cx.type_ix(width))
703670
}
704-
result.unwrap_or_else(|| {
705-
if signed {
706-
self.fptosi(args[0].immediate(), self.cx.type_ix(width))
707-
} else {
708-
self.fptoui(args[0].immediate(), self.cx.type_ix(width))
709-
}
710-
})
711671
}
712672

713673
sym::discriminant_value => {

src/librustc_codegen_ssa/mir/rvalue.rs

+134-33
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use rustc_apfloat::{ieee, Float, Round, Status};
1111
use rustc_hir::lang_items::ExchangeMallocFnLangItem;
1212
use rustc_middle::mir;
1313
use rustc_middle::ty::cast::{CastTy, IntTy};
14-
use rustc_middle::ty::layout::HasTyCtxt;
14+
use rustc_middle::ty::layout::{HasTyCtxt, TyAndLayout};
1515
use rustc_middle::ty::{self, adjustment::PointerCast, Instance, Ty, TyCtxt};
1616
use rustc_span::source_map::{Span, DUMMY_SP};
1717
use rustc_span::symbol::sym;
@@ -369,10 +369,10 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
369369
bx.inttoptr(usize_llval, ll_t_out)
370370
}
371371
(CastTy::Float, CastTy::Int(IntTy::I)) => {
372-
cast_float_to_int(&mut bx, true, llval, ll_t_in, ll_t_out)
372+
cast_float_to_int(&mut bx, true, llval, ll_t_in, ll_t_out, cast)
373373
}
374374
(CastTy::Float, CastTy::Int(_)) => {
375-
cast_float_to_int(&mut bx, false, llval, ll_t_in, ll_t_out)
375+
cast_float_to_int(&mut bx, false, llval, ll_t_in, ll_t_out, cast)
376376
}
377377
_ => bug!("unsupported cast: {:?} to {:?}", operand.layout.ty, cast.ty),
378378
};
@@ -772,6 +772,7 @@ fn cast_float_to_int<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
772772
x: Bx::Value,
773773
float_ty: Bx::Type,
774774
int_ty: Bx::Type,
775+
int_layout: TyAndLayout<'tcx>,
775776
) -> Bx::Value {
776777
if let Some(false) = bx.cx().sess().opts.debugging_opts.saturating_float_casts {
777778
return if signed { bx.fptosi(x, int_ty) } else { bx.fptoui(x, int_ty) };
@@ -782,8 +783,6 @@ fn cast_float_to_int<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
782783
return try_sat_result;
783784
}
784785

785-
let fptosui_result = if signed { bx.fptosi(x, int_ty) } else { bx.fptoui(x, int_ty) };
786-
787786
let int_width = bx.cx().int_width(int_ty);
788787
let float_width = bx.cx().float_width(float_ty);
789788
// LLVM's fpto[su]i returns undef when the input x is infinite, NaN, or does not fit into the
@@ -870,36 +869,138 @@ fn cast_float_to_int<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
870869
// int_ty::MIN and therefore the return value of int_ty::MIN is correct.
871870
// QED.
872871

873-
// Step 1 was already performed above.
874-
875-
// Step 2: We use two comparisons and two selects, with %s1 being the result:
876-
// %less_or_nan = fcmp ult %x, %f_min
877-
// %greater = fcmp olt %x, %f_max
878-
// %s0 = select %less_or_nan, int_ty::MIN, %fptosi_result
879-
// %s1 = select %greater, int_ty::MAX, %s0
880-
// Note that %less_or_nan uses an *unordered* comparison. This comparison is true if the
881-
// operands are not comparable (i.e., if x is NaN). The unordered comparison ensures that s1
882-
// becomes int_ty::MIN if x is NaN.
883-
// Performance note: Unordered comparison can be lowered to a "flipped" comparison and a
884-
// negation, and the negation can be merged into the select. Therefore, it not necessarily any
885-
// more expensive than a ordered ("normal") comparison. Whether these optimizations will be
886-
// performed is ultimately up to the backend, but at least x86 does perform them.
887-
let less_or_nan = bx.fcmp(RealPredicate::RealULT, x, f_min);
888-
let greater = bx.fcmp(RealPredicate::RealOGT, x, f_max);
889872
let int_max = bx.cx().const_uint_big(int_ty, int_max(signed, int_width));
890873
let int_min = bx.cx().const_uint_big(int_ty, int_min(signed, int_width) as u128);
891-
let s0 = bx.select(less_or_nan, int_min, fptosui_result);
892-
let s1 = bx.select(greater, int_max, s0);
893-
894-
// Step 3: NaN replacement.
895-
// For unsigned types, the above step already yielded int_ty::MIN == 0 if x is NaN.
896-
// Therefore we only need to execute this step for signed integer types.
897-
if signed {
898-
// LLVM has no isNaN predicate, so we use (x == x) instead
899-
let zero = bx.cx().const_uint(int_ty, 0);
900-
let cmp = bx.fcmp(RealPredicate::RealOEQ, x, x);
901-
bx.select(cmp, s1, zero)
874+
let zero = bx.cx().const_uint(int_ty, 0);
875+
876+
// The codegen here differs quite a bit depending on whether our builder's
877+
// `fptosi` and `fptoui` instructions may trap for out-of-bounds values. If
878+
// they don't trap then we can start doing everything inline with a
879+
// `select` instruction because it's ok to execute `fptosi` and `fptoui`
880+
// even if we don't use the results.
881+
if !bx.fptosui_may_trap(x, int_ty) {
882+
// Step 1 ...
883+
let fptosui_result = if signed { bx.fptosi(x, int_ty) } else { bx.fptoui(x, int_ty) };
884+
let less_or_nan = bx.fcmp(RealPredicate::RealULT, x, f_min);
885+
let greater = bx.fcmp(RealPredicate::RealOGT, x, f_max);
886+
887+
// Step 2: We use two comparisons and two selects, with %s1 being the
888+
// result:
889+
// %less_or_nan = fcmp ult %x, %f_min
890+
// %greater = fcmp olt %x, %f_max
891+
// %s0 = select %less_or_nan, int_ty::MIN, %fptosi_result
892+
// %s1 = select %greater, int_ty::MAX, %s0
893+
// Note that %less_or_nan uses an *unordered* comparison. This
894+
// comparison is true if the operands are not comparable (i.e., if x is
895+
// NaN). The unordered comparison ensures that s1 becomes int_ty::MIN if
896+
// x is NaN.
897+
//
898+
// Performance note: Unordered comparison can be lowered to a "flipped"
899+
// comparison and a negation, and the negation can be merged into the
900+
// select. Therefore, it not necessarily any more expensive than a
901+
// ordered ("normal") comparison. Whether these optimizations will be
902+
// performed is ultimately up to the backend, but at least x86 does
903+
// perform them.
904+
let s0 = bx.select(less_or_nan, int_min, fptosui_result);
905+
let s1 = bx.select(greater, int_max, s0);
906+
907+
// Step 3: NaN replacement.
908+
// For unsigned types, the above step already yielded int_ty::MIN == 0 if x is NaN.
909+
// Therefore we only need to execute this step for signed integer types.
910+
if signed {
911+
// LLVM has no isNaN predicate, so we use (x == x) instead
912+
let cmp = bx.fcmp(RealPredicate::RealOEQ, x, x);
913+
bx.select(cmp, s1, zero)
914+
} else {
915+
s1
916+
}
902917
} else {
903-
s1
918+
// In this case we cannot execute `fptosi` or `fptoui` and then later
919+
// discard the result. The builder is telling us that these instructions
920+
// will trap on out-of-bounds values, so we need to use basic blocks and
921+
// control flow to avoid executing the `fptosi` and `fptoui`
922+
// instructions.
923+
//
924+
// The general idea of what we're constructing here is, for f64 -> i32:
925+
//
926+
// ;; block so far... %0 is the argument
927+
// %result = alloca i32, align 4
928+
// %inbound_lower = fcmp oge double %0, 0xC1E0000000000000
929+
// %inbound_upper = fcmp ole double %0, 0x41DFFFFFFFC00000
930+
// ;; match (inbound_lower, inbound_upper) {
931+
// ;; (true, true) => %0 can be converted without trapping
932+
// ;; (false, false) => %0 is a NaN
933+
// ;; (true, false) => %0 is too large
934+
// ;; (false, true) => %0 is too small
935+
// ;; }
936+
// ;;
937+
// ;; The (true, true) check, go to %convert if so.
938+
// %inbounds = and i1 %inbound_lower, %inbound_upper
939+
// br i1 %inbounds, label %convert, label %specialcase
940+
//
941+
// convert:
942+
// %cvt = call i32 @llvm.wasm.trunc.signed.i32.f64(double %0)
943+
// store i32 %cvt, i32* %result, align 4
944+
// br label %done
945+
//
946+
// specialcase:
947+
// ;; Handle the cases where the number is NaN, too large or too small
948+
//
949+
// ;; Either (true, false) or (false, true)
950+
// %is_not_nan = or i1 %inbound_lower, %inbound_upper
951+
// ;; Figure out which saturated value we are interested in if not `NaN`
952+
// %saturated = select i1 %inbound_lower, i32 2147483647, i32 -2147483648
953+
// ;; Figure out between saturated and NaN representations
954+
// %result_nan = select i1 %is_not_nan, i32 %saturated, i32 0
955+
// store i32 %result_nan, i32* %result, align 4
956+
// br label %done
957+
//
958+
// done:
959+
// %r = load i32, i32* %result, align 4
960+
// ;; ...
961+
let done = bx.build_sibling_block("float_cast_done");
962+
let mut convert = bx.build_sibling_block("float_cast_convert");
963+
let mut specialcase = bx.build_sibling_block("float_cast_specialcase");
964+
965+
let result = PlaceRef::alloca(bx, int_layout);
966+
result.storage_live(bx);
967+
968+
// Use control flow to figure out whether we can execute `fptosi` in a
969+
// basic block, or whether we go to a different basic block to implement
970+
// the saturating logic.
971+
let inbound_lower = bx.fcmp(RealPredicate::RealOGE, x, f_min);
972+
let inbound_upper = bx.fcmp(RealPredicate::RealOLE, x, f_max);
973+
let inbounds = bx.and(inbound_lower, inbound_upper);
974+
bx.cond_br(inbounds, convert.llbb(), specialcase.llbb());
975+
976+
// Translation of the `convert` basic block
977+
let cvt = if signed { convert.fptosi(x, int_ty) } else { convert.fptoui(x, int_ty) };
978+
convert.store(cvt, result.llval, result.align);
979+
convert.br(done.llbb());
980+
981+
// Translation of the `specialcase` basic block. Note that like above
982+
// we try to be a bit clever here for unsigned conversions. In those
983+
// cases the `int_min` is zero so we don't need two select instructions,
984+
// just one to choose whether we need `int_max` or not. If
985+
// `inbound_lower` is true then we're guaranteed to not be `NaN` and
986+
// since we're greater than zero we must be saturating to `int_max`. If
987+
// `inbound_lower` is false then we're either NaN or less than zero, so
988+
// we saturate to zero.
989+
let result_nan = if signed {
990+
let is_not_nan = specialcase.or(inbound_lower, inbound_upper);
991+
let saturated = specialcase.select(inbound_lower, int_max, int_min);
992+
specialcase.select(is_not_nan, saturated, zero)
993+
} else {
994+
specialcase.select(inbound_lower, int_max, int_min)
995+
};
996+
specialcase.store(result_nan, result.llval, result.align);
997+
specialcase.br(done.llbb());
998+
999+
// Translation of the `done` basic block, positioning ourselves to
1000+
// continue from that point as well.
1001+
*bx = done;
1002+
let ret = bx.load(result.llval, result.align);
1003+
result.storage_dead(bx);
1004+
ret
9041005
}
9051006
}

src/librustc_codegen_ssa/traits/builder.rs

+1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ pub trait BuilderMethods<'a, 'tcx>:
160160
fn sext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value;
161161
fn fptoui_sat(&mut self, val: Self::Value, dest_ty: Self::Type) -> Option<Self::Value>;
162162
fn fptosi_sat(&mut self, val: Self::Value, dest_ty: Self::Type) -> Option<Self::Value>;
163+
fn fptosui_may_trap(&self, val: Self::Value, dest_ty: Self::Type) -> bool;
163164
fn fptoui(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value;
164165
fn fptosi(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value;
165166
fn uitofp(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value;

0 commit comments

Comments
 (0)