//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ErrorHandling.h" using namespace mlir; template static LLVM::LLVMType getPtrToElementType(T containerType, LLVMTypeConverter &lowering) { return lowering.convertType(containerType.getElementType()) .template cast() .getPointerTo(); } class VectorExtractElementOpConversion : public LLVMOpLowering { public: explicit VectorExtractElementOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); auto extractOp = cast(op); auto vectorType = extractOp.vector()->getType().cast(); auto resultType = extractOp.getResult()->getType(); auto llvmResultType = lowering.convertType(resultType); auto positionArrayAttr = extractOp.position(); // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa()) { Value *extracted = rewriter.create( loc, llvmResultType, adaptor.vector(), positionArrayAttr); rewriter.replaceOp(op, extracted); return matchSuccess(); } // Potential extraction of 1-D vector from struct. auto *context = op->getContext(); Value *extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); auto i32Type = rewriter.getIntegerType(32); if (positionAttrs.size() > 1) { auto nDVectorType = vectorType; auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(), nDVectorType.getElementType()); auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create( loc, lowering.convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); } // Remaining extraction of element from 1-D LLVM vector auto position = positionAttrs.back().cast(); auto constant = rewriter.create( loc, lowering.convertType(i32Type), position); extracted = rewriter.create(loc, extracted, constant); rewriter.replaceOp(op, extracted); return matchSuccess(); } }; class VectorOuterProductOpConversion : public LLVMOpLowering { public: explicit VectorOuterProductOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::OuterProductOpOperandAdaptor(operands); auto *ctx = op->getContext(); auto vLHS = adaptor.lhs()->getType().cast(); auto vRHS = adaptor.rhs()->getType().cast(); auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); auto llvmArrayOfVectType = lowering.convertType( cast(op).getResult()->getType()); Value *desc = rewriter.create(loc, llvmArrayOfVectType); Value *a = adaptor.lhs(), *b = adaptor.rhs(); Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); SmallVector lhs, accs; lhs.reserve(rankLHS); accs.reserve(rankLHS); for (unsigned d = 0, e = rankLHS; d < e; ++d) { // shufflevector explicitly requires i32. auto attr = rewriter.getI32IntegerAttr(d); SmallVector bcastAttr(rankRHS, attr); auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); Value *aD = nullptr, *accD = nullptr; // 1. Broadcast the element a[d] into vector aD. aD = rewriter.create(loc, a, a, bcastArrayAttr); // 2. If acc is present, extract 1-d vector acc[d] into accD. if (acc) accD = rewriter.create( loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); // 3. Compute aD outer b (plus accD, if relevant). Value *aOuterbD = accD ? rewriter.create(loc, vRHS, aD, b, accD) .getResult() : rewriter.create(loc, aD, b).getResult(); // 4. Insert as value `d` in the descriptor. desc = rewriter.create(loc, llvmArrayOfVectType, desc, aOuterbD, rewriter.getI64ArrayAttr(d)); } rewriter.replaceOp(op, desc); return matchSuccess(); } }; class VectorTypeCastOpConversion : public LLVMOpLowering { public: explicit VectorTypeCastOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); vector::TypeCastOp castOp = cast(op); MemRefType sourceMemRefType = castOp.getOperand()->getType().cast(); MemRefType targetMemRefType = castOp.getResult()->getType().cast(); // Only static shape casts supported atm. if (!sourceMemRefType.hasStaticShape() || !targetMemRefType.hasStaticShape()) return matchFailure(); auto llvmSourceDescriptorTy = operands[0]->getType().dyn_cast(); if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) return matchFailure(); MemRefDescriptor sourceMemRef(operands[0]); auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return matchFailure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(sourceMemRefType, strides, offset); bool isContiguous = (strides.back() == 1); if (isContiguous) { auto sizes = sourceMemRefType.getShape(); for (int index = 0, e = strides.size() - 2; index < e; ++index) { if (strides[index] != strides[index + 1] * sizes[index + 1]) { isContiguous = false; break; } } } // Only contiguous source tensors supported atm. if (failed(successStrides) || !isContiguous) return matchFailure(); auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); Type llvmTargetElementTy = desc.getElementType(); // Set allocated ptr. Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = rewriter.create(loc, llvmTargetElementTy, allocated); desc.setAllocatedPtr(rewriter, loc, allocated); // Set aligned ptr. Value *ptr = sourceMemRef.alignedPtr(rewriter, loc); ptr = rewriter.create(loc, llvmTargetElementTy, ptr); desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); auto zero = rewriter.create(loc, int64Ty, attr); desc.setOffset(rewriter, loc, zero); // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { int64_t index = indexedSize.index(); auto sizeAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); auto size = rewriter.create(loc, int64Ty, sizeAttr); desc.setSize(rewriter, loc, index, size); auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); auto stride = rewriter.create(loc, int64Ty, strideAttr); desc.setStride(rewriter, loc, index, stride); } rewriter.replaceOp(op, {desc}); return matchSuccess(); } }; /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert( converter.getDialect()->getContext(), converter); } namespace { struct LowerVectorToLLVMPass : public ModulePass { void runOnModule() override; }; } // namespace void LowerVectorToLLVMPass::runOnModule() { // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; LLVMTypeConverter converter(&getContext()); populateVectorToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); target.addLegalDialect(); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); if (failed( applyPartialConversion(getModule(), target, patterns, &converter))) { signalPassFailure(); } } OpPassBase *mlir::createLowerVectorToLLVMPass() { return new LowerVectorToLLVMPass(); } static PassRegistration pass("convert-vector-to-llvm", "Lower the operations from the vector dialect into the LLVM dialect");