xref: /llvm-project/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp (revision 9cbc1f29cabc01c02a523c11d098c00650f6955c)
1 //===- GPUHeuristics.cpp - Heuristics Implementation for Transforms -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
10 
11 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
12 #include "llvm/ADT/ArrayRef.h"
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Support/CommandLine.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/MathExtras.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include <cmath>
19 #include <numeric>
20 
21 using namespace mlir;
22 
23 #define DEBUG_TYPE "linalg-transforms"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
26 
27 static Attribute linearId0(MLIRContext *ctx) {
28   return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim0);
29 }
30 static Attribute linearId1(MLIRContext *ctx) {
31   return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim1);
32 }
33 static Attribute linearId2(MLIRContext *ctx) {
34   return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim2);
35 }
36 
37 transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx,
38                                                  int totalNumThreads,
39                                                  int64_t desiredBitAlignment,
40                                                  ArrayRef<int64_t> copySizes,
41                                                  bool favorPredication,
42                                                  int64_t elementalBitwidth) {
43   assert(!copySizes.empty() && copySizes.size() <= 3 &&
44          "only 1,2,3-D copies are supported for now");
45 
46   LDBG("START CopyMappingInfo, favorPredication: " << favorPredication);
47   LLVM_DEBUG(llvm::interleaveComma(copySizes, DBGS() << "--copy shape: ");
48              llvm::dbgs() << "\n";);
49 
50   // Greedily find the largest vector size that can be used to copy the most
51   // minor dimension: we are in the business of filling kMaxVectorLoadBitWidth
52   // contiguous memory transactions with as few threads as possible.
53   int64_t desiredVectorSize = CopyMappingInfo::maxContiguousElementsToTransfer(
54       desiredBitAlignment, copySizes.back(), elementalBitwidth);
55 
56   LDBG("--greedily determined vectorSize: "
57        << desiredVectorSize << " elements of " << elementalBitwidth
58        << "b each -> " << (desiredVectorSize * elementalBitwidth)
59        << "b total out of a max of " << kMaxVectorLoadBitWidth << "b");
60 
61   status = inferNumThreads(totalNumThreads, copySizes, desiredVectorSize,
62                            favorPredication);
63   if (status == Status::Invalid)
64     return;
65 
66   LLVM_DEBUG(llvm::interleaveComma(copySizes, DBGS() << "--copy: ");
67              llvm::dbgs() << "\n"; llvm::interleaveComma(
68                  this->numThreads, DBGS() << "--numThreads: ");
69              llvm::dbgs() << "\n";);
70   LDBG("--vectorSize: " << this->vectorSize);
71   assert(this->numThreads.size() == copySizes.size() &&
72          "compute copy mapping expected same number of threads and copy sizes");
73 
74   // Compute the smallest bounding box.
75   this->smallestBoundingTileSizes = llvm::to_vector(
76       llvm::map_range(llvm::zip(copySizes, this->numThreads), [](auto &&pair) {
77         int64_t size, numThreads;
78         std::tie(size, numThreads) = pair;
79         return llvm::divideCeilSigned(size, numThreads);
80       }));
81   SmallVector<Attribute> allThreadMappings{linearId2(ctx), linearId1(ctx),
82                                            linearId0(ctx)};
83 
84   // Set the thread mapping.
85   this->threadMapping =
86       llvm::to_vector(ArrayRef(allThreadMappings)
87                           .take_back(this->smallestBoundingTileSizes.size()));
88   LLVM_DEBUG(this->print(DBGS()); llvm::dbgs() << "\n");
89 }
90 
91 int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer(
92     int64_t desiredBitAlignment, int64_t numContiguousElements,
93     int64_t elementalBitwidth) {
94   assert(kMaxVectorLoadBitWidth % elementalBitwidth == 0 &&
95          "elemental bitwidth does not divide kMaxVectorLoadBitWidth");
96   assert(desiredBitAlignment % elementalBitwidth == 0 &&
97          "elemental bitwidth does not divide desired bit alignment");
98   return std::gcd(
99       std::gcd(desiredBitAlignment / elementalBitwidth, numContiguousElements),
100       kMaxVectorLoadBitWidth / elementalBitwidth);
101 }
102 
103 /// Get the list of all factors that divide `val`, not just the prime factors.
104 static SmallVector<int64_t> getFactors(int64_t val) {
105   SmallVector<int64_t> factors;
106   factors.reserve(val);
107   for (int64_t factor = 1; factor <= val; ++factor) {
108     if (val % factor != 0)
109       continue;
110     factors.push_back(factor);
111   }
112   factors.push_back(val);
113   return factors;
114 }
115 
116 static int64_t product(ArrayRef<int64_t> vals) {
117   int64_t res = 1;
118   for (auto val : vals)
119     res *= val;
120   return res;
121 }
122 
123 /// Extract `result` from `sizes` with the following constraints:
124 ///   1. sizes[i] % result[i] for all i
125 ///   2. product_of_threadsPerDim <= maxNumThreads
126 ///   3. if `currentIndex` is sizes.size() - 1, then threadsPerDim[currentIndex]
127 ///      must be sizes[currentIndex].
128 /// This is used to greedily extract the maximum number of threads usable for
129 /// mapping a copy of size `sizes`, while being bounded by `totalNumThreads` and
130 /// ensuring coalesced access along the most minor dimension.
131 /// Return the number of threads used in the range:
132 ///   threadsPerDim[currentIndex .. sizes.end()]
133 // The implementation uses a dynamic programming approach to greedily extract
134 // the best combination under the constraints.
135 // TODO: Implementation details can be improved but putting effort there is a
136 // tradeoffs: `sizes` is expected to be of small rank and contain small values.
137 static SmallVector<int64_t> maximizeNumThreads(ArrayRef<int64_t> sizes,
138                                                int64_t currentIndex,
139                                                int64_t maxNumThreads) {
140   assert(static_cast<size_t>(currentIndex) < sizes.size() &&
141          "currentIndex out of bounds");
142   std::string indent(2 * currentIndex, '-');
143   if (static_cast<size_t>(currentIndex) == sizes.size() - 1) {
144     LDBG(indent << "mandated globalBest: " << sizes[currentIndex]);
145     return SmallVector<int64_t>{sizes[currentIndex]};
146   }
147 
148   int64_t best = 0;
149   int64_t s = sizes[currentIndex];
150   SmallVector<int64_t> factors = getFactors(s);
151   SmallVector<int64_t> localThreadsPerDim;
152   localThreadsPerDim.reserve(sizes.size());
153   LDBG(indent << "maximizeNumThreads in " << s
154               << " with limit: " << maxNumThreads);
155   for (auto factor : factors) {
156     auto nestedThreadsPerDim =
157         maximizeNumThreads(sizes, currentIndex + 1, maxNumThreads / factor);
158     int64_t localBest = factor * product(nestedThreadsPerDim);
159     if (localBest > best && localBest <= maxNumThreads) {
160       LDBG(indent << "new localBest: " << localBest);
161       LLVM_DEBUG(
162           llvm::interleaveComma(nestedThreadsPerDim,
163                                 DBGS() << indent << "nestedThreadsPerDim: ");
164           llvm::dbgs() << "\n";);
165       localThreadsPerDim.clear();
166       localThreadsPerDim.push_back(factor);
167       llvm::append_range(localThreadsPerDim, nestedThreadsPerDim);
168       best = localBest;
169     }
170   }
171 
172   LDBG(indent << "found globalBest: " << best);
173   LLVM_DEBUG(llvm::interleaveComma(localThreadsPerDim,
174                                    DBGS() << indent << "numThreads: ");
175              llvm::dbgs() << "\n";);
176 
177   return localThreadsPerDim;
178 }
179 
180 transform::gpu::CopyMappingInfo::Status
181 transform::gpu::CopyMappingInfo::inferNumThreads(int64_t totalNumThreads,
182                                                  ArrayRef<int64_t> sizes,
183                                                  int64_t desiredVectorSize,
184                                                  bool favorPredication) {
185 
186   if (!favorPredication) {
187     int64_t localVectorSize = desiredVectorSize;
188     for (; localVectorSize >= 1; localVectorSize /= 2) {
189       // Attempt to map the copy with predication and current fixed vector size:
190       //   1. if the status is Success, we are done.
191       //   2. if the status is Invalid, we fail immediately, no amount of
192       //   vector size reduction can offset the bad tile size selection from the
193       //   higher-level.
194       //   3. if the status is RequiresPredication, we try again with a smaller
195       //   vector size.
196       Status status =
197           inferNumThreadsImpl(totalNumThreads, sizes, localVectorSize);
198       if (status == Status::Success || status == Status::Invalid)
199         return status;
200 
201       LDBG("requires predication, try reducing vector size to "
202            << (localVectorSize / 2));
203     }
204   }
205 
206   // If we have not yet returned, it means that we have tried all vector sizes
207   // and we still require predication. Restart from the original vector size and
208   // do not attempt to
209   return inferNumThreadsImpl(totalNumThreads, sizes, desiredVectorSize);
210 }
211 
212 transform::gpu::CopyMappingInfo::Status
213 transform::gpu::CopyMappingInfo::inferNumThreadsImpl(
214     int64_t totalNumThreads, ArrayRef<int64_t> sizes,
215     int64_t desiredVectorSize) {
216   assert(sizes.back() % desiredVectorSize == 0 &&
217          "most-minor size not divisible by actualVectorSize");
218 
219   LDBG("inferNumThreadsImpl with totalNumThreads: "
220        << totalNumThreads << " and vectorSize: " << desiredVectorSize);
221 
222   // Scale the most minor size to account for the chosen vector size and
223   // maximize the number of threads without exceeding the total number of
224   // threads.
225   SmallVector<int64_t> scaledSizes(sizes);
226   scaledSizes.back() /= desiredVectorSize;
227   if (scaledSizes.back() > totalNumThreads) {
228     LDBG("--Too few threads given the required vector size -> FAIL");
229     return Status::Invalid;
230   }
231   SmallVector<int64_t> inferredNumThreads =
232       maximizeNumThreads(scaledSizes, 0, totalNumThreads);
233 
234   LLVM_DEBUG(llvm::interleaveComma(inferredNumThreads,
235                                    DBGS() << "inferred numThreads: ");
236              llvm::dbgs() << "\n";
237              LDBG("computed actualVectorSize: " << desiredVectorSize););
238 
239   // Corner case: we cannot use more threads than available. If the dimension of
240   // the copy is so bad it is because higher-level tiling did not do its job, we
241   // do not try to recover from it here.
242   int64_t totalNumThreadsUsed = product(inferredNumThreads);
243   LDBG("--totalNumThreadsUsed: " << totalNumThreadsUsed);
244   if (totalNumThreadsUsed == 0 || totalNumThreadsUsed > totalNumThreads) {
245     LDBG("--Too few threads given the required vector size -> FAIL");
246     return Status::Invalid;
247   }
248 
249   this->vectorSize = desiredVectorSize;
250   this->numThreads = inferredNumThreads;
251   if (totalNumThreadsUsed == totalNumThreads)
252     return Status::Success;
253 
254   return Status::RequiresPredication;
255 }
256 
257 void transform::gpu::CopyMappingInfo::print(llvm::raw_ostream &os) const {
258   os << "MappingInfo{";
259   os << "CopyMappingInfo: ";
260   os << "valid: " << (status != Status::Invalid) << ", ";
261   os << "vectorSize: " << vectorSize << ", ";
262   llvm::interleaveComma(numThreads, os << ", numThreads: {");
263   llvm::interleaveComma(smallestBoundingTileSizes,
264                         os << "}, smallestBoundingTileSizes: {");
265   llvm::interleaveComma(threadMapping, os << "}, threadMapping: {");
266   os << "}}";
267 }
268