1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
10 #include "mlir/IR/Matchers.h"
11 #include "mlir/Interfaces/FunctionInterfaces.h"
12 #include "mlir/Interfaces/InferIntRangeInterface.h"
13 #include "llvm/ADT/STLForwardCompat.h"
14 #include "llvm/Support/ErrorHandling.h"
15 #include "llvm/Support/MathExtras.h"
16 #include <optional>
17
18 using namespace mlir;
19 using namespace mlir::gpu;
20
21 // Maximum grid and block dimensions of all known GPUs are less than 2^32.
22 static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
23 // Maximum cluster size
24 static constexpr uint64_t kMaxClusterDim = 8;
25 // Maximum subgroups are no larger than 128.
26 static constexpr uint64_t kMaxSubgroupSize = 128;
27
getIndexRange(uint64_t umin,uint64_t umax)28 static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
29 unsigned width = IndexType::kInternalStorageBitWidth;
30 return ConstantIntRanges::fromUnsigned(APInt(width, umin),
31 APInt(width, umax));
32 }
33
34 namespace {
35 enum class LaunchDims : uint32_t { Block = 0, Grid = 1 };
36 } // end namespace
37
38 /// If the operation `op` is in a context that is annotated with maximum
39 /// launch dimensions (a launch op with constant block or grid
40 /// sizes or a launch_func op with the appropriate dimensions), return
41 /// the bound on the maximum size of the dimension that the op is querying.
42 /// IDs will be one less than this bound.
43
valueByDim(KernelDim3 dims,Dimension dim)44 static Value valueByDim(KernelDim3 dims, Dimension dim) {
45 switch (dim) {
46 case Dimension::x:
47 return dims.x;
48 case Dimension::y:
49 return dims.y;
50 case Dimension::z:
51 return dims.z;
52 }
53 llvm_unreachable("All dimension enum cases handled above");
54 }
55
zext(uint32_t arg)56 static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
57
58 static std::optional<uint64_t>
getKnownLaunchAttr(GPUFuncOp func,LaunchDims dims,Dimension dim)59 getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim) {
60 DenseI32ArrayAttr bounds;
61 switch (dims) {
62 case LaunchDims::Block:
63 bounds = func.getKnownBlockSizeAttr();
64 break;
65 case LaunchDims::Grid:
66 bounds = func.getKnownGridSizeAttr();
67 break;
68 }
69 if (!bounds)
70 return std::nullopt;
71 if (bounds.size() < static_cast<uint32_t>(dim))
72 return std::nullopt;
73 return zext(bounds[static_cast<uint32_t>(dim)]);
74 }
75
getKnownLaunchAttr(FunctionOpInterface func,StringRef attrName,Dimension dim)76 static std::optional<uint64_t> getKnownLaunchAttr(FunctionOpInterface func,
77 StringRef attrName,
78 Dimension dim) {
79 auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(attrName);
80 if (!bounds)
81 return std::nullopt;
82 if (bounds.size() < static_cast<uint32_t>(dim))
83 return std::nullopt;
84 return zext(bounds[static_cast<uint32_t>(dim)]);
85 }
86
87 template <typename Op>
getKnownLaunchDim(Op op,LaunchDims type)88 static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
89 Dimension dim = op.getDimension();
90 if (auto launch = op->template getParentOfType<LaunchOp>()) {
91 KernelDim3 bounds;
92 switch (type) {
93 case LaunchDims::Block:
94 bounds = launch.getBlockSizeOperandValues();
95 break;
96 case LaunchDims::Grid:
97 bounds = launch.getGridSizeOperandValues();
98 break;
99 }
100 Value maybeBound = valueByDim(bounds, dim);
101 APInt value;
102 if (matchPattern(maybeBound, m_ConstantInt(&value)))
103 return value.getZExtValue();
104 }
105
106 if (auto gpuFunc = op->template getParentOfType<GPUFuncOp>()) {
107 auto inherentAttr = getKnownLaunchAttr(gpuFunc, type, dim);
108 if (inherentAttr)
109 return inherentAttr;
110 }
111 if (auto func = op->template getParentOfType<FunctionOpInterface>()) {
112 StringRef attrName;
113 switch (type) {
114 case LaunchDims::Block:
115 attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr();
116 break;
117 case LaunchDims::Grid:
118 attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr();
119 break;
120 }
121 auto discardableAttr = getKnownLaunchAttr(func, attrName, dim);
122 if (discardableAttr)
123 return discardableAttr;
124 }
125 return std::nullopt;
126 }
127
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)128 void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
129 SetIntRangeFn setResultRange) {
130 uint64_t max = kMaxDim;
131 if (auto specified = getUpperBound())
132 max = specified->getZExtValue();
133 setResultRange(getResult(), getIndexRange(1, max));
134 }
135
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)136 void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
137 SetIntRangeFn setResultRange) {
138 uint64_t max = kMaxClusterDim;
139 if (auto specified = getUpperBound())
140 max = specified->getZExtValue();
141 setResultRange(getResult(), getIndexRange(1, max));
142 }
143
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)144 void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
145 SetIntRangeFn setResultRange) {
146 uint64_t max = kMaxDim;
147 if (auto specified = getUpperBound())
148 max = specified->getZExtValue();
149 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
150 }
151
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)152 void ClusterBlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
153 SetIntRangeFn setResultRange) {
154 uint64_t max = kMaxClusterDim;
155 if (auto specified = getUpperBound())
156 max = specified->getZExtValue();
157 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
158 }
159
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)160 void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
161 SetIntRangeFn setResultRange) {
162 std::optional<uint64_t> knownVal =
163 getKnownLaunchDim(*this, LaunchDims::Block);
164 if (knownVal)
165 return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
166 ;
167 uint64_t max = kMaxDim;
168 if (auto specified = getUpperBound())
169 max = specified->getZExtValue();
170 setResultRange(getResult(), getIndexRange(1, max));
171 }
172
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)173 void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
174 SetIntRangeFn setResultRange) {
175 uint64_t max = kMaxDim;
176 if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Grid))
177 max = fromContext.value();
178 if (auto specified = getUpperBound())
179 max = specified->getZExtValue();
180 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
181 }
182
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)183 void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
184 SetIntRangeFn setResultRange) {
185 std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
186 if (knownVal)
187 return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
188 uint64_t max = kMaxDim;
189 if (auto specified = getUpperBound())
190 max = specified->getZExtValue();
191 setResultRange(getResult(), getIndexRange(1, max));
192 }
193
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)194 void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
195 SetIntRangeFn setResultRange) {
196 uint64_t max = kMaxDim;
197 if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Block))
198 max = fromContext.value();
199 if (auto specified = getUpperBound())
200 max = specified->getZExtValue();
201 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
202 }
203
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)204 void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
205 SetIntRangeFn setResultRange) {
206 uint64_t max = kMaxSubgroupSize;
207 if (auto specified = getUpperBound())
208 max = specified->getZExtValue();
209 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
210 }
211
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)212 void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
213 SetIntRangeFn setResultRange) {
214 uint64_t max = kMaxDim;
215 if (auto specified = getUpperBound())
216 max = specified->getZExtValue();
217 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
218 }
219
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)220 void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
221 SetIntRangeFn setResultRange) {
222 if (auto specified = getUpperBound())
223 return setResultRange(getResult(),
224 getIndexRange(0, specified->getZExtValue() - 1ULL));
225
226 uint64_t blockDimMax =
227 getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
228 uint64_t gridDimMax =
229 getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
230 setResultRange(getResult(),
231 getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
232 }
233
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)234 void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
235 SetIntRangeFn setResultRange) {
236 uint64_t max = kMaxDim;
237 if (auto specified = getUpperBound())
238 max = specified->getZExtValue();
239 setResultRange(getResult(), getIndexRange(1, max));
240 }
241
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)242 void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
243 SetIntRangeFn setResultRange) {
244 uint64_t max = kMaxSubgroupSize;
245 if (auto specified = getUpperBound())
246 max = specified->getZExtValue();
247 setResultRange(getResult(), getIndexRange(1, max));
248 }
249
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)250 void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
251 SetIntRangeFn setResultRange) {
252 auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
253 Value idxResult) {
254 if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
255 return;
256 ConstantIntRanges dimRange =
257 argRange.intersection(getIndexRange(1, kMaxDim));
258 setResultRange(dimResult, dimRange);
259 ConstantIntRanges idxRange =
260 getIndexRange(0, dimRange.umax().getZExtValue() - 1);
261 setResultRange(idxResult, idxRange);
262 };
263
264 argRanges = argRanges.drop_front(getAsyncDependencies().size());
265 KernelDim3 gridDims = getGridSize();
266 KernelDim3 blockIds = getBlockIds();
267 setRange(argRanges[0], gridDims.x, blockIds.x);
268 setRange(argRanges[1], gridDims.y, blockIds.y);
269 setRange(argRanges[2], gridDims.z, blockIds.z);
270 KernelDim3 blockDims = getBlockSize();
271 KernelDim3 threadIds = getThreadIds();
272 setRange(argRanges[3], blockDims.x, threadIds.x);
273 setRange(argRanges[4], blockDims.y, threadIds.y);
274 setRange(argRanges[5], blockDims.z, threadIds.z);
275 }
276