xref: /freebsd-src/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp (revision fe6060f10f634930ff71b7c50291ddc610da2475)
15ffd83dbSDimitry Andric //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
25ffd83dbSDimitry Andric //
35ffd83dbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45ffd83dbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
55ffd83dbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65ffd83dbSDimitry Andric //
75ffd83dbSDimitry Andric //===----------------------------------------------------------------------===//
85ffd83dbSDimitry Andric //
95ffd83dbSDimitry Andric // OpenMP specific optimizations:
105ffd83dbSDimitry Andric //
115ffd83dbSDimitry Andric // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12*fe6060f1SDimitry Andric // - Replacing globalized device memory with stack memory.
13*fe6060f1SDimitry Andric // - Replacing globalized device memory with shared memory.
14*fe6060f1SDimitry Andric // - Parallel region merging.
15*fe6060f1SDimitry Andric // - Transforming generic-mode device kernels to SPMD mode.
16*fe6060f1SDimitry Andric // - Specializing the state machine for generic-mode device kernels.
175ffd83dbSDimitry Andric //
185ffd83dbSDimitry Andric //===----------------------------------------------------------------------===//
195ffd83dbSDimitry Andric 
205ffd83dbSDimitry Andric #include "llvm/Transforms/IPO/OpenMPOpt.h"
215ffd83dbSDimitry Andric 
225ffd83dbSDimitry Andric #include "llvm/ADT/EnumeratedArray.h"
23*fe6060f1SDimitry Andric #include "llvm/ADT/PostOrderIterator.h"
245ffd83dbSDimitry Andric #include "llvm/ADT/Statistic.h"
255ffd83dbSDimitry Andric #include "llvm/Analysis/CallGraph.h"
265ffd83dbSDimitry Andric #include "llvm/Analysis/CallGraphSCCPass.h"
275ffd83dbSDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h"
28e8d8bef9SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
295ffd83dbSDimitry Andric #include "llvm/Frontend/OpenMP/OMPConstants.h"
305ffd83dbSDimitry Andric #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31*fe6060f1SDimitry Andric #include "llvm/IR/Assumptions.h"
32*fe6060f1SDimitry Andric #include "llvm/IR/DiagnosticInfo.h"
33*fe6060f1SDimitry Andric #include "llvm/IR/GlobalValue.h"
34*fe6060f1SDimitry Andric #include "llvm/IR/Instruction.h"
35*fe6060f1SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
365ffd83dbSDimitry Andric #include "llvm/InitializePasses.h"
375ffd83dbSDimitry Andric #include "llvm/Support/CommandLine.h"
385ffd83dbSDimitry Andric #include "llvm/Transforms/IPO.h"
395ffd83dbSDimitry Andric #include "llvm/Transforms/IPO/Attributor.h"
40e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
415ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/CallGraphUpdater.h"
42e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/CodeExtractor.h"
435ffd83dbSDimitry Andric 
445ffd83dbSDimitry Andric using namespace llvm;
455ffd83dbSDimitry Andric using namespace omp;
465ffd83dbSDimitry Andric 
475ffd83dbSDimitry Andric #define DEBUG_TYPE "openmp-opt"
485ffd83dbSDimitry Andric 
495ffd83dbSDimitry Andric static cl::opt<bool> DisableOpenMPOptimizations(
505ffd83dbSDimitry Andric     "openmp-opt-disable", cl::ZeroOrMore,
515ffd83dbSDimitry Andric     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
525ffd83dbSDimitry Andric     cl::init(false));
535ffd83dbSDimitry Andric 
54e8d8bef9SDimitry Andric static cl::opt<bool> EnableParallelRegionMerging(
55e8d8bef9SDimitry Andric     "openmp-opt-enable-merging", cl::ZeroOrMore,
56e8d8bef9SDimitry Andric     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
57e8d8bef9SDimitry Andric     cl::init(false));
58e8d8bef9SDimitry Andric 
59*fe6060f1SDimitry Andric static cl::opt<bool>
60*fe6060f1SDimitry Andric     DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore,
61*fe6060f1SDimitry Andric                            cl::desc("Disable function internalization."),
62*fe6060f1SDimitry Andric                            cl::Hidden, cl::init(false));
63*fe6060f1SDimitry Andric 
645ffd83dbSDimitry Andric static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
655ffd83dbSDimitry Andric                                     cl::Hidden);
665ffd83dbSDimitry Andric static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
675ffd83dbSDimitry Andric                                         cl::init(false), cl::Hidden);
685ffd83dbSDimitry Andric 
69e8d8bef9SDimitry Andric static cl::opt<bool> HideMemoryTransferLatency(
70e8d8bef9SDimitry Andric     "openmp-hide-memory-transfer-latency",
71e8d8bef9SDimitry Andric     cl::desc("[WIP] Tries to hide the latency of host to device memory"
72e8d8bef9SDimitry Andric              " transfers"),
73e8d8bef9SDimitry Andric     cl::Hidden, cl::init(false));
74e8d8bef9SDimitry Andric 
755ffd83dbSDimitry Andric STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
765ffd83dbSDimitry Andric           "Number of OpenMP runtime calls deduplicated");
775ffd83dbSDimitry Andric STATISTIC(NumOpenMPParallelRegionsDeleted,
785ffd83dbSDimitry Andric           "Number of OpenMP parallel regions deleted");
795ffd83dbSDimitry Andric STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
805ffd83dbSDimitry Andric           "Number of OpenMP runtime functions identified");
815ffd83dbSDimitry Andric STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
825ffd83dbSDimitry Andric           "Number of OpenMP runtime function uses identified");
835ffd83dbSDimitry Andric STATISTIC(NumOpenMPTargetRegionKernels,
845ffd83dbSDimitry Andric           "Number of OpenMP target region entry points (=kernels) identified");
85*fe6060f1SDimitry Andric STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
86*fe6060f1SDimitry Andric           "Number of OpenMP target region entry points (=kernels) executed in "
87*fe6060f1SDimitry Andric           "SPMD-mode instead of generic-mode");
88*fe6060f1SDimitry Andric STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
89*fe6060f1SDimitry Andric           "Number of OpenMP target region entry points (=kernels) executed in "
90*fe6060f1SDimitry Andric           "generic-mode without a state machines");
91*fe6060f1SDimitry Andric STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
92*fe6060f1SDimitry Andric           "Number of OpenMP target region entry points (=kernels) executed in "
93*fe6060f1SDimitry Andric           "generic-mode with customized state machines with fallback");
94*fe6060f1SDimitry Andric STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
95*fe6060f1SDimitry Andric           "Number of OpenMP target region entry points (=kernels) executed in "
96*fe6060f1SDimitry Andric           "generic-mode with customized state machines without fallback");
975ffd83dbSDimitry Andric STATISTIC(
985ffd83dbSDimitry Andric     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
995ffd83dbSDimitry Andric     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
100e8d8bef9SDimitry Andric STATISTIC(NumOpenMPParallelRegionsMerged,
101e8d8bef9SDimitry Andric           "Number of OpenMP parallel regions merged");
102*fe6060f1SDimitry Andric STATISTIC(NumBytesMovedToSharedMemory,
103*fe6060f1SDimitry Andric           "Amount of memory pushed to shared memory");
1045ffd83dbSDimitry Andric 
1055ffd83dbSDimitry Andric #if !defined(NDEBUG)
1065ffd83dbSDimitry Andric static constexpr auto TAG = "[" DEBUG_TYPE "]";
1075ffd83dbSDimitry Andric #endif
1085ffd83dbSDimitry Andric 
1095ffd83dbSDimitry Andric namespace {
1105ffd83dbSDimitry Andric 
111*fe6060f1SDimitry Andric enum class AddressSpace : unsigned {
112*fe6060f1SDimitry Andric   Generic = 0,
113*fe6060f1SDimitry Andric   Global = 1,
114*fe6060f1SDimitry Andric   Shared = 3,
115*fe6060f1SDimitry Andric   Constant = 4,
116*fe6060f1SDimitry Andric   Local = 5,
117*fe6060f1SDimitry Andric };
118*fe6060f1SDimitry Andric 
119*fe6060f1SDimitry Andric struct AAHeapToShared;
120*fe6060f1SDimitry Andric 
1215ffd83dbSDimitry Andric struct AAICVTracker;
1225ffd83dbSDimitry Andric 
1235ffd83dbSDimitry Andric /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
1245ffd83dbSDimitry Andric /// Attributor runs.
1255ffd83dbSDimitry Andric struct OMPInformationCache : public InformationCache {
1265ffd83dbSDimitry Andric   OMPInformationCache(Module &M, AnalysisGetter &AG,
1275ffd83dbSDimitry Andric                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
1285ffd83dbSDimitry Andric                       SmallPtrSetImpl<Kernel> &Kernels)
1295ffd83dbSDimitry Andric       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
1305ffd83dbSDimitry Andric         Kernels(Kernels) {
1315ffd83dbSDimitry Andric 
1325ffd83dbSDimitry Andric     OMPBuilder.initialize();
1335ffd83dbSDimitry Andric     initializeRuntimeFunctions();
1345ffd83dbSDimitry Andric     initializeInternalControlVars();
1355ffd83dbSDimitry Andric   }
1365ffd83dbSDimitry Andric 
1375ffd83dbSDimitry Andric   /// Generic information that describes an internal control variable.
1385ffd83dbSDimitry Andric   struct InternalControlVarInfo {
1395ffd83dbSDimitry Andric     /// The kind, as described by InternalControlVar enum.
1405ffd83dbSDimitry Andric     InternalControlVar Kind;
1415ffd83dbSDimitry Andric 
1425ffd83dbSDimitry Andric     /// The name of the ICV.
1435ffd83dbSDimitry Andric     StringRef Name;
1445ffd83dbSDimitry Andric 
1455ffd83dbSDimitry Andric     /// Environment variable associated with this ICV.
1465ffd83dbSDimitry Andric     StringRef EnvVarName;
1475ffd83dbSDimitry Andric 
1485ffd83dbSDimitry Andric     /// Initial value kind.
1495ffd83dbSDimitry Andric     ICVInitValue InitKind;
1505ffd83dbSDimitry Andric 
1515ffd83dbSDimitry Andric     /// Initial value.
1525ffd83dbSDimitry Andric     ConstantInt *InitValue;
1535ffd83dbSDimitry Andric 
1545ffd83dbSDimitry Andric     /// Setter RTL function associated with this ICV.
1555ffd83dbSDimitry Andric     RuntimeFunction Setter;
1565ffd83dbSDimitry Andric 
1575ffd83dbSDimitry Andric     /// Getter RTL function associated with this ICV.
1585ffd83dbSDimitry Andric     RuntimeFunction Getter;
1595ffd83dbSDimitry Andric 
1605ffd83dbSDimitry Andric     /// RTL Function corresponding to the override clause of this ICV
1615ffd83dbSDimitry Andric     RuntimeFunction Clause;
1625ffd83dbSDimitry Andric   };
1635ffd83dbSDimitry Andric 
1645ffd83dbSDimitry Andric   /// Generic information that describes a runtime function
1655ffd83dbSDimitry Andric   struct RuntimeFunctionInfo {
1665ffd83dbSDimitry Andric 
1675ffd83dbSDimitry Andric     /// The kind, as described by the RuntimeFunction enum.
1685ffd83dbSDimitry Andric     RuntimeFunction Kind;
1695ffd83dbSDimitry Andric 
1705ffd83dbSDimitry Andric     /// The name of the function.
1715ffd83dbSDimitry Andric     StringRef Name;
1725ffd83dbSDimitry Andric 
1735ffd83dbSDimitry Andric     /// Flag to indicate a variadic function.
1745ffd83dbSDimitry Andric     bool IsVarArg;
1755ffd83dbSDimitry Andric 
1765ffd83dbSDimitry Andric     /// The return type of the function.
1775ffd83dbSDimitry Andric     Type *ReturnType;
1785ffd83dbSDimitry Andric 
1795ffd83dbSDimitry Andric     /// The argument types of the function.
1805ffd83dbSDimitry Andric     SmallVector<Type *, 8> ArgumentTypes;
1815ffd83dbSDimitry Andric 
1825ffd83dbSDimitry Andric     /// The declaration if available.
1835ffd83dbSDimitry Andric     Function *Declaration = nullptr;
1845ffd83dbSDimitry Andric 
1855ffd83dbSDimitry Andric     /// Uses of this runtime function per function containing the use.
1865ffd83dbSDimitry Andric     using UseVector = SmallVector<Use *, 16>;
1875ffd83dbSDimitry Andric 
1885ffd83dbSDimitry Andric     /// Clear UsesMap for runtime function.
1895ffd83dbSDimitry Andric     void clearUsesMap() { UsesMap.clear(); }
1905ffd83dbSDimitry Andric 
1915ffd83dbSDimitry Andric     /// Boolean conversion that is true if the runtime function was found.
1925ffd83dbSDimitry Andric     operator bool() const { return Declaration; }
1935ffd83dbSDimitry Andric 
1945ffd83dbSDimitry Andric     /// Return the vector of uses in function \p F.
1955ffd83dbSDimitry Andric     UseVector &getOrCreateUseVector(Function *F) {
1965ffd83dbSDimitry Andric       std::shared_ptr<UseVector> &UV = UsesMap[F];
1975ffd83dbSDimitry Andric       if (!UV)
1985ffd83dbSDimitry Andric         UV = std::make_shared<UseVector>();
1995ffd83dbSDimitry Andric       return *UV;
2005ffd83dbSDimitry Andric     }
2015ffd83dbSDimitry Andric 
2025ffd83dbSDimitry Andric     /// Return the vector of uses in function \p F or `nullptr` if there are
2035ffd83dbSDimitry Andric     /// none.
2045ffd83dbSDimitry Andric     const UseVector *getUseVector(Function &F) const {
2055ffd83dbSDimitry Andric       auto I = UsesMap.find(&F);
2065ffd83dbSDimitry Andric       if (I != UsesMap.end())
2075ffd83dbSDimitry Andric         return I->second.get();
2085ffd83dbSDimitry Andric       return nullptr;
2095ffd83dbSDimitry Andric     }
2105ffd83dbSDimitry Andric 
2115ffd83dbSDimitry Andric     /// Return how many functions contain uses of this runtime function.
2125ffd83dbSDimitry Andric     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
2135ffd83dbSDimitry Andric 
2145ffd83dbSDimitry Andric     /// Return the number of arguments (or the minimal number for variadic
2155ffd83dbSDimitry Andric     /// functions).
2165ffd83dbSDimitry Andric     size_t getNumArgs() const { return ArgumentTypes.size(); }
2175ffd83dbSDimitry Andric 
2185ffd83dbSDimitry Andric     /// Run the callback \p CB on each use and forget the use if the result is
2195ffd83dbSDimitry Andric     /// true. The callback will be fed the function in which the use was
2205ffd83dbSDimitry Andric     /// encountered as second argument.
2215ffd83dbSDimitry Andric     void foreachUse(SmallVectorImpl<Function *> &SCC,
2225ffd83dbSDimitry Andric                     function_ref<bool(Use &, Function &)> CB) {
2235ffd83dbSDimitry Andric       for (Function *F : SCC)
2245ffd83dbSDimitry Andric         foreachUse(CB, F);
2255ffd83dbSDimitry Andric     }
2265ffd83dbSDimitry Andric 
2275ffd83dbSDimitry Andric     /// Run the callback \p CB on each use within the function \p F and forget
2285ffd83dbSDimitry Andric     /// the use if the result is true.
2295ffd83dbSDimitry Andric     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
2305ffd83dbSDimitry Andric       SmallVector<unsigned, 8> ToBeDeleted;
2315ffd83dbSDimitry Andric       ToBeDeleted.clear();
2325ffd83dbSDimitry Andric 
2335ffd83dbSDimitry Andric       unsigned Idx = 0;
2345ffd83dbSDimitry Andric       UseVector &UV = getOrCreateUseVector(F);
2355ffd83dbSDimitry Andric 
2365ffd83dbSDimitry Andric       for (Use *U : UV) {
2375ffd83dbSDimitry Andric         if (CB(*U, *F))
2385ffd83dbSDimitry Andric           ToBeDeleted.push_back(Idx);
2395ffd83dbSDimitry Andric         ++Idx;
2405ffd83dbSDimitry Andric       }
2415ffd83dbSDimitry Andric 
2425ffd83dbSDimitry Andric       // Remove the to-be-deleted indices in reverse order as prior
2435ffd83dbSDimitry Andric       // modifications will not modify the smaller indices.
2445ffd83dbSDimitry Andric       while (!ToBeDeleted.empty()) {
2455ffd83dbSDimitry Andric         unsigned Idx = ToBeDeleted.pop_back_val();
2465ffd83dbSDimitry Andric         UV[Idx] = UV.back();
2475ffd83dbSDimitry Andric         UV.pop_back();
2485ffd83dbSDimitry Andric       }
2495ffd83dbSDimitry Andric     }
2505ffd83dbSDimitry Andric 
2515ffd83dbSDimitry Andric   private:
2525ffd83dbSDimitry Andric     /// Map from functions to all uses of this runtime function contained in
2535ffd83dbSDimitry Andric     /// them.
2545ffd83dbSDimitry Andric     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
255*fe6060f1SDimitry Andric 
256*fe6060f1SDimitry Andric   public:
257*fe6060f1SDimitry Andric     /// Iterators for the uses of this runtime function.
258*fe6060f1SDimitry Andric     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
259*fe6060f1SDimitry Andric     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
2605ffd83dbSDimitry Andric   };
2615ffd83dbSDimitry Andric 
2625ffd83dbSDimitry Andric   /// An OpenMP-IR-Builder instance
2635ffd83dbSDimitry Andric   OpenMPIRBuilder OMPBuilder;
2645ffd83dbSDimitry Andric 
2655ffd83dbSDimitry Andric   /// Map from runtime function kind to the runtime function description.
2665ffd83dbSDimitry Andric   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
2675ffd83dbSDimitry Andric                   RuntimeFunction::OMPRTL___last>
2685ffd83dbSDimitry Andric       RFIs;
2695ffd83dbSDimitry Andric 
270*fe6060f1SDimitry Andric   /// Map from function declarations/definitions to their runtime enum type.
271*fe6060f1SDimitry Andric   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
272*fe6060f1SDimitry Andric 
2735ffd83dbSDimitry Andric   /// Map from ICV kind to the ICV description.
2745ffd83dbSDimitry Andric   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
2755ffd83dbSDimitry Andric                   InternalControlVar::ICV___last>
2765ffd83dbSDimitry Andric       ICVs;
2775ffd83dbSDimitry Andric 
2785ffd83dbSDimitry Andric   /// Helper to initialize all internal control variable information for those
2795ffd83dbSDimitry Andric   /// defined in OMPKinds.def.
2805ffd83dbSDimitry Andric   void initializeInternalControlVars() {
2815ffd83dbSDimitry Andric #define ICV_RT_SET(_Name, RTL)                                                 \
2825ffd83dbSDimitry Andric   {                                                                            \
2835ffd83dbSDimitry Andric     auto &ICV = ICVs[_Name];                                                   \
2845ffd83dbSDimitry Andric     ICV.Setter = RTL;                                                          \
2855ffd83dbSDimitry Andric   }
2865ffd83dbSDimitry Andric #define ICV_RT_GET(Name, RTL)                                                  \
2875ffd83dbSDimitry Andric   {                                                                            \
2885ffd83dbSDimitry Andric     auto &ICV = ICVs[Name];                                                    \
2895ffd83dbSDimitry Andric     ICV.Getter = RTL;                                                          \
2905ffd83dbSDimitry Andric   }
2915ffd83dbSDimitry Andric #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
2925ffd83dbSDimitry Andric   {                                                                            \
2935ffd83dbSDimitry Andric     auto &ICV = ICVs[Enum];                                                    \
2945ffd83dbSDimitry Andric     ICV.Name = _Name;                                                          \
2955ffd83dbSDimitry Andric     ICV.Kind = Enum;                                                           \
2965ffd83dbSDimitry Andric     ICV.InitKind = Init;                                                       \
2975ffd83dbSDimitry Andric     ICV.EnvVarName = _EnvVarName;                                              \
2985ffd83dbSDimitry Andric     switch (ICV.InitKind) {                                                    \
2995ffd83dbSDimitry Andric     case ICV_IMPLEMENTATION_DEFINED:                                           \
3005ffd83dbSDimitry Andric       ICV.InitValue = nullptr;                                                 \
3015ffd83dbSDimitry Andric       break;                                                                   \
3025ffd83dbSDimitry Andric     case ICV_ZERO:                                                             \
3035ffd83dbSDimitry Andric       ICV.InitValue = ConstantInt::get(                                        \
3045ffd83dbSDimitry Andric           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
3055ffd83dbSDimitry Andric       break;                                                                   \
3065ffd83dbSDimitry Andric     case ICV_FALSE:                                                            \
3075ffd83dbSDimitry Andric       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
3085ffd83dbSDimitry Andric       break;                                                                   \
3095ffd83dbSDimitry Andric     case ICV_LAST:                                                             \
3105ffd83dbSDimitry Andric       break;                                                                   \
3115ffd83dbSDimitry Andric     }                                                                          \
3125ffd83dbSDimitry Andric   }
3135ffd83dbSDimitry Andric #include "llvm/Frontend/OpenMP/OMPKinds.def"
3145ffd83dbSDimitry Andric   }
3155ffd83dbSDimitry Andric 
3165ffd83dbSDimitry Andric   /// Returns true if the function declaration \p F matches the runtime
3175ffd83dbSDimitry Andric   /// function types, that is, return type \p RTFRetType, and argument types
3185ffd83dbSDimitry Andric   /// \p RTFArgTypes.
3195ffd83dbSDimitry Andric   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
3205ffd83dbSDimitry Andric                                   SmallVector<Type *, 8> &RTFArgTypes) {
3215ffd83dbSDimitry Andric     // TODO: We should output information to the user (under debug output
3225ffd83dbSDimitry Andric     //       and via remarks).
3235ffd83dbSDimitry Andric 
3245ffd83dbSDimitry Andric     if (!F)
3255ffd83dbSDimitry Andric       return false;
3265ffd83dbSDimitry Andric     if (F->getReturnType() != RTFRetType)
3275ffd83dbSDimitry Andric       return false;
3285ffd83dbSDimitry Andric     if (F->arg_size() != RTFArgTypes.size())
3295ffd83dbSDimitry Andric       return false;
3305ffd83dbSDimitry Andric 
3315ffd83dbSDimitry Andric     auto RTFTyIt = RTFArgTypes.begin();
3325ffd83dbSDimitry Andric     for (Argument &Arg : F->args()) {
3335ffd83dbSDimitry Andric       if (Arg.getType() != *RTFTyIt)
3345ffd83dbSDimitry Andric         return false;
3355ffd83dbSDimitry Andric 
3365ffd83dbSDimitry Andric       ++RTFTyIt;
3375ffd83dbSDimitry Andric     }
3385ffd83dbSDimitry Andric 
3395ffd83dbSDimitry Andric     return true;
3405ffd83dbSDimitry Andric   }
3415ffd83dbSDimitry Andric 
3425ffd83dbSDimitry Andric   // Helper to collect all uses of the declaration in the UsesMap.
3435ffd83dbSDimitry Andric   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
3445ffd83dbSDimitry Andric     unsigned NumUses = 0;
3455ffd83dbSDimitry Andric     if (!RFI.Declaration)
3465ffd83dbSDimitry Andric       return NumUses;
3475ffd83dbSDimitry Andric     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
3485ffd83dbSDimitry Andric 
3495ffd83dbSDimitry Andric     if (CollectStats) {
3505ffd83dbSDimitry Andric       NumOpenMPRuntimeFunctionsIdentified += 1;
3515ffd83dbSDimitry Andric       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
3525ffd83dbSDimitry Andric     }
3535ffd83dbSDimitry Andric 
3545ffd83dbSDimitry Andric     // TODO: We directly convert uses into proper calls and unknown uses.
3555ffd83dbSDimitry Andric     for (Use &U : RFI.Declaration->uses()) {
3565ffd83dbSDimitry Andric       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
3575ffd83dbSDimitry Andric         if (ModuleSlice.count(UserI->getFunction())) {
3585ffd83dbSDimitry Andric           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
3595ffd83dbSDimitry Andric           ++NumUses;
3605ffd83dbSDimitry Andric         }
3615ffd83dbSDimitry Andric       } else {
3625ffd83dbSDimitry Andric         RFI.getOrCreateUseVector(nullptr).push_back(&U);
3635ffd83dbSDimitry Andric         ++NumUses;
3645ffd83dbSDimitry Andric       }
3655ffd83dbSDimitry Andric     }
3665ffd83dbSDimitry Andric     return NumUses;
3675ffd83dbSDimitry Andric   }
3685ffd83dbSDimitry Andric 
369e8d8bef9SDimitry Andric   // Helper function to recollect uses of a runtime function.
370e8d8bef9SDimitry Andric   void recollectUsesForFunction(RuntimeFunction RTF) {
371e8d8bef9SDimitry Andric     auto &RFI = RFIs[RTF];
3725ffd83dbSDimitry Andric     RFI.clearUsesMap();
3735ffd83dbSDimitry Andric     collectUses(RFI, /*CollectStats*/ false);
3745ffd83dbSDimitry Andric   }
375e8d8bef9SDimitry Andric 
376e8d8bef9SDimitry Andric   // Helper function to recollect uses of all runtime functions.
377e8d8bef9SDimitry Andric   void recollectUses() {
378e8d8bef9SDimitry Andric     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
379e8d8bef9SDimitry Andric       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
3805ffd83dbSDimitry Andric   }
3815ffd83dbSDimitry Andric 
3825ffd83dbSDimitry Andric   /// Helper to initialize all runtime function information for those defined
3835ffd83dbSDimitry Andric   /// in OpenMPKinds.def.
3845ffd83dbSDimitry Andric   void initializeRuntimeFunctions() {
3855ffd83dbSDimitry Andric     Module &M = *((*ModuleSlice.begin())->getParent());
3865ffd83dbSDimitry Andric 
3875ffd83dbSDimitry Andric     // Helper macros for handling __VA_ARGS__ in OMP_RTL
3885ffd83dbSDimitry Andric #define OMP_TYPE(VarName, ...)                                                 \
3895ffd83dbSDimitry Andric   Type *VarName = OMPBuilder.VarName;                                          \
3905ffd83dbSDimitry Andric   (void)VarName;
3915ffd83dbSDimitry Andric 
3925ffd83dbSDimitry Andric #define OMP_ARRAY_TYPE(VarName, ...)                                           \
3935ffd83dbSDimitry Andric   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
3945ffd83dbSDimitry Andric   (void)VarName##Ty;                                                           \
3955ffd83dbSDimitry Andric   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
3965ffd83dbSDimitry Andric   (void)VarName##PtrTy;
3975ffd83dbSDimitry Andric 
3985ffd83dbSDimitry Andric #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
3995ffd83dbSDimitry Andric   FunctionType *VarName = OMPBuilder.VarName;                                  \
4005ffd83dbSDimitry Andric   (void)VarName;                                                               \
4015ffd83dbSDimitry Andric   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
4025ffd83dbSDimitry Andric   (void)VarName##Ptr;
4035ffd83dbSDimitry Andric 
4045ffd83dbSDimitry Andric #define OMP_STRUCT_TYPE(VarName, ...)                                          \
4055ffd83dbSDimitry Andric   StructType *VarName = OMPBuilder.VarName;                                    \
4065ffd83dbSDimitry Andric   (void)VarName;                                                               \
4075ffd83dbSDimitry Andric   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
4085ffd83dbSDimitry Andric   (void)VarName##Ptr;
4095ffd83dbSDimitry Andric 
4105ffd83dbSDimitry Andric #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
4115ffd83dbSDimitry Andric   {                                                                            \
4125ffd83dbSDimitry Andric     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
4135ffd83dbSDimitry Andric     Function *F = M.getFunction(_Name);                                        \
414*fe6060f1SDimitry Andric     RTLFunctions.insert(F);                                                    \
4155ffd83dbSDimitry Andric     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
416*fe6060f1SDimitry Andric       RuntimeFunctionIDMap[F] = _Enum;                                         \
417*fe6060f1SDimitry Andric       F->removeFnAttr(Attribute::NoInline);                                    \
4185ffd83dbSDimitry Andric       auto &RFI = RFIs[_Enum];                                                 \
4195ffd83dbSDimitry Andric       RFI.Kind = _Enum;                                                        \
4205ffd83dbSDimitry Andric       RFI.Name = _Name;                                                        \
4215ffd83dbSDimitry Andric       RFI.IsVarArg = _IsVarArg;                                                \
4225ffd83dbSDimitry Andric       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
4235ffd83dbSDimitry Andric       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
4245ffd83dbSDimitry Andric       RFI.Declaration = F;                                                     \
4255ffd83dbSDimitry Andric       unsigned NumUses = collectUses(RFI);                                     \
4265ffd83dbSDimitry Andric       (void)NumUses;                                                           \
4275ffd83dbSDimitry Andric       LLVM_DEBUG({                                                             \
4285ffd83dbSDimitry Andric         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
4295ffd83dbSDimitry Andric                << " found\n";                                                  \
4305ffd83dbSDimitry Andric         if (RFI.Declaration)                                                   \
4315ffd83dbSDimitry Andric           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
4325ffd83dbSDimitry Andric                  << RFI.getNumFunctionsWithUses()                              \
4335ffd83dbSDimitry Andric                  << " different functions.\n";                                 \
4345ffd83dbSDimitry Andric       });                                                                      \
4355ffd83dbSDimitry Andric     }                                                                          \
4365ffd83dbSDimitry Andric   }
4375ffd83dbSDimitry Andric #include "llvm/Frontend/OpenMP/OMPKinds.def"
4385ffd83dbSDimitry Andric 
4395ffd83dbSDimitry Andric     // TODO: We should attach the attributes defined in OMPKinds.def.
4405ffd83dbSDimitry Andric   }
4415ffd83dbSDimitry Andric 
4425ffd83dbSDimitry Andric   /// Collection of known kernels (\see Kernel) in the module.
4435ffd83dbSDimitry Andric   SmallPtrSetImpl<Kernel> &Kernels;
444*fe6060f1SDimitry Andric 
445*fe6060f1SDimitry Andric   /// Collection of known OpenMP runtime functions..
446*fe6060f1SDimitry Andric   DenseSet<const Function *> RTLFunctions;
447*fe6060f1SDimitry Andric };
448*fe6060f1SDimitry Andric 
449*fe6060f1SDimitry Andric template <typename Ty, bool InsertInvalidates = true>
450*fe6060f1SDimitry Andric struct BooleanStateWithSetVector : public BooleanState {
451*fe6060f1SDimitry Andric   bool contains(const Ty &Elem) const { return Set.contains(Elem); }
452*fe6060f1SDimitry Andric   bool insert(const Ty &Elem) {
453*fe6060f1SDimitry Andric     if (InsertInvalidates)
454*fe6060f1SDimitry Andric       BooleanState::indicatePessimisticFixpoint();
455*fe6060f1SDimitry Andric     return Set.insert(Elem);
456*fe6060f1SDimitry Andric   }
457*fe6060f1SDimitry Andric 
458*fe6060f1SDimitry Andric   const Ty &operator[](int Idx) const { return Set[Idx]; }
459*fe6060f1SDimitry Andric   bool operator==(const BooleanStateWithSetVector &RHS) const {
460*fe6060f1SDimitry Andric     return BooleanState::operator==(RHS) && Set == RHS.Set;
461*fe6060f1SDimitry Andric   }
462*fe6060f1SDimitry Andric   bool operator!=(const BooleanStateWithSetVector &RHS) const {
463*fe6060f1SDimitry Andric     return !(*this == RHS);
464*fe6060f1SDimitry Andric   }
465*fe6060f1SDimitry Andric 
466*fe6060f1SDimitry Andric   bool empty() const { return Set.empty(); }
467*fe6060f1SDimitry Andric   size_t size() const { return Set.size(); }
468*fe6060f1SDimitry Andric 
469*fe6060f1SDimitry Andric   /// "Clamp" this state with \p RHS.
470*fe6060f1SDimitry Andric   BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
471*fe6060f1SDimitry Andric     BooleanState::operator^=(RHS);
472*fe6060f1SDimitry Andric     Set.insert(RHS.Set.begin(), RHS.Set.end());
473*fe6060f1SDimitry Andric     return *this;
474*fe6060f1SDimitry Andric   }
475*fe6060f1SDimitry Andric 
476*fe6060f1SDimitry Andric private:
477*fe6060f1SDimitry Andric   /// A set to keep track of elements.
478*fe6060f1SDimitry Andric   SetVector<Ty> Set;
479*fe6060f1SDimitry Andric 
480*fe6060f1SDimitry Andric public:
481*fe6060f1SDimitry Andric   typename decltype(Set)::iterator begin() { return Set.begin(); }
482*fe6060f1SDimitry Andric   typename decltype(Set)::iterator end() { return Set.end(); }
483*fe6060f1SDimitry Andric   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
484*fe6060f1SDimitry Andric   typename decltype(Set)::const_iterator end() const { return Set.end(); }
485*fe6060f1SDimitry Andric };
486*fe6060f1SDimitry Andric 
487*fe6060f1SDimitry Andric template <typename Ty, bool InsertInvalidates = true>
488*fe6060f1SDimitry Andric using BooleanStateWithPtrSetVector =
489*fe6060f1SDimitry Andric     BooleanStateWithSetVector<Ty *, InsertInvalidates>;
490*fe6060f1SDimitry Andric 
491*fe6060f1SDimitry Andric struct KernelInfoState : AbstractState {
492*fe6060f1SDimitry Andric   /// Flag to track if we reached a fixpoint.
493*fe6060f1SDimitry Andric   bool IsAtFixpoint = false;
494*fe6060f1SDimitry Andric 
495*fe6060f1SDimitry Andric   /// The parallel regions (identified by the outlined parallel functions) that
496*fe6060f1SDimitry Andric   /// can be reached from the associated function.
497*fe6060f1SDimitry Andric   BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
498*fe6060f1SDimitry Andric       ReachedKnownParallelRegions;
499*fe6060f1SDimitry Andric 
500*fe6060f1SDimitry Andric   /// State to track what parallel region we might reach.
501*fe6060f1SDimitry Andric   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
502*fe6060f1SDimitry Andric 
503*fe6060f1SDimitry Andric   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
504*fe6060f1SDimitry Andric   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
505*fe6060f1SDimitry Andric   /// false.
506*fe6060f1SDimitry Andric   BooleanStateWithPtrSetVector<Instruction> SPMDCompatibilityTracker;
507*fe6060f1SDimitry Andric 
508*fe6060f1SDimitry Andric   /// The __kmpc_target_init call in this kernel, if any. If we find more than
509*fe6060f1SDimitry Andric   /// one we abort as the kernel is malformed.
510*fe6060f1SDimitry Andric   CallBase *KernelInitCB = nullptr;
511*fe6060f1SDimitry Andric 
512*fe6060f1SDimitry Andric   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
513*fe6060f1SDimitry Andric   /// one we abort as the kernel is malformed.
514*fe6060f1SDimitry Andric   CallBase *KernelDeinitCB = nullptr;
515*fe6060f1SDimitry Andric 
516*fe6060f1SDimitry Andric   /// Flag to indicate if the associated function is a kernel entry.
517*fe6060f1SDimitry Andric   bool IsKernelEntry = false;
518*fe6060f1SDimitry Andric 
519*fe6060f1SDimitry Andric   /// State to track what kernel entries can reach the associated function.
520*fe6060f1SDimitry Andric   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
521*fe6060f1SDimitry Andric 
522*fe6060f1SDimitry Andric   /// State to indicate if we can track parallel level of the associated
523*fe6060f1SDimitry Andric   /// function. We will give up tracking if we encounter unknown caller or the
524*fe6060f1SDimitry Andric   /// caller is __kmpc_parallel_51.
525*fe6060f1SDimitry Andric   BooleanStateWithSetVector<uint8_t> ParallelLevels;
526*fe6060f1SDimitry Andric 
527*fe6060f1SDimitry Andric   /// Abstract State interface
528*fe6060f1SDimitry Andric   ///{
529*fe6060f1SDimitry Andric 
530*fe6060f1SDimitry Andric   KernelInfoState() {}
531*fe6060f1SDimitry Andric   KernelInfoState(bool BestState) {
532*fe6060f1SDimitry Andric     if (!BestState)
533*fe6060f1SDimitry Andric       indicatePessimisticFixpoint();
534*fe6060f1SDimitry Andric   }
535*fe6060f1SDimitry Andric 
536*fe6060f1SDimitry Andric   /// See AbstractState::isValidState(...)
537*fe6060f1SDimitry Andric   bool isValidState() const override { return true; }
538*fe6060f1SDimitry Andric 
539*fe6060f1SDimitry Andric   /// See AbstractState::isAtFixpoint(...)
540*fe6060f1SDimitry Andric   bool isAtFixpoint() const override { return IsAtFixpoint; }
541*fe6060f1SDimitry Andric 
542*fe6060f1SDimitry Andric   /// See AbstractState::indicatePessimisticFixpoint(...)
543*fe6060f1SDimitry Andric   ChangeStatus indicatePessimisticFixpoint() override {
544*fe6060f1SDimitry Andric     IsAtFixpoint = true;
545*fe6060f1SDimitry Andric     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
546*fe6060f1SDimitry Andric     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
547*fe6060f1SDimitry Andric     return ChangeStatus::CHANGED;
548*fe6060f1SDimitry Andric   }
549*fe6060f1SDimitry Andric 
550*fe6060f1SDimitry Andric   /// See AbstractState::indicateOptimisticFixpoint(...)
551*fe6060f1SDimitry Andric   ChangeStatus indicateOptimisticFixpoint() override {
552*fe6060f1SDimitry Andric     IsAtFixpoint = true;
553*fe6060f1SDimitry Andric     return ChangeStatus::UNCHANGED;
554*fe6060f1SDimitry Andric   }
555*fe6060f1SDimitry Andric 
556*fe6060f1SDimitry Andric   /// Return the assumed state
557*fe6060f1SDimitry Andric   KernelInfoState &getAssumed() { return *this; }
558*fe6060f1SDimitry Andric   const KernelInfoState &getAssumed() const { return *this; }
559*fe6060f1SDimitry Andric 
560*fe6060f1SDimitry Andric   bool operator==(const KernelInfoState &RHS) const {
561*fe6060f1SDimitry Andric     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
562*fe6060f1SDimitry Andric       return false;
563*fe6060f1SDimitry Andric     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
564*fe6060f1SDimitry Andric       return false;
565*fe6060f1SDimitry Andric     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
566*fe6060f1SDimitry Andric       return false;
567*fe6060f1SDimitry Andric     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
568*fe6060f1SDimitry Andric       return false;
569*fe6060f1SDimitry Andric     return true;
570*fe6060f1SDimitry Andric   }
571*fe6060f1SDimitry Andric 
572*fe6060f1SDimitry Andric   /// Return empty set as the best state of potential values.
573*fe6060f1SDimitry Andric   static KernelInfoState getBestState() { return KernelInfoState(true); }
574*fe6060f1SDimitry Andric 
575*fe6060f1SDimitry Andric   static KernelInfoState getBestState(KernelInfoState &KIS) {
576*fe6060f1SDimitry Andric     return getBestState();
577*fe6060f1SDimitry Andric   }
578*fe6060f1SDimitry Andric 
579*fe6060f1SDimitry Andric   /// Return full set as the worst state of potential values.
580*fe6060f1SDimitry Andric   static KernelInfoState getWorstState() { return KernelInfoState(false); }
581*fe6060f1SDimitry Andric 
582*fe6060f1SDimitry Andric   /// "Clamp" this state with \p KIS.
583*fe6060f1SDimitry Andric   KernelInfoState operator^=(const KernelInfoState &KIS) {
584*fe6060f1SDimitry Andric     // Do not merge two different _init and _deinit call sites.
585*fe6060f1SDimitry Andric     if (KIS.KernelInitCB) {
586*fe6060f1SDimitry Andric       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
587*fe6060f1SDimitry Andric         indicatePessimisticFixpoint();
588*fe6060f1SDimitry Andric       KernelInitCB = KIS.KernelInitCB;
589*fe6060f1SDimitry Andric     }
590*fe6060f1SDimitry Andric     if (KIS.KernelDeinitCB) {
591*fe6060f1SDimitry Andric       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
592*fe6060f1SDimitry Andric         indicatePessimisticFixpoint();
593*fe6060f1SDimitry Andric       KernelDeinitCB = KIS.KernelDeinitCB;
594*fe6060f1SDimitry Andric     }
595*fe6060f1SDimitry Andric     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
596*fe6060f1SDimitry Andric     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
597*fe6060f1SDimitry Andric     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
598*fe6060f1SDimitry Andric     return *this;
599*fe6060f1SDimitry Andric   }
600*fe6060f1SDimitry Andric 
601*fe6060f1SDimitry Andric   KernelInfoState operator&=(const KernelInfoState &KIS) {
602*fe6060f1SDimitry Andric     return (*this ^= KIS);
603*fe6060f1SDimitry Andric   }
604*fe6060f1SDimitry Andric 
605*fe6060f1SDimitry Andric   ///}
6065ffd83dbSDimitry Andric };
6075ffd83dbSDimitry Andric 
608e8d8bef9SDimitry Andric /// Used to map the values physically (in the IR) stored in an offload
609e8d8bef9SDimitry Andric /// array, to a vector in memory.
610e8d8bef9SDimitry Andric struct OffloadArray {
611e8d8bef9SDimitry Andric   /// Physical array (in the IR).
612e8d8bef9SDimitry Andric   AllocaInst *Array = nullptr;
613e8d8bef9SDimitry Andric   /// Mapped values.
614e8d8bef9SDimitry Andric   SmallVector<Value *, 8> StoredValues;
615e8d8bef9SDimitry Andric   /// Last stores made in the offload array.
616e8d8bef9SDimitry Andric   SmallVector<StoreInst *, 8> LastAccesses;
617e8d8bef9SDimitry Andric 
618e8d8bef9SDimitry Andric   OffloadArray() = default;
619e8d8bef9SDimitry Andric 
620e8d8bef9SDimitry Andric   /// Initializes the OffloadArray with the values stored in \p Array before
621e8d8bef9SDimitry Andric   /// instruction \p Before is reached. Returns false if the initialization
622e8d8bef9SDimitry Andric   /// fails.
623e8d8bef9SDimitry Andric   /// This MUST be used immediately after the construction of the object.
624e8d8bef9SDimitry Andric   bool initialize(AllocaInst &Array, Instruction &Before) {
625e8d8bef9SDimitry Andric     if (!Array.getAllocatedType()->isArrayTy())
626e8d8bef9SDimitry Andric       return false;
627e8d8bef9SDimitry Andric 
628e8d8bef9SDimitry Andric     if (!getValues(Array, Before))
629e8d8bef9SDimitry Andric       return false;
630e8d8bef9SDimitry Andric 
631e8d8bef9SDimitry Andric     this->Array = &Array;
632e8d8bef9SDimitry Andric     return true;
633e8d8bef9SDimitry Andric   }
634e8d8bef9SDimitry Andric 
635e8d8bef9SDimitry Andric   static const unsigned DeviceIDArgNum = 1;
636e8d8bef9SDimitry Andric   static const unsigned BasePtrsArgNum = 3;
637e8d8bef9SDimitry Andric   static const unsigned PtrsArgNum = 4;
638e8d8bef9SDimitry Andric   static const unsigned SizesArgNum = 5;
639e8d8bef9SDimitry Andric 
640e8d8bef9SDimitry Andric private:
641e8d8bef9SDimitry Andric   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
642e8d8bef9SDimitry Andric   /// \p Array, leaving StoredValues with the values stored before the
643e8d8bef9SDimitry Andric   /// instruction \p Before is reached.
644e8d8bef9SDimitry Andric   bool getValues(AllocaInst &Array, Instruction &Before) {
645e8d8bef9SDimitry Andric     // Initialize container.
646e8d8bef9SDimitry Andric     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
647e8d8bef9SDimitry Andric     StoredValues.assign(NumValues, nullptr);
648e8d8bef9SDimitry Andric     LastAccesses.assign(NumValues, nullptr);
649e8d8bef9SDimitry Andric 
650e8d8bef9SDimitry Andric     // TODO: This assumes the instruction \p Before is in the same
651e8d8bef9SDimitry Andric     //  BasicBlock as Array. Make it general, for any control flow graph.
652e8d8bef9SDimitry Andric     BasicBlock *BB = Array.getParent();
653e8d8bef9SDimitry Andric     if (BB != Before.getParent())
654e8d8bef9SDimitry Andric       return false;
655e8d8bef9SDimitry Andric 
656e8d8bef9SDimitry Andric     const DataLayout &DL = Array.getModule()->getDataLayout();
657e8d8bef9SDimitry Andric     const unsigned int PointerSize = DL.getPointerSize();
658e8d8bef9SDimitry Andric 
659e8d8bef9SDimitry Andric     for (Instruction &I : *BB) {
660e8d8bef9SDimitry Andric       if (&I == &Before)
661e8d8bef9SDimitry Andric         break;
662e8d8bef9SDimitry Andric 
663e8d8bef9SDimitry Andric       if (!isa<StoreInst>(&I))
664e8d8bef9SDimitry Andric         continue;
665e8d8bef9SDimitry Andric 
666e8d8bef9SDimitry Andric       auto *S = cast<StoreInst>(&I);
667e8d8bef9SDimitry Andric       int64_t Offset = -1;
668e8d8bef9SDimitry Andric       auto *Dst =
669e8d8bef9SDimitry Andric           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
670e8d8bef9SDimitry Andric       if (Dst == &Array) {
671e8d8bef9SDimitry Andric         int64_t Idx = Offset / PointerSize;
672e8d8bef9SDimitry Andric         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
673e8d8bef9SDimitry Andric         LastAccesses[Idx] = S;
674e8d8bef9SDimitry Andric       }
675e8d8bef9SDimitry Andric     }
676e8d8bef9SDimitry Andric 
677e8d8bef9SDimitry Andric     return isFilled();
678e8d8bef9SDimitry Andric   }
679e8d8bef9SDimitry Andric 
680e8d8bef9SDimitry Andric   /// Returns true if all values in StoredValues and
681e8d8bef9SDimitry Andric   /// LastAccesses are not nullptrs.
682e8d8bef9SDimitry Andric   bool isFilled() {
683e8d8bef9SDimitry Andric     const unsigned NumValues = StoredValues.size();
684e8d8bef9SDimitry Andric     for (unsigned I = 0; I < NumValues; ++I) {
685e8d8bef9SDimitry Andric       if (!StoredValues[I] || !LastAccesses[I])
686e8d8bef9SDimitry Andric         return false;
687e8d8bef9SDimitry Andric     }
688e8d8bef9SDimitry Andric 
689e8d8bef9SDimitry Andric     return true;
690e8d8bef9SDimitry Andric   }
691e8d8bef9SDimitry Andric };
692e8d8bef9SDimitry Andric 
6935ffd83dbSDimitry Andric struct OpenMPOpt {
6945ffd83dbSDimitry Andric 
6955ffd83dbSDimitry Andric   using OptimizationRemarkGetter =
6965ffd83dbSDimitry Andric       function_ref<OptimizationRemarkEmitter &(Function *)>;
6975ffd83dbSDimitry Andric 
6985ffd83dbSDimitry Andric   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
6995ffd83dbSDimitry Andric             OptimizationRemarkGetter OREGetter,
7005ffd83dbSDimitry Andric             OMPInformationCache &OMPInfoCache, Attributor &A)
7015ffd83dbSDimitry Andric       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
7025ffd83dbSDimitry Andric         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
7035ffd83dbSDimitry Andric 
704e8d8bef9SDimitry Andric   /// Check if any remarks are enabled for openmp-opt
705e8d8bef9SDimitry Andric   bool remarksEnabled() {
706e8d8bef9SDimitry Andric     auto &Ctx = M.getContext();
707e8d8bef9SDimitry Andric     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
708e8d8bef9SDimitry Andric   }
709e8d8bef9SDimitry Andric 
7105ffd83dbSDimitry Andric   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
711*fe6060f1SDimitry Andric   bool run(bool IsModulePass) {
7125ffd83dbSDimitry Andric     if (SCC.empty())
7135ffd83dbSDimitry Andric       return false;
7145ffd83dbSDimitry Andric 
7155ffd83dbSDimitry Andric     bool Changed = false;
7165ffd83dbSDimitry Andric 
7175ffd83dbSDimitry Andric     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
7185ffd83dbSDimitry Andric                       << " functions in a slice with "
7195ffd83dbSDimitry Andric                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
7205ffd83dbSDimitry Andric 
721*fe6060f1SDimitry Andric     if (IsModulePass) {
722*fe6060f1SDimitry Andric       Changed |= runAttributor(IsModulePass);
723*fe6060f1SDimitry Andric 
724*fe6060f1SDimitry Andric       // Recollect uses, in case Attributor deleted any.
725*fe6060f1SDimitry Andric       OMPInfoCache.recollectUses();
726*fe6060f1SDimitry Andric 
727*fe6060f1SDimitry Andric       // TODO: This should be folded into buildCustomStateMachine.
728*fe6060f1SDimitry Andric       Changed |= rewriteDeviceCodeStateMachine();
729*fe6060f1SDimitry Andric 
730*fe6060f1SDimitry Andric       if (remarksEnabled())
731*fe6060f1SDimitry Andric         analysisGlobalization();
732*fe6060f1SDimitry Andric     } else {
7335ffd83dbSDimitry Andric       if (PrintICVValues)
7345ffd83dbSDimitry Andric         printICVs();
7355ffd83dbSDimitry Andric       if (PrintOpenMPKernels)
7365ffd83dbSDimitry Andric         printKernels();
7375ffd83dbSDimitry Andric 
738*fe6060f1SDimitry Andric       Changed |= runAttributor(IsModulePass);
7395ffd83dbSDimitry Andric 
7405ffd83dbSDimitry Andric       // Recollect uses, in case Attributor deleted any.
7415ffd83dbSDimitry Andric       OMPInfoCache.recollectUses();
7425ffd83dbSDimitry Andric 
7435ffd83dbSDimitry Andric       Changed |= deleteParallelRegions();
744*fe6060f1SDimitry Andric 
745e8d8bef9SDimitry Andric       if (HideMemoryTransferLatency)
746e8d8bef9SDimitry Andric         Changed |= hideMemTransfersLatency();
747e8d8bef9SDimitry Andric       Changed |= deduplicateRuntimeCalls();
748e8d8bef9SDimitry Andric       if (EnableParallelRegionMerging) {
749e8d8bef9SDimitry Andric         if (mergeParallelRegions()) {
750e8d8bef9SDimitry Andric           deduplicateRuntimeCalls();
751e8d8bef9SDimitry Andric           Changed = true;
752e8d8bef9SDimitry Andric         }
753e8d8bef9SDimitry Andric       }
754*fe6060f1SDimitry Andric     }
7555ffd83dbSDimitry Andric 
7565ffd83dbSDimitry Andric     return Changed;
7575ffd83dbSDimitry Andric   }
7585ffd83dbSDimitry Andric 
7595ffd83dbSDimitry Andric   /// Print initial ICV values for testing.
7605ffd83dbSDimitry Andric   /// FIXME: This should be done from the Attributor once it is added.
7615ffd83dbSDimitry Andric   void printICVs() const {
762e8d8bef9SDimitry Andric     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
763e8d8bef9SDimitry Andric                                  ICV_proc_bind};
7645ffd83dbSDimitry Andric 
7655ffd83dbSDimitry Andric     for (Function *F : OMPInfoCache.ModuleSlice) {
7665ffd83dbSDimitry Andric       for (auto ICV : ICVs) {
7675ffd83dbSDimitry Andric         auto ICVInfo = OMPInfoCache.ICVs[ICV];
768*fe6060f1SDimitry Andric         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
769*fe6060f1SDimitry Andric           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
7705ffd83dbSDimitry Andric                      << " Value: "
7715ffd83dbSDimitry Andric                      << (ICVInfo.InitValue
772*fe6060f1SDimitry Andric                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
7735ffd83dbSDimitry Andric                              : "IMPLEMENTATION_DEFINED");
7745ffd83dbSDimitry Andric         };
7755ffd83dbSDimitry Andric 
776*fe6060f1SDimitry Andric         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
7775ffd83dbSDimitry Andric       }
7785ffd83dbSDimitry Andric     }
7795ffd83dbSDimitry Andric   }
7805ffd83dbSDimitry Andric 
7815ffd83dbSDimitry Andric   /// Print OpenMP GPU kernels for testing.
7825ffd83dbSDimitry Andric   void printKernels() const {
7835ffd83dbSDimitry Andric     for (Function *F : SCC) {
7845ffd83dbSDimitry Andric       if (!OMPInfoCache.Kernels.count(F))
7855ffd83dbSDimitry Andric         continue;
7865ffd83dbSDimitry Andric 
787*fe6060f1SDimitry Andric       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
788*fe6060f1SDimitry Andric         return ORA << "OpenMP GPU kernel "
7895ffd83dbSDimitry Andric                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
7905ffd83dbSDimitry Andric       };
7915ffd83dbSDimitry Andric 
792*fe6060f1SDimitry Andric       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
7935ffd83dbSDimitry Andric     }
7945ffd83dbSDimitry Andric   }
7955ffd83dbSDimitry Andric 
7965ffd83dbSDimitry Andric   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
7975ffd83dbSDimitry Andric   /// given it has to be the callee or a nullptr is returned.
7985ffd83dbSDimitry Andric   static CallInst *getCallIfRegularCall(
7995ffd83dbSDimitry Andric       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
8005ffd83dbSDimitry Andric     CallInst *CI = dyn_cast<CallInst>(U.getUser());
8015ffd83dbSDimitry Andric     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
802*fe6060f1SDimitry Andric         (!RFI ||
803*fe6060f1SDimitry Andric          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
8045ffd83dbSDimitry Andric       return CI;
8055ffd83dbSDimitry Andric     return nullptr;
8065ffd83dbSDimitry Andric   }
8075ffd83dbSDimitry Andric 
8085ffd83dbSDimitry Andric   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
8095ffd83dbSDimitry Andric   /// the callee or a nullptr is returned.
8105ffd83dbSDimitry Andric   static CallInst *getCallIfRegularCall(
8115ffd83dbSDimitry Andric       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
8125ffd83dbSDimitry Andric     CallInst *CI = dyn_cast<CallInst>(&V);
8135ffd83dbSDimitry Andric     if (CI && !CI->hasOperandBundles() &&
814*fe6060f1SDimitry Andric         (!RFI ||
815*fe6060f1SDimitry Andric          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
8165ffd83dbSDimitry Andric       return CI;
8175ffd83dbSDimitry Andric     return nullptr;
8185ffd83dbSDimitry Andric   }
8195ffd83dbSDimitry Andric 
8205ffd83dbSDimitry Andric private:
821e8d8bef9SDimitry Andric   /// Merge parallel regions when it is safe.
822e8d8bef9SDimitry Andric   bool mergeParallelRegions() {
823e8d8bef9SDimitry Andric     const unsigned CallbackCalleeOperand = 2;
824e8d8bef9SDimitry Andric     const unsigned CallbackFirstArgOperand = 3;
825e8d8bef9SDimitry Andric     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
826e8d8bef9SDimitry Andric 
827e8d8bef9SDimitry Andric     // Check if there are any __kmpc_fork_call calls to merge.
828e8d8bef9SDimitry Andric     OMPInformationCache::RuntimeFunctionInfo &RFI =
829e8d8bef9SDimitry Andric         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
830e8d8bef9SDimitry Andric 
831e8d8bef9SDimitry Andric     if (!RFI.Declaration)
832e8d8bef9SDimitry Andric       return false;
833e8d8bef9SDimitry Andric 
834e8d8bef9SDimitry Andric     // Unmergable calls that prevent merging a parallel region.
835e8d8bef9SDimitry Andric     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
836e8d8bef9SDimitry Andric         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
837e8d8bef9SDimitry Andric         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
838e8d8bef9SDimitry Andric     };
839e8d8bef9SDimitry Andric 
840e8d8bef9SDimitry Andric     bool Changed = false;
841e8d8bef9SDimitry Andric     LoopInfo *LI = nullptr;
842e8d8bef9SDimitry Andric     DominatorTree *DT = nullptr;
843e8d8bef9SDimitry Andric 
844e8d8bef9SDimitry Andric     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
845e8d8bef9SDimitry Andric 
846e8d8bef9SDimitry Andric     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
847e8d8bef9SDimitry Andric     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
848e8d8bef9SDimitry Andric                          BasicBlock &ContinuationIP) {
849e8d8bef9SDimitry Andric       BasicBlock *CGStartBB = CodeGenIP.getBlock();
850e8d8bef9SDimitry Andric       BasicBlock *CGEndBB =
851e8d8bef9SDimitry Andric           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
852e8d8bef9SDimitry Andric       assert(StartBB != nullptr && "StartBB should not be null");
853e8d8bef9SDimitry Andric       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
854e8d8bef9SDimitry Andric       assert(EndBB != nullptr && "EndBB should not be null");
855e8d8bef9SDimitry Andric       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
856e8d8bef9SDimitry Andric     };
857e8d8bef9SDimitry Andric 
858e8d8bef9SDimitry Andric     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
859e8d8bef9SDimitry Andric                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
860e8d8bef9SDimitry Andric       ReplacementValue = &Inner;
861e8d8bef9SDimitry Andric       return CodeGenIP;
862e8d8bef9SDimitry Andric     };
863e8d8bef9SDimitry Andric 
864e8d8bef9SDimitry Andric     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
865e8d8bef9SDimitry Andric 
866e8d8bef9SDimitry Andric     /// Create a sequential execution region within a merged parallel region,
867e8d8bef9SDimitry Andric     /// encapsulated in a master construct with a barrier for synchronization.
868e8d8bef9SDimitry Andric     auto CreateSequentialRegion = [&](Function *OuterFn,
869e8d8bef9SDimitry Andric                                       BasicBlock *OuterPredBB,
870e8d8bef9SDimitry Andric                                       Instruction *SeqStartI,
871e8d8bef9SDimitry Andric                                       Instruction *SeqEndI) {
872e8d8bef9SDimitry Andric       // Isolate the instructions of the sequential region to a separate
873e8d8bef9SDimitry Andric       // block.
874e8d8bef9SDimitry Andric       BasicBlock *ParentBB = SeqStartI->getParent();
875e8d8bef9SDimitry Andric       BasicBlock *SeqEndBB =
876e8d8bef9SDimitry Andric           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
877e8d8bef9SDimitry Andric       BasicBlock *SeqAfterBB =
878e8d8bef9SDimitry Andric           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
879e8d8bef9SDimitry Andric       BasicBlock *SeqStartBB =
880e8d8bef9SDimitry Andric           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
881e8d8bef9SDimitry Andric 
882e8d8bef9SDimitry Andric       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
883e8d8bef9SDimitry Andric              "Expected a different CFG");
884e8d8bef9SDimitry Andric       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
885e8d8bef9SDimitry Andric       ParentBB->getTerminator()->eraseFromParent();
886e8d8bef9SDimitry Andric 
887e8d8bef9SDimitry Andric       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
888e8d8bef9SDimitry Andric                            BasicBlock &ContinuationIP) {
889e8d8bef9SDimitry Andric         BasicBlock *CGStartBB = CodeGenIP.getBlock();
890e8d8bef9SDimitry Andric         BasicBlock *CGEndBB =
891e8d8bef9SDimitry Andric             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
892e8d8bef9SDimitry Andric         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
893e8d8bef9SDimitry Andric         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
894e8d8bef9SDimitry Andric         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
895e8d8bef9SDimitry Andric         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
896e8d8bef9SDimitry Andric       };
897e8d8bef9SDimitry Andric       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
898e8d8bef9SDimitry Andric 
899e8d8bef9SDimitry Andric       // Find outputs from the sequential region to outside users and
900e8d8bef9SDimitry Andric       // broadcast their values to them.
901e8d8bef9SDimitry Andric       for (Instruction &I : *SeqStartBB) {
902e8d8bef9SDimitry Andric         SmallPtrSet<Instruction *, 4> OutsideUsers;
903e8d8bef9SDimitry Andric         for (User *Usr : I.users()) {
904e8d8bef9SDimitry Andric           Instruction &UsrI = *cast<Instruction>(Usr);
905e8d8bef9SDimitry Andric           // Ignore outputs to LT intrinsics, code extraction for the merged
906e8d8bef9SDimitry Andric           // parallel region will fix them.
907e8d8bef9SDimitry Andric           if (UsrI.isLifetimeStartOrEnd())
908e8d8bef9SDimitry Andric             continue;
909e8d8bef9SDimitry Andric 
910e8d8bef9SDimitry Andric           if (UsrI.getParent() != SeqStartBB)
911e8d8bef9SDimitry Andric             OutsideUsers.insert(&UsrI);
912e8d8bef9SDimitry Andric         }
913e8d8bef9SDimitry Andric 
914e8d8bef9SDimitry Andric         if (OutsideUsers.empty())
915e8d8bef9SDimitry Andric           continue;
916e8d8bef9SDimitry Andric 
917e8d8bef9SDimitry Andric         // Emit an alloca in the outer region to store the broadcasted
918e8d8bef9SDimitry Andric         // value.
919e8d8bef9SDimitry Andric         const DataLayout &DL = M.getDataLayout();
920e8d8bef9SDimitry Andric         AllocaInst *AllocaI = new AllocaInst(
921e8d8bef9SDimitry Andric             I.getType(), DL.getAllocaAddrSpace(), nullptr,
922e8d8bef9SDimitry Andric             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
923e8d8bef9SDimitry Andric 
924e8d8bef9SDimitry Andric         // Emit a store instruction in the sequential BB to update the
925e8d8bef9SDimitry Andric         // value.
926e8d8bef9SDimitry Andric         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
927e8d8bef9SDimitry Andric 
928e8d8bef9SDimitry Andric         // Emit a load instruction and replace the use of the output value
929e8d8bef9SDimitry Andric         // with it.
930e8d8bef9SDimitry Andric         for (Instruction *UsrI : OutsideUsers) {
931*fe6060f1SDimitry Andric           LoadInst *LoadI = new LoadInst(
932*fe6060f1SDimitry Andric               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
933e8d8bef9SDimitry Andric           UsrI->replaceUsesOfWith(&I, LoadI);
934e8d8bef9SDimitry Andric         }
935e8d8bef9SDimitry Andric       }
936e8d8bef9SDimitry Andric 
937e8d8bef9SDimitry Andric       OpenMPIRBuilder::LocationDescription Loc(
938e8d8bef9SDimitry Andric           InsertPointTy(ParentBB, ParentBB->end()), DL);
939e8d8bef9SDimitry Andric       InsertPointTy SeqAfterIP =
940e8d8bef9SDimitry Andric           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
941e8d8bef9SDimitry Andric 
942e8d8bef9SDimitry Andric       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
943e8d8bef9SDimitry Andric 
944e8d8bef9SDimitry Andric       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
945e8d8bef9SDimitry Andric 
946e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
947e8d8bef9SDimitry Andric                         << "\n");
948e8d8bef9SDimitry Andric     };
949e8d8bef9SDimitry Andric 
950e8d8bef9SDimitry Andric     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
951e8d8bef9SDimitry Andric     // contained in BB and only separated by instructions that can be
952e8d8bef9SDimitry Andric     // redundantly executed in parallel. The block BB is split before the first
953e8d8bef9SDimitry Andric     // call (in MergableCIs) and after the last so the entire region we merge
954e8d8bef9SDimitry Andric     // into a single parallel region is contained in a single basic block
955e8d8bef9SDimitry Andric     // without any other instructions. We use the OpenMPIRBuilder to outline
956e8d8bef9SDimitry Andric     // that block and call the resulting function via __kmpc_fork_call.
957e8d8bef9SDimitry Andric     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
958e8d8bef9SDimitry Andric       // TODO: Change the interface to allow single CIs expanded, e.g, to
959e8d8bef9SDimitry Andric       // include an outer loop.
960e8d8bef9SDimitry Andric       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
961e8d8bef9SDimitry Andric 
962e8d8bef9SDimitry Andric       auto Remark = [&](OptimizationRemark OR) {
963*fe6060f1SDimitry Andric         OR << "Parallel region merged with parallel region"
964*fe6060f1SDimitry Andric            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
965e8d8bef9SDimitry Andric         for (auto *CI : llvm::drop_begin(MergableCIs)) {
966e8d8bef9SDimitry Andric           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
967e8d8bef9SDimitry Andric           if (CI != MergableCIs.back())
968e8d8bef9SDimitry Andric             OR << ", ";
969e8d8bef9SDimitry Andric         }
970*fe6060f1SDimitry Andric         return OR << ".";
971e8d8bef9SDimitry Andric       };
972e8d8bef9SDimitry Andric 
973*fe6060f1SDimitry Andric       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
974e8d8bef9SDimitry Andric 
975e8d8bef9SDimitry Andric       Function *OriginalFn = BB->getParent();
976e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
977e8d8bef9SDimitry Andric                         << " parallel regions in " << OriginalFn->getName()
978e8d8bef9SDimitry Andric                         << "\n");
979e8d8bef9SDimitry Andric 
980e8d8bef9SDimitry Andric       // Isolate the calls to merge in a separate block.
981e8d8bef9SDimitry Andric       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
982e8d8bef9SDimitry Andric       BasicBlock *AfterBB =
983e8d8bef9SDimitry Andric           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
984e8d8bef9SDimitry Andric       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
985e8d8bef9SDimitry Andric                            "omp.par.merged");
986e8d8bef9SDimitry Andric 
987e8d8bef9SDimitry Andric       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
988e8d8bef9SDimitry Andric       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
989e8d8bef9SDimitry Andric       BB->getTerminator()->eraseFromParent();
990e8d8bef9SDimitry Andric 
991e8d8bef9SDimitry Andric       // Create sequential regions for sequential instructions that are
992e8d8bef9SDimitry Andric       // in-between mergable parallel regions.
993e8d8bef9SDimitry Andric       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
994e8d8bef9SDimitry Andric            It != End; ++It) {
995e8d8bef9SDimitry Andric         Instruction *ForkCI = *It;
996e8d8bef9SDimitry Andric         Instruction *NextForkCI = *(It + 1);
997e8d8bef9SDimitry Andric 
998e8d8bef9SDimitry Andric         // Continue if there are not in-between instructions.
999e8d8bef9SDimitry Andric         if (ForkCI->getNextNode() == NextForkCI)
1000e8d8bef9SDimitry Andric           continue;
1001e8d8bef9SDimitry Andric 
1002e8d8bef9SDimitry Andric         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1003e8d8bef9SDimitry Andric                                NextForkCI->getPrevNode());
1004e8d8bef9SDimitry Andric       }
1005e8d8bef9SDimitry Andric 
1006e8d8bef9SDimitry Andric       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1007e8d8bef9SDimitry Andric                                                DL);
1008e8d8bef9SDimitry Andric       IRBuilder<>::InsertPoint AllocaIP(
1009e8d8bef9SDimitry Andric           &OriginalFn->getEntryBlock(),
1010e8d8bef9SDimitry Andric           OriginalFn->getEntryBlock().getFirstInsertionPt());
1011e8d8bef9SDimitry Andric       // Create the merged parallel region with default proc binding, to
1012e8d8bef9SDimitry Andric       // avoid overriding binding settings, and without explicit cancellation.
1013e8d8bef9SDimitry Andric       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1014e8d8bef9SDimitry Andric           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1015e8d8bef9SDimitry Andric           OMP_PROC_BIND_default, /* IsCancellable */ false);
1016e8d8bef9SDimitry Andric       BranchInst::Create(AfterBB, AfterIP.getBlock());
1017e8d8bef9SDimitry Andric 
1018e8d8bef9SDimitry Andric       // Perform the actual outlining.
1019*fe6060f1SDimitry Andric       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
1020*fe6060f1SDimitry Andric                                        /* AllowExtractorSinking */ true);
1021e8d8bef9SDimitry Andric 
1022e8d8bef9SDimitry Andric       Function *OutlinedFn = MergableCIs.front()->getCaller();
1023e8d8bef9SDimitry Andric 
1024e8d8bef9SDimitry Andric       // Replace the __kmpc_fork_call calls with direct calls to the outlined
1025e8d8bef9SDimitry Andric       // callbacks.
1026e8d8bef9SDimitry Andric       SmallVector<Value *, 8> Args;
1027e8d8bef9SDimitry Andric       for (auto *CI : MergableCIs) {
1028e8d8bef9SDimitry Andric         Value *Callee =
1029e8d8bef9SDimitry Andric             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
1030e8d8bef9SDimitry Andric         FunctionType *FT =
1031e8d8bef9SDimitry Andric             cast<FunctionType>(Callee->getType()->getPointerElementType());
1032e8d8bef9SDimitry Andric         Args.clear();
1033e8d8bef9SDimitry Andric         Args.push_back(OutlinedFn->getArg(0));
1034e8d8bef9SDimitry Andric         Args.push_back(OutlinedFn->getArg(1));
1035e8d8bef9SDimitry Andric         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1036e8d8bef9SDimitry Andric              U < E; ++U)
1037e8d8bef9SDimitry Andric           Args.push_back(CI->getArgOperand(U));
1038e8d8bef9SDimitry Andric 
1039e8d8bef9SDimitry Andric         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1040e8d8bef9SDimitry Andric         if (CI->getDebugLoc())
1041e8d8bef9SDimitry Andric           NewCI->setDebugLoc(CI->getDebugLoc());
1042e8d8bef9SDimitry Andric 
1043e8d8bef9SDimitry Andric         // Forward parameter attributes from the callback to the callee.
1044e8d8bef9SDimitry Andric         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1045e8d8bef9SDimitry Andric              U < E; ++U)
1046e8d8bef9SDimitry Andric           for (const Attribute &A : CI->getAttributes().getParamAttributes(U))
1047e8d8bef9SDimitry Andric             NewCI->addParamAttr(
1048e8d8bef9SDimitry Andric                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1049e8d8bef9SDimitry Andric 
1050e8d8bef9SDimitry Andric         // Emit an explicit barrier to replace the implicit fork-join barrier.
1051e8d8bef9SDimitry Andric         if (CI != MergableCIs.back()) {
1052e8d8bef9SDimitry Andric           // TODO: Remove barrier if the merged parallel region includes the
1053e8d8bef9SDimitry Andric           // 'nowait' clause.
1054e8d8bef9SDimitry Andric           OMPInfoCache.OMPBuilder.createBarrier(
1055e8d8bef9SDimitry Andric               InsertPointTy(NewCI->getParent(),
1056e8d8bef9SDimitry Andric                             NewCI->getNextNode()->getIterator()),
1057e8d8bef9SDimitry Andric               OMPD_parallel);
1058e8d8bef9SDimitry Andric         }
1059e8d8bef9SDimitry Andric 
1060e8d8bef9SDimitry Andric         CI->eraseFromParent();
1061e8d8bef9SDimitry Andric       }
1062e8d8bef9SDimitry Andric 
1063e8d8bef9SDimitry Andric       assert(OutlinedFn != OriginalFn && "Outlining failed");
1064e8d8bef9SDimitry Andric       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1065e8d8bef9SDimitry Andric       CGUpdater.reanalyzeFunction(*OriginalFn);
1066e8d8bef9SDimitry Andric 
1067e8d8bef9SDimitry Andric       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1068e8d8bef9SDimitry Andric 
1069e8d8bef9SDimitry Andric       return true;
1070e8d8bef9SDimitry Andric     };
1071e8d8bef9SDimitry Andric 
1072e8d8bef9SDimitry Andric     // Helper function that identifes sequences of
1073e8d8bef9SDimitry Andric     // __kmpc_fork_call uses in a basic block.
1074e8d8bef9SDimitry Andric     auto DetectPRsCB = [&](Use &U, Function &F) {
1075e8d8bef9SDimitry Andric       CallInst *CI = getCallIfRegularCall(U, &RFI);
1076e8d8bef9SDimitry Andric       BB2PRMap[CI->getParent()].insert(CI);
1077e8d8bef9SDimitry Andric 
1078e8d8bef9SDimitry Andric       return false;
1079e8d8bef9SDimitry Andric     };
1080e8d8bef9SDimitry Andric 
1081e8d8bef9SDimitry Andric     BB2PRMap.clear();
1082e8d8bef9SDimitry Andric     RFI.foreachUse(SCC, DetectPRsCB);
1083e8d8bef9SDimitry Andric     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1084e8d8bef9SDimitry Andric     // Find mergable parallel regions within a basic block that are
1085e8d8bef9SDimitry Andric     // safe to merge, that is any in-between instructions can safely
1086e8d8bef9SDimitry Andric     // execute in parallel after merging.
1087e8d8bef9SDimitry Andric     // TODO: support merging across basic-blocks.
1088e8d8bef9SDimitry Andric     for (auto &It : BB2PRMap) {
1089e8d8bef9SDimitry Andric       auto &CIs = It.getSecond();
1090e8d8bef9SDimitry Andric       if (CIs.size() < 2)
1091e8d8bef9SDimitry Andric         continue;
1092e8d8bef9SDimitry Andric 
1093e8d8bef9SDimitry Andric       BasicBlock *BB = It.getFirst();
1094e8d8bef9SDimitry Andric       SmallVector<CallInst *, 4> MergableCIs;
1095e8d8bef9SDimitry Andric 
1096e8d8bef9SDimitry Andric       /// Returns true if the instruction is mergable, false otherwise.
1097e8d8bef9SDimitry Andric       /// A terminator instruction is unmergable by definition since merging
1098e8d8bef9SDimitry Andric       /// works within a BB. Instructions before the mergable region are
1099e8d8bef9SDimitry Andric       /// mergable if they are not calls to OpenMP runtime functions that may
1100e8d8bef9SDimitry Andric       /// set different execution parameters for subsequent parallel regions.
1101e8d8bef9SDimitry Andric       /// Instructions in-between parallel regions are mergable if they are not
1102e8d8bef9SDimitry Andric       /// calls to any non-intrinsic function since that may call a non-mergable
1103e8d8bef9SDimitry Andric       /// OpenMP runtime function.
1104e8d8bef9SDimitry Andric       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1105e8d8bef9SDimitry Andric         // We do not merge across BBs, hence return false (unmergable) if the
1106e8d8bef9SDimitry Andric         // instruction is a terminator.
1107e8d8bef9SDimitry Andric         if (I.isTerminator())
1108e8d8bef9SDimitry Andric           return false;
1109e8d8bef9SDimitry Andric 
1110e8d8bef9SDimitry Andric         if (!isa<CallInst>(&I))
1111e8d8bef9SDimitry Andric           return true;
1112e8d8bef9SDimitry Andric 
1113e8d8bef9SDimitry Andric         CallInst *CI = cast<CallInst>(&I);
1114e8d8bef9SDimitry Andric         if (IsBeforeMergableRegion) {
1115e8d8bef9SDimitry Andric           Function *CalledFunction = CI->getCalledFunction();
1116e8d8bef9SDimitry Andric           if (!CalledFunction)
1117e8d8bef9SDimitry Andric             return false;
1118e8d8bef9SDimitry Andric           // Return false (unmergable) if the call before the parallel
1119e8d8bef9SDimitry Andric           // region calls an explicit affinity (proc_bind) or number of
1120e8d8bef9SDimitry Andric           // threads (num_threads) compiler-generated function. Those settings
1121e8d8bef9SDimitry Andric           // may be incompatible with following parallel regions.
1122e8d8bef9SDimitry Andric           // TODO: ICV tracking to detect compatibility.
1123e8d8bef9SDimitry Andric           for (const auto &RFI : UnmergableCallsInfo) {
1124e8d8bef9SDimitry Andric             if (CalledFunction == RFI.Declaration)
1125e8d8bef9SDimitry Andric               return false;
1126e8d8bef9SDimitry Andric           }
1127e8d8bef9SDimitry Andric         } else {
1128e8d8bef9SDimitry Andric           // Return false (unmergable) if there is a call instruction
1129e8d8bef9SDimitry Andric           // in-between parallel regions when it is not an intrinsic. It
1130e8d8bef9SDimitry Andric           // may call an unmergable OpenMP runtime function in its callpath.
1131e8d8bef9SDimitry Andric           // TODO: Keep track of possible OpenMP calls in the callpath.
1132e8d8bef9SDimitry Andric           if (!isa<IntrinsicInst>(CI))
1133e8d8bef9SDimitry Andric             return false;
1134e8d8bef9SDimitry Andric         }
1135e8d8bef9SDimitry Andric 
1136e8d8bef9SDimitry Andric         return true;
1137e8d8bef9SDimitry Andric       };
1138e8d8bef9SDimitry Andric       // Find maximal number of parallel region CIs that are safe to merge.
1139e8d8bef9SDimitry Andric       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1140e8d8bef9SDimitry Andric         Instruction &I = *It;
1141e8d8bef9SDimitry Andric         ++It;
1142e8d8bef9SDimitry Andric 
1143e8d8bef9SDimitry Andric         if (CIs.count(&I)) {
1144e8d8bef9SDimitry Andric           MergableCIs.push_back(cast<CallInst>(&I));
1145e8d8bef9SDimitry Andric           continue;
1146e8d8bef9SDimitry Andric         }
1147e8d8bef9SDimitry Andric 
1148e8d8bef9SDimitry Andric         // Continue expanding if the instruction is mergable.
1149e8d8bef9SDimitry Andric         if (IsMergable(I, MergableCIs.empty()))
1150e8d8bef9SDimitry Andric           continue;
1151e8d8bef9SDimitry Andric 
1152e8d8bef9SDimitry Andric         // Forward the instruction iterator to skip the next parallel region
1153e8d8bef9SDimitry Andric         // since there is an unmergable instruction which can affect it.
1154e8d8bef9SDimitry Andric         for (; It != End; ++It) {
1155e8d8bef9SDimitry Andric           Instruction &SkipI = *It;
1156e8d8bef9SDimitry Andric           if (CIs.count(&SkipI)) {
1157e8d8bef9SDimitry Andric             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1158e8d8bef9SDimitry Andric                               << " due to " << I << "\n");
1159e8d8bef9SDimitry Andric             ++It;
1160e8d8bef9SDimitry Andric             break;
1161e8d8bef9SDimitry Andric           }
1162e8d8bef9SDimitry Andric         }
1163e8d8bef9SDimitry Andric 
1164e8d8bef9SDimitry Andric         // Store mergable regions found.
1165e8d8bef9SDimitry Andric         if (MergableCIs.size() > 1) {
1166e8d8bef9SDimitry Andric           MergableCIsVector.push_back(MergableCIs);
1167e8d8bef9SDimitry Andric           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1168e8d8bef9SDimitry Andric                             << " parallel regions in block " << BB->getName()
1169e8d8bef9SDimitry Andric                             << " of function " << BB->getParent()->getName()
1170e8d8bef9SDimitry Andric                             << "\n";);
1171e8d8bef9SDimitry Andric         }
1172e8d8bef9SDimitry Andric 
1173e8d8bef9SDimitry Andric         MergableCIs.clear();
1174e8d8bef9SDimitry Andric       }
1175e8d8bef9SDimitry Andric 
1176e8d8bef9SDimitry Andric       if (!MergableCIsVector.empty()) {
1177e8d8bef9SDimitry Andric         Changed = true;
1178e8d8bef9SDimitry Andric 
1179e8d8bef9SDimitry Andric         for (auto &MergableCIs : MergableCIsVector)
1180e8d8bef9SDimitry Andric           Merge(MergableCIs, BB);
1181*fe6060f1SDimitry Andric         MergableCIsVector.clear();
1182e8d8bef9SDimitry Andric       }
1183e8d8bef9SDimitry Andric     }
1184e8d8bef9SDimitry Andric 
1185e8d8bef9SDimitry Andric     if (Changed) {
1186e8d8bef9SDimitry Andric       /// Re-collect use for fork calls, emitted barrier calls, and
1187e8d8bef9SDimitry Andric       /// any emitted master/end_master calls.
1188e8d8bef9SDimitry Andric       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1189e8d8bef9SDimitry Andric       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1190e8d8bef9SDimitry Andric       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1191e8d8bef9SDimitry Andric       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1192e8d8bef9SDimitry Andric     }
1193e8d8bef9SDimitry Andric 
1194e8d8bef9SDimitry Andric     return Changed;
1195e8d8bef9SDimitry Andric   }
1196e8d8bef9SDimitry Andric 
11975ffd83dbSDimitry Andric   /// Try to delete parallel regions if possible.
11985ffd83dbSDimitry Andric   bool deleteParallelRegions() {
11995ffd83dbSDimitry Andric     const unsigned CallbackCalleeOperand = 2;
12005ffd83dbSDimitry Andric 
12015ffd83dbSDimitry Andric     OMPInformationCache::RuntimeFunctionInfo &RFI =
12025ffd83dbSDimitry Andric         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
12035ffd83dbSDimitry Andric 
12045ffd83dbSDimitry Andric     if (!RFI.Declaration)
12055ffd83dbSDimitry Andric       return false;
12065ffd83dbSDimitry Andric 
12075ffd83dbSDimitry Andric     bool Changed = false;
12085ffd83dbSDimitry Andric     auto DeleteCallCB = [&](Use &U, Function &) {
12095ffd83dbSDimitry Andric       CallInst *CI = getCallIfRegularCall(U);
12105ffd83dbSDimitry Andric       if (!CI)
12115ffd83dbSDimitry Andric         return false;
12125ffd83dbSDimitry Andric       auto *Fn = dyn_cast<Function>(
12135ffd83dbSDimitry Andric           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
12145ffd83dbSDimitry Andric       if (!Fn)
12155ffd83dbSDimitry Andric         return false;
12165ffd83dbSDimitry Andric       if (!Fn->onlyReadsMemory())
12175ffd83dbSDimitry Andric         return false;
12185ffd83dbSDimitry Andric       if (!Fn->hasFnAttribute(Attribute::WillReturn))
12195ffd83dbSDimitry Andric         return false;
12205ffd83dbSDimitry Andric 
12215ffd83dbSDimitry Andric       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
12225ffd83dbSDimitry Andric                         << CI->getCaller()->getName() << "\n");
12235ffd83dbSDimitry Andric 
12245ffd83dbSDimitry Andric       auto Remark = [&](OptimizationRemark OR) {
1225*fe6060f1SDimitry Andric         return OR << "Removing parallel region with no side-effects.";
12265ffd83dbSDimitry Andric       };
1227*fe6060f1SDimitry Andric       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
12285ffd83dbSDimitry Andric 
12295ffd83dbSDimitry Andric       CGUpdater.removeCallSite(*CI);
12305ffd83dbSDimitry Andric       CI->eraseFromParent();
12315ffd83dbSDimitry Andric       Changed = true;
12325ffd83dbSDimitry Andric       ++NumOpenMPParallelRegionsDeleted;
12335ffd83dbSDimitry Andric       return true;
12345ffd83dbSDimitry Andric     };
12355ffd83dbSDimitry Andric 
12365ffd83dbSDimitry Andric     RFI.foreachUse(SCC, DeleteCallCB);
12375ffd83dbSDimitry Andric 
12385ffd83dbSDimitry Andric     return Changed;
12395ffd83dbSDimitry Andric   }
12405ffd83dbSDimitry Andric 
12415ffd83dbSDimitry Andric   /// Try to eliminate runtime calls by reusing existing ones.
12425ffd83dbSDimitry Andric   bool deduplicateRuntimeCalls() {
12435ffd83dbSDimitry Andric     bool Changed = false;
12445ffd83dbSDimitry Andric 
12455ffd83dbSDimitry Andric     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
12465ffd83dbSDimitry Andric         OMPRTL_omp_get_num_threads,
12475ffd83dbSDimitry Andric         OMPRTL_omp_in_parallel,
12485ffd83dbSDimitry Andric         OMPRTL_omp_get_cancellation,
12495ffd83dbSDimitry Andric         OMPRTL_omp_get_thread_limit,
12505ffd83dbSDimitry Andric         OMPRTL_omp_get_supported_active_levels,
12515ffd83dbSDimitry Andric         OMPRTL_omp_get_level,
12525ffd83dbSDimitry Andric         OMPRTL_omp_get_ancestor_thread_num,
12535ffd83dbSDimitry Andric         OMPRTL_omp_get_team_size,
12545ffd83dbSDimitry Andric         OMPRTL_omp_get_active_level,
12555ffd83dbSDimitry Andric         OMPRTL_omp_in_final,
12565ffd83dbSDimitry Andric         OMPRTL_omp_get_proc_bind,
12575ffd83dbSDimitry Andric         OMPRTL_omp_get_num_places,
12585ffd83dbSDimitry Andric         OMPRTL_omp_get_num_procs,
12595ffd83dbSDimitry Andric         OMPRTL_omp_get_place_num,
12605ffd83dbSDimitry Andric         OMPRTL_omp_get_partition_num_places,
12615ffd83dbSDimitry Andric         OMPRTL_omp_get_partition_place_nums};
12625ffd83dbSDimitry Andric 
12635ffd83dbSDimitry Andric     // Global-tid is handled separately.
12645ffd83dbSDimitry Andric     SmallSetVector<Value *, 16> GTIdArgs;
12655ffd83dbSDimitry Andric     collectGlobalThreadIdArguments(GTIdArgs);
12665ffd83dbSDimitry Andric     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
12675ffd83dbSDimitry Andric                       << " global thread ID arguments\n");
12685ffd83dbSDimitry Andric 
12695ffd83dbSDimitry Andric     for (Function *F : SCC) {
12705ffd83dbSDimitry Andric       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1271e8d8bef9SDimitry Andric         Changed |= deduplicateRuntimeCalls(
1272e8d8bef9SDimitry Andric             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
12735ffd83dbSDimitry Andric 
12745ffd83dbSDimitry Andric       // __kmpc_global_thread_num is special as we can replace it with an
12755ffd83dbSDimitry Andric       // argument in enough cases to make it worth trying.
12765ffd83dbSDimitry Andric       Value *GTIdArg = nullptr;
12775ffd83dbSDimitry Andric       for (Argument &Arg : F->args())
12785ffd83dbSDimitry Andric         if (GTIdArgs.count(&Arg)) {
12795ffd83dbSDimitry Andric           GTIdArg = &Arg;
12805ffd83dbSDimitry Andric           break;
12815ffd83dbSDimitry Andric         }
12825ffd83dbSDimitry Andric       Changed |= deduplicateRuntimeCalls(
12835ffd83dbSDimitry Andric           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
12845ffd83dbSDimitry Andric     }
12855ffd83dbSDimitry Andric 
12865ffd83dbSDimitry Andric     return Changed;
12875ffd83dbSDimitry Andric   }
12885ffd83dbSDimitry Andric 
1289e8d8bef9SDimitry Andric   /// Tries to hide the latency of runtime calls that involve host to
1290e8d8bef9SDimitry Andric   /// device memory transfers by splitting them into their "issue" and "wait"
1291e8d8bef9SDimitry Andric   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1292e8d8bef9SDimitry Andric   /// moved downards as much as possible. The "issue" issues the memory transfer
1293e8d8bef9SDimitry Andric   /// asynchronously, returning a handle. The "wait" waits in the returned
1294e8d8bef9SDimitry Andric   /// handle for the memory transfer to finish.
1295e8d8bef9SDimitry Andric   bool hideMemTransfersLatency() {
1296e8d8bef9SDimitry Andric     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1297e8d8bef9SDimitry Andric     bool Changed = false;
1298e8d8bef9SDimitry Andric     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1299e8d8bef9SDimitry Andric       auto *RTCall = getCallIfRegularCall(U, &RFI);
1300e8d8bef9SDimitry Andric       if (!RTCall)
1301e8d8bef9SDimitry Andric         return false;
1302e8d8bef9SDimitry Andric 
1303e8d8bef9SDimitry Andric       OffloadArray OffloadArrays[3];
1304e8d8bef9SDimitry Andric       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1305e8d8bef9SDimitry Andric         return false;
1306e8d8bef9SDimitry Andric 
1307e8d8bef9SDimitry Andric       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1308e8d8bef9SDimitry Andric 
1309e8d8bef9SDimitry Andric       // TODO: Check if can be moved upwards.
1310e8d8bef9SDimitry Andric       bool WasSplit = false;
1311e8d8bef9SDimitry Andric       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1312e8d8bef9SDimitry Andric       if (WaitMovementPoint)
1313e8d8bef9SDimitry Andric         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1314e8d8bef9SDimitry Andric 
1315e8d8bef9SDimitry Andric       Changed |= WasSplit;
1316e8d8bef9SDimitry Andric       return WasSplit;
1317e8d8bef9SDimitry Andric     };
1318e8d8bef9SDimitry Andric     RFI.foreachUse(SCC, SplitMemTransfers);
1319e8d8bef9SDimitry Andric 
1320e8d8bef9SDimitry Andric     return Changed;
1321e8d8bef9SDimitry Andric   }
1322e8d8bef9SDimitry Andric 
1323e8d8bef9SDimitry Andric   void analysisGlobalization() {
1324*fe6060f1SDimitry Andric     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1325e8d8bef9SDimitry Andric 
1326e8d8bef9SDimitry Andric     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1327e8d8bef9SDimitry Andric       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1328*fe6060f1SDimitry Andric         auto Remark = [&](OptimizationRemarkMissed ORM) {
1329*fe6060f1SDimitry Andric           return ORM
1330e8d8bef9SDimitry Andric                  << "Found thread data sharing on the GPU. "
1331e8d8bef9SDimitry Andric                  << "Expect degraded performance due to data globalization.";
1332e8d8bef9SDimitry Andric         };
1333*fe6060f1SDimitry Andric         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1334e8d8bef9SDimitry Andric       }
1335e8d8bef9SDimitry Andric 
1336e8d8bef9SDimitry Andric       return false;
1337e8d8bef9SDimitry Andric     };
1338e8d8bef9SDimitry Andric 
1339e8d8bef9SDimitry Andric     RFI.foreachUse(SCC, CheckGlobalization);
1340e8d8bef9SDimitry Andric   }
1341e8d8bef9SDimitry Andric 
1342e8d8bef9SDimitry Andric   /// Maps the values stored in the offload arrays passed as arguments to
1343e8d8bef9SDimitry Andric   /// \p RuntimeCall into the offload arrays in \p OAs.
1344e8d8bef9SDimitry Andric   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1345e8d8bef9SDimitry Andric                                 MutableArrayRef<OffloadArray> OAs) {
1346e8d8bef9SDimitry Andric     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1347e8d8bef9SDimitry Andric 
1348e8d8bef9SDimitry Andric     // A runtime call that involves memory offloading looks something like:
1349e8d8bef9SDimitry Andric     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1350e8d8bef9SDimitry Andric     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1351e8d8bef9SDimitry Andric     // ...)
1352e8d8bef9SDimitry Andric     // So, the idea is to access the allocas that allocate space for these
1353e8d8bef9SDimitry Andric     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1354e8d8bef9SDimitry Andric     // Therefore:
1355e8d8bef9SDimitry Andric     // i8** %offload_baseptrs.
1356e8d8bef9SDimitry Andric     Value *BasePtrsArg =
1357e8d8bef9SDimitry Andric         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1358e8d8bef9SDimitry Andric     // i8** %offload_ptrs.
1359e8d8bef9SDimitry Andric     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1360e8d8bef9SDimitry Andric     // i8** %offload_sizes.
1361e8d8bef9SDimitry Andric     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1362e8d8bef9SDimitry Andric 
1363e8d8bef9SDimitry Andric     // Get values stored in **offload_baseptrs.
1364e8d8bef9SDimitry Andric     auto *V = getUnderlyingObject(BasePtrsArg);
1365e8d8bef9SDimitry Andric     if (!isa<AllocaInst>(V))
1366e8d8bef9SDimitry Andric       return false;
1367e8d8bef9SDimitry Andric     auto *BasePtrsArray = cast<AllocaInst>(V);
1368e8d8bef9SDimitry Andric     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1369e8d8bef9SDimitry Andric       return false;
1370e8d8bef9SDimitry Andric 
1371e8d8bef9SDimitry Andric     // Get values stored in **offload_baseptrs.
1372e8d8bef9SDimitry Andric     V = getUnderlyingObject(PtrsArg);
1373e8d8bef9SDimitry Andric     if (!isa<AllocaInst>(V))
1374e8d8bef9SDimitry Andric       return false;
1375e8d8bef9SDimitry Andric     auto *PtrsArray = cast<AllocaInst>(V);
1376e8d8bef9SDimitry Andric     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1377e8d8bef9SDimitry Andric       return false;
1378e8d8bef9SDimitry Andric 
1379e8d8bef9SDimitry Andric     // Get values stored in **offload_sizes.
1380e8d8bef9SDimitry Andric     V = getUnderlyingObject(SizesArg);
1381e8d8bef9SDimitry Andric     // If it's a [constant] global array don't analyze it.
1382e8d8bef9SDimitry Andric     if (isa<GlobalValue>(V))
1383e8d8bef9SDimitry Andric       return isa<Constant>(V);
1384e8d8bef9SDimitry Andric     if (!isa<AllocaInst>(V))
1385e8d8bef9SDimitry Andric       return false;
1386e8d8bef9SDimitry Andric 
1387e8d8bef9SDimitry Andric     auto *SizesArray = cast<AllocaInst>(V);
1388e8d8bef9SDimitry Andric     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1389e8d8bef9SDimitry Andric       return false;
1390e8d8bef9SDimitry Andric 
1391e8d8bef9SDimitry Andric     return true;
1392e8d8bef9SDimitry Andric   }
1393e8d8bef9SDimitry Andric 
1394e8d8bef9SDimitry Andric   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1395e8d8bef9SDimitry Andric   /// For now this is a way to test that the function getValuesInOffloadArrays
1396e8d8bef9SDimitry Andric   /// is working properly.
1397e8d8bef9SDimitry Andric   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1398e8d8bef9SDimitry Andric   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1399e8d8bef9SDimitry Andric     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1400e8d8bef9SDimitry Andric 
1401e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1402e8d8bef9SDimitry Andric     std::string ValuesStr;
1403e8d8bef9SDimitry Andric     raw_string_ostream Printer(ValuesStr);
1404e8d8bef9SDimitry Andric     std::string Separator = " --- ";
1405e8d8bef9SDimitry Andric 
1406e8d8bef9SDimitry Andric     for (auto *BP : OAs[0].StoredValues) {
1407e8d8bef9SDimitry Andric       BP->print(Printer);
1408e8d8bef9SDimitry Andric       Printer << Separator;
1409e8d8bef9SDimitry Andric     }
1410e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1411e8d8bef9SDimitry Andric     ValuesStr.clear();
1412e8d8bef9SDimitry Andric 
1413e8d8bef9SDimitry Andric     for (auto *P : OAs[1].StoredValues) {
1414e8d8bef9SDimitry Andric       P->print(Printer);
1415e8d8bef9SDimitry Andric       Printer << Separator;
1416e8d8bef9SDimitry Andric     }
1417e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1418e8d8bef9SDimitry Andric     ValuesStr.clear();
1419e8d8bef9SDimitry Andric 
1420e8d8bef9SDimitry Andric     for (auto *S : OAs[2].StoredValues) {
1421e8d8bef9SDimitry Andric       S->print(Printer);
1422e8d8bef9SDimitry Andric       Printer << Separator;
1423e8d8bef9SDimitry Andric     }
1424e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1425e8d8bef9SDimitry Andric   }
1426e8d8bef9SDimitry Andric 
1427e8d8bef9SDimitry Andric   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1428e8d8bef9SDimitry Andric   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1429e8d8bef9SDimitry Andric   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1430e8d8bef9SDimitry Andric     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1431e8d8bef9SDimitry Andric     //  Make it traverse the CFG.
1432e8d8bef9SDimitry Andric 
1433e8d8bef9SDimitry Andric     Instruction *CurrentI = &RuntimeCall;
1434e8d8bef9SDimitry Andric     bool IsWorthIt = false;
1435e8d8bef9SDimitry Andric     while ((CurrentI = CurrentI->getNextNode())) {
1436e8d8bef9SDimitry Andric 
1437e8d8bef9SDimitry Andric       // TODO: Once we detect the regions to be offloaded we should use the
1438e8d8bef9SDimitry Andric       //  alias analysis manager to check if CurrentI may modify one of
1439e8d8bef9SDimitry Andric       //  the offloaded regions.
1440e8d8bef9SDimitry Andric       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1441e8d8bef9SDimitry Andric         if (IsWorthIt)
1442e8d8bef9SDimitry Andric           return CurrentI;
1443e8d8bef9SDimitry Andric 
1444e8d8bef9SDimitry Andric         return nullptr;
1445e8d8bef9SDimitry Andric       }
1446e8d8bef9SDimitry Andric 
1447e8d8bef9SDimitry Andric       // FIXME: For now if we move it over anything without side effect
1448e8d8bef9SDimitry Andric       //  is worth it.
1449e8d8bef9SDimitry Andric       IsWorthIt = true;
1450e8d8bef9SDimitry Andric     }
1451e8d8bef9SDimitry Andric 
1452e8d8bef9SDimitry Andric     // Return end of BasicBlock.
1453e8d8bef9SDimitry Andric     return RuntimeCall.getParent()->getTerminator();
1454e8d8bef9SDimitry Andric   }
1455e8d8bef9SDimitry Andric 
1456e8d8bef9SDimitry Andric   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1457e8d8bef9SDimitry Andric   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1458e8d8bef9SDimitry Andric                                Instruction &WaitMovementPoint) {
1459e8d8bef9SDimitry Andric     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1460e8d8bef9SDimitry Andric     // function. Used for storing information of the async transfer, allowing to
1461e8d8bef9SDimitry Andric     // wait on it later.
1462e8d8bef9SDimitry Andric     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1463e8d8bef9SDimitry Andric     auto *F = RuntimeCall.getCaller();
1464e8d8bef9SDimitry Andric     Instruction *FirstInst = &(F->getEntryBlock().front());
1465e8d8bef9SDimitry Andric     AllocaInst *Handle = new AllocaInst(
1466e8d8bef9SDimitry Andric         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1467e8d8bef9SDimitry Andric 
1468e8d8bef9SDimitry Andric     // Add "issue" runtime call declaration:
1469e8d8bef9SDimitry Andric     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1470e8d8bef9SDimitry Andric     //   i8**, i8**, i64*, i64*)
1471e8d8bef9SDimitry Andric     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1472e8d8bef9SDimitry Andric         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1473e8d8bef9SDimitry Andric 
1474e8d8bef9SDimitry Andric     // Change RuntimeCall call site for its asynchronous version.
1475e8d8bef9SDimitry Andric     SmallVector<Value *, 16> Args;
1476e8d8bef9SDimitry Andric     for (auto &Arg : RuntimeCall.args())
1477e8d8bef9SDimitry Andric       Args.push_back(Arg.get());
1478e8d8bef9SDimitry Andric     Args.push_back(Handle);
1479e8d8bef9SDimitry Andric 
1480e8d8bef9SDimitry Andric     CallInst *IssueCallsite =
1481e8d8bef9SDimitry Andric         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1482e8d8bef9SDimitry Andric     RuntimeCall.eraseFromParent();
1483e8d8bef9SDimitry Andric 
1484e8d8bef9SDimitry Andric     // Add "wait" runtime call declaration:
1485e8d8bef9SDimitry Andric     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1486e8d8bef9SDimitry Andric     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1487e8d8bef9SDimitry Andric         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1488e8d8bef9SDimitry Andric 
1489e8d8bef9SDimitry Andric     Value *WaitParams[2] = {
1490e8d8bef9SDimitry Andric         IssueCallsite->getArgOperand(
1491e8d8bef9SDimitry Andric             OffloadArray::DeviceIDArgNum), // device_id.
1492e8d8bef9SDimitry Andric         Handle                             // handle to wait on.
1493e8d8bef9SDimitry Andric     };
1494e8d8bef9SDimitry Andric     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1495e8d8bef9SDimitry Andric 
1496e8d8bef9SDimitry Andric     return true;
1497e8d8bef9SDimitry Andric   }
1498e8d8bef9SDimitry Andric 
14995ffd83dbSDimitry Andric   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
15005ffd83dbSDimitry Andric                                     bool GlobalOnly, bool &SingleChoice) {
15015ffd83dbSDimitry Andric     if (CurrentIdent == NextIdent)
15025ffd83dbSDimitry Andric       return CurrentIdent;
15035ffd83dbSDimitry Andric 
15045ffd83dbSDimitry Andric     // TODO: Figure out how to actually combine multiple debug locations. For
15055ffd83dbSDimitry Andric     //       now we just keep an existing one if there is a single choice.
15065ffd83dbSDimitry Andric     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
15075ffd83dbSDimitry Andric       SingleChoice = !CurrentIdent;
15085ffd83dbSDimitry Andric       return NextIdent;
15095ffd83dbSDimitry Andric     }
15105ffd83dbSDimitry Andric     return nullptr;
15115ffd83dbSDimitry Andric   }
15125ffd83dbSDimitry Andric 
15135ffd83dbSDimitry Andric   /// Return an `struct ident_t*` value that represents the ones used in the
15145ffd83dbSDimitry Andric   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
15155ffd83dbSDimitry Andric   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
15165ffd83dbSDimitry Andric   /// return value we create one from scratch. We also do not yet combine
15175ffd83dbSDimitry Andric   /// information, e.g., the source locations, see combinedIdentStruct.
15185ffd83dbSDimitry Andric   Value *
15195ffd83dbSDimitry Andric   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
15205ffd83dbSDimitry Andric                                  Function &F, bool GlobalOnly) {
15215ffd83dbSDimitry Andric     bool SingleChoice = true;
15225ffd83dbSDimitry Andric     Value *Ident = nullptr;
15235ffd83dbSDimitry Andric     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
15245ffd83dbSDimitry Andric       CallInst *CI = getCallIfRegularCall(U, &RFI);
15255ffd83dbSDimitry Andric       if (!CI || &F != &Caller)
15265ffd83dbSDimitry Andric         return false;
15275ffd83dbSDimitry Andric       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
15285ffd83dbSDimitry Andric                                   /* GlobalOnly */ true, SingleChoice);
15295ffd83dbSDimitry Andric       return false;
15305ffd83dbSDimitry Andric     };
15315ffd83dbSDimitry Andric     RFI.foreachUse(SCC, CombineIdentStruct);
15325ffd83dbSDimitry Andric 
15335ffd83dbSDimitry Andric     if (!Ident || !SingleChoice) {
15345ffd83dbSDimitry Andric       // The IRBuilder uses the insertion block to get to the module, this is
15355ffd83dbSDimitry Andric       // unfortunate but we work around it for now.
15365ffd83dbSDimitry Andric       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
15375ffd83dbSDimitry Andric         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
15385ffd83dbSDimitry Andric             &F.getEntryBlock(), F.getEntryBlock().begin()));
15395ffd83dbSDimitry Andric       // Create a fallback location if non was found.
15405ffd83dbSDimitry Andric       // TODO: Use the debug locations of the calls instead.
15415ffd83dbSDimitry Andric       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
15425ffd83dbSDimitry Andric       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
15435ffd83dbSDimitry Andric     }
15445ffd83dbSDimitry Andric     return Ident;
15455ffd83dbSDimitry Andric   }
15465ffd83dbSDimitry Andric 
15475ffd83dbSDimitry Andric   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
15485ffd83dbSDimitry Andric   /// \p ReplVal if given.
15495ffd83dbSDimitry Andric   bool deduplicateRuntimeCalls(Function &F,
15505ffd83dbSDimitry Andric                                OMPInformationCache::RuntimeFunctionInfo &RFI,
15515ffd83dbSDimitry Andric                                Value *ReplVal = nullptr) {
15525ffd83dbSDimitry Andric     auto *UV = RFI.getUseVector(F);
15535ffd83dbSDimitry Andric     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
15545ffd83dbSDimitry Andric       return false;
15555ffd83dbSDimitry Andric 
15565ffd83dbSDimitry Andric     LLVM_DEBUG(
15575ffd83dbSDimitry Andric         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
15585ffd83dbSDimitry Andric                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
15595ffd83dbSDimitry Andric 
15605ffd83dbSDimitry Andric     assert((!ReplVal || (isa<Argument>(ReplVal) &&
15615ffd83dbSDimitry Andric                          cast<Argument>(ReplVal)->getParent() == &F)) &&
15625ffd83dbSDimitry Andric            "Unexpected replacement value!");
15635ffd83dbSDimitry Andric 
15645ffd83dbSDimitry Andric     // TODO: Use dominance to find a good position instead.
15655ffd83dbSDimitry Andric     auto CanBeMoved = [this](CallBase &CB) {
15665ffd83dbSDimitry Andric       unsigned NumArgs = CB.getNumArgOperands();
15675ffd83dbSDimitry Andric       if (NumArgs == 0)
15685ffd83dbSDimitry Andric         return true;
15695ffd83dbSDimitry Andric       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
15705ffd83dbSDimitry Andric         return false;
15715ffd83dbSDimitry Andric       for (unsigned u = 1; u < NumArgs; ++u)
15725ffd83dbSDimitry Andric         if (isa<Instruction>(CB.getArgOperand(u)))
15735ffd83dbSDimitry Andric           return false;
15745ffd83dbSDimitry Andric       return true;
15755ffd83dbSDimitry Andric     };
15765ffd83dbSDimitry Andric 
15775ffd83dbSDimitry Andric     if (!ReplVal) {
15785ffd83dbSDimitry Andric       for (Use *U : *UV)
15795ffd83dbSDimitry Andric         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
15805ffd83dbSDimitry Andric           if (!CanBeMoved(*CI))
15815ffd83dbSDimitry Andric             continue;
15825ffd83dbSDimitry Andric 
1583*fe6060f1SDimitry Andric           // If the function is a kernel, dedup will move
1584*fe6060f1SDimitry Andric           // the runtime call right after the kernel init callsite. Otherwise,
1585*fe6060f1SDimitry Andric           // it will move it to the beginning of the caller function.
1586*fe6060f1SDimitry Andric           if (isKernel(F)) {
1587*fe6060f1SDimitry Andric             auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1588*fe6060f1SDimitry Andric             auto *KernelInitUV = KernelInitRFI.getUseVector(F);
15895ffd83dbSDimitry Andric 
1590*fe6060f1SDimitry Andric             if (KernelInitUV->empty())
1591*fe6060f1SDimitry Andric               continue;
1592*fe6060f1SDimitry Andric 
1593*fe6060f1SDimitry Andric             assert(KernelInitUV->size() == 1 &&
1594*fe6060f1SDimitry Andric                    "Expected a single __kmpc_target_init in kernel\n");
1595*fe6060f1SDimitry Andric 
1596*fe6060f1SDimitry Andric             CallInst *KernelInitCI =
1597*fe6060f1SDimitry Andric                 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1598*fe6060f1SDimitry Andric             assert(KernelInitCI &&
1599*fe6060f1SDimitry Andric                    "Expected a call to __kmpc_target_init in kernel\n");
1600*fe6060f1SDimitry Andric 
1601*fe6060f1SDimitry Andric             CI->moveAfter(KernelInitCI);
1602*fe6060f1SDimitry Andric           } else
16035ffd83dbSDimitry Andric             CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
16045ffd83dbSDimitry Andric           ReplVal = CI;
16055ffd83dbSDimitry Andric           break;
16065ffd83dbSDimitry Andric         }
16075ffd83dbSDimitry Andric       if (!ReplVal)
16085ffd83dbSDimitry Andric         return false;
16095ffd83dbSDimitry Andric     }
16105ffd83dbSDimitry Andric 
16115ffd83dbSDimitry Andric     // If we use a call as a replacement value we need to make sure the ident is
16125ffd83dbSDimitry Andric     // valid at the new location. For now we just pick a global one, either
16135ffd83dbSDimitry Andric     // existing and used by one of the calls, or created from scratch.
16145ffd83dbSDimitry Andric     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
16155ffd83dbSDimitry Andric       if (CI->getNumArgOperands() > 0 &&
16165ffd83dbSDimitry Andric           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
16175ffd83dbSDimitry Andric         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
16185ffd83dbSDimitry Andric                                                       /* GlobalOnly */ true);
16195ffd83dbSDimitry Andric         CI->setArgOperand(0, Ident);
16205ffd83dbSDimitry Andric       }
16215ffd83dbSDimitry Andric     }
16225ffd83dbSDimitry Andric 
16235ffd83dbSDimitry Andric     bool Changed = false;
16245ffd83dbSDimitry Andric     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
16255ffd83dbSDimitry Andric       CallInst *CI = getCallIfRegularCall(U, &RFI);
16265ffd83dbSDimitry Andric       if (!CI || CI == ReplVal || &F != &Caller)
16275ffd83dbSDimitry Andric         return false;
16285ffd83dbSDimitry Andric       assert(CI->getCaller() == &F && "Unexpected call!");
16295ffd83dbSDimitry Andric 
16305ffd83dbSDimitry Andric       auto Remark = [&](OptimizationRemark OR) {
16315ffd83dbSDimitry Andric         return OR << "OpenMP runtime call "
1632*fe6060f1SDimitry Andric                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
16335ffd83dbSDimitry Andric       };
1634*fe6060f1SDimitry Andric       if (CI->getDebugLoc())
1635*fe6060f1SDimitry Andric         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1636*fe6060f1SDimitry Andric       else
1637*fe6060f1SDimitry Andric         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
16385ffd83dbSDimitry Andric 
16395ffd83dbSDimitry Andric       CGUpdater.removeCallSite(*CI);
16405ffd83dbSDimitry Andric       CI->replaceAllUsesWith(ReplVal);
16415ffd83dbSDimitry Andric       CI->eraseFromParent();
16425ffd83dbSDimitry Andric       ++NumOpenMPRuntimeCallsDeduplicated;
16435ffd83dbSDimitry Andric       Changed = true;
16445ffd83dbSDimitry Andric       return true;
16455ffd83dbSDimitry Andric     };
16465ffd83dbSDimitry Andric     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
16475ffd83dbSDimitry Andric 
16485ffd83dbSDimitry Andric     return Changed;
16495ffd83dbSDimitry Andric   }
16505ffd83dbSDimitry Andric 
16515ffd83dbSDimitry Andric   /// Collect arguments that represent the global thread id in \p GTIdArgs.
16525ffd83dbSDimitry Andric   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
16535ffd83dbSDimitry Andric     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
16545ffd83dbSDimitry Andric     //       initialization. We could define an AbstractAttribute instead and
16555ffd83dbSDimitry Andric     //       run the Attributor here once it can be run as an SCC pass.
16565ffd83dbSDimitry Andric 
16575ffd83dbSDimitry Andric     // Helper to check the argument \p ArgNo at all call sites of \p F for
16585ffd83dbSDimitry Andric     // a GTId.
16595ffd83dbSDimitry Andric     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
16605ffd83dbSDimitry Andric       if (!F.hasLocalLinkage())
16615ffd83dbSDimitry Andric         return false;
16625ffd83dbSDimitry Andric       for (Use &U : F.uses()) {
16635ffd83dbSDimitry Andric         if (CallInst *CI = getCallIfRegularCall(U)) {
16645ffd83dbSDimitry Andric           Value *ArgOp = CI->getArgOperand(ArgNo);
16655ffd83dbSDimitry Andric           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
16665ffd83dbSDimitry Andric               getCallIfRegularCall(
16675ffd83dbSDimitry Andric                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
16685ffd83dbSDimitry Andric             continue;
16695ffd83dbSDimitry Andric         }
16705ffd83dbSDimitry Andric         return false;
16715ffd83dbSDimitry Andric       }
16725ffd83dbSDimitry Andric       return true;
16735ffd83dbSDimitry Andric     };
16745ffd83dbSDimitry Andric 
16755ffd83dbSDimitry Andric     // Helper to identify uses of a GTId as GTId arguments.
16765ffd83dbSDimitry Andric     auto AddUserArgs = [&](Value &GTId) {
16775ffd83dbSDimitry Andric       for (Use &U : GTId.uses())
16785ffd83dbSDimitry Andric         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
16795ffd83dbSDimitry Andric           if (CI->isArgOperand(&U))
16805ffd83dbSDimitry Andric             if (Function *Callee = CI->getCalledFunction())
16815ffd83dbSDimitry Andric               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
16825ffd83dbSDimitry Andric                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
16835ffd83dbSDimitry Andric     };
16845ffd83dbSDimitry Andric 
16855ffd83dbSDimitry Andric     // The argument users of __kmpc_global_thread_num calls are GTIds.
16865ffd83dbSDimitry Andric     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
16875ffd83dbSDimitry Andric         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
16885ffd83dbSDimitry Andric 
16895ffd83dbSDimitry Andric     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
16905ffd83dbSDimitry Andric       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
16915ffd83dbSDimitry Andric         AddUserArgs(*CI);
16925ffd83dbSDimitry Andric       return false;
16935ffd83dbSDimitry Andric     });
16945ffd83dbSDimitry Andric 
16955ffd83dbSDimitry Andric     // Transitively search for more arguments by looking at the users of the
16965ffd83dbSDimitry Andric     // ones we know already. During the search the GTIdArgs vector is extended
16975ffd83dbSDimitry Andric     // so we cannot cache the size nor can we use a range based for.
16985ffd83dbSDimitry Andric     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
16995ffd83dbSDimitry Andric       AddUserArgs(*GTIdArgs[u]);
17005ffd83dbSDimitry Andric   }
17015ffd83dbSDimitry Andric 
17025ffd83dbSDimitry Andric   /// Kernel (=GPU) optimizations and utility functions
17035ffd83dbSDimitry Andric   ///
17045ffd83dbSDimitry Andric   ///{{
17055ffd83dbSDimitry Andric 
17065ffd83dbSDimitry Andric   /// Check if \p F is a kernel, hence entry point for target offloading.
17075ffd83dbSDimitry Andric   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
17085ffd83dbSDimitry Andric 
17095ffd83dbSDimitry Andric   /// Cache to remember the unique kernel for a function.
17105ffd83dbSDimitry Andric   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
17115ffd83dbSDimitry Andric 
17125ffd83dbSDimitry Andric   /// Find the unique kernel that will execute \p F, if any.
17135ffd83dbSDimitry Andric   Kernel getUniqueKernelFor(Function &F);
17145ffd83dbSDimitry Andric 
17155ffd83dbSDimitry Andric   /// Find the unique kernel that will execute \p I, if any.
17165ffd83dbSDimitry Andric   Kernel getUniqueKernelFor(Instruction &I) {
17175ffd83dbSDimitry Andric     return getUniqueKernelFor(*I.getFunction());
17185ffd83dbSDimitry Andric   }
17195ffd83dbSDimitry Andric 
17205ffd83dbSDimitry Andric   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
17215ffd83dbSDimitry Andric   /// the cases we can avoid taking the address of a function.
17225ffd83dbSDimitry Andric   bool rewriteDeviceCodeStateMachine();
17235ffd83dbSDimitry Andric 
17245ffd83dbSDimitry Andric   ///
17255ffd83dbSDimitry Andric   ///}}
17265ffd83dbSDimitry Andric 
17275ffd83dbSDimitry Andric   /// Emit a remark generically
17285ffd83dbSDimitry Andric   ///
17295ffd83dbSDimitry Andric   /// This template function can be used to generically emit a remark. The
17305ffd83dbSDimitry Andric   /// RemarkKind should be one of the following:
17315ffd83dbSDimitry Andric   ///   - OptimizationRemark to indicate a successful optimization attempt
17325ffd83dbSDimitry Andric   ///   - OptimizationRemarkMissed to report a failed optimization attempt
17335ffd83dbSDimitry Andric   ///   - OptimizationRemarkAnalysis to provide additional information about an
17345ffd83dbSDimitry Andric   ///     optimization attempt
17355ffd83dbSDimitry Andric   ///
17365ffd83dbSDimitry Andric   /// The remark is built using a callback function provided by the caller that
17375ffd83dbSDimitry Andric   /// takes a RemarkKind as input and returns a RemarkKind.
1738*fe6060f1SDimitry Andric   template <typename RemarkKind, typename RemarkCallBack>
1739*fe6060f1SDimitry Andric   void emitRemark(Instruction *I, StringRef RemarkName,
17405ffd83dbSDimitry Andric                   RemarkCallBack &&RemarkCB) const {
1741*fe6060f1SDimitry Andric     Function *F = I->getParent()->getParent();
17425ffd83dbSDimitry Andric     auto &ORE = OREGetter(F);
17435ffd83dbSDimitry Andric 
1744*fe6060f1SDimitry Andric     if (RemarkName.startswith("OMP"))
17455ffd83dbSDimitry Andric       ORE.emit([&]() {
1746*fe6060f1SDimitry Andric         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
1747*fe6060f1SDimitry Andric                << " [" << RemarkName << "]";
17485ffd83dbSDimitry Andric       });
1749*fe6060f1SDimitry Andric     else
1750*fe6060f1SDimitry Andric       ORE.emit(
1751*fe6060f1SDimitry Andric           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
17525ffd83dbSDimitry Andric   }
17535ffd83dbSDimitry Andric 
1754*fe6060f1SDimitry Andric   /// Emit a remark on a function.
1755*fe6060f1SDimitry Andric   template <typename RemarkKind, typename RemarkCallBack>
1756*fe6060f1SDimitry Andric   void emitRemark(Function *F, StringRef RemarkName,
1757*fe6060f1SDimitry Andric                   RemarkCallBack &&RemarkCB) const {
1758*fe6060f1SDimitry Andric     auto &ORE = OREGetter(F);
1759*fe6060f1SDimitry Andric 
1760*fe6060f1SDimitry Andric     if (RemarkName.startswith("OMP"))
1761*fe6060f1SDimitry Andric       ORE.emit([&]() {
1762*fe6060f1SDimitry Andric         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
1763*fe6060f1SDimitry Andric                << " [" << RemarkName << "]";
1764*fe6060f1SDimitry Andric       });
1765*fe6060f1SDimitry Andric     else
1766*fe6060f1SDimitry Andric       ORE.emit(
1767*fe6060f1SDimitry Andric           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1768*fe6060f1SDimitry Andric   }
1769*fe6060f1SDimitry Andric 
1770*fe6060f1SDimitry Andric   /// RAII struct to temporarily change an RTL function's linkage to external.
1771*fe6060f1SDimitry Andric   /// This prevents it from being mistakenly removed by other optimizations.
1772*fe6060f1SDimitry Andric   struct ExternalizationRAII {
1773*fe6060f1SDimitry Andric     ExternalizationRAII(OMPInformationCache &OMPInfoCache,
1774*fe6060f1SDimitry Andric                         RuntimeFunction RFKind)
1775*fe6060f1SDimitry Andric         : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
1776*fe6060f1SDimitry Andric       if (!Declaration)
1777*fe6060f1SDimitry Andric         return;
1778*fe6060f1SDimitry Andric 
1779*fe6060f1SDimitry Andric       LinkageType = Declaration->getLinkage();
1780*fe6060f1SDimitry Andric       Declaration->setLinkage(GlobalValue::ExternalLinkage);
1781*fe6060f1SDimitry Andric     }
1782*fe6060f1SDimitry Andric 
1783*fe6060f1SDimitry Andric     ~ExternalizationRAII() {
1784*fe6060f1SDimitry Andric       if (!Declaration)
1785*fe6060f1SDimitry Andric         return;
1786*fe6060f1SDimitry Andric 
1787*fe6060f1SDimitry Andric       Declaration->setLinkage(LinkageType);
1788*fe6060f1SDimitry Andric     }
1789*fe6060f1SDimitry Andric 
1790*fe6060f1SDimitry Andric     Function *Declaration;
1791*fe6060f1SDimitry Andric     GlobalValue::LinkageTypes LinkageType;
1792*fe6060f1SDimitry Andric   };
1793*fe6060f1SDimitry Andric 
17945ffd83dbSDimitry Andric   /// The underlying module.
17955ffd83dbSDimitry Andric   Module &M;
17965ffd83dbSDimitry Andric 
17975ffd83dbSDimitry Andric   /// The SCC we are operating on.
17985ffd83dbSDimitry Andric   SmallVectorImpl<Function *> &SCC;
17995ffd83dbSDimitry Andric 
18005ffd83dbSDimitry Andric   /// Callback to update the call graph, the first argument is a removed call,
18015ffd83dbSDimitry Andric   /// the second an optional replacement call.
18025ffd83dbSDimitry Andric   CallGraphUpdater &CGUpdater;
18035ffd83dbSDimitry Andric 
18045ffd83dbSDimitry Andric   /// Callback to get an OptimizationRemarkEmitter from a Function *
18055ffd83dbSDimitry Andric   OptimizationRemarkGetter OREGetter;
18065ffd83dbSDimitry Andric 
18075ffd83dbSDimitry Andric   /// OpenMP-specific information cache. Also Used for Attributor runs.
18085ffd83dbSDimitry Andric   OMPInformationCache &OMPInfoCache;
18095ffd83dbSDimitry Andric 
18105ffd83dbSDimitry Andric   /// Attributor instance.
18115ffd83dbSDimitry Andric   Attributor &A;
18125ffd83dbSDimitry Andric 
18135ffd83dbSDimitry Andric   /// Helper function to run Attributor on SCC.
1814*fe6060f1SDimitry Andric   bool runAttributor(bool IsModulePass) {
18155ffd83dbSDimitry Andric     if (SCC.empty())
18165ffd83dbSDimitry Andric       return false;
18175ffd83dbSDimitry Andric 
1818*fe6060f1SDimitry Andric     // Temporarily make these function have external linkage so the Attributor
1819*fe6060f1SDimitry Andric     // doesn't remove them when we try to look them up later.
1820*fe6060f1SDimitry Andric     ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
1821*fe6060f1SDimitry Andric     ExternalizationRAII EndParallel(OMPInfoCache,
1822*fe6060f1SDimitry Andric                                     OMPRTL___kmpc_kernel_end_parallel);
1823*fe6060f1SDimitry Andric     ExternalizationRAII BarrierSPMD(OMPInfoCache,
1824*fe6060f1SDimitry Andric                                     OMPRTL___kmpc_barrier_simple_spmd);
1825*fe6060f1SDimitry Andric 
1826*fe6060f1SDimitry Andric     registerAAs(IsModulePass);
18275ffd83dbSDimitry Andric 
18285ffd83dbSDimitry Andric     ChangeStatus Changed = A.run();
18295ffd83dbSDimitry Andric 
18305ffd83dbSDimitry Andric     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
18315ffd83dbSDimitry Andric                       << " functions, result: " << Changed << ".\n");
18325ffd83dbSDimitry Andric 
18335ffd83dbSDimitry Andric     return Changed == ChangeStatus::CHANGED;
18345ffd83dbSDimitry Andric   }
18355ffd83dbSDimitry Andric 
1836*fe6060f1SDimitry Andric   void registerFoldRuntimeCall(RuntimeFunction RF);
1837*fe6060f1SDimitry Andric 
18385ffd83dbSDimitry Andric   /// Populate the Attributor with abstract attribute opportunities in the
18395ffd83dbSDimitry Andric   /// function.
1840*fe6060f1SDimitry Andric   void registerAAs(bool IsModulePass);
18415ffd83dbSDimitry Andric };
18425ffd83dbSDimitry Andric 
18435ffd83dbSDimitry Andric Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
18445ffd83dbSDimitry Andric   if (!OMPInfoCache.ModuleSlice.count(&F))
18455ffd83dbSDimitry Andric     return nullptr;
18465ffd83dbSDimitry Andric 
18475ffd83dbSDimitry Andric   // Use a scope to keep the lifetime of the CachedKernel short.
18485ffd83dbSDimitry Andric   {
18495ffd83dbSDimitry Andric     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
18505ffd83dbSDimitry Andric     if (CachedKernel)
18515ffd83dbSDimitry Andric       return *CachedKernel;
18525ffd83dbSDimitry Andric 
18535ffd83dbSDimitry Andric     // TODO: We should use an AA to create an (optimistic and callback
18545ffd83dbSDimitry Andric     //       call-aware) call graph. For now we stick to simple patterns that
18555ffd83dbSDimitry Andric     //       are less powerful, basically the worst fixpoint.
18565ffd83dbSDimitry Andric     if (isKernel(F)) {
18575ffd83dbSDimitry Andric       CachedKernel = Kernel(&F);
18585ffd83dbSDimitry Andric       return *CachedKernel;
18595ffd83dbSDimitry Andric     }
18605ffd83dbSDimitry Andric 
18615ffd83dbSDimitry Andric     CachedKernel = nullptr;
1862e8d8bef9SDimitry Andric     if (!F.hasLocalLinkage()) {
1863e8d8bef9SDimitry Andric 
1864e8d8bef9SDimitry Andric       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1865*fe6060f1SDimitry Andric       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1866*fe6060f1SDimitry Andric         return ORA << "Potentially unknown OpenMP target region caller.";
1867e8d8bef9SDimitry Andric       };
1868*fe6060f1SDimitry Andric       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1869e8d8bef9SDimitry Andric 
18705ffd83dbSDimitry Andric       return nullptr;
18715ffd83dbSDimitry Andric     }
1872e8d8bef9SDimitry Andric   }
18735ffd83dbSDimitry Andric 
18745ffd83dbSDimitry Andric   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
18755ffd83dbSDimitry Andric     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
18765ffd83dbSDimitry Andric       // Allow use in equality comparisons.
18775ffd83dbSDimitry Andric       if (Cmp->isEquality())
18785ffd83dbSDimitry Andric         return getUniqueKernelFor(*Cmp);
18795ffd83dbSDimitry Andric       return nullptr;
18805ffd83dbSDimitry Andric     }
18815ffd83dbSDimitry Andric     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
18825ffd83dbSDimitry Andric       // Allow direct calls.
18835ffd83dbSDimitry Andric       if (CB->isCallee(&U))
18845ffd83dbSDimitry Andric         return getUniqueKernelFor(*CB);
1885*fe6060f1SDimitry Andric 
1886*fe6060f1SDimitry Andric       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1887*fe6060f1SDimitry Andric           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1888*fe6060f1SDimitry Andric       // Allow the use in __kmpc_parallel_51 calls.
1889*fe6060f1SDimitry Andric       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
18905ffd83dbSDimitry Andric         return getUniqueKernelFor(*CB);
18915ffd83dbSDimitry Andric       return nullptr;
18925ffd83dbSDimitry Andric     }
18935ffd83dbSDimitry Andric     // Disallow every other use.
18945ffd83dbSDimitry Andric     return nullptr;
18955ffd83dbSDimitry Andric   };
18965ffd83dbSDimitry Andric 
18975ffd83dbSDimitry Andric   // TODO: In the future we want to track more than just a unique kernel.
18985ffd83dbSDimitry Andric   SmallPtrSet<Kernel, 2> PotentialKernels;
1899e8d8bef9SDimitry Andric   OMPInformationCache::foreachUse(F, [&](const Use &U) {
19005ffd83dbSDimitry Andric     PotentialKernels.insert(GetUniqueKernelForUse(U));
19015ffd83dbSDimitry Andric   });
19025ffd83dbSDimitry Andric 
19035ffd83dbSDimitry Andric   Kernel K = nullptr;
19045ffd83dbSDimitry Andric   if (PotentialKernels.size() == 1)
19055ffd83dbSDimitry Andric     K = *PotentialKernels.begin();
19065ffd83dbSDimitry Andric 
19075ffd83dbSDimitry Andric   // Cache the result.
19085ffd83dbSDimitry Andric   UniqueKernelMap[&F] = K;
19095ffd83dbSDimitry Andric 
19105ffd83dbSDimitry Andric   return K;
19115ffd83dbSDimitry Andric }
19125ffd83dbSDimitry Andric 
19135ffd83dbSDimitry Andric bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1914*fe6060f1SDimitry Andric   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1915*fe6060f1SDimitry Andric       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
19165ffd83dbSDimitry Andric 
19175ffd83dbSDimitry Andric   bool Changed = false;
1918*fe6060f1SDimitry Andric   if (!KernelParallelRFI)
19195ffd83dbSDimitry Andric     return Changed;
19205ffd83dbSDimitry Andric 
19215ffd83dbSDimitry Andric   for (Function *F : SCC) {
19225ffd83dbSDimitry Andric 
1923*fe6060f1SDimitry Andric     // Check if the function is a use in a __kmpc_parallel_51 call at
19245ffd83dbSDimitry Andric     // all.
19255ffd83dbSDimitry Andric     bool UnknownUse = false;
1926*fe6060f1SDimitry Andric     bool KernelParallelUse = false;
19275ffd83dbSDimitry Andric     unsigned NumDirectCalls = 0;
19285ffd83dbSDimitry Andric 
19295ffd83dbSDimitry Andric     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1930e8d8bef9SDimitry Andric     OMPInformationCache::foreachUse(*F, [&](Use &U) {
19315ffd83dbSDimitry Andric       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
19325ffd83dbSDimitry Andric         if (CB->isCallee(&U)) {
19335ffd83dbSDimitry Andric           ++NumDirectCalls;
19345ffd83dbSDimitry Andric           return;
19355ffd83dbSDimitry Andric         }
19365ffd83dbSDimitry Andric 
19375ffd83dbSDimitry Andric       if (isa<ICmpInst>(U.getUser())) {
19385ffd83dbSDimitry Andric         ToBeReplacedStateMachineUses.push_back(&U);
19395ffd83dbSDimitry Andric         return;
19405ffd83dbSDimitry Andric       }
1941*fe6060f1SDimitry Andric 
1942*fe6060f1SDimitry Andric       // Find wrapper functions that represent parallel kernels.
1943*fe6060f1SDimitry Andric       CallInst *CI =
1944*fe6060f1SDimitry Andric           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1945*fe6060f1SDimitry Andric       const unsigned int WrapperFunctionArgNo = 6;
1946*fe6060f1SDimitry Andric       if (!KernelParallelUse && CI &&
1947*fe6060f1SDimitry Andric           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1948*fe6060f1SDimitry Andric         KernelParallelUse = true;
19495ffd83dbSDimitry Andric         ToBeReplacedStateMachineUses.push_back(&U);
19505ffd83dbSDimitry Andric         return;
19515ffd83dbSDimitry Andric       }
19525ffd83dbSDimitry Andric       UnknownUse = true;
19535ffd83dbSDimitry Andric     });
19545ffd83dbSDimitry Andric 
1955*fe6060f1SDimitry Andric     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
19565ffd83dbSDimitry Andric     // use.
1957*fe6060f1SDimitry Andric     if (!KernelParallelUse)
19585ffd83dbSDimitry Andric       continue;
19595ffd83dbSDimitry Andric 
19605ffd83dbSDimitry Andric     // If this ever hits, we should investigate.
19615ffd83dbSDimitry Andric     // TODO: Checking the number of uses is not a necessary restriction and
19625ffd83dbSDimitry Andric     // should be lifted.
19635ffd83dbSDimitry Andric     if (UnknownUse || NumDirectCalls != 1 ||
1964*fe6060f1SDimitry Andric         ToBeReplacedStateMachineUses.size() > 2) {
1965*fe6060f1SDimitry Andric       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1966*fe6060f1SDimitry Andric         return ORA << "Parallel region is used in "
19675ffd83dbSDimitry Andric                    << (UnknownUse ? "unknown" : "unexpected")
1968*fe6060f1SDimitry Andric                    << " ways. Will not attempt to rewrite the state machine.";
19695ffd83dbSDimitry Andric       };
1970*fe6060f1SDimitry Andric       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
19715ffd83dbSDimitry Andric       continue;
19725ffd83dbSDimitry Andric     }
19735ffd83dbSDimitry Andric 
1974*fe6060f1SDimitry Andric     // Even if we have __kmpc_parallel_51 calls, we (for now) give
19755ffd83dbSDimitry Andric     // up if the function is not called from a unique kernel.
19765ffd83dbSDimitry Andric     Kernel K = getUniqueKernelFor(*F);
19775ffd83dbSDimitry Andric     if (!K) {
1978*fe6060f1SDimitry Andric       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1979*fe6060f1SDimitry Andric         return ORA << "Parallel region is not called from a unique kernel. "
1980*fe6060f1SDimitry Andric                       "Will not attempt to rewrite the state machine.";
19815ffd83dbSDimitry Andric       };
1982*fe6060f1SDimitry Andric       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
19835ffd83dbSDimitry Andric       continue;
19845ffd83dbSDimitry Andric     }
19855ffd83dbSDimitry Andric 
19865ffd83dbSDimitry Andric     // We now know F is a parallel body function called only from the kernel K.
19875ffd83dbSDimitry Andric     // We also identified the state machine uses in which we replace the
19885ffd83dbSDimitry Andric     // function pointer by a new global symbol for identification purposes. This
19895ffd83dbSDimitry Andric     // ensures only direct calls to the function are left.
19905ffd83dbSDimitry Andric 
19915ffd83dbSDimitry Andric     Module &M = *F->getParent();
19925ffd83dbSDimitry Andric     Type *Int8Ty = Type::getInt8Ty(M.getContext());
19935ffd83dbSDimitry Andric 
19945ffd83dbSDimitry Andric     auto *ID = new GlobalVariable(
19955ffd83dbSDimitry Andric         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
19965ffd83dbSDimitry Andric         UndefValue::get(Int8Ty), F->getName() + ".ID");
19975ffd83dbSDimitry Andric 
19985ffd83dbSDimitry Andric     for (Use *U : ToBeReplacedStateMachineUses)
19995ffd83dbSDimitry Andric       U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
20005ffd83dbSDimitry Andric 
20015ffd83dbSDimitry Andric     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
20025ffd83dbSDimitry Andric 
20035ffd83dbSDimitry Andric     Changed = true;
20045ffd83dbSDimitry Andric   }
20055ffd83dbSDimitry Andric 
20065ffd83dbSDimitry Andric   return Changed;
20075ffd83dbSDimitry Andric }
20085ffd83dbSDimitry Andric 
20095ffd83dbSDimitry Andric /// Abstract Attribute for tracking ICV values.
20105ffd83dbSDimitry Andric struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
20115ffd83dbSDimitry Andric   using Base = StateWrapper<BooleanState, AbstractAttribute>;
20125ffd83dbSDimitry Andric   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
20135ffd83dbSDimitry Andric 
2014e8d8bef9SDimitry Andric   void initialize(Attributor &A) override {
2015e8d8bef9SDimitry Andric     Function *F = getAnchorScope();
2016e8d8bef9SDimitry Andric     if (!F || !A.isFunctionIPOAmendable(*F))
2017e8d8bef9SDimitry Andric       indicatePessimisticFixpoint();
2018e8d8bef9SDimitry Andric   }
2019e8d8bef9SDimitry Andric 
20205ffd83dbSDimitry Andric   /// Returns true if value is assumed to be tracked.
20215ffd83dbSDimitry Andric   bool isAssumedTracked() const { return getAssumed(); }
20225ffd83dbSDimitry Andric 
20235ffd83dbSDimitry Andric   /// Returns true if value is known to be tracked.
20245ffd83dbSDimitry Andric   bool isKnownTracked() const { return getAssumed(); }
20255ffd83dbSDimitry Andric 
20265ffd83dbSDimitry Andric   /// Create an abstract attribute biew for the position \p IRP.
20275ffd83dbSDimitry Andric   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
20285ffd83dbSDimitry Andric 
20295ffd83dbSDimitry Andric   /// Return the value with which \p I can be replaced for specific \p ICV.
2030e8d8bef9SDimitry Andric   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2031e8d8bef9SDimitry Andric                                                 const Instruction *I,
2032e8d8bef9SDimitry Andric                                                 Attributor &A) const {
2033e8d8bef9SDimitry Andric     return None;
2034e8d8bef9SDimitry Andric   }
2035e8d8bef9SDimitry Andric 
2036e8d8bef9SDimitry Andric   /// Return an assumed unique ICV value if a single candidate is found. If
2037e8d8bef9SDimitry Andric   /// there cannot be one, return a nullptr. If it is not clear yet, return the
2038e8d8bef9SDimitry Andric   /// Optional::NoneType.
2039e8d8bef9SDimitry Andric   virtual Optional<Value *>
2040e8d8bef9SDimitry Andric   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2041e8d8bef9SDimitry Andric 
2042e8d8bef9SDimitry Andric   // Currently only nthreads is being tracked.
2043e8d8bef9SDimitry Andric   // this array will only grow with time.
2044e8d8bef9SDimitry Andric   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
20455ffd83dbSDimitry Andric 
20465ffd83dbSDimitry Andric   /// See AbstractAttribute::getName()
20475ffd83dbSDimitry Andric   const std::string getName() const override { return "AAICVTracker"; }
20485ffd83dbSDimitry Andric 
20495ffd83dbSDimitry Andric   /// See AbstractAttribute::getIdAddr()
20505ffd83dbSDimitry Andric   const char *getIdAddr() const override { return &ID; }
20515ffd83dbSDimitry Andric 
20525ffd83dbSDimitry Andric   /// This function should return true if the type of the \p AA is AAICVTracker
20535ffd83dbSDimitry Andric   static bool classof(const AbstractAttribute *AA) {
20545ffd83dbSDimitry Andric     return (AA->getIdAddr() == &ID);
20555ffd83dbSDimitry Andric   }
20565ffd83dbSDimitry Andric 
20575ffd83dbSDimitry Andric   static const char ID;
20585ffd83dbSDimitry Andric };
20595ffd83dbSDimitry Andric 
20605ffd83dbSDimitry Andric struct AAICVTrackerFunction : public AAICVTracker {
20615ffd83dbSDimitry Andric   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
20625ffd83dbSDimitry Andric       : AAICVTracker(IRP, A) {}
20635ffd83dbSDimitry Andric 
20645ffd83dbSDimitry Andric   // FIXME: come up with better string.
2065e8d8bef9SDimitry Andric   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
20665ffd83dbSDimitry Andric 
20675ffd83dbSDimitry Andric   // FIXME: come up with some stats.
20685ffd83dbSDimitry Andric   void trackStatistics() const override {}
20695ffd83dbSDimitry Andric 
2070e8d8bef9SDimitry Andric   /// We don't manifest anything for this AA.
20715ffd83dbSDimitry Andric   ChangeStatus manifest(Attributor &A) override {
2072e8d8bef9SDimitry Andric     return ChangeStatus::UNCHANGED;
20735ffd83dbSDimitry Andric   }
20745ffd83dbSDimitry Andric 
20755ffd83dbSDimitry Andric   // Map of ICV to their values at specific program point.
2076e8d8bef9SDimitry Andric   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
20775ffd83dbSDimitry Andric                   InternalControlVar::ICV___last>
2078e8d8bef9SDimitry Andric       ICVReplacementValuesMap;
20795ffd83dbSDimitry Andric 
20805ffd83dbSDimitry Andric   ChangeStatus updateImpl(Attributor &A) override {
20815ffd83dbSDimitry Andric     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
20825ffd83dbSDimitry Andric 
20835ffd83dbSDimitry Andric     Function *F = getAnchorScope();
20845ffd83dbSDimitry Andric 
20855ffd83dbSDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
20865ffd83dbSDimitry Andric 
20875ffd83dbSDimitry Andric     for (InternalControlVar ICV : TrackableICVs) {
20885ffd83dbSDimitry Andric       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
20895ffd83dbSDimitry Andric 
2090e8d8bef9SDimitry Andric       auto &ValuesMap = ICVReplacementValuesMap[ICV];
20915ffd83dbSDimitry Andric       auto TrackValues = [&](Use &U, Function &) {
20925ffd83dbSDimitry Andric         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
20935ffd83dbSDimitry Andric         if (!CI)
20945ffd83dbSDimitry Andric           return false;
20955ffd83dbSDimitry Andric 
20965ffd83dbSDimitry Andric         // FIXME: handle setters with more that 1 arguments.
20975ffd83dbSDimitry Andric         /// Track new value.
2098e8d8bef9SDimitry Andric         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
20995ffd83dbSDimitry Andric           HasChanged = ChangeStatus::CHANGED;
21005ffd83dbSDimitry Andric 
21015ffd83dbSDimitry Andric         return false;
21025ffd83dbSDimitry Andric       };
21035ffd83dbSDimitry Andric 
2104e8d8bef9SDimitry Andric       auto CallCheck = [&](Instruction &I) {
2105e8d8bef9SDimitry Andric         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
2106e8d8bef9SDimitry Andric         if (ReplVal.hasValue() &&
2107e8d8bef9SDimitry Andric             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2108e8d8bef9SDimitry Andric           HasChanged = ChangeStatus::CHANGED;
2109e8d8bef9SDimitry Andric 
2110e8d8bef9SDimitry Andric         return true;
2111e8d8bef9SDimitry Andric       };
2112e8d8bef9SDimitry Andric 
2113e8d8bef9SDimitry Andric       // Track all changes of an ICV.
21145ffd83dbSDimitry Andric       SetterRFI.foreachUse(TrackValues, F);
2115e8d8bef9SDimitry Andric 
2116*fe6060f1SDimitry Andric       bool UsedAssumedInformation = false;
2117e8d8bef9SDimitry Andric       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2118*fe6060f1SDimitry Andric                                 UsedAssumedInformation,
2119e8d8bef9SDimitry Andric                                 /* CheckBBLivenessOnly */ true);
2120e8d8bef9SDimitry Andric 
2121e8d8bef9SDimitry Andric       /// TODO: Figure out a way to avoid adding entry in
2122e8d8bef9SDimitry Andric       /// ICVReplacementValuesMap
2123e8d8bef9SDimitry Andric       Instruction *Entry = &F->getEntryBlock().front();
2124e8d8bef9SDimitry Andric       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2125e8d8bef9SDimitry Andric         ValuesMap.insert(std::make_pair(Entry, nullptr));
21265ffd83dbSDimitry Andric     }
21275ffd83dbSDimitry Andric 
21285ffd83dbSDimitry Andric     return HasChanged;
21295ffd83dbSDimitry Andric   }
21305ffd83dbSDimitry Andric 
2131e8d8bef9SDimitry Andric   /// Hepler to check if \p I is a call and get the value for it if it is
2132e8d8bef9SDimitry Andric   /// unique.
2133e8d8bef9SDimitry Andric   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
2134e8d8bef9SDimitry Andric                                     InternalControlVar &ICV) const {
21355ffd83dbSDimitry Andric 
2136e8d8bef9SDimitry Andric     const auto *CB = dyn_cast<CallBase>(I);
2137e8d8bef9SDimitry Andric     if (!CB || CB->hasFnAttr("no_openmp") ||
2138e8d8bef9SDimitry Andric         CB->hasFnAttr("no_openmp_routines"))
2139e8d8bef9SDimitry Andric       return None;
2140e8d8bef9SDimitry Andric 
21415ffd83dbSDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
21425ffd83dbSDimitry Andric     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2143e8d8bef9SDimitry Andric     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2144e8d8bef9SDimitry Andric     Function *CalledFunction = CB->getCalledFunction();
21455ffd83dbSDimitry Andric 
2146e8d8bef9SDimitry Andric     // Indirect call, assume ICV changes.
2147e8d8bef9SDimitry Andric     if (CalledFunction == nullptr)
2148e8d8bef9SDimitry Andric       return nullptr;
2149e8d8bef9SDimitry Andric     if (CalledFunction == GetterRFI.Declaration)
2150e8d8bef9SDimitry Andric       return None;
2151e8d8bef9SDimitry Andric     if (CalledFunction == SetterRFI.Declaration) {
2152e8d8bef9SDimitry Andric       if (ICVReplacementValuesMap[ICV].count(I))
2153e8d8bef9SDimitry Andric         return ICVReplacementValuesMap[ICV].lookup(I);
2154e8d8bef9SDimitry Andric 
2155e8d8bef9SDimitry Andric       return nullptr;
2156e8d8bef9SDimitry Andric     }
2157e8d8bef9SDimitry Andric 
2158e8d8bef9SDimitry Andric     // Since we don't know, assume it changes the ICV.
2159e8d8bef9SDimitry Andric     if (CalledFunction->isDeclaration())
2160e8d8bef9SDimitry Andric       return nullptr;
2161e8d8bef9SDimitry Andric 
2162*fe6060f1SDimitry Andric     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2163*fe6060f1SDimitry Andric         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2164e8d8bef9SDimitry Andric 
2165e8d8bef9SDimitry Andric     if (ICVTrackingAA.isAssumedTracked())
2166e8d8bef9SDimitry Andric       return ICVTrackingAA.getUniqueReplacementValue(ICV);
2167e8d8bef9SDimitry Andric 
2168e8d8bef9SDimitry Andric     // If we don't know, assume it changes.
2169e8d8bef9SDimitry Andric     return nullptr;
2170e8d8bef9SDimitry Andric   }
2171e8d8bef9SDimitry Andric 
2172e8d8bef9SDimitry Andric   // We don't check unique value for a function, so return None.
2173e8d8bef9SDimitry Andric   Optional<Value *>
2174e8d8bef9SDimitry Andric   getUniqueReplacementValue(InternalControlVar ICV) const override {
2175e8d8bef9SDimitry Andric     return None;
2176e8d8bef9SDimitry Andric   }
2177e8d8bef9SDimitry Andric 
2178e8d8bef9SDimitry Andric   /// Return the value with which \p I can be replaced for specific \p ICV.
2179e8d8bef9SDimitry Andric   Optional<Value *> getReplacementValue(InternalControlVar ICV,
2180e8d8bef9SDimitry Andric                                         const Instruction *I,
2181e8d8bef9SDimitry Andric                                         Attributor &A) const override {
2182e8d8bef9SDimitry Andric     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2183e8d8bef9SDimitry Andric     if (ValuesMap.count(I))
2184e8d8bef9SDimitry Andric       return ValuesMap.lookup(I);
2185e8d8bef9SDimitry Andric 
2186e8d8bef9SDimitry Andric     SmallVector<const Instruction *, 16> Worklist;
2187e8d8bef9SDimitry Andric     SmallPtrSet<const Instruction *, 16> Visited;
2188e8d8bef9SDimitry Andric     Worklist.push_back(I);
2189e8d8bef9SDimitry Andric 
2190e8d8bef9SDimitry Andric     Optional<Value *> ReplVal;
2191e8d8bef9SDimitry Andric 
2192e8d8bef9SDimitry Andric     while (!Worklist.empty()) {
2193e8d8bef9SDimitry Andric       const Instruction *CurrInst = Worklist.pop_back_val();
2194e8d8bef9SDimitry Andric       if (!Visited.insert(CurrInst).second)
21955ffd83dbSDimitry Andric         continue;
21965ffd83dbSDimitry Andric 
2197e8d8bef9SDimitry Andric       const BasicBlock *CurrBB = CurrInst->getParent();
2198e8d8bef9SDimitry Andric 
2199e8d8bef9SDimitry Andric       // Go up and look for all potential setters/calls that might change the
2200e8d8bef9SDimitry Andric       // ICV.
2201e8d8bef9SDimitry Andric       while ((CurrInst = CurrInst->getPrevNode())) {
2202e8d8bef9SDimitry Andric         if (ValuesMap.count(CurrInst)) {
2203e8d8bef9SDimitry Andric           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2204e8d8bef9SDimitry Andric           // Unknown value, track new.
2205e8d8bef9SDimitry Andric           if (!ReplVal.hasValue()) {
2206e8d8bef9SDimitry Andric             ReplVal = NewReplVal;
2207e8d8bef9SDimitry Andric             break;
2208e8d8bef9SDimitry Andric           }
2209e8d8bef9SDimitry Andric 
2210e8d8bef9SDimitry Andric           // If we found a new value, we can't know the icv value anymore.
2211e8d8bef9SDimitry Andric           if (NewReplVal.hasValue())
2212e8d8bef9SDimitry Andric             if (ReplVal != NewReplVal)
22135ffd83dbSDimitry Andric               return nullptr;
22145ffd83dbSDimitry Andric 
2215e8d8bef9SDimitry Andric           break;
22165ffd83dbSDimitry Andric         }
22175ffd83dbSDimitry Andric 
2218e8d8bef9SDimitry Andric         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
2219e8d8bef9SDimitry Andric         if (!NewReplVal.hasValue())
2220e8d8bef9SDimitry Andric           continue;
2221e8d8bef9SDimitry Andric 
2222e8d8bef9SDimitry Andric         // Unknown value, track new.
2223e8d8bef9SDimitry Andric         if (!ReplVal.hasValue()) {
2224e8d8bef9SDimitry Andric           ReplVal = NewReplVal;
2225e8d8bef9SDimitry Andric           break;
22265ffd83dbSDimitry Andric         }
22275ffd83dbSDimitry Andric 
2228e8d8bef9SDimitry Andric         // if (NewReplVal.hasValue())
2229e8d8bef9SDimitry Andric         // We found a new value, we can't know the icv value anymore.
2230e8d8bef9SDimitry Andric         if (ReplVal != NewReplVal)
22315ffd83dbSDimitry Andric           return nullptr;
22325ffd83dbSDimitry Andric       }
2233e8d8bef9SDimitry Andric 
2234e8d8bef9SDimitry Andric       // If we are in the same BB and we have a value, we are done.
2235e8d8bef9SDimitry Andric       if (CurrBB == I->getParent() && ReplVal.hasValue())
2236e8d8bef9SDimitry Andric         return ReplVal;
2237e8d8bef9SDimitry Andric 
2238e8d8bef9SDimitry Andric       // Go through all predecessors and add terminators for analysis.
2239e8d8bef9SDimitry Andric       for (const BasicBlock *Pred : predecessors(CurrBB))
2240e8d8bef9SDimitry Andric         if (const Instruction *Terminator = Pred->getTerminator())
2241e8d8bef9SDimitry Andric           Worklist.push_back(Terminator);
2242e8d8bef9SDimitry Andric     }
2243e8d8bef9SDimitry Andric 
2244e8d8bef9SDimitry Andric     return ReplVal;
2245e8d8bef9SDimitry Andric   }
2246e8d8bef9SDimitry Andric };
2247e8d8bef9SDimitry Andric 
2248e8d8bef9SDimitry Andric struct AAICVTrackerFunctionReturned : AAICVTracker {
2249e8d8bef9SDimitry Andric   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2250e8d8bef9SDimitry Andric       : AAICVTracker(IRP, A) {}
2251e8d8bef9SDimitry Andric 
2252e8d8bef9SDimitry Andric   // FIXME: come up with better string.
2253e8d8bef9SDimitry Andric   const std::string getAsStr() const override {
2254e8d8bef9SDimitry Andric     return "ICVTrackerFunctionReturned";
2255e8d8bef9SDimitry Andric   }
2256e8d8bef9SDimitry Andric 
2257e8d8bef9SDimitry Andric   // FIXME: come up with some stats.
2258e8d8bef9SDimitry Andric   void trackStatistics() const override {}
2259e8d8bef9SDimitry Andric 
2260e8d8bef9SDimitry Andric   /// We don't manifest anything for this AA.
2261e8d8bef9SDimitry Andric   ChangeStatus manifest(Attributor &A) override {
2262e8d8bef9SDimitry Andric     return ChangeStatus::UNCHANGED;
2263e8d8bef9SDimitry Andric   }
2264e8d8bef9SDimitry Andric 
2265e8d8bef9SDimitry Andric   // Map of ICV to their values at specific program point.
2266e8d8bef9SDimitry Andric   EnumeratedArray<Optional<Value *>, InternalControlVar,
2267e8d8bef9SDimitry Andric                   InternalControlVar::ICV___last>
2268e8d8bef9SDimitry Andric       ICVReplacementValuesMap;
2269e8d8bef9SDimitry Andric 
2270e8d8bef9SDimitry Andric   /// Return the value with which \p I can be replaced for specific \p ICV.
2271e8d8bef9SDimitry Andric   Optional<Value *>
2272e8d8bef9SDimitry Andric   getUniqueReplacementValue(InternalControlVar ICV) const override {
2273e8d8bef9SDimitry Andric     return ICVReplacementValuesMap[ICV];
2274e8d8bef9SDimitry Andric   }
2275e8d8bef9SDimitry Andric 
2276e8d8bef9SDimitry Andric   ChangeStatus updateImpl(Attributor &A) override {
2277e8d8bef9SDimitry Andric     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2278e8d8bef9SDimitry Andric     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2279*fe6060f1SDimitry Andric         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2280e8d8bef9SDimitry Andric 
2281e8d8bef9SDimitry Andric     if (!ICVTrackingAA.isAssumedTracked())
2282e8d8bef9SDimitry Andric       return indicatePessimisticFixpoint();
2283e8d8bef9SDimitry Andric 
2284e8d8bef9SDimitry Andric     for (InternalControlVar ICV : TrackableICVs) {
2285e8d8bef9SDimitry Andric       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2286e8d8bef9SDimitry Andric       Optional<Value *> UniqueICVValue;
2287e8d8bef9SDimitry Andric 
2288e8d8bef9SDimitry Andric       auto CheckReturnInst = [&](Instruction &I) {
2289e8d8bef9SDimitry Andric         Optional<Value *> NewReplVal =
2290e8d8bef9SDimitry Andric             ICVTrackingAA.getReplacementValue(ICV, &I, A);
2291e8d8bef9SDimitry Andric 
2292e8d8bef9SDimitry Andric         // If we found a second ICV value there is no unique returned value.
2293e8d8bef9SDimitry Andric         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2294e8d8bef9SDimitry Andric           return false;
2295e8d8bef9SDimitry Andric 
2296e8d8bef9SDimitry Andric         UniqueICVValue = NewReplVal;
2297e8d8bef9SDimitry Andric 
2298e8d8bef9SDimitry Andric         return true;
2299e8d8bef9SDimitry Andric       };
2300e8d8bef9SDimitry Andric 
2301*fe6060f1SDimitry Andric       bool UsedAssumedInformation = false;
2302e8d8bef9SDimitry Andric       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2303*fe6060f1SDimitry Andric                                      UsedAssumedInformation,
2304e8d8bef9SDimitry Andric                                      /* CheckBBLivenessOnly */ true))
2305e8d8bef9SDimitry Andric         UniqueICVValue = nullptr;
2306e8d8bef9SDimitry Andric 
2307e8d8bef9SDimitry Andric       if (UniqueICVValue == ReplVal)
2308e8d8bef9SDimitry Andric         continue;
2309e8d8bef9SDimitry Andric 
2310e8d8bef9SDimitry Andric       ReplVal = UniqueICVValue;
2311e8d8bef9SDimitry Andric       Changed = ChangeStatus::CHANGED;
2312e8d8bef9SDimitry Andric     }
2313e8d8bef9SDimitry Andric 
2314e8d8bef9SDimitry Andric     return Changed;
2315e8d8bef9SDimitry Andric   }
2316e8d8bef9SDimitry Andric };
2317e8d8bef9SDimitry Andric 
2318e8d8bef9SDimitry Andric struct AAICVTrackerCallSite : AAICVTracker {
2319e8d8bef9SDimitry Andric   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2320e8d8bef9SDimitry Andric       : AAICVTracker(IRP, A) {}
2321e8d8bef9SDimitry Andric 
2322e8d8bef9SDimitry Andric   void initialize(Attributor &A) override {
2323e8d8bef9SDimitry Andric     Function *F = getAnchorScope();
2324e8d8bef9SDimitry Andric     if (!F || !A.isFunctionIPOAmendable(*F))
2325e8d8bef9SDimitry Andric       indicatePessimisticFixpoint();
2326e8d8bef9SDimitry Andric 
2327e8d8bef9SDimitry Andric     // We only initialize this AA for getters, so we need to know which ICV it
2328e8d8bef9SDimitry Andric     // gets.
2329e8d8bef9SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2330e8d8bef9SDimitry Andric     for (InternalControlVar ICV : TrackableICVs) {
2331e8d8bef9SDimitry Andric       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2332e8d8bef9SDimitry Andric       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2333e8d8bef9SDimitry Andric       if (Getter.Declaration == getAssociatedFunction()) {
2334e8d8bef9SDimitry Andric         AssociatedICV = ICVInfo.Kind;
2335e8d8bef9SDimitry Andric         return;
2336e8d8bef9SDimitry Andric       }
2337e8d8bef9SDimitry Andric     }
2338e8d8bef9SDimitry Andric 
2339e8d8bef9SDimitry Andric     /// Unknown ICV.
2340e8d8bef9SDimitry Andric     indicatePessimisticFixpoint();
2341e8d8bef9SDimitry Andric   }
2342e8d8bef9SDimitry Andric 
2343e8d8bef9SDimitry Andric   ChangeStatus manifest(Attributor &A) override {
2344e8d8bef9SDimitry Andric     if (!ReplVal.hasValue() || !ReplVal.getValue())
2345e8d8bef9SDimitry Andric       return ChangeStatus::UNCHANGED;
2346e8d8bef9SDimitry Andric 
2347e8d8bef9SDimitry Andric     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2348e8d8bef9SDimitry Andric     A.deleteAfterManifest(*getCtxI());
2349e8d8bef9SDimitry Andric 
2350e8d8bef9SDimitry Andric     return ChangeStatus::CHANGED;
2351e8d8bef9SDimitry Andric   }
2352e8d8bef9SDimitry Andric 
2353e8d8bef9SDimitry Andric   // FIXME: come up with better string.
2354e8d8bef9SDimitry Andric   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2355e8d8bef9SDimitry Andric 
2356e8d8bef9SDimitry Andric   // FIXME: come up with some stats.
2357e8d8bef9SDimitry Andric   void trackStatistics() const override {}
2358e8d8bef9SDimitry Andric 
2359e8d8bef9SDimitry Andric   InternalControlVar AssociatedICV;
2360e8d8bef9SDimitry Andric   Optional<Value *> ReplVal;
2361e8d8bef9SDimitry Andric 
2362e8d8bef9SDimitry Andric   ChangeStatus updateImpl(Attributor &A) override {
2363e8d8bef9SDimitry Andric     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2364*fe6060f1SDimitry Andric         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2365e8d8bef9SDimitry Andric 
2366e8d8bef9SDimitry Andric     // We don't have any information, so we assume it changes the ICV.
2367e8d8bef9SDimitry Andric     if (!ICVTrackingAA.isAssumedTracked())
2368e8d8bef9SDimitry Andric       return indicatePessimisticFixpoint();
2369e8d8bef9SDimitry Andric 
2370e8d8bef9SDimitry Andric     Optional<Value *> NewReplVal =
2371e8d8bef9SDimitry Andric         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2372e8d8bef9SDimitry Andric 
2373e8d8bef9SDimitry Andric     if (ReplVal == NewReplVal)
2374e8d8bef9SDimitry Andric       return ChangeStatus::UNCHANGED;
2375e8d8bef9SDimitry Andric 
2376e8d8bef9SDimitry Andric     ReplVal = NewReplVal;
2377e8d8bef9SDimitry Andric     return ChangeStatus::CHANGED;
2378e8d8bef9SDimitry Andric   }
2379e8d8bef9SDimitry Andric 
2380e8d8bef9SDimitry Andric   // Return the value with which associated value can be replaced for specific
2381e8d8bef9SDimitry Andric   // \p ICV.
2382e8d8bef9SDimitry Andric   Optional<Value *>
2383e8d8bef9SDimitry Andric   getUniqueReplacementValue(InternalControlVar ICV) const override {
2384e8d8bef9SDimitry Andric     return ReplVal;
2385e8d8bef9SDimitry Andric   }
2386e8d8bef9SDimitry Andric };
2387e8d8bef9SDimitry Andric 
2388e8d8bef9SDimitry Andric struct AAICVTrackerCallSiteReturned : AAICVTracker {
2389e8d8bef9SDimitry Andric   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2390e8d8bef9SDimitry Andric       : AAICVTracker(IRP, A) {}
2391e8d8bef9SDimitry Andric 
2392e8d8bef9SDimitry Andric   // FIXME: come up with better string.
2393e8d8bef9SDimitry Andric   const std::string getAsStr() const override {
2394e8d8bef9SDimitry Andric     return "ICVTrackerCallSiteReturned";
2395e8d8bef9SDimitry Andric   }
2396e8d8bef9SDimitry Andric 
2397e8d8bef9SDimitry Andric   // FIXME: come up with some stats.
2398e8d8bef9SDimitry Andric   void trackStatistics() const override {}
2399e8d8bef9SDimitry Andric 
2400e8d8bef9SDimitry Andric   /// We don't manifest anything for this AA.
2401e8d8bef9SDimitry Andric   ChangeStatus manifest(Attributor &A) override {
2402e8d8bef9SDimitry Andric     return ChangeStatus::UNCHANGED;
2403e8d8bef9SDimitry Andric   }
2404e8d8bef9SDimitry Andric 
2405e8d8bef9SDimitry Andric   // Map of ICV to their values at specific program point.
2406e8d8bef9SDimitry Andric   EnumeratedArray<Optional<Value *>, InternalControlVar,
2407e8d8bef9SDimitry Andric                   InternalControlVar::ICV___last>
2408e8d8bef9SDimitry Andric       ICVReplacementValuesMap;
2409e8d8bef9SDimitry Andric 
2410e8d8bef9SDimitry Andric   /// Return the value with which associated value can be replaced for specific
2411e8d8bef9SDimitry Andric   /// \p ICV.
2412e8d8bef9SDimitry Andric   Optional<Value *>
2413e8d8bef9SDimitry Andric   getUniqueReplacementValue(InternalControlVar ICV) const override {
2414e8d8bef9SDimitry Andric     return ICVReplacementValuesMap[ICV];
2415e8d8bef9SDimitry Andric   }
2416e8d8bef9SDimitry Andric 
2417e8d8bef9SDimitry Andric   ChangeStatus updateImpl(Attributor &A) override {
2418e8d8bef9SDimitry Andric     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2419e8d8bef9SDimitry Andric     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2420*fe6060f1SDimitry Andric         *this, IRPosition::returned(*getAssociatedFunction()),
2421*fe6060f1SDimitry Andric         DepClassTy::REQUIRED);
2422e8d8bef9SDimitry Andric 
2423e8d8bef9SDimitry Andric     // We don't have any information, so we assume it changes the ICV.
2424e8d8bef9SDimitry Andric     if (!ICVTrackingAA.isAssumedTracked())
2425e8d8bef9SDimitry Andric       return indicatePessimisticFixpoint();
2426e8d8bef9SDimitry Andric 
2427e8d8bef9SDimitry Andric     for (InternalControlVar ICV : TrackableICVs) {
2428e8d8bef9SDimitry Andric       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2429e8d8bef9SDimitry Andric       Optional<Value *> NewReplVal =
2430e8d8bef9SDimitry Andric           ICVTrackingAA.getUniqueReplacementValue(ICV);
2431e8d8bef9SDimitry Andric 
2432e8d8bef9SDimitry Andric       if (ReplVal == NewReplVal)
2433e8d8bef9SDimitry Andric         continue;
2434e8d8bef9SDimitry Andric 
2435e8d8bef9SDimitry Andric       ReplVal = NewReplVal;
2436e8d8bef9SDimitry Andric       Changed = ChangeStatus::CHANGED;
2437e8d8bef9SDimitry Andric     }
2438e8d8bef9SDimitry Andric     return Changed;
2439e8d8bef9SDimitry Andric   }
24405ffd83dbSDimitry Andric };
2441*fe6060f1SDimitry Andric 
2442*fe6060f1SDimitry Andric struct AAExecutionDomainFunction : public AAExecutionDomain {
2443*fe6060f1SDimitry Andric   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2444*fe6060f1SDimitry Andric       : AAExecutionDomain(IRP, A) {}
2445*fe6060f1SDimitry Andric 
2446*fe6060f1SDimitry Andric   const std::string getAsStr() const override {
2447*fe6060f1SDimitry Andric     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2448*fe6060f1SDimitry Andric            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2449*fe6060f1SDimitry Andric   }
2450*fe6060f1SDimitry Andric 
2451*fe6060f1SDimitry Andric   /// See AbstractAttribute::trackStatistics().
2452*fe6060f1SDimitry Andric   void trackStatistics() const override {}
2453*fe6060f1SDimitry Andric 
2454*fe6060f1SDimitry Andric   void initialize(Attributor &A) override {
2455*fe6060f1SDimitry Andric     Function *F = getAnchorScope();
2456*fe6060f1SDimitry Andric     for (const auto &BB : *F)
2457*fe6060f1SDimitry Andric       SingleThreadedBBs.insert(&BB);
2458*fe6060f1SDimitry Andric     NumBBs = SingleThreadedBBs.size();
2459*fe6060f1SDimitry Andric   }
2460*fe6060f1SDimitry Andric 
2461*fe6060f1SDimitry Andric   ChangeStatus manifest(Attributor &A) override {
2462*fe6060f1SDimitry Andric     LLVM_DEBUG({
2463*fe6060f1SDimitry Andric       for (const BasicBlock *BB : SingleThreadedBBs)
2464*fe6060f1SDimitry Andric         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2465*fe6060f1SDimitry Andric                << BB->getName() << " is executed by a single thread.\n";
2466*fe6060f1SDimitry Andric     });
2467*fe6060f1SDimitry Andric     return ChangeStatus::UNCHANGED;
2468*fe6060f1SDimitry Andric   }
2469*fe6060f1SDimitry Andric 
2470*fe6060f1SDimitry Andric   ChangeStatus updateImpl(Attributor &A) override;
2471*fe6060f1SDimitry Andric 
2472*fe6060f1SDimitry Andric   /// Check if an instruction is executed by a single thread.
2473*fe6060f1SDimitry Andric   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2474*fe6060f1SDimitry Andric     return isExecutedByInitialThreadOnly(*I.getParent());
2475*fe6060f1SDimitry Andric   }
2476*fe6060f1SDimitry Andric 
2477*fe6060f1SDimitry Andric   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2478*fe6060f1SDimitry Andric     return isValidState() && SingleThreadedBBs.contains(&BB);
2479*fe6060f1SDimitry Andric   }
2480*fe6060f1SDimitry Andric 
2481*fe6060f1SDimitry Andric   /// Set of basic blocks that are executed by a single thread.
2482*fe6060f1SDimitry Andric   DenseSet<const BasicBlock *> SingleThreadedBBs;
2483*fe6060f1SDimitry Andric 
2484*fe6060f1SDimitry Andric   /// Total number of basic blocks in this function.
2485*fe6060f1SDimitry Andric   long unsigned NumBBs;
2486*fe6060f1SDimitry Andric };
2487*fe6060f1SDimitry Andric 
2488*fe6060f1SDimitry Andric ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2489*fe6060f1SDimitry Andric   Function *F = getAnchorScope();
2490*fe6060f1SDimitry Andric   ReversePostOrderTraversal<Function *> RPOT(F);
2491*fe6060f1SDimitry Andric   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2492*fe6060f1SDimitry Andric 
2493*fe6060f1SDimitry Andric   bool AllCallSitesKnown;
2494*fe6060f1SDimitry Andric   auto PredForCallSite = [&](AbstractCallSite ACS) {
2495*fe6060f1SDimitry Andric     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2496*fe6060f1SDimitry Andric         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2497*fe6060f1SDimitry Andric         DepClassTy::REQUIRED);
2498*fe6060f1SDimitry Andric     return ACS.isDirectCall() &&
2499*fe6060f1SDimitry Andric            ExecutionDomainAA.isExecutedByInitialThreadOnly(
2500*fe6060f1SDimitry Andric                *ACS.getInstruction());
2501*fe6060f1SDimitry Andric   };
2502*fe6060f1SDimitry Andric 
2503*fe6060f1SDimitry Andric   if (!A.checkForAllCallSites(PredForCallSite, *this,
2504*fe6060f1SDimitry Andric                               /* RequiresAllCallSites */ true,
2505*fe6060f1SDimitry Andric                               AllCallSitesKnown))
2506*fe6060f1SDimitry Andric     SingleThreadedBBs.erase(&F->getEntryBlock());
2507*fe6060f1SDimitry Andric 
2508*fe6060f1SDimitry Andric   auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2509*fe6060f1SDimitry Andric   auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2510*fe6060f1SDimitry Andric 
2511*fe6060f1SDimitry Andric   // Check if the edge into the successor block compares the __kmpc_target_init
2512*fe6060f1SDimitry Andric   // result with -1. If we are in non-SPMD-mode that signals only the main
2513*fe6060f1SDimitry Andric   // thread will execute the edge.
2514*fe6060f1SDimitry Andric   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2515*fe6060f1SDimitry Andric     if (!Edge || !Edge->isConditional())
2516*fe6060f1SDimitry Andric       return false;
2517*fe6060f1SDimitry Andric     if (Edge->getSuccessor(0) != SuccessorBB)
2518*fe6060f1SDimitry Andric       return false;
2519*fe6060f1SDimitry Andric 
2520*fe6060f1SDimitry Andric     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2521*fe6060f1SDimitry Andric     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2522*fe6060f1SDimitry Andric       return false;
2523*fe6060f1SDimitry Andric 
2524*fe6060f1SDimitry Andric     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2525*fe6060f1SDimitry Andric     if (!C)
2526*fe6060f1SDimitry Andric       return false;
2527*fe6060f1SDimitry Andric 
2528*fe6060f1SDimitry Andric     // Match:  -1 == __kmpc_target_init (for non-SPMD kernels only!)
2529*fe6060f1SDimitry Andric     if (C->isAllOnesValue()) {
2530*fe6060f1SDimitry Andric       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2531*fe6060f1SDimitry Andric       CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2532*fe6060f1SDimitry Andric       if (!CB)
2533*fe6060f1SDimitry Andric         return false;
2534*fe6060f1SDimitry Andric       const int InitIsSPMDArgNo = 1;
2535*fe6060f1SDimitry Andric       auto *IsSPMDModeCI =
2536*fe6060f1SDimitry Andric           dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
2537*fe6060f1SDimitry Andric       return IsSPMDModeCI && IsSPMDModeCI->isZero();
2538*fe6060f1SDimitry Andric     }
2539*fe6060f1SDimitry Andric 
2540*fe6060f1SDimitry Andric     return false;
2541*fe6060f1SDimitry Andric   };
2542*fe6060f1SDimitry Andric 
2543*fe6060f1SDimitry Andric   // Merge all the predecessor states into the current basic block. A basic
2544*fe6060f1SDimitry Andric   // block is executed by a single thread if all of its predecessors are.
2545*fe6060f1SDimitry Andric   auto MergePredecessorStates = [&](BasicBlock *BB) {
2546*fe6060f1SDimitry Andric     if (pred_begin(BB) == pred_end(BB))
2547*fe6060f1SDimitry Andric       return SingleThreadedBBs.contains(BB);
2548*fe6060f1SDimitry Andric 
2549*fe6060f1SDimitry Andric     bool IsInitialThread = true;
2550*fe6060f1SDimitry Andric     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
2551*fe6060f1SDimitry Andric          PredBB != PredEndBB; ++PredBB) {
2552*fe6060f1SDimitry Andric       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
2553*fe6060f1SDimitry Andric                                BB))
2554*fe6060f1SDimitry Andric         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
2555*fe6060f1SDimitry Andric     }
2556*fe6060f1SDimitry Andric 
2557*fe6060f1SDimitry Andric     return IsInitialThread;
2558*fe6060f1SDimitry Andric   };
2559*fe6060f1SDimitry Andric 
2560*fe6060f1SDimitry Andric   for (auto *BB : RPOT) {
2561*fe6060f1SDimitry Andric     if (!MergePredecessorStates(BB))
2562*fe6060f1SDimitry Andric       SingleThreadedBBs.erase(BB);
2563*fe6060f1SDimitry Andric   }
2564*fe6060f1SDimitry Andric 
2565*fe6060f1SDimitry Andric   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2566*fe6060f1SDimitry Andric              ? ChangeStatus::UNCHANGED
2567*fe6060f1SDimitry Andric              : ChangeStatus::CHANGED;
2568*fe6060f1SDimitry Andric }
2569*fe6060f1SDimitry Andric 
2570*fe6060f1SDimitry Andric /// Try to replace memory allocation calls called by a single thread with a
2571*fe6060f1SDimitry Andric /// static buffer of shared memory.
2572*fe6060f1SDimitry Andric struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2573*fe6060f1SDimitry Andric   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2574*fe6060f1SDimitry Andric   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2575*fe6060f1SDimitry Andric 
2576*fe6060f1SDimitry Andric   /// Create an abstract attribute view for the position \p IRP.
2577*fe6060f1SDimitry Andric   static AAHeapToShared &createForPosition(const IRPosition &IRP,
2578*fe6060f1SDimitry Andric                                            Attributor &A);
2579*fe6060f1SDimitry Andric 
2580*fe6060f1SDimitry Andric   /// Returns true if HeapToShared conversion is assumed to be possible.
2581*fe6060f1SDimitry Andric   virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2582*fe6060f1SDimitry Andric 
2583*fe6060f1SDimitry Andric   /// Returns true if HeapToShared conversion is assumed and the CB is a
2584*fe6060f1SDimitry Andric   /// callsite to a free operation to be removed.
2585*fe6060f1SDimitry Andric   virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2586*fe6060f1SDimitry Andric 
2587*fe6060f1SDimitry Andric   /// See AbstractAttribute::getName().
2588*fe6060f1SDimitry Andric   const std::string getName() const override { return "AAHeapToShared"; }
2589*fe6060f1SDimitry Andric 
2590*fe6060f1SDimitry Andric   /// See AbstractAttribute::getIdAddr().
2591*fe6060f1SDimitry Andric   const char *getIdAddr() const override { return &ID; }
2592*fe6060f1SDimitry Andric 
2593*fe6060f1SDimitry Andric   /// This function should return true if the type of the \p AA is
2594*fe6060f1SDimitry Andric   /// AAHeapToShared.
2595*fe6060f1SDimitry Andric   static bool classof(const AbstractAttribute *AA) {
2596*fe6060f1SDimitry Andric     return (AA->getIdAddr() == &ID);
2597*fe6060f1SDimitry Andric   }
2598*fe6060f1SDimitry Andric 
2599*fe6060f1SDimitry Andric   /// Unique ID (due to the unique address)
2600*fe6060f1SDimitry Andric   static const char ID;
2601*fe6060f1SDimitry Andric };
2602*fe6060f1SDimitry Andric 
2603*fe6060f1SDimitry Andric struct AAHeapToSharedFunction : public AAHeapToShared {
2604*fe6060f1SDimitry Andric   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2605*fe6060f1SDimitry Andric       : AAHeapToShared(IRP, A) {}
2606*fe6060f1SDimitry Andric 
2607*fe6060f1SDimitry Andric   const std::string getAsStr() const override {
2608*fe6060f1SDimitry Andric     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2609*fe6060f1SDimitry Andric            " malloc calls eligible.";
2610*fe6060f1SDimitry Andric   }
2611*fe6060f1SDimitry Andric 
2612*fe6060f1SDimitry Andric   /// See AbstractAttribute::trackStatistics().
2613*fe6060f1SDimitry Andric   void trackStatistics() const override {}
2614*fe6060f1SDimitry Andric 
2615*fe6060f1SDimitry Andric   /// This functions finds free calls that will be removed by the
2616*fe6060f1SDimitry Andric   /// HeapToShared transformation.
2617*fe6060f1SDimitry Andric   void findPotentialRemovedFreeCalls(Attributor &A) {
2618*fe6060f1SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2619*fe6060f1SDimitry Andric     auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2620*fe6060f1SDimitry Andric 
2621*fe6060f1SDimitry Andric     PotentialRemovedFreeCalls.clear();
2622*fe6060f1SDimitry Andric     // Update free call users of found malloc calls.
2623*fe6060f1SDimitry Andric     for (CallBase *CB : MallocCalls) {
2624*fe6060f1SDimitry Andric       SmallVector<CallBase *, 4> FreeCalls;
2625*fe6060f1SDimitry Andric       for (auto *U : CB->users()) {
2626*fe6060f1SDimitry Andric         CallBase *C = dyn_cast<CallBase>(U);
2627*fe6060f1SDimitry Andric         if (C && C->getCalledFunction() == FreeRFI.Declaration)
2628*fe6060f1SDimitry Andric           FreeCalls.push_back(C);
2629*fe6060f1SDimitry Andric       }
2630*fe6060f1SDimitry Andric 
2631*fe6060f1SDimitry Andric       if (FreeCalls.size() != 1)
2632*fe6060f1SDimitry Andric         continue;
2633*fe6060f1SDimitry Andric 
2634*fe6060f1SDimitry Andric       PotentialRemovedFreeCalls.insert(FreeCalls.front());
2635*fe6060f1SDimitry Andric     }
2636*fe6060f1SDimitry Andric   }
2637*fe6060f1SDimitry Andric 
2638*fe6060f1SDimitry Andric   void initialize(Attributor &A) override {
2639*fe6060f1SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2640*fe6060f1SDimitry Andric     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2641*fe6060f1SDimitry Andric 
2642*fe6060f1SDimitry Andric     for (User *U : RFI.Declaration->users())
2643*fe6060f1SDimitry Andric       if (CallBase *CB = dyn_cast<CallBase>(U))
2644*fe6060f1SDimitry Andric         MallocCalls.insert(CB);
2645*fe6060f1SDimitry Andric 
2646*fe6060f1SDimitry Andric     findPotentialRemovedFreeCalls(A);
2647*fe6060f1SDimitry Andric   }
2648*fe6060f1SDimitry Andric 
2649*fe6060f1SDimitry Andric   bool isAssumedHeapToShared(CallBase &CB) const override {
2650*fe6060f1SDimitry Andric     return isValidState() && MallocCalls.count(&CB);
2651*fe6060f1SDimitry Andric   }
2652*fe6060f1SDimitry Andric 
2653*fe6060f1SDimitry Andric   bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2654*fe6060f1SDimitry Andric     return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2655*fe6060f1SDimitry Andric   }
2656*fe6060f1SDimitry Andric 
2657*fe6060f1SDimitry Andric   ChangeStatus manifest(Attributor &A) override {
2658*fe6060f1SDimitry Andric     if (MallocCalls.empty())
2659*fe6060f1SDimitry Andric       return ChangeStatus::UNCHANGED;
2660*fe6060f1SDimitry Andric 
2661*fe6060f1SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2662*fe6060f1SDimitry Andric     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2663*fe6060f1SDimitry Andric 
2664*fe6060f1SDimitry Andric     Function *F = getAnchorScope();
2665*fe6060f1SDimitry Andric     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2666*fe6060f1SDimitry Andric                                             DepClassTy::OPTIONAL);
2667*fe6060f1SDimitry Andric 
2668*fe6060f1SDimitry Andric     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2669*fe6060f1SDimitry Andric     for (CallBase *CB : MallocCalls) {
2670*fe6060f1SDimitry Andric       // Skip replacing this if HeapToStack has already claimed it.
2671*fe6060f1SDimitry Andric       if (HS && HS->isAssumedHeapToStack(*CB))
2672*fe6060f1SDimitry Andric         continue;
2673*fe6060f1SDimitry Andric 
2674*fe6060f1SDimitry Andric       // Find the unique free call to remove it.
2675*fe6060f1SDimitry Andric       SmallVector<CallBase *, 4> FreeCalls;
2676*fe6060f1SDimitry Andric       for (auto *U : CB->users()) {
2677*fe6060f1SDimitry Andric         CallBase *C = dyn_cast<CallBase>(U);
2678*fe6060f1SDimitry Andric         if (C && C->getCalledFunction() == FreeCall.Declaration)
2679*fe6060f1SDimitry Andric           FreeCalls.push_back(C);
2680*fe6060f1SDimitry Andric       }
2681*fe6060f1SDimitry Andric       if (FreeCalls.size() != 1)
2682*fe6060f1SDimitry Andric         continue;
2683*fe6060f1SDimitry Andric 
2684*fe6060f1SDimitry Andric       ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
2685*fe6060f1SDimitry Andric 
2686*fe6060f1SDimitry Andric       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call in "
2687*fe6060f1SDimitry Andric                         << CB->getCaller()->getName() << " with "
2688*fe6060f1SDimitry Andric                         << AllocSize->getZExtValue()
2689*fe6060f1SDimitry Andric                         << " bytes of shared memory\n");
2690*fe6060f1SDimitry Andric 
2691*fe6060f1SDimitry Andric       // Create a new shared memory buffer of the same size as the allocation
2692*fe6060f1SDimitry Andric       // and replace all the uses of the original allocation with it.
2693*fe6060f1SDimitry Andric       Module *M = CB->getModule();
2694*fe6060f1SDimitry Andric       Type *Int8Ty = Type::getInt8Ty(M->getContext());
2695*fe6060f1SDimitry Andric       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2696*fe6060f1SDimitry Andric       auto *SharedMem = new GlobalVariable(
2697*fe6060f1SDimitry Andric           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2698*fe6060f1SDimitry Andric           UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
2699*fe6060f1SDimitry Andric           GlobalValue::NotThreadLocal,
2700*fe6060f1SDimitry Andric           static_cast<unsigned>(AddressSpace::Shared));
2701*fe6060f1SDimitry Andric       auto *NewBuffer =
2702*fe6060f1SDimitry Andric           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2703*fe6060f1SDimitry Andric 
2704*fe6060f1SDimitry Andric       auto Remark = [&](OptimizationRemark OR) {
2705*fe6060f1SDimitry Andric         return OR << "Replaced globalized variable with "
2706*fe6060f1SDimitry Andric                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
2707*fe6060f1SDimitry Andric                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2708*fe6060f1SDimitry Andric                   << "of shared memory.";
2709*fe6060f1SDimitry Andric       };
2710*fe6060f1SDimitry Andric       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
2711*fe6060f1SDimitry Andric 
2712*fe6060f1SDimitry Andric       SharedMem->setAlignment(MaybeAlign(32));
2713*fe6060f1SDimitry Andric 
2714*fe6060f1SDimitry Andric       A.changeValueAfterManifest(*CB, *NewBuffer);
2715*fe6060f1SDimitry Andric       A.deleteAfterManifest(*CB);
2716*fe6060f1SDimitry Andric       A.deleteAfterManifest(*FreeCalls.front());
2717*fe6060f1SDimitry Andric 
2718*fe6060f1SDimitry Andric       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2719*fe6060f1SDimitry Andric       Changed = ChangeStatus::CHANGED;
2720*fe6060f1SDimitry Andric     }
2721*fe6060f1SDimitry Andric 
2722*fe6060f1SDimitry Andric     return Changed;
2723*fe6060f1SDimitry Andric   }
2724*fe6060f1SDimitry Andric 
2725*fe6060f1SDimitry Andric   ChangeStatus updateImpl(Attributor &A) override {
2726*fe6060f1SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2727*fe6060f1SDimitry Andric     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2728*fe6060f1SDimitry Andric     Function *F = getAnchorScope();
2729*fe6060f1SDimitry Andric 
2730*fe6060f1SDimitry Andric     auto NumMallocCalls = MallocCalls.size();
2731*fe6060f1SDimitry Andric 
2732*fe6060f1SDimitry Andric     // Only consider malloc calls executed by a single thread with a constant.
2733*fe6060f1SDimitry Andric     for (User *U : RFI.Declaration->users()) {
2734*fe6060f1SDimitry Andric       const auto &ED = A.getAAFor<AAExecutionDomain>(
2735*fe6060f1SDimitry Andric           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2736*fe6060f1SDimitry Andric       if (CallBase *CB = dyn_cast<CallBase>(U))
2737*fe6060f1SDimitry Andric         if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
2738*fe6060f1SDimitry Andric             !ED.isExecutedByInitialThreadOnly(*CB))
2739*fe6060f1SDimitry Andric           MallocCalls.erase(CB);
2740*fe6060f1SDimitry Andric     }
2741*fe6060f1SDimitry Andric 
2742*fe6060f1SDimitry Andric     findPotentialRemovedFreeCalls(A);
2743*fe6060f1SDimitry Andric 
2744*fe6060f1SDimitry Andric     if (NumMallocCalls != MallocCalls.size())
2745*fe6060f1SDimitry Andric       return ChangeStatus::CHANGED;
2746*fe6060f1SDimitry Andric 
2747*fe6060f1SDimitry Andric     return ChangeStatus::UNCHANGED;
2748*fe6060f1SDimitry Andric   }
2749*fe6060f1SDimitry Andric 
2750*fe6060f1SDimitry Andric   /// Collection of all malloc calls in a function.
2751*fe6060f1SDimitry Andric   SmallPtrSet<CallBase *, 4> MallocCalls;
2752*fe6060f1SDimitry Andric   /// Collection of potentially removed free calls in a function.
2753*fe6060f1SDimitry Andric   SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
2754*fe6060f1SDimitry Andric };
2755*fe6060f1SDimitry Andric 
2756*fe6060f1SDimitry Andric struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2757*fe6060f1SDimitry Andric   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2758*fe6060f1SDimitry Andric   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2759*fe6060f1SDimitry Andric 
2760*fe6060f1SDimitry Andric   /// Statistics are tracked as part of manifest for now.
2761*fe6060f1SDimitry Andric   void trackStatistics() const override {}
2762*fe6060f1SDimitry Andric 
2763*fe6060f1SDimitry Andric   /// See AbstractAttribute::getAsStr()
2764*fe6060f1SDimitry Andric   const std::string getAsStr() const override {
2765*fe6060f1SDimitry Andric     if (!isValidState())
2766*fe6060f1SDimitry Andric       return "<invalid>";
2767*fe6060f1SDimitry Andric     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2768*fe6060f1SDimitry Andric                                                             : "generic") +
2769*fe6060f1SDimitry Andric            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2770*fe6060f1SDimitry Andric                                                                : "") +
2771*fe6060f1SDimitry Andric            std::string(" #PRs: ") +
2772*fe6060f1SDimitry Andric            std::to_string(ReachedKnownParallelRegions.size()) +
2773*fe6060f1SDimitry Andric            ", #Unknown PRs: " +
2774*fe6060f1SDimitry Andric            std::to_string(ReachedUnknownParallelRegions.size());
2775*fe6060f1SDimitry Andric   }
2776*fe6060f1SDimitry Andric 
2777*fe6060f1SDimitry Andric   /// Create an abstract attribute biew for the position \p IRP.
2778*fe6060f1SDimitry Andric   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2779*fe6060f1SDimitry Andric 
2780*fe6060f1SDimitry Andric   /// See AbstractAttribute::getName()
2781*fe6060f1SDimitry Andric   const std::string getName() const override { return "AAKernelInfo"; }
2782*fe6060f1SDimitry Andric 
2783*fe6060f1SDimitry Andric   /// See AbstractAttribute::getIdAddr()
2784*fe6060f1SDimitry Andric   const char *getIdAddr() const override { return &ID; }
2785*fe6060f1SDimitry Andric 
2786*fe6060f1SDimitry Andric   /// This function should return true if the type of the \p AA is AAKernelInfo
2787*fe6060f1SDimitry Andric   static bool classof(const AbstractAttribute *AA) {
2788*fe6060f1SDimitry Andric     return (AA->getIdAddr() == &ID);
2789*fe6060f1SDimitry Andric   }
2790*fe6060f1SDimitry Andric 
2791*fe6060f1SDimitry Andric   static const char ID;
2792*fe6060f1SDimitry Andric };
2793*fe6060f1SDimitry Andric 
2794*fe6060f1SDimitry Andric /// The function kernel info abstract attribute, basically, what can we say
2795*fe6060f1SDimitry Andric /// about a function with regards to the KernelInfoState.
2796*fe6060f1SDimitry Andric struct AAKernelInfoFunction : AAKernelInfo {
2797*fe6060f1SDimitry Andric   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2798*fe6060f1SDimitry Andric       : AAKernelInfo(IRP, A) {}
2799*fe6060f1SDimitry Andric 
2800*fe6060f1SDimitry Andric   /// See AbstractAttribute::initialize(...).
2801*fe6060f1SDimitry Andric   void initialize(Attributor &A) override {
2802*fe6060f1SDimitry Andric     // This is a high-level transform that might change the constant arguments
2803*fe6060f1SDimitry Andric     // of the init and dinit calls. We need to tell the Attributor about this
2804*fe6060f1SDimitry Andric     // to avoid other parts using the current constant value for simpliication.
2805*fe6060f1SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2806*fe6060f1SDimitry Andric 
2807*fe6060f1SDimitry Andric     Function *Fn = getAnchorScope();
2808*fe6060f1SDimitry Andric     if (!OMPInfoCache.Kernels.count(Fn))
2809*fe6060f1SDimitry Andric       return;
2810*fe6060f1SDimitry Andric 
2811*fe6060f1SDimitry Andric     // Add itself to the reaching kernel and set IsKernelEntry.
2812*fe6060f1SDimitry Andric     ReachingKernelEntries.insert(Fn);
2813*fe6060f1SDimitry Andric     IsKernelEntry = true;
2814*fe6060f1SDimitry Andric 
2815*fe6060f1SDimitry Andric     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2816*fe6060f1SDimitry Andric         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2817*fe6060f1SDimitry Andric     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2818*fe6060f1SDimitry Andric         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2819*fe6060f1SDimitry Andric 
2820*fe6060f1SDimitry Andric     // For kernels we perform more initialization work, first we find the init
2821*fe6060f1SDimitry Andric     // and deinit calls.
2822*fe6060f1SDimitry Andric     auto StoreCallBase = [](Use &U,
2823*fe6060f1SDimitry Andric                             OMPInformationCache::RuntimeFunctionInfo &RFI,
2824*fe6060f1SDimitry Andric                             CallBase *&Storage) {
2825*fe6060f1SDimitry Andric       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2826*fe6060f1SDimitry Andric       assert(CB &&
2827*fe6060f1SDimitry Andric              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2828*fe6060f1SDimitry Andric       assert(!Storage &&
2829*fe6060f1SDimitry Andric              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2830*fe6060f1SDimitry Andric       Storage = CB;
2831*fe6060f1SDimitry Andric       return false;
2832*fe6060f1SDimitry Andric     };
2833*fe6060f1SDimitry Andric     InitRFI.foreachUse(
2834*fe6060f1SDimitry Andric         [&](Use &U, Function &) {
2835*fe6060f1SDimitry Andric           StoreCallBase(U, InitRFI, KernelInitCB);
2836*fe6060f1SDimitry Andric           return false;
2837*fe6060f1SDimitry Andric         },
2838*fe6060f1SDimitry Andric         Fn);
2839*fe6060f1SDimitry Andric     DeinitRFI.foreachUse(
2840*fe6060f1SDimitry Andric         [&](Use &U, Function &) {
2841*fe6060f1SDimitry Andric           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2842*fe6060f1SDimitry Andric           return false;
2843*fe6060f1SDimitry Andric         },
2844*fe6060f1SDimitry Andric         Fn);
2845*fe6060f1SDimitry Andric 
2846*fe6060f1SDimitry Andric     assert((KernelInitCB && KernelDeinitCB) &&
2847*fe6060f1SDimitry Andric            "Kernel without __kmpc_target_init or __kmpc_target_deinit!");
2848*fe6060f1SDimitry Andric 
2849*fe6060f1SDimitry Andric     // For kernels we might need to initialize/finalize the IsSPMD state and
2850*fe6060f1SDimitry Andric     // we need to register a simplification callback so that the Attributor
2851*fe6060f1SDimitry Andric     // knows the constant arguments to __kmpc_target_init and
2852*fe6060f1SDimitry Andric     // __kmpc_target_deinit might actually change.
2853*fe6060f1SDimitry Andric 
2854*fe6060f1SDimitry Andric     Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2855*fe6060f1SDimitry Andric         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2856*fe6060f1SDimitry Andric             bool &UsedAssumedInformation) -> Optional<Value *> {
2857*fe6060f1SDimitry Andric       // IRP represents the "use generic state machine" argument of an
2858*fe6060f1SDimitry Andric       // __kmpc_target_init call. We will answer this one with the internal
2859*fe6060f1SDimitry Andric       // state. As long as we are not in an invalid state, we will create a
2860*fe6060f1SDimitry Andric       // custom state machine so the value should be a `i1 false`. If we are
2861*fe6060f1SDimitry Andric       // in an invalid state, we won't change the value that is in the IR.
2862*fe6060f1SDimitry Andric       if (!isValidState())
2863*fe6060f1SDimitry Andric         return nullptr;
2864*fe6060f1SDimitry Andric       if (AA)
2865*fe6060f1SDimitry Andric         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2866*fe6060f1SDimitry Andric       UsedAssumedInformation = !isAtFixpoint();
2867*fe6060f1SDimitry Andric       auto *FalseVal =
2868*fe6060f1SDimitry Andric           ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0);
2869*fe6060f1SDimitry Andric       return FalseVal;
2870*fe6060f1SDimitry Andric     };
2871*fe6060f1SDimitry Andric 
2872*fe6060f1SDimitry Andric     Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB =
2873*fe6060f1SDimitry Andric         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2874*fe6060f1SDimitry Andric             bool &UsedAssumedInformation) -> Optional<Value *> {
2875*fe6060f1SDimitry Andric       // IRP represents the "SPMDCompatibilityTracker" argument of an
2876*fe6060f1SDimitry Andric       // __kmpc_target_init or
2877*fe6060f1SDimitry Andric       // __kmpc_target_deinit call. We will answer this one with the internal
2878*fe6060f1SDimitry Andric       // state.
2879*fe6060f1SDimitry Andric       if (!SPMDCompatibilityTracker.isValidState())
2880*fe6060f1SDimitry Andric         return nullptr;
2881*fe6060f1SDimitry Andric       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2882*fe6060f1SDimitry Andric         if (AA)
2883*fe6060f1SDimitry Andric           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2884*fe6060f1SDimitry Andric         UsedAssumedInformation = true;
2885*fe6060f1SDimitry Andric       } else {
2886*fe6060f1SDimitry Andric         UsedAssumedInformation = false;
2887*fe6060f1SDimitry Andric       }
2888*fe6060f1SDimitry Andric       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2889*fe6060f1SDimitry Andric                                        SPMDCompatibilityTracker.isAssumed());
2890*fe6060f1SDimitry Andric       return Val;
2891*fe6060f1SDimitry Andric     };
2892*fe6060f1SDimitry Andric 
2893*fe6060f1SDimitry Andric     Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
2894*fe6060f1SDimitry Andric         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2895*fe6060f1SDimitry Andric             bool &UsedAssumedInformation) -> Optional<Value *> {
2896*fe6060f1SDimitry Andric       // IRP represents the "RequiresFullRuntime" argument of an
2897*fe6060f1SDimitry Andric       // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
2898*fe6060f1SDimitry Andric       // one with the internal state of the SPMDCompatibilityTracker, so if
2899*fe6060f1SDimitry Andric       // generic then true, if SPMD then false.
2900*fe6060f1SDimitry Andric       if (!SPMDCompatibilityTracker.isValidState())
2901*fe6060f1SDimitry Andric         return nullptr;
2902*fe6060f1SDimitry Andric       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2903*fe6060f1SDimitry Andric         if (AA)
2904*fe6060f1SDimitry Andric           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2905*fe6060f1SDimitry Andric         UsedAssumedInformation = true;
2906*fe6060f1SDimitry Andric       } else {
2907*fe6060f1SDimitry Andric         UsedAssumedInformation = false;
2908*fe6060f1SDimitry Andric       }
2909*fe6060f1SDimitry Andric       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2910*fe6060f1SDimitry Andric                                        !SPMDCompatibilityTracker.isAssumed());
2911*fe6060f1SDimitry Andric       return Val;
2912*fe6060f1SDimitry Andric     };
2913*fe6060f1SDimitry Andric 
2914*fe6060f1SDimitry Andric     constexpr const int InitIsSPMDArgNo = 1;
2915*fe6060f1SDimitry Andric     constexpr const int DeinitIsSPMDArgNo = 1;
2916*fe6060f1SDimitry Andric     constexpr const int InitUseStateMachineArgNo = 2;
2917*fe6060f1SDimitry Andric     constexpr const int InitRequiresFullRuntimeArgNo = 3;
2918*fe6060f1SDimitry Andric     constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
2919*fe6060f1SDimitry Andric     A.registerSimplificationCallback(
2920*fe6060f1SDimitry Andric         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
2921*fe6060f1SDimitry Andric         StateMachineSimplifyCB);
2922*fe6060f1SDimitry Andric     A.registerSimplificationCallback(
2923*fe6060f1SDimitry Andric         IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo),
2924*fe6060f1SDimitry Andric         IsSPMDModeSimplifyCB);
2925*fe6060f1SDimitry Andric     A.registerSimplificationCallback(
2926*fe6060f1SDimitry Andric         IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo),
2927*fe6060f1SDimitry Andric         IsSPMDModeSimplifyCB);
2928*fe6060f1SDimitry Andric     A.registerSimplificationCallback(
2929*fe6060f1SDimitry Andric         IRPosition::callsite_argument(*KernelInitCB,
2930*fe6060f1SDimitry Andric                                       InitRequiresFullRuntimeArgNo),
2931*fe6060f1SDimitry Andric         IsGenericModeSimplifyCB);
2932*fe6060f1SDimitry Andric     A.registerSimplificationCallback(
2933*fe6060f1SDimitry Andric         IRPosition::callsite_argument(*KernelDeinitCB,
2934*fe6060f1SDimitry Andric                                       DeinitRequiresFullRuntimeArgNo),
2935*fe6060f1SDimitry Andric         IsGenericModeSimplifyCB);
2936*fe6060f1SDimitry Andric 
2937*fe6060f1SDimitry Andric     // Check if we know we are in SPMD-mode already.
2938*fe6060f1SDimitry Andric     ConstantInt *IsSPMDArg =
2939*fe6060f1SDimitry Andric         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
2940*fe6060f1SDimitry Andric     if (IsSPMDArg && !IsSPMDArg->isZero())
2941*fe6060f1SDimitry Andric       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
2942*fe6060f1SDimitry Andric   }
2943*fe6060f1SDimitry Andric 
2944*fe6060f1SDimitry Andric   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
2945*fe6060f1SDimitry Andric   /// finished now.
2946*fe6060f1SDimitry Andric   ChangeStatus manifest(Attributor &A) override {
2947*fe6060f1SDimitry Andric     // If we are not looking at a kernel with __kmpc_target_init and
2948*fe6060f1SDimitry Andric     // __kmpc_target_deinit call we cannot actually manifest the information.
2949*fe6060f1SDimitry Andric     if (!KernelInitCB || !KernelDeinitCB)
2950*fe6060f1SDimitry Andric       return ChangeStatus::UNCHANGED;
2951*fe6060f1SDimitry Andric 
2952*fe6060f1SDimitry Andric     // Known SPMD-mode kernels need no manifest changes.
2953*fe6060f1SDimitry Andric     if (SPMDCompatibilityTracker.isKnown())
2954*fe6060f1SDimitry Andric       return ChangeStatus::UNCHANGED;
2955*fe6060f1SDimitry Andric 
2956*fe6060f1SDimitry Andric     // If we can we change the execution mode to SPMD-mode otherwise we build a
2957*fe6060f1SDimitry Andric     // custom state machine.
2958*fe6060f1SDimitry Andric     if (!changeToSPMDMode(A))
2959*fe6060f1SDimitry Andric       buildCustomStateMachine(A);
2960*fe6060f1SDimitry Andric 
2961*fe6060f1SDimitry Andric     return ChangeStatus::CHANGED;
2962*fe6060f1SDimitry Andric   }
2963*fe6060f1SDimitry Andric 
2964*fe6060f1SDimitry Andric   bool changeToSPMDMode(Attributor &A) {
2965*fe6060f1SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2966*fe6060f1SDimitry Andric 
2967*fe6060f1SDimitry Andric     if (!SPMDCompatibilityTracker.isAssumed()) {
2968*fe6060f1SDimitry Andric       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
2969*fe6060f1SDimitry Andric         if (!NonCompatibleI)
2970*fe6060f1SDimitry Andric           continue;
2971*fe6060f1SDimitry Andric 
2972*fe6060f1SDimitry Andric         // Skip diagnostics on calls to known OpenMP runtime functions for now.
2973*fe6060f1SDimitry Andric         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
2974*fe6060f1SDimitry Andric           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
2975*fe6060f1SDimitry Andric             continue;
2976*fe6060f1SDimitry Andric 
2977*fe6060f1SDimitry Andric         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2978*fe6060f1SDimitry Andric           ORA << "Value has potential side effects preventing SPMD-mode "
2979*fe6060f1SDimitry Andric                  "execution";
2980*fe6060f1SDimitry Andric           if (isa<CallBase>(NonCompatibleI)) {
2981*fe6060f1SDimitry Andric             ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
2982*fe6060f1SDimitry Andric                    "the called function to override";
2983*fe6060f1SDimitry Andric           }
2984*fe6060f1SDimitry Andric           return ORA << ".";
2985*fe6060f1SDimitry Andric         };
2986*fe6060f1SDimitry Andric         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
2987*fe6060f1SDimitry Andric                                                  Remark);
2988*fe6060f1SDimitry Andric 
2989*fe6060f1SDimitry Andric         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
2990*fe6060f1SDimitry Andric                           << *NonCompatibleI << "\n");
2991*fe6060f1SDimitry Andric       }
2992*fe6060f1SDimitry Andric 
2993*fe6060f1SDimitry Andric       return false;
2994*fe6060f1SDimitry Andric     }
2995*fe6060f1SDimitry Andric 
2996*fe6060f1SDimitry Andric     // Adjust the global exec mode flag that tells the runtime what mode this
2997*fe6060f1SDimitry Andric     // kernel is executed in.
2998*fe6060f1SDimitry Andric     Function *Kernel = getAnchorScope();
2999*fe6060f1SDimitry Andric     GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
3000*fe6060f1SDimitry Andric         (Kernel->getName() + "_exec_mode").str());
3001*fe6060f1SDimitry Andric     assert(ExecMode && "Kernel without exec mode?");
3002*fe6060f1SDimitry Andric     assert(ExecMode->getInitializer() &&
3003*fe6060f1SDimitry Andric            ExecMode->getInitializer()->isOneValue() &&
3004*fe6060f1SDimitry Andric            "Initially non-SPMD kernel has SPMD exec mode!");
3005*fe6060f1SDimitry Andric 
3006*fe6060f1SDimitry Andric     // Set the global exec mode flag to indicate SPMD-Generic mode.
3007*fe6060f1SDimitry Andric     constexpr int SPMDGeneric = 2;
3008*fe6060f1SDimitry Andric     if (!ExecMode->getInitializer()->isZeroValue())
3009*fe6060f1SDimitry Andric       ExecMode->setInitializer(
3010*fe6060f1SDimitry Andric           ConstantInt::get(ExecMode->getInitializer()->getType(), SPMDGeneric));
3011*fe6060f1SDimitry Andric 
3012*fe6060f1SDimitry Andric     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3013*fe6060f1SDimitry Andric     const int InitIsSPMDArgNo = 1;
3014*fe6060f1SDimitry Andric     const int DeinitIsSPMDArgNo = 1;
3015*fe6060f1SDimitry Andric     const int InitUseStateMachineArgNo = 2;
3016*fe6060f1SDimitry Andric     const int InitRequiresFullRuntimeArgNo = 3;
3017*fe6060f1SDimitry Andric     const int DeinitRequiresFullRuntimeArgNo = 2;
3018*fe6060f1SDimitry Andric 
3019*fe6060f1SDimitry Andric     auto &Ctx = getAnchorValue().getContext();
3020*fe6060f1SDimitry Andric     A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo),
3021*fe6060f1SDimitry Andric                              *ConstantInt::getBool(Ctx, 1));
3022*fe6060f1SDimitry Andric     A.changeUseAfterManifest(
3023*fe6060f1SDimitry Andric         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3024*fe6060f1SDimitry Andric         *ConstantInt::getBool(Ctx, 0));
3025*fe6060f1SDimitry Andric     A.changeUseAfterManifest(
3026*fe6060f1SDimitry Andric         KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo),
3027*fe6060f1SDimitry Andric         *ConstantInt::getBool(Ctx, 1));
3028*fe6060f1SDimitry Andric     A.changeUseAfterManifest(
3029*fe6060f1SDimitry Andric         KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3030*fe6060f1SDimitry Andric         *ConstantInt::getBool(Ctx, 0));
3031*fe6060f1SDimitry Andric     A.changeUseAfterManifest(
3032*fe6060f1SDimitry Andric         KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3033*fe6060f1SDimitry Andric         *ConstantInt::getBool(Ctx, 0));
3034*fe6060f1SDimitry Andric 
3035*fe6060f1SDimitry Andric     ++NumOpenMPTargetRegionKernelsSPMD;
3036*fe6060f1SDimitry Andric 
3037*fe6060f1SDimitry Andric     auto Remark = [&](OptimizationRemark OR) {
3038*fe6060f1SDimitry Andric       return OR << "Transformed generic-mode kernel to SPMD-mode.";
3039*fe6060f1SDimitry Andric     };
3040*fe6060f1SDimitry Andric     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3041*fe6060f1SDimitry Andric     return true;
3042*fe6060f1SDimitry Andric   };
3043*fe6060f1SDimitry Andric 
3044*fe6060f1SDimitry Andric   ChangeStatus buildCustomStateMachine(Attributor &A) {
3045*fe6060f1SDimitry Andric     assert(ReachedKnownParallelRegions.isValidState() &&
3046*fe6060f1SDimitry Andric            "Custom state machine with invalid parallel region states?");
3047*fe6060f1SDimitry Andric 
3048*fe6060f1SDimitry Andric     const int InitIsSPMDArgNo = 1;
3049*fe6060f1SDimitry Andric     const int InitUseStateMachineArgNo = 2;
3050*fe6060f1SDimitry Andric 
3051*fe6060f1SDimitry Andric     // Check if the current configuration is non-SPMD and generic state machine.
3052*fe6060f1SDimitry Andric     // If we already have SPMD mode or a custom state machine we do not need to
3053*fe6060f1SDimitry Andric     // go any further. If it is anything but a constant something is weird and
3054*fe6060f1SDimitry Andric     // we give up.
3055*fe6060f1SDimitry Andric     ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3056*fe6060f1SDimitry Andric         KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3057*fe6060f1SDimitry Andric     ConstantInt *IsSPMD =
3058*fe6060f1SDimitry Andric         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
3059*fe6060f1SDimitry Andric 
3060*fe6060f1SDimitry Andric     // If we are stuck with generic mode, try to create a custom device (=GPU)
3061*fe6060f1SDimitry Andric     // state machine which is specialized for the parallel regions that are
3062*fe6060f1SDimitry Andric     // reachable by the kernel.
3063*fe6060f1SDimitry Andric     if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD ||
3064*fe6060f1SDimitry Andric         !IsSPMD->isZero())
3065*fe6060f1SDimitry Andric       return ChangeStatus::UNCHANGED;
3066*fe6060f1SDimitry Andric 
3067*fe6060f1SDimitry Andric     // If not SPMD mode, indicate we use a custom state machine now.
3068*fe6060f1SDimitry Andric     auto &Ctx = getAnchorValue().getContext();
3069*fe6060f1SDimitry Andric     auto *FalseVal = ConstantInt::getBool(Ctx, 0);
3070*fe6060f1SDimitry Andric     A.changeUseAfterManifest(
3071*fe6060f1SDimitry Andric         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3072*fe6060f1SDimitry Andric 
3073*fe6060f1SDimitry Andric     // If we don't actually need a state machine we are done here. This can
3074*fe6060f1SDimitry Andric     // happen if there simply are no parallel regions. In the resulting kernel
3075*fe6060f1SDimitry Andric     // all worker threads will simply exit right away, leaving the main thread
3076*fe6060f1SDimitry Andric     // to do the work alone.
3077*fe6060f1SDimitry Andric     if (ReachedKnownParallelRegions.empty() &&
3078*fe6060f1SDimitry Andric         ReachedUnknownParallelRegions.empty()) {
3079*fe6060f1SDimitry Andric       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3080*fe6060f1SDimitry Andric 
3081*fe6060f1SDimitry Andric       auto Remark = [&](OptimizationRemark OR) {
3082*fe6060f1SDimitry Andric         return OR << "Removing unused state machine from generic-mode kernel.";
3083*fe6060f1SDimitry Andric       };
3084*fe6060f1SDimitry Andric       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3085*fe6060f1SDimitry Andric 
3086*fe6060f1SDimitry Andric       return ChangeStatus::CHANGED;
3087*fe6060f1SDimitry Andric     }
3088*fe6060f1SDimitry Andric 
3089*fe6060f1SDimitry Andric     // Keep track in the statistics of our new shiny custom state machine.
3090*fe6060f1SDimitry Andric     if (ReachedUnknownParallelRegions.empty()) {
3091*fe6060f1SDimitry Andric       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3092*fe6060f1SDimitry Andric 
3093*fe6060f1SDimitry Andric       auto Remark = [&](OptimizationRemark OR) {
3094*fe6060f1SDimitry Andric         return OR << "Rewriting generic-mode kernel with a customized state "
3095*fe6060f1SDimitry Andric                      "machine.";
3096*fe6060f1SDimitry Andric       };
3097*fe6060f1SDimitry Andric       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3098*fe6060f1SDimitry Andric     } else {
3099*fe6060f1SDimitry Andric       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3100*fe6060f1SDimitry Andric 
3101*fe6060f1SDimitry Andric       auto Remark = [&](OptimizationRemarkAnalysis OR) {
3102*fe6060f1SDimitry Andric         return OR << "Generic-mode kernel is executed with a customized state "
3103*fe6060f1SDimitry Andric                      "machine that requires a fallback.";
3104*fe6060f1SDimitry Andric       };
3105*fe6060f1SDimitry Andric       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3106*fe6060f1SDimitry Andric 
3107*fe6060f1SDimitry Andric       // Tell the user why we ended up with a fallback.
3108*fe6060f1SDimitry Andric       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3109*fe6060f1SDimitry Andric         if (!UnknownParallelRegionCB)
3110*fe6060f1SDimitry Andric           continue;
3111*fe6060f1SDimitry Andric         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3112*fe6060f1SDimitry Andric           return ORA << "Call may contain unknown parallel regions. Use "
3113*fe6060f1SDimitry Andric                      << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3114*fe6060f1SDimitry Andric                         "override.";
3115*fe6060f1SDimitry Andric         };
3116*fe6060f1SDimitry Andric         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3117*fe6060f1SDimitry Andric                                                  "OMP133", Remark);
3118*fe6060f1SDimitry Andric       }
3119*fe6060f1SDimitry Andric     }
3120*fe6060f1SDimitry Andric 
3121*fe6060f1SDimitry Andric     // Create all the blocks:
3122*fe6060f1SDimitry Andric     //
3123*fe6060f1SDimitry Andric     //                       InitCB = __kmpc_target_init(...)
3124*fe6060f1SDimitry Andric     //                       bool IsWorker = InitCB >= 0;
3125*fe6060f1SDimitry Andric     //                       if (IsWorker) {
3126*fe6060f1SDimitry Andric     // SMBeginBB:               __kmpc_barrier_simple_spmd(...);
3127*fe6060f1SDimitry Andric     //                         void *WorkFn;
3128*fe6060f1SDimitry Andric     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
3129*fe6060f1SDimitry Andric     //                         if (!WorkFn) return;
3130*fe6060f1SDimitry Andric     // SMIsActiveCheckBB:       if (Active) {
3131*fe6060f1SDimitry Andric     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
3132*fe6060f1SDimitry Andric     //                              ParFn0(...);
3133*fe6060f1SDimitry Andric     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
3134*fe6060f1SDimitry Andric     //                              ParFn1(...);
3135*fe6060f1SDimitry Andric     //                            ...
3136*fe6060f1SDimitry Andric     // SMIfCascadeCurrentBB:      else
3137*fe6060f1SDimitry Andric     //                              ((WorkFnTy*)WorkFn)(...);
3138*fe6060f1SDimitry Andric     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
3139*fe6060f1SDimitry Andric     //                          }
3140*fe6060f1SDimitry Andric     // SMDoneBB:                __kmpc_barrier_simple_spmd(...);
3141*fe6060f1SDimitry Andric     //                          goto SMBeginBB;
3142*fe6060f1SDimitry Andric     //                       }
3143*fe6060f1SDimitry Andric     // UserCodeEntryBB:      // user code
3144*fe6060f1SDimitry Andric     //                       __kmpc_target_deinit(...)
3145*fe6060f1SDimitry Andric     //
3146*fe6060f1SDimitry Andric     Function *Kernel = getAssociatedFunction();
3147*fe6060f1SDimitry Andric     assert(Kernel && "Expected an associated function!");
3148*fe6060f1SDimitry Andric 
3149*fe6060f1SDimitry Andric     BasicBlock *InitBB = KernelInitCB->getParent();
3150*fe6060f1SDimitry Andric     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3151*fe6060f1SDimitry Andric         KernelInitCB->getNextNode(), "thread.user_code.check");
3152*fe6060f1SDimitry Andric     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3153*fe6060f1SDimitry Andric         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3154*fe6060f1SDimitry Andric     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3155*fe6060f1SDimitry Andric         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3156*fe6060f1SDimitry Andric     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3157*fe6060f1SDimitry Andric         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3158*fe6060f1SDimitry Andric     BasicBlock *StateMachineIfCascadeCurrentBB =
3159*fe6060f1SDimitry Andric         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3160*fe6060f1SDimitry Andric                            Kernel, UserCodeEntryBB);
3161*fe6060f1SDimitry Andric     BasicBlock *StateMachineEndParallelBB =
3162*fe6060f1SDimitry Andric         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3163*fe6060f1SDimitry Andric                            Kernel, UserCodeEntryBB);
3164*fe6060f1SDimitry Andric     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3165*fe6060f1SDimitry Andric         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3166*fe6060f1SDimitry Andric     A.registerManifestAddedBasicBlock(*InitBB);
3167*fe6060f1SDimitry Andric     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3168*fe6060f1SDimitry Andric     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3169*fe6060f1SDimitry Andric     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3170*fe6060f1SDimitry Andric     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3171*fe6060f1SDimitry Andric     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3172*fe6060f1SDimitry Andric     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3173*fe6060f1SDimitry Andric     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3174*fe6060f1SDimitry Andric 
3175*fe6060f1SDimitry Andric     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3176*fe6060f1SDimitry Andric     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3177*fe6060f1SDimitry Andric 
3178*fe6060f1SDimitry Andric     InitBB->getTerminator()->eraseFromParent();
3179*fe6060f1SDimitry Andric     Instruction *IsWorker =
3180*fe6060f1SDimitry Andric         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3181*fe6060f1SDimitry Andric                          ConstantInt::get(KernelInitCB->getType(), -1),
3182*fe6060f1SDimitry Andric                          "thread.is_worker", InitBB);
3183*fe6060f1SDimitry Andric     IsWorker->setDebugLoc(DLoc);
3184*fe6060f1SDimitry Andric     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3185*fe6060f1SDimitry Andric 
3186*fe6060f1SDimitry Andric     // Create local storage for the work function pointer.
3187*fe6060f1SDimitry Andric     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3188*fe6060f1SDimitry Andric     AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr",
3189*fe6060f1SDimitry Andric                                           &Kernel->getEntryBlock().front());
3190*fe6060f1SDimitry Andric     WorkFnAI->setDebugLoc(DLoc);
3191*fe6060f1SDimitry Andric 
3192*fe6060f1SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3193*fe6060f1SDimitry Andric     OMPInfoCache.OMPBuilder.updateToLocation(
3194*fe6060f1SDimitry Andric         OpenMPIRBuilder::LocationDescription(
3195*fe6060f1SDimitry Andric             IRBuilder<>::InsertPoint(StateMachineBeginBB,
3196*fe6060f1SDimitry Andric                                      StateMachineBeginBB->end()),
3197*fe6060f1SDimitry Andric             DLoc));
3198*fe6060f1SDimitry Andric 
3199*fe6060f1SDimitry Andric     Value *Ident = KernelInitCB->getArgOperand(0);
3200*fe6060f1SDimitry Andric     Value *GTid = KernelInitCB;
3201*fe6060f1SDimitry Andric 
3202*fe6060f1SDimitry Andric     Module &M = *Kernel->getParent();
3203*fe6060f1SDimitry Andric     FunctionCallee BarrierFn =
3204*fe6060f1SDimitry Andric         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3205*fe6060f1SDimitry Andric             M, OMPRTL___kmpc_barrier_simple_spmd);
3206*fe6060f1SDimitry Andric     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
3207*fe6060f1SDimitry Andric         ->setDebugLoc(DLoc);
3208*fe6060f1SDimitry Andric 
3209*fe6060f1SDimitry Andric     FunctionCallee KernelParallelFn =
3210*fe6060f1SDimitry Andric         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3211*fe6060f1SDimitry Andric             M, OMPRTL___kmpc_kernel_parallel);
3212*fe6060f1SDimitry Andric     Instruction *IsActiveWorker = CallInst::Create(
3213*fe6060f1SDimitry Andric         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3214*fe6060f1SDimitry Andric     IsActiveWorker->setDebugLoc(DLoc);
3215*fe6060f1SDimitry Andric     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3216*fe6060f1SDimitry Andric                                        StateMachineBeginBB);
3217*fe6060f1SDimitry Andric     WorkFn->setDebugLoc(DLoc);
3218*fe6060f1SDimitry Andric 
3219*fe6060f1SDimitry Andric     FunctionType *ParallelRegionFnTy = FunctionType::get(
3220*fe6060f1SDimitry Andric         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3221*fe6060f1SDimitry Andric         false);
3222*fe6060f1SDimitry Andric     Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3223*fe6060f1SDimitry Andric         WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3224*fe6060f1SDimitry Andric         StateMachineBeginBB);
3225*fe6060f1SDimitry Andric 
3226*fe6060f1SDimitry Andric     Instruction *IsDone =
3227*fe6060f1SDimitry Andric         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3228*fe6060f1SDimitry Andric                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
3229*fe6060f1SDimitry Andric                          StateMachineBeginBB);
3230*fe6060f1SDimitry Andric     IsDone->setDebugLoc(DLoc);
3231*fe6060f1SDimitry Andric     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3232*fe6060f1SDimitry Andric                        IsDone, StateMachineBeginBB)
3233*fe6060f1SDimitry Andric         ->setDebugLoc(DLoc);
3234*fe6060f1SDimitry Andric 
3235*fe6060f1SDimitry Andric     BranchInst::Create(StateMachineIfCascadeCurrentBB,
3236*fe6060f1SDimitry Andric                        StateMachineDoneBarrierBB, IsActiveWorker,
3237*fe6060f1SDimitry Andric                        StateMachineIsActiveCheckBB)
3238*fe6060f1SDimitry Andric         ->setDebugLoc(DLoc);
3239*fe6060f1SDimitry Andric 
3240*fe6060f1SDimitry Andric     Value *ZeroArg =
3241*fe6060f1SDimitry Andric         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3242*fe6060f1SDimitry Andric 
3243*fe6060f1SDimitry Andric     // Now that we have most of the CFG skeleton it is time for the if-cascade
3244*fe6060f1SDimitry Andric     // that checks the function pointer we got from the runtime against the
3245*fe6060f1SDimitry Andric     // parallel regions we expect, if there are any.
3246*fe6060f1SDimitry Andric     for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) {
3247*fe6060f1SDimitry Andric       auto *ParallelRegion = ReachedKnownParallelRegions[i];
3248*fe6060f1SDimitry Andric       BasicBlock *PRExecuteBB = BasicBlock::Create(
3249*fe6060f1SDimitry Andric           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3250*fe6060f1SDimitry Andric           StateMachineEndParallelBB);
3251*fe6060f1SDimitry Andric       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3252*fe6060f1SDimitry Andric           ->setDebugLoc(DLoc);
3253*fe6060f1SDimitry Andric       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3254*fe6060f1SDimitry Andric           ->setDebugLoc(DLoc);
3255*fe6060f1SDimitry Andric 
3256*fe6060f1SDimitry Andric       BasicBlock *PRNextBB =
3257*fe6060f1SDimitry Andric           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3258*fe6060f1SDimitry Andric                              Kernel, StateMachineEndParallelBB);
3259*fe6060f1SDimitry Andric 
3260*fe6060f1SDimitry Andric       // Check if we need to compare the pointer at all or if we can just
3261*fe6060f1SDimitry Andric       // call the parallel region function.
3262*fe6060f1SDimitry Andric       Value *IsPR;
3263*fe6060f1SDimitry Andric       if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) {
3264*fe6060f1SDimitry Andric         Instruction *CmpI = ICmpInst::Create(
3265*fe6060f1SDimitry Andric             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3266*fe6060f1SDimitry Andric             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3267*fe6060f1SDimitry Andric         CmpI->setDebugLoc(DLoc);
3268*fe6060f1SDimitry Andric         IsPR = CmpI;
3269*fe6060f1SDimitry Andric       } else {
3270*fe6060f1SDimitry Andric         IsPR = ConstantInt::getTrue(Ctx);
3271*fe6060f1SDimitry Andric       }
3272*fe6060f1SDimitry Andric 
3273*fe6060f1SDimitry Andric       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3274*fe6060f1SDimitry Andric                          StateMachineIfCascadeCurrentBB)
3275*fe6060f1SDimitry Andric           ->setDebugLoc(DLoc);
3276*fe6060f1SDimitry Andric       StateMachineIfCascadeCurrentBB = PRNextBB;
3277*fe6060f1SDimitry Andric     }
3278*fe6060f1SDimitry Andric 
3279*fe6060f1SDimitry Andric     // At the end of the if-cascade we place the indirect function pointer call
3280*fe6060f1SDimitry Andric     // in case we might need it, that is if there can be parallel regions we
3281*fe6060f1SDimitry Andric     // have not handled in the if-cascade above.
3282*fe6060f1SDimitry Andric     if (!ReachedUnknownParallelRegions.empty()) {
3283*fe6060f1SDimitry Andric       StateMachineIfCascadeCurrentBB->setName(
3284*fe6060f1SDimitry Andric           "worker_state_machine.parallel_region.fallback.execute");
3285*fe6060f1SDimitry Andric       CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3286*fe6060f1SDimitry Andric                        StateMachineIfCascadeCurrentBB)
3287*fe6060f1SDimitry Andric           ->setDebugLoc(DLoc);
3288*fe6060f1SDimitry Andric     }
3289*fe6060f1SDimitry Andric     BranchInst::Create(StateMachineEndParallelBB,
3290*fe6060f1SDimitry Andric                        StateMachineIfCascadeCurrentBB)
3291*fe6060f1SDimitry Andric         ->setDebugLoc(DLoc);
3292*fe6060f1SDimitry Andric 
3293*fe6060f1SDimitry Andric     CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3294*fe6060f1SDimitry Andric                          M, OMPRTL___kmpc_kernel_end_parallel),
3295*fe6060f1SDimitry Andric                      {}, "", StateMachineEndParallelBB)
3296*fe6060f1SDimitry Andric         ->setDebugLoc(DLoc);
3297*fe6060f1SDimitry Andric     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3298*fe6060f1SDimitry Andric         ->setDebugLoc(DLoc);
3299*fe6060f1SDimitry Andric 
3300*fe6060f1SDimitry Andric     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3301*fe6060f1SDimitry Andric         ->setDebugLoc(DLoc);
3302*fe6060f1SDimitry Andric     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3303*fe6060f1SDimitry Andric         ->setDebugLoc(DLoc);
3304*fe6060f1SDimitry Andric 
3305*fe6060f1SDimitry Andric     return ChangeStatus::CHANGED;
3306*fe6060f1SDimitry Andric   }
3307*fe6060f1SDimitry Andric 
3308*fe6060f1SDimitry Andric   /// Fixpoint iteration update function. Will be called every time a dependence
3309*fe6060f1SDimitry Andric   /// changed its state (and in the beginning).
3310*fe6060f1SDimitry Andric   ChangeStatus updateImpl(Attributor &A) override {
3311*fe6060f1SDimitry Andric     KernelInfoState StateBefore = getState();
3312*fe6060f1SDimitry Andric 
3313*fe6060f1SDimitry Andric     // Callback to check a read/write instruction.
3314*fe6060f1SDimitry Andric     auto CheckRWInst = [&](Instruction &I) {
3315*fe6060f1SDimitry Andric       // We handle calls later.
3316*fe6060f1SDimitry Andric       if (isa<CallBase>(I))
3317*fe6060f1SDimitry Andric         return true;
3318*fe6060f1SDimitry Andric       // We only care about write effects.
3319*fe6060f1SDimitry Andric       if (!I.mayWriteToMemory())
3320*fe6060f1SDimitry Andric         return true;
3321*fe6060f1SDimitry Andric       if (auto *SI = dyn_cast<StoreInst>(&I)) {
3322*fe6060f1SDimitry Andric         SmallVector<const Value *> Objects;
3323*fe6060f1SDimitry Andric         getUnderlyingObjects(SI->getPointerOperand(), Objects);
3324*fe6060f1SDimitry Andric         if (llvm::all_of(Objects,
3325*fe6060f1SDimitry Andric                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3326*fe6060f1SDimitry Andric           return true;
3327*fe6060f1SDimitry Andric       }
3328*fe6060f1SDimitry Andric       // For now we give up on everything but stores.
3329*fe6060f1SDimitry Andric       SPMDCompatibilityTracker.insert(&I);
3330*fe6060f1SDimitry Andric       return true;
3331*fe6060f1SDimitry Andric     };
3332*fe6060f1SDimitry Andric 
3333*fe6060f1SDimitry Andric     bool UsedAssumedInformationInCheckRWInst = false;
3334*fe6060f1SDimitry Andric     if (!SPMDCompatibilityTracker.isAtFixpoint())
3335*fe6060f1SDimitry Andric       if (!A.checkForAllReadWriteInstructions(
3336*fe6060f1SDimitry Andric               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3337*fe6060f1SDimitry Andric         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3338*fe6060f1SDimitry Andric 
3339*fe6060f1SDimitry Andric     if (!IsKernelEntry) {
3340*fe6060f1SDimitry Andric       updateReachingKernelEntries(A);
3341*fe6060f1SDimitry Andric       updateParallelLevels(A);
3342*fe6060f1SDimitry Andric     }
3343*fe6060f1SDimitry Andric 
3344*fe6060f1SDimitry Andric     // Callback to check a call instruction.
3345*fe6060f1SDimitry Andric     bool AllSPMDStatesWereFixed = true;
3346*fe6060f1SDimitry Andric     auto CheckCallInst = [&](Instruction &I) {
3347*fe6060f1SDimitry Andric       auto &CB = cast<CallBase>(I);
3348*fe6060f1SDimitry Andric       auto &CBAA = A.getAAFor<AAKernelInfo>(
3349*fe6060f1SDimitry Andric           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3350*fe6060f1SDimitry Andric       getState() ^= CBAA.getState();
3351*fe6060f1SDimitry Andric       AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
3352*fe6060f1SDimitry Andric       return true;
3353*fe6060f1SDimitry Andric     };
3354*fe6060f1SDimitry Andric 
3355*fe6060f1SDimitry Andric     bool UsedAssumedInformationInCheckCallInst = false;
3356*fe6060f1SDimitry Andric     if (!A.checkForAllCallLikeInstructions(
3357*fe6060f1SDimitry Andric             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst))
3358*fe6060f1SDimitry Andric       return indicatePessimisticFixpoint();
3359*fe6060f1SDimitry Andric 
3360*fe6060f1SDimitry Andric     // If we haven't used any assumed information for the SPMD state we can fix
3361*fe6060f1SDimitry Andric     // it.
3362*fe6060f1SDimitry Andric     if (!UsedAssumedInformationInCheckRWInst &&
3363*fe6060f1SDimitry Andric         !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed)
3364*fe6060f1SDimitry Andric       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3365*fe6060f1SDimitry Andric 
3366*fe6060f1SDimitry Andric     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3367*fe6060f1SDimitry Andric                                      : ChangeStatus::CHANGED;
3368*fe6060f1SDimitry Andric   }
3369*fe6060f1SDimitry Andric 
3370*fe6060f1SDimitry Andric private:
3371*fe6060f1SDimitry Andric   /// Update info regarding reaching kernels.
3372*fe6060f1SDimitry Andric   void updateReachingKernelEntries(Attributor &A) {
3373*fe6060f1SDimitry Andric     auto PredCallSite = [&](AbstractCallSite ACS) {
3374*fe6060f1SDimitry Andric       Function *Caller = ACS.getInstruction()->getFunction();
3375*fe6060f1SDimitry Andric 
3376*fe6060f1SDimitry Andric       assert(Caller && "Caller is nullptr");
3377*fe6060f1SDimitry Andric 
3378*fe6060f1SDimitry Andric       auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
3379*fe6060f1SDimitry Andric           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
3380*fe6060f1SDimitry Andric       if (CAA.ReachingKernelEntries.isValidState()) {
3381*fe6060f1SDimitry Andric         ReachingKernelEntries ^= CAA.ReachingKernelEntries;
3382*fe6060f1SDimitry Andric         return true;
3383*fe6060f1SDimitry Andric       }
3384*fe6060f1SDimitry Andric 
3385*fe6060f1SDimitry Andric       // We lost track of the caller of the associated function, any kernel
3386*fe6060f1SDimitry Andric       // could reach now.
3387*fe6060f1SDimitry Andric       ReachingKernelEntries.indicatePessimisticFixpoint();
3388*fe6060f1SDimitry Andric 
3389*fe6060f1SDimitry Andric       return true;
3390*fe6060f1SDimitry Andric     };
3391*fe6060f1SDimitry Andric 
3392*fe6060f1SDimitry Andric     bool AllCallSitesKnown;
3393*fe6060f1SDimitry Andric     if (!A.checkForAllCallSites(PredCallSite, *this,
3394*fe6060f1SDimitry Andric                                 true /* RequireAllCallSites */,
3395*fe6060f1SDimitry Andric                                 AllCallSitesKnown))
3396*fe6060f1SDimitry Andric       ReachingKernelEntries.indicatePessimisticFixpoint();
3397*fe6060f1SDimitry Andric   }
3398*fe6060f1SDimitry Andric 
3399*fe6060f1SDimitry Andric   /// Update info regarding parallel levels.
3400*fe6060f1SDimitry Andric   void updateParallelLevels(Attributor &A) {
3401*fe6060f1SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3402*fe6060f1SDimitry Andric     OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
3403*fe6060f1SDimitry Andric         OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
3404*fe6060f1SDimitry Andric 
3405*fe6060f1SDimitry Andric     auto PredCallSite = [&](AbstractCallSite ACS) {
3406*fe6060f1SDimitry Andric       Function *Caller = ACS.getInstruction()->getFunction();
3407*fe6060f1SDimitry Andric 
3408*fe6060f1SDimitry Andric       assert(Caller && "Caller is nullptr");
3409*fe6060f1SDimitry Andric 
3410*fe6060f1SDimitry Andric       auto &CAA =
3411*fe6060f1SDimitry Andric           A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
3412*fe6060f1SDimitry Andric       if (CAA.ParallelLevels.isValidState()) {
3413*fe6060f1SDimitry Andric         // Any function that is called by `__kmpc_parallel_51` will not be
3414*fe6060f1SDimitry Andric         // folded as the parallel level in the function is updated. In order to
3415*fe6060f1SDimitry Andric         // get it right, all the analysis would depend on the implentation. That
3416*fe6060f1SDimitry Andric         // said, if in the future any change to the implementation, the analysis
3417*fe6060f1SDimitry Andric         // could be wrong. As a consequence, we are just conservative here.
3418*fe6060f1SDimitry Andric         if (Caller == Parallel51RFI.Declaration) {
3419*fe6060f1SDimitry Andric           ParallelLevels.indicatePessimisticFixpoint();
3420*fe6060f1SDimitry Andric           return true;
3421*fe6060f1SDimitry Andric         }
3422*fe6060f1SDimitry Andric 
3423*fe6060f1SDimitry Andric         ParallelLevels ^= CAA.ParallelLevels;
3424*fe6060f1SDimitry Andric 
3425*fe6060f1SDimitry Andric         return true;
3426*fe6060f1SDimitry Andric       }
3427*fe6060f1SDimitry Andric 
3428*fe6060f1SDimitry Andric       // We lost track of the caller of the associated function, any kernel
3429*fe6060f1SDimitry Andric       // could reach now.
3430*fe6060f1SDimitry Andric       ParallelLevels.indicatePessimisticFixpoint();
3431*fe6060f1SDimitry Andric 
3432*fe6060f1SDimitry Andric       return true;
3433*fe6060f1SDimitry Andric     };
3434*fe6060f1SDimitry Andric 
3435*fe6060f1SDimitry Andric     bool AllCallSitesKnown = true;
3436*fe6060f1SDimitry Andric     if (!A.checkForAllCallSites(PredCallSite, *this,
3437*fe6060f1SDimitry Andric                                 true /* RequireAllCallSites */,
3438*fe6060f1SDimitry Andric                                 AllCallSitesKnown))
3439*fe6060f1SDimitry Andric       ParallelLevels.indicatePessimisticFixpoint();
3440*fe6060f1SDimitry Andric   }
3441*fe6060f1SDimitry Andric };
3442*fe6060f1SDimitry Andric 
3443*fe6060f1SDimitry Andric /// The call site kernel info abstract attribute, basically, what can we say
3444*fe6060f1SDimitry Andric /// about a call site with regards to the KernelInfoState. For now this simply
3445*fe6060f1SDimitry Andric /// forwards the information from the callee.
3446*fe6060f1SDimitry Andric struct AAKernelInfoCallSite : AAKernelInfo {
3447*fe6060f1SDimitry Andric   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3448*fe6060f1SDimitry Andric       : AAKernelInfo(IRP, A) {}
3449*fe6060f1SDimitry Andric 
3450*fe6060f1SDimitry Andric   /// See AbstractAttribute::initialize(...).
3451*fe6060f1SDimitry Andric   void initialize(Attributor &A) override {
3452*fe6060f1SDimitry Andric     AAKernelInfo::initialize(A);
3453*fe6060f1SDimitry Andric 
3454*fe6060f1SDimitry Andric     CallBase &CB = cast<CallBase>(getAssociatedValue());
3455*fe6060f1SDimitry Andric     Function *Callee = getAssociatedFunction();
3456*fe6060f1SDimitry Andric 
3457*fe6060f1SDimitry Andric     // Helper to lookup an assumption string.
3458*fe6060f1SDimitry Andric     auto HasAssumption = [](Function *Fn, StringRef AssumptionStr) {
3459*fe6060f1SDimitry Andric       return Fn && hasAssumption(*Fn, AssumptionStr);
3460*fe6060f1SDimitry Andric     };
3461*fe6060f1SDimitry Andric 
3462*fe6060f1SDimitry Andric     // Check for SPMD-mode assumptions.
3463*fe6060f1SDimitry Andric     if (HasAssumption(Callee, "ompx_spmd_amenable"))
3464*fe6060f1SDimitry Andric       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3465*fe6060f1SDimitry Andric 
3466*fe6060f1SDimitry Andric     // First weed out calls we do not care about, that is readonly/readnone
3467*fe6060f1SDimitry Andric     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3468*fe6060f1SDimitry Andric     // parallel region or anything else we are looking for.
3469*fe6060f1SDimitry Andric     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3470*fe6060f1SDimitry Andric       indicateOptimisticFixpoint();
3471*fe6060f1SDimitry Andric       return;
3472*fe6060f1SDimitry Andric     }
3473*fe6060f1SDimitry Andric 
3474*fe6060f1SDimitry Andric     // Next we check if we know the callee. If it is a known OpenMP function
3475*fe6060f1SDimitry Andric     // we will handle them explicitly in the switch below. If it is not, we
3476*fe6060f1SDimitry Andric     // will use an AAKernelInfo object on the callee to gather information and
3477*fe6060f1SDimitry Andric     // merge that into the current state. The latter happens in the updateImpl.
3478*fe6060f1SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3479*fe6060f1SDimitry Andric     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3480*fe6060f1SDimitry Andric     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3481*fe6060f1SDimitry Andric       // Unknown caller or declarations are not analyzable, we give up.
3482*fe6060f1SDimitry Andric       if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3483*fe6060f1SDimitry Andric 
3484*fe6060f1SDimitry Andric         // Unknown callees might contain parallel regions, except if they have
3485*fe6060f1SDimitry Andric         // an appropriate assumption attached.
3486*fe6060f1SDimitry Andric         if (!(HasAssumption(Callee, "omp_no_openmp") ||
3487*fe6060f1SDimitry Andric               HasAssumption(Callee, "omp_no_parallelism")))
3488*fe6060f1SDimitry Andric           ReachedUnknownParallelRegions.insert(&CB);
3489*fe6060f1SDimitry Andric 
3490*fe6060f1SDimitry Andric         // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3491*fe6060f1SDimitry Andric         // idea we can run something unknown in SPMD-mode.
3492*fe6060f1SDimitry Andric         if (!SPMDCompatibilityTracker.isAtFixpoint())
3493*fe6060f1SDimitry Andric           SPMDCompatibilityTracker.insert(&CB);
3494*fe6060f1SDimitry Andric 
3495*fe6060f1SDimitry Andric         // We have updated the state for this unknown call properly, there won't
3496*fe6060f1SDimitry Andric         // be any change so we indicate a fixpoint.
3497*fe6060f1SDimitry Andric         indicateOptimisticFixpoint();
3498*fe6060f1SDimitry Andric       }
3499*fe6060f1SDimitry Andric       // If the callee is known and can be used in IPO, we will update the state
3500*fe6060f1SDimitry Andric       // based on the callee state in updateImpl.
3501*fe6060f1SDimitry Andric       return;
3502*fe6060f1SDimitry Andric     }
3503*fe6060f1SDimitry Andric 
3504*fe6060f1SDimitry Andric     const unsigned int WrapperFunctionArgNo = 6;
3505*fe6060f1SDimitry Andric     RuntimeFunction RF = It->getSecond();
3506*fe6060f1SDimitry Andric     switch (RF) {
3507*fe6060f1SDimitry Andric     // All the functions we know are compatible with SPMD mode.
3508*fe6060f1SDimitry Andric     case OMPRTL___kmpc_is_spmd_exec_mode:
3509*fe6060f1SDimitry Andric     case OMPRTL___kmpc_for_static_fini:
3510*fe6060f1SDimitry Andric     case OMPRTL___kmpc_global_thread_num:
3511*fe6060f1SDimitry Andric     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3512*fe6060f1SDimitry Andric     case OMPRTL___kmpc_get_hardware_num_blocks:
3513*fe6060f1SDimitry Andric     case OMPRTL___kmpc_single:
3514*fe6060f1SDimitry Andric     case OMPRTL___kmpc_end_single:
3515*fe6060f1SDimitry Andric     case OMPRTL___kmpc_master:
3516*fe6060f1SDimitry Andric     case OMPRTL___kmpc_end_master:
3517*fe6060f1SDimitry Andric     case OMPRTL___kmpc_barrier:
3518*fe6060f1SDimitry Andric       break;
3519*fe6060f1SDimitry Andric     case OMPRTL___kmpc_for_static_init_4:
3520*fe6060f1SDimitry Andric     case OMPRTL___kmpc_for_static_init_4u:
3521*fe6060f1SDimitry Andric     case OMPRTL___kmpc_for_static_init_8:
3522*fe6060f1SDimitry Andric     case OMPRTL___kmpc_for_static_init_8u: {
3523*fe6060f1SDimitry Andric       // Check the schedule and allow static schedule in SPMD mode.
3524*fe6060f1SDimitry Andric       unsigned ScheduleArgOpNo = 2;
3525*fe6060f1SDimitry Andric       auto *ScheduleTypeCI =
3526*fe6060f1SDimitry Andric           dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
3527*fe6060f1SDimitry Andric       unsigned ScheduleTypeVal =
3528*fe6060f1SDimitry Andric           ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
3529*fe6060f1SDimitry Andric       switch (OMPScheduleType(ScheduleTypeVal)) {
3530*fe6060f1SDimitry Andric       case OMPScheduleType::Static:
3531*fe6060f1SDimitry Andric       case OMPScheduleType::StaticChunked:
3532*fe6060f1SDimitry Andric       case OMPScheduleType::Distribute:
3533*fe6060f1SDimitry Andric       case OMPScheduleType::DistributeChunked:
3534*fe6060f1SDimitry Andric         break;
3535*fe6060f1SDimitry Andric       default:
3536*fe6060f1SDimitry Andric         SPMDCompatibilityTracker.insert(&CB);
3537*fe6060f1SDimitry Andric         break;
3538*fe6060f1SDimitry Andric       };
3539*fe6060f1SDimitry Andric     } break;
3540*fe6060f1SDimitry Andric     case OMPRTL___kmpc_target_init:
3541*fe6060f1SDimitry Andric       KernelInitCB = &CB;
3542*fe6060f1SDimitry Andric       break;
3543*fe6060f1SDimitry Andric     case OMPRTL___kmpc_target_deinit:
3544*fe6060f1SDimitry Andric       KernelDeinitCB = &CB;
3545*fe6060f1SDimitry Andric       break;
3546*fe6060f1SDimitry Andric     case OMPRTL___kmpc_parallel_51:
3547*fe6060f1SDimitry Andric       if (auto *ParallelRegion = dyn_cast<Function>(
3548*fe6060f1SDimitry Andric               CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
3549*fe6060f1SDimitry Andric         ReachedKnownParallelRegions.insert(ParallelRegion);
3550*fe6060f1SDimitry Andric         break;
3551*fe6060f1SDimitry Andric       }
3552*fe6060f1SDimitry Andric       // The condition above should usually get the parallel region function
3553*fe6060f1SDimitry Andric       // pointer and record it. In the off chance it doesn't we assume the
3554*fe6060f1SDimitry Andric       // worst.
3555*fe6060f1SDimitry Andric       ReachedUnknownParallelRegions.insert(&CB);
3556*fe6060f1SDimitry Andric       break;
3557*fe6060f1SDimitry Andric     case OMPRTL___kmpc_omp_task:
3558*fe6060f1SDimitry Andric       // We do not look into tasks right now, just give up.
3559*fe6060f1SDimitry Andric       SPMDCompatibilityTracker.insert(&CB);
3560*fe6060f1SDimitry Andric       ReachedUnknownParallelRegions.insert(&CB);
3561*fe6060f1SDimitry Andric       break;
3562*fe6060f1SDimitry Andric     case OMPRTL___kmpc_alloc_shared:
3563*fe6060f1SDimitry Andric     case OMPRTL___kmpc_free_shared:
3564*fe6060f1SDimitry Andric       // Return without setting a fixpoint, to be resolved in updateImpl.
3565*fe6060f1SDimitry Andric       return;
3566*fe6060f1SDimitry Andric     default:
3567*fe6060f1SDimitry Andric       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3568*fe6060f1SDimitry Andric       // generally.
3569*fe6060f1SDimitry Andric       SPMDCompatibilityTracker.insert(&CB);
3570*fe6060f1SDimitry Andric       break;
3571*fe6060f1SDimitry Andric     }
3572*fe6060f1SDimitry Andric     // All other OpenMP runtime calls will not reach parallel regions so they
3573*fe6060f1SDimitry Andric     // can be safely ignored for now. Since it is a known OpenMP runtime call we
3574*fe6060f1SDimitry Andric     // have now modeled all effects and there is no need for any update.
3575*fe6060f1SDimitry Andric     indicateOptimisticFixpoint();
3576*fe6060f1SDimitry Andric   }
3577*fe6060f1SDimitry Andric 
3578*fe6060f1SDimitry Andric   ChangeStatus updateImpl(Attributor &A) override {
3579*fe6060f1SDimitry Andric     // TODO: Once we have call site specific value information we can provide
3580*fe6060f1SDimitry Andric     //       call site specific liveness information and then it makes
3581*fe6060f1SDimitry Andric     //       sense to specialize attributes for call sites arguments instead of
3582*fe6060f1SDimitry Andric     //       redirecting requests to the callee argument.
3583*fe6060f1SDimitry Andric     Function *F = getAssociatedFunction();
3584*fe6060f1SDimitry Andric 
3585*fe6060f1SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3586*fe6060f1SDimitry Andric     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
3587*fe6060f1SDimitry Andric 
3588*fe6060f1SDimitry Andric     // If F is not a runtime function, propagate the AAKernelInfo of the callee.
3589*fe6060f1SDimitry Andric     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3590*fe6060f1SDimitry Andric       const IRPosition &FnPos = IRPosition::function(*F);
3591*fe6060f1SDimitry Andric       auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
3592*fe6060f1SDimitry Andric       if (getState() == FnAA.getState())
3593*fe6060f1SDimitry Andric         return ChangeStatus::UNCHANGED;
3594*fe6060f1SDimitry Andric       getState() = FnAA.getState();
3595*fe6060f1SDimitry Andric       return ChangeStatus::CHANGED;
3596*fe6060f1SDimitry Andric     }
3597*fe6060f1SDimitry Andric 
3598*fe6060f1SDimitry Andric     // F is a runtime function that allocates or frees memory, check
3599*fe6060f1SDimitry Andric     // AAHeapToStack and AAHeapToShared.
3600*fe6060f1SDimitry Andric     KernelInfoState StateBefore = getState();
3601*fe6060f1SDimitry Andric     assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
3602*fe6060f1SDimitry Andric             It->getSecond() == OMPRTL___kmpc_free_shared) &&
3603*fe6060f1SDimitry Andric            "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
3604*fe6060f1SDimitry Andric 
3605*fe6060f1SDimitry Andric     CallBase &CB = cast<CallBase>(getAssociatedValue());
3606*fe6060f1SDimitry Andric 
3607*fe6060f1SDimitry Andric     auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
3608*fe6060f1SDimitry Andric         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3609*fe6060f1SDimitry Andric     auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
3610*fe6060f1SDimitry Andric         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3611*fe6060f1SDimitry Andric 
3612*fe6060f1SDimitry Andric     RuntimeFunction RF = It->getSecond();
3613*fe6060f1SDimitry Andric 
3614*fe6060f1SDimitry Andric     switch (RF) {
3615*fe6060f1SDimitry Andric     // If neither HeapToStack nor HeapToShared assume the call is removed,
3616*fe6060f1SDimitry Andric     // assume SPMD incompatibility.
3617*fe6060f1SDimitry Andric     case OMPRTL___kmpc_alloc_shared:
3618*fe6060f1SDimitry Andric       if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
3619*fe6060f1SDimitry Andric           !HeapToSharedAA.isAssumedHeapToShared(CB))
3620*fe6060f1SDimitry Andric         SPMDCompatibilityTracker.insert(&CB);
3621*fe6060f1SDimitry Andric       break;
3622*fe6060f1SDimitry Andric     case OMPRTL___kmpc_free_shared:
3623*fe6060f1SDimitry Andric       if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
3624*fe6060f1SDimitry Andric           !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
3625*fe6060f1SDimitry Andric         SPMDCompatibilityTracker.insert(&CB);
3626*fe6060f1SDimitry Andric       break;
3627*fe6060f1SDimitry Andric     default:
3628*fe6060f1SDimitry Andric       SPMDCompatibilityTracker.insert(&CB);
3629*fe6060f1SDimitry Andric     }
3630*fe6060f1SDimitry Andric 
3631*fe6060f1SDimitry Andric     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3632*fe6060f1SDimitry Andric                                      : ChangeStatus::CHANGED;
3633*fe6060f1SDimitry Andric   }
3634*fe6060f1SDimitry Andric };
3635*fe6060f1SDimitry Andric 
3636*fe6060f1SDimitry Andric struct AAFoldRuntimeCall
3637*fe6060f1SDimitry Andric     : public StateWrapper<BooleanState, AbstractAttribute> {
3638*fe6060f1SDimitry Andric   using Base = StateWrapper<BooleanState, AbstractAttribute>;
3639*fe6060f1SDimitry Andric 
3640*fe6060f1SDimitry Andric   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3641*fe6060f1SDimitry Andric 
3642*fe6060f1SDimitry Andric   /// Statistics are tracked as part of manifest for now.
3643*fe6060f1SDimitry Andric   void trackStatistics() const override {}
3644*fe6060f1SDimitry Andric 
3645*fe6060f1SDimitry Andric   /// Create an abstract attribute biew for the position \p IRP.
3646*fe6060f1SDimitry Andric   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
3647*fe6060f1SDimitry Andric                                               Attributor &A);
3648*fe6060f1SDimitry Andric 
3649*fe6060f1SDimitry Andric   /// See AbstractAttribute::getName()
3650*fe6060f1SDimitry Andric   const std::string getName() const override { return "AAFoldRuntimeCall"; }
3651*fe6060f1SDimitry Andric 
3652*fe6060f1SDimitry Andric   /// See AbstractAttribute::getIdAddr()
3653*fe6060f1SDimitry Andric   const char *getIdAddr() const override { return &ID; }
3654*fe6060f1SDimitry Andric 
3655*fe6060f1SDimitry Andric   /// This function should return true if the type of the \p AA is
3656*fe6060f1SDimitry Andric   /// AAFoldRuntimeCall
3657*fe6060f1SDimitry Andric   static bool classof(const AbstractAttribute *AA) {
3658*fe6060f1SDimitry Andric     return (AA->getIdAddr() == &ID);
3659*fe6060f1SDimitry Andric   }
3660*fe6060f1SDimitry Andric 
3661*fe6060f1SDimitry Andric   static const char ID;
3662*fe6060f1SDimitry Andric };
3663*fe6060f1SDimitry Andric 
3664*fe6060f1SDimitry Andric struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3665*fe6060f1SDimitry Andric   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
3666*fe6060f1SDimitry Andric       : AAFoldRuntimeCall(IRP, A) {}
3667*fe6060f1SDimitry Andric 
3668*fe6060f1SDimitry Andric   /// See AbstractAttribute::getAsStr()
3669*fe6060f1SDimitry Andric   const std::string getAsStr() const override {
3670*fe6060f1SDimitry Andric     if (!isValidState())
3671*fe6060f1SDimitry Andric       return "<invalid>";
3672*fe6060f1SDimitry Andric 
3673*fe6060f1SDimitry Andric     std::string Str("simplified value: ");
3674*fe6060f1SDimitry Andric 
3675*fe6060f1SDimitry Andric     if (!SimplifiedValue.hasValue())
3676*fe6060f1SDimitry Andric       return Str + std::string("none");
3677*fe6060f1SDimitry Andric 
3678*fe6060f1SDimitry Andric     if (!SimplifiedValue.getValue())
3679*fe6060f1SDimitry Andric       return Str + std::string("nullptr");
3680*fe6060f1SDimitry Andric 
3681*fe6060f1SDimitry Andric     if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
3682*fe6060f1SDimitry Andric       return Str + std::to_string(CI->getSExtValue());
3683*fe6060f1SDimitry Andric 
3684*fe6060f1SDimitry Andric     return Str + std::string("unknown");
3685*fe6060f1SDimitry Andric   }
3686*fe6060f1SDimitry Andric 
3687*fe6060f1SDimitry Andric   void initialize(Attributor &A) override {
3688*fe6060f1SDimitry Andric     Function *Callee = getAssociatedFunction();
3689*fe6060f1SDimitry Andric 
3690*fe6060f1SDimitry Andric     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3691*fe6060f1SDimitry Andric     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3692*fe6060f1SDimitry Andric     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
3693*fe6060f1SDimitry Andric            "Expected a known OpenMP runtime function");
3694*fe6060f1SDimitry Andric 
3695*fe6060f1SDimitry Andric     RFKind = It->getSecond();
3696*fe6060f1SDimitry Andric 
3697*fe6060f1SDimitry Andric     CallBase &CB = cast<CallBase>(getAssociatedValue());
3698*fe6060f1SDimitry Andric     A.registerSimplificationCallback(
3699*fe6060f1SDimitry Andric         IRPosition::callsite_returned(CB),
3700*fe6060f1SDimitry Andric         [&](const IRPosition &IRP, const AbstractAttribute *AA,
3701*fe6060f1SDimitry Andric             bool &UsedAssumedInformation) -> Optional<Value *> {
3702*fe6060f1SDimitry Andric           assert((isValidState() || (SimplifiedValue.hasValue() &&
3703*fe6060f1SDimitry Andric                                      SimplifiedValue.getValue() == nullptr)) &&
3704*fe6060f1SDimitry Andric                  "Unexpected invalid state!");
3705*fe6060f1SDimitry Andric 
3706*fe6060f1SDimitry Andric           if (!isAtFixpoint()) {
3707*fe6060f1SDimitry Andric             UsedAssumedInformation = true;
3708*fe6060f1SDimitry Andric             if (AA)
3709*fe6060f1SDimitry Andric               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3710*fe6060f1SDimitry Andric           }
3711*fe6060f1SDimitry Andric           return SimplifiedValue;
3712*fe6060f1SDimitry Andric         });
3713*fe6060f1SDimitry Andric   }
3714*fe6060f1SDimitry Andric 
3715*fe6060f1SDimitry Andric   ChangeStatus updateImpl(Attributor &A) override {
3716*fe6060f1SDimitry Andric     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3717*fe6060f1SDimitry Andric     switch (RFKind) {
3718*fe6060f1SDimitry Andric     case OMPRTL___kmpc_is_spmd_exec_mode:
3719*fe6060f1SDimitry Andric       Changed |= foldIsSPMDExecMode(A);
3720*fe6060f1SDimitry Andric       break;
3721*fe6060f1SDimitry Andric     case OMPRTL___kmpc_is_generic_main_thread_id:
3722*fe6060f1SDimitry Andric       Changed |= foldIsGenericMainThread(A);
3723*fe6060f1SDimitry Andric       break;
3724*fe6060f1SDimitry Andric     case OMPRTL___kmpc_parallel_level:
3725*fe6060f1SDimitry Andric       Changed |= foldParallelLevel(A);
3726*fe6060f1SDimitry Andric       break;
3727*fe6060f1SDimitry Andric     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3728*fe6060f1SDimitry Andric       Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
3729*fe6060f1SDimitry Andric       break;
3730*fe6060f1SDimitry Andric     case OMPRTL___kmpc_get_hardware_num_blocks:
3731*fe6060f1SDimitry Andric       Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
3732*fe6060f1SDimitry Andric       break;
3733*fe6060f1SDimitry Andric     default:
3734*fe6060f1SDimitry Andric       llvm_unreachable("Unhandled OpenMP runtime function!");
3735*fe6060f1SDimitry Andric     }
3736*fe6060f1SDimitry Andric 
3737*fe6060f1SDimitry Andric     return Changed;
3738*fe6060f1SDimitry Andric   }
3739*fe6060f1SDimitry Andric 
3740*fe6060f1SDimitry Andric   ChangeStatus manifest(Attributor &A) override {
3741*fe6060f1SDimitry Andric     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3742*fe6060f1SDimitry Andric 
3743*fe6060f1SDimitry Andric     if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
3744*fe6060f1SDimitry Andric       Instruction &CB = *getCtxI();
3745*fe6060f1SDimitry Andric       A.changeValueAfterManifest(CB, **SimplifiedValue);
3746*fe6060f1SDimitry Andric       A.deleteAfterManifest(CB);
3747*fe6060f1SDimitry Andric 
3748*fe6060f1SDimitry Andric       LLVM_DEBUG(dbgs() << TAG << "Folding runtime call: " << CB << " with "
3749*fe6060f1SDimitry Andric                         << **SimplifiedValue << "\n");
3750*fe6060f1SDimitry Andric 
3751*fe6060f1SDimitry Andric       Changed = ChangeStatus::CHANGED;
3752*fe6060f1SDimitry Andric     }
3753*fe6060f1SDimitry Andric 
3754*fe6060f1SDimitry Andric     return Changed;
3755*fe6060f1SDimitry Andric   }
3756*fe6060f1SDimitry Andric 
3757*fe6060f1SDimitry Andric   ChangeStatus indicatePessimisticFixpoint() override {
3758*fe6060f1SDimitry Andric     SimplifiedValue = nullptr;
3759*fe6060f1SDimitry Andric     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
3760*fe6060f1SDimitry Andric   }
3761*fe6060f1SDimitry Andric 
3762*fe6060f1SDimitry Andric private:
3763*fe6060f1SDimitry Andric   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
3764*fe6060f1SDimitry Andric   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
3765*fe6060f1SDimitry Andric     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
3766*fe6060f1SDimitry Andric 
3767*fe6060f1SDimitry Andric     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
3768*fe6060f1SDimitry Andric     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
3769*fe6060f1SDimitry Andric     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
3770*fe6060f1SDimitry Andric         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
3771*fe6060f1SDimitry Andric 
3772*fe6060f1SDimitry Andric     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
3773*fe6060f1SDimitry Andric       return indicatePessimisticFixpoint();
3774*fe6060f1SDimitry Andric 
3775*fe6060f1SDimitry Andric     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
3776*fe6060f1SDimitry Andric       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
3777*fe6060f1SDimitry Andric                                           DepClassTy::REQUIRED);
3778*fe6060f1SDimitry Andric 
3779*fe6060f1SDimitry Andric       if (!AA.isValidState()) {
3780*fe6060f1SDimitry Andric         SimplifiedValue = nullptr;
3781*fe6060f1SDimitry Andric         return indicatePessimisticFixpoint();
3782*fe6060f1SDimitry Andric       }
3783*fe6060f1SDimitry Andric 
3784*fe6060f1SDimitry Andric       if (AA.SPMDCompatibilityTracker.isAssumed()) {
3785*fe6060f1SDimitry Andric         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
3786*fe6060f1SDimitry Andric           ++KnownSPMDCount;
3787*fe6060f1SDimitry Andric         else
3788*fe6060f1SDimitry Andric           ++AssumedSPMDCount;
3789*fe6060f1SDimitry Andric       } else {
3790*fe6060f1SDimitry Andric         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
3791*fe6060f1SDimitry Andric           ++KnownNonSPMDCount;
3792*fe6060f1SDimitry Andric         else
3793*fe6060f1SDimitry Andric           ++AssumedNonSPMDCount;
3794*fe6060f1SDimitry Andric       }
3795*fe6060f1SDimitry Andric     }
3796*fe6060f1SDimitry Andric 
3797*fe6060f1SDimitry Andric     if ((AssumedSPMDCount + KnownSPMDCount) &&
3798*fe6060f1SDimitry Andric         (AssumedNonSPMDCount + KnownNonSPMDCount))
3799*fe6060f1SDimitry Andric       return indicatePessimisticFixpoint();
3800*fe6060f1SDimitry Andric 
3801*fe6060f1SDimitry Andric     auto &Ctx = getAnchorValue().getContext();
3802*fe6060f1SDimitry Andric     if (KnownSPMDCount || AssumedSPMDCount) {
3803*fe6060f1SDimitry Andric       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
3804*fe6060f1SDimitry Andric              "Expected only SPMD kernels!");
3805*fe6060f1SDimitry Andric       // All reaching kernels are in SPMD mode. Update all function calls to
3806*fe6060f1SDimitry Andric       // __kmpc_is_spmd_exec_mode to 1.
3807*fe6060f1SDimitry Andric       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
3808*fe6060f1SDimitry Andric     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
3809*fe6060f1SDimitry Andric       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
3810*fe6060f1SDimitry Andric              "Expected only non-SPMD kernels!");
3811*fe6060f1SDimitry Andric       // All reaching kernels are in non-SPMD mode. Update all function
3812*fe6060f1SDimitry Andric       // calls to __kmpc_is_spmd_exec_mode to 0.
3813*fe6060f1SDimitry Andric       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
3814*fe6060f1SDimitry Andric     } else {
3815*fe6060f1SDimitry Andric       // We have empty reaching kernels, therefore we cannot tell if the
3816*fe6060f1SDimitry Andric       // associated call site can be folded. At this moment, SimplifiedValue
3817*fe6060f1SDimitry Andric       // must be none.
3818*fe6060f1SDimitry Andric       assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none");
3819*fe6060f1SDimitry Andric     }
3820*fe6060f1SDimitry Andric 
3821*fe6060f1SDimitry Andric     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
3822*fe6060f1SDimitry Andric                                                     : ChangeStatus::CHANGED;
3823*fe6060f1SDimitry Andric   }
3824*fe6060f1SDimitry Andric 
3825*fe6060f1SDimitry Andric   /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
3826*fe6060f1SDimitry Andric   ChangeStatus foldIsGenericMainThread(Attributor &A) {
3827*fe6060f1SDimitry Andric     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
3828*fe6060f1SDimitry Andric 
3829*fe6060f1SDimitry Andric     CallBase &CB = cast<CallBase>(getAssociatedValue());
3830*fe6060f1SDimitry Andric     Function *F = CB.getFunction();
3831*fe6060f1SDimitry Andric     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
3832*fe6060f1SDimitry Andric         *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3833*fe6060f1SDimitry Andric 
3834*fe6060f1SDimitry Andric     if (!ExecutionDomainAA.isValidState())
3835*fe6060f1SDimitry Andric       return indicatePessimisticFixpoint();
3836*fe6060f1SDimitry Andric 
3837*fe6060f1SDimitry Andric     auto &Ctx = getAnchorValue().getContext();
3838*fe6060f1SDimitry Andric     if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
3839*fe6060f1SDimitry Andric       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
3840*fe6060f1SDimitry Andric     else
3841*fe6060f1SDimitry Andric       return indicatePessimisticFixpoint();
3842*fe6060f1SDimitry Andric 
3843*fe6060f1SDimitry Andric     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
3844*fe6060f1SDimitry Andric                                                     : ChangeStatus::CHANGED;
3845*fe6060f1SDimitry Andric   }
3846*fe6060f1SDimitry Andric 
3847*fe6060f1SDimitry Andric   /// Fold __kmpc_parallel_level into a constant if possible.
3848*fe6060f1SDimitry Andric   ChangeStatus foldParallelLevel(Attributor &A) {
3849*fe6060f1SDimitry Andric     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
3850*fe6060f1SDimitry Andric 
3851*fe6060f1SDimitry Andric     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
3852*fe6060f1SDimitry Andric         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
3853*fe6060f1SDimitry Andric 
3854*fe6060f1SDimitry Andric     if (!CallerKernelInfoAA.ParallelLevels.isValidState())
3855*fe6060f1SDimitry Andric       return indicatePessimisticFixpoint();
3856*fe6060f1SDimitry Andric 
3857*fe6060f1SDimitry Andric     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
3858*fe6060f1SDimitry Andric       return indicatePessimisticFixpoint();
3859*fe6060f1SDimitry Andric 
3860*fe6060f1SDimitry Andric     if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
3861*fe6060f1SDimitry Andric       assert(!SimplifiedValue.hasValue() &&
3862*fe6060f1SDimitry Andric              "SimplifiedValue should keep none at this point");
3863*fe6060f1SDimitry Andric       return ChangeStatus::UNCHANGED;
3864*fe6060f1SDimitry Andric     }
3865*fe6060f1SDimitry Andric 
3866*fe6060f1SDimitry Andric     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
3867*fe6060f1SDimitry Andric     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
3868*fe6060f1SDimitry Andric     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
3869*fe6060f1SDimitry Andric       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
3870*fe6060f1SDimitry Andric                                           DepClassTy::REQUIRED);
3871*fe6060f1SDimitry Andric       if (!AA.SPMDCompatibilityTracker.isValidState())
3872*fe6060f1SDimitry Andric         return indicatePessimisticFixpoint();
3873*fe6060f1SDimitry Andric 
3874*fe6060f1SDimitry Andric       if (AA.SPMDCompatibilityTracker.isAssumed()) {
3875*fe6060f1SDimitry Andric         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
3876*fe6060f1SDimitry Andric           ++KnownSPMDCount;
3877*fe6060f1SDimitry Andric         else
3878*fe6060f1SDimitry Andric           ++AssumedSPMDCount;
3879*fe6060f1SDimitry Andric       } else {
3880*fe6060f1SDimitry Andric         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
3881*fe6060f1SDimitry Andric           ++KnownNonSPMDCount;
3882*fe6060f1SDimitry Andric         else
3883*fe6060f1SDimitry Andric           ++AssumedNonSPMDCount;
3884*fe6060f1SDimitry Andric       }
3885*fe6060f1SDimitry Andric     }
3886*fe6060f1SDimitry Andric 
3887*fe6060f1SDimitry Andric     if ((AssumedSPMDCount + KnownSPMDCount) &&
3888*fe6060f1SDimitry Andric         (AssumedNonSPMDCount + KnownNonSPMDCount))
3889*fe6060f1SDimitry Andric       return indicatePessimisticFixpoint();
3890*fe6060f1SDimitry Andric 
3891*fe6060f1SDimitry Andric     auto &Ctx = getAnchorValue().getContext();
3892*fe6060f1SDimitry Andric     // If the caller can only be reached by SPMD kernel entries, the parallel
3893*fe6060f1SDimitry Andric     // level is 1. Similarly, if the caller can only be reached by non-SPMD
3894*fe6060f1SDimitry Andric     // kernel entries, it is 0.
3895*fe6060f1SDimitry Andric     if (AssumedSPMDCount || KnownSPMDCount) {
3896*fe6060f1SDimitry Andric       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
3897*fe6060f1SDimitry Andric              "Expected only SPMD kernels!");
3898*fe6060f1SDimitry Andric       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
3899*fe6060f1SDimitry Andric     } else {
3900*fe6060f1SDimitry Andric       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
3901*fe6060f1SDimitry Andric              "Expected only non-SPMD kernels!");
3902*fe6060f1SDimitry Andric       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
3903*fe6060f1SDimitry Andric     }
3904*fe6060f1SDimitry Andric     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
3905*fe6060f1SDimitry Andric                                                     : ChangeStatus::CHANGED;
3906*fe6060f1SDimitry Andric   }
3907*fe6060f1SDimitry Andric 
3908*fe6060f1SDimitry Andric   ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
3909*fe6060f1SDimitry Andric     // Specialize only if all the calls agree with the attribute constant value
3910*fe6060f1SDimitry Andric     int32_t CurrentAttrValue = -1;
3911*fe6060f1SDimitry Andric     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
3912*fe6060f1SDimitry Andric 
3913*fe6060f1SDimitry Andric     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
3914*fe6060f1SDimitry Andric         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
3915*fe6060f1SDimitry Andric 
3916*fe6060f1SDimitry Andric     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
3917*fe6060f1SDimitry Andric       return indicatePessimisticFixpoint();
3918*fe6060f1SDimitry Andric 
3919*fe6060f1SDimitry Andric     // Iterate over the kernels that reach this function
3920*fe6060f1SDimitry Andric     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
3921*fe6060f1SDimitry Andric       int32_t NextAttrVal = -1;
3922*fe6060f1SDimitry Andric       if (K->hasFnAttribute(Attr))
3923*fe6060f1SDimitry Andric         NextAttrVal =
3924*fe6060f1SDimitry Andric             std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
3925*fe6060f1SDimitry Andric 
3926*fe6060f1SDimitry Andric       if (NextAttrVal == -1 ||
3927*fe6060f1SDimitry Andric           (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
3928*fe6060f1SDimitry Andric         return indicatePessimisticFixpoint();
3929*fe6060f1SDimitry Andric       CurrentAttrValue = NextAttrVal;
3930*fe6060f1SDimitry Andric     }
3931*fe6060f1SDimitry Andric 
3932*fe6060f1SDimitry Andric     if (CurrentAttrValue != -1) {
3933*fe6060f1SDimitry Andric       auto &Ctx = getAnchorValue().getContext();
3934*fe6060f1SDimitry Andric       SimplifiedValue =
3935*fe6060f1SDimitry Andric           ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
3936*fe6060f1SDimitry Andric     }
3937*fe6060f1SDimitry Andric     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
3938*fe6060f1SDimitry Andric                                                     : ChangeStatus::CHANGED;
3939*fe6060f1SDimitry Andric   }
3940*fe6060f1SDimitry Andric 
3941*fe6060f1SDimitry Andric   /// An optional value the associated value is assumed to fold to. That is, we
3942*fe6060f1SDimitry Andric   /// assume the associated value (which is a call) can be replaced by this
3943*fe6060f1SDimitry Andric   /// simplified value.
3944*fe6060f1SDimitry Andric   Optional<Value *> SimplifiedValue;
3945*fe6060f1SDimitry Andric 
3946*fe6060f1SDimitry Andric   /// The runtime function kind of the callee of the associated call site.
3947*fe6060f1SDimitry Andric   RuntimeFunction RFKind;
3948*fe6060f1SDimitry Andric };
3949*fe6060f1SDimitry Andric 
39505ffd83dbSDimitry Andric } // namespace
39515ffd83dbSDimitry Andric 
3952*fe6060f1SDimitry Andric /// Register folding callsite
3953*fe6060f1SDimitry Andric void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
3954*fe6060f1SDimitry Andric   auto &RFI = OMPInfoCache.RFIs[RF];
3955*fe6060f1SDimitry Andric   RFI.foreachUse(SCC, [&](Use &U, Function &F) {
3956*fe6060f1SDimitry Andric     CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3957*fe6060f1SDimitry Andric     if (!CI)
3958*fe6060f1SDimitry Andric       return false;
3959*fe6060f1SDimitry Andric     A.getOrCreateAAFor<AAFoldRuntimeCall>(
3960*fe6060f1SDimitry Andric         IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
3961*fe6060f1SDimitry Andric         DepClassTy::NONE, /* ForceUpdate */ false,
3962*fe6060f1SDimitry Andric         /* UpdateAfterInit */ false);
3963*fe6060f1SDimitry Andric     return false;
3964*fe6060f1SDimitry Andric   });
3965*fe6060f1SDimitry Andric }
3966*fe6060f1SDimitry Andric 
3967*fe6060f1SDimitry Andric void OpenMPOpt::registerAAs(bool IsModulePass) {
3968*fe6060f1SDimitry Andric   if (SCC.empty())
3969*fe6060f1SDimitry Andric 
3970*fe6060f1SDimitry Andric     return;
3971*fe6060f1SDimitry Andric   if (IsModulePass) {
3972*fe6060f1SDimitry Andric     // Ensure we create the AAKernelInfo AAs first and without triggering an
3973*fe6060f1SDimitry Andric     // update. This will make sure we register all value simplification
3974*fe6060f1SDimitry Andric     // callbacks before any other AA has the chance to create an AAValueSimplify
3975*fe6060f1SDimitry Andric     // or similar.
3976*fe6060f1SDimitry Andric     for (Function *Kernel : OMPInfoCache.Kernels)
3977*fe6060f1SDimitry Andric       A.getOrCreateAAFor<AAKernelInfo>(
3978*fe6060f1SDimitry Andric           IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
3979*fe6060f1SDimitry Andric           DepClassTy::NONE, /* ForceUpdate */ false,
3980*fe6060f1SDimitry Andric           /* UpdateAfterInit */ false);
3981*fe6060f1SDimitry Andric 
3982*fe6060f1SDimitry Andric 
3983*fe6060f1SDimitry Andric     registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
3984*fe6060f1SDimitry Andric     registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
3985*fe6060f1SDimitry Andric     registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
3986*fe6060f1SDimitry Andric     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
3987*fe6060f1SDimitry Andric     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
3988*fe6060f1SDimitry Andric   }
3989*fe6060f1SDimitry Andric 
3990*fe6060f1SDimitry Andric   // Create CallSite AA for all Getters.
3991*fe6060f1SDimitry Andric   for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
3992*fe6060f1SDimitry Andric     auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
3993*fe6060f1SDimitry Andric 
3994*fe6060f1SDimitry Andric     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
3995*fe6060f1SDimitry Andric 
3996*fe6060f1SDimitry Andric     auto CreateAA = [&](Use &U, Function &Caller) {
3997*fe6060f1SDimitry Andric       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
3998*fe6060f1SDimitry Andric       if (!CI)
3999*fe6060f1SDimitry Andric         return false;
4000*fe6060f1SDimitry Andric 
4001*fe6060f1SDimitry Andric       auto &CB = cast<CallBase>(*CI);
4002*fe6060f1SDimitry Andric 
4003*fe6060f1SDimitry Andric       IRPosition CBPos = IRPosition::callsite_function(CB);
4004*fe6060f1SDimitry Andric       A.getOrCreateAAFor<AAICVTracker>(CBPos);
4005*fe6060f1SDimitry Andric       return false;
4006*fe6060f1SDimitry Andric     };
4007*fe6060f1SDimitry Andric 
4008*fe6060f1SDimitry Andric     GetterRFI.foreachUse(SCC, CreateAA);
4009*fe6060f1SDimitry Andric   }
4010*fe6060f1SDimitry Andric   auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4011*fe6060f1SDimitry Andric   auto CreateAA = [&](Use &U, Function &F) {
4012*fe6060f1SDimitry Andric     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4013*fe6060f1SDimitry Andric     return false;
4014*fe6060f1SDimitry Andric   };
4015*fe6060f1SDimitry Andric   GlobalizationRFI.foreachUse(SCC, CreateAA);
4016*fe6060f1SDimitry Andric 
4017*fe6060f1SDimitry Andric   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4018*fe6060f1SDimitry Andric   // every function if there is a device kernel.
4019*fe6060f1SDimitry Andric   if (!isOpenMPDevice(M))
4020*fe6060f1SDimitry Andric     return;
4021*fe6060f1SDimitry Andric 
4022*fe6060f1SDimitry Andric   for (auto *F : SCC) {
4023*fe6060f1SDimitry Andric     if (F->isDeclaration())
4024*fe6060f1SDimitry Andric       continue;
4025*fe6060f1SDimitry Andric 
4026*fe6060f1SDimitry Andric     A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4027*fe6060f1SDimitry Andric     A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4028*fe6060f1SDimitry Andric 
4029*fe6060f1SDimitry Andric     for (auto &I : instructions(*F)) {
4030*fe6060f1SDimitry Andric       if (auto *LI = dyn_cast<LoadInst>(&I)) {
4031*fe6060f1SDimitry Andric         bool UsedAssumedInformation = false;
4032*fe6060f1SDimitry Andric         A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4033*fe6060f1SDimitry Andric                                UsedAssumedInformation);
4034*fe6060f1SDimitry Andric       }
4035*fe6060f1SDimitry Andric     }
4036*fe6060f1SDimitry Andric   }
4037*fe6060f1SDimitry Andric }
4038*fe6060f1SDimitry Andric 
40395ffd83dbSDimitry Andric const char AAICVTracker::ID = 0;
4040*fe6060f1SDimitry Andric const char AAKernelInfo::ID = 0;
4041*fe6060f1SDimitry Andric const char AAExecutionDomain::ID = 0;
4042*fe6060f1SDimitry Andric const char AAHeapToShared::ID = 0;
4043*fe6060f1SDimitry Andric const char AAFoldRuntimeCall::ID = 0;
40445ffd83dbSDimitry Andric 
40455ffd83dbSDimitry Andric AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
40465ffd83dbSDimitry Andric                                               Attributor &A) {
40475ffd83dbSDimitry Andric   AAICVTracker *AA = nullptr;
40485ffd83dbSDimitry Andric   switch (IRP.getPositionKind()) {
40495ffd83dbSDimitry Andric   case IRPosition::IRP_INVALID:
40505ffd83dbSDimitry Andric   case IRPosition::IRP_FLOAT:
40515ffd83dbSDimitry Andric   case IRPosition::IRP_ARGUMENT:
40525ffd83dbSDimitry Andric   case IRPosition::IRP_CALL_SITE_ARGUMENT:
40535ffd83dbSDimitry Andric     llvm_unreachable("ICVTracker can only be created for function position!");
4054e8d8bef9SDimitry Andric   case IRPosition::IRP_RETURNED:
4055e8d8bef9SDimitry Andric     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4056e8d8bef9SDimitry Andric     break;
4057e8d8bef9SDimitry Andric   case IRPosition::IRP_CALL_SITE_RETURNED:
4058e8d8bef9SDimitry Andric     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4059e8d8bef9SDimitry Andric     break;
4060e8d8bef9SDimitry Andric   case IRPosition::IRP_CALL_SITE:
4061e8d8bef9SDimitry Andric     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4062e8d8bef9SDimitry Andric     break;
40635ffd83dbSDimitry Andric   case IRPosition::IRP_FUNCTION:
40645ffd83dbSDimitry Andric     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
40655ffd83dbSDimitry Andric     break;
40665ffd83dbSDimitry Andric   }
40675ffd83dbSDimitry Andric 
40685ffd83dbSDimitry Andric   return *AA;
40695ffd83dbSDimitry Andric }
40705ffd83dbSDimitry Andric 
4071*fe6060f1SDimitry Andric AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
4072*fe6060f1SDimitry Andric                                                         Attributor &A) {
4073*fe6060f1SDimitry Andric   AAExecutionDomainFunction *AA = nullptr;
4074*fe6060f1SDimitry Andric   switch (IRP.getPositionKind()) {
4075*fe6060f1SDimitry Andric   case IRPosition::IRP_INVALID:
4076*fe6060f1SDimitry Andric   case IRPosition::IRP_FLOAT:
4077*fe6060f1SDimitry Andric   case IRPosition::IRP_ARGUMENT:
4078*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4079*fe6060f1SDimitry Andric   case IRPosition::IRP_RETURNED:
4080*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE_RETURNED:
4081*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE:
4082*fe6060f1SDimitry Andric     llvm_unreachable(
4083*fe6060f1SDimitry Andric         "AAExecutionDomain can only be created for function position!");
4084*fe6060f1SDimitry Andric   case IRPosition::IRP_FUNCTION:
4085*fe6060f1SDimitry Andric     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4086*fe6060f1SDimitry Andric     break;
4087*fe6060f1SDimitry Andric   }
4088*fe6060f1SDimitry Andric 
4089*fe6060f1SDimitry Andric   return *AA;
4090*fe6060f1SDimitry Andric }
4091*fe6060f1SDimitry Andric 
4092*fe6060f1SDimitry Andric AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4093*fe6060f1SDimitry Andric                                                   Attributor &A) {
4094*fe6060f1SDimitry Andric   AAHeapToSharedFunction *AA = nullptr;
4095*fe6060f1SDimitry Andric   switch (IRP.getPositionKind()) {
4096*fe6060f1SDimitry Andric   case IRPosition::IRP_INVALID:
4097*fe6060f1SDimitry Andric   case IRPosition::IRP_FLOAT:
4098*fe6060f1SDimitry Andric   case IRPosition::IRP_ARGUMENT:
4099*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4100*fe6060f1SDimitry Andric   case IRPosition::IRP_RETURNED:
4101*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE_RETURNED:
4102*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE:
4103*fe6060f1SDimitry Andric     llvm_unreachable(
4104*fe6060f1SDimitry Andric         "AAHeapToShared can only be created for function position!");
4105*fe6060f1SDimitry Andric   case IRPosition::IRP_FUNCTION:
4106*fe6060f1SDimitry Andric     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4107*fe6060f1SDimitry Andric     break;
4108*fe6060f1SDimitry Andric   }
4109*fe6060f1SDimitry Andric 
4110*fe6060f1SDimitry Andric   return *AA;
4111*fe6060f1SDimitry Andric }
4112*fe6060f1SDimitry Andric 
4113*fe6060f1SDimitry Andric AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4114*fe6060f1SDimitry Andric                                               Attributor &A) {
4115*fe6060f1SDimitry Andric   AAKernelInfo *AA = nullptr;
4116*fe6060f1SDimitry Andric   switch (IRP.getPositionKind()) {
4117*fe6060f1SDimitry Andric   case IRPosition::IRP_INVALID:
4118*fe6060f1SDimitry Andric   case IRPosition::IRP_FLOAT:
4119*fe6060f1SDimitry Andric   case IRPosition::IRP_ARGUMENT:
4120*fe6060f1SDimitry Andric   case IRPosition::IRP_RETURNED:
4121*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE_RETURNED:
4122*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4123*fe6060f1SDimitry Andric     llvm_unreachable("KernelInfo can only be created for function position!");
4124*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE:
4125*fe6060f1SDimitry Andric     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4126*fe6060f1SDimitry Andric     break;
4127*fe6060f1SDimitry Andric   case IRPosition::IRP_FUNCTION:
4128*fe6060f1SDimitry Andric     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4129*fe6060f1SDimitry Andric     break;
4130*fe6060f1SDimitry Andric   }
4131*fe6060f1SDimitry Andric 
4132*fe6060f1SDimitry Andric   return *AA;
4133*fe6060f1SDimitry Andric }
4134*fe6060f1SDimitry Andric 
4135*fe6060f1SDimitry Andric AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4136*fe6060f1SDimitry Andric                                                         Attributor &A) {
4137*fe6060f1SDimitry Andric   AAFoldRuntimeCall *AA = nullptr;
4138*fe6060f1SDimitry Andric   switch (IRP.getPositionKind()) {
4139*fe6060f1SDimitry Andric   case IRPosition::IRP_INVALID:
4140*fe6060f1SDimitry Andric   case IRPosition::IRP_FLOAT:
4141*fe6060f1SDimitry Andric   case IRPosition::IRP_ARGUMENT:
4142*fe6060f1SDimitry Andric   case IRPosition::IRP_RETURNED:
4143*fe6060f1SDimitry Andric   case IRPosition::IRP_FUNCTION:
4144*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE:
4145*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4146*fe6060f1SDimitry Andric     llvm_unreachable("KernelInfo can only be created for call site position!");
4147*fe6060f1SDimitry Andric   case IRPosition::IRP_CALL_SITE_RETURNED:
4148*fe6060f1SDimitry Andric     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4149*fe6060f1SDimitry Andric     break;
4150*fe6060f1SDimitry Andric   }
4151*fe6060f1SDimitry Andric 
4152*fe6060f1SDimitry Andric   return *AA;
4153*fe6060f1SDimitry Andric }
4154*fe6060f1SDimitry Andric 
4155*fe6060f1SDimitry Andric PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
4156*fe6060f1SDimitry Andric   if (!containsOpenMP(M))
4157*fe6060f1SDimitry Andric     return PreservedAnalyses::all();
4158*fe6060f1SDimitry Andric   if (DisableOpenMPOptimizations)
41595ffd83dbSDimitry Andric     return PreservedAnalyses::all();
41605ffd83dbSDimitry Andric 
4161*fe6060f1SDimitry Andric   FunctionAnalysisManager &FAM =
4162*fe6060f1SDimitry Andric       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
4163*fe6060f1SDimitry Andric   KernelSet Kernels = getDeviceKernels(M);
4164*fe6060f1SDimitry Andric 
4165*fe6060f1SDimitry Andric   auto IsCalled = [&](Function &F) {
4166*fe6060f1SDimitry Andric     if (Kernels.contains(&F))
4167*fe6060f1SDimitry Andric       return true;
4168*fe6060f1SDimitry Andric     for (const User *U : F.users())
4169*fe6060f1SDimitry Andric       if (!isa<BlockAddress>(U))
4170*fe6060f1SDimitry Andric         return true;
4171*fe6060f1SDimitry Andric     return false;
4172*fe6060f1SDimitry Andric   };
4173*fe6060f1SDimitry Andric 
4174*fe6060f1SDimitry Andric   auto EmitRemark = [&](Function &F) {
4175*fe6060f1SDimitry Andric     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
4176*fe6060f1SDimitry Andric     ORE.emit([&]() {
4177*fe6060f1SDimitry Andric       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
4178*fe6060f1SDimitry Andric       return ORA << "Could not internalize function. "
4179*fe6060f1SDimitry Andric                  << "Some optimizations may not be possible.";
4180*fe6060f1SDimitry Andric     });
4181*fe6060f1SDimitry Andric   };
4182*fe6060f1SDimitry Andric 
4183*fe6060f1SDimitry Andric   // Create internal copies of each function if this is a kernel Module. This
4184*fe6060f1SDimitry Andric   // allows iterprocedural passes to see every call edge.
4185*fe6060f1SDimitry Andric   DenseSet<const Function *> InternalizedFuncs;
4186*fe6060f1SDimitry Andric   if (isOpenMPDevice(M))
4187*fe6060f1SDimitry Andric     for (Function &F : M)
4188*fe6060f1SDimitry Andric       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
4189*fe6060f1SDimitry Andric           !DisableInternalization) {
4190*fe6060f1SDimitry Andric         if (Attributor::internalizeFunction(F, /* Force */ true)) {
4191*fe6060f1SDimitry Andric           InternalizedFuncs.insert(&F);
4192*fe6060f1SDimitry Andric         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
4193*fe6060f1SDimitry Andric           EmitRemark(F);
4194*fe6060f1SDimitry Andric         }
4195*fe6060f1SDimitry Andric       }
4196*fe6060f1SDimitry Andric 
4197*fe6060f1SDimitry Andric   // Look at every function in the Module unless it was internalized.
4198*fe6060f1SDimitry Andric   SmallVector<Function *, 16> SCC;
4199*fe6060f1SDimitry Andric   for (Function &F : M)
4200*fe6060f1SDimitry Andric     if (!F.isDeclaration() && !InternalizedFuncs.contains(&F))
4201*fe6060f1SDimitry Andric       SCC.push_back(&F);
4202*fe6060f1SDimitry Andric 
4203*fe6060f1SDimitry Andric   if (SCC.empty())
4204*fe6060f1SDimitry Andric     return PreservedAnalyses::all();
4205*fe6060f1SDimitry Andric 
4206*fe6060f1SDimitry Andric   AnalysisGetter AG(FAM);
4207*fe6060f1SDimitry Andric 
4208*fe6060f1SDimitry Andric   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4209*fe6060f1SDimitry Andric     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4210*fe6060f1SDimitry Andric   };
4211*fe6060f1SDimitry Andric 
4212*fe6060f1SDimitry Andric   BumpPtrAllocator Allocator;
4213*fe6060f1SDimitry Andric   CallGraphUpdater CGUpdater;
4214*fe6060f1SDimitry Andric 
4215*fe6060f1SDimitry Andric   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4216*fe6060f1SDimitry Andric   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
4217*fe6060f1SDimitry Andric 
4218*fe6060f1SDimitry Andric   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4219*fe6060f1SDimitry Andric   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
4220*fe6060f1SDimitry Andric                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4221*fe6060f1SDimitry Andric 
4222*fe6060f1SDimitry Andric   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4223*fe6060f1SDimitry Andric   bool Changed = OMPOpt.run(true);
4224*fe6060f1SDimitry Andric   if (Changed)
4225*fe6060f1SDimitry Andric     return PreservedAnalyses::none();
4226*fe6060f1SDimitry Andric 
4227*fe6060f1SDimitry Andric   return PreservedAnalyses::all();
4228*fe6060f1SDimitry Andric }
4229*fe6060f1SDimitry Andric 
4230*fe6060f1SDimitry Andric PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
4231*fe6060f1SDimitry Andric                                           CGSCCAnalysisManager &AM,
4232*fe6060f1SDimitry Andric                                           LazyCallGraph &CG,
4233*fe6060f1SDimitry Andric                                           CGSCCUpdateResult &UR) {
4234*fe6060f1SDimitry Andric   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
4235*fe6060f1SDimitry Andric     return PreservedAnalyses::all();
42365ffd83dbSDimitry Andric   if (DisableOpenMPOptimizations)
42375ffd83dbSDimitry Andric     return PreservedAnalyses::all();
42385ffd83dbSDimitry Andric 
42395ffd83dbSDimitry Andric   SmallVector<Function *, 16> SCC;
4240e8d8bef9SDimitry Andric   // If there are kernels in the module, we have to run on all SCC's.
4241e8d8bef9SDimitry Andric   for (LazyCallGraph::Node &N : C) {
4242e8d8bef9SDimitry Andric     Function *Fn = &N.getFunction();
4243e8d8bef9SDimitry Andric     SCC.push_back(Fn);
4244e8d8bef9SDimitry Andric   }
4245e8d8bef9SDimitry Andric 
4246*fe6060f1SDimitry Andric   if (SCC.empty())
42475ffd83dbSDimitry Andric     return PreservedAnalyses::all();
42485ffd83dbSDimitry Andric 
4249*fe6060f1SDimitry Andric   Module &M = *C.begin()->getFunction().getParent();
4250*fe6060f1SDimitry Andric 
4251*fe6060f1SDimitry Andric   KernelSet Kernels = getDeviceKernels(M);
4252*fe6060f1SDimitry Andric 
42535ffd83dbSDimitry Andric   FunctionAnalysisManager &FAM =
42545ffd83dbSDimitry Andric       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
42555ffd83dbSDimitry Andric 
42565ffd83dbSDimitry Andric   AnalysisGetter AG(FAM);
42575ffd83dbSDimitry Andric 
42585ffd83dbSDimitry Andric   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
42595ffd83dbSDimitry Andric     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
42605ffd83dbSDimitry Andric   };
42615ffd83dbSDimitry Andric 
4262*fe6060f1SDimitry Andric   BumpPtrAllocator Allocator;
42635ffd83dbSDimitry Andric   CallGraphUpdater CGUpdater;
42645ffd83dbSDimitry Andric   CGUpdater.initialize(CG, C, AM, UR);
42655ffd83dbSDimitry Andric 
42665ffd83dbSDimitry Andric   SetVector<Function *> Functions(SCC.begin(), SCC.end());
42675ffd83dbSDimitry Andric   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
4268*fe6060f1SDimitry Andric                                 /*CGSCC*/ Functions, Kernels);
42695ffd83dbSDimitry Andric 
4270*fe6060f1SDimitry Andric   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4271*fe6060f1SDimitry Andric   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4272*fe6060f1SDimitry Andric                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
42735ffd83dbSDimitry Andric 
42745ffd83dbSDimitry Andric   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4275*fe6060f1SDimitry Andric   bool Changed = OMPOpt.run(false);
42765ffd83dbSDimitry Andric   if (Changed)
42775ffd83dbSDimitry Andric     return PreservedAnalyses::none();
42785ffd83dbSDimitry Andric 
42795ffd83dbSDimitry Andric   return PreservedAnalyses::all();
42805ffd83dbSDimitry Andric }
42815ffd83dbSDimitry Andric 
42825ffd83dbSDimitry Andric namespace {
42835ffd83dbSDimitry Andric 
4284*fe6060f1SDimitry Andric struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
42855ffd83dbSDimitry Andric   CallGraphUpdater CGUpdater;
42865ffd83dbSDimitry Andric   static char ID;
42875ffd83dbSDimitry Andric 
4288*fe6060f1SDimitry Andric   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
4289*fe6060f1SDimitry Andric     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
42905ffd83dbSDimitry Andric   }
42915ffd83dbSDimitry Andric 
42925ffd83dbSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
42935ffd83dbSDimitry Andric     CallGraphSCCPass::getAnalysisUsage(AU);
42945ffd83dbSDimitry Andric   }
42955ffd83dbSDimitry Andric 
42965ffd83dbSDimitry Andric   bool runOnSCC(CallGraphSCC &CGSCC) override {
4297*fe6060f1SDimitry Andric     if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
42985ffd83dbSDimitry Andric       return false;
42995ffd83dbSDimitry Andric     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
43005ffd83dbSDimitry Andric       return false;
43015ffd83dbSDimitry Andric 
43025ffd83dbSDimitry Andric     SmallVector<Function *, 16> SCC;
4303e8d8bef9SDimitry Andric     // If there are kernels in the module, we have to run on all SCC's.
4304e8d8bef9SDimitry Andric     for (CallGraphNode *CGN : CGSCC) {
4305e8d8bef9SDimitry Andric       Function *Fn = CGN->getFunction();
4306e8d8bef9SDimitry Andric       if (!Fn || Fn->isDeclaration())
4307e8d8bef9SDimitry Andric         continue;
43085ffd83dbSDimitry Andric       SCC.push_back(Fn);
4309e8d8bef9SDimitry Andric     }
4310e8d8bef9SDimitry Andric 
4311*fe6060f1SDimitry Andric     if (SCC.empty())
43125ffd83dbSDimitry Andric       return false;
43135ffd83dbSDimitry Andric 
4314*fe6060f1SDimitry Andric     Module &M = CGSCC.getCallGraph().getModule();
4315*fe6060f1SDimitry Andric     KernelSet Kernels = getDeviceKernels(M);
4316*fe6060f1SDimitry Andric 
43175ffd83dbSDimitry Andric     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
43185ffd83dbSDimitry Andric     CGUpdater.initialize(CG, CGSCC);
43195ffd83dbSDimitry Andric 
43205ffd83dbSDimitry Andric     // Maintain a map of functions to avoid rebuilding the ORE
43215ffd83dbSDimitry Andric     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
43225ffd83dbSDimitry Andric     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
43235ffd83dbSDimitry Andric       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
43245ffd83dbSDimitry Andric       if (!ORE)
43255ffd83dbSDimitry Andric         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
43265ffd83dbSDimitry Andric       return *ORE;
43275ffd83dbSDimitry Andric     };
43285ffd83dbSDimitry Andric 
43295ffd83dbSDimitry Andric     AnalysisGetter AG;
43305ffd83dbSDimitry Andric     SetVector<Function *> Functions(SCC.begin(), SCC.end());
43315ffd83dbSDimitry Andric     BumpPtrAllocator Allocator;
4332*fe6060f1SDimitry Andric     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
4333*fe6060f1SDimitry Andric                                   Allocator,
4334*fe6060f1SDimitry Andric                                   /*CGSCC*/ Functions, Kernels);
43355ffd83dbSDimitry Andric 
4336*fe6060f1SDimitry Andric     unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4337*fe6060f1SDimitry Andric     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4338*fe6060f1SDimitry Andric                  MaxFixpointIterations, OREGetter, DEBUG_TYPE);
43395ffd83dbSDimitry Andric 
43405ffd83dbSDimitry Andric     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4341*fe6060f1SDimitry Andric     return OMPOpt.run(false);
43425ffd83dbSDimitry Andric   }
43435ffd83dbSDimitry Andric 
43445ffd83dbSDimitry Andric   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
43455ffd83dbSDimitry Andric };
43465ffd83dbSDimitry Andric 
43475ffd83dbSDimitry Andric } // end anonymous namespace
43485ffd83dbSDimitry Andric 
4349*fe6060f1SDimitry Andric KernelSet llvm::omp::getDeviceKernels(Module &M) {
4350*fe6060f1SDimitry Andric   // TODO: Create a more cross-platform way of determining device kernels.
43515ffd83dbSDimitry Andric   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4352*fe6060f1SDimitry Andric   KernelSet Kernels;
4353*fe6060f1SDimitry Andric 
43545ffd83dbSDimitry Andric   if (!MD)
4355*fe6060f1SDimitry Andric     return Kernels;
43565ffd83dbSDimitry Andric 
43575ffd83dbSDimitry Andric   for (auto *Op : MD->operands()) {
43585ffd83dbSDimitry Andric     if (Op->getNumOperands() < 2)
43595ffd83dbSDimitry Andric       continue;
43605ffd83dbSDimitry Andric     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
43615ffd83dbSDimitry Andric     if (!KindID || KindID->getString() != "kernel")
43625ffd83dbSDimitry Andric       continue;
43635ffd83dbSDimitry Andric 
43645ffd83dbSDimitry Andric     Function *KernelFn =
43655ffd83dbSDimitry Andric         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
43665ffd83dbSDimitry Andric     if (!KernelFn)
43675ffd83dbSDimitry Andric       continue;
43685ffd83dbSDimitry Andric 
43695ffd83dbSDimitry Andric     ++NumOpenMPTargetRegionKernels;
43705ffd83dbSDimitry Andric 
43715ffd83dbSDimitry Andric     Kernels.insert(KernelFn);
43725ffd83dbSDimitry Andric   }
4373*fe6060f1SDimitry Andric 
4374*fe6060f1SDimitry Andric   return Kernels;
43755ffd83dbSDimitry Andric }
43765ffd83dbSDimitry Andric 
4377*fe6060f1SDimitry Andric bool llvm::omp::containsOpenMP(Module &M) {
4378*fe6060f1SDimitry Andric   Metadata *MD = M.getModuleFlag("openmp");
4379*fe6060f1SDimitry Andric   if (!MD)
4380*fe6060f1SDimitry Andric     return false;
43815ffd83dbSDimitry Andric 
43825ffd83dbSDimitry Andric   return true;
43835ffd83dbSDimitry Andric }
43845ffd83dbSDimitry Andric 
4385*fe6060f1SDimitry Andric bool llvm::omp::isOpenMPDevice(Module &M) {
4386*fe6060f1SDimitry Andric   Metadata *MD = M.getModuleFlag("openmp-device");
4387*fe6060f1SDimitry Andric   if (!MD)
4388*fe6060f1SDimitry Andric     return false;
4389*fe6060f1SDimitry Andric 
4390*fe6060f1SDimitry Andric   return true;
43915ffd83dbSDimitry Andric }
43925ffd83dbSDimitry Andric 
4393*fe6060f1SDimitry Andric char OpenMPOptCGSCCLegacyPass::ID = 0;
43945ffd83dbSDimitry Andric 
4395*fe6060f1SDimitry Andric INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
43965ffd83dbSDimitry Andric                       "OpenMP specific optimizations", false, false)
43975ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
4398*fe6060f1SDimitry Andric INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
43995ffd83dbSDimitry Andric                     "OpenMP specific optimizations", false, false)
44005ffd83dbSDimitry Andric 
4401*fe6060f1SDimitry Andric Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
4402*fe6060f1SDimitry Andric   return new OpenMPOptCGSCCLegacyPass();
4403*fe6060f1SDimitry Andric }
4404