xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 5c0c51a9979ffaf1eb57be061bed56a05ae6ddd4)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
20 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/Dialect/VectorOps/VectorOps.h"
23 #include "mlir/IR/Attributes.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/IR/Module.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/StandardTypes.h"
30 #include "mlir/IR/Types.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Pass/PassManager.h"
33 #include "mlir/Transforms/DialectConversion.h"
34 #include "mlir/Transforms/Passes.h"
35 
36 #include "llvm/IR/DerivedTypes.h"
37 #include "llvm/IR/Module.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/Support/Allocator.h"
40 #include "llvm/Support/ErrorHandling.h"
41 
42 using namespace mlir;
43 
44 template <typename T>
45 static LLVM::LLVMType getPtrToElementType(T containerType,
46                                           LLVMTypeConverter &lowering) {
47   return lowering.convertType(containerType.getElementType())
48       .template cast<LLVM::LLVMType>()
49       .getPointerTo();
50 }
51 
52 class VectorExtractElementOpConversion : public LLVMOpLowering {
53 public:
54   explicit VectorExtractElementOpConversion(MLIRContext *context,
55                                             LLVMTypeConverter &typeConverter)
56       : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
57                        typeConverter) {}
58 
59   PatternMatchResult
60   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
61                   ConversionPatternRewriter &rewriter) const override {
62     auto loc = op->getLoc();
63     auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
64     auto extractOp = cast<vector::ExtractElementOp>(op);
65     auto vectorType = extractOp.vector()->getType().cast<VectorType>();
66     auto resultType = extractOp.getResult()->getType();
67     auto llvmResultType = lowering.convertType(resultType);
68 
69     auto positionArrayAttr = extractOp.position();
70     // One-shot extraction of vector from array (only requires extractvalue).
71     if (resultType.isa<VectorType>()) {
72       Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
73           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
74       rewriter.replaceOp(op, extracted);
75       return matchSuccess();
76     }
77 
78     // Potential extraction of 1-D vector from struct.
79     auto *context = op->getContext();
80     Value *extracted = adaptor.vector();
81     auto positionAttrs = positionArrayAttr.getValue();
82     auto i32Type = rewriter.getIntegerType(32);
83     if (positionAttrs.size() > 1) {
84       auto nDVectorType = vectorType;
85       auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
86                                             nDVectorType.getElementType());
87       auto nMinusOnePositionAttrs =
88           ArrayAttr::get(positionAttrs.drop_back(), context);
89       extracted = rewriter.create<LLVM::ExtractValueOp>(
90           loc, lowering.convertType(oneDVectorType), extracted,
91           nMinusOnePositionAttrs);
92     }
93 
94     // Remaining extraction of element from 1-D LLVM vector
95     auto position = positionAttrs.back().cast<IntegerAttr>();
96     auto constant = rewriter.create<LLVM::ConstantOp>(
97         loc, lowering.convertType(i32Type), position);
98     extracted =
99         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
100     rewriter.replaceOp(op, extracted);
101 
102     return matchSuccess();
103   }
104 };
105 
106 class VectorOuterProductOpConversion : public LLVMOpLowering {
107 public:
108   explicit VectorOuterProductOpConversion(MLIRContext *context,
109                                           LLVMTypeConverter &typeConverter)
110       : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context,
111                        typeConverter) {}
112 
113   PatternMatchResult
114   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
115                   ConversionPatternRewriter &rewriter) const override {
116     auto loc = op->getLoc();
117     auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
118     auto *ctx = op->getContext();
119     auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
120     auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
121     auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
122     auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
123     auto llvmArrayOfVectType = lowering.convertType(
124         cast<vector::OuterProductOp>(op).getResult()->getType());
125     Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
126     Value *a = adaptor.lhs(), *b = adaptor.rhs();
127     Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
128     SmallVector<Value *, 8> lhs, accs;
129     lhs.reserve(rankLHS);
130     accs.reserve(rankLHS);
131     for (unsigned d = 0, e = rankLHS; d < e; ++d) {
132       // shufflevector explicitly requires i32.
133       auto attr = rewriter.getI32IntegerAttr(d);
134       SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
135       auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
136       Value *aD = nullptr, *accD = nullptr;
137       // 1. Broadcast the element a[d] into vector aD.
138       aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
139       // 2. If acc is present, extract 1-d vector acc[d] into accD.
140       if (acc)
141         accD = rewriter.create<LLVM::ExtractValueOp>(
142             loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
143       // 3. Compute aD outer b (plus accD, if relevant).
144       Value *aOuterbD =
145           accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD)
146                      .getResult()
147                : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
148       // 4. Insert as value `d` in the descriptor.
149       desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
150                                                   desc, aOuterbD,
151                                                   rewriter.getI64ArrayAttr(d));
152     }
153     rewriter.replaceOp(op, desc);
154     return matchSuccess();
155   }
156 };
157 
158 class VectorTypeCastOpConversion : public LLVMOpLowering {
159 public:
160   explicit VectorTypeCastOpConversion(MLIRContext *context,
161                                       LLVMTypeConverter &typeConverter)
162       : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context,
163                        typeConverter) {}
164 
165   PatternMatchResult
166   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
167                   ConversionPatternRewriter &rewriter) const override {
168     auto loc = op->getLoc();
169     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
170     MemRefType sourceMemRefType =
171         castOp.getOperand()->getType().cast<MemRefType>();
172     MemRefType targetMemRefType =
173         castOp.getResult()->getType().cast<MemRefType>();
174 
175     // Only static shape casts supported atm.
176     if (!sourceMemRefType.hasStaticShape() ||
177         !targetMemRefType.hasStaticShape())
178       return matchFailure();
179 
180     auto llvmSourceDescriptorTy =
181         operands[0]->getType().dyn_cast<LLVM::LLVMType>();
182     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
183       return matchFailure();
184     MemRefDescriptor sourceMemRef(operands[0]);
185 
186     auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
187                                       .dyn_cast_or_null<LLVM::LLVMType>();
188     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
189       return matchFailure();
190 
191     int64_t offset;
192     SmallVector<int64_t, 4> strides;
193     auto successStrides =
194         getStridesAndOffset(sourceMemRefType, strides, offset);
195     bool isContiguous = (strides.back() == 1);
196     if (isContiguous) {
197       auto sizes = sourceMemRefType.getShape();
198       for (int index = 0, e = strides.size() - 2; index < e; ++index) {
199         if (strides[index] != strides[index + 1] * sizes[index + 1]) {
200           isContiguous = false;
201           break;
202         }
203       }
204     }
205     // Only contiguous source tensors supported atm.
206     if (failed(successStrides) || !isContiguous)
207       return matchFailure();
208 
209     auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
210 
211     // Create descriptor.
212     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
213     Type llvmTargetElementTy = desc.getElementType();
214     // Set allocated ptr.
215     Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc);
216     allocated =
217         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
218     desc.setAllocatedPtr(rewriter, loc, allocated);
219     // Set aligned ptr.
220     Value *ptr = sourceMemRef.alignedPtr(rewriter, loc);
221     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
222     desc.setAlignedPtr(rewriter, loc, ptr);
223     // Fill offset 0.
224     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
225     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
226     desc.setOffset(rewriter, loc, zero);
227 
228     // Fill size and stride descriptors in memref.
229     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
230       int64_t index = indexedSize.index();
231       auto sizeAttr =
232           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
233       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
234       desc.setSize(rewriter, loc, index, size);
235       auto strideAttr =
236           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
237       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
238       desc.setStride(rewriter, loc, index, stride);
239     }
240 
241     rewriter.replaceOp(op, {desc});
242     return matchSuccess();
243   }
244 };
245 
246 /// Populate the given list with patterns that convert from Vector to LLVM.
247 void mlir::populateVectorToLLVMConversionPatterns(
248     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
249   patterns.insert<VectorExtractElementOpConversion,
250                   VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
251       converter.getDialect()->getContext(), converter);
252 }
253 
254 namespace {
255 struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
256   void runOnModule() override;
257 };
258 } // namespace
259 
260 void LowerVectorToLLVMPass::runOnModule() {
261   // Convert to the LLVM IR dialect using the converter defined above.
262   OwningRewritePatternList patterns;
263   LLVMTypeConverter converter(&getContext());
264   populateVectorToLLVMConversionPatterns(converter, patterns);
265   populateStdToLLVMConversionPatterns(converter, patterns);
266 
267   ConversionTarget target(getContext());
268   target.addLegalDialect<LLVM::LLVMDialect>();
269   target.addDynamicallyLegalOp<FuncOp>(
270       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
271   if (failed(
272           applyPartialConversion(getModule(), target, patterns, &converter))) {
273     signalPassFailure();
274   }
275 }
276 
277 OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() {
278   return new LowerVectorToLLVMPass();
279 }
280 
281 static PassRegistration<LowerVectorToLLVMPass>
282     pass("convert-vector-to-llvm",
283          "Lower the operations from the vector dialect into the LLVM dialect");
284