1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 10 #include "mlir/Conversion/LLVMCommon/Pattern.h" 11 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" 12 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Utils/IndexingUtils.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/PatternMatch.h" 19 20 using namespace mlir; 21 using namespace mlir::arm_sve; 22 23 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; 24 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; 25 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; 26 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; 27 using ScalableMaskedAddIOpLowering = 28 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, 29 ScalableMaskedAddIIntrOp>; 30 using ScalableMaskedAddFOpLowering = 31 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, 32 ScalableMaskedAddFIntrOp>; 33 using ScalableMaskedSubIOpLowering = 34 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, 35 ScalableMaskedSubIIntrOp>; 36 using ScalableMaskedSubFOpLowering = 37 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, 38 ScalableMaskedSubFIntrOp>; 39 using ScalableMaskedMulIOpLowering = 40 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, 41 ScalableMaskedMulIIntrOp>; 42 using ScalableMaskedMulFOpLowering = 43 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, 44 ScalableMaskedMulFIntrOp>; 45 using ScalableMaskedSDivIOpLowering = 46 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, 47 ScalableMaskedSDivIIntrOp>; 48 using ScalableMaskedUDivIOpLowering = 49 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, 50 ScalableMaskedUDivIIntrOp>; 51 using ScalableMaskedDivFOpLowering = 52 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, 53 ScalableMaskedDivFIntrOp>; 54 55 namespace { 56 57 /// Unrolls a conversion to/from equivalent vector types, to allow using a 58 /// conversion intrinsic that only supports 1-D vector types. 59 /// 60 /// Example: 61 /// ``` 62 /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1> 63 /// ``` 64 /// is rewritten into: 65 /// ``` 66 /// %cst = arith.constant dense<false> : vector<2x[16]xi1> 67 /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1> 68 /// %2 = "arm_sve.intr.convert.to.svbool"(%1) 69 /// : (vector<[4]xi1>) -> vector<[16]xi1> 70 /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1> 71 /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1> 72 /// %5 = "arm_sve.intr.convert.to.svbool"(%4) 73 /// : (vector<[4]xi1>) -> vector<[16]xi1> 74 /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1> 75 /// ``` 76 template <typename Op, typename IntrOp> 77 struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> { 78 using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern; 79 80 LogicalResult 81 matchAndRewrite(Op convertOp, typename Op::Adaptor, 82 ConversionPatternRewriter &rewriter) const override { 83 auto loc = convertOp.getLoc(); 84 85 auto source = convertOp.getSource(); 86 VectorType sourceType = source.getType(); 87 VectorType resultType = convertOp.getResult().getType(); 88 89 Value result = rewriter.create<arith::ConstantOp>( 90 loc, resultType, rewriter.getZeroAttr(resultType)); 91 92 // We want to iterate over the input vector in steps of the trailing 93 // dimension. So this creates tile shape where all leading dimensions are 1, 94 // and the trailing dimension step is the size of the dimension. 95 SmallVector<int64_t> tileShape(sourceType.getRank(), 1); 96 tileShape.back() = sourceType.getShape().back(); 97 98 // Iterate over all scalable mask/predicate slices of the source vector. 99 for (SmallVector<int64_t> index : 100 StaticTileOffsetRange(sourceType.getShape(), tileShape)) { 101 auto extractOrInsertPosition = ArrayRef(index).drop_back(); 102 auto sourceVector = rewriter.create<vector::ExtractOp>( 103 loc, source, extractOrInsertPosition); 104 VectorType convertedType = 105 VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType())) 106 .setDim(0, resultType.getShape().back()); 107 auto convertedVector = 108 rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector); 109 result = rewriter.create<vector::InsertOp>(loc, convertedVector, result, 110 extractOrInsertPosition); 111 } 112 113 rewriter.replaceOp(convertOp, result); 114 return success(); 115 } 116 }; 117 118 using ConvertToSvboolOpLowering = 119 SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>; 120 121 using ConvertFromSvboolOpLowering = 122 SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>; 123 124 using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>; 125 using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>; 126 127 /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion 128 /// but first input (P1) and result predicates need conversion to/from svbool. 129 struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> { 130 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 131 132 LogicalResult 133 matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor, 134 ConversionPatternRewriter &rewriter) const override { 135 auto svboolType = VectorType::get(16, rewriter.getI1Type(), true); 136 auto loc = pselOp.getLoc(); 137 auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType, 138 adaptor.getP1()); 139 auto indexI32 = rewriter.create<arith::IndexCastOp>( 140 loc, rewriter.getI32Type(), pselOp.getIndex()); 141 auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1, 142 pselOp.getP2(), indexI32); 143 rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>( 144 pselOp, adaptor.getP1().getType(), pselIntr); 145 return success(); 146 } 147 }; 148 149 /// Converts `vector.create_mask` ops that match the size of an SVE predicate 150 /// to the `whilelt` intrinsic. This produces more canonical codegen than the 151 /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840 152 /// for more details. Note that we can't use (the more general) active.lane.mask 153 /// as its semantics don't neatly map on to `vector.create_mask`, as it does an 154 /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if 155 /// `n` is zero (whereas `create_mask` just returns an all-false mask). 156 struct CreateMaskOpLowering 157 : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { 158 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 159 160 LogicalResult 161 matchAndRewrite(vector::CreateMaskOp createMaskOp, 162 vector::CreateMaskOp::Adaptor adaptor, 163 ConversionPatternRewriter &rewriter) const override { 164 auto maskType = createMaskOp.getVectorType(); 165 if (maskType.getRank() != 1 || !maskType.isScalable()) 166 return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable"); 167 168 // TODO: Support masks which are multiples of SVE predicates. 169 auto maskBaseSize = maskType.getDimSize(0); 170 if (maskBaseSize < 2 || maskBaseSize > 16 || 171 !llvm::isPowerOf2_32(uint32_t(maskBaseSize))) 172 return rewriter.notifyMatchFailure(createMaskOp, 173 "not SVE predicate-sized"); 174 175 auto loc = createMaskOp.getLoc(); 176 auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type()); 177 rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero, 178 adaptor.getOperands()[0]); 179 return success(); 180 } 181 }; 182 183 } // namespace 184 185 /// Populate the given list with patterns that convert from ArmSVE to LLVM. 186 void mlir::populateArmSVELegalizeForLLVMExportPatterns( 187 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 188 // Populate conversion patterns 189 190 // clang-format off 191 patterns.add<SdotOpLowering, 192 SmmlaOpLowering, 193 UdotOpLowering, 194 UmmlaOpLowering, 195 ScalableMaskedAddIOpLowering, 196 ScalableMaskedAddFOpLowering, 197 ScalableMaskedSubIOpLowering, 198 ScalableMaskedSubFOpLowering, 199 ScalableMaskedMulIOpLowering, 200 ScalableMaskedMulFOpLowering, 201 ScalableMaskedSDivIOpLowering, 202 ScalableMaskedUDivIOpLowering, 203 ScalableMaskedDivFOpLowering, 204 ConvertToSvboolOpLowering, 205 ConvertFromSvboolOpLowering, 206 ZipX2OpLowering, 207 ZipX4OpLowering, 208 PselOpLowering>(converter); 209 // Add vector.create_mask conversion with a high benefit as it produces much 210 // nicer code than the generic lowering. 211 patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096); 212 // clang-format on 213 } 214 215 void mlir::configureArmSVELegalizeForExportTarget( 216 LLVMConversionTarget &target) { 217 // clang-format off 218 target.addLegalOp<SdotIntrOp, 219 SmmlaIntrOp, 220 UdotIntrOp, 221 UmmlaIntrOp, 222 ScalableMaskedAddIIntrOp, 223 ScalableMaskedAddFIntrOp, 224 ScalableMaskedSubIIntrOp, 225 ScalableMaskedSubFIntrOp, 226 ScalableMaskedMulIIntrOp, 227 ScalableMaskedMulFIntrOp, 228 ScalableMaskedSDivIIntrOp, 229 ScalableMaskedUDivIIntrOp, 230 ScalableMaskedDivFIntrOp, 231 ConvertToSvboolIntrOp, 232 ConvertFromSvboolIntrOp, 233 ZipX2IntrOp, 234 ZipX4IntrOp, 235 PselIntrOp, 236 WhileLTIntrOp>(); 237 target.addIllegalOp<SdotOp, 238 SmmlaOp, 239 UdotOp, 240 UmmlaOp, 241 ScalableMaskedAddIOp, 242 ScalableMaskedAddFOp, 243 ScalableMaskedSubIOp, 244 ScalableMaskedSubFOp, 245 ScalableMaskedMulIOp, 246 ScalableMaskedMulFOp, 247 ScalableMaskedSDivIOp, 248 ScalableMaskedUDivIOp, 249 ScalableMaskedDivFOp, 250 ConvertToSvboolOp, 251 ConvertFromSvboolOp, 252 ZipX2Op, 253 ZipX4Op>(); 254 // clang-format on 255 } 256