Code generation changes to support JIT compilation via LLJIT – JIT Compilation
Now, let’s take a brief look at some of the changes we have made within CodeGen.cpp to support our JIT-based calculator:
- As previously mentioned, the code generation class has two important methods: one to compile the user-defined function into LLVM IR and print the IR to the console, and another to prepare the calculation evaluation function, calc_expr_func, which contains a call to the original user-defined function for evaluation. This second function also prints the resulting IR to the user:
void CodeGen::compileToIR(AST *Tree, Module *M,
StringMap &JITtedFunctions) {
ToIRVisitor ToIR(M, JITtedFunctions);
ToIR.run(Tree);
M->print(outs(), nullptr);
}
void CodeGen::prepareCalculationCallFunc(AST *FuncCall,
Module *M, llvm::StringRef FnName,
StringMap &JITtedFunctions) {
ToIRVisitor ToIR(M, JITtedFunctions);
ToIR.genFuncEvaluationCall(FuncCall);
M->print(outs(), nullptr);
}
- As noted in the preceding source, these code generation functions define a ToIRVisitor instance that takes in our module and a JITtedFunctions map to be used in its constructor upon initialization:
class ToIRVisitor : public ASTVisitor {
Module *M;
IRBuilder<> Builder;
StringMap &JITtedFunctionsMap;
. . .
public:
ToIRVisitor(Module *M,
StringMap &JITtedFunctions)
: M(M), Builder(M->getContext()), JITtedFunctionsMap(JITtedFunctions) {
- Ultimately, this information is used to either generate IR or evaluate the function that the IR was previously generated for. When generating the IR, the code generator expects to see a DefDecl node, which represents defining a new function. The function name, along with the number of arguments it is defined with, is stored within the function definitions map:
virtual void visit(DefDecl &Node) override {
llvm::StringRef FnName = Node.getFnName();
llvm::SmallVector FunctionVars = Node.getVars();
(JITtedFunctionsMap)[FnName] = FunctionVars.size();
- Afterward, the actual function definition is created by the genUserDefinedFunction() call:
Function *DefFunc = genUserDefinedFunction(FnName);
- Within genUserDefinedFunction(), the first step is to check if the function exists within the module. If it does not, we ensure that the function prototype exists within our map data structure. Then, we use the name and the number of arguments to construct a function that has the number of arguments that were defined by the user, and make the function return a single integer value:
Function *genUserDefinedFunction(llvm::StringRef Name) {
if (Function *F = M->getFunction(Name))
return F;
Function *UserDefinedFunction = nullptr;
auto FnNameToArgCount = JITtedFunctionsMap.find(Name);
if (FnNameToArgCount != JITtedFunctionsMap.end()) {
std::vector IntArgs(FnNameToArgCount->second, Int32Ty);
FunctionType *FuncType = FunctionType::get(Int32Ty, IntArgs, false);
UserDefinedFunction =
Function::Create(FuncType, GlobalValue::ExternalLinkage, Name, M);
}
return UserDefinedFunction;
}
- After generating the user-defined function, a new basic block is created, and we insert our function into the basic block. Each function argument is also associated with a name that is defined by the user, so we also set the names for all function arguments accordingly, as well as generate mathematical operations that operate on the arguments within the function:
BasicBlock BB = BasicBlock::Create(M->getContext(), “entry”, DefFunc); Builder.SetInsertPoint(BB); unsigned FIdx = 0; for (auto &FArg : DefFunc->args()) { nameMap[FunctionVars[FIdx]] = &FArg; FArg.setName(FunctionVars[FIdx++]); } Node.getExpr()->accept(this);
};
- When evaluating the user-defined function, the AST that is expected in our example is called a FuncCallFromDef node. First, we define the evaluation function and name it calc_expr_func (taking in zero arguments and returning one result):
virtual void visit(FuncCallFromDef &Node) override {
llvm::StringRef CalcExprFunName = “calc_expr_func”;
FunctionType *CalcExprFunTy = FunctionType::get(Int32Ty, {}, false);
Function *CalcExprFun = Function::Create(
CalcExprFunTy, GlobalValue::ExternalLinkage, CalcExprFunName, M);
- Next, we create a new basic block to insert calc_expr_func into:
BasicBlock *BB = BasicBlock::Create(M->getContext(), “entry”, CalcExprFun);
Builder.SetInsertPoint(BB);
- Similar to before, the user-defined function is retrieved by genUserDefinedFunction(), and we pass the numerical parameters of the function call into the original function that we have just regenerated:
llvm::StringRef CalleeFnName = Node.getFnName();
Function *CalleeFn = genUserDefinedFunction(CalleeFnName);
- Once we have the actual llvm::Function instance available, we utilize IRBuilder to create a call to the defined function and also return the result so that it is accessible when the result is printed to the user in the end:
auto CalleeFnVars = Node.getArgs();
llvm::SmallVector IntParams;
for (unsigned i = 0, end = CalleeFnVars.size(); i != end; ++i) {
int ArgsToIntType;
CalleeFnVars[i].getAsInteger(10, ArgsToIntType);
Value *IntParam = ConstantInt::get(Int32Ty, ArgsToIntType, true);
IntParams.push_back(IntParam);
}
Builder.CreateRet(Builder.CreateCall(CalleeFn, IntParams, “calc_expr_res”));
};