xref: /llvm-project/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp (revision 74cf9bcf71d94f4df80578bccec6ed6d51dd9682)
1 //===- Utils.cpp - Utils for GPU transform ops ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/GPU/TransformOps/Utils.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
20 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Vector/IR/VectorOps.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinAttributes.h"
26 #include "mlir/IR/IRMapping.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Value.h"
30 #include "mlir/IR/Visitors.h"
31 #include "mlir/Support/LLVM.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 
37 using namespace mlir;
38 using namespace mlir::gpu;
39 using namespace mlir::transform;
40 using namespace mlir::transform::gpu;
41 
42 #define DEBUG_TYPE "gpu-transforms"
43 
44 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
45 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
46 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
47 
48 /// Return a flattened thread id for the workgroup with given sizes.
49 template <typename ThreadOrBlockIdOp>
50 static Value buildLinearId(RewriterBase &rewriter, Location loc,
51                            ArrayRef<OpFoldResult> originalBasisOfr) {
52   LLVM_DEBUG(llvm::interleaveComma(
53                  originalBasisOfr,
54                  DBGS() << "----buildLinearId with originalBasisOfr:  ");
55              llvm::dbgs() << "\n");
56   assert(originalBasisOfr.size() == 3 && "expected 3 sizes");
57   IndexType indexType = rewriter.getIndexType();
58   AffineExpr tx, ty, tz, BDX, BDY;
59   bindDims(rewriter.getContext(), tx, ty, tz);
60   bindSymbols(rewriter.getContext(), BDX, BDY);
61   SmallVector<OpFoldResult> vals{
62       rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x)
63           .getResult(),
64       rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y)
65           .getResult(),
66       rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)
67           .getResult(),
68       originalBasisOfr[0], originalBasisOfr[1]};
69   OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
70       rewriter, loc, tx + ty * BDX + tz * BDX * BDY, vals);
71   return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
72 }
73 
74 /// Create a linear id builder that takes the `originalBasisOfr` and decompose
75 /// it in the basis of `forallMappingSizes`. The linear id builder returns an
76 /// n-D vector of ids for indexing and 1-D size + id for predicate generation.
77 template <typename ThreadOrBlockIdOp>
78 static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
79   auto res = [multiplicity](RewriterBase &rewriter, Location loc,
80                             ArrayRef<int64_t> forallMappingSizes,
81                             ArrayRef<int64_t> originalBasis) {
82     SmallVector<OpFoldResult> originalBasisOfr =
83         getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
84     OpFoldResult linearId =
85         buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr);
86     // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
87     // "row-major" order.
88     SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
89     SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
90     AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
91     OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
92         rewriter, loc, d0.floorDiv(multiplicity), {linearId});
93     SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
94     SmallVector<Value> ids;
95     // Reverse back to be in [0 .. n] order.
96     for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
97       ids.push_back(
98           affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
99     }
100 
101     // clang-format off
102       LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
103                                        DBGS() << "--delinearization basis: ");
104                  llvm::dbgs() << "\n";
105                  llvm::interleaveComma(strides,
106                                        DBGS() << "--delinearization strides: ");
107                  llvm::dbgs() << "\n";
108                  llvm::interleaveComma(delinearizingExprs,
109                                        DBGS() << "--delinearization exprs: ");
110                  llvm::dbgs() << "\n";
111                  llvm::interleaveComma(ids, DBGS() << "--ids: ");
112                  llvm::dbgs() << "\n";);
113     // clang-format on
114 
115     // Return n-D ids for indexing and 1-D size + id for predicate generation.
116     return IdBuilderResult{
117         /*mappingIdOps=*/ids,
118         /*availableMappingSizes=*/
119         SmallVector<int64_t>{computeProduct(originalBasis)},
120         // `forallMappingSizes` iterate in the scaled basis, they need to be
121         // scaled back into the original basis to provide tight
122         // activeMappingSizes quantities for predication.
123         /*activeMappingSizes=*/
124         SmallVector<int64_t>{computeProduct(forallMappingSizes) * multiplicity},
125         /*activeIdOps=*/SmallVector<Value>{linearId.get<Value>()}};
126   };
127 
128   return res;
129 }
130 
131 /// Create a simple 3-D id builder that takes the `originalBasisOfr`
132 /// The 3-D id builder returns a 3-D vector of ids for indexing and 3-D sizes
133 /// + ids for predicate generation.
134 template <typename ThreadOrBlockIdOp>
135 static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
136   auto res = [multiplicity](RewriterBase &rewriter, Location loc,
137                             ArrayRef<int64_t> forallMappingSizes,
138                             ArrayRef<int64_t> originalBasis) {
139     IndexType indexType = rewriter.getIndexType();
140     SmallVector<Value> ids{
141         rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x),
142         rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y),
143         rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)};
144     // In the 3-D mapping case, scale the first dimension by the multiplicity.
145     SmallVector<Value> scaledIds = ids;
146     AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
147     scaledIds[0] = affine::makeComposedFoldedAffineApply(
148                        rewriter, loc, d0.floorDiv(multiplicity), {scaledIds[0]})
149                        .get<Value>();
150     // In the 3-D mapping case, unscale the first dimension by the multiplicity.
151     SmallVector<int64_t> forallMappingSizeInOriginalBasis(
152         forallMappingSizes.begin(), forallMappingSizes.end());
153     forallMappingSizeInOriginalBasis[0] *= multiplicity;
154     return IdBuilderResult{
155         /*mappingIdOps=*/scaledIds,
156         /*availableMappingSizes=*/SmallVector<int64_t>{originalBasis},
157         // `forallMappingSizes` iterate in the scaled basis, they need to be
158         // scaled back into the original basis to provide tight
159         // activeMappingSizes quantities for predication.
160         /*activeMappingSizes=*/
161         SmallVector<int64_t>{forallMappingSizeInOriginalBasis},
162         /*activeIdOps=*/ids};
163   };
164   return res;
165 }
166 
167 namespace mlir {
168 namespace transform {
169 namespace gpu {
170 
171 GpuIdBuilder::GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping,
172                            const MappingIdBuilderFnType &fn)
173     : mappingAttributes(), idBuilder() {
174   if (useLinearMapping) {
175     for (uint64_t d = static_cast<uint64_t>(MappingId::LinearDim0),
176                   e = getMaxEnumValForMappingId();
177          d <= e; ++d)
178       mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value()));
179   } else {
180     for (uint64_t d = static_cast<uint64_t>(MappingId::DimX),
181                   e = static_cast<uint64_t>(MappingId::DimZ);
182          d <= e; ++d)
183       mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value()));
184   }
185 }
186 
187 GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping)
188     : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) {
189         return GPUBlockMappingAttr::get(ctx, id);
190       }) {
191   idBuilder = useLinearMapping
192                   ? commonLinearIdBuilderFn<BlockIdOp>(/*multiplicity=*/1)
193                   : common3DIdBuilderFn<BlockIdOp>(/*multiplicity=*/1);
194 }
195 
196 GpuWarpgroupIdBuilder::GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
197                                              bool useLinearMapping)
198     : GpuIdBuilder(ctx, useLinearMapping,
199                    [](MLIRContext *ctx, MappingId id) {
200                      return GPUWarpgroupMappingAttr::get(ctx, id);
201                    }),
202       warpSize(warpSize) {
203   idBuilder = useLinearMapping
204                   ? commonLinearIdBuilderFn<ThreadIdOp>(
205                         /*multiplicity=*/kNumWarpsPerGroup * warpSize)
206                   : common3DIdBuilderFn<ThreadIdOp>(
207                         /*multiplicity=*/kNumWarpsPerGroup * warpSize);
208 }
209 
210 GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
211                                    bool useLinearMapping)
212     : GpuIdBuilder(ctx, useLinearMapping,
213                    [](MLIRContext *ctx, MappingId id) {
214                      return GPUWarpMappingAttr::get(ctx, id);
215                    }),
216       warpSize(warpSize) {
217   idBuilder =
218       useLinearMapping
219           ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize)
220           : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize);
221 }
222 
223 GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping)
224     : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) {
225         return GPUThreadMappingAttr::get(ctx, id);
226       }) {
227   idBuilder = useLinearMapping
228                   ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1)
229                   : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1);
230 }
231 
232 DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp,
233                                            std::optional<int64_t> gridDimX,
234                                            std::optional<int64_t> gridDimY,
235                                            std::optional<int64_t> gridDimZ,
236                                            std::optional<int64_t> blockDimX,
237                                            std::optional<int64_t> blockDimY,
238                                            std::optional<int64_t> blockDimZ) {
239 
240   // TODO: pass a configuration object to set the limits properly.
241   static constexpr int maxTotalBlockdim = 1024;
242   static constexpr int maxBlockdimx = 1024;
243   static constexpr int maxBlockdimy = 1024;
244   static constexpr int maxBlockdimz = 64;
245   static constexpr int maxTotalGriddim = 2147483647;
246   static constexpr int maxGriddimx = 2147483647;
247   static constexpr int maxGriddimy = 65535;
248   static constexpr int maxGriddimz = 65535;
249 
250   if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) >
251           maxTotalBlockdim ||
252       (gridDimX.value_or(1) * gridDimY.value_or(1) * gridDimZ.value_or(1)) >
253           maxTotalGriddim ||
254       blockDimX.value_or(1) > maxBlockdimx ||
255       blockDimY.value_or(1) > maxBlockdimy ||
256       blockDimZ.value_or(1) > maxBlockdimz ||
257       gridDimY.value_or(1) > maxGriddimy ||
258       gridDimZ.value_or(1) > maxGriddimz ||
259       gridDimX.value_or(1) > maxGriddimx) {
260     return transformOp.emitSilenceableError()
261            << "Trying to launch a GPU kernel with grid_dims = ("
262            << gridDimX.value_or(1) << ", " << gridDimY.value_or(1) << ", "
263            << gridDimZ.value_or(1) << ") block_dims = ("
264            << blockDimX.value_or(1) << ", " << blockDimY.value_or(1) << ", "
265            << blockDimZ.value_or(1) << "). It is larger than the limits.";
266   }
267   return DiagnosedSilenceableFailure::success();
268 }
269 
270 DiagnosedSilenceableFailure createGpuLaunch(
271     RewriterBase &rewriter, Location loc, TransformOpInterface transformOp,
272     LaunchOp &launchOp, std::optional<int64_t> gridDimX,
273     std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ,
274     std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY,
275     std::optional<int64_t> blockDimZ) {
276   DiagnosedSilenceableFailure diag =
277       checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
278                      blockDimY, blockDimZ);
279   if (!diag.succeeded())
280     return diag;
281 
282   auto createConst = [&](int dim) {
283     return rewriter.create<arith::ConstantIndexOp>(loc, dim);
284   };
285   OpBuilder::InsertionGuard guard(rewriter);
286   Value one = createConst(1);
287   Value gridSizeX = gridDimX.has_value() ? createConst(gridDimX.value()) : one;
288   Value gridSizeY = gridDimY.has_value() ? createConst(gridDimY.value()) : one;
289   Value gridSizeZ = gridDimZ.has_value() ? createConst(gridDimZ.value()) : one;
290   Value blkSizeX = blockDimX.has_value() ? createConst(blockDimX.value()) : one;
291   Value blkSizeY = blockDimY.has_value() ? createConst(blockDimY.value()) : one;
292   Value blkSizeZ = blockDimZ.has_value() ? createConst(blockDimZ.value()) : one;
293   launchOp = rewriter.create<LaunchOp>(loc, gridSizeX, gridSizeY, gridSizeZ,
294                                        blkSizeX, blkSizeY, blkSizeZ);
295   rewriter.setInsertionPointToEnd(&launchOp.getBody().front());
296   rewriter.create<TerminatorOp>(loc);
297   return DiagnosedSilenceableFailure::success();
298 }
299 
300 /// Alter kernel configuration of the given kernel.
301 DiagnosedSilenceableFailure alterGpuLaunch(
302     RewriterBase &rewriter, LaunchOp gpuLaunch,
303     TransformOpInterface transformOp, std::optional<int64_t> gridDimX,
304     std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ,
305     std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY,
306     std::optional<int64_t> blockDimZ) {
307   DiagnosedSilenceableFailure diag =
308       checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
309                      blockDimY, blockDimZ);
310   if (!diag.succeeded())
311     return diag;
312 
313   KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
314   OpBuilder::InsertionGuard guard(rewriter);
315   rewriter.setInsertionPointAfterValue(currentBlockdim.x);
316   auto createConstValue = [&](int dim) {
317     return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(),
318                                                    dim);
319   };
320 
321   if (gridDimX.has_value())
322     gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value()));
323   if (gridDimY.has_value())
324     gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value()));
325   if (gridDimZ.has_value())
326     gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value()));
327   if (blockDimX.has_value())
328     gpuLaunch.getBlockSizeXMutable().assign(
329         createConstValue(blockDimX.value()));
330   if (blockDimY.has_value())
331     gpuLaunch.getBlockSizeYMutable().assign(
332         createConstValue(blockDimY.value()));
333   if (blockDimZ.has_value())
334     gpuLaunch.getBlockSizeZMutable().assign(
335         createConstValue(blockDimZ.value()));
336   return DiagnosedSilenceableFailure::success();
337 }
338 
339 } // namespace gpu
340 } // namespace transform
341 } // namespace mlir
342