xref: /llvm-project/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp (revision 43fd4c49bd8d54b9058620f0a885c7a5672fd602)
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
10 #include "mlir/IR/Matchers.h"
11 #include "mlir/Interfaces/FunctionInterfaces.h"
12 #include "mlir/Interfaces/InferIntRangeInterface.h"
13 #include "llvm/ADT/STLForwardCompat.h"
14 #include "llvm/Support/ErrorHandling.h"
15 #include "llvm/Support/MathExtras.h"
16 #include <optional>
17 
18 using namespace mlir;
19 using namespace mlir::gpu;
20 
21 // Maximum grid and block dimensions of all known GPUs are less than 2^32.
22 static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
23 // Maximum cluster size
24 static constexpr uint64_t kMaxClusterDim = 8;
25 // Maximum subgroups are no larger than 128.
26 static constexpr uint64_t kMaxSubgroupSize = 128;
27 
getIndexRange(uint64_t umin,uint64_t umax)28 static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
29   unsigned width = IndexType::kInternalStorageBitWidth;
30   return ConstantIntRanges::fromUnsigned(APInt(width, umin),
31                                          APInt(width, umax));
32 }
33 
34 namespace {
35 enum class LaunchDims : uint32_t { Block = 0, Grid = 1 };
36 } // end namespace
37 
38 /// If the operation `op` is in a context that is annotated with maximum
39 /// launch dimensions (a launch op with constant block or grid
40 /// sizes or a launch_func op with the appropriate dimensions), return
41 /// the bound on the maximum size of the dimension that the op is querying.
42 /// IDs will be one less than this bound.
43 
valueByDim(KernelDim3 dims,Dimension dim)44 static Value valueByDim(KernelDim3 dims, Dimension dim) {
45   switch (dim) {
46   case Dimension::x:
47     return dims.x;
48   case Dimension::y:
49     return dims.y;
50   case Dimension::z:
51     return dims.z;
52   }
53   llvm_unreachable("All dimension enum cases handled above");
54 }
55 
zext(uint32_t arg)56 static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
57 
58 static std::optional<uint64_t>
getKnownLaunchAttr(GPUFuncOp func,LaunchDims dims,Dimension dim)59 getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim) {
60   DenseI32ArrayAttr bounds;
61   switch (dims) {
62   case LaunchDims::Block:
63     bounds = func.getKnownBlockSizeAttr();
64     break;
65   case LaunchDims::Grid:
66     bounds = func.getKnownGridSizeAttr();
67     break;
68   }
69   if (!bounds)
70     return std::nullopt;
71   if (bounds.size() < static_cast<uint32_t>(dim))
72     return std::nullopt;
73   return zext(bounds[static_cast<uint32_t>(dim)]);
74 }
75 
getKnownLaunchAttr(FunctionOpInterface func,StringRef attrName,Dimension dim)76 static std::optional<uint64_t> getKnownLaunchAttr(FunctionOpInterface func,
77                                                   StringRef attrName,
78                                                   Dimension dim) {
79   auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(attrName);
80   if (!bounds)
81     return std::nullopt;
82   if (bounds.size() < static_cast<uint32_t>(dim))
83     return std::nullopt;
84   return zext(bounds[static_cast<uint32_t>(dim)]);
85 }
86 
87 template <typename Op>
getKnownLaunchDim(Op op,LaunchDims type)88 static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
89   Dimension dim = op.getDimension();
90   if (auto launch = op->template getParentOfType<LaunchOp>()) {
91     KernelDim3 bounds;
92     switch (type) {
93     case LaunchDims::Block:
94       bounds = launch.getBlockSizeOperandValues();
95       break;
96     case LaunchDims::Grid:
97       bounds = launch.getGridSizeOperandValues();
98       break;
99     }
100     Value maybeBound = valueByDim(bounds, dim);
101     APInt value;
102     if (matchPattern(maybeBound, m_ConstantInt(&value)))
103       return value.getZExtValue();
104   }
105 
106   if (auto gpuFunc = op->template getParentOfType<GPUFuncOp>()) {
107     auto inherentAttr = getKnownLaunchAttr(gpuFunc, type, dim);
108     if (inherentAttr)
109       return inherentAttr;
110   }
111   if (auto func = op->template getParentOfType<FunctionOpInterface>()) {
112     StringRef attrName;
113     switch (type) {
114     case LaunchDims::Block:
115       attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr();
116       break;
117     case LaunchDims::Grid:
118       attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr();
119       break;
120     }
121     auto discardableAttr = getKnownLaunchAttr(func, attrName, dim);
122     if (discardableAttr)
123       return discardableAttr;
124   }
125   return std::nullopt;
126 }
127 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)128 void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
129                                      SetIntRangeFn setResultRange) {
130   uint64_t max = kMaxDim;
131   if (auto specified = getUpperBound())
132     max = specified->getZExtValue();
133   setResultRange(getResult(), getIndexRange(1, max));
134 }
135 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)136 void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
137                                            SetIntRangeFn setResultRange) {
138   uint64_t max = kMaxClusterDim;
139   if (auto specified = getUpperBound())
140     max = specified->getZExtValue();
141   setResultRange(getResult(), getIndexRange(1, max));
142 }
143 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)144 void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
145                                     SetIntRangeFn setResultRange) {
146   uint64_t max = kMaxDim;
147   if (auto specified = getUpperBound())
148     max = specified->getZExtValue();
149   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
150 }
151 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)152 void ClusterBlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
153                                          SetIntRangeFn setResultRange) {
154   uint64_t max = kMaxClusterDim;
155   if (auto specified = getUpperBound())
156     max = specified->getZExtValue();
157   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
158 }
159 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)160 void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
161                                    SetIntRangeFn setResultRange) {
162   std::optional<uint64_t> knownVal =
163       getKnownLaunchDim(*this, LaunchDims::Block);
164   if (knownVal)
165     return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
166   ;
167   uint64_t max = kMaxDim;
168   if (auto specified = getUpperBound())
169     max = specified->getZExtValue();
170   setResultRange(getResult(), getIndexRange(1, max));
171 }
172 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)173 void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
174                                   SetIntRangeFn setResultRange) {
175   uint64_t max = kMaxDim;
176   if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Grid))
177     max = fromContext.value();
178   if (auto specified = getUpperBound())
179     max = specified->getZExtValue();
180   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
181 }
182 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)183 void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
184                                   SetIntRangeFn setResultRange) {
185   std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
186   if (knownVal)
187     return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
188   uint64_t max = kMaxDim;
189   if (auto specified = getUpperBound())
190     max = specified->getZExtValue();
191   setResultRange(getResult(), getIndexRange(1, max));
192 }
193 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)194 void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
195                                    SetIntRangeFn setResultRange) {
196   uint64_t max = kMaxDim;
197   if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Block))
198     max = fromContext.value();
199   if (auto specified = getUpperBound())
200     max = specified->getZExtValue();
201   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
202 }
203 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)204 void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
205                                  SetIntRangeFn setResultRange) {
206   uint64_t max = kMaxSubgroupSize;
207   if (auto specified = getUpperBound())
208     max = specified->getZExtValue();
209   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
210 }
211 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)212 void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
213                                      SetIntRangeFn setResultRange) {
214   uint64_t max = kMaxDim;
215   if (auto specified = getUpperBound())
216     max = specified->getZExtValue();
217   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
218 }
219 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)220 void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
221                                    SetIntRangeFn setResultRange) {
222   if (auto specified = getUpperBound())
223     return setResultRange(getResult(),
224                           getIndexRange(0, specified->getZExtValue() - 1ULL));
225 
226   uint64_t blockDimMax =
227       getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
228   uint64_t gridDimMax =
229       getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
230   setResultRange(getResult(),
231                  getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
232 }
233 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)234 void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
235                                        SetIntRangeFn setResultRange) {
236   uint64_t max = kMaxDim;
237   if (auto specified = getUpperBound())
238     max = specified->getZExtValue();
239   setResultRange(getResult(), getIndexRange(1, max));
240 }
241 
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)242 void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
243                                        SetIntRangeFn setResultRange) {
244   uint64_t max = kMaxSubgroupSize;
245   if (auto specified = getUpperBound())
246     max = specified->getZExtValue();
247   setResultRange(getResult(), getIndexRange(1, max));
248 }
249 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)250 void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
251                                  SetIntRangeFn setResultRange) {
252   auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
253                       Value idxResult) {
254     if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
255       return;
256     ConstantIntRanges dimRange =
257         argRange.intersection(getIndexRange(1, kMaxDim));
258     setResultRange(dimResult, dimRange);
259     ConstantIntRanges idxRange =
260         getIndexRange(0, dimRange.umax().getZExtValue() - 1);
261     setResultRange(idxResult, idxRange);
262   };
263 
264   argRanges = argRanges.drop_front(getAsyncDependencies().size());
265   KernelDim3 gridDims = getGridSize();
266   KernelDim3 blockIds = getBlockIds();
267   setRange(argRanges[0], gridDims.x, blockIds.x);
268   setRange(argRanges[1], gridDims.y, blockIds.y);
269   setRange(argRanges[2], gridDims.z, blockIds.z);
270   KernelDim3 blockDims = getBlockSize();
271   KernelDim3 threadIds = getThreadIds();
272   setRange(argRanges[3], blockDims.x, threadIds.x);
273   setRange(argRanges[4], blockDims.y, threadIds.y);
274   setRange(argRanges[5], blockDims.z, threadIds.z);
275 }
276