1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/X86Vector/Transforms.h" 10 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 using namespace mlir; 20 using namespace mlir::x86vector; 21 22 /// Extracts the "main" vector element type from the given X86Vector operation. 23 template <typename OpTy> 24 static Type getSrcVectorElementType(OpTy op) { 25 return cast<VectorType>(op.getSrc().getType()).getElementType(); 26 } 27 template <> 28 Type getSrcVectorElementType(Vp2IntersectOp op) { 29 return cast<VectorType>(op.getA().getType()).getElementType(); 30 } 31 32 namespace { 33 34 /// Base conversion for AVX512 ops that can be lowered to one of the two 35 /// intrinsics based on the bitwidth of their "main" vector element type. This 36 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the 37 /// results of multi-result intrinsic ops. 38 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 39 struct LowerToIntrinsic : public OpConversionPattern<OpTy> { 40 explicit LowerToIntrinsic(const LLVMTypeConverter &converter) 41 : OpConversionPattern<OpTy>(converter, &converter.getContext()) {} 42 43 const LLVMTypeConverter &getTypeConverter() const { 44 return *static_cast<const LLVMTypeConverter *>( 45 OpConversionPattern<OpTy>::getTypeConverter()); 46 } 47 48 LogicalResult 49 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, 50 ConversionPatternRewriter &rewriter) const override { 51 Type elementType = getSrcVectorElementType<OpTy>(op); 52 unsigned bitwidth = elementType.getIntOrFloatBitWidth(); 53 if (bitwidth == 32) 54 return LLVM::detail::oneToOneRewrite( 55 op, Intr32OpTy::getOperationName(), adaptor.getOperands(), 56 op->getAttrs(), getTypeConverter(), rewriter); 57 if (bitwidth == 64) 58 return LLVM::detail::oneToOneRewrite( 59 op, Intr64OpTy::getOperationName(), adaptor.getOperands(), 60 op->getAttrs(), getTypeConverter(), rewriter); 61 return rewriter.notifyMatchFailure( 62 op, "expected 'src' to be either f32 or f64"); 63 } 64 }; 65 66 struct MaskCompressOpConversion 67 : public ConvertOpToLLVMPattern<MaskCompressOp> { 68 using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern; 69 70 LogicalResult 71 matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor, 72 ConversionPatternRewriter &rewriter) const override { 73 auto opType = adaptor.getA().getType(); 74 75 Value src; 76 if (op.getSrc()) { 77 src = adaptor.getSrc(); 78 } else if (op.getConstantSrc()) { 79 src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType, 80 op.getConstantSrcAttr()); 81 } else { 82 auto zeroAttr = rewriter.getZeroAttr(opType); 83 src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr); 84 } 85 86 rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(), 87 src, adaptor.getK()); 88 89 return success(); 90 } 91 }; 92 93 struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> { 94 using ConvertOpToLLVMPattern<DotBF16Op>::ConvertOpToLLVMPattern; 95 96 LogicalResult 97 matchAndRewrite(DotBF16Op op, OpAdaptor adaptor, 98 ConversionPatternRewriter &rewriter) const override { 99 auto typeA = dyn_cast<VectorType>(op.getA().getType()); 100 unsigned elemBitWidth = typeA.getElementTypeBitWidth(); 101 unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth; 102 103 auto opType = adaptor.getSrc().getType(); 104 auto opSrc = adaptor.getSrc(); 105 auto opA = adaptor.getA(); 106 auto opB = adaptor.getB(); 107 108 switch (opBitWidth) { 109 case 128: { 110 rewriter.replaceOpWithNewOp<DotBF16Ps128IntrOp>(op, opType, opSrc, opA, 111 opB); 112 break; 113 } 114 case 256: { 115 rewriter.replaceOpWithNewOp<DotBF16Ps256IntrOp>(op, opType, opSrc, opA, 116 opB); 117 break; 118 } 119 case 512: { 120 rewriter.replaceOpWithNewOp<DotBF16Ps512IntrOp>(op, opType, opSrc, opA, 121 opB); 122 break; 123 } 124 default: { 125 return rewriter.notifyMatchFailure(op, 126 "unsupported AVX512-BF16 dot variant"); 127 } 128 } 129 130 return success(); 131 } 132 }; 133 134 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> { 135 using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern; 136 137 LogicalResult 138 matchAndRewrite(RsqrtOp op, OpAdaptor adaptor, 139 ConversionPatternRewriter &rewriter) const override { 140 auto opType = adaptor.getA().getType(); 141 rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA()); 142 return success(); 143 } 144 }; 145 146 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> { 147 using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern; 148 149 LogicalResult 150 matchAndRewrite(DotOp op, OpAdaptor adaptor, 151 ConversionPatternRewriter &rewriter) const override { 152 auto opType = adaptor.getA().getType(); 153 Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); 154 // Dot product of all elements, broadcasted to all elements. 155 auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff)); 156 Value scale = 157 rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr); 158 rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(), 159 adaptor.getB(), scale); 160 return success(); 161 } 162 }; 163 164 /// An entry associating the "main" AVX512 op with its instantiations for 165 /// vectors of 32-bit and 64-bit elements. 166 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 167 struct RegEntry { 168 using MainOp = OpTy; 169 using Intr32Op = Intr32OpTy; 170 using Intr64Op = Intr64OpTy; 171 }; 172 173 /// A container for op association entries facilitating the configuration of 174 /// dialect conversion. 175 template <typename... Args> 176 struct RegistryImpl { 177 /// Registers the patterns specializing the "main" op to one of the 178 /// "intrinsic" ops depending on elemental type. 179 static void registerPatterns(const LLVMTypeConverter &converter, 180 RewritePatternSet &patterns) { 181 patterns 182 .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op, 183 typename Args::Intr64Op>...>(converter); 184 } 185 186 /// Configures the conversion target to lower out "main" ops. 187 static void configureTarget(LLVMConversionTarget &target) { 188 target.addIllegalOp<typename Args::MainOp...>(); 189 target.addLegalOp<typename Args::Intr32Op...>(); 190 target.addLegalOp<typename Args::Intr64Op...>(); 191 } 192 }; 193 194 using Registry = RegistryImpl< 195 RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>, 196 RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>, 197 RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>; 198 199 } // namespace 200 201 /// Populate the given list with patterns that convert from X86Vector to LLVM. 202 void mlir::populateX86VectorLegalizeForLLVMExportPatterns( 203 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 204 Registry::registerPatterns(converter, patterns); 205 patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion, 206 DotOpConversion>(converter); 207 } 208 209 void mlir::configureX86VectorLegalizeForExportTarget( 210 LLVMConversionTarget &target) { 211 Registry::configureTarget(target); 212 target.addLegalOp<MaskCompressIntrOp>(); 213 target.addIllegalOp<MaskCompressOp>(); 214 target.addLegalOp<DotBF16Ps128IntrOp>(); 215 target.addLegalOp<DotBF16Ps256IntrOp>(); 216 target.addLegalOp<DotBF16Ps512IntrOp>(); 217 target.addIllegalOp<DotBF16Op>(); 218 target.addLegalOp<RsqrtIntrOp>(); 219 target.addIllegalOp<RsqrtOp>(); 220 target.addLegalOp<DotIntrOp>(); 221 target.addIllegalOp<DotOp>(); 222 } 223