Skip to content

Commit 81a0cc6

Browse files
committed
only use accurate powf function
The powi intrinsic optimization over calling powf is that it is inaccurate. We don't need that. When it is equally accurate (e.g. tiny constant powers), LLVM will already recognize and optimize any call to a function named `powf`, and produce the same speedup. fix #19872
1 parent 6eaec18 commit 81a0cc6

8 files changed

+25
-52
lines changed

base/fastmath.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ module FastMath
2323

2424
export @fastmath
2525

26-
import Core.Intrinsics: box, unbox, powi_llvm, sqrt_llvm_fast
26+
import Core.Intrinsics: box, unbox, powf_llvm, sqrt_llvm_fast
2727

2828
const fast_op =
2929
Dict(# basic arithmetic
@@ -250,9 +250,9 @@ end
250250

251251
# builtins
252252

253-
pow_fast{T<:FloatTypes}(x::T, y::Integer) = pow_fast(x, Int32(y))
254-
pow_fast{T<:FloatTypes}(x::T, y::Int32) =
255-
box(T, Base.powi_llvm(unbox(T,x), unbox(Int32,y)))
253+
pow_fast{T<:FloatTypes}(x::T, y::Integer) = pow_fast(x, convert(T, y))
254+
pow_fast{T<:FloatTypes}(x::T, y::T) =
255+
box(T, Base.powf_llvm(unbox(T,x), unbox(T,y)))
256256

257257
# TODO: Change sqrt_llvm intrinsic to avoid nan checking; add nan
258258
# checking to sqrt in math.jl; remove sqrt_llvm_fast intrinsic

base/math.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import Base: log, exp, sin, cos, tan, sinh, cosh, tanh, asin,
3131
significand_mask, significand_bits, exponent_bits, exponent_bias
3232

3333

34-
import Core.Intrinsics: sqrt_llvm, box, unbox, powi_llvm
34+
import Core.Intrinsics: sqrt_llvm, box, unbox, powf_llvm
3535

3636
# non-type specific math functions
3737

@@ -648,9 +648,9 @@ end
648648
^(x::Float32, y::Float32) = nan_dom_err(ccall((:powf,libm), Float32, (Float32,Float32), x, y), x+y)
649649

650650
^(x::Float64, y::Integer) =
651-
box(Float64, powi_llvm(unbox(Float64,x), unbox(Int32,Int32(y))))
651+
box(Float64, powf_llvm(unbox(Float64,x), unbox(Float64,Float64(y))))
652652
^(x::Float32, y::Integer) =
653-
box(Float32, powi_llvm(unbox(Float32,x), unbox(Int32,Int32(y))))
653+
box(Float32, powf_llvm(unbox(Float32,x), unbox(Float32,Float32(y))))
654654
^(x::Float16, y::Integer) = Float16(Float32(x)^y)
655655

656656
function angle_restrict_symm(theta)

src/codegen.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -389,10 +389,8 @@ static Function *expect_func;
389389
static Function *jldlsym_func;
390390
static Function *jlnewbits_func;
391391
static Function *jltypeassert_func;
392-
#if JL_LLVM_VERSION < 30600
393392
static Function *jlpow_func;
394393
static Function *jlpowf_func;
395-
#endif
396394
//static Function *jlgetnthfield_func;
397395
static Function *jlgetnthfieldchecked_func;
398396
//static Function *jlsetnthfield_func;
@@ -5946,7 +5944,6 @@ static void init_julia_llvm_env(Module *m)
59465944
"jl_gc_diff_total_bytes", m);
59475945
add_named_global(diff_gc_total_bytes_func, *jl_gc_diff_total_bytes);
59485946

5949-
#if JL_LLVM_VERSION < 30600
59505947
Type *powf_type[2] = { T_float32, T_float32 };
59515948
jlpowf_func = Function::Create(FunctionType::get(T_float32, powf_type, false),
59525949
Function::ExternalLinkage,
@@ -5964,7 +5961,7 @@ static void init_julia_llvm_env(Module *m)
59645961
&pow,
59655962
#endif
59665963
false);
5967-
#endif
5964+
59685965
std::vector<Type*> array_owner_args(0);
59695966
array_owner_args.push_back(T_pjlvalue);
59705967
jlarray_data_owner_func =

src/intrinsics.cpp

