Skip to content

Commit a0c2543

Browse files
committed
WIP: Implement function multi versioning in sysimg
1 parent f28f57c commit a0c2543

File tree

6 files changed

+345
-2
lines changed

6 files changed

+345
-2
lines changed

base/sysimg.jl

+18
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,24 @@ end
400400
INCLUDE_STATE = 3 # include = include_from_node1
401401
include("precompile.jl")
402402

403+
@noinline function test_clone_f(a)
404+
s = zero(eltype(a))
405+
@inbounds @simd for i in 1:length(a)
406+
s += a[i]
407+
end
408+
return s
409+
end
410+
411+
@noinline function test_clone_g(a, n)
412+
s = zero(eltype(a))
413+
for i in 1:n
414+
s += test_clone_f(a)
415+
end
416+
return s
417+
end
418+
419+
test_clone_g(Float64[], 1)
420+
403421
end # baremodule Base
404422

405423
using Base

src/Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ endif
5454
LLVMLINK :=
5555

5656
ifeq ($(JULIACODEGEN),LLVM)
57-
SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-gcroot cgmemmgr
57+
SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-gcroot llvm-mv cgmemmgr
5858
FLAGS += -I$(shell $(LLVM_CONFIG_HOST) --includedir)
5959
LLVM_LIBS := all
6060
ifeq ($(USE_POLLY),1)

src/dump.c

+27
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,11 @@ JL_DLLEXPORT int jl_running_on_valgrind(void)
221221
return RUNNING_ON_VALGRIND;
222222
}
223223

