//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "llvm/ADT/STLForwardCompat.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include using namespace mlir; using namespace mlir::gpu; // Maximum grid and block dimensions of all known GPUs are less than 2^32. static constexpr uint64_t kMaxDim = std::numeric_limits::max(); // Maximum cluster size static constexpr uint64_t kMaxClusterDim = 8; // Maximum subgroups are no larger than 128. static constexpr uint64_t kMaxSubgroupSize = 128; static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) { unsigned width = IndexType::kInternalStorageBitWidth; return ConstantIntRanges::fromUnsigned(APInt(width, umin), APInt(width, umax)); } namespace { enum class LaunchDims : uint32_t { Block = 0, Grid = 1 }; } // end namespace /// If the operation `op` is in a context that is annotated with maximum /// launch dimensions (a launch op with constant block or grid /// sizes or a launch_func op with the appropriate dimensions), return /// the bound on the maximum size of the dimension that the op is querying. /// IDs will be one less than this bound. static Value valueByDim(KernelDim3 dims, Dimension dim) { switch (dim) { case Dimension::x: return dims.x; case Dimension::y: return dims.y; case Dimension::z: return dims.z; } llvm_unreachable("All dimension enum cases handled above"); } static uint64_t zext(uint32_t arg) { return static_cast(arg); } static std::optional getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim) { DenseI32ArrayAttr bounds; switch (dims) { case LaunchDims::Block: bounds = func.getKnownBlockSizeAttr(); break; case LaunchDims::Grid: bounds = func.getKnownGridSizeAttr(); break; } if (!bounds) return std::nullopt; if (bounds.size() < static_cast(dim)) return std::nullopt; return zext(bounds[static_cast(dim)]); } static std::optional getKnownLaunchAttr(FunctionOpInterface func, StringRef attrName, Dimension dim) { auto bounds = func.getOperation()->getAttrOfType(attrName); if (!bounds) return std::nullopt; if (bounds.size() < static_cast(dim)) return std::nullopt; return zext(bounds[static_cast(dim)]); } template static std::optional getKnownLaunchDim(Op op, LaunchDims type) { Dimension dim = op.getDimension(); if (auto launch = op->template getParentOfType()) { KernelDim3 bounds; switch (type) { case LaunchDims::Block: bounds = launch.getBlockSizeOperandValues(); break; case LaunchDims::Grid: bounds = launch.getGridSizeOperandValues(); break; } Value maybeBound = valueByDim(bounds, dim); APInt value; if (matchPattern(maybeBound, m_ConstantInt(&value))) return value.getZExtValue(); } if (auto gpuFunc = op->template getParentOfType()) { auto inherentAttr = getKnownLaunchAttr(gpuFunc, type, dim); if (inherentAttr) return inherentAttr; } if (auto func = op->template getParentOfType()) { StringRef attrName; switch (type) { case LaunchDims::Block: attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr(); break; case LaunchDims::Grid: attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr(); break; } auto discardableAttr = getKnownLaunchAttr(func, attrName, dim); if (discardableAttr) return discardableAttr; } return std::nullopt; } void ClusterDimOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { uint64_t max = kMaxDim; if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(1, max)); } void ClusterDimBlocksOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { uint64_t max = kMaxClusterDim; if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(1, max)); } void ClusterIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { uint64_t max = kMaxDim; if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(0, max - 1ULL)); } void ClusterBlockIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { uint64_t max = kMaxClusterDim; if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(0, max - 1ULL)); } void BlockDimOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { std::optional knownVal = getKnownLaunchDim(*this, LaunchDims::Block); if (knownVal) return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal)); ; uint64_t max = kMaxDim; if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(1, max)); } void BlockIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { uint64_t max = kMaxDim; if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Grid)) max = fromContext.value(); if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(0, max - 1ULL)); } void GridDimOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { std::optional knownVal = getKnownLaunchDim(*this, LaunchDims::Grid); if (knownVal) return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal)); uint64_t max = kMaxDim; if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(1, max)); } void ThreadIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { uint64_t max = kMaxDim; if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Block)) max = fromContext.value(); if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(0, max - 1ULL)); } void LaneIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { uint64_t max = kMaxSubgroupSize; if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(0, max - 1ULL)); } void SubgroupIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { uint64_t max = kMaxDim; if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(0, max - 1ULL)); } void GlobalIdOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { if (auto specified = getUpperBound()) return setResultRange(getResult(), getIndexRange(0, specified->getZExtValue() - 1ULL)); uint64_t blockDimMax = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim); uint64_t gridDimMax = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim); setResultRange(getResult(), getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL)); } void NumSubgroupsOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { uint64_t max = kMaxDim; if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(1, max)); } void SubgroupSizeOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { uint64_t max = kMaxSubgroupSize; if (auto specified = getUpperBound()) max = specified->getZExtValue(); setResultRange(getResult(), getIndexRange(1, max)); } void LaunchOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult, Value idxResult) { if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth) return; ConstantIntRanges dimRange = argRange.intersection(getIndexRange(1, kMaxDim)); setResultRange(dimResult, dimRange); ConstantIntRanges idxRange = getIndexRange(0, dimRange.umax().getZExtValue() - 1); setResultRange(idxResult, idxRange); }; argRanges = argRanges.drop_front(getAsyncDependencies().size()); KernelDim3 gridDims = getGridSize(); KernelDim3 blockIds = getBlockIds(); setRange(argRanges[0], gridDims.x, blockIds.x); setRange(argRanges[1], gridDims.y, blockIds.y); setRange(argRanges[2], gridDims.z, blockIds.z); KernelDim3 blockDims = getBlockSize(); KernelDim3 threadIds = getThreadIds(); setRange(argRanges[3], blockDims.x, threadIds.x); setRange(argRanges[4], blockDims.y, threadIds.y); setRange(argRanges[5], blockDims.z, threadIds.z); }