1b739badaSJavier Setoain //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// 2b739badaSJavier Setoain // 3b739badaSJavier Setoain // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b739badaSJavier Setoain // See https://llvm.org/LICENSE.txt for license information. 5b739badaSJavier Setoain // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b739badaSJavier Setoain // 7b739badaSJavier Setoain //===----------------------------------------------------------------------===// 8b739badaSJavier Setoain 975e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1075e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 117bbfd2aeSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" 127bbfd2aeSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" 131f971e23SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 14b739badaSJavier Setoain #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15b833bcb5SBenjamin Maxwell #include "mlir/Dialect/Utils/IndexingUtils.h" 16b833bcb5SBenjamin Maxwell #include "mlir/Dialect/Vector/IR/VectorOps.h" 17b739badaSJavier Setoain #include "mlir/IR/BuiltinOps.h" 18b739badaSJavier Setoain #include "mlir/IR/PatternMatch.h" 19b739badaSJavier Setoain 20b739badaSJavier Setoain using namespace mlir; 21b739badaSJavier Setoain using namespace mlir::arm_sve; 22b739badaSJavier Setoain 23b739badaSJavier Setoain using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; 24b739badaSJavier Setoain using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; 25b739badaSJavier Setoain using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; 26b739badaSJavier Setoain using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; 2795861216SJavier Setoain using ScalableMaskedAddIOpLowering = 2895861216SJavier Setoain OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, 2995861216SJavier Setoain ScalableMaskedAddIIntrOp>; 3095861216SJavier Setoain using ScalableMaskedAddFOpLowering = 3195861216SJavier Setoain OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, 3295861216SJavier Setoain ScalableMaskedAddFIntrOp>; 3395861216SJavier Setoain using ScalableMaskedSubIOpLowering = 3495861216SJavier Setoain OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, 3595861216SJavier Setoain ScalableMaskedSubIIntrOp>; 3695861216SJavier Setoain using ScalableMaskedSubFOpLowering = 3795861216SJavier Setoain OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, 3895861216SJavier Setoain ScalableMaskedSubFIntrOp>; 3995861216SJavier Setoain using ScalableMaskedMulIOpLowering = 4095861216SJavier Setoain OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, 4195861216SJavier Setoain ScalableMaskedMulIIntrOp>; 4295861216SJavier Setoain using ScalableMaskedMulFOpLowering = 4395861216SJavier Setoain OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, 4495861216SJavier Setoain ScalableMaskedMulFIntrOp>; 4595861216SJavier Setoain using ScalableMaskedSDivIOpLowering = 4695861216SJavier Setoain OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, 4795861216SJavier Setoain ScalableMaskedSDivIIntrOp>; 4895861216SJavier Setoain using ScalableMaskedUDivIOpLowering = 4995861216SJavier Setoain OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, 5095861216SJavier Setoain ScalableMaskedUDivIIntrOp>; 5195861216SJavier Setoain using ScalableMaskedDivFOpLowering = 5295861216SJavier Setoain OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, 5395861216SJavier Setoain ScalableMaskedDivFIntrOp>; 54b739badaSJavier Setoain 55b833bcb5SBenjamin Maxwell namespace { 56b833bcb5SBenjamin Maxwell 57b833bcb5SBenjamin Maxwell /// Unrolls a conversion to/from equivalent vector types, to allow using a 58b833bcb5SBenjamin Maxwell /// conversion intrinsic that only supports 1-D vector types. 59b833bcb5SBenjamin Maxwell /// 60b833bcb5SBenjamin Maxwell /// Example: 61b833bcb5SBenjamin Maxwell /// ``` 62b833bcb5SBenjamin Maxwell /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1> 63b833bcb5SBenjamin Maxwell /// ``` 64b833bcb5SBenjamin Maxwell /// is rewritten into: 65b833bcb5SBenjamin Maxwell /// ``` 66b833bcb5SBenjamin Maxwell /// %cst = arith.constant dense<false> : vector<2x[16]xi1> 67b833bcb5SBenjamin Maxwell /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1> 68b833bcb5SBenjamin Maxwell /// %2 = "arm_sve.intr.convert.to.svbool"(%1) 69b833bcb5SBenjamin Maxwell /// : (vector<[4]xi1>) -> vector<[16]xi1> 70b833bcb5SBenjamin Maxwell /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1> 71b833bcb5SBenjamin Maxwell /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1> 72b833bcb5SBenjamin Maxwell /// %5 = "arm_sve.intr.convert.to.svbool"(%4) 73b833bcb5SBenjamin Maxwell /// : (vector<[4]xi1>) -> vector<[16]xi1> 74b833bcb5SBenjamin Maxwell /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1> 75b833bcb5SBenjamin Maxwell /// ``` 76b833bcb5SBenjamin Maxwell template <typename Op, typename IntrOp> 77b833bcb5SBenjamin Maxwell struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> { 78b833bcb5SBenjamin Maxwell using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern; 79b833bcb5SBenjamin Maxwell 80b833bcb5SBenjamin Maxwell LogicalResult 81b833bcb5SBenjamin Maxwell matchAndRewrite(Op convertOp, typename Op::Adaptor, 82b833bcb5SBenjamin Maxwell ConversionPatternRewriter &rewriter) const override { 83b833bcb5SBenjamin Maxwell auto loc = convertOp.getLoc(); 84b833bcb5SBenjamin Maxwell 85b833bcb5SBenjamin Maxwell auto source = convertOp.getSource(); 86b833bcb5SBenjamin Maxwell VectorType sourceType = source.getType(); 87b833bcb5SBenjamin Maxwell VectorType resultType = convertOp.getResult().getType(); 88b833bcb5SBenjamin Maxwell 89b833bcb5SBenjamin Maxwell Value result = rewriter.create<arith::ConstantOp>( 90b833bcb5SBenjamin Maxwell loc, resultType, rewriter.getZeroAttr(resultType)); 91b833bcb5SBenjamin Maxwell 92b833bcb5SBenjamin Maxwell // We want to iterate over the input vector in steps of the trailing 93b833bcb5SBenjamin Maxwell // dimension. So this creates tile shape where all leading dimensions are 1, 94b833bcb5SBenjamin Maxwell // and the trailing dimension step is the size of the dimension. 95b833bcb5SBenjamin Maxwell SmallVector<int64_t> tileShape(sourceType.getRank(), 1); 96b833bcb5SBenjamin Maxwell tileShape.back() = sourceType.getShape().back(); 97b833bcb5SBenjamin Maxwell 98b833bcb5SBenjamin Maxwell // Iterate over all scalable mask/predicate slices of the source vector. 99b833bcb5SBenjamin Maxwell for (SmallVector<int64_t> index : 100b833bcb5SBenjamin Maxwell StaticTileOffsetRange(sourceType.getShape(), tileShape)) { 101b833bcb5SBenjamin Maxwell auto extractOrInsertPosition = ArrayRef(index).drop_back(); 102b833bcb5SBenjamin Maxwell auto sourceVector = rewriter.create<vector::ExtractOp>( 103b833bcb5SBenjamin Maxwell loc, source, extractOrInsertPosition); 104b44b3494SBenjamin Maxwell VectorType convertedType = 105b833bcb5SBenjamin Maxwell VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType())) 106b833bcb5SBenjamin Maxwell .setDim(0, resultType.getShape().back()); 107b833bcb5SBenjamin Maxwell auto convertedVector = 108b833bcb5SBenjamin Maxwell rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector); 109b833bcb5SBenjamin Maxwell result = rewriter.create<vector::InsertOp>(loc, convertedVector, result, 110b833bcb5SBenjamin Maxwell extractOrInsertPosition); 111b833bcb5SBenjamin Maxwell } 112b833bcb5SBenjamin Maxwell 113b833bcb5SBenjamin Maxwell rewriter.replaceOp(convertOp, result); 114b833bcb5SBenjamin Maxwell return success(); 115b833bcb5SBenjamin Maxwell } 116b833bcb5SBenjamin Maxwell }; 117b833bcb5SBenjamin Maxwell 118b833bcb5SBenjamin Maxwell using ConvertToSvboolOpLowering = 119b833bcb5SBenjamin Maxwell SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>; 120b833bcb5SBenjamin Maxwell 121b833bcb5SBenjamin Maxwell using ConvertFromSvboolOpLowering = 122b833bcb5SBenjamin Maxwell SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>; 123b833bcb5SBenjamin Maxwell 1247dcca621SBenjamin Maxwell using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>; 1257dcca621SBenjamin Maxwell using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>; 1267dcca621SBenjamin Maxwell 12778113303SBenjamin Maxwell /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion 12878113303SBenjamin Maxwell /// but first input (P1) and result predicates need conversion to/from svbool. 12978113303SBenjamin Maxwell struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> { 13078113303SBenjamin Maxwell using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 13178113303SBenjamin Maxwell 13278113303SBenjamin Maxwell LogicalResult 13378113303SBenjamin Maxwell matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor, 13478113303SBenjamin Maxwell ConversionPatternRewriter &rewriter) const override { 13578113303SBenjamin Maxwell auto svboolType = VectorType::get(16, rewriter.getI1Type(), true); 13678113303SBenjamin Maxwell auto loc = pselOp.getLoc(); 13778113303SBenjamin Maxwell auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType, 13878113303SBenjamin Maxwell adaptor.getP1()); 13978113303SBenjamin Maxwell auto indexI32 = rewriter.create<arith::IndexCastOp>( 14078113303SBenjamin Maxwell loc, rewriter.getI32Type(), pselOp.getIndex()); 14178113303SBenjamin Maxwell auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1, 14278113303SBenjamin Maxwell pselOp.getP2(), indexI32); 14378113303SBenjamin Maxwell rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>( 14478113303SBenjamin Maxwell pselOp, adaptor.getP1().getType(), pselIntr); 14578113303SBenjamin Maxwell return success(); 14678113303SBenjamin Maxwell } 14778113303SBenjamin Maxwell }; 14878113303SBenjamin Maxwell 149657ec732SBenjamin Maxwell /// Converts `vector.create_mask` ops that match the size of an SVE predicate 150657ec732SBenjamin Maxwell /// to the `whilelt` intrinsic. This produces more canonical codegen than the 151657ec732SBenjamin Maxwell /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840 152657ec732SBenjamin Maxwell /// for more details. Note that we can't use (the more general) active.lane.mask 153657ec732SBenjamin Maxwell /// as its semantics don't neatly map on to `vector.create_mask`, as it does an 154657ec732SBenjamin Maxwell /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if 155657ec732SBenjamin Maxwell /// `n` is zero (whereas `create_mask` just returns an all-false mask). 156657ec732SBenjamin Maxwell struct CreateMaskOpLowering 157657ec732SBenjamin Maxwell : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { 158657ec732SBenjamin Maxwell using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 159657ec732SBenjamin Maxwell 160657ec732SBenjamin Maxwell LogicalResult 161657ec732SBenjamin Maxwell matchAndRewrite(vector::CreateMaskOp createMaskOp, 162657ec732SBenjamin Maxwell vector::CreateMaskOp::Adaptor adaptor, 163657ec732SBenjamin Maxwell ConversionPatternRewriter &rewriter) const override { 164657ec732SBenjamin Maxwell auto maskType = createMaskOp.getVectorType(); 165657ec732SBenjamin Maxwell if (maskType.getRank() != 1 || !maskType.isScalable()) 166657ec732SBenjamin Maxwell return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable"); 167657ec732SBenjamin Maxwell 168657ec732SBenjamin Maxwell // TODO: Support masks which are multiples of SVE predicates. 169657ec732SBenjamin Maxwell auto maskBaseSize = maskType.getDimSize(0); 170657ec732SBenjamin Maxwell if (maskBaseSize < 2 || maskBaseSize > 16 || 171657ec732SBenjamin Maxwell !llvm::isPowerOf2_32(uint32_t(maskBaseSize))) 172657ec732SBenjamin Maxwell return rewriter.notifyMatchFailure(createMaskOp, 173657ec732SBenjamin Maxwell "not SVE predicate-sized"); 174657ec732SBenjamin Maxwell 175657ec732SBenjamin Maxwell auto loc = createMaskOp.getLoc(); 176657ec732SBenjamin Maxwell auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type()); 177657ec732SBenjamin Maxwell rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero, 178657ec732SBenjamin Maxwell adaptor.getOperands()[0]); 179657ec732SBenjamin Maxwell return success(); 180657ec732SBenjamin Maxwell } 181657ec732SBenjamin Maxwell }; 182657ec732SBenjamin Maxwell 183b833bcb5SBenjamin Maxwell } // namespace 184b833bcb5SBenjamin Maxwell 185b739badaSJavier Setoain /// Populate the given list with patterns that convert from ArmSVE to LLVM. 186b739badaSJavier Setoain void mlir::populateArmSVELegalizeForLLVMExportPatterns( 187*206fad0eSMatthias Springer const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 188b739badaSJavier Setoain // Populate conversion patterns 189b739badaSJavier Setoain 190b739badaSJavier Setoain // clang-format off 191b739badaSJavier Setoain patterns.add<SdotOpLowering, 192b739badaSJavier Setoain SmmlaOpLowering, 193b739badaSJavier Setoain UdotOpLowering, 194b739badaSJavier Setoain UmmlaOpLowering, 19595861216SJavier Setoain ScalableMaskedAddIOpLowering, 19695861216SJavier Setoain ScalableMaskedAddFOpLowering, 19795861216SJavier Setoain ScalableMaskedSubIOpLowering, 19895861216SJavier Setoain ScalableMaskedSubFOpLowering, 19995861216SJavier Setoain ScalableMaskedMulIOpLowering, 20095861216SJavier Setoain ScalableMaskedMulFOpLowering, 20195861216SJavier Setoain ScalableMaskedSDivIOpLowering, 20295861216SJavier Setoain ScalableMaskedUDivIOpLowering, 203b833bcb5SBenjamin Maxwell ScalableMaskedDivFOpLowering, 204b833bcb5SBenjamin Maxwell ConvertToSvboolOpLowering, 2057dcca621SBenjamin Maxwell ConvertFromSvboolOpLowering, 2067dcca621SBenjamin Maxwell ZipX2OpLowering, 20778113303SBenjamin Maxwell ZipX4OpLowering, 20878113303SBenjamin Maxwell PselOpLowering>(converter); 209657ec732SBenjamin Maxwell // Add vector.create_mask conversion with a high benefit as it produces much 210657ec732SBenjamin Maxwell // nicer code than the generic lowering. 211657ec732SBenjamin Maxwell patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096); 212b739badaSJavier Setoain // clang-format on 213b739badaSJavier Setoain } 214b739badaSJavier Setoain 215b739badaSJavier Setoain void mlir::configureArmSVELegalizeForExportTarget( 216b739badaSJavier Setoain LLVMConversionTarget &target) { 21795861216SJavier Setoain // clang-format off 21895861216SJavier Setoain target.addLegalOp<SdotIntrOp, 21995861216SJavier Setoain SmmlaIntrOp, 22095861216SJavier Setoain UdotIntrOp, 22195861216SJavier Setoain UmmlaIntrOp, 22295861216SJavier Setoain ScalableMaskedAddIIntrOp, 22395861216SJavier Setoain ScalableMaskedAddFIntrOp, 22495861216SJavier Setoain ScalableMaskedSubIIntrOp, 22595861216SJavier Setoain ScalableMaskedSubFIntrOp, 22695861216SJavier Setoain ScalableMaskedMulIIntrOp, 22795861216SJavier Setoain ScalableMaskedMulFIntrOp, 22895861216SJavier Setoain ScalableMaskedSDivIIntrOp, 22995861216SJavier Setoain ScalableMaskedUDivIIntrOp, 230b833bcb5SBenjamin Maxwell ScalableMaskedDivFIntrOp, 231b833bcb5SBenjamin Maxwell ConvertToSvboolIntrOp, 2327dcca621SBenjamin Maxwell ConvertFromSvboolIntrOp, 2337dcca621SBenjamin Maxwell ZipX2IntrOp, 234657ec732SBenjamin Maxwell ZipX4IntrOp, 23578113303SBenjamin Maxwell PselIntrOp, 236657ec732SBenjamin Maxwell WhileLTIntrOp>(); 23795861216SJavier Setoain target.addIllegalOp<SdotOp, 23895861216SJavier Setoain SmmlaOp, 23995861216SJavier Setoain UdotOp, 24095861216SJavier Setoain UmmlaOp, 24195861216SJavier Setoain ScalableMaskedAddIOp, 24295861216SJavier Setoain ScalableMaskedAddFOp, 24395861216SJavier Setoain ScalableMaskedSubIOp, 24495861216SJavier Setoain ScalableMaskedSubFOp, 24595861216SJavier Setoain ScalableMaskedMulIOp, 24695861216SJavier Setoain ScalableMaskedMulFOp, 24795861216SJavier Setoain ScalableMaskedSDivIOp, 24895861216SJavier Setoain ScalableMaskedUDivIOp, 249b833bcb5SBenjamin Maxwell ScalableMaskedDivFOp, 250b833bcb5SBenjamin Maxwell ConvertToSvboolOp, 2517dcca621SBenjamin Maxwell ConvertFromSvboolOp, 2527dcca621SBenjamin Maxwell ZipX2Op, 2537dcca621SBenjamin Maxwell ZipX4Op>(); 25495861216SJavier Setoain // clang-format on 255b739badaSJavier Setoain } 256