Skip to content

Commit e577f8e

Browse files
committed
only use accurate powf function
The powi intrinsic optimization over calling powf is that it is inaccurate. We don't need that. When it is equally accurate (e.g. tiny constant powers), LLVM will already recognize and optimize any call to a function named `powf`, and produce the same speedup. fix #19872
1 parent 8a75199 commit e577f8e

10 files changed

+27
-96
lines changed

base/fastmath.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ module FastMath
2323

2424
export @fastmath
2525

26-
import Core.Intrinsics: powi_llvm, sqrt_llvm_fast, neg_float_fast,
26+
import Core.Intrinsics: sqrt_llvm_fast, neg_float_fast,
2727
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast, rem_float_fast,
2828
eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast
2929

@@ -243,8 +243,8 @@ end
243243

244244
# builtins
245245

246-
pow_fast{T<:FloatTypes}(x::T, y::Integer) = pow_fast(x, Int32(y))
247-
pow_fast{T<:FloatTypes}(x::T, y::Int32) = Base.powi_llvm(x, y)
246+
pow_fast(x::Float32, y::Integer) = ccall("llvm.powi.f32", llvmcall, Float32, (Float32, Int32), x, y)
247+
pow_fast(x::Float64, y::Integer) = ccall("llvm.powi.f64", llvmcall, Float64, (Float64, Int32), x, y)
248248

249249
# TODO: Change sqrt_llvm intrinsic to avoid nan checking; add nan
250250
# checking to sqrt in math.jl; remove sqrt_llvm_fast intrinsic

base/inference.jl

-1
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,6 @@ add_tfunc(floor_llvm, 1, 1, math_tfunc)
447447
add_tfunc(trunc_llvm, 1, 1, math_tfunc)
448448
add_tfunc(rint_llvm, 1, 1, math_tfunc)
449449
add_tfunc(sqrt_llvm, 1, 1, math_tfunc)
450-
add_tfunc(powi_llvm, 2, 2, math_tfunc)
451450
add_tfunc(sqrt_llvm_fast, 1, 1, math_tfunc)
452451
## same-type comparisons ##
453452
cmp_tfunc(x::ANY, y::ANY) = Bool

base/math.jl

+11-12
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ using Base: sign_mask, exponent_mask, exponent_one, exponent_bias,
3232
exponent_half, exponent_max, exponent_raw_max, fpinttype,
3333
significand_mask, significand_bits, exponent_bits
3434

35-
using Core.Intrinsics: sqrt_llvm, powi_llvm
35+
using Core.Intrinsics: sqrt_llvm
3636

3737
# non-type specific math functions
3838

@@ -308,6 +308,8 @@ exp10(x::Float32) = 10.0f0^x
308308
exp10(x::Integer) = exp10(float(x))
309309

310310
# utility for converting NaN return to DomainError
311+
# the branch in nan_dom_err prevents its callers from inlining, so be sure to force it
312+
# until the heuristics can be improved
311313
@inline nan_dom_err(f, x) = isnan(f) & !isnan(x) ? throw(DomainError()) : f
312314

313315
# functions that return NaN on non-NaN argument for domain error
@@ -414,9 +416,9 @@ log1p(x)
414416
for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
415417
:lgamma, :log1p)
416418
@eval begin
417-
($f)(x::Float64) = nan_dom_err(ccall(($(string(f)),libm), Float64, (Float64,), x), x)
418-
($f)(x::Float32) = nan_dom_err(ccall(($(string(f,"f")),libm), Float32, (Float32,), x), x)
419-
($f)(x::Real) = ($f)(float(x))
419+
@inline ($f)(x::Float64) = nan_dom_err(ccall(($(string(f)),libm), Float64, (Float64,), x), x)
420+
@inline ($f)(x::Float32) = nan_dom_err(ccall(($(string(f, "f")), libm), Float32, (Float32,), x), x)
421+
@inline ($f)(x::Real) = ($f)(float(x))
420422
end
421423
end
422424

@@ -677,14 +679,11 @@ function modf(x::Float64)
677679
f, _modf_temp[]
678680
end
679681

680-
^(x::Float64, y::Float64) = nan_dom_err(ccall((:pow,libm), Float64, (Float64,Float64), x, y), x+y)
681-
^(x::Float32, y::Float32) = nan_dom_err(ccall((:powf,libm), Float32, (Float32,Float32), x, y), x+y)
682-
683-
^(x::Float64, y::Integer) = x^Int32(y)
684-
^(x::Float64, y::Int32) = powi_llvm(x, y)
685-
^(x::Float32, y::Integer) = x^Int32(y)
686-
^(x::Float32, y::Int32) = powi_llvm(x, y)
687-
^(x::Float16, y::Integer) = Float16(Float32(x)^y)
682+
@inline ^(x::Float64, y::Float64) = nan_dom_err(ccall("llvm.pow.f64", llvmcall, Float64, (Float64, Float64), x, y), x + y)
683+
@inline ^(x::Float32, y::Float32) = nan_dom_err(ccall("llvm.pow.f32", llvmcall, Float32, (Float32, Float32), x, y), x + y)
684+
@inline ^(x::Float64, y::Integer) = x ^ Float64(y)
685+
@inline ^(x::Float32, y::Integer) = x ^ Float32(y)
686+
@inline ^(x::Float16, y::Integer) = Float16(Float32(x) ^ Float32(y))
688687

