Skip to content

Commit b72c604

Browse files
committed
Implement AArch64 ABI
1 parent f8c6aac commit b72c604

8 files changed

+308
-10
lines changed

src/abi_aarch64.cpp

+282
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
// This file is a part of Julia. License is MIT: http://julialang.org/license
2+
3+
//===----------------------------------------------------------------------===//
4+
//
5+
// The ABI implementation used for AArch64 targets.
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// The Procedure Call Standard can be found here:
10+
// http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
namespace {
15+
16+
typedef bool AbiState;
17+
static const AbiState default_abi_state = 0;
18+
19+
static Type *get_llvm_fptype(jl_datatype_t *dt)
20+
{
21+
// Assume jl_is_datatype(dt) && !jl_is_abstracttype(dt)
22+
if (dt->mutabl || jl_datatype_nfields(dt) >= 0)
23+
return NULL;
24+
Type *lltype;
25+
// Check size first since it's cheaper.
26+
switch (dt->size) {
27+
case 2:
28+
lltype = T_float16;
29+
break;
30+
case 4:
31+
lltype = T_float32;
32+
break;
33+
case 8:
34+
lltype = T_float64;
35+
break;
36+
case 16:
37+
lltype = T_float128;
38+
break;
39+
default:
40+
return NULL;
41+
}
42+
return jl_is_floattype(dt) ? lltype : NULL;
43+
}
44+
45+
// Whether a type is a homogeneous floating-point aggregates (HFA) or a
46+
// homogeneous short-vector aggregates (HVA). Returns the number of members.
47+
// We only handle HFA of HP, SP and DP here since these are the only ones we
48+
// have (no QP).
49+
static size_t isHFAorHVA(jl_datatype_t *dt)
50+
{
51+
// Assume jl_is_datatype(dt) && !jl_is_abstracttype(dt)
52+
53+
// An Homogeneous Floating-point Aggregate (HFA) is an Homogeneous Aggregate
54+
// with a Fundamental Data Type that is a Floating-Point type and at most
55+
// four uniquely addressable members.
56+
// An Homogeneous Short-Vector Aggregate (HVA) is an Homogeneous Aggregate
57+
// with a Fundamental Data Type that is a Short-Vector type and at most four
58+
// uniquely addressable members.
59+
size_t members = jl_datatype_nfields(dt);
60+
if (members < 1 || members > 4)
61+
return 0;
62+
// There's at least one member
63+
jl_value_t *ftype = jl_field_type(dt, 0);
64+
if (!get_llvm_fptype((jl_datatype_t*)ftype))
65+
return 0;
66+
for (size_t i = 1;i < members;i++) {
67+
if (ftype != jl_field_type(dt, i)) {
68+
return 0;
69+
}
70+
}
71+
return members;
72+
}
73+
74+
void needPassByRef(AbiState*, jl_value_t *ty, bool *byRef, bool*)
75+
{
76+
// Assume jl_is_datatype(ty) && !jl_is_abstracttype(ty)
77+
jl_datatype_t *dt = (jl_datatype_t*)ty;
78+
// B.2
79+
// If the argument type is an HFA or an HVA, then the argument is used
80+
// unmodified.
81+
if (isHFAorHVA(dt))
82+
return;
83+
// B.3
84+
// If the argument type is a Composite Type that is larger than 16 bytes,
85+
// then the argument is copied to memory allocated by the caller and the
86+
// argument is replaced by a pointer to the copy.
87+
// We only check for the total size and not whether it is a composite type
88+
// since there's no corresponding C type and we just treat such large
89+
// bitstype as a composite type of the right size.
90+
*byRef = dt->size > 16;
91+
// B.4
92+
// If the argument type is a Composite Type then the size of the argument
93+
// is rounded up to the nearest multiple of 8 bytes.
94+
}
95+
96+
bool need_private_copy(jl_value_t*, bool)
97+
{
98+
return false;
99+
}
100+
101+
// Determine which kind of register the argument will be passed in and
102+
// if the argument has to be passed on stack (including by reference).
103+
//
104+
// If the argument should be passed in SIMD and floating-point registers,
105+
// we may need to rewrite the argument types to [n x ftype].
106+
// If the argument should be passed in general purpose registers, we may need
107+
// to rewrite the argument types to [n x i64].
108+
//
109+
// If the argument has to be passed on stack, we need to use sret.
110+
//
111+
// All the out parameters should be default to `false`.
112+
static void classify_arg(jl_value_t *ty, bool *fpreg, bool *onstack,
113+
bool *need_rewrite)
114+
{
115+
// Assume jl_is_datatype(ty) && !jl_is_abstracttype(ty)
116+
jl_datatype_t *dt = (jl_datatype_t*)ty;
117+
118+
// Based on section 5.4 C of the Procedure Call Standard
119+
// C.1
120+
// If the argument is a Half-, Single-, Double- or Quad- precision
121+
// Floating-point or Short Vector Type and the NSRN is less than 8, then
122+
// the argument is allocated to the least significant bits of register
123+
// v[NSRN]. The NSRN is incremented by one. The argument has now been
124+
// allocated.
125+
// Note that this is missing QP float as well as short vector types since we
126+
// don't really have those types.
127+
if (get_llvm_fptype(dt)) {
128+
*fpreg = true;
129+
return;
130+
}
131+
132+
// C.2
133+
// If the argument is an HFA or an HVA and there are sufficient
134+
// unallocated SIMD and Floating-point registers (NSRN + number of
135+
// members <= 8), then the argument is allocated to SIMD and
136+
// Floating-point Registers (with one register per member of the HFA
137+
// or HVA). The NSRN is incremented by the number of registers used.
138+
// The argument has now been allocated.
139+
if (isHFAorHVA(dt)) { // HFA and HVA have <= 4 members
140+
*fpreg = true;
141+
*need_rewrite = true;
142+
return;
143+
}
144+
145+
// Check if the argument needs to be passed by reference. This should be
146+
// done before starting step C but we do this here to avoid checking for
147+
// HFA and HVA twice. We don't check whether it is a composite type.
148+
// See `needPassByRef` above.
149+
if (dt->size > 16) {
150+
*onstack = true;
151+
return;
152+
}
153+
154+
// C.3
155+
// If the argument is an HFA or an HVA then the NSRN is set to 8 and the
156+
// size of the argument is rounded up to the nearest multiple of 8 bytes.
157+
// C.4
158+
// If the argument is an HFA, an HVA, a Quad-precision Floating-point or
159+
// Short Vector Type then the NSAA is rounded up to the larger of 8 or
160+
// the Natural Alignment of the argument’s type.
161+
// C.5
162+
// If the argument is a Half- or Single- precision Floating Point type,
163+
// then the size of the argument is set to 8 bytes. The effect is as if
164+
// the argument had been copied to the least significant bits of a 64-bit
165+
// register and the remaining bits filled with unspecified values.
166+
// C.6
167+
// If the argument is an HFA, an HVA, a Half-, Single-, Double- or
168+
// Quad- precision Floating-point or Short Vector Type, then the argument
169+
// is copied to memory at the adjusted NSAA. The NSAA is incremented
170+
// by the size of the argument. The argument has now been allocated.
171+
// <already included in the C.2 case above>
172+
// C.7
173+
// If the argument is an Integral or Pointer Type, the size of the
174+
// argument is less than or equal to 8 bytes and the NGRN is less than 8,
175+
// the argument is copied to the least significant bits in x[NGRN].
176+
// The NGRN is incremented by one. The argument has now been allocated.
177+
// Here we treat any bitstype of the right size as integers or pointers
178+
// This is needed for types like Cstring which should be treated as
179+
// pointers. We don't need to worry about floating points here since they
180+
// are handled above.
181+
if (jl_is_immutable(dt) && jl_datatype_nfields(dt) == 0 &&
182+
(dt->size == 1 || dt->size == 2 || dt->size == 4 ||
183+
dt->size == 8 || dt->size == 16))
184+
return;
185+
186+
// C.8
187+
// If the argument has an alignment of 16 then the NGRN is rounded up to
188+
// the next even number.
189+
// C.9
190+
// If the argument is an Integral Type, the size of the argument is equal
191+
// to 16 and the NGRN is less than 7, the argument is copied to x[NGRN]
192+
// and x[NGRN+1]. x[NGRN] shall contain the lower addressed double-word
193+
// of the memory representation of the argument. The NGRN is incremented
194+
// by two. The argument has now been allocated.
195+
// <merged into C.7 above>
196+
// C.10
197+
// If the argument is a Composite Type and the size in double-words of
198+
// the argument is not more than 8 minus NGRN, then the argument is
199+
// copied into consecutive general-purpose registers, starting at x[NGRN].
200+
// The argument is passed as though it had been loaded into the registers
201+
// from a double-word-aligned address with an appropriate sequence of LDR
202+
// instructions loading consecutive registers from memory (the contents of
203+
// any unused parts of the registers are unspecified by this standard).
204+
// The NGRN is incremented by the number of registers used. The argument
205+
// has now been allocated.
206+
// We don't check for composite types here since the ones that have
207+
// corresponding C types are already handled and we just treat the ones
208+
// with weird size as a black box composite type.
209+
// The type can fit in 8 x 8 bytes since it is handled by
210+
// need_pass_by_ref otherwise.
211+
*need_rewrite = true;
212+
213+
// C.11
214+
// The NGRN is set to 8.
215+
// C.12
216+
// The NSAA is rounded up to the larger of 8 or the Natural Alignment
217+
// of the argument’s type.
218+
// C.13
219+
// If the argument is a composite type then the argument is copied to
220+
// memory at the adjusted NSAA. The NSAA is incremented by the size of
221+
// the argument. The argument has now been allocated.
222+
// <handled by C.10 above>
223+
// C.14
224+
// If the size of the argument is less than 8 bytes then the size of the
225+
// argument is set to 8 bytes. The effect is as if the argument was
226+
// copied to the least significant bits of a 64-bit register and the
227+
// remaining bits filled with unspecified values.
228+
// C.15
229+
// The argument is copied to memory at the adjusted NSAA. The NSAA is
230+
// incremented by the size of the argument. The argument has now been
231+
// allocated.
232+
// <handled by C.10 above>
233+
}
234+
235+
bool use_sret(AbiState*, jl_value_t *ty)
236+
{
237+
// Assume jl_is_datatype(ty) && !jl_is_abstracttype(ty)
238+
// Section 5.5
239+
// If the type, T, of the result of a function is such that
240+
//
241+
// void func(T arg)
242+
//
243+
// would require that arg be passed as a value in a register (or set of
244+
// registers) according to the rules in section 5.4 Parameter Passing,
245+
// then the result is returned in the same registers as would be used for
246+
// such an argument.
247+
bool fpreg = false;
248+
bool onstack = false;
249+
bool need_rewrite = false;
250+
classify_arg(ty, &fpreg, &onstack, &need_rewrite);
251+
return onstack;
252+
}
253+
254+
Type *preferred_llvm_type(jl_value_t *ty, bool)
255+
{
256+
if (!jl_is_datatype(ty) || jl_is_abstracttype(ty))
257+
return NULL;
258+
jl_datatype_t *dt = (jl_datatype_t*)ty;
259+
if (Type *fptype = get_llvm_fptype(dt))
260+
return fptype;
261+
bool fpreg = false;
262+
bool onstack = false;
263+
bool need_rewrite = false;
264+
classify_arg(ty, &fpreg, &onstack, &need_rewrite);
265+
if (!need_rewrite)
266+
return NULL;
267+
if (fpreg) {
268+
// Rewrite to [n x fptype] where n is the number of field
269+
// This only happens for isHFAorHVA
270+
size_t members = jl_datatype_nfields(dt);
271+
assert(members > 0 && members <= 4);
272+
jl_datatype_t *eltype = (jl_datatype_t*)jl_field_type(dt, 0);
273+
return ArrayType::get(get_llvm_fptype(eltype), members);
274+
}
275+
else {
276+
// Rewrite to [n x Int64] where n is the **size in dword**
277+
assert(dt->size <= 16); // Should be pass by reference otherwise
278+
return ArrayType::get(T_int64, (dt->size + 7) >> 3);
279+
}
280+
}
281+
282+
}

