13e01af09SChristian Sigg //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
23e01af09SChristian Sigg //
33e01af09SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43e01af09SChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
53e01af09SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63e01af09SChristian Sigg //
73e01af09SChristian Sigg //===----------------------------------------------------------------------===//
83e01af09SChristian Sigg
93e01af09SChristian Sigg #include "mlir/Dialect/GPU/IR/GPUDialect.h"
10be575c5dSKrzysztof Drewniak #include "mlir/IR/Matchers.h"
11*43fd4c49SKrzysztof Drewniak #include "mlir/Interfaces/FunctionInterfaces.h"
123e01af09SChristian Sigg #include "mlir/Interfaces/InferIntRangeInterface.h"
13be575c5dSKrzysztof Drewniak #include "llvm/ADT/STLForwardCompat.h"
14be575c5dSKrzysztof Drewniak #include "llvm/Support/ErrorHandling.h"
15be575c5dSKrzysztof Drewniak #include "llvm/Support/MathExtras.h"
16be575c5dSKrzysztof Drewniak #include <optional>
173e01af09SChristian Sigg
183e01af09SChristian Sigg using namespace mlir;
193e01af09SChristian Sigg using namespace mlir::gpu;
203e01af09SChristian Sigg
213e01af09SChristian Sigg // Maximum grid and block dimensions of all known GPUs are less than 2^32.
223e01af09SChristian Sigg static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
23edf5cae7SGuray Ozen // Maximum cluster size
24edf5cae7SGuray Ozen static constexpr uint64_t kMaxClusterDim = 8;
253e01af09SChristian Sigg // Maximum subgroups are no larger than 128.
263e01af09SChristian Sigg static constexpr uint64_t kMaxSubgroupSize = 128;
273e01af09SChristian Sigg
getIndexRange(uint64_t umin,uint64_t umax)283e01af09SChristian Sigg static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
293e01af09SChristian Sigg unsigned width = IndexType::kInternalStorageBitWidth;
303e01af09SChristian Sigg return ConstantIntRanges::fromUnsigned(APInt(width, umin),
313e01af09SChristian Sigg APInt(width, umax));
323e01af09SChristian Sigg }
333e01af09SChristian Sigg
34be575c5dSKrzysztof Drewniak namespace {
35be575c5dSKrzysztof Drewniak enum class LaunchDims : uint32_t { Block = 0, Grid = 1 };
36be575c5dSKrzysztof Drewniak } // end namespace
37be575c5dSKrzysztof Drewniak
38be575c5dSKrzysztof Drewniak /// If the operation `op` is in a context that is annotated with maximum
39be575c5dSKrzysztof Drewniak /// launch dimensions (a launch op with constant block or grid
40be575c5dSKrzysztof Drewniak /// sizes or a launch_func op with the appropriate dimensions), return
41be575c5dSKrzysztof Drewniak /// the bound on the maximum size of the dimension that the op is querying.
42be575c5dSKrzysztof Drewniak /// IDs will be one less than this bound.
43be575c5dSKrzysztof Drewniak
valueByDim(KernelDim3 dims,Dimension dim)44be575c5dSKrzysztof Drewniak static Value valueByDim(KernelDim3 dims, Dimension dim) {
45be575c5dSKrzysztof Drewniak switch (dim) {
46be575c5dSKrzysztof Drewniak case Dimension::x:
47be575c5dSKrzysztof Drewniak return dims.x;
48be575c5dSKrzysztof Drewniak case Dimension::y:
49be575c5dSKrzysztof Drewniak return dims.y;
50be575c5dSKrzysztof Drewniak case Dimension::z:
51be575c5dSKrzysztof Drewniak return dims.z;
52be575c5dSKrzysztof Drewniak }
53be575c5dSKrzysztof Drewniak llvm_unreachable("All dimension enum cases handled above");
54be575c5dSKrzysztof Drewniak }
55be575c5dSKrzysztof Drewniak
zext(uint32_t arg)56be575c5dSKrzysztof Drewniak static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
57be575c5dSKrzysztof Drewniak
58*43fd4c49SKrzysztof Drewniak static std::optional<uint64_t>
getKnownLaunchAttr(GPUFuncOp func,LaunchDims dims,Dimension dim)59*43fd4c49SKrzysztof Drewniak getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim) {
60*43fd4c49SKrzysztof Drewniak DenseI32ArrayAttr bounds;
61*43fd4c49SKrzysztof Drewniak switch (dims) {
62*43fd4c49SKrzysztof Drewniak case LaunchDims::Block:
63*43fd4c49SKrzysztof Drewniak bounds = func.getKnownBlockSizeAttr();
64*43fd4c49SKrzysztof Drewniak break;
65*43fd4c49SKrzysztof Drewniak case LaunchDims::Grid:
66*43fd4c49SKrzysztof Drewniak bounds = func.getKnownGridSizeAttr();
67*43fd4c49SKrzysztof Drewniak break;
68*43fd4c49SKrzysztof Drewniak }
69*43fd4c49SKrzysztof Drewniak if (!bounds)
70*43fd4c49SKrzysztof Drewniak return std::nullopt;
71*43fd4c49SKrzysztof Drewniak if (bounds.size() < static_cast<uint32_t>(dim))
72*43fd4c49SKrzysztof Drewniak return std::nullopt;
73*43fd4c49SKrzysztof Drewniak return zext(bounds[static_cast<uint32_t>(dim)]);
74*43fd4c49SKrzysztof Drewniak }
75*43fd4c49SKrzysztof Drewniak
getKnownLaunchAttr(FunctionOpInterface func,StringRef attrName,Dimension dim)76*43fd4c49SKrzysztof Drewniak static std::optional<uint64_t> getKnownLaunchAttr(FunctionOpInterface func,
77*43fd4c49SKrzysztof Drewniak StringRef attrName,
78*43fd4c49SKrzysztof Drewniak Dimension dim) {
79*43fd4c49SKrzysztof Drewniak auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(attrName);
80*43fd4c49SKrzysztof Drewniak if (!bounds)
81*43fd4c49SKrzysztof Drewniak return std::nullopt;
82*43fd4c49SKrzysztof Drewniak if (bounds.size() < static_cast<uint32_t>(dim))
83*43fd4c49SKrzysztof Drewniak return std::nullopt;
84*43fd4c49SKrzysztof Drewniak return zext(bounds[static_cast<uint32_t>(dim)]);
85*43fd4c49SKrzysztof Drewniak }
86*43fd4c49SKrzysztof Drewniak
87be575c5dSKrzysztof Drewniak template <typename Op>
getKnownLaunchDim(Op op,LaunchDims type)880a81ace0SKazu Hirata static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
89be575c5dSKrzysztof Drewniak Dimension dim = op.getDimension();
90be575c5dSKrzysztof Drewniak if (auto launch = op->template getParentOfType<LaunchOp>()) {
91be575c5dSKrzysztof Drewniak KernelDim3 bounds;
92be575c5dSKrzysztof Drewniak switch (type) {
93be575c5dSKrzysztof Drewniak case LaunchDims::Block:
94be575c5dSKrzysztof Drewniak bounds = launch.getBlockSizeOperandValues();
95be575c5dSKrzysztof Drewniak break;
96be575c5dSKrzysztof Drewniak case LaunchDims::Grid:
97be575c5dSKrzysztof Drewniak bounds = launch.getGridSizeOperandValues();
98be575c5dSKrzysztof Drewniak break;
99be575c5dSKrzysztof Drewniak }
100be575c5dSKrzysztof Drewniak Value maybeBound = valueByDim(bounds, dim);
101be575c5dSKrzysztof Drewniak APInt value;
102be575c5dSKrzysztof Drewniak if (matchPattern(maybeBound, m_ConstantInt(&value)))
103be575c5dSKrzysztof Drewniak return value.getZExtValue();
104be575c5dSKrzysztof Drewniak }
105be575c5dSKrzysztof Drewniak
106*43fd4c49SKrzysztof Drewniak if (auto gpuFunc = op->template getParentOfType<GPUFuncOp>()) {
107*43fd4c49SKrzysztof Drewniak auto inherentAttr = getKnownLaunchAttr(gpuFunc, type, dim);
108*43fd4c49SKrzysztof Drewniak if (inherentAttr)
109*43fd4c49SKrzysztof Drewniak return inherentAttr;
110*43fd4c49SKrzysztof Drewniak }
111*43fd4c49SKrzysztof Drewniak if (auto func = op->template getParentOfType<FunctionOpInterface>()) {
112*43fd4c49SKrzysztof Drewniak StringRef attrName;
113be575c5dSKrzysztof Drewniak switch (type) {
114be575c5dSKrzysztof Drewniak case LaunchDims::Block:
115*43fd4c49SKrzysztof Drewniak attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr();
116*43fd4c49SKrzysztof Drewniak break;
117be575c5dSKrzysztof Drewniak case LaunchDims::Grid:
118*43fd4c49SKrzysztof Drewniak attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr();
119*43fd4c49SKrzysztof Drewniak break;
120be575c5dSKrzysztof Drewniak }
121*43fd4c49SKrzysztof Drewniak auto discardableAttr = getKnownLaunchAttr(func, attrName, dim);
122*43fd4c49SKrzysztof Drewniak if (discardableAttr)
123*43fd4c49SKrzysztof Drewniak return discardableAttr;
124be575c5dSKrzysztof Drewniak }
125be575c5dSKrzysztof Drewniak return std::nullopt;
126be575c5dSKrzysztof Drewniak }
127be575c5dSKrzysztof Drewniak
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)128edf5cae7SGuray Ozen void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
129edf5cae7SGuray Ozen SetIntRangeFn setResultRange) {
130*43fd4c49SKrzysztof Drewniak uint64_t max = kMaxDim;
131*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
132*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
133bd6568c9SPradeep Kumar setResultRange(getResult(), getIndexRange(1, max));
134bd6568c9SPradeep Kumar }
135bd6568c9SPradeep Kumar
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)136bd6568c9SPradeep Kumar void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
137bd6568c9SPradeep Kumar SetIntRangeFn setResultRange) {
138*43fd4c49SKrzysztof Drewniak uint64_t max = kMaxClusterDim;
139*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
140*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
141*43fd4c49SKrzysztof Drewniak setResultRange(getResult(), getIndexRange(1, max));
142edf5cae7SGuray Ozen }
143edf5cae7SGuray Ozen
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)144edf5cae7SGuray Ozen void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
145edf5cae7SGuray Ozen SetIntRangeFn setResultRange) {
146*43fd4c49SKrzysztof Drewniak uint64_t max = kMaxDim;
147*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
148*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
149edf5cae7SGuray Ozen setResultRange(getResult(), getIndexRange(0, max - 1ULL));
150edf5cae7SGuray Ozen }
151edf5cae7SGuray Ozen
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)152bd6568c9SPradeep Kumar void ClusterBlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
153bd6568c9SPradeep Kumar SetIntRangeFn setResultRange) {
154bd6568c9SPradeep Kumar uint64_t max = kMaxClusterDim;
155*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
156*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
157bd6568c9SPradeep Kumar setResultRange(getResult(), getIndexRange(0, max - 1ULL));
158bd6568c9SPradeep Kumar }
159bd6568c9SPradeep Kumar
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)1603e01af09SChristian Sigg void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
1613e01af09SChristian Sigg SetIntRangeFn setResultRange) {
1620a81ace0SKazu Hirata std::optional<uint64_t> knownVal =
1630a81ace0SKazu Hirata getKnownLaunchDim(*this, LaunchDims::Block);
164be575c5dSKrzysztof Drewniak if (knownVal)
165*43fd4c49SKrzysztof Drewniak return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
166*43fd4c49SKrzysztof Drewniak ;
167*43fd4c49SKrzysztof Drewniak uint64_t max = kMaxDim;
168*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
169*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
170*43fd4c49SKrzysztof Drewniak setResultRange(getResult(), getIndexRange(1, max));
1713e01af09SChristian Sigg }
1723e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)1733e01af09SChristian Sigg void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
1743e01af09SChristian Sigg SetIntRangeFn setResultRange) {
175*43fd4c49SKrzysztof Drewniak uint64_t max = kMaxDim;
176*43fd4c49SKrzysztof Drewniak if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Grid))
177*43fd4c49SKrzysztof Drewniak max = fromContext.value();
178*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
179*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
180be575c5dSKrzysztof Drewniak setResultRange(getResult(), getIndexRange(0, max - 1ULL));
1813e01af09SChristian Sigg }
1823e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)1833e01af09SChristian Sigg void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
1843e01af09SChristian Sigg SetIntRangeFn setResultRange) {
1850a81ace0SKazu Hirata std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
186be575c5dSKrzysztof Drewniak if (knownVal)
187*43fd4c49SKrzysztof Drewniak return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
188*43fd4c49SKrzysztof Drewniak uint64_t max = kMaxDim;
189*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
190*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
191*43fd4c49SKrzysztof Drewniak setResultRange(getResult(), getIndexRange(1, max));
1923e01af09SChristian Sigg }
1933e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)1943e01af09SChristian Sigg void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
1953e01af09SChristian Sigg SetIntRangeFn setResultRange) {
196*43fd4c49SKrzysztof Drewniak uint64_t max = kMaxDim;
197*43fd4c49SKrzysztof Drewniak if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Block))
198*43fd4c49SKrzysztof Drewniak max = fromContext.value();
199*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
200*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
201be575c5dSKrzysztof Drewniak setResultRange(getResult(), getIndexRange(0, max - 1ULL));
2023e01af09SChristian Sigg }
2033e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)2043e01af09SChristian Sigg void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
2053e01af09SChristian Sigg SetIntRangeFn setResultRange) {
206*43fd4c49SKrzysztof Drewniak uint64_t max = kMaxSubgroupSize;
207*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
208*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
209*43fd4c49SKrzysztof Drewniak setResultRange(getResult(), getIndexRange(0, max - 1ULL));
2103e01af09SChristian Sigg }
2113e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)2123e01af09SChristian Sigg void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
2133e01af09SChristian Sigg SetIntRangeFn setResultRange) {
214*43fd4c49SKrzysztof Drewniak uint64_t max = kMaxDim;
215*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
216*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
217*43fd4c49SKrzysztof Drewniak setResultRange(getResult(), getIndexRange(0, max - 1ULL));
2183e01af09SChristian Sigg }
2193e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)2203e01af09SChristian Sigg void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
2213e01af09SChristian Sigg SetIntRangeFn setResultRange) {
222*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
223*43fd4c49SKrzysztof Drewniak return setResultRange(getResult(),
224*43fd4c49SKrzysztof Drewniak getIndexRange(0, specified->getZExtValue() - 1ULL));
225*43fd4c49SKrzysztof Drewniak
226be575c5dSKrzysztof Drewniak uint64_t blockDimMax =
227be575c5dSKrzysztof Drewniak getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
228be575c5dSKrzysztof Drewniak uint64_t gridDimMax =
229be575c5dSKrzysztof Drewniak getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
2303e01af09SChristian Sigg setResultRange(getResult(),
231be575c5dSKrzysztof Drewniak getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
2323e01af09SChristian Sigg }
2333e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)2343e01af09SChristian Sigg void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
2353e01af09SChristian Sigg SetIntRangeFn setResultRange) {
236*43fd4c49SKrzysztof Drewniak uint64_t max = kMaxDim;
237*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
238*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
239*43fd4c49SKrzysztof Drewniak setResultRange(getResult(), getIndexRange(1, max));
2403e01af09SChristian Sigg }
2413e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges>,SetIntRangeFn setResultRange)2423e01af09SChristian Sigg void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
2433e01af09SChristian Sigg SetIntRangeFn setResultRange) {
244*43fd4c49SKrzysztof Drewniak uint64_t max = kMaxSubgroupSize;
245*43fd4c49SKrzysztof Drewniak if (auto specified = getUpperBound())
246*43fd4c49SKrzysztof Drewniak max = specified->getZExtValue();
247*43fd4c49SKrzysztof Drewniak setResultRange(getResult(), getIndexRange(1, max));
2483e01af09SChristian Sigg }
2493e01af09SChristian Sigg
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)2503e01af09SChristian Sigg void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2513e01af09SChristian Sigg SetIntRangeFn setResultRange) {
25228c17a4bSMehdi Amini auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
2533e01af09SChristian Sigg Value idxResult) {
2543e01af09SChristian Sigg if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
2553e01af09SChristian Sigg return;
2563e01af09SChristian Sigg ConstantIntRanges dimRange =
2573e01af09SChristian Sigg argRange.intersection(getIndexRange(1, kMaxDim));
2583e01af09SChristian Sigg setResultRange(dimResult, dimRange);
2593e01af09SChristian Sigg ConstantIntRanges idxRange =
2603e01af09SChristian Sigg getIndexRange(0, dimRange.umax().getZExtValue() - 1);
2613e01af09SChristian Sigg setResultRange(idxResult, idxRange);
2623e01af09SChristian Sigg };
2633e01af09SChristian Sigg
26410c04f46SRiver Riddle argRanges = argRanges.drop_front(getAsyncDependencies().size());
2653e01af09SChristian Sigg KernelDim3 gridDims = getGridSize();
2663e01af09SChristian Sigg KernelDim3 blockIds = getBlockIds();
2673e01af09SChristian Sigg setRange(argRanges[0], gridDims.x, blockIds.x);
2683e01af09SChristian Sigg setRange(argRanges[1], gridDims.y, blockIds.y);
2693e01af09SChristian Sigg setRange(argRanges[2], gridDims.z, blockIds.z);
2703e01af09SChristian Sigg KernelDim3 blockDims = getBlockSize();
2713e01af09SChristian Sigg KernelDim3 threadIds = getThreadIds();
2723e01af09SChristian Sigg setRange(argRanges[3], blockDims.x, threadIds.x);
2733e01af09SChristian Sigg setRange(argRanges[4], blockDims.y, threadIds.y);
2743e01af09SChristian Sigg setRange(argRanges[5], blockDims.z, threadIds.z);
2753e01af09SChristian Sigg }
276