xref: /llvm-project/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp (revision 31aa8ea252c0b6acdcb362c1d0f01cc4b810d6d0)
1 //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
20 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinAttributes.h"
25 #include "mlir/IR/IRMapping.h"
26 #include "mlir/IR/MLIRContext.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/Visitors.h"
29 #include "mlir/Support/LLVM.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/Support/Debug.h"
33 
34 using namespace mlir;
35 using namespace mlir::gpu;
36 using namespace mlir::transform;
37 using namespace mlir::transform::gpu;
38 
39 #define DEBUG_TYPE "gpu-transforms"
40 
41 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
42 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
43 
44 namespace {
45 
46 /// Return a flattened thread id for the workgroup with given sizes.
47 static Value buildLinearThreadId(RewriterBase &rewriter, Location loc,
48                                  ArrayRef<OpFoldResult> blockDimsOfr) {
49   LLVM_DEBUG(llvm::interleaveComma(
50                  blockDimsOfr,
51                  DBGS() << "----buildLinearThreadId with blockDimsOfr:  ");
52              llvm::dbgs() << "\n");
53   assert(blockDimsOfr.size() == 3 && "expected 3 workgroup sizes");
54   AffineExpr tx, ty, tz, BDX, BDY;
55   bindDims(rewriter.getContext(), tx, ty, tz);
56   bindSymbols(rewriter.getContext(), BDX, BDY);
57   IndexType indexType = rewriter.getIndexType();
58   SmallVector<OpFoldResult> threadsAndWorkGroups{
59       rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x).getResult(),
60       rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y).getResult(),
61       rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z).getResult()};
62   threadsAndWorkGroups.push_back(blockDimsOfr[0]);
63   threadsAndWorkGroups.push_back(blockDimsOfr[1]);
64   OpFoldResult ofr = makeComposedFoldedAffineApply(
65       rewriter, loc, tx + ty * BDX + tz * BDX * BDY, threadsAndWorkGroups);
66   return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
67 }
68 
69 /// Builder for gpu::BlockIdOps used in mapping scf.forall to blocks.
70 /// The `idBuilder` method returns 3-D values used for indexing rewrites as well
71 /// as 3-D sizes for predicate generation.
72 struct GpuBlockIdBuilder : public GpuIdBuilder {
73 
74   GpuBlockIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
75                     ArrayRef<int64_t> mappingSizes)
76       : GpuIdBuilder(blockDims, mappingSizes) {
77     mappingAttributes = {GPUBlockMappingAttr::get(ctx, Blocks::DimX),
78                          GPUBlockMappingAttr::get(ctx, Blocks::DimY),
79                          GPUBlockMappingAttr::get(ctx, Blocks::DimZ)},
80     idBuilder = [](RewriterBase &rewriter, Location loc,
81                    ArrayRef<int64_t> forallMappingSizes) {
82       IndexType indexType = rewriter.getIndexType();
83       SmallVector<Value> ids{
84           rewriter.create<BlockIdOp>(loc, indexType, Dimension::x),
85           rewriter.create<BlockIdOp>(loc, indexType, Dimension::y),
86           rewriter.create<BlockIdOp>(loc, indexType, Dimension::z)};
87       // Return 3-D ids for indexing rewrite and 3-D sizes and ids for
88       // predicate generation.
89       return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes},
90                              ids};
91     };
92   }
93 };
94 
95 /// Builder for gpu::ThreadIdOp used in mapping scf.forall to thread ids without
96 /// any reindexing.
97 /// The `idBuilder` method returns 3-D values used for indexing rewrites as well
98 /// as 3-D sizes for predicate generation.
99 struct GpuThreadIdBuilder : public GpuIdBuilder {
100   GpuThreadIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
101                      ArrayRef<int64_t> mappingSizes)
102       : GpuIdBuilder(blockDims, mappingSizes) {
103     mappingAttributes = {GPUThreadMappingAttr::get(ctx, Threads::DimX),
104                          GPUThreadMappingAttr::get(ctx, Threads::DimY),
105                          GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
106     idBuilder = [](RewriterBase &rewriter, Location loc,
107                    ArrayRef<int64_t> forallMappingSizes) {
108       IndexType indexType = rewriter.getIndexType();
109       SmallVector<Value> ids{
110           rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
111           rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
112           rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)};
113       // Return 3-D ids for indexing rewrite and 3-D sizes and ids for
114       // predicate generation.
115       return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes},
116                              ids};
117     };
118   }
119 };
120 
121 /// Builder for warp ids used in mapping scf.forall to warps.
122 /// This builder requires a specification of the number of warps along each
123 /// dimension to more finely control mapping to warps as well a predication than
124 /// by solely analyzing the IR.
125 /// The `idBuilder` method returns 3-D values used for indexing rewrites as well
126 /// as 3-D sizes for predicate generation.
127 struct GpuWarpIdBuilder : public GpuIdBuilder {
128   GpuWarpIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
129                    ArrayRef<int64_t> mappingSizes)
130       : GpuIdBuilder(blockDims, mappingSizes) {
131     mappingAttributes = {GPUWarpMappingAttr::get(ctx, Warps::DimX),
132                          GPUWarpMappingAttr::get(ctx, Warps::DimY),
133                          GPUWarpMappingAttr::get(ctx, Warps::DimZ)};
134     idBuilder = [this](RewriterBase &rewriter, Location loc,
135                        ArrayRef<int64_t> forallMappingSizes) {
136       // Build the linear warp id and decompose it in the basis of
137       // `forallMappingSizes`.
138       Value linearId = buildLinearThreadId(rewriter, loc, this->blockDimsOfr);
139       AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
140       OpFoldResult warpIdOfr = makeComposedFoldedAffineApply(
141           rewriter, loc, d0.floorDiv(kWarpSize), {linearId});
142       Value warpId = getValueOrCreateConstantIndexOp(rewriter, loc, warpIdOfr);
143       SmallVector<int64_t> reverseBasisSizes(
144           llvm::reverse(this->availableMappingSizes));
145       SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
146       SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
147       SmallVector<Value> ids;
148       for (AffineExpr e : delinearizingExprs)
149         ids.push_back(makeComposedAffineApply(rewriter, loc, e, warpId));
150 
151       // clang-format off
152       LDBG("----linearId: " << linearId);
153           LDBG("----warpId: " << warpId);
154       LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
155                                        DBGS() << "--delinearization basis: ");
156                  llvm::dbgs() << "\n";
157                  llvm::interleaveComma(strides,
158                                        DBGS() << "--delinearization strides: ");
159                  llvm::dbgs() << "\n";
160                  llvm::interleaveComma(delinearizingExprs,
161                                        DBGS() << "--delinearization exprs: ");
162                  llvm::dbgs() << "\n";
163                  llvm::interleaveComma(ids, DBGS() << "--ids: ");
164                  llvm::dbgs() << "\n";);
165       // clang-format on
166 
167       // Return 3-D ids for indexing rewrite and 3-D sizes and ids for
168       // predicate generation.
169       return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes},
170                              ids};
171     };
172   }
173 
174   /// Static specification of the warp size.
175   /// In the future this may be configured by the transformation.
176   static constexpr int64_t kWarpSize = 32;
177 };
178 
179 /// Builder for linear ids used in mapping scf.forall to reindexed threads.
180 /// The `idBuilder` method returns 3-D values used for indexing rewrites as well
181 /// as 1-D sizes for predicate generation.
182 struct GpuLinearIdBuilder : public GpuIdBuilder {
183   GpuLinearIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims,
184                      ArrayRef<int64_t> mappingSizes)
185       : GpuIdBuilder(blockDims, mappingSizes) {
186     mappingAttributes = {GPULinearIdMappingAttr::get(ctx, LinearId::DimX),
187                          GPULinearIdMappingAttr::get(ctx, LinearId::DimY),
188                          GPULinearIdMappingAttr::get(ctx, LinearId::DimZ)};
189     idBuilder = [this](RewriterBase &rewriter, Location loc,
190                        ArrayRef<int64_t> forallMappingSizes) {
191       // Build the linear thread id and decompose it in the basis of
192       // `forallMappingSizes`.
193       Value linearId = buildLinearThreadId(rewriter, loc, this->blockDimsOfr);
194       SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
195       SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
196       AffineExpr d0;
197       bindDims(rewriter.getContext(), d0);
198       SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
199       SmallVector<Value> ids;
200       for (AffineExpr e : delinearizingExprs)
201         ids.push_back(makeComposedAffineApply(rewriter, loc, e, linearId));
202 
203       // clang-format off
204       LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
205                                        DBGS() << "--delinearization basis: ");
206                  llvm::dbgs() << "\n";
207                  llvm::interleaveComma(strides,
208                                        DBGS() << "--delinearization strides: ");
209                  llvm::dbgs() << "\n";
210                  llvm::interleaveComma(delinearizingExprs,
211                                        DBGS() << "--delinearization exprs: ");
212                  llvm::dbgs() << "\n";
213                  llvm::interleaveComma(ids, DBGS() << "--ids: ");
214                  llvm::dbgs() << "\n";);
215       // clang-format on
216 
217       // Compute and return the 1-D actual mapping size spanned by the linearId,
218       // it will be used to predicate against the linearized total number of
219       // threads.
220       int64_t actualMappingSize = 1;
221       for (int64_t s : forallMappingSizes)
222         actualMappingSize *= s;
223 
224       // Return 3-D ids for indexing rewrite and 1-D size and id for
225       // predicate generation.
226       return IdBuilderResult{ids, SmallVector<int64_t>{actualMappingSize},
227                              SmallVector<Value>{linearId}};
228     };
229   }
230 };
231 
232 } // namespace
233 
234 static DiagnosedSilenceableFailure
235 definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
236                       Operation *target, const Twine &message) {
237   if (transformOp.has_value())
238     return transformOp->emitDefiniteFailure() << message;
239   return emitDefiniteFailure(target, message);
240 }
241 
242 /// Check if given mapping attributes are one of the desired attributes
243 static DiagnosedSilenceableFailure
244 checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
245                            scf::ForallOp forallOp) {
246   if (!forallOp.getMapping().has_value())
247     return definiteFailureHelper(transformOp, forallOp,
248                                  "mapping must be present");
249 
250   bool hasBlockMapping =
251       llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
252         return attr.isa<GPUBlockMappingAttr>();
253       });
254   bool hasThreadMapping =
255       llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
256         return attr.isa<GPUThreadMappingAttr>();
257       });
258   bool hasWarpMapping =
259       llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
260         return attr.isa<GPUWarpMappingAttr>();
261       });
262   bool hasLinearMapping =
263       llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
264         return attr.isa<GPULinearIdMappingAttr>();
265       });
266   int64_t countMappingTypes = 0;
267   countMappingTypes += hasBlockMapping ? 1 : 0;
268   countMappingTypes += hasThreadMapping ? 1 : 0;
269   countMappingTypes += hasWarpMapping ? 1 : 0;
270   countMappingTypes += hasLinearMapping ? 1 : 0;
271   if (countMappingTypes > 1) {
272     return definiteFailureHelper(
273         transformOp, forallOp,
274         "cannot mix different mapping types, use nesting");
275   }
276 
277   DenseSet<Attribute> seen;
278   for (Attribute map : forallOp.getMapping()->getValue()) {
279     if (seen.contains(map)) {
280       return definiteFailureHelper(
281           transformOp, forallOp,
282           "duplicated attribute, cannot map different loops "
283           "to the same processor");
284     }
285     seen.insert(map);
286   }
287 
288   return DiagnosedSilenceableFailure::success();
289 }
290 
291 static DiagnosedSilenceableFailure
292 verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
293                  scf::ForallOp forallOp) {
294   // Check the types of the mapping attributes match.
295   DiagnosedSilenceableFailure typeRes =
296       checkMappingAttributeTypes(transformOp, forallOp);
297   if (!typeRes.succeeded())
298     return typeRes;
299 
300   // Perform other non-types verifications.
301   if (!forallOp.isNormalized())
302     return definiteFailureHelper(transformOp, forallOp,
303                                  "unsupported non-normalized loops");
304   if (forallOp.getNumResults() > 0)
305     return definiteFailureHelper(transformOp, forallOp,
306                                  "only bufferized scf.forall can be mapped");
307   if (forallOp.getRank() > 3)
308     return definiteFailureHelper(transformOp, forallOp,
309                                  "scf.forall with rank > 3 does not lower");
310   if (llvm::any_of(forallOp.getMixedUpperBound(), [&](OpFoldResult ofr) {
311         return !getConstantIntValue(ofr).has_value();
312       })) {
313     return definiteFailureHelper(transformOp, forallOp,
314                                  "unsupported dynamic sizes in forall op");
315   }
316   return DiagnosedSilenceableFailure::success();
317 }
318 
319 /// Determines if the size of the kernel configuration is supported by the
320 /// GPU architecture being used. It presently makes use of CUDA limitations,
321 /// however that aspect may be enhanced for other GPUs.
322 static DiagnosedSilenceableFailure checkGpuLimits(
323     TransformOpInterface transformOp, std::optional<int64_t> gridDimX,
324     std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ,
325     std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY,
326     std::optional<int64_t> blockDimZ) {
327 
328   static constexpr int maxTotalBlockdim = 1024;
329   static constexpr int maxBlockdimx = 1024;
330   static constexpr int maxBlockdimy = 1024;
331   static constexpr int maxBlockdimz = 64;
332   static constexpr int maxTotalGriddim = 2147483647;
333   static constexpr int maxGriddimx = 2147483647;
334   static constexpr int maxGriddimy = 65535;
335   static constexpr int maxGriddimz = 65535;
336 
337   if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) >
338           maxTotalBlockdim ||
339       (gridDimX.value_or(1) * gridDimY.value_or(1) * gridDimZ.value_or(1)) >
340           maxTotalGriddim ||
341       blockDimX.value_or(1) > maxBlockdimx ||
342       blockDimY.value_or(1) > maxBlockdimy ||
343       blockDimZ.value_or(1) > maxBlockdimz ||
344       gridDimY.value_or(1) > maxGriddimy ||
345       gridDimZ.value_or(1) > maxGriddimz ||
346       gridDimX.value_or(1) > maxGriddimx) {
347     return transformOp.emitSilenceableError()
348            << "Trying to launch a GPU kernel with grid_dims = ("
349            << gridDimX.value_or(1) << ", " << gridDimY.value_or(1) << ", "
350            << gridDimZ.value_or(1) << ") block_dims = ("
351            << blockDimX.value_or(1) << ", " << blockDimY.value_or(1) << ", "
352            << blockDimZ.value_or(1) << "). It is larger than the limits.";
353   }
354   return DiagnosedSilenceableFailure::success();
355 }
356 
357 /// Creates an empty-body gpu::LaunchOp using the provided kernel settings
358 /// and put a terminator within.
359 static DiagnosedSilenceableFailure
360 createGpuLaunch(RewriterBase &rewriter, Location loc,
361                 TransformOpInterface transformOp, LaunchOp &launchOp,
362                 std::optional<int64_t> gridDimX = std::nullopt,
363                 std::optional<int64_t> gridDimY = std::nullopt,
364                 std::optional<int64_t> gridDimZ = std::nullopt,
365                 std::optional<int64_t> blockDimX = std::nullopt,
366                 std::optional<int64_t> blockDimY = std::nullopt,
367                 std::optional<int64_t> blockDimZ = std::nullopt) {
368   DiagnosedSilenceableFailure diag =
369       checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
370                      blockDimY, blockDimZ);
371   if (!diag.succeeded())
372     return diag;
373 
374   auto createConst = [&](int dim) {
375     return rewriter.create<arith::ConstantIndexOp>(loc, dim);
376   };
377   OpBuilder::InsertionGuard guard(rewriter);
378   Value one = createConst(1);
379   Value gridSizeX = gridDimX.has_value() ? createConst(gridDimX.value()) : one;
380   Value gridSizeY = gridDimY.has_value() ? createConst(gridDimY.value()) : one;
381   Value gridSizeZ = gridDimZ.has_value() ? createConst(gridDimZ.value()) : one;
382   Value blkSizeX = blockDimX.has_value() ? createConst(blockDimX.value()) : one;
383   Value blkSizeY = blockDimY.has_value() ? createConst(blockDimY.value()) : one;
384   Value blkSizeZ = blockDimZ.has_value() ? createConst(blockDimZ.value()) : one;
385   launchOp = rewriter.create<LaunchOp>(loc, gridSizeX, gridSizeY, gridSizeZ,
386                                        blkSizeX, blkSizeY, blkSizeZ);
387   rewriter.setInsertionPointToEnd(&launchOp.getBody().front());
388   rewriter.create<TerminatorOp>(loc);
389   return DiagnosedSilenceableFailure::success();
390 }
391 
392 /// Alter kernel configuration of the given kernel.
393 static DiagnosedSilenceableFailure
394 alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch,
395                TransformOpInterface transformOp,
396                std::optional<int64_t> gridDimX = std::nullopt,
397                std::optional<int64_t> gridDimY = std::nullopt,
398                std::optional<int64_t> gridDimZ = std::nullopt,
399                std::optional<int64_t> blockDimX = std::nullopt,
400                std::optional<int64_t> blockDimY = std::nullopt,
401                std::optional<int64_t> blockDimZ = std::nullopt) {
402   DiagnosedSilenceableFailure diag =
403       checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
404                      blockDimY, blockDimZ);
405   if (!diag.succeeded())
406     return diag;
407 
408   KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
409   OpBuilder::InsertionGuard guard(rewriter);
410   rewriter.setInsertionPointAfterValue(currentBlockdim.x);
411   auto createConstValue = [&](int dim) {
412     return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(),
413                                                    dim);
414   };
415 
416   if (gridDimX.has_value())
417     gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value()));
418   if (gridDimY.has_value())
419     gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value()));
420   if (gridDimZ.has_value())
421     gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value()));
422   if (blockDimX.has_value())
423     gpuLaunch.getBlockSizeXMutable().assign(
424         createConstValue(blockDimX.value()));
425   if (blockDimY.has_value())
426     gpuLaunch.getBlockSizeYMutable().assign(
427         createConstValue(blockDimY.value()));
428   if (blockDimZ.has_value())
429     gpuLaunch.getBlockSizeZMutable().assign(
430         createConstValue(blockDimZ.value()));
431   return DiagnosedSilenceableFailure::success();
432 }
433 
434 /// Struct to return the result of the rewrite of a forall operation.
435 struct ForallRewriteResult {
436   SmallVector<int64_t> mappingSizes;
437   SmallVector<Value> mappingIds;
438 };
439 
440 /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
441 template <typename OpTy, typename OperationOrBlock>
442 static void
443 replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc,
444                             OperationOrBlock *parent, Value replacement,
445                             ArrayRef<int64_t> availableMappingSizes) {
446   parent->walk([&](OpTy idOp) {
447     if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
448       rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
449   });
450 }
451 
452 static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
453     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
454     scf::ForallOp forallOp, ForallRewriteResult &result,
455     ArrayRef<int64_t> availableMappingSizes, const GpuIdBuilder &gpuIdBuilder) {
456   LDBG("--start rewriteOneForallCommonImpl");
457 
458   // Step 0. GPU-specific verifications. There is no better place to anchor
459   // those right now: the ForallOp is target-independent and the transform
460   // op does not apply to individual ForallOp.
461   DiagnosedSilenceableFailure diag = verifyGpuMapping(transformOp, forallOp);
462   if (!diag.succeeded())
463     return diag;
464 
465   // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
466   SmallVector<int64_t> tmpMappingSizes = llvm::to_vector(
467       llvm::map_range(forallOp.getMixedUpperBound(), [](OpFoldResult ofr) {
468         auto maybeStaticValue = getConstantIntValue(ofr);
469         assert(maybeStaticValue && "expected static value");
470         return maybeStaticValue.value();
471       }));
472   SmallVector<Attribute> forallMappingAttrs =
473       llvm::to_vector(forallOp.getMapping()->getValue());
474   for (auto attr : gpuIdBuilder.mappingAttributes) {
475     if (llvm::is_contained(forallMappingAttrs, attr))
476       continue;
477     forallMappingAttrs.push_back(attr);
478     tmpMappingSizes.push_back(1);
479   }
480   LLVM_DEBUG(
481       llvm::interleaveComma(
482           tmpMappingSizes,
483           DBGS() << "----tmpMappingSizes extracted from scf.forall op: ");
484       llvm::dbgs() << "\n");
485 
486   // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
487   auto comparator = [&](DeviceMappingAttrInterface a,
488                         DeviceMappingAttrInterface b) -> bool {
489     return a.getMappingId() < b.getMappingId();
490   };
491   SmallVector<int64_t> forallMappingSizes =
492       getValuesSortedByKey(forallMappingAttrs, tmpMappingSizes, comparator);
493   LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
494                                    DBGS() << "----forallMappingSizes: ");
495              llvm::dbgs() << "\n"; llvm::interleaveComma(
496                  forallMappingAttrs, DBGS() << "----mappingAttrs: ");
497              llvm::dbgs() << "\n");
498 
499   // Step 3. Generate the mappingIdOps using the provided generator.
500   Location loc = forallOp.getLoc();
501   OpBuilder::InsertionGuard guard(rewriter);
502   rewriter.setInsertionPoint(forallOp);
503   IdBuilderResult builderResult =
504       gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes);
505 
506   // Step 4. Map the induction variables to the mappingIdOps, this may involve a
507   // permutation.
508   SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
509   IRMapping bvm;
510   for (auto [iv, dim] :
511        llvm::zip_equal(forallOp.getInductionVars(),
512                        ArrayRef<Attribute>{forallMappingAttrs}.take_front(
513                            forallOp.getInductionVars().size()))) {
514     Value peIdOp = mappingIdOps[static_cast<int64_t>(
515         dim.cast<DeviceMappingAttrInterface>().getMappingId())];
516     bvm.map(iv, peIdOp);
517   }
518 
519   // Step 5. If the availableMappingSizes are already known, create conditionals
520   // to predicate the region. Otherwise, the current forall determines the
521   // availableMappingSizes and no predication occurs.
522   Value predicate;
523   if (!availableMappingSizes.empty()) {
524     SmallVector<int64_t> predicateMappingSizes =
525         builderResult.predicateMappingSizes;
526     SmallVector<Value> predicateIdOps = builderResult.predicateIdOps;
527     // clang-format off
528     LLVM_DEBUG(
529         llvm::interleaveComma(
530           predicateMappingSizes, DBGS() << "----predicateMappingSizes: ");
531         llvm::dbgs() << "\n";
532         llvm::interleaveComma(
533           availableMappingSizes, DBGS() << "----availableMappingSizes: ");
534         llvm::dbgs() << "\n";
535         llvm::interleaveComma(predicateIdOps, DBGS() << "----predicateIdOps: ");
536         llvm::dbgs() << "\n");
537     // clang-format on
538     for (auto [id, mappingSize, availableMappingSize] : llvm::zip_equal(
539              predicateIdOps, predicateMappingSizes, availableMappingSizes)) {
540       if (mappingSize > availableMappingSize) {
541         return definiteFailureHelper(
542             transformOp, forallOp,
543             "Trying to map to fewer GPU threads than loop iterations but "
544             "overprovisioning is not yet supported. "
545             "Try additional tiling of the before mapping or map to more "
546             "threads.");
547       }
548       if (mappingSize == availableMappingSize)
549         continue;
550       Value idx = rewriter.create<arith::ConstantIndexOp>(loc, mappingSize);
551       Value tmpPredicate = rewriter.create<arith::CmpIOp>(
552           loc, arith::CmpIPredicate::ult, id, idx);
553       LDBG("----predicate: " << tmpPredicate);
554       predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
555                                                              tmpPredicate)
556                             : tmpPredicate;
557     }
558   }
559 
560   // Step 6. Move the body of forallOp.
561   // Erase the terminator first, it will not be used.
562   rewriter.eraseOp(forallOp.getTerminator());
563   Block *targetBlock;
564   Block::iterator insertionPoint;
565   if (predicate) {
566     // Step 6.a. If predicated, move at the beginning.
567     auto ifOp = rewriter.create<scf::IfOp>(loc, predicate,
568                                            /*withElseRegion=*/false);
569     targetBlock = ifOp.thenBlock();
570     insertionPoint = ifOp.thenBlock()->begin();
571   } else {
572     // Step 6.b. Otherwise, move inline just at the rewriter insertion
573     // point.
574     targetBlock = forallOp->getBlock();
575     insertionPoint = rewriter.getInsertionPoint();
576   }
577   Block &sourceBlock = forallOp.getRegion().front();
578   targetBlock->getOperations().splice(insertionPoint,
579                                       sourceBlock.getOperations());
580 
581   // Step 7. RAUW indices.
582   for (Value loopIndex : forallOp.getInductionVars()) {
583     Value threadIdx = bvm.lookup(loopIndex);
584     rewriter.replaceAllUsesWith(loopIndex, threadIdx);
585   }
586 
587   // Step 8. Erase old op.
588   rewriter.eraseOp(forallOp);
589 
590   result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
591   return DiagnosedSilenceableFailure::success();
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // MapForallToBlocks
596 //===----------------------------------------------------------------------===//
597 
598 DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
599     RewriterBase &rewriter, TransformOpInterface transformOp,
600     scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
601     const GpuIdBuilder &gpuIdBuilder) {
602   LDBG("Start mapForallToBlocksImpl");
603 
604   Location loc = forallOp.getLoc();
605   Block *parentBlock = forallOp->getBlock();
606   Value zero;
607   {
608     // Create an early zero index value for replacements and immediately reset
609     // the insertion point.
610     OpBuilder::InsertionGuard guard(rewriter);
611     rewriter.setInsertionPointToStart(parentBlock);
612     zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
613   }
614 
615   SmallVector<int64_t> anyAvailableMappingSizes;
616   ForallRewriteResult rewriteResult;
617   // Pass an empty anyAvailableMappingSizes.
618   DiagnosedSilenceableFailure diag =
619       rewriteOneForallCommonImpl(rewriter, transformOp, forallOp, rewriteResult,
620                                  anyAvailableMappingSizes, gpuIdBuilder);
621 
622   // Return if anything goes wrong, use silenceable failure as a match failure.
623   if (!diag.succeeded())
624     return diag;
625 
626   // Set the gridDims that act as a return.
627   gridDims = rewriteResult.mappingSizes;
628 
629   // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
630   // Here, the result of mapping determines the available mapping sizes.
631   replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
632                                           gridDims);
633 
634   return DiagnosedSilenceableFailure::success();
635 }
636 
637 DiagnosedSilenceableFailure
638 mlir::transform::gpu::findTopLevelForallOp(Operation *target,
639                                            scf::ForallOp &topLevelForallOp,
640                                            TransformOpInterface transformOp) {
641   auto walkResult = target->walk([&](scf::ForallOp forallOp) {
642     if (forallOp->getParentOfType<scf::ForallOp>())
643       return WalkResult::advance();
644     if (topLevelForallOp)
645       // TODO: Handle multiple forall if they are independent.
646       return WalkResult::interrupt();
647     topLevelForallOp = forallOp;
648     return WalkResult::advance();
649   });
650 
651   if (walkResult.wasInterrupted())
652     return transformOp.emitSilenceableError()
653            << "could not find a unique topLevel scf.forall";
654   return DiagnosedSilenceableFailure::success();
655 }
656 
657 DiagnosedSilenceableFailure
658 transform::MapForallToBlocks::applyToOne(Operation *target,
659                                          ApplyToEachResultList &results,
660                                          transform::TransformState &state) {
661   LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
662   IRRewriter rewriter(getContext());
663   auto transformOp = cast<TransformOpInterface>(getOperation());
664 
665   if (!getGenerateGpuLaunch() && !gpuLaunch) {
666     DiagnosedSilenceableFailure diag =
667         emitSilenceableError()
668         << "Given target is not gpu.launch, set `generate_gpu_launch` "
669            "attribute";
670     diag.attachNote(target->getLoc()) << "when applied to this payload op";
671     return diag;
672   }
673 
674   scf::ForallOp topLevelForallOp;
675   DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp(
676       target, topLevelForallOp, transformOp);
677   if (!diag.succeeded()) {
678     diag.attachNote(target->getLoc()) << "when applied to this payload op";
679     return diag;
680   }
681 
682   SmallVector<int64_t> gridDims{getGridDims()};
683   if (!getGenerateGpuLaunch() && gridDims.size() != 3)
684     return transformOp.emitDefiniteFailure("transform require size-3 mapping");
685 
686   OpBuilder::InsertionGuard guard(rewriter);
687   rewriter.setInsertionPoint(topLevelForallOp);
688 
689   // Generate gpu launch here and move the forall inside
690   if (getGenerateGpuLaunch()) {
691     DiagnosedSilenceableFailure diag =
692         createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
693     if (!diag.succeeded()) {
694       return diag;
695     }
696     rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
697     Operation *newForallOp = rewriter.clone(*topLevelForallOp);
698     rewriter.eraseOp(topLevelForallOp);
699     topLevelForallOp = cast<scf::ForallOp>(newForallOp);
700   }
701 
702   GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), {}, {});
703   diag = mlir::transform::gpu::mapForallToBlocksImpl(
704       rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
705   if (!diag.succeeded())
706     return diag;
707 
708   // Set the GPU launch configuration for the grid dims late, this is subject to
709   // IR inspection.
710   diag = alterGpuLaunch(rewriter, gpuLaunch,
711                         cast<TransformOpInterface>(getOperation()), gridDims[0],
712                         gridDims[1], gridDims[2]);
713 
714   results.push_back(gpuLaunch);
715   return diag;
716 }
717 
718 //===----------------------------------------------------------------------===//
719 // MapNestedForallToThreads
720 //===----------------------------------------------------------------------===//
721 
722 DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl(
723     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
724     scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
725     bool syncAfterDistribute, const GpuIdBuilder &gpuIdBuilder) {
726   // Ignore cases with different attributes than this builder supports.
727   for (Attribute map : forallOp.getMapping()->getValue()) {
728     if (!llvm::is_contained(gpuIdBuilder.mappingAttributes, map)) {
729       LDBG("--skip " << map);
730       LLVM_DEBUG(llvm::interleaveComma(gpuIdBuilder.mappingAttributes,
731                                        DBGS() << "----not in: ");
732                  llvm::dbgs() << "\n";);
733       return emitSilenceableFailure(forallOp);
734     }
735   }
736 
737   Location loc = forallOp.getLoc();
738   OpBuilder::InsertionGuard g(rewriter);
739   // Insert after to allow for syncthreads after `forall` is erased.
740   rewriter.setInsertionPointAfter(forallOp);
741   ForallRewriteResult rewriteResult;
742   DiagnosedSilenceableFailure diag =
743       rewriteOneForallCommonImpl(rewriter, transformOp, forallOp, rewriteResult,
744                                  availableMappingSizes, gpuIdBuilder);
745 
746   // Return if anything goes wrong, use silenceable failure as a match failure.
747   if (!diag.succeeded())
748     return diag;
749 
750   // Add a syncthreads if needed. TODO: warpsync
751   if (syncAfterDistribute)
752     rewriter.create<BarrierOp>(loc);
753 
754   return DiagnosedSilenceableFailure::success();
755 }
756 
757 DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
758     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
759     Operation *target, ArrayRef<int64_t> blockDims, ArrayRef<int64_t> warpDims,
760     bool syncAfterDistribute) {
761   LDBG("Start mapNestedForallToThreadsImpl");
762   MLIRContext *ctx = rewriter.getContext();
763   SmallVector<OpFoldResult> blockDimsOfr =
764       getAsIndexOpFoldResult(ctx, blockDims);
765 
766   if (blockDims.size() != 3)
767     return definiteFailureHelper(transformOp, target,
768                                  "requires size-3 thread mapping");
769   if (!warpDims.empty()) {
770     if (warpDims.size() != 3)
771       return definiteFailureHelper(transformOp, target,
772                                    "requires empty or size-3 warp mapping");
773   }
774 
775   // Create an early zero index value for replacements.
776   Location loc = target->getLoc();
777   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
778   DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
779   WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
780     //===--------------------------------------------------------------------===//
781     // Mapping to warp ids.
782     //===--------------------------------------------------------------------===//
783     if (!warpDims.empty()) {
784       LLVM_DEBUG(
785           llvm::interleaveComma(
786               warpDims, DBGS() << "+mapNestedForallToThreadsImpl warpDims: ");
787           llvm::dbgs() << "\n");
788       LLVM_DEBUG(llvm::interleaveComma(
789                      blockDimsOfr, DBGS() << "--warpDims with blockDimsOfr:  ");
790                  llvm::dbgs() << "\n");
791       GpuWarpIdBuilder gpuWarpIdBuilder(ctx, blockDimsOfr, warpDims);
792       diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
793           rewriter, transformOp, forallOp, warpDims, syncAfterDistribute,
794           gpuWarpIdBuilder);
795       // Use silenceable failure to encode "failure to match" and pass
796       // through.
797       if (diag.isDefiniteFailure())
798         return WalkResult::interrupt();
799       if (diag.succeeded())
800         return WalkResult::skip();
801     }
802 
803     //===--------------------------------------------------------------------===//
804     // Mapping to linear ids.
805     //===--------------------------------------------------------------------===//
806     LDBG("+mapNestedForallToThreadsImpl linearDims");
807     LLVM_DEBUG(llvm::interleaveComma(
808                    blockDimsOfr, DBGS() << "--linearDims with blockDimsOfr:  ");
809                llvm::dbgs() << "\n");
810     int64_t numThreads = 1;
811     for (int64_t b : blockDims)
812       numThreads *= b;
813     GpuLinearIdBuilder gpuLinearIdBuilder(ctx, blockDimsOfr, numThreads);
814     diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
815         rewriter, transformOp, forallOp, numThreads, syncAfterDistribute,
816         gpuLinearIdBuilder);
817     // Use silenceable failure to encode "failure to match" and pass through.
818     if (diag.isDefiniteFailure())
819       return WalkResult::interrupt();
820     if (diag.succeeded())
821       return WalkResult::skip();
822 
823     //===--------------------------------------------------------------------===//
824     // Mapping to block ids (happens last so we can replay ThreadIdOp).
825     //===--------------------------------------------------------------------===//
826     LLVM_DEBUG(
827         llvm::interleaveComma(
828             blockDimsOfr, DBGS() << "mapNestedForallToThreadsImpl blockDims: ");
829         llvm::dbgs() << "\n");
830     GpuThreadIdBuilder gpuThreadIdBuilder(ctx, blockDimsOfr, blockDims);
831     diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
832         rewriter, transformOp, forallOp, blockDims, syncAfterDistribute,
833         gpuThreadIdBuilder);
834     // Use silenceable failure to encode "failure to match" and pass through.
835     if (diag.isDefiniteFailure())
836       return WalkResult::interrupt();
837 
838     return WalkResult::advance();
839   });
840   if (walkResult.wasInterrupted())
841     return diag;
842 
843   // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
844   // Here, the result of mapping determines the available mapping sizes.
845   replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
846                                           blockDims);
847 
848   return DiagnosedSilenceableFailure::success();
849 }
850 
851 void transform::MapNestedForallToThreads::getEffects(
852     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
853   onlyReadsHandle(getTarget(), effects);
854   modifiesPayload(effects);
855 }
856 
857 DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
858     Operation *target, ApplyToEachResultList &results, TransformState &state) {
859   LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
860   auto transformOp = cast<TransformOpInterface>(getOperation());
861 
862   // Basic high-level verifications.
863   if (!gpuLaunch)
864     return emitSilenceableError() << "Given target is not a gpu.launch";
865 
866   // Mapping to block ids.
867   SmallVector<int64_t> blockDims{getBlockDims()};
868 
869   DiagnosedSilenceableFailure diag =
870       checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
871                      blockDims[0], blockDims[1], blockDims[2]);
872   if (diag.isSilenceableFailure()) {
873     diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large";
874     return diag;
875   }
876 
877   // Set the GPU launch configuration for the block dims early, this is not
878   // subject to IR inspection.
879   IRRewriter rewriter(getContext());
880   diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
881                         std::nullopt, std::nullopt, blockDims[0], blockDims[1],
882                         blockDims[2]);
883 
884   rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
885   diag =
886       mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
887                                    getWarpDims(), getSyncAfterDistribute());
888 
889   return diag;
890 }
891 
892 //===----------------------------------------------------------------------===//
893 // Transform op registration
894 //===----------------------------------------------------------------------===//
895 
896 namespace {
897 /// Registers new ops and declares PDL as dependent dialect since the
898 /// additional ops are using PDL types for operands and results.
899 class GPUTransformDialectExtension
900     : public transform::TransformDialectExtension<
901           GPUTransformDialectExtension> {
902 public:
903   GPUTransformDialectExtension() {
904     declareDependentDialect<pdl::PDLDialect>();
905     declareGeneratedDialect<scf::SCFDialect>();
906     declareGeneratedDialect<arith::ArithDialect>();
907     declareGeneratedDialect<GPUDialect>();
908     registerTransformOps<
909 #define GET_OP_LIST
910 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
911         >();
912   }
913 };
914 } // namespace
915 
916 #define GET_OP_CLASSES
917 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
918 
919 void mlir::gpu::registerTransformDialectExtension(DialectRegistry &registry) {
920   registry.addExtensions<GPUTransformDialectExtension>();
921 }
922