xref: /llvm-project/llvm/lib/Target/NVPTX/NVPTXLowerAggrCopies.cpp (revision ed8019d9fbed2e6a6b08f8f73e9fa54a24f3ed52)
1 //===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when
11 // the size is large or is not a compile-time constant.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "NVPTXLowerAggrCopies.h"
16 #include "llvm/Analysis/TargetTransformInfo.h"
17 #include "llvm/CodeGen/StackProtector.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/DataLayout.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/Intrinsics.h"
24 #include "llvm/IR/LLVMContext.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
27 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
28 
29 #define DEBUG_TYPE "nvptx"
30 
31 using namespace llvm;
32 
33 namespace {
34 
35 // actual analysis class, which is a functionpass
36 struct NVPTXLowerAggrCopies : public FunctionPass {
37   static char ID;
38 
39   NVPTXLowerAggrCopies() : FunctionPass(ID) {}
40 
41   void getAnalysisUsage(AnalysisUsage &AU) const override {
42     AU.addPreserved<StackProtector>();
43     AU.addRequired<TargetTransformInfoWrapperPass>();
44   }
45 
46   bool runOnFunction(Function &F) override;
47 
48   static const unsigned MaxAggrCopySize = 128;
49 
50   StringRef getPassName() const override {
51     return "Lower aggregate copies/intrinsics into loops";
52   }
53 };
54 
55 char NVPTXLowerAggrCopies::ID = 0;
56 
57 bool NVPTXLowerAggrCopies::runOnFunction(Function &F) {
58   SmallVector<LoadInst *, 4> AggrLoads;
59   SmallVector<MemIntrinsic *, 4> MemCalls;
60 
61   const DataLayout &DL = F.getDataLayout();
62   LLVMContext &Context = F.getParent()->getContext();
63   const TargetTransformInfo &TTI =
64       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
65 
66   // Collect all aggregate loads and mem* calls.
67   for (BasicBlock &BB : F) {
68     for (Instruction &I : BB) {
69       if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
70         if (!LI->hasOneUse())
71           continue;
72 
73         if (DL.getTypeStoreSize(LI->getType()) < MaxAggrCopySize)
74           continue;
75 
76         if (StoreInst *SI = dyn_cast<StoreInst>(LI->user_back())) {
77           if (SI->getOperand(0) != LI)
78             continue;
79           AggrLoads.push_back(LI);
80         }
81       } else if (MemIntrinsic *IntrCall = dyn_cast<MemIntrinsic>(&I)) {
82         // Convert intrinsic calls with variable size or with constant size
83         // larger than the MaxAggrCopySize threshold.
84         if (ConstantInt *LenCI = dyn_cast<ConstantInt>(IntrCall->getLength())) {
85           if (LenCI->getZExtValue() >= MaxAggrCopySize) {
86             MemCalls.push_back(IntrCall);
87           }
88         } else {
89           MemCalls.push_back(IntrCall);
90         }
91       }
92     }
93   }
94 
95   if (AggrLoads.size() == 0 && MemCalls.size() == 0) {
96     return false;
97   }
98 
99   //
100   // Do the transformation of an aggr load/copy/set to a loop
101   //
102   for (LoadInst *LI : AggrLoads) {
103     auto *SI = cast<StoreInst>(*LI->user_begin());
104     Value *SrcAddr = LI->getOperand(0);
105     Value *DstAddr = SI->getOperand(1);
106     unsigned NumLoads = DL.getTypeStoreSize(LI->getType());
107     ConstantInt *CopyLen =
108         ConstantInt::get(Type::getInt32Ty(Context), NumLoads);
109 
110     createMemCpyLoopKnownSize(/* ConvertedInst */ SI,
111                               /* SrcAddr */ SrcAddr, /* DstAddr */ DstAddr,
112                               /* CopyLen */ CopyLen,
113                               /* SrcAlign */ LI->getAlign(),
114                               /* DestAlign */ SI->getAlign(),
115                               /* SrcIsVolatile */ LI->isVolatile(),
116                               /* DstIsVolatile */ SI->isVolatile(),
117                               /* CanOverlap */ true, TTI);
118 
119     SI->eraseFromParent();
120     LI->eraseFromParent();
121   }
122 
123   // Transform mem* intrinsic calls.
124   for (MemIntrinsic *MemCall : MemCalls) {
125     if (MemCpyInst *Memcpy = dyn_cast<MemCpyInst>(MemCall)) {
126       expandMemCpyAsLoop(Memcpy, TTI);
127     } else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(MemCall)) {
128       expandMemMoveAsLoop(Memmove, TTI);
129     } else if (MemSetInst *Memset = dyn_cast<MemSetInst>(MemCall)) {
130       expandMemSetAsLoop(Memset);
131     }
132     MemCall->eraseFromParent();
133   }
134 
135   return true;
136 }
137 
138 } // namespace
139 
140 namespace llvm {
141 void initializeNVPTXLowerAggrCopiesPass(PassRegistry &);
142 }
143 
144 INITIALIZE_PASS(NVPTXLowerAggrCopies, "nvptx-lower-aggr-copies",
145                 "Lower aggregate copies, and llvm.mem* intrinsics into loops",
146                 false, false)
147 
148 FunctionPass *llvm::createLowerAggrCopies() {
149   return new NVPTXLowerAggrCopies();
150 }
151