xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp (revision 25b90c4ef67a01de6eba4f9e160d33772eb53454)
1 //===- SeedCollector.cpp  -0000000-----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
10 #include "llvm/Analysis/LoopAccessAnalysis.h"
11 #include "llvm/Analysis/ValueTracking.h"
12 #include "llvm/IR/Type.h"
13 #include "llvm/SandboxIR/Instruction.h"
14 #include "llvm/SandboxIR/Utils.h"
15 #include "llvm/Support/Debug.h"
16 
17 using namespace llvm;
18 namespace llvm::sandboxir {
19 
20 cl::opt<unsigned> SeedBundleSizeLimit(
21     "sbvec-seed-bundle-size-limit", cl::init(32), cl::Hidden,
22     cl::desc("Limit the size of the seed bundle to cap compilation time."));
23 #define LoadSeedsDef "loads"
24 #define StoreSeedsDef "stores"
25 cl::opt<std::string> CollectSeeds(
26     "sbvec-collect-seeds", cl::init(LoadSeedsDef "," StoreSeedsDef), cl::Hidden,
27     cl::desc("Collect these seeds. Use empty for none or a comma-separated "
28              "list of '" LoadSeedsDef "' and '" StoreSeedsDef "'."));
29 cl::opt<unsigned> SeedGroupsLimit(
30     "sbvec-seed-groups-limit", cl::init(256), cl::Hidden,
31     cl::desc("Limit the number of collected seeds groups in a BB to "
32              "cap compilation time."));
33 
34 ArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
35                                              unsigned MaxVecRegBits,
36                                              bool ForcePowerOf2) {
37   // Use uint32_t here for compatibility with IsPowerOf2_32
38 
39   // BitCount tracks the size of the working slice. From that we can tell
40   // when the working slice's size is a power-of-two and when it exceeds
41   // the legal size in MaxVecBits.
42   uint32_t BitCount = 0;
43   uint32_t NumElements = 0;
44   // Tracks the most recent slice where NumElements gave a power-of-2 BitCount
45   uint32_t NumElementsPowerOfTwo = 0;
46   uint32_t BitCountPowerOfTwo = 0;
47   // Can't start a slice with a used instruction.
48   assert(!isUsed(StartIdx) && "Expected unused at StartIdx");
49   for (auto S : make_range(Seeds.begin() + StartIdx, Seeds.end())) {
50     // Stop if this instruction is used. This needs to be done before
51     // getNumBits() because a "used" instruction may have been erased.
52     if (isUsed(StartIdx + NumElements))
53       break;
54     uint32_t InstBits = Utils::getNumBits(S);
55     // Stop if adding it puts the slice over the limit.
56     if (BitCount + InstBits > MaxVecRegBits)
57       break;
58     NumElements++;
59     BitCount += InstBits;
60     if (ForcePowerOf2 && isPowerOf2_32(BitCount)) {
61       NumElementsPowerOfTwo = NumElements;
62       BitCountPowerOfTwo = BitCount;
63     }
64   }
65   if (ForcePowerOf2) {
66     NumElements = NumElementsPowerOfTwo;
67     BitCount = BitCountPowerOfTwo;
68   }
69 
70   // Return any non-empty slice
71   if (NumElements > 1) {
72     assert((!ForcePowerOf2 || isPowerOf2_32(BitCount)) &&
73            "Must be a power of two");
74     return ArrayRef<Instruction *>(&Seeds[StartIdx], NumElements);
75   }
76   return {};
77 }
78 
79 template <typename LoadOrStoreT>
80 SeedContainer::KeyT SeedContainer::getKey(LoadOrStoreT *LSI) const {
81   assert((isa<LoadInst>(LSI) || isa<StoreInst>(LSI)) &&
82          "Expected Load or Store!");
83   Value *Ptr = Utils::getMemInstructionBase(LSI);
84   Instruction::Opcode Op = LSI->getOpcode();
85   Type *Ty = Utils::getExpectedType(LSI);
86   if (auto *VTy = dyn_cast<VectorType>(Ty))
87     Ty = VTy->getElementType();
88   return {Ptr, Ty, Op};
89 }
90 
91 // Explicit instantiations
92 template SeedContainer::KeyT
93 SeedContainer::getKey<LoadInst>(LoadInst *LSI) const;
94 template SeedContainer::KeyT
95 SeedContainer::getKey<StoreInst>(StoreInst *LSI) const;
96 
97 bool SeedContainer::erase(Instruction *I) {
98   assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && "Expected Load or Store!");
99   auto It = SeedLookupMap.find(I);
100   if (It == SeedLookupMap.end())
101     return false;
102   SeedBundle *Bndl = It->second;
103   Bndl->setUsed(I);
104   return true;
105 }
106 
107 template <typename LoadOrStoreT> void SeedContainer::insert(LoadOrStoreT *LSI) {
108   // Find the bundle containing seeds for this symbol and type-of-access.
109   auto &BundleVec = Bundles[getKey(LSI)];
110   // Fill this vector of bundles front to back so that only the last bundle in
111   // the vector may have available space. This avoids iteration to find one with
112   // space.
113   if (BundleVec.empty() || BundleVec.back()->size() == SeedBundleSizeLimit)
114     BundleVec.emplace_back(std::make_unique<MemSeedBundle<LoadOrStoreT>>(LSI));
115   else
116     BundleVec.back()->insert(LSI, SE);
117 
118   SeedLookupMap[LSI] = BundleVec.back().get();
119 }
120 
121 // Explicit instantiations
122 template void SeedContainer::insert<LoadInst>(LoadInst *);
123 template void SeedContainer::insert<StoreInst>(StoreInst *);
124 
125 #ifndef NDEBUG
126 void SeedContainer::print(raw_ostream &OS) const {
127   for (const auto &Pair : Bundles) {
128     auto [I, Ty, Opc] = Pair.first;
129     const auto &SeedsVec = Pair.second;
130     std::string RefType = dyn_cast<LoadInst>(I)    ? "Load"
131                           : dyn_cast<StoreInst>(I) ? "Store"
132                                                    : "Other";
133     OS << "[Inst=" << *I << " Ty=" << Ty << " " << RefType << "]\n";
134     for (const auto &SeedPtr : SeedsVec) {
135       SeedPtr->dump(OS);
136       OS << "\n";
137     }
138   }
139   OS << "\n";
140 }
141 
142 LLVM_DUMP_METHOD void SeedContainer::dump() const { print(dbgs()); }
143 #endif // NDEBUG
144 
145 template <typename LoadOrStoreT> static bool isValidMemSeed(LoadOrStoreT *LSI) {
146   if (!LSI->isSimple())
147     return false;
148   auto *Ty = Utils::getExpectedType(LSI);
149   // Omit types that are architecturally unvectorizable
150   if (Ty->isX86_FP80Ty() || Ty->isPPC_FP128Ty())
151     return false;
152   // Omit vector types without compile-time-known lane counts
153   if (isa<ScalableVectorType>(Ty))
154     return false;
155   if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
156     return VectorType::isValidElementType(VTy->getElementType());
157   return VectorType::isValidElementType(Ty);
158 }
159 
160 template bool isValidMemSeed<LoadInst>(LoadInst *LSI);
161 template bool isValidMemSeed<StoreInst>(StoreInst *LSI);
162 
163 SeedCollector::SeedCollector(BasicBlock *BB, ScalarEvolution &SE)
164     : StoreSeeds(SE), LoadSeeds(SE), Ctx(BB->getContext()) {
165 
166   bool CollectStores = CollectSeeds.find(StoreSeedsDef) != std::string::npos;
167   bool CollectLoads = CollectSeeds.find(LoadSeedsDef) != std::string::npos;
168   if (!CollectStores && !CollectLoads)
169     return;
170 
171   EraseCallbackID = Ctx.registerEraseInstrCallback([this](Instruction *I) {
172     if (auto SI = dyn_cast<StoreInst>(I))
173       StoreSeeds.erase(SI);
174     else if (auto LI = dyn_cast<LoadInst>(I))
175       LoadSeeds.erase(LI);
176   });
177 
178   // Actually collect the seeds.
179   for (auto &I : *BB) {
180     if (StoreInst *SI = dyn_cast<StoreInst>(&I))
181       if (CollectStores && isValidMemSeed(SI))
182         StoreSeeds.insert(SI);
183     if (LoadInst *LI = dyn_cast<LoadInst>(&I))
184       if (CollectLoads && isValidMemSeed(LI))
185         LoadSeeds.insert(LI);
186     // Cap compilation time.
187     if (totalNumSeedGroups() > SeedGroupsLimit)
188       break;
189   }
190 }
191 
192 SeedCollector::~SeedCollector() {
193   Ctx.unregisterEraseInstrCallback(EraseCallbackID);
194 }
195 
196 #ifndef NDEBUG
197 void SeedCollector::print(raw_ostream &OS) const {
198   OS << "=== StoreSeeds ===\n";
199   StoreSeeds.print(OS);
200   OS << "=== LoadSeeds ===\n";
201   LoadSeeds.print(OS);
202 }
203 
204 void SeedCollector::dump() const { print(dbgs()); }
205 #endif
206 
207 } // namespace llvm::sandboxir
208