xref: /llvm-project/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h (revision 5a9bdd85ee4d8527e2cedf44f3ce26ff414f9b6a)
1 //===- Utils.h - Utils for GPU transform ops --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_GPU_TRANSFORMOPS_UTILS_H
10 #define MLIR_DIALECT_GPU_TRANSFORMOPS_UTILS_H
11 
12 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
13 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
14 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 namespace mlir {
19 namespace gpu {
20 class GPUOp;
21 class LaunchOp;
22 enum class MappingId : uint64_t;
23 } // namespace gpu
24 namespace scf {
25 class ForallOp;
26 } // namespace scf
27 namespace transform {
28 namespace gpu {
29 
30 /// Helper type for functions that generate ids for the mapping of a scf.forall.
31 /// Operates on both 1) an "original" basis that represents the individual
32 /// thread and block ids and 2) a "scaled" basis that represents grouped ids
33 /// (e.g. block clusters, warpgroups and warps).
34 /// The mapping of ids is done in the "scaled" basis (i.e. when mapping to warps
35 /// a division by 32 occurs).
36 /// The predication is in the "original" basis using the "active" quantities
37 /// (`activeMappingSizes`, `availableMappingSizes` and `activeIdOps`).
38 struct IdBuilderResult {
39   // Ops used to replace the forall induction variables.
40   SmallVector<Value> mappingIdOps;
41   // Available mapping sizes used to predicate the forall body when they are
42   // larger than the predicate mapping sizes.
43   SmallVector<int64_t> availableMappingSizes;
44   // Actual mapping sizes used to predicate the forall body when they are
45   // smaller than the available mapping sizes.
46   SmallVector<int64_t> activeMappingSizes;
47   // Ops used to predicate the forall body when activeMappingSizes is smaller
48   // than the available mapping sizes.
49   SmallVector<Value> activeIdOps;
50 };
51 
52 /// Common gpu id builder type, allows the configuration of lowering for various
53 /// mapping schemes. Takes:
54 ///   - A rewriter with insertion point set before the forall op to rewrite.
55 ///   - The loc of the forall op to rewrite.
56 ///   - A list of positive integers carrying the mapping sizes for the current
57 ///     forall op to rewrite.
58 using GpuIdBuilderFnType = std::function<IdBuilderResult(
59     RewriterBase &, Location, ArrayRef<int64_t>, ArrayRef<int64_t>)>;
60 
61 /// Helper struct for configuring the rewrite of mapped scf.forall ops to
62 /// various gpu id configurations.
63 struct GpuIdBuilder {
64   using MappingIdBuilderFnType = std::function<DeviceMappingAttrInterface(
65       MLIRContext *, mlir::gpu::MappingId)>;
66 
67   GpuIdBuilder() = default;
68   GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping,
69                const MappingIdBuilderFnType &builder);
70 
71   /// The mapping attributes targeted by this generator.
72   SmallVector<DeviceMappingAttrInterface> mappingAttributes;
73 
74   /// The constructor that builds the concrete IR for mapping ids.
75   GpuIdBuilderFnType idBuilder;
76 };
77 
78 /// Builder for gpu::BlockIdOps used to map scf.forall to blocks.
79 /// If `useLinearMapping` is false, the `idBuilder` method returns 3D values
80 /// used for indexing rewrites as well as 3D sizes for predicate generation.
81 /// If `useLinearMapping` is true, the `idBuilder` method returns nD values
82 /// used for indexing rewrites as well as 1D sizes for predicate generation.
83 struct GpuBlockIdBuilder : public GpuIdBuilder {
84   GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
85 };
86 
87 /// Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
88 /// If `useLinearMapping` is false, the `idBuilder` method returns 3D values
89 /// used for indexing rewrites as well as 3D sizes for predicate generation.
90 /// If `useLinearMapping` is true, the `idBuilder` method returns nD values
91 /// used for indexing rewrites as well as 1D sizes for predicate generation.
92 struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
93   GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
94                         bool useLinearMapping = false);
95   int64_t warpSize = 32;
96   /// In the future this may be configured by the transformation.
97   static constexpr int64_t kNumWarpsPerGroup = 4;
98 };
99 
100 /// Builder for warp ids used to map scf.forall to reindexed warps.
101 /// If `useLinearMapping` is false, the `idBuilder` method returns 3D values
102 /// used for indexing rewrites as well as 3D sizes for predicate generation.
103 /// If `useLinearMapping` is true, the `idBuilder` method returns nD values
104 /// used for indexing rewrites as well as 1D sizes for predicate generation.
105 struct GpuWarpIdBuilder : public GpuIdBuilder {
106   GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
107                    bool useLinearMapping = false);
108   int64_t warpSize = 32;
109 };
110 
111 /// Builder for warp ids used to map scf.forall to reindexed threads.
112 /// If `useLinearMapping` is false, the `idBuilder` method returns 3D values
113 /// used for indexing rewrites as well as 3D sizes for predicate generation.
114 /// If `useLinearMapping` is true, the `idBuilder` method returns nD values
115 /// used for indexing rewrites as well as 1D sizes for predicate generation.
116 struct GpuThreadIdBuilder : public GpuIdBuilder {
117   GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
118 };
119 
120 /// Determine if the size of the kernel configuration is supported by the
121 /// GPU architecture being used.
122 /// TODO this is currently hardwired to CUDA, parameterize and generalize.
123 DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp,
124                                            std::optional<int64_t> gridDimX,
125                                            std::optional<int64_t> gridDimY,
126                                            std::optional<int64_t> gridDimZ,
127                                            std::optional<int64_t> blockDimX,
128                                            std::optional<int64_t> blockDimY,
129                                            std::optional<int64_t> blockDimZ);
130 
131 /// Create an empty-body gpu::LaunchOp using the provided kernel settings
132 /// and put a terminator within.
133 DiagnosedSilenceableFailure
134 createGpuLaunch(RewriterBase &rewriter, Location loc,
135                 TransformOpInterface transformOp, mlir::gpu::LaunchOp &launchOp,
136                 std::optional<int64_t> gridDimX = std::nullopt,
137                 std::optional<int64_t> gridDimY = std::nullopt,
138                 std::optional<int64_t> gridDimZ = std::nullopt,
139                 std::optional<int64_t> blockDimX = std::nullopt,
140                 std::optional<int64_t> blockDimY = std::nullopt,
141                 std::optional<int64_t> blockDimZ = std::nullopt);
142 
143 /// Alter kernel configuration of the given kernel.
144 DiagnosedSilenceableFailure
145 alterGpuLaunch(RewriterBase &rewriter, mlir::gpu::LaunchOp gpuLaunch,
146                TransformOpInterface transformOp,
147                std::optional<int64_t> gridDimX = std::nullopt,
148                std::optional<int64_t> gridDimY = std::nullopt,
149                std::optional<int64_t> gridDimZ = std::nullopt,
150                std::optional<int64_t> blockDimX = std::nullopt,
151                std::optional<int64_t> blockDimY = std::nullopt,
152                std::optional<int64_t> blockDimZ = std::nullopt);
153 
154 /// Find the unique top level scf::ForallOp within a given target op.
155 DiagnosedSilenceableFailure
156 findTopLevelForallOp(Operation *target, scf::ForallOp &topLevelForallOp,
157                      TransformOpInterface transformOp);
158 
159 } // namespace gpu
160 } // namespace transform
161 } // namespace mlir
162 
163 #endif // MLIR_DIALECT_GPU_TRANSFORMOPS_UTILS_H
164