xref: /llvm-project/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp (revision c95ffadb2474a4d8c4f598d94d35a9f31d9606cb)
1 //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
2 //                  Set Load/Store Alignments From Assumptions
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a ScalarEvolution-based transformation to set
11 // the alignments of load, stores and memory intrinsics based on the truth
12 // expressions of assume intrinsics. The primary motivation is to handle
13 // complex alignment assumptions that apply to vector loads and stores that
14 // appear after vectorization and unrolling.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/InitializePasses.h"
19 #define AA_NAME "alignment-from-assumptions"
20 #define DEBUG_TYPE AA_NAME
21 #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/Statistic.h"
24 #include "llvm/Analysis/AliasAnalysis.h"
25 #include "llvm/Analysis/AssumptionCache.h"
26 #include "llvm/Analysis/GlobalsModRef.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
29 #include "llvm/Analysis/ValueTracking.h"
30 #include "llvm/IR/Constant.h"
31 #include "llvm/IR/Dominators.h"
32 #include "llvm/IR/Instruction.h"
33 #include "llvm/IR/IntrinsicInst.h"
34 #include "llvm/IR/Intrinsics.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "llvm/Transforms/Scalar.h"
39 using namespace llvm;
40 
41 STATISTIC(NumLoadAlignChanged,
42   "Number of loads changed by alignment assumptions");
43 STATISTIC(NumStoreAlignChanged,
44   "Number of stores changed by alignment assumptions");
45 STATISTIC(NumMemIntAlignChanged,
46   "Number of memory intrinsics changed by alignment assumptions");
47 
48 namespace {
49 struct AlignmentFromAssumptions : public FunctionPass {
50   static char ID; // Pass identification, replacement for typeid
51   AlignmentFromAssumptions() : FunctionPass(ID) {
52     initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
53   }
54 
55   bool runOnFunction(Function &F) override;
56 
57   void getAnalysisUsage(AnalysisUsage &AU) const override {
58     AU.addRequired<AssumptionCacheTracker>();
59     AU.addRequired<ScalarEvolutionWrapperPass>();
60     AU.addRequired<DominatorTreeWrapperPass>();
61 
62     AU.setPreservesCFG();
63     AU.addPreserved<AAResultsWrapperPass>();
64     AU.addPreserved<GlobalsAAWrapperPass>();
65     AU.addPreserved<LoopInfoWrapperPass>();
66     AU.addPreserved<DominatorTreeWrapperPass>();
67     AU.addPreserved<ScalarEvolutionWrapperPass>();
68   }
69 
70   AlignmentFromAssumptionsPass Impl;
71 };
72 }
73 
74 char AlignmentFromAssumptions::ID = 0;
75 static const char aip_name[] = "Alignment from assumptions";
76 INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME,
77                       aip_name, false, false)
78 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
79 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
80 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
81 INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME,
82                     aip_name, false, false)
83 
84 FunctionPass *llvm::createAlignmentFromAssumptionsPass() {
85   return new AlignmentFromAssumptions();
86 }
87 
88 // Given an expression for the (constant) alignment, AlignSCEV, and an
89 // expression for the displacement between a pointer and the aligned address,
90 // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
91 // to a constant. Using SCEV to compute alignment handles the case where
92 // DiffSCEV is a recurrence with constant start such that the aligned offset
93 // is constant. e.g. {16,+,32} % 32 -> 16.
94 static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV,
95                                       const SCEV *AlignSCEV,
96                                       ScalarEvolution *SE) {
97   // DiffUnits = Diff % int64_t(Alignment)
98   const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV);
99 
100   LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is "
101                     << *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");
102 
103   if (const SCEVConstant *ConstDUSCEV =
104       dyn_cast<SCEVConstant>(DiffUnitsSCEV)) {
105     int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue();
106 
107     // If the displacement is an exact multiple of the alignment, then the
108     // displaced pointer has the same alignment as the aligned pointer, so
109     // return the alignment value.
110     if (!DiffUnits)
111       return cast<SCEVConstant>(AlignSCEV)->getValue()->getAlignValue();
112 
113     // If the displacement is not an exact multiple, but the remainder is a
114     // constant, then return this remainder (but only if it is a power of 2).
115     uint64_t DiffUnitsAbs = std::abs(DiffUnits);
116     if (isPowerOf2_64(DiffUnitsAbs))
117       return Align(DiffUnitsAbs);
118   }
119 
120   return None;
121 }
122 
123 // There is an address given by an offset OffSCEV from AASCEV which has an
124 // alignment AlignSCEV. Use that information, if possible, to compute a new
125 // alignment for Ptr.
126 static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
127                              const SCEV *OffSCEV, Value *Ptr,
128                              ScalarEvolution *SE) {
129   const SCEV *PtrSCEV = SE->getSCEV(Ptr);
130   // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes
131   // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV
132   // may disagree. Trunc/extend so they agree.
133   PtrSCEV = SE->getTruncateOrZeroExtend(
134       PtrSCEV, SE->getEffectiveSCEVType(AASCEV->getType()));
135   const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
136 
137   // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
138   // sign-extended OffSCEV to i64, so make sure they agree again.
139   DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType());
140 
141   // What we really want to know is the overall offset to the aligned
142   // address. This address is displaced by the provided offset.
143   DiffSCEV = SE->getMinusSCEV(DiffSCEV, OffSCEV);
144 
145   LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to "
146                     << *AlignSCEV << " and offset " << *OffSCEV
147                     << " using diff " << *DiffSCEV << "\n");
148 
149   if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) {
150     LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n");
151     return *NewAlignment;
152   }
153 
154   if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
155     // The relative offset to the alignment assumption did not yield a constant,
156     // but we should try harder: if we assume that a is 32-byte aligned, then in
157     // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
158     // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
159     // As a result, the new alignment will not be a constant, but can still
160     // be improved over the default (of 4) to 16.
161 
162     const SCEV *DiffStartSCEV = DiffARSCEV->getStart();
163     const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE);
164 
165     LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "
166                       << *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n");
167 
168     // Now compute the new alignment using the displacement to the value in the
169     // first iteration, and also the alignment using the per-iteration delta.
170     // If these are the same, then use that answer. Otherwise, use the smaller
171     // one, but only if it divides the larger one.
172     MaybeAlign NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
173     MaybeAlign NewIncAlignment =
174         getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
175 
176     LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment)
177                       << "\n");
178     LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment)
179                       << "\n");
180 
181     if (!NewAlignment || !NewIncAlignment)
182       return Align(1);
183 
184     const Align NewAlign = *NewAlignment;
185     const Align NewIncAlign = *NewIncAlignment;
186     if (NewAlign > NewIncAlign) {
187       LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: "
188                         << DebugStr(NewIncAlign) << "\n");
189       return NewIncAlign;
190     }
191     if (NewIncAlign > NewAlign) {
192       LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
193                         << "\n");
194       return NewAlign;
195     }
196     assert(NewIncAlign == NewAlign);
197     LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
198                       << "\n");
199     return NewAlign;
200   }
201 
202   return Align(1);
203 }
204 
205 bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I,
206                                                         Value *&AAPtr,
207                                                         const SCEV *&AlignSCEV,
208                                                         const SCEV *&OffSCEV) {
209   Type *Int64Ty = Type::getInt64Ty(I->getContext());
210   Optional<OperandBundleUse> AlignOB = I->getOperandBundle("align");
211   if (AlignOB.hasValue()) {
212     assert(AlignOB.getValue().Inputs.size() >= 2);
213     AAPtr = AlignOB.getValue().Inputs[0].get();
214     // TODO: Consider accumulating the offset to the base.
215     AAPtr = AAPtr->stripPointerCastsSameRepresentation();
216     AlignSCEV = SE->getSCEV(AlignOB.getValue().Inputs[1].get());
217     AlignSCEV = SE->getTruncateOrZeroExtend(AlignSCEV, Int64Ty);
218     if (AlignOB.getValue().Inputs.size() == 3)
219       OffSCEV = SE->getSCEV(AlignOB.getValue().Inputs[2].get());
220     else
221       OffSCEV = SE->getZero(Int64Ty);
222     OffSCEV = SE->getTruncateOrZeroExtend(OffSCEV, Int64Ty);
223     return true;
224   }
225   return false;
226 }
227 
228 bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
229   Value *AAPtr;
230   const SCEV *AlignSCEV, *OffSCEV;
231   if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV))
232     return false;
233 
234   // Skip ConstantPointerNull and UndefValue.  Assumptions on these shouldn't
235   // affect other users.
236   if (isa<ConstantData>(AAPtr))
237     return false;
238 
239   const SCEV *AASCEV = SE->getSCEV(AAPtr);
240 
241   // Apply the assumption to all other users of the specified pointer.
242   SmallPtrSet<Instruction *, 32> Visited;
243   SmallVector<Instruction*, 16> WorkList;
244   for (User *J : AAPtr->users()) {
245     if (J == ACall)
246       continue;
247 
248     if (Instruction *K = dyn_cast<Instruction>(J))
249         WorkList.push_back(K);
250   }
251 
252   while (!WorkList.empty()) {
253     Instruction *J = WorkList.pop_back_val();
254     if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
255       if (!isValidAssumeForContext(ACall, J, DT))
256         continue;
257       Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
258                                            LI->getPointerOperand(), SE);
259       if (NewAlignment > LI->getAlign()) {
260         LI->setAlignment(NewAlignment);
261         ++NumLoadAlignChanged;
262       }
263     } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
264       if (!isValidAssumeForContext(ACall, J, DT))
265         continue;
266       Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
267                                            SI->getPointerOperand(), SE);
268       if (NewAlignment > SI->getAlign()) {
269         SI->setAlignment(NewAlignment);
270         ++NumStoreAlignChanged;
271       }
272     } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
273       if (!isValidAssumeForContext(ACall, J, DT))
274         continue;
275       Align NewDestAlignment =
276           getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE);
277 
278       LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment)
279                         << "\n";);
280       if (NewDestAlignment > *MI->getDestAlign()) {
281         MI->setDestAlignment(NewDestAlignment);
282         ++NumMemIntAlignChanged;
283       }
284 
285       // For memory transfers, there is also a source alignment that
286       // can be set.
287       if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
288         Align NewSrcAlignment =
289             getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE);
290 
291         LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment)
292                           << "\n";);
293 
294         if (NewSrcAlignment > *MTI->getSourceAlign()) {
295           MTI->setSourceAlignment(NewSrcAlignment);
296           ++NumMemIntAlignChanged;
297         }
298       }
299     }
300 
301     // Now that we've updated that use of the pointer, look for other uses of
302     // the pointer to update.
303     Visited.insert(J);
304     for (User *UJ : J->users()) {
305       Instruction *K = cast<Instruction>(UJ);
306       if (!Visited.count(K))
307         WorkList.push_back(K);
308     }
309   }
310 
311   return true;
312 }
313 
314 bool AlignmentFromAssumptions::runOnFunction(Function &F) {
315   if (skipFunction(F))
316     return false;
317 
318   auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
319   ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
320   DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
321 
322   return Impl.runImpl(F, AC, SE, DT);
323 }
324 
325 bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
326                                            ScalarEvolution *SE_,
327                                            DominatorTree *DT_) {
328   SE = SE_;
329   DT = DT_;
330 
331   bool Changed = false;
332   for (auto &AssumeVH : AC.assumptions())
333     if (AssumeVH)
334       Changed |= processAssumption(cast<CallInst>(AssumeVH));
335 
336   return Changed;
337 }
338 
339 PreservedAnalyses
340 AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) {
341 
342   AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
343   ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
344   DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
345   if (!runImpl(F, AC, &SE, &DT))
346     return PreservedAnalyses::all();
347 
348   PreservedAnalyses PA;
349   PA.preserveSet<CFGAnalyses>();
350   PA.preserve<AAManager>();
351   PA.preserve<ScalarEvolutionAnalysis>();
352   PA.preserve<GlobalsAA>();
353   return PA;
354 }
355