190ecfa2aSNicolas Vasilache //===- Utils.cpp - Utils for GPU transform ops ----------------------------===// 290ecfa2aSNicolas Vasilache // 390ecfa2aSNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 490ecfa2aSNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information. 590ecfa2aSNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 690ecfa2aSNicolas Vasilache // 790ecfa2aSNicolas Vasilache //===----------------------------------------------------------------------===// 890ecfa2aSNicolas Vasilache 990ecfa2aSNicolas Vasilache #include "mlir/Dialect/GPU/TransformOps/Utils.h" 1090ecfa2aSNicolas Vasilache 1190ecfa2aSNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h" 1290ecfa2aSNicolas Vasilache #include "mlir/Dialect/Arith/IR/Arith.h" 1390ecfa2aSNicolas Vasilache #include "mlir/Dialect/Func/IR/FuncOps.h" 1490ecfa2aSNicolas Vasilache #include "mlir/Dialect/GPU/IR/GPUDialect.h" 1590ecfa2aSNicolas Vasilache #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" 1690ecfa2aSNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h" 1790ecfa2aSNicolas Vasilache #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" 1890ecfa2aSNicolas Vasilache #include "mlir/Dialect/SCF/IR/SCF.h" 1990ecfa2aSNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformDialect.h" 205a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 2190ecfa2aSNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h" 2290ecfa2aSNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h" 2390ecfa2aSNicolas Vasilache #include "mlir/IR/AffineExpr.h" 2490ecfa2aSNicolas Vasilache #include "mlir/IR/Builders.h" 2590ecfa2aSNicolas Vasilache #include "mlir/IR/BuiltinAttributes.h" 2690ecfa2aSNicolas Vasilache #include "mlir/IR/IRMapping.h" 2790ecfa2aSNicolas Vasilache #include "mlir/IR/MLIRContext.h" 2890ecfa2aSNicolas Vasilache #include "mlir/IR/OpDefinition.h" 2990ecfa2aSNicolas Vasilache #include "mlir/IR/Value.h" 3090ecfa2aSNicolas Vasilache #include "mlir/IR/Visitors.h" 3190ecfa2aSNicolas Vasilache #include "mlir/Support/LLVM.h" 3290ecfa2aSNicolas Vasilache #include "llvm/ADT/STLExtras.h" 3390ecfa2aSNicolas Vasilache #include "llvm/ADT/SmallVector.h" 3490ecfa2aSNicolas Vasilache #include "llvm/ADT/TypeSwitch.h" 3590ecfa2aSNicolas Vasilache #include "llvm/Support/Debug.h" 3690ecfa2aSNicolas Vasilache 3790ecfa2aSNicolas Vasilache using namespace mlir; 3890ecfa2aSNicolas Vasilache using namespace mlir::gpu; 3990ecfa2aSNicolas Vasilache using namespace mlir::transform; 4090ecfa2aSNicolas Vasilache using namespace mlir::transform::gpu; 4190ecfa2aSNicolas Vasilache 4290ecfa2aSNicolas Vasilache #define DEBUG_TYPE "gpu-transforms" 4390ecfa2aSNicolas Vasilache 4490ecfa2aSNicolas Vasilache #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 45d8ed736cSMehdi Amini #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") 4690ecfa2aSNicolas Vasilache #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") 4790ecfa2aSNicolas Vasilache 4890ecfa2aSNicolas Vasilache /// Return a flattened thread id for the workgroup with given sizes. 4944e6318cSNicolas Vasilache template <typename ThreadOrBlockIdOp> 5044e6318cSNicolas Vasilache static Value buildLinearId(RewriterBase &rewriter, Location loc, 5144e6318cSNicolas Vasilache ArrayRef<OpFoldResult> originalBasisOfr) { 5290ecfa2aSNicolas Vasilache LLVM_DEBUG(llvm::interleaveComma( 5344e6318cSNicolas Vasilache originalBasisOfr, 5444e6318cSNicolas Vasilache DBGS() << "----buildLinearId with originalBasisOfr: "); 5590ecfa2aSNicolas Vasilache llvm::dbgs() << "\n"); 5644e6318cSNicolas Vasilache assert(originalBasisOfr.size() == 3 && "expected 3 sizes"); 5744e6318cSNicolas Vasilache IndexType indexType = rewriter.getIndexType(); 589a2a6a72SMehdi Amini AffineExpr tx, ty, tz, bdx, bdy; 5990ecfa2aSNicolas Vasilache bindDims(rewriter.getContext(), tx, ty, tz); 609a2a6a72SMehdi Amini bindSymbols(rewriter.getContext(), bdx, bdy); 6144e6318cSNicolas Vasilache SmallVector<OpFoldResult> vals{ 6244e6318cSNicolas Vasilache rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x) 6344e6318cSNicolas Vasilache .getResult(), 6444e6318cSNicolas Vasilache rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y) 6544e6318cSNicolas Vasilache .getResult(), 6644e6318cSNicolas Vasilache rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z) 6744e6318cSNicolas Vasilache .getResult(), 6844e6318cSNicolas Vasilache originalBasisOfr[0], originalBasisOfr[1]}; 6990ecfa2aSNicolas Vasilache OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 709a2a6a72SMehdi Amini rewriter, loc, tx + ty * bdx + tz * bdx * bdy, vals); 7190ecfa2aSNicolas Vasilache return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); 7290ecfa2aSNicolas Vasilache } 7390ecfa2aSNicolas Vasilache 7444e6318cSNicolas Vasilache /// Create a linear id builder that takes the `originalBasisOfr` and decompose 7544e6318cSNicolas Vasilache /// it in the basis of `forallMappingSizes`. The linear id builder returns an 7644e6318cSNicolas Vasilache /// n-D vector of ids for indexing and 1-D size + id for predicate generation. 7744e6318cSNicolas Vasilache template <typename ThreadOrBlockIdOp> 7844e6318cSNicolas Vasilache static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) { 7944e6318cSNicolas Vasilache auto res = [multiplicity](RewriterBase &rewriter, Location loc, 8044e6318cSNicolas Vasilache ArrayRef<int64_t> forallMappingSizes, 8144e6318cSNicolas Vasilache ArrayRef<int64_t> originalBasis) { 8244e6318cSNicolas Vasilache SmallVector<OpFoldResult> originalBasisOfr = 8344e6318cSNicolas Vasilache getAsIndexOpFoldResult(rewriter.getContext(), originalBasis); 8444e6318cSNicolas Vasilache OpFoldResult linearId = 8544e6318cSNicolas Vasilache buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr); 8644e6318cSNicolas Vasilache // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in 8744e6318cSNicolas Vasilache // "row-major" order. 8844e6318cSNicolas Vasilache SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes)); 8944e6318cSNicolas Vasilache SmallVector<int64_t> strides = computeStrides(reverseBasisSizes); 9044e6318cSNicolas Vasilache AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); 9144e6318cSNicolas Vasilache OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply( 9244e6318cSNicolas Vasilache rewriter, loc, d0.floorDiv(multiplicity), {linearId}); 9344e6318cSNicolas Vasilache SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides); 9444e6318cSNicolas Vasilache SmallVector<Value> ids; 9544e6318cSNicolas Vasilache // Reverse back to be in [0 .. n] order. 9644e6318cSNicolas Vasilache for (AffineExpr e : llvm::reverse(delinearizingExprs)) { 9744e6318cSNicolas Vasilache ids.push_back( 9844e6318cSNicolas Vasilache affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId})); 9944e6318cSNicolas Vasilache } 10044e6318cSNicolas Vasilache 10144e6318cSNicolas Vasilache // clang-format off 10244e6318cSNicolas Vasilache LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes, 10344e6318cSNicolas Vasilache DBGS() << "--delinearization basis: "); 10444e6318cSNicolas Vasilache llvm::dbgs() << "\n"; 10544e6318cSNicolas Vasilache llvm::interleaveComma(strides, 10644e6318cSNicolas Vasilache DBGS() << "--delinearization strides: "); 10744e6318cSNicolas Vasilache llvm::dbgs() << "\n"; 10844e6318cSNicolas Vasilache llvm::interleaveComma(delinearizingExprs, 10944e6318cSNicolas Vasilache DBGS() << "--delinearization exprs: "); 11044e6318cSNicolas Vasilache llvm::dbgs() << "\n"; 11144e6318cSNicolas Vasilache llvm::interleaveComma(ids, DBGS() << "--ids: "); 11244e6318cSNicolas Vasilache llvm::dbgs() << "\n";); 11344e6318cSNicolas Vasilache // clang-format on 11444e6318cSNicolas Vasilache 11544e6318cSNicolas Vasilache // Return n-D ids for indexing and 1-D size + id for predicate generation. 11644e6318cSNicolas Vasilache return IdBuilderResult{ 11744e6318cSNicolas Vasilache /*mappingIdOps=*/ids, 11844e6318cSNicolas Vasilache /*availableMappingSizes=*/ 11944e6318cSNicolas Vasilache SmallVector<int64_t>{computeProduct(originalBasis)}, 12044e6318cSNicolas Vasilache // `forallMappingSizes` iterate in the scaled basis, they need to be 12144e6318cSNicolas Vasilache // scaled back into the original basis to provide tight 12244e6318cSNicolas Vasilache // activeMappingSizes quantities for predication. 12344e6318cSNicolas Vasilache /*activeMappingSizes=*/ 124*129f1001SKazu Hirata SmallVector<int64_t>{computeProduct(forallMappingSizes) * 125*129f1001SKazu Hirata multiplicity}, 126*129f1001SKazu Hirata /*activeIdOps=*/SmallVector<Value>{cast<Value>(linearId)}}; 12744e6318cSNicolas Vasilache }; 12844e6318cSNicolas Vasilache 12944e6318cSNicolas Vasilache return res; 13044e6318cSNicolas Vasilache } 13144e6318cSNicolas Vasilache 13244e6318cSNicolas Vasilache /// Create a simple 3-D id builder that takes the `originalBasisOfr` 13344e6318cSNicolas Vasilache /// The 3-D id builder returns a 3-D vector of ids for indexing and 3-D sizes 13444e6318cSNicolas Vasilache /// + ids for predicate generation. 13544e6318cSNicolas Vasilache template <typename ThreadOrBlockIdOp> 13644e6318cSNicolas Vasilache static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) { 13744e6318cSNicolas Vasilache auto res = [multiplicity](RewriterBase &rewriter, Location loc, 13844e6318cSNicolas Vasilache ArrayRef<int64_t> forallMappingSizes, 13944e6318cSNicolas Vasilache ArrayRef<int64_t> originalBasis) { 14044e6318cSNicolas Vasilache IndexType indexType = rewriter.getIndexType(); 14144e6318cSNicolas Vasilache SmallVector<Value> ids{ 14244e6318cSNicolas Vasilache rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x), 14344e6318cSNicolas Vasilache rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y), 14444e6318cSNicolas Vasilache rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)}; 14544e6318cSNicolas Vasilache // In the 3-D mapping case, scale the first dimension by the multiplicity. 14644e6318cSNicolas Vasilache SmallVector<Value> scaledIds = ids; 14744e6318cSNicolas Vasilache AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); 148*129f1001SKazu Hirata scaledIds[0] = cast<Value>(affine::makeComposedFoldedAffineApply( 149*129f1001SKazu Hirata rewriter, loc, d0.floorDiv(multiplicity), {scaledIds[0]})); 15044e6318cSNicolas Vasilache // In the 3-D mapping case, unscale the first dimension by the multiplicity. 1515262865aSKazu Hirata SmallVector<int64_t> forallMappingSizeInOriginalBasis(forallMappingSizes); 15244e6318cSNicolas Vasilache forallMappingSizeInOriginalBasis[0] *= multiplicity; 15344e6318cSNicolas Vasilache return IdBuilderResult{ 15444e6318cSNicolas Vasilache /*mappingIdOps=*/scaledIds, 15544e6318cSNicolas Vasilache /*availableMappingSizes=*/SmallVector<int64_t>{originalBasis}, 15644e6318cSNicolas Vasilache // `forallMappingSizes` iterate in the scaled basis, they need to be 15744e6318cSNicolas Vasilache // scaled back into the original basis to provide tight 15844e6318cSNicolas Vasilache // activeMappingSizes quantities for predication. 15944e6318cSNicolas Vasilache /*activeMappingSizes=*/ 16044e6318cSNicolas Vasilache SmallVector<int64_t>{forallMappingSizeInOriginalBasis}, 16144e6318cSNicolas Vasilache /*activeIdOps=*/ids}; 16244e6318cSNicolas Vasilache }; 16344e6318cSNicolas Vasilache return res; 16444e6318cSNicolas Vasilache } 16544e6318cSNicolas Vasilache 16690ecfa2aSNicolas Vasilache namespace mlir { 16790ecfa2aSNicolas Vasilache namespace transform { 16890ecfa2aSNicolas Vasilache namespace gpu { 16990ecfa2aSNicolas Vasilache 17044e6318cSNicolas Vasilache GpuIdBuilder::GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping, 17174cf9bcfSMehdi Amini const MappingIdBuilderFnType &fn) 17244e6318cSNicolas Vasilache : mappingAttributes(), idBuilder() { 17344e6318cSNicolas Vasilache if (useLinearMapping) { 17444e6318cSNicolas Vasilache for (uint64_t d = static_cast<uint64_t>(MappingId::LinearDim0), 17544e6318cSNicolas Vasilache e = getMaxEnumValForMappingId(); 17644e6318cSNicolas Vasilache d <= e; ++d) 17744e6318cSNicolas Vasilache mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value())); 17844e6318cSNicolas Vasilache } else { 17944e6318cSNicolas Vasilache for (uint64_t d = static_cast<uint64_t>(MappingId::DimX), 18044e6318cSNicolas Vasilache e = static_cast<uint64_t>(MappingId::DimZ); 18144e6318cSNicolas Vasilache d <= e; ++d) 18244e6318cSNicolas Vasilache mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value())); 18344e6318cSNicolas Vasilache } 18490ecfa2aSNicolas Vasilache } 18590ecfa2aSNicolas Vasilache 18644e6318cSNicolas Vasilache GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping) 18744e6318cSNicolas Vasilache : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { 18844e6318cSNicolas Vasilache return GPUBlockMappingAttr::get(ctx, id); 18944e6318cSNicolas Vasilache }) { 19044e6318cSNicolas Vasilache idBuilder = useLinearMapping 19144e6318cSNicolas Vasilache ? commonLinearIdBuilderFn<BlockIdOp>(/*multiplicity=*/1) 19244e6318cSNicolas Vasilache : common3DIdBuilderFn<BlockIdOp>(/*multiplicity=*/1); 19390ecfa2aSNicolas Vasilache } 19490ecfa2aSNicolas Vasilache 19544e6318cSNicolas Vasilache GpuWarpgroupIdBuilder::GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize, 19644e6318cSNicolas Vasilache bool useLinearMapping) 19744e6318cSNicolas Vasilache : GpuIdBuilder(ctx, useLinearMapping, 19844e6318cSNicolas Vasilache [](MLIRContext *ctx, MappingId id) { 19944e6318cSNicolas Vasilache return GPUWarpgroupMappingAttr::get(ctx, id); 20044e6318cSNicolas Vasilache }), 20144e6318cSNicolas Vasilache warpSize(warpSize) { 20244e6318cSNicolas Vasilache idBuilder = useLinearMapping 20344e6318cSNicolas Vasilache ? commonLinearIdBuilderFn<ThreadIdOp>( 20444e6318cSNicolas Vasilache /*multiplicity=*/kNumWarpsPerGroup * warpSize) 20544e6318cSNicolas Vasilache : common3DIdBuilderFn<ThreadIdOp>( 20644e6318cSNicolas Vasilache /*multiplicity=*/kNumWarpsPerGroup * warpSize); 20790ecfa2aSNicolas Vasilache } 20890ecfa2aSNicolas Vasilache 20944e6318cSNicolas Vasilache GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize, 21044e6318cSNicolas Vasilache bool useLinearMapping) 21144e6318cSNicolas Vasilache : GpuIdBuilder(ctx, useLinearMapping, 21244e6318cSNicolas Vasilache [](MLIRContext *ctx, MappingId id) { 21344e6318cSNicolas Vasilache return GPUWarpMappingAttr::get(ctx, id); 21444e6318cSNicolas Vasilache }), 21544e6318cSNicolas Vasilache warpSize(warpSize) { 21644e6318cSNicolas Vasilache idBuilder = 21744e6318cSNicolas Vasilache useLinearMapping 21844e6318cSNicolas Vasilache ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize) 21944e6318cSNicolas Vasilache : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize); 22044e6318cSNicolas Vasilache } 22190ecfa2aSNicolas Vasilache 22244e6318cSNicolas Vasilache GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping) 22344e6318cSNicolas Vasilache : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { 22444e6318cSNicolas Vasilache return GPUThreadMappingAttr::get(ctx, id); 22544e6318cSNicolas Vasilache }) { 22644e6318cSNicolas Vasilache idBuilder = useLinearMapping 22744e6318cSNicolas Vasilache ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1) 22844e6318cSNicolas Vasilache : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1); 22990ecfa2aSNicolas Vasilache } 23090ecfa2aSNicolas Vasilache 23190ecfa2aSNicolas Vasilache DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp, 23290ecfa2aSNicolas Vasilache std::optional<int64_t> gridDimX, 23390ecfa2aSNicolas Vasilache std::optional<int64_t> gridDimY, 23490ecfa2aSNicolas Vasilache std::optional<int64_t> gridDimZ, 23590ecfa2aSNicolas Vasilache std::optional<int64_t> blockDimX, 23690ecfa2aSNicolas Vasilache std::optional<int64_t> blockDimY, 23790ecfa2aSNicolas Vasilache std::optional<int64_t> blockDimZ) { 23890ecfa2aSNicolas Vasilache 23990ecfa2aSNicolas Vasilache // TODO: pass a configuration object to set the limits properly. 24090ecfa2aSNicolas Vasilache static constexpr int maxTotalBlockdim = 1024; 24190ecfa2aSNicolas Vasilache static constexpr int maxBlockdimx = 1024; 24290ecfa2aSNicolas Vasilache static constexpr int maxBlockdimy = 1024; 24390ecfa2aSNicolas Vasilache static constexpr int maxBlockdimz = 64; 24490ecfa2aSNicolas Vasilache static constexpr int maxTotalGriddim = 2147483647; 24590ecfa2aSNicolas Vasilache static constexpr int maxGriddimx = 2147483647; 24690ecfa2aSNicolas Vasilache static constexpr int maxGriddimy = 65535; 24790ecfa2aSNicolas Vasilache static constexpr int maxGriddimz = 65535; 24890ecfa2aSNicolas Vasilache 24990ecfa2aSNicolas Vasilache if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) > 25090ecfa2aSNicolas Vasilache maxTotalBlockdim || 25190ecfa2aSNicolas Vasilache (gridDimX.value_or(1) * gridDimY.value_or(1) * gridDimZ.value_or(1)) > 25290ecfa2aSNicolas Vasilache maxTotalGriddim || 25390ecfa2aSNicolas Vasilache blockDimX.value_or(1) > maxBlockdimx || 25490ecfa2aSNicolas Vasilache blockDimY.value_or(1) > maxBlockdimy || 25590ecfa2aSNicolas Vasilache blockDimZ.value_or(1) > maxBlockdimz || 25690ecfa2aSNicolas Vasilache gridDimY.value_or(1) > maxGriddimy || 25790ecfa2aSNicolas Vasilache gridDimZ.value_or(1) > maxGriddimz || 25890ecfa2aSNicolas Vasilache gridDimX.value_or(1) > maxGriddimx) { 25990ecfa2aSNicolas Vasilache return transformOp.emitSilenceableError() 26090ecfa2aSNicolas Vasilache << "Trying to launch a GPU kernel with grid_dims = (" 26190ecfa2aSNicolas Vasilache << gridDimX.value_or(1) << ", " << gridDimY.value_or(1) << ", " 26290ecfa2aSNicolas Vasilache << gridDimZ.value_or(1) << ") block_dims = (" 26390ecfa2aSNicolas Vasilache << blockDimX.value_or(1) << ", " << blockDimY.value_or(1) << ", " 26490ecfa2aSNicolas Vasilache << blockDimZ.value_or(1) << "). It is larger than the limits."; 26590ecfa2aSNicolas Vasilache } 26690ecfa2aSNicolas Vasilache return DiagnosedSilenceableFailure::success(); 26790ecfa2aSNicolas Vasilache } 26890ecfa2aSNicolas Vasilache 26990ecfa2aSNicolas Vasilache DiagnosedSilenceableFailure createGpuLaunch( 27090ecfa2aSNicolas Vasilache RewriterBase &rewriter, Location loc, TransformOpInterface transformOp, 27190ecfa2aSNicolas Vasilache LaunchOp &launchOp, std::optional<int64_t> gridDimX, 27290ecfa2aSNicolas Vasilache std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ, 27390ecfa2aSNicolas Vasilache std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY, 27490ecfa2aSNicolas Vasilache std::optional<int64_t> blockDimZ) { 27590ecfa2aSNicolas Vasilache DiagnosedSilenceableFailure diag = 27690ecfa2aSNicolas Vasilache checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX, 27790ecfa2aSNicolas Vasilache blockDimY, blockDimZ); 27890ecfa2aSNicolas Vasilache if (!diag.succeeded()) 27990ecfa2aSNicolas Vasilache return diag; 28090ecfa2aSNicolas Vasilache 28190ecfa2aSNicolas Vasilache auto createConst = [&](int dim) { 28290ecfa2aSNicolas Vasilache return rewriter.create<arith::ConstantIndexOp>(loc, dim); 28390ecfa2aSNicolas Vasilache }; 28490ecfa2aSNicolas Vasilache OpBuilder::InsertionGuard guard(rewriter); 28590ecfa2aSNicolas Vasilache Value one = createConst(1); 28690ecfa2aSNicolas Vasilache Value gridSizeX = gridDimX.has_value() ? createConst(gridDimX.value()) : one; 28790ecfa2aSNicolas Vasilache Value gridSizeY = gridDimY.has_value() ? createConst(gridDimY.value()) : one; 28890ecfa2aSNicolas Vasilache Value gridSizeZ = gridDimZ.has_value() ? createConst(gridDimZ.value()) : one; 28990ecfa2aSNicolas Vasilache Value blkSizeX = blockDimX.has_value() ? createConst(blockDimX.value()) : one; 29090ecfa2aSNicolas Vasilache Value blkSizeY = blockDimY.has_value() ? createConst(blockDimY.value()) : one; 29190ecfa2aSNicolas Vasilache Value blkSizeZ = blockDimZ.has_value() ? createConst(blockDimZ.value()) : one; 29290ecfa2aSNicolas Vasilache launchOp = rewriter.create<LaunchOp>(loc, gridSizeX, gridSizeY, gridSizeZ, 29390ecfa2aSNicolas Vasilache blkSizeX, blkSizeY, blkSizeZ); 29490ecfa2aSNicolas Vasilache rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); 29590ecfa2aSNicolas Vasilache rewriter.create<TerminatorOp>(loc); 29690ecfa2aSNicolas Vasilache return DiagnosedSilenceableFailure::success(); 29790ecfa2aSNicolas Vasilache } 29890ecfa2aSNicolas Vasilache 29990ecfa2aSNicolas Vasilache /// Alter kernel configuration of the given kernel. 30090ecfa2aSNicolas Vasilache DiagnosedSilenceableFailure alterGpuLaunch( 30190ecfa2aSNicolas Vasilache RewriterBase &rewriter, LaunchOp gpuLaunch, 30290ecfa2aSNicolas Vasilache TransformOpInterface transformOp, std::optional<int64_t> gridDimX, 30390ecfa2aSNicolas Vasilache std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ, 30490ecfa2aSNicolas Vasilache std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY, 30590ecfa2aSNicolas Vasilache std::optional<int64_t> blockDimZ) { 30690ecfa2aSNicolas Vasilache DiagnosedSilenceableFailure diag = 30790ecfa2aSNicolas Vasilache checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX, 30890ecfa2aSNicolas Vasilache blockDimY, blockDimZ); 30990ecfa2aSNicolas Vasilache if (!diag.succeeded()) 31090ecfa2aSNicolas Vasilache return diag; 31190ecfa2aSNicolas Vasilache 31290ecfa2aSNicolas Vasilache KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues(); 31390ecfa2aSNicolas Vasilache OpBuilder::InsertionGuard guard(rewriter); 31490ecfa2aSNicolas Vasilache rewriter.setInsertionPointAfterValue(currentBlockdim.x); 31590ecfa2aSNicolas Vasilache auto createConstValue = [&](int dim) { 31690ecfa2aSNicolas Vasilache return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(), 31790ecfa2aSNicolas Vasilache dim); 31890ecfa2aSNicolas Vasilache }; 31990ecfa2aSNicolas Vasilache 32090ecfa2aSNicolas Vasilache if (gridDimX.has_value()) 32190ecfa2aSNicolas Vasilache gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value())); 32290ecfa2aSNicolas Vasilache if (gridDimY.has_value()) 32390ecfa2aSNicolas Vasilache gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value())); 32490ecfa2aSNicolas Vasilache if (gridDimZ.has_value()) 32590ecfa2aSNicolas Vasilache gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value())); 32690ecfa2aSNicolas Vasilache if (blockDimX.has_value()) 32790ecfa2aSNicolas Vasilache gpuLaunch.getBlockSizeXMutable().assign( 32890ecfa2aSNicolas Vasilache createConstValue(blockDimX.value())); 32990ecfa2aSNicolas Vasilache if (blockDimY.has_value()) 33090ecfa2aSNicolas Vasilache gpuLaunch.getBlockSizeYMutable().assign( 33190ecfa2aSNicolas Vasilache createConstValue(blockDimY.value())); 33290ecfa2aSNicolas Vasilache if (blockDimZ.has_value()) 33390ecfa2aSNicolas Vasilache gpuLaunch.getBlockSizeZMutable().assign( 33490ecfa2aSNicolas Vasilache createConstValue(blockDimZ.value())); 33590ecfa2aSNicolas Vasilache return DiagnosedSilenceableFailure::success(); 33690ecfa2aSNicolas Vasilache } 33790ecfa2aSNicolas Vasilache 33890ecfa2aSNicolas Vasilache } // namespace gpu 33990ecfa2aSNicolas Vasilache } // namespace transform 34090ecfa2aSNicolas Vasilache } // namespace mlir 341