14f962b0dSLang Hames //===-- LLJITWithOptimizingIRTransform.cpp -- LLJIT with IR optimization --===//
24f962b0dSLang Hames //
34f962b0dSLang Hames // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44f962b0dSLang Hames // See https://llvm.org/LICENSE.txt for license information.
54f962b0dSLang Hames // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64f962b0dSLang Hames //
74f962b0dSLang Hames //===----------------------------------------------------------------------===//
84f962b0dSLang Hames //
94f962b0dSLang Hames // In this example we will use an IR transform to optimize a module as it
104f962b0dSLang Hames // passes through LLJIT's IRTransformLayer.
114f962b0dSLang Hames //
124f962b0dSLang Hames //===----------------------------------------------------------------------===//
134f962b0dSLang Hames 
144f962b0dSLang Hames #include "llvm/ExecutionEngine/Orc/LLJIT.h"
154f962b0dSLang Hames #include "llvm/IR/LegacyPassManager.h"
163c4410dfSserge-sans-paille #include "llvm/Pass.h"
174f962b0dSLang Hames #include "llvm/Support/InitLLVM.h"
184f962b0dSLang Hames #include "llvm/Support/TargetSelect.h"
194f962b0dSLang Hames #include "llvm/Support/raw_ostream.h"
204f962b0dSLang Hames #include "llvm/Transforms/IPO.h"
214f962b0dSLang Hames #include "llvm/Transforms/Scalar.h"
224f962b0dSLang Hames 
234f962b0dSLang Hames #include "../ExampleModules.h"
244f962b0dSLang Hames 
254f962b0dSLang Hames using namespace llvm;
264f962b0dSLang Hames using namespace llvm::orc;
274f962b0dSLang Hames 
284f962b0dSLang Hames ExitOnError ExitOnErr;
294f962b0dSLang Hames 
304f962b0dSLang Hames // Example IR module.
314f962b0dSLang Hames //
324f962b0dSLang Hames // This IR contains a recursive definition of the factorial function:
334f962b0dSLang Hames //
344f962b0dSLang Hames // fac(n) | n == 0    = 1
354f962b0dSLang Hames //        | otherwise = n * fac(n - 1)
364f962b0dSLang Hames //
374f962b0dSLang Hames // It also contains an entry function which calls the factorial function with
384f962b0dSLang Hames // an input value of 5.
394f962b0dSLang Hames //
404f962b0dSLang Hames // We expect the IR optimization transform that we build below to transform
414f962b0dSLang Hames // this into a non-recursive factorial function and an entry function that
424f962b0dSLang Hames // returns a constant value of 5!, or 120.
434f962b0dSLang Hames 
444f962b0dSLang Hames const llvm::StringRef MainMod =
454f962b0dSLang Hames     R"(
464f962b0dSLang Hames 
474f962b0dSLang Hames   define i32 @fac(i32 %n) {
484f962b0dSLang Hames   entry:
494f962b0dSLang Hames     %tobool = icmp eq i32 %n, 0
504f962b0dSLang Hames     br i1 %tobool, label %return, label %if.then
514f962b0dSLang Hames 
524f962b0dSLang Hames   if.then:                                          ; preds = %entry
534f962b0dSLang Hames     %arg = add nsw i32 %n, -1
544f962b0dSLang Hames     %call_result = call i32 @fac(i32 %arg)
554f962b0dSLang Hames     %result = mul nsw i32 %n, %call_result
564f962b0dSLang Hames     br label %return
574f962b0dSLang Hames 
584f962b0dSLang Hames   return:                                           ; preds = %entry, %if.then
594f962b0dSLang Hames     %final_result = phi i32 [ %result, %if.then ], [ 1, %entry ]
604f962b0dSLang Hames     ret i32 %final_result
614f962b0dSLang Hames   }
624f962b0dSLang Hames 
634f962b0dSLang Hames   define i32 @entry() {
644f962b0dSLang Hames   entry:
654f962b0dSLang Hames     %result = call i32 @fac(i32 5)
664f962b0dSLang Hames     ret i32 %result
674f962b0dSLang Hames   }
684f962b0dSLang Hames 
694f962b0dSLang Hames )";
704f962b0dSLang Hames 
714f962b0dSLang Hames // A function object that creates a simple pass pipeline to apply to each
724f962b0dSLang Hames // module as it passes through the IRTransformLayer.
734f962b0dSLang Hames class MyOptimizationTransform {
744f962b0dSLang Hames public:
MyOptimizationTransform()754f962b0dSLang Hames   MyOptimizationTransform() : PM(std::make_unique<legacy::PassManager>()) {
764f962b0dSLang Hames     PM->add(createTailCallEliminationPass());
774f962b0dSLang Hames     PM->add(createCFGSimplificationPass());
784f962b0dSLang Hames   }
794f962b0dSLang Hames 
operator ()(ThreadSafeModule TSM,MaterializationResponsibility & R)804f962b0dSLang Hames   Expected<ThreadSafeModule> operator()(ThreadSafeModule TSM,
814f962b0dSLang Hames                                         MaterializationResponsibility &R) {
824f962b0dSLang Hames     TSM.withModuleDo([this](Module &M) {
834f962b0dSLang Hames       dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M << "\n";
844f962b0dSLang Hames       PM->run(M);
854f962b0dSLang Hames       dbgs() << "--- AFTER OPTIMIZATION ---\n" << M << "\n";
864f962b0dSLang Hames     });
874f962b0dSLang Hames     return std::move(TSM);
884f962b0dSLang Hames   }
894f962b0dSLang Hames 
904f962b0dSLang Hames private:
914f962b0dSLang Hames   std::unique_ptr<legacy::PassManager> PM;
924f962b0dSLang Hames };
934f962b0dSLang Hames 
main(int argc,char * argv[])944f962b0dSLang Hames int main(int argc, char *argv[]) {
954f962b0dSLang Hames   // Initialize LLVM.
964f962b0dSLang Hames   InitLLVM X(argc, argv);
974f962b0dSLang Hames 
984f962b0dSLang Hames   InitializeNativeTarget();
994f962b0dSLang Hames   InitializeNativeTargetAsmPrinter();
1004f962b0dSLang Hames 
1014f962b0dSLang Hames   ExitOnErr.setBanner(std::string(argv[0]) + ": ");
1024f962b0dSLang Hames 
1034f962b0dSLang Hames   // (1) Create LLJIT instance.
1044f962b0dSLang Hames   auto J = ExitOnErr(LLJITBuilder().create());
1054f962b0dSLang Hames 
1064f962b0dSLang Hames   // (2) Install transform to optimize modules when they're materialized.
1074f962b0dSLang Hames   J->getIRTransformLayer().setTransform(MyOptimizationTransform());
1084f962b0dSLang Hames 
1094f962b0dSLang Hames   // (3) Add modules.
1104f962b0dSLang Hames   ExitOnErr(J->addIRModule(ExitOnErr(parseExampleModule(MainMod, "MainMod"))));
1114f962b0dSLang Hames 
1124f962b0dSLang Hames   // (4) Look up the JIT'd function and call it.
113*16dcbb53SLang Hames   auto EntryAddr = ExitOnErr(J->lookup("entry"));
114*16dcbb53SLang Hames   auto *Entry = EntryAddr.toPtr<int()>();
1154f962b0dSLang Hames 
1164f962b0dSLang Hames   int Result = Entry();
1174f962b0dSLang Hames   outs() << "--- Result ---\n"
1184f962b0dSLang Hames          << "entry() = " << Result << "\n";
1194f962b0dSLang Hames 
1204f962b0dSLang Hames   return 0;
1214f962b0dSLang Hames }
122