src/ccall.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ static Value *runtime_sym_lookup(PointerType *funcptype, const char *f_lib, cons
141141
# else
142142
# include "abi_x86.cpp"
143143
# endif
144+
#elif defined _CPU_AARCH64_
145+
# include "abi_aarch64.cpp"
144146
#else
145147
# warning "ccall is defaulting to llvm ABI, since no platform ABI has been defined for this CPU/OS combination"
146148
# include "abi_llvm.cpp"
@@ -900,8 +902,12 @@ static std::string generate_func_sig(
900902
// Note that even though the LLVM argument is called ByVal
901903
// this really means that the thing we're passing is pointing to
902904
// the thing we want to pass by value
905+
#ifndef _CPU_AARCH64_
906+
// the aarch64 backend seems to interpret ByVal as
907+
// implicitly passed on stack.
903908
if (byRef)
904909
paramattrs[i + sret].addAttribute(Attribute::ByVal);
910+
#endif
905911
if (inReg)
906912
paramattrs[i + sret].addAttribute(Attribute::InReg);
907913
if (av != Attribute::None)

src/cgutils.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -923,15 +923,15 @@ JL_DLLEXPORT Type *julia_type_to_llvm(jl_value_t *jt, bool *isboxed)
923923
if (jl_is_floattype(jt)) {
924924
#ifndef DISABLE_FLOAT16
925925
if (nb == 2)
926-
return Type::getHalfTy(jl_LLVMContext);
926+
return T_float16;
927927
else
928928
#endif
929929
if (nb == 4)
930-
return Type::getFloatTy(jl_LLVMContext);
930+
return T_float32;
931931
else if (nb == 8)
932-
return Type::getDoubleTy(jl_LLVMContext);
932+
return T_float64;
933933
else if (nb == 16)
934-
return Type::getFP128Ty(jl_LLVMContext);
934+
return T_float128;
935935
}
936936
return Type::getIntNTy(jl_LLVMContext, nb*8);
937937
}

src/codegen.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,10 @@ static IntegerType *T_uint64;
246246
static IntegerType *T_char;
247247
static IntegerType *T_size;
248248

249+
static Type *T_float16;
249250
static Type *T_float32;
250251
static Type *T_float64;
252+
static Type *T_float128;
251253

252254
static Type *T_pint8;
253255
static Type *T_pint16;
@@ -5512,10 +5514,12 @@ static void init_julia_llvm_env(Module *m)
55125514
else
55135515
T_size = T_uint32;
55145516
T_psize = PointerType::get(T_size, 0);
5517+
T_float16 = Type::getHalfTy(getGlobalContext());
55155518
T_float32 = Type::getFloatTy(getGlobalContext());
55165519
T_pfloat32 = PointerType::get(T_float32, 0);
55175520
T_float64 = Type::getDoubleTy(getGlobalContext());
55185521
T_pfloat64 = PointerType::get(T_float64, 0);
5522+
T_float128 = Type::getFP128Ty(getGlobalContext());
55195523
T_void = Type::getVoidTy(jl_LLVMContext);
55205524
T_pvoidfunc = FunctionType::get(T_void, /*isVarArg*/false)->getPointerTo();
55215525

src/init.c

+1
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,7 @@ void jl_get_builtin_hooks(void)
801801
jl_uint32_type = (jl_datatype_t*)core("UInt32");
802802
jl_uint64_type = (jl_datatype_t*)core("UInt64");
803803

804+
jl_float16_type = (jl_datatype_t*)core("Float16");
804805
jl_float32_type = (jl_datatype_t*)core("Float32");
805806
jl_float64_type = (jl_datatype_t*)core("Float64");
806807
jl_floatingpoint_type = (jl_datatype_t*)core("AbstractFloat");

src/intrinsics.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ static Type *FTnbits(size_t nb)
6161
{
6262
#ifndef DISABLE_FLOAT16
6363
if (nb == 16)
64-
return Type::getHalfTy(jl_LLVMContext);
64+
return T_float16;
6565
else
6666
#endif
6767
if (nb == 32)
68-
return Type::getFloatTy(jl_LLVMContext);
68+
return T_float32;
6969
else if (nb == 64)
70-
return Type::getDoubleTy(jl_LLVMContext);
70+
return T_float64;
7171
else if (nb == 128)
72-
return Type::getFP128Ty(jl_LLVMContext);
72+
return T_float128;
7373
else
7474
jl_error("Unsupported Float Size");
7575
}
@@ -107,7 +107,7 @@ static jl_value_t *JL_JLUINTT(Type *t)
107107
assert(!t->isIntegerTy());
108108
if (t == T_float32) return (jl_value_t*)jl_uint32_type;
109109
if (t == T_float64) return (jl_value_t*)jl_uint64_type;
110-
if (t == Type::getHalfTy(jl_LLVMContext)) return (jl_value_t*)jl_uint16_type;
110+
if (t == T_float16) return (jl_value_t*)jl_uint16_type;
111111
assert(t == T_void);
112112
return jl_bottom_type;
113113
}
@@ -116,7 +116,7 @@ static jl_value_t *JL_JLSINTT(Type *t)
116116
assert(!t->isIntegerTy());
117117
if (t == T_float32) return (jl_value_t*)jl_int32_type;
118118
if (t == T_float64) return (jl_value_t*)jl_int64_type;
119-
if (t == Type::getHalfTy(jl_LLVMContext)) return (jl_value_t*)jl_int16_type;
119+
if (t == T_float16) return (jl_value_t*)jl_int16_type;
120120
assert(t == T_void);
121121
return jl_bottom_type;
122122
}

0 commit comments

Comments
 (0)