18508a63bSEmilio Cota //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===// 28508a63bSEmilio Cota // 38508a63bSEmilio Cota // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48508a63bSEmilio Cota // See https://llvm.org/LICENSE.txt for license information. 58508a63bSEmilio Cota // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68508a63bSEmilio Cota // 78508a63bSEmilio Cota //===----------------------------------------------------------------------===// 88508a63bSEmilio Cota 98508a63bSEmilio Cota #include "mlir/Dialect/X86Vector/Transforms.h" 108508a63bSEmilio Cota 1175e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1275e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 148508a63bSEmilio Cota #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 158508a63bSEmilio Cota #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 168508a63bSEmilio Cota #include "mlir/IR/BuiltinOps.h" 178508a63bSEmilio Cota #include "mlir/IR/PatternMatch.h" 188508a63bSEmilio Cota 198508a63bSEmilio Cota using namespace mlir; 208508a63bSEmilio Cota using namespace mlir::x86vector; 218508a63bSEmilio Cota 228508a63bSEmilio Cota /// Extracts the "main" vector element type from the given X86Vector operation. 238508a63bSEmilio Cota template <typename OpTy> 248508a63bSEmilio Cota static Type getSrcVectorElementType(OpTy op) { 255550c821STres Popp return cast<VectorType>(op.getSrc().getType()).getElementType(); 268508a63bSEmilio Cota } 278508a63bSEmilio Cota template <> 288508a63bSEmilio Cota Type getSrcVectorElementType(Vp2IntersectOp op) { 295550c821STres Popp return cast<VectorType>(op.getA().getType()).getElementType(); 308508a63bSEmilio Cota } 318508a63bSEmilio Cota 328508a63bSEmilio Cota namespace { 338508a63bSEmilio Cota 348508a63bSEmilio Cota /// Base conversion for AVX512 ops that can be lowered to one of the two 358508a63bSEmilio Cota /// intrinsics based on the bitwidth of their "main" vector element type. This 368508a63bSEmilio Cota /// relies on the to-LLVM-dialect conversion helpers to correctly pack the 378508a63bSEmilio Cota /// results of multi-result intrinsic ops. 388508a63bSEmilio Cota template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 398508a63bSEmilio Cota struct LowerToIntrinsic : public OpConversionPattern<OpTy> { 40206fad0eSMatthias Springer explicit LowerToIntrinsic(const LLVMTypeConverter &converter) 418508a63bSEmilio Cota : OpConversionPattern<OpTy>(converter, &converter.getContext()) {} 428508a63bSEmilio Cota 43ce254598SMatthias Springer const LLVMTypeConverter &getTypeConverter() const { 44ce254598SMatthias Springer return *static_cast<const LLVMTypeConverter *>( 458508a63bSEmilio Cota OpConversionPattern<OpTy>::getTypeConverter()); 468508a63bSEmilio Cota } 478508a63bSEmilio Cota 488508a63bSEmilio Cota LogicalResult 49b54c724bSRiver Riddle matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, 508508a63bSEmilio Cota ConversionPatternRewriter &rewriter) const override { 518508a63bSEmilio Cota Type elementType = getSrcVectorElementType<OpTy>(op); 528508a63bSEmilio Cota unsigned bitwidth = elementType.getIntOrFloatBitWidth(); 538508a63bSEmilio Cota if (bitwidth == 32) 54b56e65d3SJeremy Furtek return LLVM::detail::oneToOneRewrite( 55b56e65d3SJeremy Furtek op, Intr32OpTy::getOperationName(), adaptor.getOperands(), 56b56e65d3SJeremy Furtek op->getAttrs(), getTypeConverter(), rewriter); 578508a63bSEmilio Cota if (bitwidth == 64) 58b56e65d3SJeremy Furtek return LLVM::detail::oneToOneRewrite( 59b56e65d3SJeremy Furtek op, Intr64OpTy::getOperationName(), adaptor.getOperands(), 60b56e65d3SJeremy Furtek op->getAttrs(), getTypeConverter(), rewriter); 618508a63bSEmilio Cota return rewriter.notifyMatchFailure( 628508a63bSEmilio Cota op, "expected 'src' to be either f32 or f64"); 638508a63bSEmilio Cota } 648508a63bSEmilio Cota }; 658508a63bSEmilio Cota 668508a63bSEmilio Cota struct MaskCompressOpConversion 678508a63bSEmilio Cota : public ConvertOpToLLVMPattern<MaskCompressOp> { 688508a63bSEmilio Cota using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern; 698508a63bSEmilio Cota 708508a63bSEmilio Cota LogicalResult 71b54c724bSRiver Riddle matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor, 728508a63bSEmilio Cota ConversionPatternRewriter &rewriter) const override { 738df54a6aSJacques Pienaar auto opType = adaptor.getA().getType(); 748508a63bSEmilio Cota 758508a63bSEmilio Cota Value src; 768df54a6aSJacques Pienaar if (op.getSrc()) { 778df54a6aSJacques Pienaar src = adaptor.getSrc(); 788df54a6aSJacques Pienaar } else if (op.getConstantSrc()) { 79a54f4eaeSMogball src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType, 808df54a6aSJacques Pienaar op.getConstantSrcAttr()); 818508a63bSEmilio Cota } else { 826089d612SRahul Kayaith auto zeroAttr = rewriter.getZeroAttr(opType); 83a54f4eaeSMogball src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr); 848508a63bSEmilio Cota } 858508a63bSEmilio Cota 868df54a6aSJacques Pienaar rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(), 878df54a6aSJacques Pienaar src, adaptor.getK()); 888508a63bSEmilio Cota 898508a63bSEmilio Cota return success(); 908508a63bSEmilio Cota } 918508a63bSEmilio Cota }; 928508a63bSEmilio Cota 93*87782b21SAdam Siemieniuk struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> { 94*87782b21SAdam Siemieniuk using ConvertOpToLLVMPattern<DotBF16Op>::ConvertOpToLLVMPattern; 95*87782b21SAdam Siemieniuk 96*87782b21SAdam Siemieniuk LogicalResult 97*87782b21SAdam Siemieniuk matchAndRewrite(DotBF16Op op, OpAdaptor adaptor, 98*87782b21SAdam Siemieniuk ConversionPatternRewriter &rewriter) const override { 99*87782b21SAdam Siemieniuk auto typeA = dyn_cast<VectorType>(op.getA().getType()); 100*87782b21SAdam Siemieniuk unsigned elemBitWidth = typeA.getElementTypeBitWidth(); 101*87782b21SAdam Siemieniuk unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth; 102*87782b21SAdam Siemieniuk 103*87782b21SAdam Siemieniuk auto opType = adaptor.getSrc().getType(); 104*87782b21SAdam Siemieniuk auto opSrc = adaptor.getSrc(); 105*87782b21SAdam Siemieniuk auto opA = adaptor.getA(); 106*87782b21SAdam Siemieniuk auto opB = adaptor.getB(); 107*87782b21SAdam Siemieniuk 108*87782b21SAdam Siemieniuk switch (opBitWidth) { 109*87782b21SAdam Siemieniuk case 128: { 110*87782b21SAdam Siemieniuk rewriter.replaceOpWithNewOp<DotBF16Ps128IntrOp>(op, opType, opSrc, opA, 111*87782b21SAdam Siemieniuk opB); 112*87782b21SAdam Siemieniuk break; 113*87782b21SAdam Siemieniuk } 114*87782b21SAdam Siemieniuk case 256: { 115*87782b21SAdam Siemieniuk rewriter.replaceOpWithNewOp<DotBF16Ps256IntrOp>(op, opType, opSrc, opA, 116*87782b21SAdam Siemieniuk opB); 117*87782b21SAdam Siemieniuk break; 118*87782b21SAdam Siemieniuk } 119*87782b21SAdam Siemieniuk case 512: { 120*87782b21SAdam Siemieniuk rewriter.replaceOpWithNewOp<DotBF16Ps512IntrOp>(op, opType, opSrc, opA, 121*87782b21SAdam Siemieniuk opB); 122*87782b21SAdam Siemieniuk break; 123*87782b21SAdam Siemieniuk } 124*87782b21SAdam Siemieniuk default: { 125*87782b21SAdam Siemieniuk return rewriter.notifyMatchFailure(op, 126*87782b21SAdam Siemieniuk "unsupported AVX512-BF16 dot variant"); 127*87782b21SAdam Siemieniuk } 128*87782b21SAdam Siemieniuk } 129*87782b21SAdam Siemieniuk 130*87782b21SAdam Siemieniuk return success(); 131*87782b21SAdam Siemieniuk } 132*87782b21SAdam Siemieniuk }; 133*87782b21SAdam Siemieniuk 1340b63e322SEmilio Cota struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> { 1350b63e322SEmilio Cota using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern; 1360b63e322SEmilio Cota 1370b63e322SEmilio Cota LogicalResult 138b54c724bSRiver Riddle matchAndRewrite(RsqrtOp op, OpAdaptor adaptor, 1390b63e322SEmilio Cota ConversionPatternRewriter &rewriter) const override { 1408df54a6aSJacques Pienaar auto opType = adaptor.getA().getType(); 1418df54a6aSJacques Pienaar rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA()); 1420b63e322SEmilio Cota return success(); 1430b63e322SEmilio Cota } 1440b63e322SEmilio Cota }; 1450b63e322SEmilio Cota 146916f3e16SAart Bik struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> { 147916f3e16SAart Bik using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern; 148916f3e16SAart Bik 149916f3e16SAart Bik LogicalResult 150b54c724bSRiver Riddle matchAndRewrite(DotOp op, OpAdaptor adaptor, 151916f3e16SAart Bik ConversionPatternRewriter &rewriter) const override { 1528df54a6aSJacques Pienaar auto opType = adaptor.getA().getType(); 153916f3e16SAart Bik Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); 154916f3e16SAart Bik // Dot product of all elements, broadcasted to all elements. 15566b1e629SMatthias Springer auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff)); 156916f3e16SAart Bik Value scale = 157916f3e16SAart Bik rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr); 1588df54a6aSJacques Pienaar rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(), 1598df54a6aSJacques Pienaar adaptor.getB(), scale); 160916f3e16SAart Bik return success(); 161916f3e16SAart Bik } 162916f3e16SAart Bik }; 163916f3e16SAart Bik 1648508a63bSEmilio Cota /// An entry associating the "main" AVX512 op with its instantiations for 1658508a63bSEmilio Cota /// vectors of 32-bit and 64-bit elements. 1668508a63bSEmilio Cota template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 1678508a63bSEmilio Cota struct RegEntry { 1688508a63bSEmilio Cota using MainOp = OpTy; 1698508a63bSEmilio Cota using Intr32Op = Intr32OpTy; 1708508a63bSEmilio Cota using Intr64Op = Intr64OpTy; 1718508a63bSEmilio Cota }; 1728508a63bSEmilio Cota 1738508a63bSEmilio Cota /// A container for op association entries facilitating the configuration of 1748508a63bSEmilio Cota /// dialect conversion. 1758508a63bSEmilio Cota template <typename... Args> 1768508a63bSEmilio Cota struct RegistryImpl { 1778508a63bSEmilio Cota /// Registers the patterns specializing the "main" op to one of the 1788508a63bSEmilio Cota /// "intrinsic" ops depending on elemental type. 179206fad0eSMatthias Springer static void registerPatterns(const LLVMTypeConverter &converter, 1808508a63bSEmilio Cota RewritePatternSet &patterns) { 1818508a63bSEmilio Cota patterns 1828508a63bSEmilio Cota .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op, 1838508a63bSEmilio Cota typename Args::Intr64Op>...>(converter); 1848508a63bSEmilio Cota } 1858508a63bSEmilio Cota 1868508a63bSEmilio Cota /// Configures the conversion target to lower out "main" ops. 1878508a63bSEmilio Cota static void configureTarget(LLVMConversionTarget &target) { 1888508a63bSEmilio Cota target.addIllegalOp<typename Args::MainOp...>(); 1898508a63bSEmilio Cota target.addLegalOp<typename Args::Intr32Op...>(); 1908508a63bSEmilio Cota target.addLegalOp<typename Args::Intr64Op...>(); 1918508a63bSEmilio Cota } 1928508a63bSEmilio Cota }; 1938508a63bSEmilio Cota 1948508a63bSEmilio Cota using Registry = RegistryImpl< 1958508a63bSEmilio Cota RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>, 1968508a63bSEmilio Cota RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>, 1978508a63bSEmilio Cota RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>; 1988508a63bSEmilio Cota 1998508a63bSEmilio Cota } // namespace 2008508a63bSEmilio Cota 2018508a63bSEmilio Cota /// Populate the given list with patterns that convert from X86Vector to LLVM. 2028508a63bSEmilio Cota void mlir::populateX86VectorLegalizeForLLVMExportPatterns( 203206fad0eSMatthias Springer const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 2048508a63bSEmilio Cota Registry::registerPatterns(converter, patterns); 205*87782b21SAdam Siemieniuk patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion, 206*87782b21SAdam Siemieniuk DotOpConversion>(converter); 2078508a63bSEmilio Cota } 2088508a63bSEmilio Cota 2098508a63bSEmilio Cota void mlir::configureX86VectorLegalizeForExportTarget( 2108508a63bSEmilio Cota LLVMConversionTarget &target) { 2118508a63bSEmilio Cota Registry::configureTarget(target); 2128508a63bSEmilio Cota target.addLegalOp<MaskCompressIntrOp>(); 2138508a63bSEmilio Cota target.addIllegalOp<MaskCompressOp>(); 214*87782b21SAdam Siemieniuk target.addLegalOp<DotBF16Ps128IntrOp>(); 215*87782b21SAdam Siemieniuk target.addLegalOp<DotBF16Ps256IntrOp>(); 216*87782b21SAdam Siemieniuk target.addLegalOp<DotBF16Ps512IntrOp>(); 217*87782b21SAdam Siemieniuk target.addIllegalOp<DotBF16Op>(); 2180b63e322SEmilio Cota target.addLegalOp<RsqrtIntrOp>(); 2190b63e322SEmilio Cota target.addIllegalOp<RsqrtOp>(); 220916f3e16SAart Bik target.addLegalOp<DotIntrOp>(); 221916f3e16SAart Bik target.addIllegalOp<DotOp>(); 2228508a63bSEmilio Cota } 223