xref: /llvm-project/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp (revision 87782b216fd3e7a8f8b2de04d4af467b390e9a34)
18508a63bSEmilio Cota //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
28508a63bSEmilio Cota //
38508a63bSEmilio Cota // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48508a63bSEmilio Cota // See https://llvm.org/LICENSE.txt for license information.
58508a63bSEmilio Cota // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68508a63bSEmilio Cota //
78508a63bSEmilio Cota //===----------------------------------------------------------------------===//
88508a63bSEmilio Cota 
98508a63bSEmilio Cota #include "mlir/Dialect/X86Vector/Transforms.h"
108508a63bSEmilio Cota 
1175e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1275e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
148508a63bSEmilio Cota #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
158508a63bSEmilio Cota #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
168508a63bSEmilio Cota #include "mlir/IR/BuiltinOps.h"
178508a63bSEmilio Cota #include "mlir/IR/PatternMatch.h"
188508a63bSEmilio Cota 
198508a63bSEmilio Cota using namespace mlir;
208508a63bSEmilio Cota using namespace mlir::x86vector;
218508a63bSEmilio Cota 
228508a63bSEmilio Cota /// Extracts the "main" vector element type from the given X86Vector operation.
238508a63bSEmilio Cota template <typename OpTy>
248508a63bSEmilio Cota static Type getSrcVectorElementType(OpTy op) {
255550c821STres Popp   return cast<VectorType>(op.getSrc().getType()).getElementType();
268508a63bSEmilio Cota }
278508a63bSEmilio Cota template <>
288508a63bSEmilio Cota Type getSrcVectorElementType(Vp2IntersectOp op) {
295550c821STres Popp   return cast<VectorType>(op.getA().getType()).getElementType();
308508a63bSEmilio Cota }
318508a63bSEmilio Cota 
328508a63bSEmilio Cota namespace {
338508a63bSEmilio Cota 
348508a63bSEmilio Cota /// Base conversion for AVX512 ops that can be lowered to one of the two
358508a63bSEmilio Cota /// intrinsics based on the bitwidth of their "main" vector element type. This
368508a63bSEmilio Cota /// relies on the to-LLVM-dialect conversion helpers to correctly pack the
378508a63bSEmilio Cota /// results of multi-result intrinsic ops.
388508a63bSEmilio Cota template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
398508a63bSEmilio Cota struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
40206fad0eSMatthias Springer   explicit LowerToIntrinsic(const LLVMTypeConverter &converter)
418508a63bSEmilio Cota       : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
428508a63bSEmilio Cota 
43ce254598SMatthias Springer   const LLVMTypeConverter &getTypeConverter() const {
44ce254598SMatthias Springer     return *static_cast<const LLVMTypeConverter *>(
458508a63bSEmilio Cota         OpConversionPattern<OpTy>::getTypeConverter());
468508a63bSEmilio Cota   }
478508a63bSEmilio Cota 
488508a63bSEmilio Cota   LogicalResult
49b54c724bSRiver Riddle   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
508508a63bSEmilio Cota                   ConversionPatternRewriter &rewriter) const override {
518508a63bSEmilio Cota     Type elementType = getSrcVectorElementType<OpTy>(op);
528508a63bSEmilio Cota     unsigned bitwidth = elementType.getIntOrFloatBitWidth();
538508a63bSEmilio Cota     if (bitwidth == 32)
54b56e65d3SJeremy Furtek       return LLVM::detail::oneToOneRewrite(
55b56e65d3SJeremy Furtek           op, Intr32OpTy::getOperationName(), adaptor.getOperands(),
56b56e65d3SJeremy Furtek           op->getAttrs(), getTypeConverter(), rewriter);
578508a63bSEmilio Cota     if (bitwidth == 64)
58b56e65d3SJeremy Furtek       return LLVM::detail::oneToOneRewrite(
59b56e65d3SJeremy Furtek           op, Intr64OpTy::getOperationName(), adaptor.getOperands(),
60b56e65d3SJeremy Furtek           op->getAttrs(), getTypeConverter(), rewriter);
618508a63bSEmilio Cota     return rewriter.notifyMatchFailure(
628508a63bSEmilio Cota         op, "expected 'src' to be either f32 or f64");
638508a63bSEmilio Cota   }
648508a63bSEmilio Cota };
658508a63bSEmilio Cota 
668508a63bSEmilio Cota struct MaskCompressOpConversion
678508a63bSEmilio Cota     : public ConvertOpToLLVMPattern<MaskCompressOp> {
688508a63bSEmilio Cota   using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
698508a63bSEmilio Cota 
708508a63bSEmilio Cota   LogicalResult
71b54c724bSRiver Riddle   matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
728508a63bSEmilio Cota                   ConversionPatternRewriter &rewriter) const override {
738df54a6aSJacques Pienaar     auto opType = adaptor.getA().getType();
748508a63bSEmilio Cota 
758508a63bSEmilio Cota     Value src;
768df54a6aSJacques Pienaar     if (op.getSrc()) {
778df54a6aSJacques Pienaar       src = adaptor.getSrc();
788df54a6aSJacques Pienaar     } else if (op.getConstantSrc()) {
79a54f4eaeSMogball       src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
808df54a6aSJacques Pienaar                                                op.getConstantSrcAttr());
818508a63bSEmilio Cota     } else {
826089d612SRahul Kayaith       auto zeroAttr = rewriter.getZeroAttr(opType);
83a54f4eaeSMogball       src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
848508a63bSEmilio Cota     }
858508a63bSEmilio Cota 
868df54a6aSJacques Pienaar     rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(),
878df54a6aSJacques Pienaar                                                     src, adaptor.getK());
888508a63bSEmilio Cota 
898508a63bSEmilio Cota     return success();
908508a63bSEmilio Cota   }
918508a63bSEmilio Cota };
928508a63bSEmilio Cota 
93*87782b21SAdam Siemieniuk struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
94*87782b21SAdam Siemieniuk   using ConvertOpToLLVMPattern<DotBF16Op>::ConvertOpToLLVMPattern;
95*87782b21SAdam Siemieniuk 
96*87782b21SAdam Siemieniuk   LogicalResult
97*87782b21SAdam Siemieniuk   matchAndRewrite(DotBF16Op op, OpAdaptor adaptor,
98*87782b21SAdam Siemieniuk                   ConversionPatternRewriter &rewriter) const override {
99*87782b21SAdam Siemieniuk     auto typeA = dyn_cast<VectorType>(op.getA().getType());
100*87782b21SAdam Siemieniuk     unsigned elemBitWidth = typeA.getElementTypeBitWidth();
101*87782b21SAdam Siemieniuk     unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
102*87782b21SAdam Siemieniuk 
103*87782b21SAdam Siemieniuk     auto opType = adaptor.getSrc().getType();
104*87782b21SAdam Siemieniuk     auto opSrc = adaptor.getSrc();
105*87782b21SAdam Siemieniuk     auto opA = adaptor.getA();
106*87782b21SAdam Siemieniuk     auto opB = adaptor.getB();
107*87782b21SAdam Siemieniuk 
108*87782b21SAdam Siemieniuk     switch (opBitWidth) {
109*87782b21SAdam Siemieniuk     case 128: {
110*87782b21SAdam Siemieniuk       rewriter.replaceOpWithNewOp<DotBF16Ps128IntrOp>(op, opType, opSrc, opA,
111*87782b21SAdam Siemieniuk                                                       opB);
112*87782b21SAdam Siemieniuk       break;
113*87782b21SAdam Siemieniuk     }
114*87782b21SAdam Siemieniuk     case 256: {
115*87782b21SAdam Siemieniuk       rewriter.replaceOpWithNewOp<DotBF16Ps256IntrOp>(op, opType, opSrc, opA,
116*87782b21SAdam Siemieniuk                                                       opB);
117*87782b21SAdam Siemieniuk       break;
118*87782b21SAdam Siemieniuk     }
119*87782b21SAdam Siemieniuk     case 512: {
120*87782b21SAdam Siemieniuk       rewriter.replaceOpWithNewOp<DotBF16Ps512IntrOp>(op, opType, opSrc, opA,
121*87782b21SAdam Siemieniuk                                                       opB);
122*87782b21SAdam Siemieniuk       break;
123*87782b21SAdam Siemieniuk     }
124*87782b21SAdam Siemieniuk     default: {
125*87782b21SAdam Siemieniuk       return rewriter.notifyMatchFailure(op,
126*87782b21SAdam Siemieniuk                                          "unsupported AVX512-BF16 dot variant");
127*87782b21SAdam Siemieniuk     }
128*87782b21SAdam Siemieniuk     }
129*87782b21SAdam Siemieniuk 
130*87782b21SAdam Siemieniuk     return success();
131*87782b21SAdam Siemieniuk   }
132*87782b21SAdam Siemieniuk };
133*87782b21SAdam Siemieniuk 
1340b63e322SEmilio Cota struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
1350b63e322SEmilio Cota   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
1360b63e322SEmilio Cota 
1370b63e322SEmilio Cota   LogicalResult
138b54c724bSRiver Riddle   matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
1390b63e322SEmilio Cota                   ConversionPatternRewriter &rewriter) const override {
1408df54a6aSJacques Pienaar     auto opType = adaptor.getA().getType();
1418df54a6aSJacques Pienaar     rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA());
1420b63e322SEmilio Cota     return success();
1430b63e322SEmilio Cota   }
1440b63e322SEmilio Cota };
1450b63e322SEmilio Cota 
146916f3e16SAart Bik struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
147916f3e16SAart Bik   using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
148916f3e16SAart Bik 
149916f3e16SAart Bik   LogicalResult
150b54c724bSRiver Riddle   matchAndRewrite(DotOp op, OpAdaptor adaptor,
151916f3e16SAart Bik                   ConversionPatternRewriter &rewriter) const override {
1528df54a6aSJacques Pienaar     auto opType = adaptor.getA().getType();
153916f3e16SAart Bik     Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
154916f3e16SAart Bik     // Dot product of all elements, broadcasted to all elements.
15566b1e629SMatthias Springer     auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
156916f3e16SAart Bik     Value scale =
157916f3e16SAart Bik         rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
1588df54a6aSJacques Pienaar     rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(),
1598df54a6aSJacques Pienaar                                            adaptor.getB(), scale);
160916f3e16SAart Bik     return success();
161916f3e16SAart Bik   }
162916f3e16SAart Bik };
163916f3e16SAart Bik 
1648508a63bSEmilio Cota /// An entry associating the "main" AVX512 op with its instantiations for
1658508a63bSEmilio Cota /// vectors of 32-bit and 64-bit elements.
1668508a63bSEmilio Cota template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
1678508a63bSEmilio Cota struct RegEntry {
1688508a63bSEmilio Cota   using MainOp = OpTy;
1698508a63bSEmilio Cota   using Intr32Op = Intr32OpTy;
1708508a63bSEmilio Cota   using Intr64Op = Intr64OpTy;
1718508a63bSEmilio Cota };
1728508a63bSEmilio Cota 
1738508a63bSEmilio Cota /// A container for op association entries facilitating the configuration of
1748508a63bSEmilio Cota /// dialect conversion.
1758508a63bSEmilio Cota template <typename... Args>
1768508a63bSEmilio Cota struct RegistryImpl {
1778508a63bSEmilio Cota   /// Registers the patterns specializing the "main" op to one of the
1788508a63bSEmilio Cota   /// "intrinsic" ops depending on elemental type.
179206fad0eSMatthias Springer   static void registerPatterns(const LLVMTypeConverter &converter,
1808508a63bSEmilio Cota                                RewritePatternSet &patterns) {
1818508a63bSEmilio Cota     patterns
1828508a63bSEmilio Cota         .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
1838508a63bSEmilio Cota                               typename Args::Intr64Op>...>(converter);
1848508a63bSEmilio Cota   }
1858508a63bSEmilio Cota 
1868508a63bSEmilio Cota   /// Configures the conversion target to lower out "main" ops.
1878508a63bSEmilio Cota   static void configureTarget(LLVMConversionTarget &target) {
1888508a63bSEmilio Cota     target.addIllegalOp<typename Args::MainOp...>();
1898508a63bSEmilio Cota     target.addLegalOp<typename Args::Intr32Op...>();
1908508a63bSEmilio Cota     target.addLegalOp<typename Args::Intr64Op...>();
1918508a63bSEmilio Cota   }
1928508a63bSEmilio Cota };
1938508a63bSEmilio Cota 
1948508a63bSEmilio Cota using Registry = RegistryImpl<
1958508a63bSEmilio Cota     RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
1968508a63bSEmilio Cota     RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
1978508a63bSEmilio Cota     RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
1988508a63bSEmilio Cota 
1998508a63bSEmilio Cota } // namespace
2008508a63bSEmilio Cota 
2018508a63bSEmilio Cota /// Populate the given list with patterns that convert from X86Vector to LLVM.
2028508a63bSEmilio Cota void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
203206fad0eSMatthias Springer     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
2048508a63bSEmilio Cota   Registry::registerPatterns(converter, patterns);
205*87782b21SAdam Siemieniuk   patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
206*87782b21SAdam Siemieniuk                DotOpConversion>(converter);
2078508a63bSEmilio Cota }
2088508a63bSEmilio Cota 
2098508a63bSEmilio Cota void mlir::configureX86VectorLegalizeForExportTarget(
2108508a63bSEmilio Cota     LLVMConversionTarget &target) {
2118508a63bSEmilio Cota   Registry::configureTarget(target);
2128508a63bSEmilio Cota   target.addLegalOp<MaskCompressIntrOp>();
2138508a63bSEmilio Cota   target.addIllegalOp<MaskCompressOp>();
214*87782b21SAdam Siemieniuk   target.addLegalOp<DotBF16Ps128IntrOp>();
215*87782b21SAdam Siemieniuk   target.addLegalOp<DotBF16Ps256IntrOp>();
216*87782b21SAdam Siemieniuk   target.addLegalOp<DotBF16Ps512IntrOp>();
217*87782b21SAdam Siemieniuk   target.addIllegalOp<DotBF16Op>();
2180b63e322SEmilio Cota   target.addLegalOp<RsqrtIntrOp>();
2190b63e322SEmilio Cota   target.addIllegalOp<RsqrtOp>();
220916f3e16SAart Bik   target.addLegalOp<DotIntrOp>();
221916f3e16SAart Bik   target.addIllegalOp<DotOp>();
2228508a63bSEmilio Cota }
223