xref: /llvm-project/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (revision 2dcb3b9f377de428f7d9d103c80226b9007c72a9)
1b739badaSJavier Setoain //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2b739badaSJavier Setoain //
3b739badaSJavier Setoain // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b739badaSJavier Setoain // See https://llvm.org/LICENSE.txt for license information.
5b739badaSJavier Setoain // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b739badaSJavier Setoain //
7b739badaSJavier Setoain //===----------------------------------------------------------------------===//
8b739badaSJavier Setoain 
975e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1075e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
117bbfd2aeSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
127bbfd2aeSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
131f971e23SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
14b739badaSJavier Setoain #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15b833bcb5SBenjamin Maxwell #include "mlir/Dialect/Utils/IndexingUtils.h"
16b833bcb5SBenjamin Maxwell #include "mlir/Dialect/Vector/IR/VectorOps.h"
17b739badaSJavier Setoain #include "mlir/IR/BuiltinOps.h"
18b739badaSJavier Setoain #include "mlir/IR/PatternMatch.h"
19b739badaSJavier Setoain 
20b739badaSJavier Setoain using namespace mlir;
21b739badaSJavier Setoain using namespace mlir::arm_sve;
22b739badaSJavier Setoain 
23b739badaSJavier Setoain using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
24b739badaSJavier Setoain using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
25b739badaSJavier Setoain using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
26b739badaSJavier Setoain using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
2795861216SJavier Setoain using ScalableMaskedAddIOpLowering =
2895861216SJavier Setoain     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
2995861216SJavier Setoain                                  ScalableMaskedAddIIntrOp>;
3095861216SJavier Setoain using ScalableMaskedAddFOpLowering =
3195861216SJavier Setoain     OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
3295861216SJavier Setoain                                  ScalableMaskedAddFIntrOp>;
3395861216SJavier Setoain using ScalableMaskedSubIOpLowering =
3495861216SJavier Setoain     OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
3595861216SJavier Setoain                                  ScalableMaskedSubIIntrOp>;
3695861216SJavier Setoain using ScalableMaskedSubFOpLowering =
3795861216SJavier Setoain     OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
3895861216SJavier Setoain                                  ScalableMaskedSubFIntrOp>;
3995861216SJavier Setoain using ScalableMaskedMulIOpLowering =
4095861216SJavier Setoain     OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
4195861216SJavier Setoain                                  ScalableMaskedMulIIntrOp>;
4295861216SJavier Setoain using ScalableMaskedMulFOpLowering =
4395861216SJavier Setoain     OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
4495861216SJavier Setoain                                  ScalableMaskedMulFIntrOp>;
4595861216SJavier Setoain using ScalableMaskedSDivIOpLowering =
4695861216SJavier Setoain     OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
4795861216SJavier Setoain                                  ScalableMaskedSDivIIntrOp>;
4895861216SJavier Setoain using ScalableMaskedUDivIOpLowering =
4995861216SJavier Setoain     OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
5095861216SJavier Setoain                                  ScalableMaskedUDivIIntrOp>;
5195861216SJavier Setoain using ScalableMaskedDivFOpLowering =
5295861216SJavier Setoain     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
5395861216SJavier Setoain                                  ScalableMaskedDivFIntrOp>;
54b739badaSJavier Setoain 
55b833bcb5SBenjamin Maxwell namespace {
56b833bcb5SBenjamin Maxwell 
57b833bcb5SBenjamin Maxwell /// Unrolls a conversion to/from equivalent vector types, to allow using a
58b833bcb5SBenjamin Maxwell /// conversion intrinsic that only supports 1-D vector types.
59b833bcb5SBenjamin Maxwell ///
60b833bcb5SBenjamin Maxwell /// Example:
61b833bcb5SBenjamin Maxwell /// ```
62b833bcb5SBenjamin Maxwell /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
63b833bcb5SBenjamin Maxwell /// ```
64b833bcb5SBenjamin Maxwell /// is rewritten into:
65b833bcb5SBenjamin Maxwell /// ```
66b833bcb5SBenjamin Maxwell /// %cst = arith.constant dense<false> : vector<2x[16]xi1>
67b833bcb5SBenjamin Maxwell /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
68b833bcb5SBenjamin Maxwell /// %2 = "arm_sve.intr.convert.to.svbool"(%1)
69b833bcb5SBenjamin Maxwell ///                : (vector<[4]xi1>) -> vector<[16]xi1>
70b833bcb5SBenjamin Maxwell /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
71b833bcb5SBenjamin Maxwell /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
72b833bcb5SBenjamin Maxwell /// %5 = "arm_sve.intr.convert.to.svbool"(%4)
73b833bcb5SBenjamin Maxwell ///                : (vector<[4]xi1>) -> vector<[16]xi1>
74b833bcb5SBenjamin Maxwell /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
75b833bcb5SBenjamin Maxwell /// ```
76b833bcb5SBenjamin Maxwell template <typename Op, typename IntrOp>
77b833bcb5SBenjamin Maxwell struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
78b833bcb5SBenjamin Maxwell   using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
79b833bcb5SBenjamin Maxwell 
80b833bcb5SBenjamin Maxwell   LogicalResult
81b833bcb5SBenjamin Maxwell   matchAndRewrite(Op convertOp, typename Op::Adaptor,
82b833bcb5SBenjamin Maxwell                   ConversionPatternRewriter &rewriter) const override {
83b833bcb5SBenjamin Maxwell     auto loc = convertOp.getLoc();
84b833bcb5SBenjamin Maxwell 
85b833bcb5SBenjamin Maxwell     auto source = convertOp.getSource();
86b833bcb5SBenjamin Maxwell     VectorType sourceType = source.getType();
87b833bcb5SBenjamin Maxwell     VectorType resultType = convertOp.getResult().getType();
88b833bcb5SBenjamin Maxwell 
89b833bcb5SBenjamin Maxwell     Value result = rewriter.create<arith::ConstantOp>(
90b833bcb5SBenjamin Maxwell         loc, resultType, rewriter.getZeroAttr(resultType));
91b833bcb5SBenjamin Maxwell 
92b833bcb5SBenjamin Maxwell     // We want to iterate over the input vector in steps of the trailing
93b833bcb5SBenjamin Maxwell     // dimension. So this creates tile shape where all leading dimensions are 1,
94b833bcb5SBenjamin Maxwell     // and the trailing dimension step is the size of the dimension.
95b833bcb5SBenjamin Maxwell     SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
96b833bcb5SBenjamin Maxwell     tileShape.back() = sourceType.getShape().back();
97b833bcb5SBenjamin Maxwell 
98b833bcb5SBenjamin Maxwell     // Iterate over all scalable mask/predicate slices of the source vector.
99b833bcb5SBenjamin Maxwell     for (SmallVector<int64_t> index :
100b833bcb5SBenjamin Maxwell          StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
101b833bcb5SBenjamin Maxwell       auto extractOrInsertPosition = ArrayRef(index).drop_back();
102b833bcb5SBenjamin Maxwell       auto sourceVector = rewriter.create<vector::ExtractOp>(
103b833bcb5SBenjamin Maxwell           loc, source, extractOrInsertPosition);
104b44b3494SBenjamin Maxwell       VectorType convertedType =
105b833bcb5SBenjamin Maxwell           VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
106b833bcb5SBenjamin Maxwell               .setDim(0, resultType.getShape().back());
107b833bcb5SBenjamin Maxwell       auto convertedVector =
108b833bcb5SBenjamin Maxwell           rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
109b833bcb5SBenjamin Maxwell       result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
110b833bcb5SBenjamin Maxwell                                                  extractOrInsertPosition);
111b833bcb5SBenjamin Maxwell     }
112b833bcb5SBenjamin Maxwell 
113b833bcb5SBenjamin Maxwell     rewriter.replaceOp(convertOp, result);
114b833bcb5SBenjamin Maxwell     return success();
115b833bcb5SBenjamin Maxwell   }
116b833bcb5SBenjamin Maxwell };
117b833bcb5SBenjamin Maxwell 
118b833bcb5SBenjamin Maxwell using ConvertToSvboolOpLowering =
119b833bcb5SBenjamin Maxwell     SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
120b833bcb5SBenjamin Maxwell 
121b833bcb5SBenjamin Maxwell using ConvertFromSvboolOpLowering =
122b833bcb5SBenjamin Maxwell     SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
123b833bcb5SBenjamin Maxwell 
1247dcca621SBenjamin Maxwell using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
1257dcca621SBenjamin Maxwell using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
1267dcca621SBenjamin Maxwell 
12778113303SBenjamin Maxwell /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
12878113303SBenjamin Maxwell /// but first input (P1) and result predicates need conversion to/from svbool.
12978113303SBenjamin Maxwell struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
13078113303SBenjamin Maxwell   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
13178113303SBenjamin Maxwell 
13278113303SBenjamin Maxwell   LogicalResult
13378113303SBenjamin Maxwell   matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
13478113303SBenjamin Maxwell                   ConversionPatternRewriter &rewriter) const override {
13578113303SBenjamin Maxwell     auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
13678113303SBenjamin Maxwell     auto loc = pselOp.getLoc();
13778113303SBenjamin Maxwell     auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
13878113303SBenjamin Maxwell                                                            adaptor.getP1());
13978113303SBenjamin Maxwell     auto indexI32 = rewriter.create<arith::IndexCastOp>(
14078113303SBenjamin Maxwell         loc, rewriter.getI32Type(), pselOp.getIndex());
14178113303SBenjamin Maxwell     auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
14278113303SBenjamin Maxwell                                                 pselOp.getP2(), indexI32);
14378113303SBenjamin Maxwell     rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
14478113303SBenjamin Maxwell         pselOp, adaptor.getP1().getType(), pselIntr);
14578113303SBenjamin Maxwell     return success();
14678113303SBenjamin Maxwell   }
14778113303SBenjamin Maxwell };
14878113303SBenjamin Maxwell 
149657ec732SBenjamin Maxwell /// Converts `vector.create_mask` ops that match the size of an SVE predicate
150657ec732SBenjamin Maxwell /// to the `whilelt` intrinsic. This produces more canonical codegen than the
151657ec732SBenjamin Maxwell /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
152657ec732SBenjamin Maxwell /// for more details. Note that we can't use (the more general) active.lane.mask
153657ec732SBenjamin Maxwell /// as its semantics don't neatly map on to `vector.create_mask`, as it does an
154657ec732SBenjamin Maxwell /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
155657ec732SBenjamin Maxwell /// `n` is zero (whereas `create_mask` just returns an all-false mask).
156657ec732SBenjamin Maxwell struct CreateMaskOpLowering
157657ec732SBenjamin Maxwell     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
158657ec732SBenjamin Maxwell   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
159657ec732SBenjamin Maxwell 
160657ec732SBenjamin Maxwell   LogicalResult
161657ec732SBenjamin Maxwell   matchAndRewrite(vector::CreateMaskOp createMaskOp,
162657ec732SBenjamin Maxwell                   vector::CreateMaskOp::Adaptor adaptor,
163657ec732SBenjamin Maxwell                   ConversionPatternRewriter &rewriter) const override {
164657ec732SBenjamin Maxwell     auto maskType = createMaskOp.getVectorType();
165657ec732SBenjamin Maxwell     if (maskType.getRank() != 1 || !maskType.isScalable())
166657ec732SBenjamin Maxwell       return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
167657ec732SBenjamin Maxwell 
168657ec732SBenjamin Maxwell     // TODO: Support masks which are multiples of SVE predicates.
169657ec732SBenjamin Maxwell     auto maskBaseSize = maskType.getDimSize(0);
170657ec732SBenjamin Maxwell     if (maskBaseSize < 2 || maskBaseSize > 16 ||
171657ec732SBenjamin Maxwell         !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
172657ec732SBenjamin Maxwell       return rewriter.notifyMatchFailure(createMaskOp,
173657ec732SBenjamin Maxwell                                          "not SVE predicate-sized");
174657ec732SBenjamin Maxwell 
175657ec732SBenjamin Maxwell     auto loc = createMaskOp.getLoc();
176657ec732SBenjamin Maxwell     auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
177657ec732SBenjamin Maxwell     rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
178657ec732SBenjamin Maxwell                                                adaptor.getOperands()[0]);
179657ec732SBenjamin Maxwell     return success();
180657ec732SBenjamin Maxwell   }
181657ec732SBenjamin Maxwell };
182657ec732SBenjamin Maxwell 
183b833bcb5SBenjamin Maxwell } // namespace
184b833bcb5SBenjamin Maxwell 
185b739badaSJavier Setoain /// Populate the given list with patterns that convert from ArmSVE to LLVM.
186b739badaSJavier Setoain void mlir::populateArmSVELegalizeForLLVMExportPatterns(
187*206fad0eSMatthias Springer     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
188b739badaSJavier Setoain   // Populate conversion patterns
189b739badaSJavier Setoain 
190b739badaSJavier Setoain   // clang-format off
191b739badaSJavier Setoain   patterns.add<SdotOpLowering,
192b739badaSJavier Setoain                SmmlaOpLowering,
193b739badaSJavier Setoain                UdotOpLowering,
194b739badaSJavier Setoain                UmmlaOpLowering,
19595861216SJavier Setoain                ScalableMaskedAddIOpLowering,
19695861216SJavier Setoain                ScalableMaskedAddFOpLowering,
19795861216SJavier Setoain                ScalableMaskedSubIOpLowering,
19895861216SJavier Setoain                ScalableMaskedSubFOpLowering,
19995861216SJavier Setoain                ScalableMaskedMulIOpLowering,
20095861216SJavier Setoain                ScalableMaskedMulFOpLowering,
20195861216SJavier Setoain                ScalableMaskedSDivIOpLowering,
20295861216SJavier Setoain                ScalableMaskedUDivIOpLowering,
203b833bcb5SBenjamin Maxwell                ScalableMaskedDivFOpLowering,
204b833bcb5SBenjamin Maxwell                ConvertToSvboolOpLowering,
2057dcca621SBenjamin Maxwell                ConvertFromSvboolOpLowering,
2067dcca621SBenjamin Maxwell                ZipX2OpLowering,
20778113303SBenjamin Maxwell                ZipX4OpLowering,
20878113303SBenjamin Maxwell                PselOpLowering>(converter);
209657ec732SBenjamin Maxwell   // Add vector.create_mask conversion with a high benefit as it produces much
210657ec732SBenjamin Maxwell   // nicer code than the generic lowering.
211657ec732SBenjamin Maxwell   patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
212b739badaSJavier Setoain   // clang-format on
213b739badaSJavier Setoain }
214b739badaSJavier Setoain 
215b739badaSJavier Setoain void mlir::configureArmSVELegalizeForExportTarget(
216b739badaSJavier Setoain     LLVMConversionTarget &target) {
21795861216SJavier Setoain   // clang-format off
21895861216SJavier Setoain   target.addLegalOp<SdotIntrOp,
21995861216SJavier Setoain                     SmmlaIntrOp,
22095861216SJavier Setoain                     UdotIntrOp,
22195861216SJavier Setoain                     UmmlaIntrOp,
22295861216SJavier Setoain                     ScalableMaskedAddIIntrOp,
22395861216SJavier Setoain                     ScalableMaskedAddFIntrOp,
22495861216SJavier Setoain                     ScalableMaskedSubIIntrOp,
22595861216SJavier Setoain                     ScalableMaskedSubFIntrOp,
22695861216SJavier Setoain                     ScalableMaskedMulIIntrOp,
22795861216SJavier Setoain                     ScalableMaskedMulFIntrOp,
22895861216SJavier Setoain                     ScalableMaskedSDivIIntrOp,
22995861216SJavier Setoain                     ScalableMaskedUDivIIntrOp,
230b833bcb5SBenjamin Maxwell                     ScalableMaskedDivFIntrOp,
231b833bcb5SBenjamin Maxwell                     ConvertToSvboolIntrOp,
2327dcca621SBenjamin Maxwell                     ConvertFromSvboolIntrOp,
2337dcca621SBenjamin Maxwell                     ZipX2IntrOp,
234657ec732SBenjamin Maxwell                     ZipX4IntrOp,
23578113303SBenjamin Maxwell                     PselIntrOp,
236657ec732SBenjamin Maxwell                     WhileLTIntrOp>();
23795861216SJavier Setoain   target.addIllegalOp<SdotOp,
23895861216SJavier Setoain                       SmmlaOp,
23995861216SJavier Setoain                       UdotOp,
24095861216SJavier Setoain                       UmmlaOp,
24195861216SJavier Setoain                       ScalableMaskedAddIOp,
24295861216SJavier Setoain                       ScalableMaskedAddFOp,
24395861216SJavier Setoain                       ScalableMaskedSubIOp,
24495861216SJavier Setoain                       ScalableMaskedSubFOp,
24595861216SJavier Setoain                       ScalableMaskedMulIOp,
24695861216SJavier Setoain                       ScalableMaskedMulFOp,
24795861216SJavier Setoain                       ScalableMaskedSDivIOp,
24895861216SJavier Setoain                       ScalableMaskedUDivIOp,
249b833bcb5SBenjamin Maxwell                       ScalableMaskedDivFOp,
250b833bcb5SBenjamin Maxwell                       ConvertToSvboolOp,
2517dcca621SBenjamin Maxwell                       ConvertFromSvboolOp,
2527dcca621SBenjamin Maxwell                       ZipX2Op,
2537dcca621SBenjamin Maxwell                       ZipX4Op>();
25495861216SJavier Setoain   // clang-format on
255b739badaSJavier Setoain }
256