xref: /llvm-project/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp (revision 129f1001c3b1b5200de43917d53c0efbdf08f11f)
190ecfa2aSNicolas Vasilache //===- Utils.cpp - Utils for GPU transform ops ----------------------------===//
290ecfa2aSNicolas Vasilache //
390ecfa2aSNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
490ecfa2aSNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
590ecfa2aSNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
690ecfa2aSNicolas Vasilache //
790ecfa2aSNicolas Vasilache //===----------------------------------------------------------------------===//
890ecfa2aSNicolas Vasilache 
990ecfa2aSNicolas Vasilache #include "mlir/Dialect/GPU/TransformOps/Utils.h"
1090ecfa2aSNicolas Vasilache 
1190ecfa2aSNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
1290ecfa2aSNicolas Vasilache #include "mlir/Dialect/Arith/IR/Arith.h"
1390ecfa2aSNicolas Vasilache #include "mlir/Dialect/Func/IR/FuncOps.h"
1490ecfa2aSNicolas Vasilache #include "mlir/Dialect/GPU/IR/GPUDialect.h"
1590ecfa2aSNicolas Vasilache #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
1690ecfa2aSNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h"
1790ecfa2aSNicolas Vasilache #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
1890ecfa2aSNicolas Vasilache #include "mlir/Dialect/SCF/IR/SCF.h"
1990ecfa2aSNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformDialect.h"
205a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
2190ecfa2aSNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h"
2290ecfa2aSNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h"
2390ecfa2aSNicolas Vasilache #include "mlir/IR/AffineExpr.h"
2490ecfa2aSNicolas Vasilache #include "mlir/IR/Builders.h"
2590ecfa2aSNicolas Vasilache #include "mlir/IR/BuiltinAttributes.h"
2690ecfa2aSNicolas Vasilache #include "mlir/IR/IRMapping.h"
2790ecfa2aSNicolas Vasilache #include "mlir/IR/MLIRContext.h"
2890ecfa2aSNicolas Vasilache #include "mlir/IR/OpDefinition.h"
2990ecfa2aSNicolas Vasilache #include "mlir/IR/Value.h"
3090ecfa2aSNicolas Vasilache #include "mlir/IR/Visitors.h"
3190ecfa2aSNicolas Vasilache #include "mlir/Support/LLVM.h"
3290ecfa2aSNicolas Vasilache #include "llvm/ADT/STLExtras.h"
3390ecfa2aSNicolas Vasilache #include "llvm/ADT/SmallVector.h"
3490ecfa2aSNicolas Vasilache #include "llvm/ADT/TypeSwitch.h"
3590ecfa2aSNicolas Vasilache #include "llvm/Support/Debug.h"
3690ecfa2aSNicolas Vasilache 
3790ecfa2aSNicolas Vasilache using namespace mlir;
3890ecfa2aSNicolas Vasilache using namespace mlir::gpu;
3990ecfa2aSNicolas Vasilache using namespace mlir::transform;
4090ecfa2aSNicolas Vasilache using namespace mlir::transform::gpu;
4190ecfa2aSNicolas Vasilache 
4290ecfa2aSNicolas Vasilache #define DEBUG_TYPE "gpu-transforms"
4390ecfa2aSNicolas Vasilache 
4490ecfa2aSNicolas Vasilache #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
45d8ed736cSMehdi Amini #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
4690ecfa2aSNicolas Vasilache #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
4790ecfa2aSNicolas Vasilache 
4890ecfa2aSNicolas Vasilache /// Return a flattened thread id for the workgroup with given sizes.
4944e6318cSNicolas Vasilache template <typename ThreadOrBlockIdOp>
5044e6318cSNicolas Vasilache static Value buildLinearId(RewriterBase &rewriter, Location loc,
5144e6318cSNicolas Vasilache                            ArrayRef<OpFoldResult> originalBasisOfr) {
5290ecfa2aSNicolas Vasilache   LLVM_DEBUG(llvm::interleaveComma(
5344e6318cSNicolas Vasilache                  originalBasisOfr,
5444e6318cSNicolas Vasilache                  DBGS() << "----buildLinearId with originalBasisOfr:  ");
5590ecfa2aSNicolas Vasilache              llvm::dbgs() << "\n");
5644e6318cSNicolas Vasilache   assert(originalBasisOfr.size() == 3 && "expected 3 sizes");
5744e6318cSNicolas Vasilache   IndexType indexType = rewriter.getIndexType();
589a2a6a72SMehdi Amini   AffineExpr tx, ty, tz, bdx, bdy;
5990ecfa2aSNicolas Vasilache   bindDims(rewriter.getContext(), tx, ty, tz);
609a2a6a72SMehdi Amini   bindSymbols(rewriter.getContext(), bdx, bdy);
6144e6318cSNicolas Vasilache   SmallVector<OpFoldResult> vals{
6244e6318cSNicolas Vasilache       rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x)
6344e6318cSNicolas Vasilache           .getResult(),
6444e6318cSNicolas Vasilache       rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y)
6544e6318cSNicolas Vasilache           .getResult(),
6644e6318cSNicolas Vasilache       rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)
6744e6318cSNicolas Vasilache           .getResult(),
6844e6318cSNicolas Vasilache       originalBasisOfr[0], originalBasisOfr[1]};
6990ecfa2aSNicolas Vasilache   OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
709a2a6a72SMehdi Amini       rewriter, loc, tx + ty * bdx + tz * bdx * bdy, vals);
7190ecfa2aSNicolas Vasilache   return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
7290ecfa2aSNicolas Vasilache }
7390ecfa2aSNicolas Vasilache 
7444e6318cSNicolas Vasilache /// Create a linear id builder that takes the `originalBasisOfr` and decompose
7544e6318cSNicolas Vasilache /// it in the basis of `forallMappingSizes`. The linear id builder returns an
7644e6318cSNicolas Vasilache /// n-D vector of ids for indexing and 1-D size + id for predicate generation.
7744e6318cSNicolas Vasilache template <typename ThreadOrBlockIdOp>
7844e6318cSNicolas Vasilache static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
7944e6318cSNicolas Vasilache   auto res = [multiplicity](RewriterBase &rewriter, Location loc,
8044e6318cSNicolas Vasilache                             ArrayRef<int64_t> forallMappingSizes,
8144e6318cSNicolas Vasilache                             ArrayRef<int64_t> originalBasis) {
8244e6318cSNicolas Vasilache     SmallVector<OpFoldResult> originalBasisOfr =
8344e6318cSNicolas Vasilache         getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
8444e6318cSNicolas Vasilache     OpFoldResult linearId =
8544e6318cSNicolas Vasilache         buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr);
8644e6318cSNicolas Vasilache     // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
8744e6318cSNicolas Vasilache     // "row-major" order.
8844e6318cSNicolas Vasilache     SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
8944e6318cSNicolas Vasilache     SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
9044e6318cSNicolas Vasilache     AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
9144e6318cSNicolas Vasilache     OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
9244e6318cSNicolas Vasilache         rewriter, loc, d0.floorDiv(multiplicity), {linearId});
9344e6318cSNicolas Vasilache     SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
9444e6318cSNicolas Vasilache     SmallVector<Value> ids;
9544e6318cSNicolas Vasilache     // Reverse back to be in [0 .. n] order.
9644e6318cSNicolas Vasilache     for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
9744e6318cSNicolas Vasilache       ids.push_back(
9844e6318cSNicolas Vasilache           affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
9944e6318cSNicolas Vasilache     }
10044e6318cSNicolas Vasilache 
10144e6318cSNicolas Vasilache     // clang-format off
10244e6318cSNicolas Vasilache       LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
10344e6318cSNicolas Vasilache                                        DBGS() << "--delinearization basis: ");
10444e6318cSNicolas Vasilache                  llvm::dbgs() << "\n";
10544e6318cSNicolas Vasilache                  llvm::interleaveComma(strides,
10644e6318cSNicolas Vasilache                                        DBGS() << "--delinearization strides: ");
10744e6318cSNicolas Vasilache                  llvm::dbgs() << "\n";
10844e6318cSNicolas Vasilache                  llvm::interleaveComma(delinearizingExprs,
10944e6318cSNicolas Vasilache                                        DBGS() << "--delinearization exprs: ");
11044e6318cSNicolas Vasilache                  llvm::dbgs() << "\n";
11144e6318cSNicolas Vasilache                  llvm::interleaveComma(ids, DBGS() << "--ids: ");
11244e6318cSNicolas Vasilache                  llvm::dbgs() << "\n";);
11344e6318cSNicolas Vasilache     // clang-format on
11444e6318cSNicolas Vasilache 
11544e6318cSNicolas Vasilache     // Return n-D ids for indexing and 1-D size + id for predicate generation.
11644e6318cSNicolas Vasilache       return IdBuilderResult{
11744e6318cSNicolas Vasilache           /*mappingIdOps=*/ids,
11844e6318cSNicolas Vasilache           /*availableMappingSizes=*/
11944e6318cSNicolas Vasilache           SmallVector<int64_t>{computeProduct(originalBasis)},
12044e6318cSNicolas Vasilache           // `forallMappingSizes` iterate in the scaled basis, they need to be
12144e6318cSNicolas Vasilache           // scaled back into the original basis to provide tight
12244e6318cSNicolas Vasilache           // activeMappingSizes quantities for predication.
12344e6318cSNicolas Vasilache           /*activeMappingSizes=*/
124*129f1001SKazu Hirata           SmallVector<int64_t>{computeProduct(forallMappingSizes) *
125*129f1001SKazu Hirata                                multiplicity},
126*129f1001SKazu Hirata           /*activeIdOps=*/SmallVector<Value>{cast<Value>(linearId)}};
12744e6318cSNicolas Vasilache   };
12844e6318cSNicolas Vasilache 
12944e6318cSNicolas Vasilache   return res;
13044e6318cSNicolas Vasilache }
13144e6318cSNicolas Vasilache 
13244e6318cSNicolas Vasilache /// Create a simple 3-D id builder that takes the `originalBasisOfr`
13344e6318cSNicolas Vasilache /// The 3-D id builder returns a 3-D vector of ids for indexing and 3-D sizes
13444e6318cSNicolas Vasilache /// + ids for predicate generation.
13544e6318cSNicolas Vasilache template <typename ThreadOrBlockIdOp>
13644e6318cSNicolas Vasilache static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
13744e6318cSNicolas Vasilache   auto res = [multiplicity](RewriterBase &rewriter, Location loc,
13844e6318cSNicolas Vasilache                             ArrayRef<int64_t> forallMappingSizes,
13944e6318cSNicolas Vasilache                             ArrayRef<int64_t> originalBasis) {
14044e6318cSNicolas Vasilache     IndexType indexType = rewriter.getIndexType();
14144e6318cSNicolas Vasilache     SmallVector<Value> ids{
14244e6318cSNicolas Vasilache         rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x),
14344e6318cSNicolas Vasilache         rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y),
14444e6318cSNicolas Vasilache         rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)};
14544e6318cSNicolas Vasilache     // In the 3-D mapping case, scale the first dimension by the multiplicity.
14644e6318cSNicolas Vasilache     SmallVector<Value> scaledIds = ids;
14744e6318cSNicolas Vasilache     AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
148*129f1001SKazu Hirata     scaledIds[0] = cast<Value>(affine::makeComposedFoldedAffineApply(
149*129f1001SKazu Hirata         rewriter, loc, d0.floorDiv(multiplicity), {scaledIds[0]}));
15044e6318cSNicolas Vasilache     // In the 3-D mapping case, unscale the first dimension by the multiplicity.
1515262865aSKazu Hirata     SmallVector<int64_t> forallMappingSizeInOriginalBasis(forallMappingSizes);
15244e6318cSNicolas Vasilache     forallMappingSizeInOriginalBasis[0] *= multiplicity;
15344e6318cSNicolas Vasilache     return IdBuilderResult{
15444e6318cSNicolas Vasilache         /*mappingIdOps=*/scaledIds,
15544e6318cSNicolas Vasilache         /*availableMappingSizes=*/SmallVector<int64_t>{originalBasis},
15644e6318cSNicolas Vasilache         // `forallMappingSizes` iterate in the scaled basis, they need to be
15744e6318cSNicolas Vasilache         // scaled back into the original basis to provide tight
15844e6318cSNicolas Vasilache         // activeMappingSizes quantities for predication.
15944e6318cSNicolas Vasilache         /*activeMappingSizes=*/
16044e6318cSNicolas Vasilache         SmallVector<int64_t>{forallMappingSizeInOriginalBasis},
16144e6318cSNicolas Vasilache         /*activeIdOps=*/ids};
16244e6318cSNicolas Vasilache   };
16344e6318cSNicolas Vasilache   return res;
16444e6318cSNicolas Vasilache }
16544e6318cSNicolas Vasilache 
16690ecfa2aSNicolas Vasilache namespace mlir {
16790ecfa2aSNicolas Vasilache namespace transform {
16890ecfa2aSNicolas Vasilache namespace gpu {
16990ecfa2aSNicolas Vasilache 
17044e6318cSNicolas Vasilache GpuIdBuilder::GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping,
17174cf9bcfSMehdi Amini                            const MappingIdBuilderFnType &fn)
17244e6318cSNicolas Vasilache     : mappingAttributes(), idBuilder() {
17344e6318cSNicolas Vasilache   if (useLinearMapping) {
17444e6318cSNicolas Vasilache     for (uint64_t d = static_cast<uint64_t>(MappingId::LinearDim0),
17544e6318cSNicolas Vasilache                   e = getMaxEnumValForMappingId();
17644e6318cSNicolas Vasilache          d <= e; ++d)
17744e6318cSNicolas Vasilache       mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value()));
17844e6318cSNicolas Vasilache   } else {
17944e6318cSNicolas Vasilache     for (uint64_t d = static_cast<uint64_t>(MappingId::DimX),
18044e6318cSNicolas Vasilache                   e = static_cast<uint64_t>(MappingId::DimZ);
18144e6318cSNicolas Vasilache          d <= e; ++d)
18244e6318cSNicolas Vasilache       mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value()));
18344e6318cSNicolas Vasilache   }
18490ecfa2aSNicolas Vasilache }
18590ecfa2aSNicolas Vasilache 
18644e6318cSNicolas Vasilache GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping)
18744e6318cSNicolas Vasilache     : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) {
18844e6318cSNicolas Vasilache         return GPUBlockMappingAttr::get(ctx, id);
18944e6318cSNicolas Vasilache       }) {
19044e6318cSNicolas Vasilache   idBuilder = useLinearMapping
19144e6318cSNicolas Vasilache                   ? commonLinearIdBuilderFn<BlockIdOp>(/*multiplicity=*/1)
19244e6318cSNicolas Vasilache                   : common3DIdBuilderFn<BlockIdOp>(/*multiplicity=*/1);
19390ecfa2aSNicolas Vasilache }
19490ecfa2aSNicolas Vasilache 
19544e6318cSNicolas Vasilache GpuWarpgroupIdBuilder::GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
19644e6318cSNicolas Vasilache                                              bool useLinearMapping)
19744e6318cSNicolas Vasilache     : GpuIdBuilder(ctx, useLinearMapping,
19844e6318cSNicolas Vasilache                    [](MLIRContext *ctx, MappingId id) {
19944e6318cSNicolas Vasilache                      return GPUWarpgroupMappingAttr::get(ctx, id);
20044e6318cSNicolas Vasilache                    }),
20144e6318cSNicolas Vasilache       warpSize(warpSize) {
20244e6318cSNicolas Vasilache   idBuilder = useLinearMapping
20344e6318cSNicolas Vasilache                   ? commonLinearIdBuilderFn<ThreadIdOp>(
20444e6318cSNicolas Vasilache                         /*multiplicity=*/kNumWarpsPerGroup * warpSize)
20544e6318cSNicolas Vasilache                   : common3DIdBuilderFn<ThreadIdOp>(
20644e6318cSNicolas Vasilache                         /*multiplicity=*/kNumWarpsPerGroup * warpSize);
20790ecfa2aSNicolas Vasilache }
20890ecfa2aSNicolas Vasilache 
20944e6318cSNicolas Vasilache GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
21044e6318cSNicolas Vasilache                                    bool useLinearMapping)
21144e6318cSNicolas Vasilache     : GpuIdBuilder(ctx, useLinearMapping,
21244e6318cSNicolas Vasilache                    [](MLIRContext *ctx, MappingId id) {
21344e6318cSNicolas Vasilache                      return GPUWarpMappingAttr::get(ctx, id);
21444e6318cSNicolas Vasilache                    }),
21544e6318cSNicolas Vasilache       warpSize(warpSize) {
21644e6318cSNicolas Vasilache   idBuilder =
21744e6318cSNicolas Vasilache       useLinearMapping
21844e6318cSNicolas Vasilache           ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize)
21944e6318cSNicolas Vasilache           : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize);
22044e6318cSNicolas Vasilache }
22190ecfa2aSNicolas Vasilache 
22244e6318cSNicolas Vasilache GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping)
22344e6318cSNicolas Vasilache     : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) {
22444e6318cSNicolas Vasilache         return GPUThreadMappingAttr::get(ctx, id);
22544e6318cSNicolas Vasilache       }) {
22644e6318cSNicolas Vasilache   idBuilder = useLinearMapping
22744e6318cSNicolas Vasilache                   ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1)
22844e6318cSNicolas Vasilache                   : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1);
22990ecfa2aSNicolas Vasilache }
23090ecfa2aSNicolas Vasilache 
23190ecfa2aSNicolas Vasilache DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp,
23290ecfa2aSNicolas Vasilache                                            std::optional<int64_t> gridDimX,
23390ecfa2aSNicolas Vasilache                                            std::optional<int64_t> gridDimY,
23490ecfa2aSNicolas Vasilache                                            std::optional<int64_t> gridDimZ,
23590ecfa2aSNicolas Vasilache                                            std::optional<int64_t> blockDimX,
23690ecfa2aSNicolas Vasilache                                            std::optional<int64_t> blockDimY,
23790ecfa2aSNicolas Vasilache                                            std::optional<int64_t> blockDimZ) {
23890ecfa2aSNicolas Vasilache 
23990ecfa2aSNicolas Vasilache   // TODO: pass a configuration object to set the limits properly.
24090ecfa2aSNicolas Vasilache   static constexpr int maxTotalBlockdim = 1024;
24190ecfa2aSNicolas Vasilache   static constexpr int maxBlockdimx = 1024;
24290ecfa2aSNicolas Vasilache   static constexpr int maxBlockdimy = 1024;
24390ecfa2aSNicolas Vasilache   static constexpr int maxBlockdimz = 64;
24490ecfa2aSNicolas Vasilache   static constexpr int maxTotalGriddim = 2147483647;
24590ecfa2aSNicolas Vasilache   static constexpr int maxGriddimx = 2147483647;
24690ecfa2aSNicolas Vasilache   static constexpr int maxGriddimy = 65535;
24790ecfa2aSNicolas Vasilache   static constexpr int maxGriddimz = 65535;
24890ecfa2aSNicolas Vasilache 
24990ecfa2aSNicolas Vasilache   if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) >
25090ecfa2aSNicolas Vasilache           maxTotalBlockdim ||
25190ecfa2aSNicolas Vasilache       (gridDimX.value_or(1) * gridDimY.value_or(1) * gridDimZ.value_or(1)) >
25290ecfa2aSNicolas Vasilache           maxTotalGriddim ||
25390ecfa2aSNicolas Vasilache       blockDimX.value_or(1) > maxBlockdimx ||
25490ecfa2aSNicolas Vasilache       blockDimY.value_or(1) > maxBlockdimy ||
25590ecfa2aSNicolas Vasilache       blockDimZ.value_or(1) > maxBlockdimz ||
25690ecfa2aSNicolas Vasilache       gridDimY.value_or(1) > maxGriddimy ||
25790ecfa2aSNicolas Vasilache       gridDimZ.value_or(1) > maxGriddimz ||
25890ecfa2aSNicolas Vasilache       gridDimX.value_or(1) > maxGriddimx) {
25990ecfa2aSNicolas Vasilache     return transformOp.emitSilenceableError()
26090ecfa2aSNicolas Vasilache            << "Trying to launch a GPU kernel with grid_dims = ("
26190ecfa2aSNicolas Vasilache            << gridDimX.value_or(1) << ", " << gridDimY.value_or(1) << ", "
26290ecfa2aSNicolas Vasilache            << gridDimZ.value_or(1) << ") block_dims = ("
26390ecfa2aSNicolas Vasilache            << blockDimX.value_or(1) << ", " << blockDimY.value_or(1) << ", "
26490ecfa2aSNicolas Vasilache            << blockDimZ.value_or(1) << "). It is larger than the limits.";
26590ecfa2aSNicolas Vasilache   }
26690ecfa2aSNicolas Vasilache   return DiagnosedSilenceableFailure::success();
26790ecfa2aSNicolas Vasilache }
26890ecfa2aSNicolas Vasilache 
26990ecfa2aSNicolas Vasilache DiagnosedSilenceableFailure createGpuLaunch(
27090ecfa2aSNicolas Vasilache     RewriterBase &rewriter, Location loc, TransformOpInterface transformOp,
27190ecfa2aSNicolas Vasilache     LaunchOp &launchOp, std::optional<int64_t> gridDimX,
27290ecfa2aSNicolas Vasilache     std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ,
27390ecfa2aSNicolas Vasilache     std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY,
27490ecfa2aSNicolas Vasilache     std::optional<int64_t> blockDimZ) {
27590ecfa2aSNicolas Vasilache   DiagnosedSilenceableFailure diag =
27690ecfa2aSNicolas Vasilache       checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
27790ecfa2aSNicolas Vasilache                      blockDimY, blockDimZ);
27890ecfa2aSNicolas Vasilache   if (!diag.succeeded())
27990ecfa2aSNicolas Vasilache     return diag;
28090ecfa2aSNicolas Vasilache 
28190ecfa2aSNicolas Vasilache   auto createConst = [&](int dim) {
28290ecfa2aSNicolas Vasilache     return rewriter.create<arith::ConstantIndexOp>(loc, dim);
28390ecfa2aSNicolas Vasilache   };
28490ecfa2aSNicolas Vasilache   OpBuilder::InsertionGuard guard(rewriter);
28590ecfa2aSNicolas Vasilache   Value one = createConst(1);
28690ecfa2aSNicolas Vasilache   Value gridSizeX = gridDimX.has_value() ? createConst(gridDimX.value()) : one;
28790ecfa2aSNicolas Vasilache   Value gridSizeY = gridDimY.has_value() ? createConst(gridDimY.value()) : one;
28890ecfa2aSNicolas Vasilache   Value gridSizeZ = gridDimZ.has_value() ? createConst(gridDimZ.value()) : one;
28990ecfa2aSNicolas Vasilache   Value blkSizeX = blockDimX.has_value() ? createConst(blockDimX.value()) : one;
29090ecfa2aSNicolas Vasilache   Value blkSizeY = blockDimY.has_value() ? createConst(blockDimY.value()) : one;
29190ecfa2aSNicolas Vasilache   Value blkSizeZ = blockDimZ.has_value() ? createConst(blockDimZ.value()) : one;
29290ecfa2aSNicolas Vasilache   launchOp = rewriter.create<LaunchOp>(loc, gridSizeX, gridSizeY, gridSizeZ,
29390ecfa2aSNicolas Vasilache                                        blkSizeX, blkSizeY, blkSizeZ);
29490ecfa2aSNicolas Vasilache   rewriter.setInsertionPointToEnd(&launchOp.getBody().front());
29590ecfa2aSNicolas Vasilache   rewriter.create<TerminatorOp>(loc);
29690ecfa2aSNicolas Vasilache   return DiagnosedSilenceableFailure::success();
29790ecfa2aSNicolas Vasilache }
29890ecfa2aSNicolas Vasilache 
29990ecfa2aSNicolas Vasilache /// Alter kernel configuration of the given kernel.
30090ecfa2aSNicolas Vasilache DiagnosedSilenceableFailure alterGpuLaunch(
30190ecfa2aSNicolas Vasilache     RewriterBase &rewriter, LaunchOp gpuLaunch,
30290ecfa2aSNicolas Vasilache     TransformOpInterface transformOp, std::optional<int64_t> gridDimX,
30390ecfa2aSNicolas Vasilache     std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ,
30490ecfa2aSNicolas Vasilache     std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY,
30590ecfa2aSNicolas Vasilache     std::optional<int64_t> blockDimZ) {
30690ecfa2aSNicolas Vasilache   DiagnosedSilenceableFailure diag =
30790ecfa2aSNicolas Vasilache       checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
30890ecfa2aSNicolas Vasilache                      blockDimY, blockDimZ);
30990ecfa2aSNicolas Vasilache   if (!diag.succeeded())
31090ecfa2aSNicolas Vasilache     return diag;
31190ecfa2aSNicolas Vasilache 
31290ecfa2aSNicolas Vasilache   KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
31390ecfa2aSNicolas Vasilache   OpBuilder::InsertionGuard guard(rewriter);
31490ecfa2aSNicolas Vasilache   rewriter.setInsertionPointAfterValue(currentBlockdim.x);
31590ecfa2aSNicolas Vasilache   auto createConstValue = [&](int dim) {
31690ecfa2aSNicolas Vasilache     return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(),
31790ecfa2aSNicolas Vasilache                                                    dim);
31890ecfa2aSNicolas Vasilache   };
31990ecfa2aSNicolas Vasilache 
32090ecfa2aSNicolas Vasilache   if (gridDimX.has_value())
32190ecfa2aSNicolas Vasilache     gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value()));
32290ecfa2aSNicolas Vasilache   if (gridDimY.has_value())
32390ecfa2aSNicolas Vasilache     gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value()));
32490ecfa2aSNicolas Vasilache   if (gridDimZ.has_value())
32590ecfa2aSNicolas Vasilache     gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value()));
32690ecfa2aSNicolas Vasilache   if (blockDimX.has_value())
32790ecfa2aSNicolas Vasilache     gpuLaunch.getBlockSizeXMutable().assign(
32890ecfa2aSNicolas Vasilache         createConstValue(blockDimX.value()));
32990ecfa2aSNicolas Vasilache   if (blockDimY.has_value())
33090ecfa2aSNicolas Vasilache     gpuLaunch.getBlockSizeYMutable().assign(
33190ecfa2aSNicolas Vasilache         createConstValue(blockDimY.value()));
33290ecfa2aSNicolas Vasilache   if (blockDimZ.has_value())
33390ecfa2aSNicolas Vasilache     gpuLaunch.getBlockSizeZMutable().assign(
33490ecfa2aSNicolas Vasilache         createConstValue(blockDimZ.value()));
33590ecfa2aSNicolas Vasilache   return DiagnosedSilenceableFailure::success();
33690ecfa2aSNicolas Vasilache }
33790ecfa2aSNicolas Vasilache 
33890ecfa2aSNicolas Vasilache } // namespace gpu
33990ecfa2aSNicolas Vasilache } // namespace transform
34090ecfa2aSNicolas Vasilache } // namespace mlir
341