xref: /llvm-project/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp (revision 87782b216fd3e7a8f8b2de04d4af467b390e9a34)
1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/X86Vector/Transforms.h"
10 
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace mlir::x86vector;
21 
22 /// Extracts the "main" vector element type from the given X86Vector operation.
23 template <typename OpTy>
24 static Type getSrcVectorElementType(OpTy op) {
25   return cast<VectorType>(op.getSrc().getType()).getElementType();
26 }
27 template <>
28 Type getSrcVectorElementType(Vp2IntersectOp op) {
29   return cast<VectorType>(op.getA().getType()).getElementType();
30 }
31 
32 namespace {
33 
34 /// Base conversion for AVX512 ops that can be lowered to one of the two
35 /// intrinsics based on the bitwidth of their "main" vector element type. This
36 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the
37 /// results of multi-result intrinsic ops.
38 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
39 struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
40   explicit LowerToIntrinsic(const LLVMTypeConverter &converter)
41       : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
42 
43   const LLVMTypeConverter &getTypeConverter() const {
44     return *static_cast<const LLVMTypeConverter *>(
45         OpConversionPattern<OpTy>::getTypeConverter());
46   }
47 
48   LogicalResult
49   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
50                   ConversionPatternRewriter &rewriter) const override {
51     Type elementType = getSrcVectorElementType<OpTy>(op);
52     unsigned bitwidth = elementType.getIntOrFloatBitWidth();
53     if (bitwidth == 32)
54       return LLVM::detail::oneToOneRewrite(
55           op, Intr32OpTy::getOperationName(), adaptor.getOperands(),
56           op->getAttrs(), getTypeConverter(), rewriter);
57     if (bitwidth == 64)
58       return LLVM::detail::oneToOneRewrite(
59           op, Intr64OpTy::getOperationName(), adaptor.getOperands(),
60           op->getAttrs(), getTypeConverter(), rewriter);
61     return rewriter.notifyMatchFailure(
62         op, "expected 'src' to be either f32 or f64");
63   }
64 };
65 
66 struct MaskCompressOpConversion
67     : public ConvertOpToLLVMPattern<MaskCompressOp> {
68   using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
69 
70   LogicalResult
71   matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
72                   ConversionPatternRewriter &rewriter) const override {
73     auto opType = adaptor.getA().getType();
74 
75     Value src;
76     if (op.getSrc()) {
77       src = adaptor.getSrc();
78     } else if (op.getConstantSrc()) {
79       src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
80                                                op.getConstantSrcAttr());
81     } else {
82       auto zeroAttr = rewriter.getZeroAttr(opType);
83       src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
84     }
85 
86     rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(),
87                                                     src, adaptor.getK());
88 
89     return success();
90   }
91 };
92 
93 struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
94   using ConvertOpToLLVMPattern<DotBF16Op>::ConvertOpToLLVMPattern;
95 
96   LogicalResult
97   matchAndRewrite(DotBF16Op op, OpAdaptor adaptor,
98                   ConversionPatternRewriter &rewriter) const override {
99     auto typeA = dyn_cast<VectorType>(op.getA().getType());
100     unsigned elemBitWidth = typeA.getElementTypeBitWidth();
101     unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
102 
103     auto opType = adaptor.getSrc().getType();
104     auto opSrc = adaptor.getSrc();
105     auto opA = adaptor.getA();
106     auto opB = adaptor.getB();
107 
108     switch (opBitWidth) {
109     case 128: {
110       rewriter.replaceOpWithNewOp<DotBF16Ps128IntrOp>(op, opType, opSrc, opA,
111                                                       opB);
112       break;
113     }
114     case 256: {
115       rewriter.replaceOpWithNewOp<DotBF16Ps256IntrOp>(op, opType, opSrc, opA,
116                                                       opB);
117       break;
118     }
119     case 512: {
120       rewriter.replaceOpWithNewOp<DotBF16Ps512IntrOp>(op, opType, opSrc, opA,
121                                                       opB);
122       break;
123     }
124     default: {
125       return rewriter.notifyMatchFailure(op,
126                                          "unsupported AVX512-BF16 dot variant");
127     }
128     }
129 
130     return success();
131   }
132 };
133 
134 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
135   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
136 
137   LogicalResult
138   matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
139                   ConversionPatternRewriter &rewriter) const override {
140     auto opType = adaptor.getA().getType();
141     rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA());
142     return success();
143   }
144 };
145 
146 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
147   using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
148 
149   LogicalResult
150   matchAndRewrite(DotOp op, OpAdaptor adaptor,
151                   ConversionPatternRewriter &rewriter) const override {
152     auto opType = adaptor.getA().getType();
153     Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
154     // Dot product of all elements, broadcasted to all elements.
155     auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
156     Value scale =
157         rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
158     rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(),
159                                            adaptor.getB(), scale);
160     return success();
161   }
162 };
163 
164 /// An entry associating the "main" AVX512 op with its instantiations for
165 /// vectors of 32-bit and 64-bit elements.
166 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
167 struct RegEntry {
168   using MainOp = OpTy;
169   using Intr32Op = Intr32OpTy;
170   using Intr64Op = Intr64OpTy;
171 };
172 
173 /// A container for op association entries facilitating the configuration of
174 /// dialect conversion.
175 template <typename... Args>
176 struct RegistryImpl {
177   /// Registers the patterns specializing the "main" op to one of the
178   /// "intrinsic" ops depending on elemental type.
179   static void registerPatterns(const LLVMTypeConverter &converter,
180                                RewritePatternSet &patterns) {
181     patterns
182         .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
183                               typename Args::Intr64Op>...>(converter);
184   }
185 
186   /// Configures the conversion target to lower out "main" ops.
187   static void configureTarget(LLVMConversionTarget &target) {
188     target.addIllegalOp<typename Args::MainOp...>();
189     target.addLegalOp<typename Args::Intr32Op...>();
190     target.addLegalOp<typename Args::Intr64Op...>();
191   }
192 };
193 
194 using Registry = RegistryImpl<
195     RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
196     RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
197     RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
198 
199 } // namespace
200 
201 /// Populate the given list with patterns that convert from X86Vector to LLVM.
202 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
203     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
204   Registry::registerPatterns(converter, patterns);
205   patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
206                DotOpConversion>(converter);
207 }
208 
209 void mlir::configureX86VectorLegalizeForExportTarget(
210     LLVMConversionTarget &target) {
211   Registry::configureTarget(target);
212   target.addLegalOp<MaskCompressIntrOp>();
213   target.addIllegalOp<MaskCompressOp>();
214   target.addLegalOp<DotBF16Ps128IntrOp>();
215   target.addLegalOp<DotBF16Ps256IntrOp>();
216   target.addLegalOp<DotBF16Ps512IntrOp>();
217   target.addIllegalOp<DotBF16Op>();
218   target.addLegalOp<RsqrtIntrOp>();
219   target.addIllegalOp<RsqrtOp>();
220   target.addLegalOp<DotIntrOp>();
221   target.addIllegalOp<DotOp>();
222 }
223