xref: /llvm-project/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (revision 2dcb3b9f377de428f7d9d103c80226b9007c72a9)
1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
10 #include "mlir/Conversion/LLVMCommon/Pattern.h"
11 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
12 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Utils/IndexingUtils.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::arm_sve;
22 
23 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
24 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
25 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
26 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
27 using ScalableMaskedAddIOpLowering =
28     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
29                                  ScalableMaskedAddIIntrOp>;
30 using ScalableMaskedAddFOpLowering =
31     OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
32                                  ScalableMaskedAddFIntrOp>;
33 using ScalableMaskedSubIOpLowering =
34     OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
35                                  ScalableMaskedSubIIntrOp>;
36 using ScalableMaskedSubFOpLowering =
37     OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
38                                  ScalableMaskedSubFIntrOp>;
39 using ScalableMaskedMulIOpLowering =
40     OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
41                                  ScalableMaskedMulIIntrOp>;
42 using ScalableMaskedMulFOpLowering =
43     OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
44                                  ScalableMaskedMulFIntrOp>;
45 using ScalableMaskedSDivIOpLowering =
46     OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
47                                  ScalableMaskedSDivIIntrOp>;
48 using ScalableMaskedUDivIOpLowering =
49     OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
50                                  ScalableMaskedUDivIIntrOp>;
51 using ScalableMaskedDivFOpLowering =
52     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
53                                  ScalableMaskedDivFIntrOp>;
54 
55 namespace {
56 
57 /// Unrolls a conversion to/from equivalent vector types, to allow using a
58 /// conversion intrinsic that only supports 1-D vector types.
59 ///
60 /// Example:
61 /// ```
62 /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
63 /// ```
64 /// is rewritten into:
65 /// ```
66 /// %cst = arith.constant dense<false> : vector<2x[16]xi1>
67 /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
68 /// %2 = "arm_sve.intr.convert.to.svbool"(%1)
69 ///                : (vector<[4]xi1>) -> vector<[16]xi1>
70 /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
71 /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
72 /// %5 = "arm_sve.intr.convert.to.svbool"(%4)
73 ///                : (vector<[4]xi1>) -> vector<[16]xi1>
74 /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
75 /// ```
76 template <typename Op, typename IntrOp>
77 struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
78   using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
79 
80   LogicalResult
81   matchAndRewrite(Op convertOp, typename Op::Adaptor,
82                   ConversionPatternRewriter &rewriter) const override {
83     auto loc = convertOp.getLoc();
84 
85     auto source = convertOp.getSource();
86     VectorType sourceType = source.getType();
87     VectorType resultType = convertOp.getResult().getType();
88 
89     Value result = rewriter.create<arith::ConstantOp>(
90         loc, resultType, rewriter.getZeroAttr(resultType));
91 
92     // We want to iterate over the input vector in steps of the trailing
93     // dimension. So this creates tile shape where all leading dimensions are 1,
94     // and the trailing dimension step is the size of the dimension.
95     SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
96     tileShape.back() = sourceType.getShape().back();
97 
98     // Iterate over all scalable mask/predicate slices of the source vector.
99     for (SmallVector<int64_t> index :
100          StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
101       auto extractOrInsertPosition = ArrayRef(index).drop_back();
102       auto sourceVector = rewriter.create<vector::ExtractOp>(
103           loc, source, extractOrInsertPosition);
104       VectorType convertedType =
105           VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
106               .setDim(0, resultType.getShape().back());
107       auto convertedVector =
108           rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
109       result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
110                                                  extractOrInsertPosition);
111     }
112 
113     rewriter.replaceOp(convertOp, result);
114     return success();
115   }
116 };
117 
118 using ConvertToSvboolOpLowering =
119     SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
120 
121 using ConvertFromSvboolOpLowering =
122     SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
123 
124 using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
125 using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
126 
127 /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
128 /// but first input (P1) and result predicates need conversion to/from svbool.
129 struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
130   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
131 
132   LogicalResult
133   matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
134                   ConversionPatternRewriter &rewriter) const override {
135     auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
136     auto loc = pselOp.getLoc();
137     auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
138                                                            adaptor.getP1());
139     auto indexI32 = rewriter.create<arith::IndexCastOp>(
140         loc, rewriter.getI32Type(), pselOp.getIndex());
141     auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
142                                                 pselOp.getP2(), indexI32);
143     rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
144         pselOp, adaptor.getP1().getType(), pselIntr);
145     return success();
146   }
147 };
148 
149 /// Converts `vector.create_mask` ops that match the size of an SVE predicate
150 /// to the `whilelt` intrinsic. This produces more canonical codegen than the
151 /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
152 /// for more details. Note that we can't use (the more general) active.lane.mask
153 /// as its semantics don't neatly map on to `vector.create_mask`, as it does an
154 /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
155 /// `n` is zero (whereas `create_mask` just returns an all-false mask).
156 struct CreateMaskOpLowering
157     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
158   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
159 
160   LogicalResult
161   matchAndRewrite(vector::CreateMaskOp createMaskOp,
162                   vector::CreateMaskOp::Adaptor adaptor,
163                   ConversionPatternRewriter &rewriter) const override {
164     auto maskType = createMaskOp.getVectorType();
165     if (maskType.getRank() != 1 || !maskType.isScalable())
166       return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
167 
168     // TODO: Support masks which are multiples of SVE predicates.
169     auto maskBaseSize = maskType.getDimSize(0);
170     if (maskBaseSize < 2 || maskBaseSize > 16 ||
171         !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
172       return rewriter.notifyMatchFailure(createMaskOp,
173                                          "not SVE predicate-sized");
174 
175     auto loc = createMaskOp.getLoc();
176     auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
177     rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
178                                                adaptor.getOperands()[0]);
179     return success();
180   }
181 };
182 
183 } // namespace
184 
185 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
186 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
187     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
188   // Populate conversion patterns
189 
190   // clang-format off
191   patterns.add<SdotOpLowering,
192                SmmlaOpLowering,
193                UdotOpLowering,
194                UmmlaOpLowering,
195                ScalableMaskedAddIOpLowering,
196                ScalableMaskedAddFOpLowering,
197                ScalableMaskedSubIOpLowering,
198                ScalableMaskedSubFOpLowering,
199                ScalableMaskedMulIOpLowering,
200                ScalableMaskedMulFOpLowering,
201                ScalableMaskedSDivIOpLowering,
202                ScalableMaskedUDivIOpLowering,
203                ScalableMaskedDivFOpLowering,
204                ConvertToSvboolOpLowering,
205                ConvertFromSvboolOpLowering,
206                ZipX2OpLowering,
207                ZipX4OpLowering,
208                PselOpLowering>(converter);
209   // Add vector.create_mask conversion with a high benefit as it produces much
210   // nicer code than the generic lowering.
211   patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
212   // clang-format on
213 }
214 
215 void mlir::configureArmSVELegalizeForExportTarget(
216     LLVMConversionTarget &target) {
217   // clang-format off
218   target.addLegalOp<SdotIntrOp,
219                     SmmlaIntrOp,
220                     UdotIntrOp,
221                     UmmlaIntrOp,
222                     ScalableMaskedAddIIntrOp,
223                     ScalableMaskedAddFIntrOp,
224                     ScalableMaskedSubIIntrOp,
225                     ScalableMaskedSubFIntrOp,
226                     ScalableMaskedMulIIntrOp,
227                     ScalableMaskedMulFIntrOp,
228                     ScalableMaskedSDivIIntrOp,
229                     ScalableMaskedUDivIIntrOp,
230                     ScalableMaskedDivFIntrOp,
231                     ConvertToSvboolIntrOp,
232                     ConvertFromSvboolIntrOp,
233                     ZipX2IntrOp,
234                     ZipX4IntrOp,
235                     PselIntrOp,
236                     WhileLTIntrOp>();
237   target.addIllegalOp<SdotOp,
238                       SmmlaOp,
239                       UdotOp,
240                       UmmlaOp,
241                       ScalableMaskedAddIOp,
242                       ScalableMaskedAddFOp,
243                       ScalableMaskedSubIOp,
244                       ScalableMaskedSubFOp,
245                       ScalableMaskedMulIOp,
246                       ScalableMaskedMulFOp,
247                       ScalableMaskedSDivIOp,
248                       ScalableMaskedUDivIOp,
249                       ScalableMaskedDivFOp,
250                       ConvertToSvboolOp,
251                       ConvertFromSvboolOp,
252                       ZipX2Op,
253                       ZipX4Op>();
254   // clang-format on
255 }
256