xref: /llvm-project/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
189bb0caeSGuray Ozen //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===//
289bb0caeSGuray Ozen //
389bb0caeSGuray Ozen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
489bb0caeSGuray Ozen // See https://llvm.org/LICENSE.txt for license information.
589bb0caeSGuray Ozen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
689bb0caeSGuray Ozen //
789bb0caeSGuray Ozen //===----------------------------------------------------------------------===//
889bb0caeSGuray Ozen 
989bb0caeSGuray Ozen #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
1089bb0caeSGuray Ozen 
11888717e8SNicolas Vasilache #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12888717e8SNicolas Vasilache #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
13888717e8SNicolas Vasilache #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14c59465e1SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
1589bb0caeSGuray Ozen #include "mlir/Dialect/Arith/IR/Arith.h"
16c59465e1SNicolas Vasilache #include "mlir/Dialect/Func/IR/FuncOps.h"
1789bb0caeSGuray Ozen #include "mlir/Dialect/GPU/IR/GPUDialect.h"
1890ecfa2aSNicolas Vasilache #include "mlir/Dialect/GPU/TransformOps/Utils.h"
19888717e8SNicolas Vasilache #include "mlir/Dialect/GPU/Transforms/Passes.h"
20888717e8SNicolas Vasilache #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
219ab34689SAlex Zinenko #include "mlir/Dialect/MemRef/IR/MemRef.h"
22beaffb04SGuray Ozen #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
2389bb0caeSGuray Ozen #include "mlir/Dialect/SCF/IR/SCF.h"
2489bb0caeSGuray Ozen #include "mlir/Dialect/Transform/IR/TransformDialect.h"
255a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
26c59465e1SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h"
279ab34689SAlex Zinenko #include "mlir/Dialect/Vector/IR/VectorOps.h"
28ff8775f3SQuinn Dawkins #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
29c59465e1SNicolas Vasilache #include "mlir/IR/AffineExpr.h"
30c59465e1SNicolas Vasilache #include "mlir/IR/Builders.h"
31768615bbSNicolas Vasilache #include "mlir/IR/BuiltinAttributes.h"
324d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
33c59465e1SNicolas Vasilache #include "mlir/IR/MLIRContext.h"
34aafb52d7SNicolas Vasilache #include "mlir/IR/OpDefinition.h"
35c59465e1SNicolas Vasilache #include "mlir/IR/Visitors.h"
36aafb52d7SNicolas Vasilache #include "mlir/Support/LLVM.h"
37888717e8SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
38768615bbSNicolas Vasilache #include "llvm/ADT/STLExtras.h"
39768615bbSNicolas Vasilache #include "llvm/ADT/SmallVector.h"
409ab34689SAlex Zinenko #include "llvm/ADT/TypeSwitch.h"
41768615bbSNicolas Vasilache #include "llvm/Support/Debug.h"
4244e6318cSNicolas Vasilache #include "llvm/Support/ErrorHandling.h"
4392f088d3SNicolas Vasilache #include <type_traits>
4489bb0caeSGuray Ozen 
4589bb0caeSGuray Ozen using namespace mlir;
4689bb0caeSGuray Ozen using namespace mlir::gpu;
4789bb0caeSGuray Ozen using namespace mlir::transform;
48c59465e1SNicolas Vasilache using namespace mlir::transform::gpu;
4989bb0caeSGuray Ozen 
50768615bbSNicolas Vasilache #define DEBUG_TYPE "gpu-transforms"
519ab34689SAlex Zinenko #define DEBUG_TYPE_ALIAS "gpu-transforms-alias"
52768615bbSNicolas Vasilache 
53768615bbSNicolas Vasilache #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
54768615bbSNicolas Vasilache #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
559ab34689SAlex Zinenko #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
569ab34689SAlex Zinenko 
579ab34689SAlex Zinenko //===----------------------------------------------------------------------===//
58888717e8SNicolas Vasilache // Apply...ConversionPatternsOp
59888717e8SNicolas Vasilache //===----------------------------------------------------------------------===//
60888717e8SNicolas Vasilache 
61888717e8SNicolas Vasilache void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
62888717e8SNicolas Vasilache     TypeConverter &typeConverter, RewritePatternSet &patterns) {
63888717e8SNicolas Vasilache   auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
64888717e8SNicolas Vasilache   // NVVM uses alloca in the default address space to represent private
65888717e8SNicolas Vasilache   // memory allocations, so drop private annotations. NVVM uses address
66888717e8SNicolas Vasilache   // space 3 for shared memory. NVVM uses the default address space to
67888717e8SNicolas Vasilache   // represent global memory.
68888717e8SNicolas Vasilache   // Used in populateGpuToNVVMConversionPatternsso attaching here for now.
69888717e8SNicolas Vasilache   // TODO: We should have a single to_nvvm_type_converter.
70888717e8SNicolas Vasilache   populateGpuMemorySpaceAttributeConversions(
71888717e8SNicolas Vasilache       llvmTypeConverter, [](AddressSpace space) -> unsigned {
72888717e8SNicolas Vasilache         switch (space) {
73888717e8SNicolas Vasilache         case AddressSpace::Global:
74888717e8SNicolas Vasilache           return static_cast<unsigned>(
75888717e8SNicolas Vasilache               NVVM::NVVMMemorySpace::kGlobalMemorySpace);
76888717e8SNicolas Vasilache         case AddressSpace::Workgroup:
77888717e8SNicolas Vasilache           return static_cast<unsigned>(
78888717e8SNicolas Vasilache               NVVM::NVVMMemorySpace::kSharedMemorySpace);
79888717e8SNicolas Vasilache         case AddressSpace::Private:
80888717e8SNicolas Vasilache           return 0;
81888717e8SNicolas Vasilache         }
82888717e8SNicolas Vasilache         llvm_unreachable("unknown address space enum value");
83888717e8SNicolas Vasilache         return 0;
84888717e8SNicolas Vasilache       });
85888717e8SNicolas Vasilache   // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now.
86888717e8SNicolas Vasilache   // TODO: We should have a single to_nvvm_type_converter.
87888717e8SNicolas Vasilache   llvmTypeConverter.addConversion(
88888717e8SNicolas Vasilache       [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); });
89888717e8SNicolas Vasilache   populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns);
90888717e8SNicolas Vasilache }
91888717e8SNicolas Vasilache 
92888717e8SNicolas Vasilache LogicalResult
93888717e8SNicolas Vasilache transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
94888717e8SNicolas Vasilache     transform::TypeConverterBuilderOpInterface builder) {
95888717e8SNicolas Vasilache   if (builder.getTypeConverterType() != "LLVMTypeConverter")
96888717e8SNicolas Vasilache     return emitOpError("expected LLVMTypeConverter");
97888717e8SNicolas Vasilache   return success();
98888717e8SNicolas Vasilache }
99888717e8SNicolas Vasilache 
100888717e8SNicolas Vasilache void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
101888717e8SNicolas Vasilache     TypeConverter &typeConverter, RewritePatternSet &patterns) {
102888717e8SNicolas Vasilache   auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
103888717e8SNicolas Vasilache   populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns);
104888717e8SNicolas Vasilache }
105888717e8SNicolas Vasilache 
106888717e8SNicolas Vasilache LogicalResult
107888717e8SNicolas Vasilache transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
108888717e8SNicolas Vasilache     transform::TypeConverterBuilderOpInterface builder) {
109888717e8SNicolas Vasilache   if (builder.getTypeConverterType() != "LLVMTypeConverter")
110888717e8SNicolas Vasilache     return emitOpError("expected LLVMTypeConverter");
111888717e8SNicolas Vasilache   return success();
112888717e8SNicolas Vasilache }
113888717e8SNicolas Vasilache 
114888717e8SNicolas Vasilache void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
115888717e8SNicolas Vasilache     populatePatterns(TypeConverter &typeConverter,
116888717e8SNicolas Vasilache                      RewritePatternSet &patterns) {
117888717e8SNicolas Vasilache   auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
118888717e8SNicolas Vasilache   populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns);
119888717e8SNicolas Vasilache }
120888717e8SNicolas Vasilache 
121888717e8SNicolas Vasilache LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
122888717e8SNicolas Vasilache     verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
123888717e8SNicolas Vasilache   if (builder.getTypeConverterType() != "LLVMTypeConverter")
124888717e8SNicolas Vasilache     return emitOpError("expected LLVMTypeConverter");
125888717e8SNicolas Vasilache   return success();
126888717e8SNicolas Vasilache }
127888717e8SNicolas Vasilache 
128888717e8SNicolas Vasilache //===----------------------------------------------------------------------===//
129888717e8SNicolas Vasilache // Apply...PatternsOp
130888717e8SNicolas Vasilache //===----------------------------------------------------------------------===//s
131888717e8SNicolas Vasilache 
132888717e8SNicolas Vasilache void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
133888717e8SNicolas Vasilache   populateGpuRewritePatterns(patterns);
134888717e8SNicolas Vasilache }
135888717e8SNicolas Vasilache 
136888717e8SNicolas Vasilache //===----------------------------------------------------------------------===//
137ff8775f3SQuinn Dawkins // ApplyUnrollVectorsSubgroupMmaOp
138ff8775f3SQuinn Dawkins //===----------------------------------------------------------------------===//
139ff8775f3SQuinn Dawkins 
140ff8775f3SQuinn Dawkins /// Pick an unrolling order that will allow tensorcore operation to reuse LHS
141ff8775f3SQuinn Dawkins /// register.
142ff8775f3SQuinn Dawkins static std::optional<SmallVector<int64_t>>
143ff8775f3SQuinn Dawkins gpuMmaUnrollOrder(vector::ContractionOp contract) {
144ff8775f3SQuinn Dawkins   SmallVector<int64_t> order;
145ff8775f3SQuinn Dawkins   // First make reduction the outer dimensions.
146ff8775f3SQuinn Dawkins   for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
147ff8775f3SQuinn Dawkins     if (vector::isReductionIterator(iter)) {
148ff8775f3SQuinn Dawkins       order.push_back(index);
149ff8775f3SQuinn Dawkins     }
150ff8775f3SQuinn Dawkins   }
151ff8775f3SQuinn Dawkins 
152ff8775f3SQuinn Dawkins   llvm::SmallDenseSet<int64_t> dims;
153ff8775f3SQuinn Dawkins   for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
1541609f1c2Slong.chen     dims.insert(cast<AffineDimExpr>(expr).getPosition());
155ff8775f3SQuinn Dawkins   }
156ff8775f3SQuinn Dawkins   // Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
157ff8775f3SQuinn Dawkins   for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
158ff8775f3SQuinn Dawkins     if (vector::isParallelIterator(iter) && dims.count(index)) {
159ff8775f3SQuinn Dawkins       order.push_back(index);
160ff8775f3SQuinn Dawkins     }
161ff8775f3SQuinn Dawkins   }
162ff8775f3SQuinn Dawkins   // Then the remaining parallel loops.
163ff8775f3SQuinn Dawkins   for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
164ff8775f3SQuinn Dawkins     if (vector::isParallelIterator(iter) && !dims.count(index)) {
165ff8775f3SQuinn Dawkins       order.push_back(index);
166ff8775f3SQuinn Dawkins     }
167ff8775f3SQuinn Dawkins   }
168ff8775f3SQuinn Dawkins   return order;
169ff8775f3SQuinn Dawkins }
170ff8775f3SQuinn Dawkins 
171ff8775f3SQuinn Dawkins /// Returns the target vector size for the target operation based on the native
172ff8775f3SQuinn Dawkins /// vector size specified with `m`, `n`, and `k`.
173ff8775f3SQuinn Dawkins static std::optional<SmallVector<int64_t>>
174ff8775f3SQuinn Dawkins getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
175ff8775f3SQuinn Dawkins   if (auto contract = dyn_cast<vector::ContractionOp>(op)) {
176ff8775f3SQuinn Dawkins     int64_t contractRank = contract.getIteratorTypes().size();
177ff8775f3SQuinn Dawkins     if (contractRank < 3)
178ff8775f3SQuinn Dawkins       return std::nullopt;
179ff8775f3SQuinn Dawkins     SmallVector<int64_t> nativeSize(contractRank - 3, 1);
180ff8775f3SQuinn Dawkins     nativeSize.append({m, n, k});
181ff8775f3SQuinn Dawkins     return nativeSize;
182ff8775f3SQuinn Dawkins   }
183ff8775f3SQuinn Dawkins   if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
184ff8775f3SQuinn Dawkins     int64_t writeRank = writeOp.getVectorType().getRank();
185ff8775f3SQuinn Dawkins     if (writeRank < 2)
186ff8775f3SQuinn Dawkins       return std::nullopt;
187ff8775f3SQuinn Dawkins     SmallVector<int64_t> nativeSize(writeRank - 2, 1);
188ff8775f3SQuinn Dawkins     nativeSize.append({m, n});
189ff8775f3SQuinn Dawkins     return nativeSize;
190ff8775f3SQuinn Dawkins   }
191ff8775f3SQuinn Dawkins   if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
192ff8775f3SQuinn Dawkins     // Transfer read ops may need different shapes based on how they are being
193ff8775f3SQuinn Dawkins     // used. For simplicity just match the shape used by the extract strided op.
194ff8775f3SQuinn Dawkins     VectorType sliceType;
195ff8775f3SQuinn Dawkins     for (Operation *users : op->getUsers()) {
196ff8775f3SQuinn Dawkins       auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
197ff8775f3SQuinn Dawkins       if (!extract)
198ff8775f3SQuinn Dawkins         return std::nullopt;
199a5757c5bSChristian Sigg       auto vecType = cast<VectorType>(extract.getResult().getType());
200ff8775f3SQuinn Dawkins       if (sliceType && sliceType != vecType)
201ff8775f3SQuinn Dawkins         return std::nullopt;
202ff8775f3SQuinn Dawkins       sliceType = vecType;
203ff8775f3SQuinn Dawkins     }
204ff8775f3SQuinn Dawkins     return llvm::to_vector(sliceType.getShape());
205ff8775f3SQuinn Dawkins   }
206ff8775f3SQuinn Dawkins   if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
207a5757c5bSChristian Sigg     if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
208ff8775f3SQuinn Dawkins       // TODO: The condition for unrolling elementwise should be restricted
209ff8775f3SQuinn Dawkins       // only to operations that need unrolling (connected to the contract).
210ff8775f3SQuinn Dawkins       if (vecType.getRank() < 2)
211ff8775f3SQuinn Dawkins         return std::nullopt;
212ff8775f3SQuinn Dawkins 
213ff8775f3SQuinn Dawkins       // First check whether there is a slice to infer the shape from. This is
214ff8775f3SQuinn Dawkins       // required for cases where the accumulator type differs from the input
215ff8775f3SQuinn Dawkins       // types, in which case we will see an `arith.ext_` between the contract
216ff8775f3SQuinn Dawkins       // and transfer_read which needs to be unrolled.
217ff8775f3SQuinn Dawkins       VectorType sliceType;
218ff8775f3SQuinn Dawkins       for (Operation *users : op->getUsers()) {
219ff8775f3SQuinn Dawkins         auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
220ff8775f3SQuinn Dawkins         if (!extract)
221ff8775f3SQuinn Dawkins           return std::nullopt;
222a5757c5bSChristian Sigg         auto vecType = cast<VectorType>(extract.getResult().getType());
223ff8775f3SQuinn Dawkins         if (sliceType && sliceType != vecType)
224ff8775f3SQuinn Dawkins           return std::nullopt;
225ff8775f3SQuinn Dawkins         sliceType = vecType;
226ff8775f3SQuinn Dawkins       }
227ff8775f3SQuinn Dawkins       if (sliceType)
228ff8775f3SQuinn Dawkins         return llvm::to_vector(sliceType.getShape());
229ff8775f3SQuinn Dawkins 
230ff8775f3SQuinn Dawkins       // Else unroll for trailing elementwise.
231ff8775f3SQuinn Dawkins       SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1);
232ff8775f3SQuinn Dawkins       // Map elementwise ops to the output shape.
233ff8775f3SQuinn Dawkins       nativeSize.append({m, n});
234ff8775f3SQuinn Dawkins       return nativeSize;
235ff8775f3SQuinn Dawkins     }
236ff8775f3SQuinn Dawkins   }
237ff8775f3SQuinn Dawkins   return std::nullopt;
238ff8775f3SQuinn Dawkins }
239ff8775f3SQuinn Dawkins 
240ff8775f3SQuinn Dawkins void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
241ff8775f3SQuinn Dawkins     RewritePatternSet &patterns) {
242ff8775f3SQuinn Dawkins   auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
243ff8775f3SQuinn Dawkins     auto contract = dyn_cast<vector::ContractionOp>(op);
244ff8775f3SQuinn Dawkins     if (!contract)
245ff8775f3SQuinn Dawkins       return std::nullopt;
246ff8775f3SQuinn Dawkins     return gpuMmaUnrollOrder(contract);
247ff8775f3SQuinn Dawkins   };
248ff8775f3SQuinn Dawkins 
249ff8775f3SQuinn Dawkins   int64_t m = getM();
250ff8775f3SQuinn Dawkins   int64_t n = getN();
251ff8775f3SQuinn Dawkins   int64_t k = getK();
252ff8775f3SQuinn Dawkins   auto nativeShapeFn =
253ff8775f3SQuinn Dawkins       [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> {
254ff8775f3SQuinn Dawkins     return getSubgroupMmaNativeVectorSize(op, m, n, k);
255ff8775f3SQuinn Dawkins   };
256ff8775f3SQuinn Dawkins   vector::populateVectorUnrollPatterns(
257ff8775f3SQuinn Dawkins       patterns, vector::UnrollVectorOptions()
258ff8775f3SQuinn Dawkins                     .setNativeShapeFn(nativeShapeFn)
259ff8775f3SQuinn Dawkins                     .setUnrollTraversalOrderFn(unrollOrder));
260ff8775f3SQuinn Dawkins }
261ff8775f3SQuinn Dawkins 
262ff8775f3SQuinn Dawkins //===----------------------------------------------------------------------===//
2639ab34689SAlex Zinenko // EliminateBarriersOp
2649ab34689SAlex Zinenko //===----------------------------------------------------------------------===//
2659ab34689SAlex Zinenko 
2669ab34689SAlex Zinenko void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) {
26700c3c731Sspaceotter   populateGpuEliminateBarriersPatterns(patterns);
2689ab34689SAlex Zinenko }
2699ab34689SAlex Zinenko 
2709ab34689SAlex Zinenko //===----------------------------------------------------------------------===//
2719ab34689SAlex Zinenko // Block and thread mapping utilities.
2729ab34689SAlex Zinenko //===----------------------------------------------------------------------===//
273768615bbSNicolas Vasilache 
27492f088d3SNicolas Vasilache namespace {
27592f088d3SNicolas Vasilache /// Local types used for mapping verification.
27692f088d3SNicolas Vasilache struct MappingKind {};
27792f088d3SNicolas Vasilache struct BlockMappingKind : MappingKind {};
27892f088d3SNicolas Vasilache struct ThreadMappingKind : MappingKind {};
27992f088d3SNicolas Vasilache } // namespace
28092f088d3SNicolas Vasilache 
281aafb52d7SNicolas Vasilache static DiagnosedSilenceableFailure
282c59465e1SNicolas Vasilache definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
283c59465e1SNicolas Vasilache                       Operation *target, const Twine &message) {
284aafb52d7SNicolas Vasilache   if (transformOp.has_value())
285c59465e1SNicolas Vasilache     return transformOp->emitDefiniteFailure() << message;
286c59465e1SNicolas Vasilache   return emitDefiniteFailure(target, message);
287aafb52d7SNicolas Vasilache }
288aafb52d7SNicolas Vasilache 
289beaffb04SGuray Ozen /// Check if given mapping attributes are one of the desired attributes
29092f088d3SNicolas Vasilache template <typename MappingKindType>
291c5798faeSGuray Ozen static DiagnosedSilenceableFailure
292aafb52d7SNicolas Vasilache checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
293aafb52d7SNicolas Vasilache                            scf::ForallOp forallOp) {
29492f088d3SNicolas Vasilache   if (!forallOp.getMapping().has_value()) {
295c59465e1SNicolas Vasilache     return definiteFailureHelper(transformOp, forallOp,
29692f088d3SNicolas Vasilache                                  "scf.forall op requires a mapping attribute");
29792f088d3SNicolas Vasilache   }
298aafb52d7SNicolas Vasilache 
299971b8525SJakub Kuderski   bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
300971b8525SJakub Kuderski                                       llvm::IsaPred<GPUBlockMappingAttr>);
301971b8525SJakub Kuderski   bool hasWarpgroupMapping = llvm::any_of(
302971b8525SJakub Kuderski       forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
303971b8525SJakub Kuderski   bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
304971b8525SJakub Kuderski                                      llvm::IsaPred<GPUWarpMappingAttr>);
305971b8525SJakub Kuderski   bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
306971b8525SJakub Kuderski                                        llvm::IsaPred<GPUThreadMappingAttr>);
307aafb52d7SNicolas Vasilache   int64_t countMappingTypes = 0;
308aafb52d7SNicolas Vasilache   countMappingTypes += hasBlockMapping ? 1 : 0;
30944e6318cSNicolas Vasilache   countMappingTypes += hasWarpgroupMapping ? 1 : 0;
310c59465e1SNicolas Vasilache   countMappingTypes += hasWarpMapping ? 1 : 0;
31144e6318cSNicolas Vasilache   countMappingTypes += hasThreadMapping ? 1 : 0;
312aafb52d7SNicolas Vasilache   if (countMappingTypes > 1) {
313c59465e1SNicolas Vasilache     return definiteFailureHelper(
314c59465e1SNicolas Vasilache         transformOp, forallOp,
315aafb52d7SNicolas Vasilache         "cannot mix different mapping types, use nesting");
316aafb52d7SNicolas Vasilache   }
31792f088d3SNicolas Vasilache   if (std::is_same<MappingKindType, BlockMappingKind>::value &&
31892f088d3SNicolas Vasilache       !hasBlockMapping) {
31992f088d3SNicolas Vasilache     return definiteFailureHelper(
32092f088d3SNicolas Vasilache         transformOp, forallOp,
32192f088d3SNicolas Vasilache         "scf.forall op requires a mapping attribute of kind 'block'");
32292f088d3SNicolas Vasilache   }
32392f088d3SNicolas Vasilache   if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
32492f088d3SNicolas Vasilache       !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) {
32592f088d3SNicolas Vasilache     return definiteFailureHelper(transformOp, forallOp,
32692f088d3SNicolas Vasilache                                  "scf.forall op requires a mapping attribute "
32792f088d3SNicolas Vasilache                                  "of kind 'thread' or 'warp'");
32892f088d3SNicolas Vasilache   }
329beaffb04SGuray Ozen 
330c5798faeSGuray Ozen   DenseSet<Attribute> seen;
331aafb52d7SNicolas Vasilache   for (Attribute map : forallOp.getMapping()->getValue()) {
3328bdf3878SKazu Hirata     if (seen.contains(map)) {
333c59465e1SNicolas Vasilache       return definiteFailureHelper(
334c59465e1SNicolas Vasilache           transformOp, forallOp,
33544e6318cSNicolas Vasilache           "duplicate attribute, cannot map different loops "
33644e6318cSNicolas Vasilache           "to the same mapping id");
337c5798faeSGuray Ozen     }
338c5798faeSGuray Ozen     seen.insert(map);
339c5798faeSGuray Ozen   }
340c5798faeSGuray Ozen 
34144e6318cSNicolas Vasilache   auto isLinear = [](Attribute a) {
34244e6318cSNicolas Vasilache     return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
34344e6318cSNicolas Vasilache   };
34444e6318cSNicolas Vasilache   if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
34544e6318cSNicolas Vasilache       !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
34644e6318cSNicolas Vasilache     return definiteFailureHelper(
34744e6318cSNicolas Vasilache         transformOp, forallOp,
34844e6318cSNicolas Vasilache         "cannot mix linear and non-linear mapping modes");
34944e6318cSNicolas Vasilache   }
35044e6318cSNicolas Vasilache 
351beaffb04SGuray Ozen   return DiagnosedSilenceableFailure::success();
352beaffb04SGuray Ozen }
353beaffb04SGuray Ozen 
35492f088d3SNicolas Vasilache template <typename MappingKindType>
355aafb52d7SNicolas Vasilache static DiagnosedSilenceableFailure
356aafb52d7SNicolas Vasilache verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
357aafb52d7SNicolas Vasilache                  scf::ForallOp forallOp) {
358aafb52d7SNicolas Vasilache   // Check the types of the mapping attributes match.
359aafb52d7SNicolas Vasilache   DiagnosedSilenceableFailure typeRes =
36092f088d3SNicolas Vasilache       checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp);
361aafb52d7SNicolas Vasilache   if (!typeRes.succeeded())
362aafb52d7SNicolas Vasilache     return typeRes;
363aafb52d7SNicolas Vasilache 
364aafb52d7SNicolas Vasilache   // Perform other non-types verifications.
365aafb52d7SNicolas Vasilache   if (!forallOp.isNormalized())
366c59465e1SNicolas Vasilache     return definiteFailureHelper(transformOp, forallOp,
367aafb52d7SNicolas Vasilache                                  "unsupported non-normalized loops");
368aafb52d7SNicolas Vasilache   if (forallOp.getNumResults() > 0)
369c59465e1SNicolas Vasilache     return definiteFailureHelper(transformOp, forallOp,
370aafb52d7SNicolas Vasilache                                  "only bufferized scf.forall can be mapped");
37144e6318cSNicolas Vasilache   bool useLinearMapping = cast<DeviceMappingAttrInterface>(
37244e6318cSNicolas Vasilache                               forallOp.getMapping()->getValue().front())
37344e6318cSNicolas Vasilache                               .isLinearMapping();
37444e6318cSNicolas Vasilache   // TODO: This would be more natural with support for Optional<EnumParameter>
37544e6318cSNicolas Vasilache   // in GPUDeviceMappingAttr.
37644e6318cSNicolas Vasilache   int64_t maxNumMappingsSupported =
37744e6318cSNicolas Vasilache       useLinearMapping ? (getMaxEnumValForMappingId() -
37844e6318cSNicolas Vasilache                           static_cast<uint64_t>(MappingId::DimZ))
37944e6318cSNicolas Vasilache                        : 3;
38044e6318cSNicolas Vasilache   if (forallOp.getRank() > maxNumMappingsSupported) {
381c59465e1SNicolas Vasilache     return definiteFailureHelper(transformOp, forallOp,
38244e6318cSNicolas Vasilache                                  "scf.forall with rank > ")
38344e6318cSNicolas Vasilache            << maxNumMappingsSupported
38444e6318cSNicolas Vasilache            << " does not lower for the specified mapping attribute type";
38544e6318cSNicolas Vasilache   }
38644e6318cSNicolas Vasilache   auto numParallelIterations =
38744e6318cSNicolas Vasilache       getConstantIntValues(forallOp.getMixedUpperBound());
38844e6318cSNicolas Vasilache   if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
38944e6318cSNicolas Vasilache     return definiteFailureHelper(
39044e6318cSNicolas Vasilache         transformOp, forallOp,
39144e6318cSNicolas Vasilache         "requires statically sized, normalized forall op");
392aafb52d7SNicolas Vasilache   }
393aafb52d7SNicolas Vasilache   return DiagnosedSilenceableFailure::success();
394aafb52d7SNicolas Vasilache }
395aafb52d7SNicolas Vasilache 
396c59465e1SNicolas Vasilache /// Struct to return the result of the rewrite of a forall operation.
397c59465e1SNicolas Vasilache struct ForallRewriteResult {
398c59465e1SNicolas Vasilache   SmallVector<int64_t> mappingSizes;
399c59465e1SNicolas Vasilache   SmallVector<Value> mappingIds;
400c59465e1SNicolas Vasilache };
40189bb0caeSGuray Ozen 
402c59465e1SNicolas Vasilache /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
403c59465e1SNicolas Vasilache template <typename OpTy, typename OperationOrBlock>
404c59465e1SNicolas Vasilache static void
405c59465e1SNicolas Vasilache replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc,
406c59465e1SNicolas Vasilache                             OperationOrBlock *parent, Value replacement,
407c59465e1SNicolas Vasilache                             ArrayRef<int64_t> availableMappingSizes) {
408c59465e1SNicolas Vasilache   parent->walk([&](OpTy idOp) {
409c59465e1SNicolas Vasilache     if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
410c59465e1SNicolas Vasilache       rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
411c59465e1SNicolas Vasilache   });
412c59465e1SNicolas Vasilache }
413c59465e1SNicolas Vasilache 
414c59465e1SNicolas Vasilache static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
415768615bbSNicolas Vasilache     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
41644e6318cSNicolas Vasilache     scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
41744e6318cSNicolas Vasilache     ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) {
418c59465e1SNicolas Vasilache   LDBG("--start rewriteOneForallCommonImpl");
419beaffb04SGuray Ozen 
420768615bbSNicolas Vasilache   // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
42144e6318cSNicolas Vasilache   auto numParallelIterations =
42244e6318cSNicolas Vasilache       getConstantIntValues(forallOp.getMixedUpperBound());
42344e6318cSNicolas Vasilache   assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
42444e6318cSNicolas Vasilache          "requires statically sized, normalized forall op");
42544e6318cSNicolas Vasilache   SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
42644e6318cSNicolas Vasilache   SetVector<Attribute> forallMappingAttrs;
42744e6318cSNicolas Vasilache   forallMappingAttrs.insert(forallOp.getMapping()->getValue().begin(),
42844e6318cSNicolas Vasilache                             forallOp.getMapping()->getValue().end());
42944e6318cSNicolas Vasilache   auto comparator = [](Attribute a, Attribute b) -> bool {
43044e6318cSNicolas Vasilache     return cast<DeviceMappingAttrInterface>(a).getMappingId() <
43144e6318cSNicolas Vasilache            cast<DeviceMappingAttrInterface>(b).getMappingId();
43244e6318cSNicolas Vasilache   };
43344e6318cSNicolas Vasilache 
43444e6318cSNicolas Vasilache   // Step 1.b. In the linear case, compute the max mapping to avoid needlessly
43544e6318cSNicolas Vasilache   // mapping all dimensions. In the 3-D mapping case we need to map all
43644e6318cSNicolas Vasilache   // dimensions.
437fab2bb8bSJustin Lebar   DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>(
438fab2bb8bSJustin Lebar       *llvm::max_element(forallMappingAttrs, comparator));
43944e6318cSNicolas Vasilache   DeviceMappingAttrInterface maxLinearMapping;
44044e6318cSNicolas Vasilache   if (maxMapping.isLinearMapping())
44144e6318cSNicolas Vasilache     maxLinearMapping = maxMapping;
442c59465e1SNicolas Vasilache   for (auto attr : gpuIdBuilder.mappingAttributes) {
44344e6318cSNicolas Vasilache     // If attr overflows, just skip.
44444e6318cSNicolas Vasilache     if (maxLinearMapping && comparator(maxLinearMapping, attr))
445768615bbSNicolas Vasilache       continue;
44644e6318cSNicolas Vasilache     // Try to insert. If element was already present, just continue.
44744e6318cSNicolas Vasilache     if (!forallMappingAttrs.insert(attr))
44844e6318cSNicolas Vasilache       continue;
44944e6318cSNicolas Vasilache     // Otherwise, we have a new insertion without a size -> use size 1.
450768615bbSNicolas Vasilache     tmpMappingSizes.push_back(1);
4516663f347SGuray Ozen   }
452c59465e1SNicolas Vasilache   LLVM_DEBUG(
453c59465e1SNicolas Vasilache       llvm::interleaveComma(
454c59465e1SNicolas Vasilache           tmpMappingSizes,
455c59465e1SNicolas Vasilache           DBGS() << "----tmpMappingSizes extracted from scf.forall op: ");
456c59465e1SNicolas Vasilache       llvm::dbgs() << "\n");
4576663f347SGuray Ozen 
458beaffb04SGuray Ozen   // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
45944e6318cSNicolas Vasilache   SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
46044e6318cSNicolas Vasilache       forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
461c59465e1SNicolas Vasilache   LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
462c59465e1SNicolas Vasilache                                    DBGS() << "----forallMappingSizes: ");
463c59465e1SNicolas Vasilache              llvm::dbgs() << "\n"; llvm::interleaveComma(
46444e6318cSNicolas Vasilache                  forallMappingAttrs, DBGS() << "----forallMappingAttrs: ");
465768615bbSNicolas Vasilache              llvm::dbgs() << "\n");
46689bb0caeSGuray Ozen 
467c59465e1SNicolas Vasilache   // Step 3. Generate the mappingIdOps using the provided generator.
468768615bbSNicolas Vasilache   Location loc = forallOp.getLoc();
469c59465e1SNicolas Vasilache   OpBuilder::InsertionGuard guard(rewriter);
470c59465e1SNicolas Vasilache   rewriter.setInsertionPoint(forallOp);
47144e6318cSNicolas Vasilache   SmallVector<int64_t> originalBasis(availableMappingSizes);
47244e6318cSNicolas Vasilache   bool originalBasisWasProvided = !originalBasis.empty();
47344e6318cSNicolas Vasilache   if (!originalBasisWasProvided) {
47444e6318cSNicolas Vasilache     originalBasis = forallMappingSizes;
47544e6318cSNicolas Vasilache     while (originalBasis.size() < 3)
47644e6318cSNicolas Vasilache       originalBasis.push_back(1);
47744e6318cSNicolas Vasilache   }
47889bb0caeSGuray Ozen 
47944e6318cSNicolas Vasilache   IdBuilderResult builderResult =
48044e6318cSNicolas Vasilache       gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
48144e6318cSNicolas Vasilache 
48244e6318cSNicolas Vasilache   // Step 4. Map the induction variables to the mappingIdOps, this may involve
48344e6318cSNicolas Vasilache   // a permutation.
484c59465e1SNicolas Vasilache   SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
485768615bbSNicolas Vasilache   IRMapping bvm;
48644e6318cSNicolas Vasilache   for (auto [iv, dim] : llvm::zip_equal(
48744e6318cSNicolas Vasilache            forallOp.getInductionVars(),
48844e6318cSNicolas Vasilache            forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
48944e6318cSNicolas Vasilache     auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
49044e6318cSNicolas Vasilache     Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
491768615bbSNicolas Vasilache     bvm.map(iv, peIdOp);
492768615bbSNicolas Vasilache   }
493768615bbSNicolas Vasilache 
49444e6318cSNicolas Vasilache   // Step 5. If the originalBasis is already known, create conditionals to
49544e6318cSNicolas Vasilache   // predicate the region. Otherwise, the current forall determines the
49644e6318cSNicolas Vasilache   // originalBasis and no predication occurs.
497768615bbSNicolas Vasilache   Value predicate;
49844e6318cSNicolas Vasilache   if (originalBasisWasProvided) {
49944e6318cSNicolas Vasilache     SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
50044e6318cSNicolas Vasilache     SmallVector<int64_t> availableMappingSizes =
50144e6318cSNicolas Vasilache         builderResult.availableMappingSizes;
50244e6318cSNicolas Vasilache     SmallVector<Value> activeIdOps = builderResult.activeIdOps;
503c59465e1SNicolas Vasilache     // clang-format off
504c59465e1SNicolas Vasilache     LLVM_DEBUG(
505c59465e1SNicolas Vasilache         llvm::interleaveComma(
50644e6318cSNicolas Vasilache           activeMappingSizes, DBGS() << "----activeMappingSizes: ");
507c59465e1SNicolas Vasilache         llvm::dbgs() << "\n";
508c59465e1SNicolas Vasilache         llvm::interleaveComma(
509c59465e1SNicolas Vasilache           availableMappingSizes, DBGS() << "----availableMappingSizes: ");
510c59465e1SNicolas Vasilache         llvm::dbgs() << "\n";
51144e6318cSNicolas Vasilache         llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: ");
512768615bbSNicolas Vasilache         llvm::dbgs() << "\n");
513c59465e1SNicolas Vasilache     // clang-format on
51444e6318cSNicolas Vasilache     for (auto [activeId, activeMappingSize, availableMappingSize] :
51544e6318cSNicolas Vasilache          llvm::zip_equal(activeIdOps, activeMappingSizes,
51644e6318cSNicolas Vasilache                          availableMappingSizes)) {
51744e6318cSNicolas Vasilache       if (activeMappingSize > availableMappingSize) {
518c59465e1SNicolas Vasilache         return definiteFailureHelper(
519768615bbSNicolas Vasilache             transformOp, forallOp,
520768615bbSNicolas Vasilache             "Trying to map to fewer GPU threads than loop iterations but "
521768615bbSNicolas Vasilache             "overprovisioning is not yet supported. "
522768615bbSNicolas Vasilache             "Try additional tiling of the before mapping or map to more "
523768615bbSNicolas Vasilache             "threads.");
524768615bbSNicolas Vasilache       }
52544e6318cSNicolas Vasilache       if (activeMappingSize == availableMappingSize)
526768615bbSNicolas Vasilache         continue;
52744e6318cSNicolas Vasilache       Value idx =
52844e6318cSNicolas Vasilache           rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
529768615bbSNicolas Vasilache       Value tmpPredicate = rewriter.create<arith::CmpIOp>(
53044e6318cSNicolas Vasilache           loc, arith::CmpIPredicate::ult, activeId, idx);
531c59465e1SNicolas Vasilache       LDBG("----predicate: " << tmpPredicate);
532768615bbSNicolas Vasilache       predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
533768615bbSNicolas Vasilache                                                              tmpPredicate)
534768615bbSNicolas Vasilache                             : tmpPredicate;
535768615bbSNicolas Vasilache     }
536768615bbSNicolas Vasilache   }
537768615bbSNicolas Vasilache 
538c59465e1SNicolas Vasilache   // Step 6. Move the body of forallOp.
539768615bbSNicolas Vasilache   // Erase the terminator first, it will not be used.
540eb2f946eSAlexander Belyaev   rewriter.eraseOp(forallOp.getTerminator());
541768615bbSNicolas Vasilache   Block *targetBlock;
542768615bbSNicolas Vasilache   Block::iterator insertionPoint;
543768615bbSNicolas Vasilache   if (predicate) {
544c59465e1SNicolas Vasilache     // Step 6.a. If predicated, move at the beginning.
545c59465e1SNicolas Vasilache     auto ifOp = rewriter.create<scf::IfOp>(loc, predicate,
546c59465e1SNicolas Vasilache                                            /*withElseRegion=*/false);
547768615bbSNicolas Vasilache     targetBlock = ifOp.thenBlock();
548768615bbSNicolas Vasilache     insertionPoint = ifOp.thenBlock()->begin();
549768615bbSNicolas Vasilache   } else {
550c59465e1SNicolas Vasilache     // Step 6.b. Otherwise, move inline just at the rewriter insertion
551c59465e1SNicolas Vasilache     // point.
552768615bbSNicolas Vasilache     targetBlock = forallOp->getBlock();
553768615bbSNicolas Vasilache     insertionPoint = rewriter.getInsertionPoint();
554768615bbSNicolas Vasilache   }
555eb2f946eSAlexander Belyaev   Block &sourceBlock = forallOp.getRegion().front();
55689bb0caeSGuray Ozen   targetBlock->getOperations().splice(insertionPoint,
55789bb0caeSGuray Ozen                                       sourceBlock.getOperations());
55889bb0caeSGuray Ozen 
559c59465e1SNicolas Vasilache   // Step 7. RAUW indices.
560eb2f946eSAlexander Belyaev   for (Value loopIndex : forallOp.getInductionVars()) {
561768615bbSNicolas Vasilache     Value threadIdx = bvm.lookup(loopIndex);
562768615bbSNicolas Vasilache     rewriter.replaceAllUsesWith(loopIndex, threadIdx);
56389bb0caeSGuray Ozen   }
56489bb0caeSGuray Ozen 
565c59465e1SNicolas Vasilache   // Step 8. Erase old op.
566eb2f946eSAlexander Belyaev   rewriter.eraseOp(forallOp);
56789bb0caeSGuray Ozen 
56844e6318cSNicolas Vasilache   LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
56944e6318cSNicolas Vasilache                                    DBGS() << "----result forallMappingSizes: ");
57044e6318cSNicolas Vasilache              llvm::dbgs() << "\n"; llvm::interleaveComma(
57144e6318cSNicolas Vasilache                  mappingIdOps, DBGS() << "----result mappingIdOps: ");
57244e6318cSNicolas Vasilache              llvm::dbgs() << "\n");
57344e6318cSNicolas Vasilache 
574c59465e1SNicolas Vasilache   result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
575c59465e1SNicolas Vasilache   return DiagnosedSilenceableFailure::success();
576768615bbSNicolas Vasilache }
577768615bbSNicolas Vasilache 
578c59465e1SNicolas Vasilache //===----------------------------------------------------------------------===//
579c59465e1SNicolas Vasilache // MapForallToBlocks
580c59465e1SNicolas Vasilache //===----------------------------------------------------------------------===//
581c59465e1SNicolas Vasilache 
582768615bbSNicolas Vasilache DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
583768615bbSNicolas Vasilache     RewriterBase &rewriter, TransformOpInterface transformOp,
584768615bbSNicolas Vasilache     scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
585c59465e1SNicolas Vasilache     const GpuIdBuilder &gpuIdBuilder) {
586c59465e1SNicolas Vasilache   LDBG("Start mapForallToBlocksImpl");
587c59465e1SNicolas Vasilache 
58892f088d3SNicolas Vasilache   {
58992f088d3SNicolas Vasilache     // GPU-specific verifications. There is no better place to anchor
59092f088d3SNicolas Vasilache     // those right now: the ForallOp is target-independent and the transform
59192f088d3SNicolas Vasilache     // op does not apply to individual ForallOp.
59292f088d3SNicolas Vasilache     DiagnosedSilenceableFailure diag =
59392f088d3SNicolas Vasilache         verifyGpuMapping<BlockMappingKind>(transformOp, forallOp);
59492f088d3SNicolas Vasilache     if (!diag.succeeded())
59592f088d3SNicolas Vasilache       return diag;
59692f088d3SNicolas Vasilache   }
59792f088d3SNicolas Vasilache 
598c59465e1SNicolas Vasilache   Location loc = forallOp.getLoc();
599c59465e1SNicolas Vasilache   Block *parentBlock = forallOp->getBlock();
600c59465e1SNicolas Vasilache   Value zero;
601c59465e1SNicolas Vasilache   {
602c59465e1SNicolas Vasilache     // Create an early zero index value for replacements and immediately reset
603c59465e1SNicolas Vasilache     // the insertion point.
604c59465e1SNicolas Vasilache     OpBuilder::InsertionGuard guard(rewriter);
605c59465e1SNicolas Vasilache     rewriter.setInsertionPointToStart(parentBlock);
606c59465e1SNicolas Vasilache     zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
607c59465e1SNicolas Vasilache   }
608c59465e1SNicolas Vasilache 
609c59465e1SNicolas Vasilache   ForallRewriteResult rewriteResult;
61044e6318cSNicolas Vasilache   DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
61144e6318cSNicolas Vasilache       rewriter, transformOp, forallOp,
61244e6318cSNicolas Vasilache       /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder);
613c59465e1SNicolas Vasilache 
61444e6318cSNicolas Vasilache   // Return if anything goes wrong, use silenceable failure as a match
61544e6318cSNicolas Vasilache   // failure.
616c59465e1SNicolas Vasilache   if (!diag.succeeded())
617c59465e1SNicolas Vasilache     return diag;
618c59465e1SNicolas Vasilache 
61944e6318cSNicolas Vasilache   // If gridDims was not provided already, set it from the return.
62044e6318cSNicolas Vasilache   if (gridDims.empty()) {
621c59465e1SNicolas Vasilache     gridDims = rewriteResult.mappingSizes;
62244e6318cSNicolas Vasilache     while (gridDims.size() < 3)
62344e6318cSNicolas Vasilache       gridDims.push_back(1);
62444e6318cSNicolas Vasilache   }
62544e6318cSNicolas Vasilache   assert(gridDims.size() == 3 && "Need 3-D gridDims");
626c59465e1SNicolas Vasilache 
627c59465e1SNicolas Vasilache   // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
628c59465e1SNicolas Vasilache   // Here, the result of mapping determines the available mapping sizes.
629c59465e1SNicolas Vasilache   replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
63044e6318cSNicolas Vasilache                                           rewriteResult.mappingSizes);
631c59465e1SNicolas Vasilache 
63289bb0caeSGuray Ozen   return DiagnosedSilenceableFailure::success();
63389bb0caeSGuray Ozen }
63489bb0caeSGuray Ozen 
63544e6318cSNicolas Vasilache DiagnosedSilenceableFailure
63644e6318cSNicolas Vasilache mlir::transform::gpu::findTopLevelForallOp(Operation *target,
63744e6318cSNicolas Vasilache                                            scf::ForallOp &topLevelForallOp,
63844e6318cSNicolas Vasilache                                            TransformOpInterface transformOp) {
63944e6318cSNicolas Vasilache   auto walkResult = target->walk([&](scf::ForallOp forallOp) {
64044e6318cSNicolas Vasilache     if (forallOp->getParentOfType<scf::ForallOp>())
64144e6318cSNicolas Vasilache       return WalkResult::advance();
64244e6318cSNicolas Vasilache     if (topLevelForallOp)
64344e6318cSNicolas Vasilache       // TODO: Handle multiple forall if they are independent.
64444e6318cSNicolas Vasilache       return WalkResult::interrupt();
64544e6318cSNicolas Vasilache     topLevelForallOp = forallOp;
64644e6318cSNicolas Vasilache     return WalkResult::advance();
64744e6318cSNicolas Vasilache   });
64844e6318cSNicolas Vasilache 
64992f088d3SNicolas Vasilache   if (walkResult.wasInterrupted() || !topLevelForallOp)
65044e6318cSNicolas Vasilache     return transformOp.emitSilenceableError()
65144e6318cSNicolas Vasilache            << "could not find a unique topLevel scf.forall";
65244e6318cSNicolas Vasilache   return DiagnosedSilenceableFailure::success();
65344e6318cSNicolas Vasilache }
65444e6318cSNicolas Vasilache 
655c63d2b2cSMatthias Springer DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
656c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
657c63d2b2cSMatthias Springer     ApplyToEachResultList &results, transform::TransformState &state) {
65889bb0caeSGuray Ozen   LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
65989bb0caeSGuray Ozen   auto transformOp = cast<TransformOpInterface>(getOperation());
66089bb0caeSGuray Ozen 
66189bb0caeSGuray Ozen   if (!getGenerateGpuLaunch() && !gpuLaunch) {
66289bb0caeSGuray Ozen     DiagnosedSilenceableFailure diag =
66389bb0caeSGuray Ozen         emitSilenceableError()
66489bb0caeSGuray Ozen         << "Given target is not gpu.launch, set `generate_gpu_launch` "
66589bb0caeSGuray Ozen            "attribute";
66689bb0caeSGuray Ozen     diag.attachNote(target->getLoc()) << "when applied to this payload op";
66789bb0caeSGuray Ozen     return diag;
66889bb0caeSGuray Ozen   }
66989bb0caeSGuray Ozen 
670eb2f946eSAlexander Belyaev   scf::ForallOp topLevelForallOp;
671eb2f946eSAlexander Belyaev   DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp(
672eb2f946eSAlexander Belyaev       target, topLevelForallOp, transformOp);
67389bb0caeSGuray Ozen   if (!diag.succeeded()) {
67489bb0caeSGuray Ozen     diag.attachNote(target->getLoc()) << "when applied to this payload op";
67589bb0caeSGuray Ozen     return diag;
67689bb0caeSGuray Ozen   }
67792f088d3SNicolas Vasilache   assert(topLevelForallOp && "expect an scf.forall");
67889bb0caeSGuray Ozen 
679c59465e1SNicolas Vasilache   SmallVector<int64_t> gridDims{getGridDims()};
680768615bbSNicolas Vasilache   if (!getGenerateGpuLaunch() && gridDims.size() != 3)
681aafb52d7SNicolas Vasilache     return transformOp.emitDefiniteFailure("transform require size-3 mapping");
682aafb52d7SNicolas Vasilache 
68389bb0caeSGuray Ozen   OpBuilder::InsertionGuard guard(rewriter);
684eb2f946eSAlexander Belyaev   rewriter.setInsertionPoint(topLevelForallOp);
68589bb0caeSGuray Ozen 
686eb2f946eSAlexander Belyaev   // Generate gpu launch here and move the forall inside
68789bb0caeSGuray Ozen   if (getGenerateGpuLaunch()) {
68889bb0caeSGuray Ozen     DiagnosedSilenceableFailure diag =
68989bb0caeSGuray Ozen         createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
69044e6318cSNicolas Vasilache     if (!diag.succeeded())
69189bb0caeSGuray Ozen       return diag;
69244e6318cSNicolas Vasilache 
69389bb0caeSGuray Ozen     rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
694eb2f946eSAlexander Belyaev     Operation *newForallOp = rewriter.clone(*topLevelForallOp);
695eb2f946eSAlexander Belyaev     rewriter.eraseOp(topLevelForallOp);
696eb2f946eSAlexander Belyaev     topLevelForallOp = cast<scf::ForallOp>(newForallOp);
69789bb0caeSGuray Ozen   }
69889bb0caeSGuray Ozen 
69944e6318cSNicolas Vasilache   // The BlockIdBuilder adapts to whatever is thrown at it.
70092f088d3SNicolas Vasilache   bool useLinearMapping = false;
70192f088d3SNicolas Vasilache   if (topLevelForallOp.getMapping()) {
70244e6318cSNicolas Vasilache     auto mappingAttr = cast<DeviceMappingAttrInterface>(
70344e6318cSNicolas Vasilache         topLevelForallOp.getMapping()->getValue().front());
70492f088d3SNicolas Vasilache     useLinearMapping = mappingAttr.isLinearMapping();
70592f088d3SNicolas Vasilache   }
70644e6318cSNicolas Vasilache   GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
70744e6318cSNicolas Vasilache 
7081cff4cbdSNicolas Vasilache   diag = mlir::transform::gpu::mapForallToBlocksImpl(
709c59465e1SNicolas Vasilache       rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
710aafb52d7SNicolas Vasilache   if (!diag.succeeded())
711aafb52d7SNicolas Vasilache     return diag;
712aafb52d7SNicolas Vasilache 
71344e6318cSNicolas Vasilache   // Set the GPU launch configuration for the grid dims late, this is
71444e6318cSNicolas Vasilache   // subject to IR inspection.
71589bb0caeSGuray Ozen   diag = alterGpuLaunch(rewriter, gpuLaunch,
716768615bbSNicolas Vasilache                         cast<TransformOpInterface>(getOperation()), gridDims[0],
717768615bbSNicolas Vasilache                         gridDims[1], gridDims[2]);
71889bb0caeSGuray Ozen 
7194b455a71SAlex Zinenko   results.push_back(gpuLaunch);
72089bb0caeSGuray Ozen   return diag;
72189bb0caeSGuray Ozen }
72289bb0caeSGuray Ozen 
72392f088d3SNicolas Vasilache LogicalResult transform::MapForallToBlocks::verify() {
72492f088d3SNicolas Vasilache   if (!getGridDims().empty() && getGridDims().size() != 3) {
72592f088d3SNicolas Vasilache     return emitOpError() << "transform requires empty or size-3 grid_dims";
72692f088d3SNicolas Vasilache   }
72792f088d3SNicolas Vasilache   return success();
72892f088d3SNicolas Vasilache }
72992f088d3SNicolas Vasilache 
73089bb0caeSGuray Ozen //===----------------------------------------------------------------------===//
7311cff4cbdSNicolas Vasilache // MapNestedForallToThreads
73289bb0caeSGuray Ozen //===----------------------------------------------------------------------===//
73389bb0caeSGuray Ozen 
73444e6318cSNicolas Vasilache static DiagnosedSilenceableFailure checkMappingSpec(
73544e6318cSNicolas Vasilache     std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
73644e6318cSNicolas Vasilache     ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
73744e6318cSNicolas Vasilache     int factor, bool useLinearMapping = false) {
73844e6318cSNicolas Vasilache   if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
73944e6318cSNicolas Vasilache     auto diag = definiteFailureHelper(
74044e6318cSNicolas Vasilache         transformOp, forallOp,
74144e6318cSNicolas Vasilache         Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
74244e6318cSNicolas Vasilache             std::to_string(factor));
74344e6318cSNicolas Vasilache     return diag;
74444e6318cSNicolas Vasilache   }
74544e6318cSNicolas Vasilache   if (computeProduct(numParallelIterations) * factor >
74644e6318cSNicolas Vasilache       computeProduct(blockOrGridSizes)) {
74744e6318cSNicolas Vasilache     auto diag = definiteFailureHelper(
74844e6318cSNicolas Vasilache         transformOp, forallOp,
74992f088d3SNicolas Vasilache         Twine("the number of required parallel resources (blocks or "
75092f088d3SNicolas Vasilache               "threads) ") +
75144e6318cSNicolas Vasilache             std::to_string(computeProduct(numParallelIterations) * factor) +
75244e6318cSNicolas Vasilache             std::string(" overflows the number of available resources ") +
75344e6318cSNicolas Vasilache             std::to_string(computeProduct(blockOrGridSizes)));
75444e6318cSNicolas Vasilache     return diag;
75544e6318cSNicolas Vasilache   }
75644e6318cSNicolas Vasilache   return DiagnosedSilenceableFailure::success();
75744e6318cSNicolas Vasilache }
75844e6318cSNicolas Vasilache 
75944e6318cSNicolas Vasilache static DiagnosedSilenceableFailure
76044e6318cSNicolas Vasilache getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
76144e6318cSNicolas Vasilache                    scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
76244e6318cSNicolas Vasilache                    int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
76344e6318cSNicolas Vasilache   auto mappingAttr = cast<DeviceMappingAttrInterface>(
76444e6318cSNicolas Vasilache       forallOp.getMapping()->getValue().front());
76544e6318cSNicolas Vasilache   bool useLinearMapping = mappingAttr.isLinearMapping();
76644e6318cSNicolas Vasilache 
76744e6318cSNicolas Vasilache   // Sanity checks that may result in runtime verification errors.
76844e6318cSNicolas Vasilache   auto numParallelIterations =
76944e6318cSNicolas Vasilache       getConstantIntValues((forallOp.getMixedUpperBound()));
77044e6318cSNicolas Vasilache   if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
77144e6318cSNicolas Vasilache     return definiteFailureHelper(
77244e6318cSNicolas Vasilache         transformOp, forallOp,
77344e6318cSNicolas Vasilache         "requires statically sized, normalized forall op");
77444e6318cSNicolas Vasilache   }
77544e6318cSNicolas Vasilache   int64_t factor = 1;
77644e6318cSNicolas Vasilache   if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
77744e6318cSNicolas Vasilache     factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
77844e6318cSNicolas Vasilache   } else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
77944e6318cSNicolas Vasilache     factor = warpSize;
78044e6318cSNicolas Vasilache   }
78144e6318cSNicolas Vasilache   DiagnosedSilenceableFailure diag =
78244e6318cSNicolas Vasilache       checkMappingSpec(transformOp, forallOp, numParallelIterations.value(),
78344e6318cSNicolas Vasilache                        blockSizes, factor, useLinearMapping);
78444e6318cSNicolas Vasilache   if (!diag.succeeded())
78544e6318cSNicolas Vasilache     return diag;
78644e6318cSNicolas Vasilache 
78744e6318cSNicolas Vasilache   // Start mapping.
78844e6318cSNicolas Vasilache   MLIRContext *ctx = forallOp.getContext();
78944e6318cSNicolas Vasilache   gpuIdBuilder =
79044e6318cSNicolas Vasilache       TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
79144e6318cSNicolas Vasilache           .Case([&](GPUWarpgroupMappingAttr) {
79244e6318cSNicolas Vasilache             return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
79344e6318cSNicolas Vasilache           })
79444e6318cSNicolas Vasilache           .Case([&](GPUWarpMappingAttr) {
79544e6318cSNicolas Vasilache             return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
79644e6318cSNicolas Vasilache           })
79744e6318cSNicolas Vasilache           .Case([&](GPUThreadMappingAttr) {
79844e6318cSNicolas Vasilache             return GpuThreadIdBuilder(ctx, useLinearMapping);
79944e6318cSNicolas Vasilache           })
80044e6318cSNicolas Vasilache           .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
80144e6318cSNicolas Vasilache             llvm_unreachable("unknown mapping attribute");
80244e6318cSNicolas Vasilache           });
80344e6318cSNicolas Vasilache   return DiagnosedSilenceableFailure::success();
80444e6318cSNicolas Vasilache }
80544e6318cSNicolas Vasilache 
806c59465e1SNicolas Vasilache DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl(
807768615bbSNicolas Vasilache     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
80844e6318cSNicolas Vasilache     scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize,
80944e6318cSNicolas Vasilache     bool syncAfterDistribute) {
81044e6318cSNicolas Vasilache 
81192f088d3SNicolas Vasilache   {
81292f088d3SNicolas Vasilache     // GPU-specific verifications. There is no better place to anchor
81392f088d3SNicolas Vasilache     // those right now: the ForallOp is target-independent and the transform
81492f088d3SNicolas Vasilache     // op does not apply to individual ForallOp.
81592f088d3SNicolas Vasilache     DiagnosedSilenceableFailure diag =
81692f088d3SNicolas Vasilache         verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp);
81792f088d3SNicolas Vasilache     if (!diag.succeeded())
81892f088d3SNicolas Vasilache       return diag;
81992f088d3SNicolas Vasilache   }
82092f088d3SNicolas Vasilache 
82144e6318cSNicolas Vasilache   GpuIdBuilder gpuIdBuilder;
82244e6318cSNicolas Vasilache   {
82344e6318cSNicolas Vasilache     // Try to construct the id builder, if it fails, return.
82444e6318cSNicolas Vasilache     DiagnosedSilenceableFailure diag = getThreadIdBuilder(
82544e6318cSNicolas Vasilache         transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
82644e6318cSNicolas Vasilache     if (!diag.succeeded())
82744e6318cSNicolas Vasilache       return diag;
828a7686db8SThomas Raoux   }
829c59465e1SNicolas Vasilache 
830768615bbSNicolas Vasilache   Location loc = forallOp.getLoc();
831768615bbSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
832c59465e1SNicolas Vasilache   // Insert after to allow for syncthreads after `forall` is erased.
833768615bbSNicolas Vasilache   rewriter.setInsertionPointAfter(forallOp);
834c59465e1SNicolas Vasilache   ForallRewriteResult rewriteResult;
83544e6318cSNicolas Vasilache   DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl(
83644e6318cSNicolas Vasilache       rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
837c59465e1SNicolas Vasilache   if (!diag.succeeded())
838c59465e1SNicolas Vasilache     return diag;
839768615bbSNicolas Vasilache   // Add a syncthreads if needed. TODO: warpsync
840768615bbSNicolas Vasilache   if (syncAfterDistribute)
841768615bbSNicolas Vasilache     rewriter.create<BarrierOp>(loc);
842c59465e1SNicolas Vasilache 
843c59465e1SNicolas Vasilache   return DiagnosedSilenceableFailure::success();
844beaffb04SGuray Ozen }
845c59465e1SNicolas Vasilache 
846c59465e1SNicolas Vasilache DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
847c59465e1SNicolas Vasilache     RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
84844e6318cSNicolas Vasilache     Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
849c59465e1SNicolas Vasilache     bool syncAfterDistribute) {
850c59465e1SNicolas Vasilache   LDBG("Start mapNestedForallToThreadsImpl");
85144e6318cSNicolas Vasilache   if (blockDims.size() != 3) {
852c59465e1SNicolas Vasilache     return definiteFailureHelper(transformOp, target,
853c59465e1SNicolas Vasilache                                  "requires size-3 thread mapping");
854c59465e1SNicolas Vasilache   }
855c59465e1SNicolas Vasilache 
856c59465e1SNicolas Vasilache   // Create an early zero index value for replacements.
857c59465e1SNicolas Vasilache   Location loc = target->getLoc();
858c59465e1SNicolas Vasilache   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
859c59465e1SNicolas Vasilache   DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
860c59465e1SNicolas Vasilache   WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
861c59465e1SNicolas Vasilache     diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
86244e6318cSNicolas Vasilache         rewriter, transformOp, forallOp, blockDims, warpSize,
86344e6318cSNicolas Vasilache         syncAfterDistribute);
864c59465e1SNicolas Vasilache     if (diag.isDefiniteFailure())
865c59465e1SNicolas Vasilache       return WalkResult::interrupt();
866c59465e1SNicolas Vasilache     if (diag.succeeded())
867c59465e1SNicolas Vasilache       return WalkResult::skip();
868c59465e1SNicolas Vasilache     return WalkResult::advance();
86989bb0caeSGuray Ozen   });
870c59465e1SNicolas Vasilache   if (walkResult.wasInterrupted())
87189bb0caeSGuray Ozen     return diag;
872c59465e1SNicolas Vasilache 
873c59465e1SNicolas Vasilache   // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
874c59465e1SNicolas Vasilache   // Here, the result of mapping determines the available mapping sizes.
875c59465e1SNicolas Vasilache   replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
876c59465e1SNicolas Vasilache                                           blockDims);
877c59465e1SNicolas Vasilache 
878c59465e1SNicolas Vasilache   return DiagnosedSilenceableFailure::success();
87989bb0caeSGuray Ozen }
88089bb0caeSGuray Ozen 
8811cff4cbdSNicolas Vasilache DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
882c63d2b2cSMatthias Springer     transform::TransformRewriter &rewriter, Operation *target,
883c63d2b2cSMatthias Springer     ApplyToEachResultList &results, TransformState &state) {
88489bb0caeSGuray Ozen   LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
88589bb0caeSGuray Ozen   auto transformOp = cast<TransformOpInterface>(getOperation());
88689bb0caeSGuray Ozen 
887aafb52d7SNicolas Vasilache   // Basic high-level verifications.
888aafb52d7SNicolas Vasilache   if (!gpuLaunch)
889aafb52d7SNicolas Vasilache     return emitSilenceableError() << "Given target is not a gpu.launch";
89089bb0caeSGuray Ozen 
891c59465e1SNicolas Vasilache   // Mapping to block ids.
892c59465e1SNicolas Vasilache   SmallVector<int64_t> blockDims{getBlockDims()};
89389bb0caeSGuray Ozen   DiagnosedSilenceableFailure diag =
8941a36588eSKazu Hirata       checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
895768615bbSNicolas Vasilache                      blockDims[0], blockDims[1], blockDims[2]);
89689bb0caeSGuray Ozen   if (diag.isSilenceableFailure()) {
897c59465e1SNicolas Vasilache     diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large";
89889bb0caeSGuray Ozen     return diag;
89989bb0caeSGuray Ozen   }
90089bb0caeSGuray Ozen 
901c59465e1SNicolas Vasilache   // Set the GPU launch configuration for the block dims early, this is not
902c59465e1SNicolas Vasilache   // subject to IR inspection.
9031a36588eSKazu Hirata   diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
904768615bbSNicolas Vasilache                         std::nullopt, std::nullopt, blockDims[0], blockDims[1],
905768615bbSNicolas Vasilache                         blockDims[2]);
90689bb0caeSGuray Ozen 
907c59465e1SNicolas Vasilache   rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
908c59465e1SNicolas Vasilache   diag =
909c59465e1SNicolas Vasilache       mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
91044e6318cSNicolas Vasilache                                    getWarpSize(), getSyncAfterDistribute());
911c59465e1SNicolas Vasilache 
912015cd84dSNicolas Vasilache   results.push_back(gpuLaunch.getOperation());
91389bb0caeSGuray Ozen   return diag;
91489bb0caeSGuray Ozen }
91589bb0caeSGuray Ozen 
91689bb0caeSGuray Ozen //===----------------------------------------------------------------------===//
91789bb0caeSGuray Ozen // Transform op registration
91889bb0caeSGuray Ozen //===----------------------------------------------------------------------===//
91989bb0caeSGuray Ozen 
92089bb0caeSGuray Ozen namespace {
92189bb0caeSGuray Ozen /// Registers new ops and declares PDL as dependent dialect since the
92289bb0caeSGuray Ozen /// additional ops are using PDL types for operands and results.
92389bb0caeSGuray Ozen class GPUTransformDialectExtension
92489bb0caeSGuray Ozen     : public transform::TransformDialectExtension<
92589bb0caeSGuray Ozen           GPUTransformDialectExtension> {
92689bb0caeSGuray Ozen public:
927*84cc1865SNikhil Kalra   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
928*84cc1865SNikhil Kalra 
92989bb0caeSGuray Ozen   GPUTransformDialectExtension() {
93089bb0caeSGuray Ozen     declareGeneratedDialect<scf::SCFDialect>();
93189bb0caeSGuray Ozen     declareGeneratedDialect<arith::ArithDialect>();
93289bb0caeSGuray Ozen     declareGeneratedDialect<GPUDialect>();
93389bb0caeSGuray Ozen     registerTransformOps<
93489bb0caeSGuray Ozen #define GET_OP_LIST
93589bb0caeSGuray Ozen #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
93689bb0caeSGuray Ozen         >();
93789bb0caeSGuray Ozen   }
93889bb0caeSGuray Ozen };
93989bb0caeSGuray Ozen } // namespace
94089bb0caeSGuray Ozen 
94189bb0caeSGuray Ozen #define GET_OP_CLASSES
94289bb0caeSGuray Ozen #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
94389bb0caeSGuray Ozen 
94489bb0caeSGuray Ozen void mlir::gpu::registerTransformDialectExtension(DialectRegistry &registry) {
94589bb0caeSGuray Ozen   registry.addExtensions<GPUTransformDialectExtension>();
94689bb0caeSGuray Ozen }
947