189bb0caeSGuray Ozen //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===// 289bb0caeSGuray Ozen // 389bb0caeSGuray Ozen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 489bb0caeSGuray Ozen // See https://llvm.org/LICENSE.txt for license information. 589bb0caeSGuray Ozen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 689bb0caeSGuray Ozen // 789bb0caeSGuray Ozen //===----------------------------------------------------------------------===// 889bb0caeSGuray Ozen 989bb0caeSGuray Ozen #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" 1089bb0caeSGuray Ozen 11888717e8SNicolas Vasilache #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 12888717e8SNicolas Vasilache #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" 13888717e8SNicolas Vasilache #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14c59465e1SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h" 1589bb0caeSGuray Ozen #include "mlir/Dialect/Arith/IR/Arith.h" 16c59465e1SNicolas Vasilache #include "mlir/Dialect/Func/IR/FuncOps.h" 1789bb0caeSGuray Ozen #include "mlir/Dialect/GPU/IR/GPUDialect.h" 1890ecfa2aSNicolas Vasilache #include "mlir/Dialect/GPU/TransformOps/Utils.h" 19888717e8SNicolas Vasilache #include "mlir/Dialect/GPU/Transforms/Passes.h" 20888717e8SNicolas Vasilache #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 219ab34689SAlex Zinenko #include "mlir/Dialect/MemRef/IR/MemRef.h" 22beaffb04SGuray Ozen #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" 2389bb0caeSGuray Ozen #include "mlir/Dialect/SCF/IR/SCF.h" 2489bb0caeSGuray Ozen #include "mlir/Dialect/Transform/IR/TransformDialect.h" 255a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 26c59465e1SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h" 279ab34689SAlex Zinenko #include "mlir/Dialect/Vector/IR/VectorOps.h" 28ff8775f3SQuinn Dawkins #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 29c59465e1SNicolas Vasilache #include "mlir/IR/AffineExpr.h" 30c59465e1SNicolas Vasilache #include "mlir/IR/Builders.h" 31768615bbSNicolas Vasilache #include "mlir/IR/BuiltinAttributes.h" 324d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 33c59465e1SNicolas Vasilache #include "mlir/IR/MLIRContext.h" 34aafb52d7SNicolas Vasilache #include "mlir/IR/OpDefinition.h" 35c59465e1SNicolas Vasilache #include "mlir/IR/Visitors.h" 36aafb52d7SNicolas Vasilache #include "mlir/Support/LLVM.h" 37888717e8SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h" 38768615bbSNicolas Vasilache #include "llvm/ADT/STLExtras.h" 39768615bbSNicolas Vasilache #include "llvm/ADT/SmallVector.h" 409ab34689SAlex Zinenko #include "llvm/ADT/TypeSwitch.h" 41768615bbSNicolas Vasilache #include "llvm/Support/Debug.h" 4244e6318cSNicolas Vasilache #include "llvm/Support/ErrorHandling.h" 4392f088d3SNicolas Vasilache #include <type_traits> 4489bb0caeSGuray Ozen 4589bb0caeSGuray Ozen using namespace mlir; 4689bb0caeSGuray Ozen using namespace mlir::gpu; 4789bb0caeSGuray Ozen using namespace mlir::transform; 48c59465e1SNicolas Vasilache using namespace mlir::transform::gpu; 4989bb0caeSGuray Ozen 50768615bbSNicolas Vasilache #define DEBUG_TYPE "gpu-transforms" 519ab34689SAlex Zinenko #define DEBUG_TYPE_ALIAS "gpu-transforms-alias" 52768615bbSNicolas Vasilache 53768615bbSNicolas Vasilache #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 54768615bbSNicolas Vasilache #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") 559ab34689SAlex Zinenko #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") 569ab34689SAlex Zinenko 579ab34689SAlex Zinenko //===----------------------------------------------------------------------===// 58888717e8SNicolas Vasilache // Apply...ConversionPatternsOp 59888717e8SNicolas Vasilache //===----------------------------------------------------------------------===// 60888717e8SNicolas Vasilache 61888717e8SNicolas Vasilache void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns( 62888717e8SNicolas Vasilache TypeConverter &typeConverter, RewritePatternSet &patterns) { 63888717e8SNicolas Vasilache auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); 64888717e8SNicolas Vasilache // NVVM uses alloca in the default address space to represent private 65888717e8SNicolas Vasilache // memory allocations, so drop private annotations. NVVM uses address 66888717e8SNicolas Vasilache // space 3 for shared memory. NVVM uses the default address space to 67888717e8SNicolas Vasilache // represent global memory. 68888717e8SNicolas Vasilache // Used in populateGpuToNVVMConversionPatternsso attaching here for now. 69888717e8SNicolas Vasilache // TODO: We should have a single to_nvvm_type_converter. 70888717e8SNicolas Vasilache populateGpuMemorySpaceAttributeConversions( 71888717e8SNicolas Vasilache llvmTypeConverter, [](AddressSpace space) -> unsigned { 72888717e8SNicolas Vasilache switch (space) { 73888717e8SNicolas Vasilache case AddressSpace::Global: 74888717e8SNicolas Vasilache return static_cast<unsigned>( 75888717e8SNicolas Vasilache NVVM::NVVMMemorySpace::kGlobalMemorySpace); 76888717e8SNicolas Vasilache case AddressSpace::Workgroup: 77888717e8SNicolas Vasilache return static_cast<unsigned>( 78888717e8SNicolas Vasilache NVVM::NVVMMemorySpace::kSharedMemorySpace); 79888717e8SNicolas Vasilache case AddressSpace::Private: 80888717e8SNicolas Vasilache return 0; 81888717e8SNicolas Vasilache } 82888717e8SNicolas Vasilache llvm_unreachable("unknown address space enum value"); 83888717e8SNicolas Vasilache return 0; 84888717e8SNicolas Vasilache }); 85888717e8SNicolas Vasilache // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now. 86888717e8SNicolas Vasilache // TODO: We should have a single to_nvvm_type_converter. 87888717e8SNicolas Vasilache llvmTypeConverter.addConversion( 88888717e8SNicolas Vasilache [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); }); 89888717e8SNicolas Vasilache populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns); 90888717e8SNicolas Vasilache } 91888717e8SNicolas Vasilache 92888717e8SNicolas Vasilache LogicalResult 93888717e8SNicolas Vasilache transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter( 94888717e8SNicolas Vasilache transform::TypeConverterBuilderOpInterface builder) { 95888717e8SNicolas Vasilache if (builder.getTypeConverterType() != "LLVMTypeConverter") 96888717e8SNicolas Vasilache return emitOpError("expected LLVMTypeConverter"); 97888717e8SNicolas Vasilache return success(); 98888717e8SNicolas Vasilache } 99888717e8SNicolas Vasilache 100888717e8SNicolas Vasilache void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns( 101888717e8SNicolas Vasilache TypeConverter &typeConverter, RewritePatternSet &patterns) { 102888717e8SNicolas Vasilache auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); 103888717e8SNicolas Vasilache populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns); 104888717e8SNicolas Vasilache } 105888717e8SNicolas Vasilache 106888717e8SNicolas Vasilache LogicalResult 107888717e8SNicolas Vasilache transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter( 108888717e8SNicolas Vasilache transform::TypeConverterBuilderOpInterface builder) { 109888717e8SNicolas Vasilache if (builder.getTypeConverterType() != "LLVMTypeConverter") 110888717e8SNicolas Vasilache return emitOpError("expected LLVMTypeConverter"); 111888717e8SNicolas Vasilache return success(); 112888717e8SNicolas Vasilache } 113888717e8SNicolas Vasilache 114888717e8SNicolas Vasilache void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: 115888717e8SNicolas Vasilache populatePatterns(TypeConverter &typeConverter, 116888717e8SNicolas Vasilache RewritePatternSet &patterns) { 117888717e8SNicolas Vasilache auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); 118888717e8SNicolas Vasilache populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns); 119888717e8SNicolas Vasilache } 120888717e8SNicolas Vasilache 121888717e8SNicolas Vasilache LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: 122888717e8SNicolas Vasilache verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) { 123888717e8SNicolas Vasilache if (builder.getTypeConverterType() != "LLVMTypeConverter") 124888717e8SNicolas Vasilache return emitOpError("expected LLVMTypeConverter"); 125888717e8SNicolas Vasilache return success(); 126888717e8SNicolas Vasilache } 127888717e8SNicolas Vasilache 128888717e8SNicolas Vasilache //===----------------------------------------------------------------------===// 129888717e8SNicolas Vasilache // Apply...PatternsOp 130888717e8SNicolas Vasilache //===----------------------------------------------------------------------===//s 131888717e8SNicolas Vasilache 132888717e8SNicolas Vasilache void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) { 133888717e8SNicolas Vasilache populateGpuRewritePatterns(patterns); 134888717e8SNicolas Vasilache } 135888717e8SNicolas Vasilache 136888717e8SNicolas Vasilache //===----------------------------------------------------------------------===// 137ff8775f3SQuinn Dawkins // ApplyUnrollVectorsSubgroupMmaOp 138ff8775f3SQuinn Dawkins //===----------------------------------------------------------------------===// 139ff8775f3SQuinn Dawkins 140ff8775f3SQuinn Dawkins /// Pick an unrolling order that will allow tensorcore operation to reuse LHS 141ff8775f3SQuinn Dawkins /// register. 142ff8775f3SQuinn Dawkins static std::optional<SmallVector<int64_t>> 143ff8775f3SQuinn Dawkins gpuMmaUnrollOrder(vector::ContractionOp contract) { 144ff8775f3SQuinn Dawkins SmallVector<int64_t> order; 145ff8775f3SQuinn Dawkins // First make reduction the outer dimensions. 146ff8775f3SQuinn Dawkins for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { 147ff8775f3SQuinn Dawkins if (vector::isReductionIterator(iter)) { 148ff8775f3SQuinn Dawkins order.push_back(index); 149ff8775f3SQuinn Dawkins } 150ff8775f3SQuinn Dawkins } 151ff8775f3SQuinn Dawkins 152ff8775f3SQuinn Dawkins llvm::SmallDenseSet<int64_t> dims; 153ff8775f3SQuinn Dawkins for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) { 1541609f1c2Slong.chen dims.insert(cast<AffineDimExpr>(expr).getPosition()); 155ff8775f3SQuinn Dawkins } 156ff8775f3SQuinn Dawkins // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. 157ff8775f3SQuinn Dawkins for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { 158ff8775f3SQuinn Dawkins if (vector::isParallelIterator(iter) && dims.count(index)) { 159ff8775f3SQuinn Dawkins order.push_back(index); 160ff8775f3SQuinn Dawkins } 161ff8775f3SQuinn Dawkins } 162ff8775f3SQuinn Dawkins // Then the remaining parallel loops. 163ff8775f3SQuinn Dawkins for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { 164ff8775f3SQuinn Dawkins if (vector::isParallelIterator(iter) && !dims.count(index)) { 165ff8775f3SQuinn Dawkins order.push_back(index); 166ff8775f3SQuinn Dawkins } 167ff8775f3SQuinn Dawkins } 168ff8775f3SQuinn Dawkins return order; 169ff8775f3SQuinn Dawkins } 170ff8775f3SQuinn Dawkins 171ff8775f3SQuinn Dawkins /// Returns the target vector size for the target operation based on the native 172ff8775f3SQuinn Dawkins /// vector size specified with `m`, `n`, and `k`. 173ff8775f3SQuinn Dawkins static std::optional<SmallVector<int64_t>> 174ff8775f3SQuinn Dawkins getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) { 175ff8775f3SQuinn Dawkins if (auto contract = dyn_cast<vector::ContractionOp>(op)) { 176ff8775f3SQuinn Dawkins int64_t contractRank = contract.getIteratorTypes().size(); 177ff8775f3SQuinn Dawkins if (contractRank < 3) 178ff8775f3SQuinn Dawkins return std::nullopt; 179ff8775f3SQuinn Dawkins SmallVector<int64_t> nativeSize(contractRank - 3, 1); 180ff8775f3SQuinn Dawkins nativeSize.append({m, n, k}); 181ff8775f3SQuinn Dawkins return nativeSize; 182ff8775f3SQuinn Dawkins } 183ff8775f3SQuinn Dawkins if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 184ff8775f3SQuinn Dawkins int64_t writeRank = writeOp.getVectorType().getRank(); 185ff8775f3SQuinn Dawkins if (writeRank < 2) 186ff8775f3SQuinn Dawkins return std::nullopt; 187ff8775f3SQuinn Dawkins SmallVector<int64_t> nativeSize(writeRank - 2, 1); 188ff8775f3SQuinn Dawkins nativeSize.append({m, n}); 189ff8775f3SQuinn Dawkins return nativeSize; 190ff8775f3SQuinn Dawkins } 191ff8775f3SQuinn Dawkins if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 192ff8775f3SQuinn Dawkins // Transfer read ops may need different shapes based on how they are being 193ff8775f3SQuinn Dawkins // used. For simplicity just match the shape used by the extract strided op. 194ff8775f3SQuinn Dawkins VectorType sliceType; 195ff8775f3SQuinn Dawkins for (Operation *users : op->getUsers()) { 196ff8775f3SQuinn Dawkins auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users); 197ff8775f3SQuinn Dawkins if (!extract) 198ff8775f3SQuinn Dawkins return std::nullopt; 199a5757c5bSChristian Sigg auto vecType = cast<VectorType>(extract.getResult().getType()); 200ff8775f3SQuinn Dawkins if (sliceType && sliceType != vecType) 201ff8775f3SQuinn Dawkins return std::nullopt; 202ff8775f3SQuinn Dawkins sliceType = vecType; 203ff8775f3SQuinn Dawkins } 204ff8775f3SQuinn Dawkins return llvm::to_vector(sliceType.getShape()); 205ff8775f3SQuinn Dawkins } 206ff8775f3SQuinn Dawkins if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) { 207a5757c5bSChristian Sigg if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) { 208ff8775f3SQuinn Dawkins // TODO: The condition for unrolling elementwise should be restricted 209ff8775f3SQuinn Dawkins // only to operations that need unrolling (connected to the contract). 210ff8775f3SQuinn Dawkins if (vecType.getRank() < 2) 211ff8775f3SQuinn Dawkins return std::nullopt; 212ff8775f3SQuinn Dawkins 213ff8775f3SQuinn Dawkins // First check whether there is a slice to infer the shape from. This is 214ff8775f3SQuinn Dawkins // required for cases where the accumulator type differs from the input 215ff8775f3SQuinn Dawkins // types, in which case we will see an `arith.ext_` between the contract 216ff8775f3SQuinn Dawkins // and transfer_read which needs to be unrolled. 217ff8775f3SQuinn Dawkins VectorType sliceType; 218ff8775f3SQuinn Dawkins for (Operation *users : op->getUsers()) { 219ff8775f3SQuinn Dawkins auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users); 220ff8775f3SQuinn Dawkins if (!extract) 221ff8775f3SQuinn Dawkins return std::nullopt; 222a5757c5bSChristian Sigg auto vecType = cast<VectorType>(extract.getResult().getType()); 223ff8775f3SQuinn Dawkins if (sliceType && sliceType != vecType) 224ff8775f3SQuinn Dawkins return std::nullopt; 225ff8775f3SQuinn Dawkins sliceType = vecType; 226ff8775f3SQuinn Dawkins } 227ff8775f3SQuinn Dawkins if (sliceType) 228ff8775f3SQuinn Dawkins return llvm::to_vector(sliceType.getShape()); 229ff8775f3SQuinn Dawkins 230ff8775f3SQuinn Dawkins // Else unroll for trailing elementwise. 231ff8775f3SQuinn Dawkins SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1); 232ff8775f3SQuinn Dawkins // Map elementwise ops to the output shape. 233ff8775f3SQuinn Dawkins nativeSize.append({m, n}); 234ff8775f3SQuinn Dawkins return nativeSize; 235ff8775f3SQuinn Dawkins } 236ff8775f3SQuinn Dawkins } 237ff8775f3SQuinn Dawkins return std::nullopt; 238ff8775f3SQuinn Dawkins } 239ff8775f3SQuinn Dawkins 240ff8775f3SQuinn Dawkins void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns( 241ff8775f3SQuinn Dawkins RewritePatternSet &patterns) { 242ff8775f3SQuinn Dawkins auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> { 243ff8775f3SQuinn Dawkins auto contract = dyn_cast<vector::ContractionOp>(op); 244ff8775f3SQuinn Dawkins if (!contract) 245ff8775f3SQuinn Dawkins return std::nullopt; 246ff8775f3SQuinn Dawkins return gpuMmaUnrollOrder(contract); 247ff8775f3SQuinn Dawkins }; 248ff8775f3SQuinn Dawkins 249ff8775f3SQuinn Dawkins int64_t m = getM(); 250ff8775f3SQuinn Dawkins int64_t n = getN(); 251ff8775f3SQuinn Dawkins int64_t k = getK(); 252ff8775f3SQuinn Dawkins auto nativeShapeFn = 253ff8775f3SQuinn Dawkins [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> { 254ff8775f3SQuinn Dawkins return getSubgroupMmaNativeVectorSize(op, m, n, k); 255ff8775f3SQuinn Dawkins }; 256ff8775f3SQuinn Dawkins vector::populateVectorUnrollPatterns( 257ff8775f3SQuinn Dawkins patterns, vector::UnrollVectorOptions() 258ff8775f3SQuinn Dawkins .setNativeShapeFn(nativeShapeFn) 259ff8775f3SQuinn Dawkins .setUnrollTraversalOrderFn(unrollOrder)); 260ff8775f3SQuinn Dawkins } 261ff8775f3SQuinn Dawkins 262ff8775f3SQuinn Dawkins //===----------------------------------------------------------------------===// 2639ab34689SAlex Zinenko // EliminateBarriersOp 2649ab34689SAlex Zinenko //===----------------------------------------------------------------------===// 2659ab34689SAlex Zinenko 2669ab34689SAlex Zinenko void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) { 26700c3c731Sspaceotter populateGpuEliminateBarriersPatterns(patterns); 2689ab34689SAlex Zinenko } 2699ab34689SAlex Zinenko 2709ab34689SAlex Zinenko //===----------------------------------------------------------------------===// 2719ab34689SAlex Zinenko // Block and thread mapping utilities. 2729ab34689SAlex Zinenko //===----------------------------------------------------------------------===// 273768615bbSNicolas Vasilache 27492f088d3SNicolas Vasilache namespace { 27592f088d3SNicolas Vasilache /// Local types used for mapping verification. 27692f088d3SNicolas Vasilache struct MappingKind {}; 27792f088d3SNicolas Vasilache struct BlockMappingKind : MappingKind {}; 27892f088d3SNicolas Vasilache struct ThreadMappingKind : MappingKind {}; 27992f088d3SNicolas Vasilache } // namespace 28092f088d3SNicolas Vasilache 281aafb52d7SNicolas Vasilache static DiagnosedSilenceableFailure 282c59465e1SNicolas Vasilache definiteFailureHelper(std::optional<TransformOpInterface> transformOp, 283c59465e1SNicolas Vasilache Operation *target, const Twine &message) { 284aafb52d7SNicolas Vasilache if (transformOp.has_value()) 285c59465e1SNicolas Vasilache return transformOp->emitDefiniteFailure() << message; 286c59465e1SNicolas Vasilache return emitDefiniteFailure(target, message); 287aafb52d7SNicolas Vasilache } 288aafb52d7SNicolas Vasilache 289beaffb04SGuray Ozen /// Check if given mapping attributes are one of the desired attributes 29092f088d3SNicolas Vasilache template <typename MappingKindType> 291c5798faeSGuray Ozen static DiagnosedSilenceableFailure 292aafb52d7SNicolas Vasilache checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp, 293aafb52d7SNicolas Vasilache scf::ForallOp forallOp) { 29492f088d3SNicolas Vasilache if (!forallOp.getMapping().has_value()) { 295c59465e1SNicolas Vasilache return definiteFailureHelper(transformOp, forallOp, 29692f088d3SNicolas Vasilache "scf.forall op requires a mapping attribute"); 29792f088d3SNicolas Vasilache } 298aafb52d7SNicolas Vasilache 299971b8525SJakub Kuderski bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(), 300971b8525SJakub Kuderski llvm::IsaPred<GPUBlockMappingAttr>); 301971b8525SJakub Kuderski bool hasWarpgroupMapping = llvm::any_of( 302971b8525SJakub Kuderski forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>); 303971b8525SJakub Kuderski bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(), 304971b8525SJakub Kuderski llvm::IsaPred<GPUWarpMappingAttr>); 305971b8525SJakub Kuderski bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(), 306971b8525SJakub Kuderski llvm::IsaPred<GPUThreadMappingAttr>); 307aafb52d7SNicolas Vasilache int64_t countMappingTypes = 0; 308aafb52d7SNicolas Vasilache countMappingTypes += hasBlockMapping ? 1 : 0; 30944e6318cSNicolas Vasilache countMappingTypes += hasWarpgroupMapping ? 1 : 0; 310c59465e1SNicolas Vasilache countMappingTypes += hasWarpMapping ? 1 : 0; 31144e6318cSNicolas Vasilache countMappingTypes += hasThreadMapping ? 1 : 0; 312aafb52d7SNicolas Vasilache if (countMappingTypes > 1) { 313c59465e1SNicolas Vasilache return definiteFailureHelper( 314c59465e1SNicolas Vasilache transformOp, forallOp, 315aafb52d7SNicolas Vasilache "cannot mix different mapping types, use nesting"); 316aafb52d7SNicolas Vasilache } 31792f088d3SNicolas Vasilache if (std::is_same<MappingKindType, BlockMappingKind>::value && 31892f088d3SNicolas Vasilache !hasBlockMapping) { 31992f088d3SNicolas Vasilache return definiteFailureHelper( 32092f088d3SNicolas Vasilache transformOp, forallOp, 32192f088d3SNicolas Vasilache "scf.forall op requires a mapping attribute of kind 'block'"); 32292f088d3SNicolas Vasilache } 32392f088d3SNicolas Vasilache if (std::is_same<MappingKindType, ThreadMappingKind>::value && 32492f088d3SNicolas Vasilache !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) { 32592f088d3SNicolas Vasilache return definiteFailureHelper(transformOp, forallOp, 32692f088d3SNicolas Vasilache "scf.forall op requires a mapping attribute " 32792f088d3SNicolas Vasilache "of kind 'thread' or 'warp'"); 32892f088d3SNicolas Vasilache } 329beaffb04SGuray Ozen 330c5798faeSGuray Ozen DenseSet<Attribute> seen; 331aafb52d7SNicolas Vasilache for (Attribute map : forallOp.getMapping()->getValue()) { 3328bdf3878SKazu Hirata if (seen.contains(map)) { 333c59465e1SNicolas Vasilache return definiteFailureHelper( 334c59465e1SNicolas Vasilache transformOp, forallOp, 33544e6318cSNicolas Vasilache "duplicate attribute, cannot map different loops " 33644e6318cSNicolas Vasilache "to the same mapping id"); 337c5798faeSGuray Ozen } 338c5798faeSGuray Ozen seen.insert(map); 339c5798faeSGuray Ozen } 340c5798faeSGuray Ozen 34144e6318cSNicolas Vasilache auto isLinear = [](Attribute a) { 34244e6318cSNicolas Vasilache return cast<DeviceMappingAttrInterface>(a).isLinearMapping(); 34344e6318cSNicolas Vasilache }; 34444e6318cSNicolas Vasilache if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) && 34544e6318cSNicolas Vasilache !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) { 34644e6318cSNicolas Vasilache return definiteFailureHelper( 34744e6318cSNicolas Vasilache transformOp, forallOp, 34844e6318cSNicolas Vasilache "cannot mix linear and non-linear mapping modes"); 34944e6318cSNicolas Vasilache } 35044e6318cSNicolas Vasilache 351beaffb04SGuray Ozen return DiagnosedSilenceableFailure::success(); 352beaffb04SGuray Ozen } 353beaffb04SGuray Ozen 35492f088d3SNicolas Vasilache template <typename MappingKindType> 355aafb52d7SNicolas Vasilache static DiagnosedSilenceableFailure 356aafb52d7SNicolas Vasilache verifyGpuMapping(std::optional<TransformOpInterface> transformOp, 357aafb52d7SNicolas Vasilache scf::ForallOp forallOp) { 358aafb52d7SNicolas Vasilache // Check the types of the mapping attributes match. 359aafb52d7SNicolas Vasilache DiagnosedSilenceableFailure typeRes = 36092f088d3SNicolas Vasilache checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp); 361aafb52d7SNicolas Vasilache if (!typeRes.succeeded()) 362aafb52d7SNicolas Vasilache return typeRes; 363aafb52d7SNicolas Vasilache 364aafb52d7SNicolas Vasilache // Perform other non-types verifications. 365aafb52d7SNicolas Vasilache if (!forallOp.isNormalized()) 366c59465e1SNicolas Vasilache return definiteFailureHelper(transformOp, forallOp, 367aafb52d7SNicolas Vasilache "unsupported non-normalized loops"); 368aafb52d7SNicolas Vasilache if (forallOp.getNumResults() > 0) 369c59465e1SNicolas Vasilache return definiteFailureHelper(transformOp, forallOp, 370aafb52d7SNicolas Vasilache "only bufferized scf.forall can be mapped"); 37144e6318cSNicolas Vasilache bool useLinearMapping = cast<DeviceMappingAttrInterface>( 37244e6318cSNicolas Vasilache forallOp.getMapping()->getValue().front()) 37344e6318cSNicolas Vasilache .isLinearMapping(); 37444e6318cSNicolas Vasilache // TODO: This would be more natural with support for Optional<EnumParameter> 37544e6318cSNicolas Vasilache // in GPUDeviceMappingAttr. 37644e6318cSNicolas Vasilache int64_t maxNumMappingsSupported = 37744e6318cSNicolas Vasilache useLinearMapping ? (getMaxEnumValForMappingId() - 37844e6318cSNicolas Vasilache static_cast<uint64_t>(MappingId::DimZ)) 37944e6318cSNicolas Vasilache : 3; 38044e6318cSNicolas Vasilache if (forallOp.getRank() > maxNumMappingsSupported) { 381c59465e1SNicolas Vasilache return definiteFailureHelper(transformOp, forallOp, 38244e6318cSNicolas Vasilache "scf.forall with rank > ") 38344e6318cSNicolas Vasilache << maxNumMappingsSupported 38444e6318cSNicolas Vasilache << " does not lower for the specified mapping attribute type"; 38544e6318cSNicolas Vasilache } 38644e6318cSNicolas Vasilache auto numParallelIterations = 38744e6318cSNicolas Vasilache getConstantIntValues(forallOp.getMixedUpperBound()); 38844e6318cSNicolas Vasilache if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { 38944e6318cSNicolas Vasilache return definiteFailureHelper( 39044e6318cSNicolas Vasilache transformOp, forallOp, 39144e6318cSNicolas Vasilache "requires statically sized, normalized forall op"); 392aafb52d7SNicolas Vasilache } 393aafb52d7SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 394aafb52d7SNicolas Vasilache } 395aafb52d7SNicolas Vasilache 396c59465e1SNicolas Vasilache /// Struct to return the result of the rewrite of a forall operation. 397c59465e1SNicolas Vasilache struct ForallRewriteResult { 398c59465e1SNicolas Vasilache SmallVector<int64_t> mappingSizes; 399c59465e1SNicolas Vasilache SmallVector<Value> mappingIds; 400c59465e1SNicolas Vasilache }; 40189bb0caeSGuray Ozen 402c59465e1SNicolas Vasilache /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR. 403c59465e1SNicolas Vasilache template <typename OpTy, typename OperationOrBlock> 404c59465e1SNicolas Vasilache static void 405c59465e1SNicolas Vasilache replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, 406c59465e1SNicolas Vasilache OperationOrBlock *parent, Value replacement, 407c59465e1SNicolas Vasilache ArrayRef<int64_t> availableMappingSizes) { 408c59465e1SNicolas Vasilache parent->walk([&](OpTy idOp) { 409c59465e1SNicolas Vasilache if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1) 410c59465e1SNicolas Vasilache rewriter.replaceAllUsesWith(idOp.getResult(), replacement); 411c59465e1SNicolas Vasilache }); 412c59465e1SNicolas Vasilache } 413c59465e1SNicolas Vasilache 414c59465e1SNicolas Vasilache static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( 415768615bbSNicolas Vasilache RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, 41644e6318cSNicolas Vasilache scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes, 41744e6318cSNicolas Vasilache ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) { 418c59465e1SNicolas Vasilache LDBG("--start rewriteOneForallCommonImpl"); 419beaffb04SGuray Ozen 420768615bbSNicolas Vasilache // Step 1. Complete the mapping to a full mapping (with 1s) if necessary. 42144e6318cSNicolas Vasilache auto numParallelIterations = 42244e6318cSNicolas Vasilache getConstantIntValues(forallOp.getMixedUpperBound()); 42344e6318cSNicolas Vasilache assert(forallOp.isNormalized() && numParallelIterations.has_value() && 42444e6318cSNicolas Vasilache "requires statically sized, normalized forall op"); 42544e6318cSNicolas Vasilache SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value(); 42644e6318cSNicolas Vasilache SetVector<Attribute> forallMappingAttrs; 42744e6318cSNicolas Vasilache forallMappingAttrs.insert(forallOp.getMapping()->getValue().begin(), 42844e6318cSNicolas Vasilache forallOp.getMapping()->getValue().end()); 42944e6318cSNicolas Vasilache auto comparator = [](Attribute a, Attribute b) -> bool { 43044e6318cSNicolas Vasilache return cast<DeviceMappingAttrInterface>(a).getMappingId() < 43144e6318cSNicolas Vasilache cast<DeviceMappingAttrInterface>(b).getMappingId(); 43244e6318cSNicolas Vasilache }; 43344e6318cSNicolas Vasilache 43444e6318cSNicolas Vasilache // Step 1.b. In the linear case, compute the max mapping to avoid needlessly 43544e6318cSNicolas Vasilache // mapping all dimensions. In the 3-D mapping case we need to map all 43644e6318cSNicolas Vasilache // dimensions. 437fab2bb8bSJustin Lebar DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>( 438fab2bb8bSJustin Lebar *llvm::max_element(forallMappingAttrs, comparator)); 43944e6318cSNicolas Vasilache DeviceMappingAttrInterface maxLinearMapping; 44044e6318cSNicolas Vasilache if (maxMapping.isLinearMapping()) 44144e6318cSNicolas Vasilache maxLinearMapping = maxMapping; 442c59465e1SNicolas Vasilache for (auto attr : gpuIdBuilder.mappingAttributes) { 44344e6318cSNicolas Vasilache // If attr overflows, just skip. 44444e6318cSNicolas Vasilache if (maxLinearMapping && comparator(maxLinearMapping, attr)) 445768615bbSNicolas Vasilache continue; 44644e6318cSNicolas Vasilache // Try to insert. If element was already present, just continue. 44744e6318cSNicolas Vasilache if (!forallMappingAttrs.insert(attr)) 44844e6318cSNicolas Vasilache continue; 44944e6318cSNicolas Vasilache // Otherwise, we have a new insertion without a size -> use size 1. 450768615bbSNicolas Vasilache tmpMappingSizes.push_back(1); 4516663f347SGuray Ozen } 452c59465e1SNicolas Vasilache LLVM_DEBUG( 453c59465e1SNicolas Vasilache llvm::interleaveComma( 454c59465e1SNicolas Vasilache tmpMappingSizes, 455c59465e1SNicolas Vasilache DBGS() << "----tmpMappingSizes extracted from scf.forall op: "); 456c59465e1SNicolas Vasilache llvm::dbgs() << "\n"); 4576663f347SGuray Ozen 458beaffb04SGuray Ozen // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. 45944e6318cSNicolas Vasilache SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey( 46044e6318cSNicolas Vasilache forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator); 461c59465e1SNicolas Vasilache LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes, 462c59465e1SNicolas Vasilache DBGS() << "----forallMappingSizes: "); 463c59465e1SNicolas Vasilache llvm::dbgs() << "\n"; llvm::interleaveComma( 46444e6318cSNicolas Vasilache forallMappingAttrs, DBGS() << "----forallMappingAttrs: "); 465768615bbSNicolas Vasilache llvm::dbgs() << "\n"); 46689bb0caeSGuray Ozen 467c59465e1SNicolas Vasilache // Step 3. Generate the mappingIdOps using the provided generator. 468768615bbSNicolas Vasilache Location loc = forallOp.getLoc(); 469c59465e1SNicolas Vasilache OpBuilder::InsertionGuard guard(rewriter); 470c59465e1SNicolas Vasilache rewriter.setInsertionPoint(forallOp); 47144e6318cSNicolas Vasilache SmallVector<int64_t> originalBasis(availableMappingSizes); 47244e6318cSNicolas Vasilache bool originalBasisWasProvided = !originalBasis.empty(); 47344e6318cSNicolas Vasilache if (!originalBasisWasProvided) { 47444e6318cSNicolas Vasilache originalBasis = forallMappingSizes; 47544e6318cSNicolas Vasilache while (originalBasis.size() < 3) 47644e6318cSNicolas Vasilache originalBasis.push_back(1); 47744e6318cSNicolas Vasilache } 47889bb0caeSGuray Ozen 47944e6318cSNicolas Vasilache IdBuilderResult builderResult = 48044e6318cSNicolas Vasilache gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis); 48144e6318cSNicolas Vasilache 48244e6318cSNicolas Vasilache // Step 4. Map the induction variables to the mappingIdOps, this may involve 48344e6318cSNicolas Vasilache // a permutation. 484c59465e1SNicolas Vasilache SmallVector<Value> mappingIdOps = builderResult.mappingIdOps; 485768615bbSNicolas Vasilache IRMapping bvm; 48644e6318cSNicolas Vasilache for (auto [iv, dim] : llvm::zip_equal( 48744e6318cSNicolas Vasilache forallOp.getInductionVars(), 48844e6318cSNicolas Vasilache forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) { 48944e6318cSNicolas Vasilache auto mappingAttr = cast<DeviceMappingAttrInterface>(dim); 49044e6318cSNicolas Vasilache Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()]; 491768615bbSNicolas Vasilache bvm.map(iv, peIdOp); 492768615bbSNicolas Vasilache } 493768615bbSNicolas Vasilache 49444e6318cSNicolas Vasilache // Step 5. If the originalBasis is already known, create conditionals to 49544e6318cSNicolas Vasilache // predicate the region. Otherwise, the current forall determines the 49644e6318cSNicolas Vasilache // originalBasis and no predication occurs. 497768615bbSNicolas Vasilache Value predicate; 49844e6318cSNicolas Vasilache if (originalBasisWasProvided) { 49944e6318cSNicolas Vasilache SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes; 50044e6318cSNicolas Vasilache SmallVector<int64_t> availableMappingSizes = 50144e6318cSNicolas Vasilache builderResult.availableMappingSizes; 50244e6318cSNicolas Vasilache SmallVector<Value> activeIdOps = builderResult.activeIdOps; 503c59465e1SNicolas Vasilache // clang-format off 504c59465e1SNicolas Vasilache LLVM_DEBUG( 505c59465e1SNicolas Vasilache llvm::interleaveComma( 50644e6318cSNicolas Vasilache activeMappingSizes, DBGS() << "----activeMappingSizes: "); 507c59465e1SNicolas Vasilache llvm::dbgs() << "\n"; 508c59465e1SNicolas Vasilache llvm::interleaveComma( 509c59465e1SNicolas Vasilache availableMappingSizes, DBGS() << "----availableMappingSizes: "); 510c59465e1SNicolas Vasilache llvm::dbgs() << "\n"; 51144e6318cSNicolas Vasilache llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: "); 512768615bbSNicolas Vasilache llvm::dbgs() << "\n"); 513c59465e1SNicolas Vasilache // clang-format on 51444e6318cSNicolas Vasilache for (auto [activeId, activeMappingSize, availableMappingSize] : 51544e6318cSNicolas Vasilache llvm::zip_equal(activeIdOps, activeMappingSizes, 51644e6318cSNicolas Vasilache availableMappingSizes)) { 51744e6318cSNicolas Vasilache if (activeMappingSize > availableMappingSize) { 518c59465e1SNicolas Vasilache return definiteFailureHelper( 519768615bbSNicolas Vasilache transformOp, forallOp, 520768615bbSNicolas Vasilache "Trying to map to fewer GPU threads than loop iterations but " 521768615bbSNicolas Vasilache "overprovisioning is not yet supported. " 522768615bbSNicolas Vasilache "Try additional tiling of the before mapping or map to more " 523768615bbSNicolas Vasilache "threads."); 524768615bbSNicolas Vasilache } 52544e6318cSNicolas Vasilache if (activeMappingSize == availableMappingSize) 526768615bbSNicolas Vasilache continue; 52744e6318cSNicolas Vasilache Value idx = 52844e6318cSNicolas Vasilache rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize); 529768615bbSNicolas Vasilache Value tmpPredicate = rewriter.create<arith::CmpIOp>( 53044e6318cSNicolas Vasilache loc, arith::CmpIPredicate::ult, activeId, idx); 531c59465e1SNicolas Vasilache LDBG("----predicate: " << tmpPredicate); 532768615bbSNicolas Vasilache predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate, 533768615bbSNicolas Vasilache tmpPredicate) 534768615bbSNicolas Vasilache : tmpPredicate; 535768615bbSNicolas Vasilache } 536768615bbSNicolas Vasilache } 537768615bbSNicolas Vasilache 538c59465e1SNicolas Vasilache // Step 6. Move the body of forallOp. 539768615bbSNicolas Vasilache // Erase the terminator first, it will not be used. 540eb2f946eSAlexander Belyaev rewriter.eraseOp(forallOp.getTerminator()); 541768615bbSNicolas Vasilache Block *targetBlock; 542768615bbSNicolas Vasilache Block::iterator insertionPoint; 543768615bbSNicolas Vasilache if (predicate) { 544c59465e1SNicolas Vasilache // Step 6.a. If predicated, move at the beginning. 545c59465e1SNicolas Vasilache auto ifOp = rewriter.create<scf::IfOp>(loc, predicate, 546c59465e1SNicolas Vasilache /*withElseRegion=*/false); 547768615bbSNicolas Vasilache targetBlock = ifOp.thenBlock(); 548768615bbSNicolas Vasilache insertionPoint = ifOp.thenBlock()->begin(); 549768615bbSNicolas Vasilache } else { 550c59465e1SNicolas Vasilache // Step 6.b. Otherwise, move inline just at the rewriter insertion 551c59465e1SNicolas Vasilache // point. 552768615bbSNicolas Vasilache targetBlock = forallOp->getBlock(); 553768615bbSNicolas Vasilache insertionPoint = rewriter.getInsertionPoint(); 554768615bbSNicolas Vasilache } 555eb2f946eSAlexander Belyaev Block &sourceBlock = forallOp.getRegion().front(); 55689bb0caeSGuray Ozen targetBlock->getOperations().splice(insertionPoint, 55789bb0caeSGuray Ozen sourceBlock.getOperations()); 55889bb0caeSGuray Ozen 559c59465e1SNicolas Vasilache // Step 7. RAUW indices. 560eb2f946eSAlexander Belyaev for (Value loopIndex : forallOp.getInductionVars()) { 561768615bbSNicolas Vasilache Value threadIdx = bvm.lookup(loopIndex); 562768615bbSNicolas Vasilache rewriter.replaceAllUsesWith(loopIndex, threadIdx); 56389bb0caeSGuray Ozen } 56489bb0caeSGuray Ozen 565c59465e1SNicolas Vasilache // Step 8. Erase old op. 566eb2f946eSAlexander Belyaev rewriter.eraseOp(forallOp); 56789bb0caeSGuray Ozen 56844e6318cSNicolas Vasilache LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes, 56944e6318cSNicolas Vasilache DBGS() << "----result forallMappingSizes: "); 57044e6318cSNicolas Vasilache llvm::dbgs() << "\n"; llvm::interleaveComma( 57144e6318cSNicolas Vasilache mappingIdOps, DBGS() << "----result mappingIdOps: "); 57244e6318cSNicolas Vasilache llvm::dbgs() << "\n"); 57344e6318cSNicolas Vasilache 574c59465e1SNicolas Vasilache result = ForallRewriteResult{forallMappingSizes, mappingIdOps}; 575c59465e1SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 576768615bbSNicolas Vasilache } 577768615bbSNicolas Vasilache 578c59465e1SNicolas Vasilache //===----------------------------------------------------------------------===// 579c59465e1SNicolas Vasilache // MapForallToBlocks 580c59465e1SNicolas Vasilache //===----------------------------------------------------------------------===// 581c59465e1SNicolas Vasilache 582768615bbSNicolas Vasilache DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( 583768615bbSNicolas Vasilache RewriterBase &rewriter, TransformOpInterface transformOp, 584768615bbSNicolas Vasilache scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims, 585c59465e1SNicolas Vasilache const GpuIdBuilder &gpuIdBuilder) { 586c59465e1SNicolas Vasilache LDBG("Start mapForallToBlocksImpl"); 587c59465e1SNicolas Vasilache 58892f088d3SNicolas Vasilache { 58992f088d3SNicolas Vasilache // GPU-specific verifications. There is no better place to anchor 59092f088d3SNicolas Vasilache // those right now: the ForallOp is target-independent and the transform 59192f088d3SNicolas Vasilache // op does not apply to individual ForallOp. 59292f088d3SNicolas Vasilache DiagnosedSilenceableFailure diag = 59392f088d3SNicolas Vasilache verifyGpuMapping<BlockMappingKind>(transformOp, forallOp); 59492f088d3SNicolas Vasilache if (!diag.succeeded()) 59592f088d3SNicolas Vasilache return diag; 59692f088d3SNicolas Vasilache } 59792f088d3SNicolas Vasilache 598c59465e1SNicolas Vasilache Location loc = forallOp.getLoc(); 599c59465e1SNicolas Vasilache Block *parentBlock = forallOp->getBlock(); 600c59465e1SNicolas Vasilache Value zero; 601c59465e1SNicolas Vasilache { 602c59465e1SNicolas Vasilache // Create an early zero index value for replacements and immediately reset 603c59465e1SNicolas Vasilache // the insertion point. 604c59465e1SNicolas Vasilache OpBuilder::InsertionGuard guard(rewriter); 605c59465e1SNicolas Vasilache rewriter.setInsertionPointToStart(parentBlock); 606c59465e1SNicolas Vasilache zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 607c59465e1SNicolas Vasilache } 608c59465e1SNicolas Vasilache 609c59465e1SNicolas Vasilache ForallRewriteResult rewriteResult; 61044e6318cSNicolas Vasilache DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( 61144e6318cSNicolas Vasilache rewriter, transformOp, forallOp, 61244e6318cSNicolas Vasilache /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder); 613c59465e1SNicolas Vasilache 61444e6318cSNicolas Vasilache // Return if anything goes wrong, use silenceable failure as a match 61544e6318cSNicolas Vasilache // failure. 616c59465e1SNicolas Vasilache if (!diag.succeeded()) 617c59465e1SNicolas Vasilache return diag; 618c59465e1SNicolas Vasilache 61944e6318cSNicolas Vasilache // If gridDims was not provided already, set it from the return. 62044e6318cSNicolas Vasilache if (gridDims.empty()) { 621c59465e1SNicolas Vasilache gridDims = rewriteResult.mappingSizes; 62244e6318cSNicolas Vasilache while (gridDims.size() < 3) 62344e6318cSNicolas Vasilache gridDims.push_back(1); 62444e6318cSNicolas Vasilache } 62544e6318cSNicolas Vasilache assert(gridDims.size() == 3 && "Need 3-D gridDims"); 626c59465e1SNicolas Vasilache 627c59465e1SNicolas Vasilache // Replace ids of dimensions known to be 1 by 0 to simplify the IR. 628c59465e1SNicolas Vasilache // Here, the result of mapping determines the available mapping sizes. 629c59465e1SNicolas Vasilache replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero, 63044e6318cSNicolas Vasilache rewriteResult.mappingSizes); 631c59465e1SNicolas Vasilache 63289bb0caeSGuray Ozen return DiagnosedSilenceableFailure::success(); 63389bb0caeSGuray Ozen } 63489bb0caeSGuray Ozen 63544e6318cSNicolas Vasilache DiagnosedSilenceableFailure 63644e6318cSNicolas Vasilache mlir::transform::gpu::findTopLevelForallOp(Operation *target, 63744e6318cSNicolas Vasilache scf::ForallOp &topLevelForallOp, 63844e6318cSNicolas Vasilache TransformOpInterface transformOp) { 63944e6318cSNicolas Vasilache auto walkResult = target->walk([&](scf::ForallOp forallOp) { 64044e6318cSNicolas Vasilache if (forallOp->getParentOfType<scf::ForallOp>()) 64144e6318cSNicolas Vasilache return WalkResult::advance(); 64244e6318cSNicolas Vasilache if (topLevelForallOp) 64344e6318cSNicolas Vasilache // TODO: Handle multiple forall if they are independent. 64444e6318cSNicolas Vasilache return WalkResult::interrupt(); 64544e6318cSNicolas Vasilache topLevelForallOp = forallOp; 64644e6318cSNicolas Vasilache return WalkResult::advance(); 64744e6318cSNicolas Vasilache }); 64844e6318cSNicolas Vasilache 64992f088d3SNicolas Vasilache if (walkResult.wasInterrupted() || !topLevelForallOp) 65044e6318cSNicolas Vasilache return transformOp.emitSilenceableError() 65144e6318cSNicolas Vasilache << "could not find a unique topLevel scf.forall"; 65244e6318cSNicolas Vasilache return DiagnosedSilenceableFailure::success(); 65344e6318cSNicolas Vasilache } 65444e6318cSNicolas Vasilache 655c63d2b2cSMatthias Springer DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne( 656c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 657c63d2b2cSMatthias Springer ApplyToEachResultList &results, transform::TransformState &state) { 65889bb0caeSGuray Ozen LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); 65989bb0caeSGuray Ozen auto transformOp = cast<TransformOpInterface>(getOperation()); 66089bb0caeSGuray Ozen 66189bb0caeSGuray Ozen if (!getGenerateGpuLaunch() && !gpuLaunch) { 66289bb0caeSGuray Ozen DiagnosedSilenceableFailure diag = 66389bb0caeSGuray Ozen emitSilenceableError() 66489bb0caeSGuray Ozen << "Given target is not gpu.launch, set `generate_gpu_launch` " 66589bb0caeSGuray Ozen "attribute"; 66689bb0caeSGuray Ozen diag.attachNote(target->getLoc()) << "when applied to this payload op"; 66789bb0caeSGuray Ozen return diag; 66889bb0caeSGuray Ozen } 66989bb0caeSGuray Ozen 670eb2f946eSAlexander Belyaev scf::ForallOp topLevelForallOp; 671eb2f946eSAlexander Belyaev DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp( 672eb2f946eSAlexander Belyaev target, topLevelForallOp, transformOp); 67389bb0caeSGuray Ozen if (!diag.succeeded()) { 67489bb0caeSGuray Ozen diag.attachNote(target->getLoc()) << "when applied to this payload op"; 67589bb0caeSGuray Ozen return diag; 67689bb0caeSGuray Ozen } 67792f088d3SNicolas Vasilache assert(topLevelForallOp && "expect an scf.forall"); 67889bb0caeSGuray Ozen 679c59465e1SNicolas Vasilache SmallVector<int64_t> gridDims{getGridDims()}; 680768615bbSNicolas Vasilache if (!getGenerateGpuLaunch() && gridDims.size() != 3) 681aafb52d7SNicolas Vasilache return transformOp.emitDefiniteFailure("transform require size-3 mapping"); 682aafb52d7SNicolas Vasilache 68389bb0caeSGuray Ozen OpBuilder::InsertionGuard guard(rewriter); 684eb2f946eSAlexander Belyaev rewriter.setInsertionPoint(topLevelForallOp); 68589bb0caeSGuray Ozen 686eb2f946eSAlexander Belyaev // Generate gpu launch here and move the forall inside 68789bb0caeSGuray Ozen if (getGenerateGpuLaunch()) { 68889bb0caeSGuray Ozen DiagnosedSilenceableFailure diag = 68989bb0caeSGuray Ozen createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch); 69044e6318cSNicolas Vasilache if (!diag.succeeded()) 69189bb0caeSGuray Ozen return diag; 69244e6318cSNicolas Vasilache 69389bb0caeSGuray Ozen rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); 694eb2f946eSAlexander Belyaev Operation *newForallOp = rewriter.clone(*topLevelForallOp); 695eb2f946eSAlexander Belyaev rewriter.eraseOp(topLevelForallOp); 696eb2f946eSAlexander Belyaev topLevelForallOp = cast<scf::ForallOp>(newForallOp); 69789bb0caeSGuray Ozen } 69889bb0caeSGuray Ozen 69944e6318cSNicolas Vasilache // The BlockIdBuilder adapts to whatever is thrown at it. 70092f088d3SNicolas Vasilache bool useLinearMapping = false; 70192f088d3SNicolas Vasilache if (topLevelForallOp.getMapping()) { 70244e6318cSNicolas Vasilache auto mappingAttr = cast<DeviceMappingAttrInterface>( 70344e6318cSNicolas Vasilache topLevelForallOp.getMapping()->getValue().front()); 70492f088d3SNicolas Vasilache useLinearMapping = mappingAttr.isLinearMapping(); 70592f088d3SNicolas Vasilache } 70644e6318cSNicolas Vasilache GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping); 70744e6318cSNicolas Vasilache 7081cff4cbdSNicolas Vasilache diag = mlir::transform::gpu::mapForallToBlocksImpl( 709c59465e1SNicolas Vasilache rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder); 710aafb52d7SNicolas Vasilache if (!diag.succeeded()) 711aafb52d7SNicolas Vasilache return diag; 712aafb52d7SNicolas Vasilache 71344e6318cSNicolas Vasilache // Set the GPU launch configuration for the grid dims late, this is 71444e6318cSNicolas Vasilache // subject to IR inspection. 71589bb0caeSGuray Ozen diag = alterGpuLaunch(rewriter, gpuLaunch, 716768615bbSNicolas Vasilache cast<TransformOpInterface>(getOperation()), gridDims[0], 717768615bbSNicolas Vasilache gridDims[1], gridDims[2]); 71889bb0caeSGuray Ozen 7194b455a71SAlex Zinenko results.push_back(gpuLaunch); 72089bb0caeSGuray Ozen return diag; 72189bb0caeSGuray Ozen } 72289bb0caeSGuray Ozen 72392f088d3SNicolas Vasilache LogicalResult transform::MapForallToBlocks::verify() { 72492f088d3SNicolas Vasilache if (!getGridDims().empty() && getGridDims().size() != 3) { 72592f088d3SNicolas Vasilache return emitOpError() << "transform requires empty or size-3 grid_dims"; 72692f088d3SNicolas Vasilache } 72792f088d3SNicolas Vasilache return success(); 72892f088d3SNicolas Vasilache } 72992f088d3SNicolas Vasilache 73089bb0caeSGuray Ozen //===----------------------------------------------------------------------===// 7311cff4cbdSNicolas Vasilache // MapNestedForallToThreads 73289bb0caeSGuray Ozen //===----------------------------------------------------------------------===// 73389bb0caeSGuray Ozen 73444e6318cSNicolas Vasilache static DiagnosedSilenceableFailure checkMappingSpec( 73544e6318cSNicolas Vasilache std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp, 73644e6318cSNicolas Vasilache ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes, 73744e6318cSNicolas Vasilache int factor, bool useLinearMapping = false) { 73844e6318cSNicolas Vasilache if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) { 73944e6318cSNicolas Vasilache auto diag = definiteFailureHelper( 74044e6318cSNicolas Vasilache transformOp, forallOp, 74144e6318cSNicolas Vasilache Twine("3-D mapping: size of threadIdx.x must be a multiple of ") + 74244e6318cSNicolas Vasilache std::to_string(factor)); 74344e6318cSNicolas Vasilache return diag; 74444e6318cSNicolas Vasilache } 74544e6318cSNicolas Vasilache if (computeProduct(numParallelIterations) * factor > 74644e6318cSNicolas Vasilache computeProduct(blockOrGridSizes)) { 74744e6318cSNicolas Vasilache auto diag = definiteFailureHelper( 74844e6318cSNicolas Vasilache transformOp, forallOp, 74992f088d3SNicolas Vasilache Twine("the number of required parallel resources (blocks or " 75092f088d3SNicolas Vasilache "threads) ") + 75144e6318cSNicolas Vasilache std::to_string(computeProduct(numParallelIterations) * factor) + 75244e6318cSNicolas Vasilache std::string(" overflows the number of available resources ") + 75344e6318cSNicolas Vasilache std::to_string(computeProduct(blockOrGridSizes))); 75444e6318cSNicolas Vasilache return diag; 75544e6318cSNicolas Vasilache } 75644e6318cSNicolas Vasilache return DiagnosedSilenceableFailure::success(); 75744e6318cSNicolas Vasilache } 75844e6318cSNicolas Vasilache 75944e6318cSNicolas Vasilache static DiagnosedSilenceableFailure 76044e6318cSNicolas Vasilache getThreadIdBuilder(std::optional<TransformOpInterface> transformOp, 76144e6318cSNicolas Vasilache scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, 76244e6318cSNicolas Vasilache int64_t warpSize, GpuIdBuilder &gpuIdBuilder) { 76344e6318cSNicolas Vasilache auto mappingAttr = cast<DeviceMappingAttrInterface>( 76444e6318cSNicolas Vasilache forallOp.getMapping()->getValue().front()); 76544e6318cSNicolas Vasilache bool useLinearMapping = mappingAttr.isLinearMapping(); 76644e6318cSNicolas Vasilache 76744e6318cSNicolas Vasilache // Sanity checks that may result in runtime verification errors. 76844e6318cSNicolas Vasilache auto numParallelIterations = 76944e6318cSNicolas Vasilache getConstantIntValues((forallOp.getMixedUpperBound())); 77044e6318cSNicolas Vasilache if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { 77144e6318cSNicolas Vasilache return definiteFailureHelper( 77244e6318cSNicolas Vasilache transformOp, forallOp, 77344e6318cSNicolas Vasilache "requires statically sized, normalized forall op"); 77444e6318cSNicolas Vasilache } 77544e6318cSNicolas Vasilache int64_t factor = 1; 77644e6318cSNicolas Vasilache if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) { 77744e6318cSNicolas Vasilache factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize; 77844e6318cSNicolas Vasilache } else if (isa<GPUWarpMappingAttr>(mappingAttr)) { 77944e6318cSNicolas Vasilache factor = warpSize; 78044e6318cSNicolas Vasilache } 78144e6318cSNicolas Vasilache DiagnosedSilenceableFailure diag = 78244e6318cSNicolas Vasilache checkMappingSpec(transformOp, forallOp, numParallelIterations.value(), 78344e6318cSNicolas Vasilache blockSizes, factor, useLinearMapping); 78444e6318cSNicolas Vasilache if (!diag.succeeded()) 78544e6318cSNicolas Vasilache return diag; 78644e6318cSNicolas Vasilache 78744e6318cSNicolas Vasilache // Start mapping. 78844e6318cSNicolas Vasilache MLIRContext *ctx = forallOp.getContext(); 78944e6318cSNicolas Vasilache gpuIdBuilder = 79044e6318cSNicolas Vasilache TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr) 79144e6318cSNicolas Vasilache .Case([&](GPUWarpgroupMappingAttr) { 79244e6318cSNicolas Vasilache return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping); 79344e6318cSNicolas Vasilache }) 79444e6318cSNicolas Vasilache .Case([&](GPUWarpMappingAttr) { 79544e6318cSNicolas Vasilache return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping); 79644e6318cSNicolas Vasilache }) 79744e6318cSNicolas Vasilache .Case([&](GPUThreadMappingAttr) { 79844e6318cSNicolas Vasilache return GpuThreadIdBuilder(ctx, useLinearMapping); 79944e6318cSNicolas Vasilache }) 80044e6318cSNicolas Vasilache .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder { 80144e6318cSNicolas Vasilache llvm_unreachable("unknown mapping attribute"); 80244e6318cSNicolas Vasilache }); 80344e6318cSNicolas Vasilache return DiagnosedSilenceableFailure::success(); 80444e6318cSNicolas Vasilache } 80544e6318cSNicolas Vasilache 806c59465e1SNicolas Vasilache DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl( 807768615bbSNicolas Vasilache RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, 80844e6318cSNicolas Vasilache scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize, 80944e6318cSNicolas Vasilache bool syncAfterDistribute) { 81044e6318cSNicolas Vasilache 81192f088d3SNicolas Vasilache { 81292f088d3SNicolas Vasilache // GPU-specific verifications. There is no better place to anchor 81392f088d3SNicolas Vasilache // those right now: the ForallOp is target-independent and the transform 81492f088d3SNicolas Vasilache // op does not apply to individual ForallOp. 81592f088d3SNicolas Vasilache DiagnosedSilenceableFailure diag = 81692f088d3SNicolas Vasilache verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp); 81792f088d3SNicolas Vasilache if (!diag.succeeded()) 81892f088d3SNicolas Vasilache return diag; 81992f088d3SNicolas Vasilache } 82092f088d3SNicolas Vasilache 82144e6318cSNicolas Vasilache GpuIdBuilder gpuIdBuilder; 82244e6318cSNicolas Vasilache { 82344e6318cSNicolas Vasilache // Try to construct the id builder, if it fails, return. 82444e6318cSNicolas Vasilache DiagnosedSilenceableFailure diag = getThreadIdBuilder( 82544e6318cSNicolas Vasilache transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder); 82644e6318cSNicolas Vasilache if (!diag.succeeded()) 82744e6318cSNicolas Vasilache return diag; 828a7686db8SThomas Raoux } 829c59465e1SNicolas Vasilache 830768615bbSNicolas Vasilache Location loc = forallOp.getLoc(); 831768615bbSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 832c59465e1SNicolas Vasilache // Insert after to allow for syncthreads after `forall` is erased. 833768615bbSNicolas Vasilache rewriter.setInsertionPointAfter(forallOp); 834c59465e1SNicolas Vasilache ForallRewriteResult rewriteResult; 83544e6318cSNicolas Vasilache DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( 83644e6318cSNicolas Vasilache rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder); 837c59465e1SNicolas Vasilache if (!diag.succeeded()) 838c59465e1SNicolas Vasilache return diag; 839768615bbSNicolas Vasilache // Add a syncthreads if needed. TODO: warpsync 840768615bbSNicolas Vasilache if (syncAfterDistribute) 841768615bbSNicolas Vasilache rewriter.create<BarrierOp>(loc); 842c59465e1SNicolas Vasilache 843c59465e1SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 844beaffb04SGuray Ozen } 845c59465e1SNicolas Vasilache 846c59465e1SNicolas Vasilache DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( 847c59465e1SNicolas Vasilache RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, 84844e6318cSNicolas Vasilache Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize, 849c59465e1SNicolas Vasilache bool syncAfterDistribute) { 850c59465e1SNicolas Vasilache LDBG("Start mapNestedForallToThreadsImpl"); 85144e6318cSNicolas Vasilache if (blockDims.size() != 3) { 852c59465e1SNicolas Vasilache return definiteFailureHelper(transformOp, target, 853c59465e1SNicolas Vasilache "requires size-3 thread mapping"); 854c59465e1SNicolas Vasilache } 855c59465e1SNicolas Vasilache 856c59465e1SNicolas Vasilache // Create an early zero index value for replacements. 857c59465e1SNicolas Vasilache Location loc = target->getLoc(); 858c59465e1SNicolas Vasilache Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 859c59465e1SNicolas Vasilache DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); 860c59465e1SNicolas Vasilache WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) { 861c59465e1SNicolas Vasilache diag = mlir::transform::gpu::mapOneForallToThreadsImpl( 86244e6318cSNicolas Vasilache rewriter, transformOp, forallOp, blockDims, warpSize, 86344e6318cSNicolas Vasilache syncAfterDistribute); 864c59465e1SNicolas Vasilache if (diag.isDefiniteFailure()) 865c59465e1SNicolas Vasilache return WalkResult::interrupt(); 866c59465e1SNicolas Vasilache if (diag.succeeded()) 867c59465e1SNicolas Vasilache return WalkResult::skip(); 868c59465e1SNicolas Vasilache return WalkResult::advance(); 86989bb0caeSGuray Ozen }); 870c59465e1SNicolas Vasilache if (walkResult.wasInterrupted()) 87189bb0caeSGuray Ozen return diag; 872c59465e1SNicolas Vasilache 873c59465e1SNicolas Vasilache // Replace ids of dimensions known to be 1 by 0 to simplify the IR. 874c59465e1SNicolas Vasilache // Here, the result of mapping determines the available mapping sizes. 875c59465e1SNicolas Vasilache replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero, 876c59465e1SNicolas Vasilache blockDims); 877c59465e1SNicolas Vasilache 878c59465e1SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 87989bb0caeSGuray Ozen } 88089bb0caeSGuray Ozen 8811cff4cbdSNicolas Vasilache DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( 882c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, Operation *target, 883c63d2b2cSMatthias Springer ApplyToEachResultList &results, TransformState &state) { 88489bb0caeSGuray Ozen LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); 88589bb0caeSGuray Ozen auto transformOp = cast<TransformOpInterface>(getOperation()); 88689bb0caeSGuray Ozen 887aafb52d7SNicolas Vasilache // Basic high-level verifications. 888aafb52d7SNicolas Vasilache if (!gpuLaunch) 889aafb52d7SNicolas Vasilache return emitSilenceableError() << "Given target is not a gpu.launch"; 89089bb0caeSGuray Ozen 891c59465e1SNicolas Vasilache // Mapping to block ids. 892c59465e1SNicolas Vasilache SmallVector<int64_t> blockDims{getBlockDims()}; 89389bb0caeSGuray Ozen DiagnosedSilenceableFailure diag = 8941a36588eSKazu Hirata checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt, 895768615bbSNicolas Vasilache blockDims[0], blockDims[1], blockDims[2]); 89689bb0caeSGuray Ozen if (diag.isSilenceableFailure()) { 897c59465e1SNicolas Vasilache diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large"; 89889bb0caeSGuray Ozen return diag; 89989bb0caeSGuray Ozen } 90089bb0caeSGuray Ozen 901c59465e1SNicolas Vasilache // Set the GPU launch configuration for the block dims early, this is not 902c59465e1SNicolas Vasilache // subject to IR inspection. 9031a36588eSKazu Hirata diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, 904768615bbSNicolas Vasilache std::nullopt, std::nullopt, blockDims[0], blockDims[1], 905768615bbSNicolas Vasilache blockDims[2]); 90689bb0caeSGuray Ozen 907c59465e1SNicolas Vasilache rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); 908c59465e1SNicolas Vasilache diag = 909c59465e1SNicolas Vasilache mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims, 91044e6318cSNicolas Vasilache getWarpSize(), getSyncAfterDistribute()); 911c59465e1SNicolas Vasilache 912015cd84dSNicolas Vasilache results.push_back(gpuLaunch.getOperation()); 91389bb0caeSGuray Ozen return diag; 91489bb0caeSGuray Ozen } 91589bb0caeSGuray Ozen 91689bb0caeSGuray Ozen //===----------------------------------------------------------------------===// 91789bb0caeSGuray Ozen // Transform op registration 91889bb0caeSGuray Ozen //===----------------------------------------------------------------------===// 91989bb0caeSGuray Ozen 92089bb0caeSGuray Ozen namespace { 92189bb0caeSGuray Ozen /// Registers new ops and declares PDL as dependent dialect since the 92289bb0caeSGuray Ozen /// additional ops are using PDL types for operands and results. 92389bb0caeSGuray Ozen class GPUTransformDialectExtension 92489bb0caeSGuray Ozen : public transform::TransformDialectExtension< 92589bb0caeSGuray Ozen GPUTransformDialectExtension> { 92689bb0caeSGuray Ozen public: 927*84cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension) 928*84cc1865SNikhil Kalra 92989bb0caeSGuray Ozen GPUTransformDialectExtension() { 93089bb0caeSGuray Ozen declareGeneratedDialect<scf::SCFDialect>(); 93189bb0caeSGuray Ozen declareGeneratedDialect<arith::ArithDialect>(); 93289bb0caeSGuray Ozen declareGeneratedDialect<GPUDialect>(); 93389bb0caeSGuray Ozen registerTransformOps< 93489bb0caeSGuray Ozen #define GET_OP_LIST 93589bb0caeSGuray Ozen #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" 93689bb0caeSGuray Ozen >(); 93789bb0caeSGuray Ozen } 93889bb0caeSGuray Ozen }; 93989bb0caeSGuray Ozen } // namespace 94089bb0caeSGuray Ozen 94189bb0caeSGuray Ozen #define GET_OP_CLASSES 94289bb0caeSGuray Ozen #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" 94389bb0caeSGuray Ozen 94489bb0caeSGuray Ozen void mlir::gpu::registerTransformDialectExtension(DialectRegistry ®istry) { 94589bb0caeSGuray Ozen registry.addExtensions<GPUTransformDialectExtension>(); 94689bb0caeSGuray Ozen } 947