xref: /llvm-project/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1 //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
10 
11 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/Dialect/GPU/TransformOps/Utils.h"
19 #include "mlir/Dialect/GPU/Transforms/Passes.h"
20 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
23 #include "mlir/Dialect/SCF/IR/SCF.h"
24 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
25 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
26 #include "mlir/Dialect/Utils/IndexingUtils.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinAttributes.h"
32 #include "mlir/IR/IRMapping.h"
33 #include "mlir/IR/MLIRContext.h"
34 #include "mlir/IR/OpDefinition.h"
35 #include "mlir/IR/Visitors.h"
36 #include "mlir/Support/LLVM.h"
37 #include "mlir/Transforms/DialectConversion.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/ErrorHandling.h"
43 #include <type_traits>
44 
45 using namespace mlir;
46 using namespace mlir::gpu;
47 using namespace mlir::transform;
48 using namespace mlir::transform::gpu;
49 
50 #define DEBUG_TYPE "gpu-transforms"
51 #define DEBUG_TYPE_ALIAS "gpu-transforms-alias"
52 
53 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
54 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
55 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
56 
57 //===----------------------------------------------------------------------===//
58 // Apply...ConversionPatternsOp
59 //===----------------------------------------------------------------------===//
60 
61 void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
62     TypeConverter &typeConverter, RewritePatternSet &patterns) {
63   auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
64   // NVVM uses alloca in the default address space to represent private
65   // memory allocations, so drop private annotations. NVVM uses address
66   // space 3 for shared memory. NVVM uses the default address space to
67   // represent global memory.
68   // Used in populateGpuToNVVMConversionPatternsso attaching here for now.
69   // TODO: We should have a single to_nvvm_type_converter.
70   populateGpuMemorySpaceAttributeConversions(
71       llvmTypeConverter, [](AddressSpace space) -> unsigned {
72         switch (space) {
73         case AddressSpace::Global:
74           return static_cast<unsigned>(
75               NVVM::NVVMMemorySpace::kGlobalMemorySpace);
76         case AddressSpace::Workgroup:
77           return static_cast<unsigned>(
78               NVVM::NVVMMemorySpace::kSharedMemorySpace);
79         case AddressSpace::Private:
80           return 0;
81         }
82         llvm_unreachable("unknown address space enum value");
83         return 0;
84       });
85   // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now.
86   // TODO: We should have a single to_nvvm_type_converter.
87   llvmTypeConverter.addConversion(
88       [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); });
89   populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns);
90 }
91 
92 LogicalResult
93 transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
94     transform::TypeConverterBuilderOpInterface builder) {
95   if (builder.getTypeConverterType() != "LLVMTypeConverter")
96     return emitOpError("expected LLVMTypeConverter");
97   return success();
98 }
99 
100 void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
101     TypeConverter &typeConverter, RewritePatternSet &patterns) {
102   auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
103   populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns);
104 }
105 
106 LogicalResult
107 transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
108     transform::TypeConverterBuilderOpInterface builder) {
109   if (builder.getTypeConverterType() != "LLVMTypeConverter")
110     return emitOpError("expected LLVMTypeConverter");
111   return success();
112 }
113 
114 void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
115     populatePatterns(TypeConverter &typeConverter,
116                      RewritePatternSet &patterns) {
117   auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
118   populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns);
119 }
120 
121 LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
122     verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
123   if (builder.getTypeConverterType() != "LLVMTypeConverter")
124     return emitOpError("expected LLVMTypeConverter");
125   return success();
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // Apply...PatternsOp
130 //===----------------------------------------------------------------------===//s
131 
132 void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
133   populateGpuRewritePatterns(patterns);
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // ApplyUnrollVectorsSubgroupMmaOp
138 //===----------------------------------------------------------------------===//
139 
140 /// Pick an unrolling order that will allow tensorcore operation to reuse LHS
141 /// register.
142 static std::optional<SmallVector<int64_t>>
143 gpuMmaUnrollOrder(vector::ContractionOp contract) {
144   SmallVector<int64_t> order;
145   // First make reduction the outer dimensions.
146   for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
147     if (vector::isReductionIterator(iter)) {
148       order.push_back(index);
149     }
150   }
151 
152   llvm::SmallDenseSet<int64_t> dims;
153   for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
154     dims.insert(cast<AffineDimExpr>(expr).getPosition());
155   }
156   // Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
157   for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
158     if (vector::isParallelIterator(iter) && dims.count(index)) {
159       order.push_back(index);
160     }
161   }
162   // Then the remaining parallel loops.
163   for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
164     if (vector::isParallelIterator(iter) && !dims.count(index)) {
165       order.push_back(index);
166     }
167   }
168   return order;
169 }
170 
171 /// Returns the target vector size for the target operation based on the native
172 /// vector size specified with `m`, `n`, and `k`.
173 static std::optional<SmallVector<int64_t>>
174 getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
175   if (auto contract = dyn_cast<vector::ContractionOp>(op)) {
176     int64_t contractRank = contract.getIteratorTypes().size();
177     if (contractRank < 3)
178       return std::nullopt;
179     SmallVector<int64_t> nativeSize(contractRank - 3, 1);
180     nativeSize.append({m, n, k});
181     return nativeSize;
182   }
183   if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
184     int64_t writeRank = writeOp.getVectorType().getRank();
185     if (writeRank < 2)
186       return std::nullopt;
187     SmallVector<int64_t> nativeSize(writeRank - 2, 1);
188     nativeSize.append({m, n});
189     return nativeSize;
190   }
191   if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
192     // Transfer read ops may need different shapes based on how they are being
193     // used. For simplicity just match the shape used by the extract strided op.
194     VectorType sliceType;
195     for (Operation *users : op->getUsers()) {
196       auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
197       if (!extract)
198         return std::nullopt;
199       auto vecType = cast<VectorType>(extract.getResult().getType());
200       if (sliceType && sliceType != vecType)
201         return std::nullopt;
202       sliceType = vecType;
203     }
204     return llvm::to_vector(sliceType.getShape());
205   }
206   if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
207     if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
208       // TODO: The condition for unrolling elementwise should be restricted
209       // only to operations that need unrolling (connected to the contract).
210       if (vecType.getRank() < 2)
211         return std::nullopt;
212 
213       // First check whether there is a slice to infer the shape from. This is
214       // required for cases where the accumulator type differs from the input
215       // types, in which case we will see an `arith.ext_` between the contract
216       // and transfer_read which needs to be unrolled.
217       VectorType sliceType;
218       for (Operation *users : op->getUsers()) {
219         auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
220         if (!extract)
221           return std::nullopt;
222         auto vecType = cast<VectorType>(extract.getResult().getType());
223         if (sliceType && sliceType != vecType)
224           return std::nullopt;
225         sliceType = vecType;
226       }
227       if (sliceType)
228         return llvm::to_vector(sliceType.getShape());
229 
230       // Else unroll for trailing elementwise.
231       SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1);
232       // Map elementwise ops to the output shape.
233       nativeSize.append({m, n});
234       return nativeSize;
235     }
236   }
237   return std::nullopt;
238 }
239 
240 void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
241     RewritePatternSet &patterns) {
242   auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
243     auto contract = dyn_cast<vector::ContractionOp>(op);
244     if (!contract)
245       return std::nullopt;
246     return gpuMmaUnrollOrder(contract);
247   };
248 
249   int64_t m = getM();
250   int64_t n = getN();
251   int64_t k = getK();
252   auto nativeShapeFn =
253       [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> {
254     return getSubgroupMmaNativeVectorSize(op, m, n, k);
255   };
256   vector::populateVectorUnrollPatterns(
257       patterns, vector::UnrollVectorOptions()
258                     .setNativeShapeFn(nativeShapeFn)
259                     .setUnrollTraversalOrderFn(unrollOrder));
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // EliminateBarriersOp
264 //===----------------------------------------------------------------------===//
265 
266 void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) {
267   populateGpuEliminateBarriersPatterns(patterns);
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // Block and thread mapping utilities.
272 //===----------------------------------------------------------------------===//
273 
274 namespace {
275 /// Local types used for mapping verification.
276 struct MappingKind {};
277 struct BlockMappingKind : MappingKind {};
278 struct ThreadMappingKind : MappingKind {};
279 } // namespace
280 
281 static DiagnosedSilenceableFailure
282 definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
283                       Operation *target, const Twine &message) {
284   if (transformOp.has_value())
285     return transformOp->emitDefiniteFailure() << message;
286   return emitDefiniteFailure(target, message);
287 }
288 
289 /// Check if given mapping attributes are one of the desired attributes
290 template <typename MappingKindType>
291 static DiagnosedSilenceableFailure
292 checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
293                            scf::ForallOp forallOp) {
294   if (!forallOp.getMapping().has_value()) {
295     return definiteFailureHelper(transformOp, forallOp,
296                                  "scf.forall op requires a mapping attribute");
297   }
298 
299   bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
300                                       llvm::IsaPred<GPUBlockMappingAttr>);
301   bool hasWarpgroupMapping = llvm::any_of(
302       forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
303   bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
304                                      llvm::IsaPred<GPUWarpMappingAttr>);
305   bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
306                                        llvm::IsaPred<GPUThreadMappingAttr>);
307   int64_t countMappingTypes = 0;
308   countMappingTypes += hasBlockMapping ? 1 : 0;
309   countMappingTypes += hasWarpgroupMapping ? 1 : 0;
310   countMappingTypes += hasWarpMapping ? 1 : 0;
311   countMappingTypes += hasThreadMapping ? 1 : 0;
312   if (countMappingTypes > 1) {
313     return definiteFailureHelper(
314         transformOp, forallOp,
315         "cannot mix different mapping types, use nesting");
316   }
317   if (std::is_same<MappingKindType, BlockMappingKind>::value &&
318       !hasBlockMapping) {
319     return definiteFailureHelper(
320         transformOp, forallOp,
321         "scf.forall op requires a mapping attribute of kind 'block'");
322   }
323   if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
324       !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) {
325     return definiteFailureHelper(transformOp, forallOp,
326                                  "scf.forall op requires a mapping attribute "
327                                  "of kind 'thread' or 'warp'");
328   }
329 
330   DenseSet<Attribute> seen;
331   for (Attribute map : forallOp.getMapping()->getValue()) {
332     if (seen.contains(map)) {
333       return definiteFailureHelper(
334           transformOp, forallOp,
335           "duplicate attribute, cannot map different loops "
336           "to the same mapping id");
337     }
338     seen.insert(map);
339   }
340 
341   auto isLinear = [](Attribute a) {
342     return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
343   };
344   if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
345       !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
346     return definiteFailureHelper(
347         transformOp, forallOp,
348         "cannot mix linear and non-linear mapping modes");
349   }
350 
351   return DiagnosedSilenceableFailure::success();
352 }
353 
354 template <typename MappingKindType>
355 static DiagnosedSilenceableFailure
356 verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
357                  scf::ForallOp forallOp) {
358   // Check the types of the mapping attributes match.
359   DiagnosedSilenceableFailure typeRes =
360       checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp);
361   if (!typeRes.succeeded())
362     return typeRes;
363 
364   // Perform other non-types verifications.
365   if (!forallOp.isNormalized())
366     return definiteFailureHelper(transformOp, forallOp,
367                                  "unsupported non-normalized loops");
368   if (forallOp.getNumResults() > 0)
369     return definiteFailureHelper(transformOp, forallOp,
370                                  "only bufferized scf.forall can be mapped");
371   bool useLinearMapping = cast<DeviceMappingAttrInterface>(
372                               forallOp.getMapping()->getValue().front())
373                               .isLinearMapping();
374   // TODO: This would be more natural with support for Optional<EnumParameter>
375   // in GPUDeviceMappingAttr.
376   int64_t maxNumMappingsSupported =
377       useLinearMapping ? (getMaxEnumValForMappingId() -
378                           static_cast<uint64_t>(MappingId::DimZ))
379                        : 3;
380   if (forallOp.getRank() > maxNumMappingsSupported) {
381     return definiteFailureHelper(transformOp, forallOp,
382                                  "scf.forall with rank > ")
383            << maxNumMappingsSupported
384            << " does not lower for the specified mapping attribute type";
385   }
386   auto numParallelIterations =
387       getConstantIntValues(forallOp.getMixedUpperBound());
388   if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
389     return definiteFailureHelper(
390         transformOp, forallOp,
391         "requires statically sized, normalized forall op");
392   }
393   return DiagnosedSilenceableFailure::success();
394 }
395 
396 /// Struct to return the result of the rewrite of a forall operation.
397 struct ForallRewriteResult {
398   SmallVector<int64_t> mappingSizes;
399   SmallVector<Value> mappingIds;
400 };
401 
402 /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
403 template <typename OpTy, typename OperationOrBlock>
404 static void
405 replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc,
406                             OperationOrBlock *parent, Value replacement,
407                             ArrayRef<int64_t> availableMappingSizes) {
408   parent->walk([&](OpTy idOp) {
409     if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
410       rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
411   });
412 }
413 
414 static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
415     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
416     scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
417     ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) {
418   LDBG("--start rewriteOneForallCommonImpl");
419 
420   // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
421   auto numParallelIterations =
422       getConstantIntValues(forallOp.getMixedUpperBound());
423   assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
424          "requires statically sized, normalized forall op");
425   SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
426   SetVector<Attribute> forallMappingAttrs;
427   forallMappingAttrs.insert(forallOp.getMapping()->getValue().begin(),
428                             forallOp.getMapping()->getValue().end());
429   auto comparator = [](Attribute a, Attribute b) -> bool {
430     return cast<DeviceMappingAttrInterface>(a).getMappingId() <
431            cast<DeviceMappingAttrInterface>(b).getMappingId();
432   };
433 
434   // Step 1.b. In the linear case, compute the max mapping to avoid needlessly
435   // mapping all dimensions. In the 3-D mapping case we need to map all
436   // dimensions.
437   DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>(
438       *llvm::max_element(forallMappingAttrs, comparator));
439   DeviceMappingAttrInterface maxLinearMapping;
440   if (maxMapping.isLinearMapping())
441     maxLinearMapping = maxMapping;
442   for (auto attr : gpuIdBuilder.mappingAttributes) {
443     // If attr overflows, just skip.
444     if (maxLinearMapping && comparator(maxLinearMapping, attr))
445       continue;
446     // Try to insert. If element was already present, just continue.
447     if (!forallMappingAttrs.insert(attr))
448       continue;
449     // Otherwise, we have a new insertion without a size -> use size 1.
450     tmpMappingSizes.push_back(1);
451   }
452   LLVM_DEBUG(
453       llvm::interleaveComma(
454           tmpMappingSizes,
455           DBGS() << "----tmpMappingSizes extracted from scf.forall op: ");
456       llvm::dbgs() << "\n");
457 
458   // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
459   SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
460       forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
461   LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
462                                    DBGS() << "----forallMappingSizes: ");
463              llvm::dbgs() << "\n"; llvm::interleaveComma(
464                  forallMappingAttrs, DBGS() << "----forallMappingAttrs: ");
465              llvm::dbgs() << "\n");
466 
467   // Step 3. Generate the mappingIdOps using the provided generator.
468   Location loc = forallOp.getLoc();
469   OpBuilder::InsertionGuard guard(rewriter);
470   rewriter.setInsertionPoint(forallOp);
471   SmallVector<int64_t> originalBasis(availableMappingSizes);
472   bool originalBasisWasProvided = !originalBasis.empty();
473   if (!originalBasisWasProvided) {
474     originalBasis = forallMappingSizes;
475     while (originalBasis.size() < 3)
476       originalBasis.push_back(1);
477   }
478 
479   IdBuilderResult builderResult =
480       gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
481 
482   // Step 4. Map the induction variables to the mappingIdOps, this may involve
483   // a permutation.
484   SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
485   IRMapping bvm;
486   for (auto [iv, dim] : llvm::zip_equal(
487            forallOp.getInductionVars(),
488            forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
489     auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
490     Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
491     bvm.map(iv, peIdOp);
492   }
493 
494   // Step 5. If the originalBasis is already known, create conditionals to
495   // predicate the region. Otherwise, the current forall determines the
496   // originalBasis and no predication occurs.
497   Value predicate;
498   if (originalBasisWasProvided) {
499     SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
500     SmallVector<int64_t> availableMappingSizes =
501         builderResult.availableMappingSizes;
502     SmallVector<Value> activeIdOps = builderResult.activeIdOps;
503     // clang-format off
504     LLVM_DEBUG(
505         llvm::interleaveComma(
506           activeMappingSizes, DBGS() << "----activeMappingSizes: ");
507         llvm::dbgs() << "\n";
508         llvm::interleaveComma(
509           availableMappingSizes, DBGS() << "----availableMappingSizes: ");
510         llvm::dbgs() << "\n";
511         llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: ");
512         llvm::dbgs() << "\n");
513     // clang-format on
514     for (auto [activeId, activeMappingSize, availableMappingSize] :
515          llvm::zip_equal(activeIdOps, activeMappingSizes,
516                          availableMappingSizes)) {
517       if (activeMappingSize > availableMappingSize) {
518         return definiteFailureHelper(
519             transformOp, forallOp,
520             "Trying to map to fewer GPU threads than loop iterations but "
521             "overprovisioning is not yet supported. "
522             "Try additional tiling of the before mapping or map to more "
523             "threads.");
524       }
525       if (activeMappingSize == availableMappingSize)
526         continue;
527       Value idx =
528           rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
529       Value tmpPredicate = rewriter.create<arith::CmpIOp>(
530           loc, arith::CmpIPredicate::ult, activeId, idx);
531       LDBG("----predicate: " << tmpPredicate);
532       predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
533                                                              tmpPredicate)
534                             : tmpPredicate;
535     }
536   }
537 
538   // Step 6. Move the body of forallOp.
539   // Erase the terminator first, it will not be used.
540   rewriter.eraseOp(forallOp.getTerminator());
541   Block *targetBlock;
542   Block::iterator insertionPoint;
543   if (predicate) {
544     // Step 6.a. If predicated, move at the beginning.
545     auto ifOp = rewriter.create<scf::IfOp>(loc, predicate,
546                                            /*withElseRegion=*/false);
547     targetBlock = ifOp.thenBlock();
548     insertionPoint = ifOp.thenBlock()->begin();
549   } else {
550     // Step 6.b. Otherwise, move inline just at the rewriter insertion
551     // point.
552     targetBlock = forallOp->getBlock();
553     insertionPoint = rewriter.getInsertionPoint();
554   }
555   Block &sourceBlock = forallOp.getRegion().front();
556   targetBlock->getOperations().splice(insertionPoint,
557                                       sourceBlock.getOperations());
558 
559   // Step 7. RAUW indices.
560   for (Value loopIndex : forallOp.getInductionVars()) {
561     Value threadIdx = bvm.lookup(loopIndex);
562     rewriter.replaceAllUsesWith(loopIndex, threadIdx);
563   }
564 
565   // Step 8. Erase old op.
566   rewriter.eraseOp(forallOp);
567 
568   LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
569                                    DBGS() << "----result forallMappingSizes: ");
570              llvm::dbgs() << "\n"; llvm::interleaveComma(
571                  mappingIdOps, DBGS() << "----result mappingIdOps: ");
572              llvm::dbgs() << "\n");
573 
574   result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
575   return DiagnosedSilenceableFailure::success();
576 }
577 
578 //===----------------------------------------------------------------------===//
579 // MapForallToBlocks
580 //===----------------------------------------------------------------------===//
581 
582 DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
583     RewriterBase &rewriter, TransformOpInterface transformOp,
584     scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
585     const GpuIdBuilder &gpuIdBuilder) {
586   LDBG("Start mapForallToBlocksImpl");
587 
588   {
589     // GPU-specific verifications. There is no better place to anchor
590     // those right now: the ForallOp is target-independent and the transform
591     // op does not apply to individual ForallOp.
592     DiagnosedSilenceableFailure diag =
593         verifyGpuMapping<BlockMappingKind>(transformOp, forallOp);
594     if (!diag.succeeded())
595       return diag;
596   }
597 
598   Location loc = forallOp.getLoc();
599   Block *parentBlock = forallOp->getBlock();
600   Value zero;
601   {
602     // Create an early zero index value for replacements and immediately reset
603     // the insertion point.
604     OpBuilder::InsertionGuard guard(rewriter);
605     rewriter.setInsertionPointToStart(parentBlock);
606     zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
607   }
608 
609   ForallRewriteResult rewriteResult;
610   DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
611       rewriter, transformOp, forallOp,
612       /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder);
613 
614   // Return if anything goes wrong, use silenceable failure as a match
615   // failure.
616   if (!diag.succeeded())
617     return diag;
618 
619   // If gridDims was not provided already, set it from the return.
620   if (gridDims.empty()) {
621     gridDims = rewriteResult.mappingSizes;
622     while (gridDims.size() < 3)
623       gridDims.push_back(1);
624   }
625   assert(gridDims.size() == 3 && "Need 3-D gridDims");
626 
627   // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
628   // Here, the result of mapping determines the available mapping sizes.
629   replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
630                                           rewriteResult.mappingSizes);
631 
632   return DiagnosedSilenceableFailure::success();
633 }
634 
635 DiagnosedSilenceableFailure
636 mlir::transform::gpu::findTopLevelForallOp(Operation *target,
637                                            scf::ForallOp &topLevelForallOp,
638                                            TransformOpInterface transformOp) {
639   auto walkResult = target->walk([&](scf::ForallOp forallOp) {
640     if (forallOp->getParentOfType<scf::ForallOp>())
641       return WalkResult::advance();
642     if (topLevelForallOp)
643       // TODO: Handle multiple forall if they are independent.
644       return WalkResult::interrupt();
645     topLevelForallOp = forallOp;
646     return WalkResult::advance();
647   });
648 
649   if (walkResult.wasInterrupted() || !topLevelForallOp)
650     return transformOp.emitSilenceableError()
651            << "could not find a unique topLevel scf.forall";
652   return DiagnosedSilenceableFailure::success();
653 }
654 
655 DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
656     transform::TransformRewriter &rewriter, Operation *target,
657     ApplyToEachResultList &results, transform::TransformState &state) {
658   LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
659   auto transformOp = cast<TransformOpInterface>(getOperation());
660 
661   if (!getGenerateGpuLaunch() && !gpuLaunch) {
662     DiagnosedSilenceableFailure diag =
663         emitSilenceableError()
664         << "Given target is not gpu.launch, set `generate_gpu_launch` "
665            "attribute";
666     diag.attachNote(target->getLoc()) << "when applied to this payload op";
667     return diag;
668   }
669 
670   scf::ForallOp topLevelForallOp;
671   DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp(
672       target, topLevelForallOp, transformOp);
673   if (!diag.succeeded()) {
674     diag.attachNote(target->getLoc()) << "when applied to this payload op";
675     return diag;
676   }
677   assert(topLevelForallOp && "expect an scf.forall");
678 
679   SmallVector<int64_t> gridDims{getGridDims()};
680   if (!getGenerateGpuLaunch() && gridDims.size() != 3)
681     return transformOp.emitDefiniteFailure("transform require size-3 mapping");
682 
683   OpBuilder::InsertionGuard guard(rewriter);
684   rewriter.setInsertionPoint(topLevelForallOp);
685 
686   // Generate gpu launch here and move the forall inside
687   if (getGenerateGpuLaunch()) {
688     DiagnosedSilenceableFailure diag =
689         createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
690     if (!diag.succeeded())
691       return diag;
692 
693     rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
694     Operation *newForallOp = rewriter.clone(*topLevelForallOp);
695     rewriter.eraseOp(topLevelForallOp);
696     topLevelForallOp = cast<scf::ForallOp>(newForallOp);
697   }
698 
699   // The BlockIdBuilder adapts to whatever is thrown at it.
700   bool useLinearMapping = false;
701   if (topLevelForallOp.getMapping()) {
702     auto mappingAttr = cast<DeviceMappingAttrInterface>(
703         topLevelForallOp.getMapping()->getValue().front());
704     useLinearMapping = mappingAttr.isLinearMapping();
705   }
706   GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
707 
708   diag = mlir::transform::gpu::mapForallToBlocksImpl(
709       rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
710   if (!diag.succeeded())
711     return diag;
712 
713   // Set the GPU launch configuration for the grid dims late, this is
714   // subject to IR inspection.
715   diag = alterGpuLaunch(rewriter, gpuLaunch,
716                         cast<TransformOpInterface>(getOperation()), gridDims[0],
717                         gridDims[1], gridDims[2]);
718 
719   results.push_back(gpuLaunch);
720   return diag;
721 }
722 
723 LogicalResult transform::MapForallToBlocks::verify() {
724   if (!getGridDims().empty() && getGridDims().size() != 3) {
725     return emitOpError() << "transform requires empty or size-3 grid_dims";
726   }
727   return success();
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // MapNestedForallToThreads
732 //===----------------------------------------------------------------------===//
733 
734 static DiagnosedSilenceableFailure checkMappingSpec(
735     std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
736     ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
737     int factor, bool useLinearMapping = false) {
738   if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
739     auto diag = definiteFailureHelper(
740         transformOp, forallOp,
741         Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
742             std::to_string(factor));
743     return diag;
744   }
745   if (computeProduct(numParallelIterations) * factor >
746       computeProduct(blockOrGridSizes)) {
747     auto diag = definiteFailureHelper(
748         transformOp, forallOp,
749         Twine("the number of required parallel resources (blocks or "
750               "threads) ") +
751             std::to_string(computeProduct(numParallelIterations) * factor) +
752             std::string(" overflows the number of available resources ") +
753             std::to_string(computeProduct(blockOrGridSizes)));
754     return diag;
755   }
756   return DiagnosedSilenceableFailure::success();
757 }
758 
759 static DiagnosedSilenceableFailure
760 getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
761                    scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
762                    int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
763   auto mappingAttr = cast<DeviceMappingAttrInterface>(
764       forallOp.getMapping()->getValue().front());
765   bool useLinearMapping = mappingAttr.isLinearMapping();
766 
767   // Sanity checks that may result in runtime verification errors.
768   auto numParallelIterations =
769       getConstantIntValues((forallOp.getMixedUpperBound()));
770   if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
771     return definiteFailureHelper(
772         transformOp, forallOp,
773         "requires statically sized, normalized forall op");
774   }
775   int64_t factor = 1;
776   if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
777     factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
778   } else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
779     factor = warpSize;
780   }
781   DiagnosedSilenceableFailure diag =
782       checkMappingSpec(transformOp, forallOp, numParallelIterations.value(),
783                        blockSizes, factor, useLinearMapping);
784   if (!diag.succeeded())
785     return diag;
786 
787   // Start mapping.
788   MLIRContext *ctx = forallOp.getContext();
789   gpuIdBuilder =
790       TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
791           .Case([&](GPUWarpgroupMappingAttr) {
792             return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
793           })
794           .Case([&](GPUWarpMappingAttr) {
795             return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
796           })
797           .Case([&](GPUThreadMappingAttr) {
798             return GpuThreadIdBuilder(ctx, useLinearMapping);
799           })
800           .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
801             llvm_unreachable("unknown mapping attribute");
802           });
803   return DiagnosedSilenceableFailure::success();
804 }
805 
806 DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl(
807     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
808     scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize,
809     bool syncAfterDistribute) {
810 
811   {
812     // GPU-specific verifications. There is no better place to anchor
813     // those right now: the ForallOp is target-independent and the transform
814     // op does not apply to individual ForallOp.
815     DiagnosedSilenceableFailure diag =
816         verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp);
817     if (!diag.succeeded())
818       return diag;
819   }
820 
821   GpuIdBuilder gpuIdBuilder;
822   {
823     // Try to construct the id builder, if it fails, return.
824     DiagnosedSilenceableFailure diag = getThreadIdBuilder(
825         transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
826     if (!diag.succeeded())
827       return diag;
828   }
829 
830   Location loc = forallOp.getLoc();
831   OpBuilder::InsertionGuard g(rewriter);
832   // Insert after to allow for syncthreads after `forall` is erased.
833   rewriter.setInsertionPointAfter(forallOp);
834   ForallRewriteResult rewriteResult;
835   DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
836       rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
837   if (!diag.succeeded())
838     return diag;
839   // Add a syncthreads if needed. TODO: warpsync
840   if (syncAfterDistribute)
841     rewriter.create<BarrierOp>(loc);
842 
843   return DiagnosedSilenceableFailure::success();
844 }
845 
846 DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
847     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
848     Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
849     bool syncAfterDistribute) {
850   LDBG("Start mapNestedForallToThreadsImpl");
851   if (blockDims.size() != 3) {
852     return definiteFailureHelper(transformOp, target,
853                                  "requires size-3 thread mapping");
854   }
855 
856   // Create an early zero index value for replacements.
857   Location loc = target->getLoc();
858   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
859   DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
860   WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
861     diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
862         rewriter, transformOp, forallOp, blockDims, warpSize,
863         syncAfterDistribute);
864     if (diag.isDefiniteFailure())
865       return WalkResult::interrupt();
866     if (diag.succeeded())
867       return WalkResult::skip();
868     return WalkResult::advance();
869   });
870   if (walkResult.wasInterrupted())
871     return diag;
872 
873   // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
874   // Here, the result of mapping determines the available mapping sizes.
875   replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
876                                           blockDims);
877 
878   return DiagnosedSilenceableFailure::success();
879 }
880 
881 DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
882     transform::TransformRewriter &rewriter, Operation *target,
883     ApplyToEachResultList &results, TransformState &state) {
884   LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
885   auto transformOp = cast<TransformOpInterface>(getOperation());
886 
887   // Basic high-level verifications.
888   if (!gpuLaunch)
889     return emitSilenceableError() << "Given target is not a gpu.launch";
890 
891   // Mapping to block ids.
892   SmallVector<int64_t> blockDims{getBlockDims()};
893   DiagnosedSilenceableFailure diag =
894       checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
895                      blockDims[0], blockDims[1], blockDims[2]);
896   if (diag.isSilenceableFailure()) {
897     diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large";
898     return diag;
899   }
900 
901   // Set the GPU launch configuration for the block dims early, this is not
902   // subject to IR inspection.
903   diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
904                         std::nullopt, std::nullopt, blockDims[0], blockDims[1],
905                         blockDims[2]);
906 
907   rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
908   diag =
909       mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
910                                    getWarpSize(), getSyncAfterDistribute());
911 
912   results.push_back(gpuLaunch.getOperation());
913   return diag;
914 }
915 
916 //===----------------------------------------------------------------------===//
917 // Transform op registration
918 //===----------------------------------------------------------------------===//
919 
920 namespace {
921 /// Registers new ops and declares PDL as dependent dialect since the
922 /// additional ops are using PDL types for operands and results.
923 class GPUTransformDialectExtension
924     : public transform::TransformDialectExtension<
925           GPUTransformDialectExtension> {
926 public:
927   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
928 
929   GPUTransformDialectExtension() {
930     declareGeneratedDialect<scf::SCFDialect>();
931     declareGeneratedDialect<arith::ArithDialect>();
932     declareGeneratedDialect<GPUDialect>();
933     registerTransformOps<
934 #define GET_OP_LIST
935 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
936         >();
937   }
938 };
939 } // namespace
940 
941 #define GET_OP_CLASSES
942 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
943 
944 void mlir::gpu::registerTransformDialectExtension(DialectRegistry &registry) {
945   registry.addExtensions<GPUTransformDialectExtension>();
946 }
947