+5-14
Original file line numberDiff line numberDiff line change
@@ -1452,23 +1452,14 @@ static Value *emit_untyped_intrinsic(intrinsic f, Value *x, Value *y, Value *z,
14521452
ArrayRef<Type*>(x->getType())),
14531453
x);
14541454
}
1455-
case powi_llvm: {
1455+
case powf_llvm: {
14561456
x = FP(x);
1457-
y = JL_INT(y);
1458-
Type *tx = x->getType(); // TODO: LLVM expects this to be i32
1459-
#if JL_LLVM_VERSION >= 30600
1460-
Type *ts[1] = { tx };
1461-
Value *powi = Intrinsic::getDeclaration(jl_Module, Intrinsic::powi,
1462-
ArrayRef<Type*>(ts));
1457+
y = FP(y);
1458+
Function *powf = (x->getType() == T_float64 ? jlpow_func : jlpowf_func);
14631459
#if JL_LLVM_VERSION >= 30700
1464-
return builder.CreateCall(powi, {x, y});
1460+
return builder.CreateCall(prepare_call(powf), {x, y});
14651461
#else
1466-
return builder.CreateCall2(powi, x, y);
1467-
#endif
1468-
#else
1469-
// issue #6506
1470-
return builder.CreateCall2(prepare_call(tx == T_float64 ? jlpow_func : jlpowf_func),
1471-
x, builder.CreateSIToFP(y, tx));
1462+
return builder.CreateCall2(prepare_call(powf), x, y);
14721463
#endif
14731464
}
14741465
case sqrt_llvm_fast: {

src/intrinsics.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
ADD_I(trunc_llvm, 1) \
9393
ADD_I(rint_llvm, 1) \
9494
ADD_I(sqrt_llvm, 1) \
95-
ADD_I(powi_llvm, 2) \
95+
ADD_I(powf_llvm, 2) \
9696
ALIAS(sqrt_llvm_fast, sqrt_llvm) \
9797
/* pointer access */ \
9898
ADD_I(pointerref, 3) \

src/julia_internal.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ JL_DLLEXPORT jl_value_t *jl_floor_llvm(jl_value_t *a);
678678
JL_DLLEXPORT jl_value_t *jl_trunc_llvm(jl_value_t *a);
679679
JL_DLLEXPORT jl_value_t *jl_rint_llvm(jl_value_t *a);
680680
JL_DLLEXPORT jl_value_t *jl_sqrt_llvm(jl_value_t *a);
681-
JL_DLLEXPORT jl_value_t *jl_powi_llvm(jl_value_t *a, jl_value_t *b);
681+
JL_DLLEXPORT jl_value_t *jl_powf_llvm(jl_value_t *a, jl_value_t *b);
682682
JL_DLLEXPORT jl_value_t *jl_abs_float(jl_value_t *a);
683683
JL_DLLEXPORT jl_value_t *jl_copysign_float(jl_value_t *a, jl_value_t *b);
684684
JL_DLLEXPORT jl_value_t *jl_flipsign_int(jl_value_t *a, jl_value_t *b);

src/runtime_intrinsics.c

+3-25
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,8 @@ bi_iintrinsic_fast(jl_LLVMFlipSign, flipsign, flipsign_int, )
947947
*pr = fp_select(a, sqrt)
948948
#define copysign_float(a, b) \
949949
fp_select2(a, b, copysign)
950+
#define pow_float(a, b) \
951+
fp_select2(a, b, pow)
950952

951953
un_fintrinsic(abs_float,abs_float)
952954
bi_fintrinsic(copysign_float,copysign_float)
@@ -955,31 +957,7 @@ un_fintrinsic(floor_float,floor_llvm)
955957
un_fintrinsic(trunc_float,trunc_llvm)
956958
un_fintrinsic(rint_float,rint_llvm)
957959
un_fintrinsic(sqrt_float,sqrt_llvm)
958-
959-
JL_DLLEXPORT jl_value_t *jl_powi_llvm(jl_value_t *a, jl_value_t *b)
960-
{
961-
jl_ptls_t ptls = jl_get_ptls_states();
962-
jl_value_t *ty = jl_typeof(a);
963-
if (!jl_is_bitstype(ty))
964-
jl_error("powi_llvm: a is not a bitstype");
965-
if (!jl_is_bitstype(jl_typeof(b)) || jl_datatype_size(jl_typeof(b)) != 4)
966-
jl_error("powi_llvm: b is not a 32-bit bitstype");
967-
int sz = jl_datatype_size(ty);
968-
jl_value_t *newv = jl_gc_alloc(ptls, sz, ty);
969-
void *pa = jl_data_ptr(a), *pr = jl_data_ptr(newv);
970-
switch (sz) {
971-
/* choose the right size c-type operation */
972-
case 4:
973-
*(float*)pr = powf(*(float*)pa, (float)jl_unbox_int32(b));
974-
break;
975-
case 8:
976-
*(double*)pr = pow(*(double*)pa, (double)jl_unbox_int32(b));
977-
break;
978-
default:
979-
jl_error("powi_llvm: runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64");
980-
}
981-
return newv;
982-
}
960+
bi_fintrinsic(pow_float,powf_llvm)
983961

984962
JL_DLLEXPORT jl_value_t *jl_select_value(jl_value_t *isfalse, jl_value_t *a, jl_value_t *b)
985963
{

test/math.jl

+7
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,13 @@ end
964964
end
965965
end
966966

967+
@testset "issue #19872" begin
968+
f19872(x) = x ^ 3
969+
@test issubnormal(2.0 ^ (-1024))
970+
@test f19872(2.0) === 8.0
971+
@test !issubnormal(0.0)
972+
end
973+
967974
@test Base.Math.f32(complex(1.0,1.0)) == complex(Float32(1.),Float32(1.))
968975
@test Base.Math.f16(complex(1.0,1.0)) == complex(Float16(1.),Float16(1.))
969976

0 commit comments

Comments
 (0)