689688
function angle_restrict_symm(theta)
690689
const P1 = 4 * 7.8539812564849853515625e-01

base/special/exp.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ exp_small_thres(::Type{Float32}) = 2.0f0^-13
6565
Compute the natural base exponential of `x`, in other words ``e^x``.
6666
"""
6767
function exp{T<:Union{Float32,Float64}}(x::T)
68+
^ = Base.FastMath.pow_fast
69+
# we don't need the accurate power function here since we only compute 2^c,
70+
# where c is a constant
71+
6872
xa = reinterpret(Unsigned, x) & ~sign_mask(T)
6973
xsb = signbit(x)
7074

@@ -114,13 +118,13 @@ function exp{T<:Union{Float32,Float64}}(x::T)
114118
# scale back
115119
if k > -significand_bits(T)
116120
# multiply by 2.0 first to prevent overflow, which helps extends the range
117-
k == exponent_max(T) && return y*T(2.0)*T(2.0)^(exponent_max(T) - 1)
121+
k == exponent_max(T) && return y * T(2.0) * T(2.0)^(exponent_max(T) - 1)
118122
twopk = reinterpret(T, rem(exponent_bias(T) + k, fpinttype(T)) << significand_bits(T))
119123
return y*twopk
120124
else
121125
# add significand_bits(T) + 1 to lift the range outside the subnormals
122126
twopk = reinterpret(T, rem(exponent_bias(T) + significand_bits(T) + 1 + k, fpinttype(T)) << significand_bits(T))
123-
return y*twopk*T(2.0)^(-significand_bits(T) - 1)
127+
return y * twopk * T(2.0)^(-significand_bits(T) - 1)
124128
end
125129
elseif xa < reinterpret(Unsigned, exp_small_thres(T)) # |x| < exp_small_thres
126130
# Taylor approximation for small x

src/codegen.cpp

-23
Original file line numberDiff line numberDiff line change
@@ -395,10 +395,6 @@ static Function *jldlsym_func;
395395
static Function *jlnewbits_func;
396396
static Function *jltypeassert_func;
397397
static Function *jldepwarnpi_func;
398-
#if JL_LLVM_VERSION < 30600
399-
static Function *jlpow_func;
400-
static Function *jlpowf_func;
401-
#endif
402398
//static Function *jlgetnthfield_func;
403399
static Function *jlgetnthfieldchecked_func;
404400
//static Function *jlsetnthfield_func;
@@ -5985,25 +5981,6 @@ static void init_julia_llvm_env(Module *m)
59855981
"jl_gc_diff_total_bytes", m);
59865982
add_named_global(diff_gc_total_bytes_func, *jl_gc_diff_total_bytes);
59875983

5988-
#if JL_LLVM_VERSION < 30600
5989-
Type *powf_type[2] = { T_float32, T_float32 };
5990-
jlpowf_func = Function::Create(FunctionType::get(T_float32, powf_type, false),
5991-
Function::ExternalLinkage,
5992-
"powf", m);
5993-
add_named_global(jlpowf_func, &powf, false);
5994-
5995-
Type *pow_type[2] = { T_float64, T_float64 };
5996-
jlpow_func = Function::Create(FunctionType::get(T_float64, pow_type, false),
5997-
Function::ExternalLinkage,
5998-
"pow", m);
5999-
add_named_global(jlpow_func,
6000-
#ifdef _COMPILER_MICROSOFT_
6001-
static_cast<double (*)(double, double)>(&pow),
6002-
#else
6003-
&pow,
6004-
#endif
6005-
false);
6006-
#endif
60075984
std::vector<Type*> array_owner_args(0);
60085985
array_owner_args.push_back(T_pjlvalue);
60095986
jlarray_data_owner_func =

src/intrinsics.cpp

-28
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ static void jl_init_intrinsic_functions_codegen(Module *m)
7171
float_func[rint_llvm] = true;
7272
float_func[sqrt_llvm] = true;
7373
float_func[sqrt_llvm_fast] = true;
74-
float_func[powi_llvm] = true;
7574
}
7675

7776
extern "C"
@@ -915,33 +914,6 @@ static jl_cgval_t emit_intrinsic(intrinsic f, jl_value_t **args, size_t nargs,
915914
return mark_julia_type(ans, false, x.typ, ctx);
916915
}
917916

918-
case powi_llvm: {
919-
const jl_cgval_t &x = argv[0];
920-
const jl_cgval_t &y = argv[1];
921-
if (!jl_is_bitstype(x.typ) || !jl_is_bitstype(y.typ) || jl_datatype_size(y.typ) != 4)
922-
return emit_runtime_call(f, argv, nargs, ctx);
923-
Type *xt = FLOATT(bitstype_to_llvm(x.typ));
924-
Type *yt = T_int32;
925-
if (!xt)
926-
return emit_runtime_call(f, argv, nargs, ctx);
927-
928-
Value *xv = emit_unbox(xt, x, x.typ);
929-
Value *yv = emit_unbox(yt, y, y.typ);
930-
#if JL_LLVM_VERSION >= 30600
931-
Value *powi = Intrinsic::getDeclaration(jl_Module, Intrinsic::powi, makeArrayRef(xt));
932-
#if JL_LLVM_VERSION >= 30700
933-
Value *ans = builder.CreateCall(powi, {xv, yv});
934-
#else
935-
Value *ans = builder.CreateCall2(powi, xv, yv);
936-
#endif
937-
#else
938-
// issue #6506
939-
Value *ans = builder.CreateCall2(prepare_call(xt == T_float64 ? jlpow_func : jlpowf_func),
940-
xv, builder.CreateSIToFP(yv, xt));
941-
#endif
942-
return mark_julia_type(ans, false, x.typ, ctx);
943-
}
944-
945917
default: {
946918
assert(nargs >= 1 && "invalid nargs for intrinsic call");
947919
const jl_cgval_t &xinfo = argv[0];

src/intrinsics.h

-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@
9191
ADD_I(trunc_llvm, 1) \
9292
ADD_I(rint_llvm, 1) \
9393
ADD_I(sqrt_llvm, 1) \
94-
ADD_I(powi_llvm, 2) \
9594
ALIAS(sqrt_llvm_fast, sqrt_llvm) \
9695
/* pointer access */ \
9796
ADD_I(pointerref, 3) \

src/julia_internal.h

-1
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,6 @@ JL_DLLEXPORT jl_value_t *jl_floor_llvm(jl_value_t *a);
677677
JL_DLLEXPORT jl_value_t *jl_trunc_llvm(jl_value_t *a);
678678
JL_DLLEXPORT jl_value_t *jl_rint_llvm(jl_value_t *a);
679679
JL_DLLEXPORT jl_value_t *jl_sqrt_llvm(jl_value_t *a);
680-
JL_DLLEXPORT jl_value_t *jl_powi_llvm(jl_value_t *a, jl_value_t *b);
681680
JL_DLLEXPORT jl_value_t *jl_abs_float(jl_value_t *a);
682681
JL_DLLEXPORT jl_value_t *jl_copysign_float(jl_value_t *a, jl_value_t *b);
683682
JL_DLLEXPORT jl_value_t *jl_flipsign_int(jl_value_t *a, jl_value_t *b);

src/runtime_intrinsics.c

-25
Original file line numberDiff line numberDiff line change
@@ -956,31 +956,6 @@ un_fintrinsic(trunc_float,trunc_llvm)
956956
un_fintrinsic(rint_float,rint_llvm)
957957
un_fintrinsic(sqrt_float,sqrt_llvm)
958958

959-
JL_DLLEXPORT jl_value_t *jl_powi_llvm(jl_value_t *a, jl_value_t *b)
960-
{
961-
jl_ptls_t ptls = jl_get_ptls_states();
962-
jl_value_t *ty = jl_typeof(a);
963-
if (!jl_is_bitstype(ty))
964-
jl_error("powi_llvm: a is not a bitstype");
965-
if (!jl_is_bitstype(jl_typeof(b)) || jl_datatype_size(jl_typeof(b)) != 4)
966-
jl_error("powi_llvm: b is not a 32-bit bitstype");
967-
int sz = jl_datatype_size(ty);
968-
jl_value_t *newv = jl_gc_alloc(ptls, sz, ty);
969-
void *pa = jl_data_ptr(a), *pr = jl_data_ptr(newv);
970-
switch (sz) {
971-
/* choose the right size c-type operation */
972-
case 4:
973-
*(float*)pr = powf(*(float*)pa, (float)jl_unbox_int32(b));
974-
break;
975-
case 8:
976-
*(double*)pr = pow(*(double*)pa, (double)jl_unbox_int32(b));
977-
break;
978-
default:
979-
jl_error("powi_llvm: runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64");
980-
}
981-
return newv;
982-
}
983-
984959
JL_DLLEXPORT jl_value_t *jl_select_value(jl_value_t *isfalse, jl_value_t *a, jl_value_t *b)
985960
{
986961
JL_TYPECHK(isfalse, bool, isfalse);

test/math.jl

+7
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,13 @@ end
996996
end
997997
end
998998

999+
@testset "issue #19872" begin
1000+
f19872(x) = x ^ 3
1001+
@test issubnormal(2.0 ^ (-1024))
1002+
@test f19872(2.0) === 8.0
1003+
@test !issubnormal(0.0)
1004+
end
1005+
9991006
@test Base.Math.f32(complex(1.0,1.0)) == complex(Float32(1.),Float32(1.))
10001007
@test Base.Math.f16(complex(1.0,1.0)) == complex(Float16(1.),Float16(1.))
10011008

0 commit comments

Comments
 (0)