224+
STATIC_INLINE uint64_t i32_to_i64(uint64_t hi, uint64_t lo)
225+
{
226+
return (hi << 32) | lo;
227+
}
228+
224229
static void jl_load_sysimg_so(void)
225230
{
226231
#ifndef _OS_WINDOWS_
@@ -242,6 +247,28 @@ static void jl_load_sysimg_so(void)
242247
*sysimg_gvars[tls_offset_idx - 1] =
243248
(jl_value_t*)(uintptr_t)(jl_tls_offset == -1 ? 0 : jl_tls_offset);
244249
#endif
250+
typedef void (*dispatch_t)(uint64_t, uint64_t, uint64_t, size_t*, void***, size_t**);
251+
dispatch_t dispatchf = (dispatch_t)jl_dlsym(jl_sysimg_handle,
252+
"jl_dispatch_sysimg_fvars");
253+
if (dispatchf) {
254+
int32_t info[4];
255+
jl_cpuid(info, 1);
256+
int32_t infoex[4];
257+
jl_cpuidex(infoex, 7, 0);
258+
uint64_t mask = i32_to_i64(info[3], info[2]);
259+
uint64_t emask1 = i32_to_i64(infoex[1], infoex[2]);
260+
uint64_t emask2 = i32_to_i64(infoex[3], 0);
261+
size_t nfunc = 0;
262+
void **fptrs = NULL;
263+
size_t *fidxs = NULL;
264+
dispatchf(mask, emask1, emask2, &nfunc, &fptrs, &fidxs);
265+
if (nfunc && fptrs && fidxs) {
266+
for (size_t i = 0; i < nfunc; i++) {
267+
size_t fi = fidxs[i];
268+
sysimg_fvars[fi] = fptrs[i];
269+
}
270+
}
271+
}
245272
const char *cpu_target = (const char*)jl_dlsym(jl_sysimg_handle, "jl_sysimg_cpu_target");
246273
if (strcmp(cpu_target,jl_options.cpu_target) != 0)
247274
jl_error("Julia and the system image were compiled for different architectures.\n"

src/jitlayers.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ void addOptimizationPasses(PassManager *PM)
172172
// Let the InstCombine pass remove the unnecessary load of
173173
// safepoint address first
174174
PM->add(createLowerPTLSPass(imaging_mode));
175+
PM->add(createJuliaMVPass());
175176
PM->add(createSROAPass()); // Break up aggregate allocas
176177
#ifndef INSTCOMBINE_BUG
177178
PM->add(createInstructionCombiningPass()); // Cleanup for scalarrepl.
@@ -1088,7 +1089,7 @@ static void jl_gen_llvm_globaldata(llvm::Module *mod, ValueToValueMapTy &VMap,
10881089
ArrayType *fvars_type = ArrayType::get(T_pvoidfunc, jl_sysimg_fvars.size());
10891090
addComdat(new GlobalVariable(*mod,
10901091
fvars_type,
1091-
true,
1092+
false,
10921093
GlobalVariable::ExternalLinkage,
10931094
MapValue(ConstantArray::get(fvars_type, ArrayRef<Constant*>(jl_sysimg_fvars)), VMap),
10941095
"jl_sysimg_fvars"));

src/jitlayers.h

+1
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ JL_DLLEXPORT extern LLVMContext &jl_LLVMContext;
248248

249249
Pass *createLowerPTLSPass(bool imaging_mode);
250250
Pass *createLowerGCFramePass();
251+
Pass *createJuliaMVPass();
251252
// Whether the Function is an llvm or julia intrinsic.
252253
static inline bool isIntrinsicFunction(Function *F)
253254
{

src/llvm-mv.cpp

+296
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
// This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
// Function multi-versioning
4+
#define DEBUG_TYPE "julia_mv"
5+
#undef DEBUG
6+
7+
// LLVM pass to clone function for different archs
8+
9+
#include "llvm-version.h"
10+
#include "support/dtypes.h"
11+
12+
#include <llvm/Pass.h>
13+
#include <llvm/IR/Module.h>
14+
#include <llvm/IR/Function.h>
15+
#include <llvm/IR/Instructions.h>
16+
#include <llvm/IR/Constants.h>
17+
#include <llvm/IR/LLVMContext.h>
18+
#include <llvm/Analysis/LoopInfo.h>
19+
#if JL_LLVM_VERSION >= 30700
20+
#include <llvm/IR/LegacyPassManager.h>
21+
#else
22+
#include <llvm/PassManager.h>
23+
#endif
24+
#include <llvm/IR/MDBuilder.h>
25+
#include <llvm/IR/IRBuilder.h>
26+
#include <llvm/Transforms/Utils/Cloning.h>
27+
#include "fix_llvm_assert.h"
28+
29+
#include "julia.h"
30+
#include "julia_internal.h"
31+
32+
#include <unordered_map>
33+
#include <vector>
34+
35+
using namespace llvm;
36+
37+
extern std::pair<MDNode*,MDNode*> tbaa_make_child(const char *name, MDNode *parent=nullptr, bool isConstant=false);
38+
extern "C" void jl_dump_llvm_value(void *v);
39+
40+
namespace {
41+
42+
struct JuliaMV: public ModulePass {
43+
static char ID;
44+
JuliaMV()
45+
: ModulePass(ID)
46+
{}
47+
48+
private:
49+
bool runOnModule(Module &M) override;
50+
void getAnalysisUsage(AnalysisUsage &AU) const override
51+
{
52+
AU.addRequired<LoopInfoWrapperPass>();
53+
AU.setPreservesAll();
54+
}
55+
bool shouldClone(Function &F);
56+
bool checkUses(Function &F, Constant *fary);
57+
bool checkUses(Function &F, Constant *V, Constant *fary, bool &inFVars);
58+
bool checkConstantUse(Function &F, Constant *V, Constant *fary, bool &inFVars);
59+
};
60+
61+
bool JuliaMV::shouldClone(Function &F)
62+
{
63+
if (F.empty())
64+
return false;
65+
auto &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
66+
if (!LI.empty())
67+
return true;
68+
for (auto &bb: F) {
69+
for (auto &I: bb) {
70+
if (auto call = dyn_cast<CallInst>(&I)) {
71+
if (auto callee = call->getCalledFunction()) {
72+
auto name = callee->getName();
73+
if (name.startswith("llvm.muladd.") || name.startswith("llvm.fma.")) {
74+
return true;
75+
}
76+
}
77+
}
78+
}
79+
}
80+
return false;
81+
}
82+
83+
bool JuliaMV::checkUses(Function &F, Constant *fary)
84+
{
85+
bool inFVars = false;
86+
bool res = checkUses(F, &F, fary, inFVars);
87+
return res && inFVars;
88+
}
89+
90+
bool JuliaMV::checkConstantUse(Function &F, Constant *V, Constant *fary, bool &inFVars)
91+
{
92+
if (V == fary) {
93+
inFVars = true;
94+
return true;
95+
}
96+
if (auto cexpr = dyn_cast<ConstantExpr>(V)) {
97+
if (cexpr->getOpcode() == Instruction::BitCast) {
98+
return checkUses(F, V, fary, inFVars);
99+
}
100+
}
101+
return false;
102+
}
103+
104+
bool JuliaMV::checkUses(Function &F, Constant *V, Constant *fary, bool &inFVars)
105+
{
106+
for (auto *user: V->users()) {
107+
if (isa<Instruction>(user))
108+
continue;
109+
auto *C = dyn_cast<Constant>(user);
110+
if (!C || !checkConstantUse(F, C, fary, inFVars)) {
111+
return false;
112+
}
113+
}
114+
return true;
115+
}
116+
117+
static Function *getFunction(Value *v)
118+
{
119+
if (auto f = dyn_cast<Function>(v))
120+
return f;
121+
if (auto c = dyn_cast<ConstantExpr>(v)) {
122+
if (c->getOpcode() == Instruction::BitCast) {
123+
return getFunction(c->getOperand(0));
124+
}
125+
}
126+
return nullptr;
127+
}
128+
129+
static void addFeatures(Function *F)
130+
{
131+
auto attr = F->getFnAttribute("target-features");
132+
std::string feature =
133+
"+avx2,+avx,+fma,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3";
134+
if (attr.isStringAttribute()) {
135+
feature += ",";
136+
feature += attr.getValueAsString();
137+
}
138+
F->addFnAttr("target-features", feature);
139+
}
140+
141+
bool JuliaMV::runOnModule(Module &M)
142+
{
143+
MDNode *tbaa_const = tbaa_make_child("jtbaa_const", nullptr, true).first;
144+
GlobalVariable *fvars = M.getGlobalVariable("jl_sysimg_fvars");
145+
// This makes sure this only runs during sysimg generation
146+
if (!fvars || !fvars->hasInitializer())
147+
return true;
148+
auto *fary = dyn_cast<ConstantArray>(fvars->getInitializer());
149+
if (!fary)
150+
return true;
151+
LLVMContext &ctx = M.getContext();
152+
ValueToValueMapTy VMap;
153+
for (auto &F: M) {
154+
if (shouldClone(F) && checkUses(F, fary)) {
155+
Function *NF = Function::Create(cast<FunctionType>(F.getValueType()),
156+
F.getLinkage(), F.getName() + ".avx2", &M);
157+
NF->copyAttributesFrom(&F);
158+
VMap[&F] = NF;
159+
}
160+
}
161+
std::unordered_map<Function*,size_t> idx_map;
162+
size_t nf = fary->getNumOperands();
163+
for (size_t i = 0; i < nf; i++) {
164+
if (Function *ele = getFunction(fary->getOperand(i))) {
165+
auto it = VMap.find(ele);
166+
if (it != VMap.end()) {
167+
idx_map[ele] = i;
168+
}
169+
}
170+
}
171+
for (auto I: idx_map) {
172+
auto oldF = I.first;
173+
auto newF = cast<Function>(VMap[oldF]);
174+
Function::arg_iterator DestI = newF->arg_begin();
175+
for (Function::const_arg_iterator J = oldF->arg_begin(); J != oldF->arg_end(); ++J) {
176+
DestI->setName(J->getName());
177+
VMap[&*J] = &*DestI++;
178+
}
179+
SmallVector<ReturnInst*,8> Returns;
180+
CloneFunctionInto(newF, oldF, VMap, false, Returns);
181+
addFeatures(newF);
182+
}
183+
std::vector<Constant*> ptrs;
184+
std::vector<Constant*> idxs;
185+
auto T_void = Type::getVoidTy(ctx);
186+
auto T_pvoidfunc = FunctionType::get(T_void, false)->getPointerTo();
187+
auto T_size = (sizeof(size_t) == 8 ? Type::getInt64Ty(ctx) : Type::getInt32Ty(ctx));
188+
for (auto I: idx_map) {
189+
auto oldF = I.first;
190+
auto idx = I.second;
191+
auto newF = cast<Function>(VMap[oldF]);
192+
ptrs.push_back(ConstantExpr::getBitCast(newF, T_pvoidfunc));
193+
auto offset = ConstantInt::get(T_size, idx);
194+
idxs.push_back(offset);
195+
for (auto user: oldF->users()) {
196+
auto inst = dyn_cast<Instruction>(user);
197+
if (!inst)
198+
continue;
199+
auto encloseF = inst->getParent()->getParent();
200+
if (VMap.find(encloseF) != VMap.end())
201+
continue;
202+
auto slot = GetElementPtrInst::Create(T_pvoidfunc->getPointerTo(), fvars,
203+
{offset}, "", inst);
204+
Instruction *ptr = new LoadInst(slot, "", inst);
205+
ptr->setMetadata(llvm::LLVMContext::MD_tbaa, tbaa_const);
206+
ptr = new BitCastInst(ptr, oldF->getType(), "", inst);
207+
inst->replaceUsesOfWith(oldF, ptr);
208+
}
209+
}
210+
ArrayType *fvars_type = ArrayType::get(T_pvoidfunc, ptrs.size());
211+
auto ptr_gv = new GlobalVariable(M, fvars_type, true, GlobalVariable::InternalLinkage,
212+
ConstantArray::get(fvars_type, ptrs));
213+
ArrayType *idxs_type = ArrayType::get(T_size, idxs.size());
214+
auto idx_gv = new GlobalVariable(M, idxs_type, true, GlobalVariable::InternalLinkage,
215+
ConstantArray::get(idxs_type, idxs));
216+
217+
std::vector<Type*> dispatch_args(0);
218+
dispatch_args.push_back(Type::getInt64Ty(ctx)); // Feature mask
219+
dispatch_args.push_back(Type::getInt64Ty(ctx)); // Extended feature mask1
220+
dispatch_args.push_back(Type::getInt64Ty(ctx)); // Extended feature mask2
221+
dispatch_args.push_back(T_size->getPointerTo());
222+
dispatch_args.push_back(fvars_type->getPointerTo()->getPointerTo());
223+
dispatch_args.push_back(idxs_type->getPointerTo()->getPointerTo());
224+
Function *dispatchF = Function::Create(FunctionType::get(T_void, dispatch_args, false),
225+
Function::ExternalLinkage,
226+
"jl_dispatch_sysimg_fvars", &M);
227+
IRBuilder<> builder(ctx);
228+
BasicBlock *b0 = BasicBlock::Create(ctx, "top", dispatchF);
229+
builder.SetInsertPoint(b0);
230+
DebugLoc noDbg;
231+
builder.SetCurrentDebugLocation(noDbg);
232+
233+
std::vector<Argument*> args;
234+
for (auto &arg: dispatchF->args())
235+
args.push_back(&arg);
236+
237+
auto sz_arg = args[3];
238+
auto fvars_arg = args[4];
239+
auto idxs_arg = args[5];
240+
241+
// Hard code for now
242+
// EDX:ECX
243+
uint64_t mask = 1 | (1 << 9) | (1 << 12) | (1 << 19) | (1 << 20) | (1 << 23) | (1 << 28);
244+
// EBX:ECX
245+
uint64_t emask1 = uint64_t(1) << (5 + 32);
246+
// EDX:0
247+
uint64_t emask2 = 0;
248+
249+
builder.CreateStore(ConstantInt::get(T_size, ptrs.size()), sz_arg);
250+
251+
auto createMaskCmp = [&] (Value *v, uint64_t mask) {
252+
auto maskv = ConstantInt::get(v->getType(), mask);
253+
return builder.CreateICmpEQ(builder.CreateAnd(v, maskv), maskv);
254+
};
255+
256+
auto match_mask = createMaskCmp(args[0], mask);
257+
auto match_emask1 = createMaskCmp(args[1], emask1);
258+
auto match_emask2 = createMaskCmp(args[2], emask2);
259+
260+
auto match = builder.CreateAnd(match_mask, match_emask1);
261+
match = builder.CreateAnd(match, match_emask2);
262+
263+
BasicBlock *match_bb = BasicBlock::Create(ctx, "match");
264+
BasicBlock *fail_bb = BasicBlock::Create(ctx, "fail");
265+
builder.CreateCondBr(match, match_bb, fail_bb);
266+
267+
dispatchF->getBasicBlockList().push_back(match_bb);
268+
builder.SetInsertPoint(match_bb);
269+
builder.CreateStore(ptr_gv, fvars_arg);
270+
builder.CreateStore(idx_gv, idxs_arg);
271+
builder.CreateRetVoid();
272+
273+
dispatchF->getBasicBlockList().push_back(fail_bb);
274+
builder.SetInsertPoint(fail_bb);
275+
builder.CreateStore(ConstantPointerNull::get(fvars_type->getPointerTo()), fvars_arg);
276+
builder.CreateStore(ConstantPointerNull::get(idxs_type->getPointerTo()), idxs_arg);
277+
builder.CreateRetVoid();
278+
279+
// jl_dump_llvm_value(dispatchF);
280+
// jl_dump_llvm_value(ptr_gv);
281+
// jl_dump_llvm_value(idx_gv);
282+
283+
return true;
284+
}
285+
286+
char JuliaMV::ID = 0;
287+
static RegisterPass<JuliaMV> X("JuliaMV", "JuliaMV Pass",
288+
false /* Only looks at CFG */,
289+
false /* Analysis Pass */);
290+
291+
}
292+
293+
Pass *createJuliaMVPass()
294+
{
295+
return new JuliaMV();
296+
}

0 commit comments

Comments
 (0)