xref: /llvm-project/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
10 
11 #include "mlir/Conversion/LLVMCommon/Pattern.h"
12 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/BuiltinAttributes.h"
15 
16 namespace mlir {
17 namespace gpu {
18 namespace index_lowering {
19 enum class IndexKind : uint32_t { Other = 0, Block = 1, Grid = 2 };
20 enum class IntrType : uint32_t {
21   None = 0,
22   Id = 1,
23   Dim = 2,
24 };
25 
26 // Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
27 // that Op operates on.  Op is assumed to return an `index` value and
28 // XOp, YOp and ZOp are assumed to return an `llvm.i32` value.  Depending on
29 // `indexBitwidth`, sign-extend or truncate the resulting value to match the
30 // bitwidth expected by the consumers of the value.
31 template <typename Op, typename XOp, typename YOp, typename ZOp>
32 struct OpLowering : public ConvertOpToLLVMPattern<Op> {
33 private:
34   unsigned indexBitwidth;
35   IndexKind indexKind;
36   IntrType intrType;
37 
38 public:
39   explicit OpLowering(const LLVMTypeConverter &typeConverter)
40       : ConvertOpToLLVMPattern<Op>(typeConverter),
41         indexBitwidth(typeConverter.getIndexTypeBitwidth()),
42         indexKind(IndexKind::Other), intrType(IntrType::None) {}
43 
44   explicit OpLowering(const LLVMTypeConverter &typeConverter,
45                       IndexKind indexKind, IntrType intrType)
46       : ConvertOpToLLVMPattern<Op>(typeConverter),
47         indexBitwidth(typeConverter.getIndexTypeBitwidth()),
48         indexKind(indexKind), intrType(intrType) {}
49 
50   // Convert the kernel arguments to an LLVM type, preserve the rest.
51   LogicalResult
52   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
53                   ConversionPatternRewriter &rewriter) const override {
54     auto loc = op->getLoc();
55     MLIRContext *context = rewriter.getContext();
56     Operation *newOp;
57     switch (op.getDimension()) {
58     case gpu::Dimension::x:
59       newOp = rewriter.create<XOp>(loc, IntegerType::get(context, 32));
60       break;
61     case gpu::Dimension::y:
62       newOp = rewriter.create<YOp>(loc, IntegerType::get(context, 32));
63       break;
64     case gpu::Dimension::z:
65       newOp = rewriter.create<ZOp>(loc, IntegerType::get(context, 32));
66       break;
67     }
68 
69     // Order of priority for bounds:
70     // 1. The upper_bound attribute
71     // 2. Inherent attributes on a surrounding gpu.func
72     // 3. Discardable attributes on a surrounding function of any kind
73     // The below code handles these in reverse order so that more important
74     // sources overwrite less important ones.
75     DenseI32ArrayAttr funcBounds = nullptr;
76     if (auto funcOp = op->template getParentOfType<FunctionOpInterface>()) {
77       switch (indexKind) {
78       case IndexKind::Block: {
79         auto blockHelper =
80             gpu::GPUDialect::KnownBlockSizeAttrHelper(op.getContext());
81         if (blockHelper.isAttrPresent(funcOp))
82           funcBounds = blockHelper.getAttr(funcOp);
83         break;
84       }
85       case IndexKind::Grid: {
86         auto gridHelper =
87             gpu::GPUDialect::KnownGridSizeAttrHelper(op.getContext());
88         if (gridHelper.isAttrPresent(funcOp))
89           funcBounds = gridHelper.getAttr(funcOp);
90         break;
91       }
92       case IndexKind::Other:
93         break;
94       }
95     }
96     if (auto gpuFunc = op->template getParentOfType<gpu::GPUFuncOp>()) {
97       switch (indexKind) {
98       case IndexKind::Block:
99         funcBounds = gpuFunc.getKnownBlockSizeAttr();
100         break;
101       case IndexKind::Grid:
102         funcBounds = gpuFunc.getKnownGridSizeAttr();
103         break;
104       case IndexKind::Other:
105         break;
106       }
107     }
108     std::optional<int32_t> upperBound;
109     if (funcBounds)
110       upperBound =
111           funcBounds.asArrayRef()[static_cast<uint32_t>(op.getDimension())];
112     if (auto opBound = op.getUpperBound())
113       upperBound = opBound->getZExtValue();
114 
115     if (upperBound && intrType != IntrType::None) {
116       int32_t min = (intrType == IntrType::Dim ? 1 : 0);
117       int32_t max = *upperBound + (intrType == IntrType::Id ? 0 : 1);
118       newOp->setAttr("range", LLVM::ConstantRangeAttr::get(
119                                   rewriter.getContext(), 32, min, max));
120     }
121     if (indexBitwidth > 32) {
122       newOp = rewriter.create<LLVM::SExtOp>(
123           loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
124     } else if (indexBitwidth < 32) {
125       newOp = rewriter.create<LLVM::TruncOp>(
126           loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
127     }
128 
129     rewriter.replaceOp(op, newOp->getResults());
130     return success();
131   }
132 };
133 } // namespace index_lowering
134 } // namespace gpu
135 } // namespace mlir
136 
137 #endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
138