//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ErrorHandling.h" using namespace mlir; template static LLVM::LLVMType getPtrToElementType(T containerType, LLVMTypeConverter &lowering) { return lowering.convertType(containerType.getElementType()) .template cast() .getPointerTo(); } // Helper to reduce vector type by one rank at front. static VectorType reducedVectorTypeFront(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); } // Helper to reduce vector type by *all* but one rank at back. static VectorType reducedVectorTypeBack(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); return VectorType::get(tp.getShape().take_back(), tp.getElementType()); } class VectorBroadcastOpConversion : public LLVMOpLowering { public: explicit VectorBroadcastOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto broadcastOp = cast(op); VectorType dstVectorType = broadcastOp.getVectorType(); if (lowering.convertType(dstVectorType) == nullptr) return matchFailure(); // Rewrite when the full vector type can be lowered (which // implies all 'reduced' types can be lowered too). VectorType srcVectorType = broadcastOp.getSourceType().dyn_cast(); rewriter.replaceOp( op, expandRanks(operands[0], // source value to be expanded op->getLoc(), // location of original broadcast srcVectorType, dstVectorType, rewriter)); return matchSuccess(); } private: // Expands the given source value over all the ranks, as defined // by the source and destination type (a null source type denotes // expansion from a scalar value into a vector). // // TODO(ajcbik): consider replacing this one-pattern lowering // with a two-pattern lowering using other vector // ops once all insert/extract/shuffle operations // are available with lowering implemention. // Value *expandRanks(Value *value, Location loc, VectorType srcVectorType, VectorType dstVectorType, ConversionPatternRewriter &rewriter) const { assert((dstVectorType != nullptr) && "invalid result type in broadcast"); // Determine rank of source and destination. int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; int64_t dstRank = dstVectorType.getRank(); int64_t curDim = dstVectorType.getDimSize(0); if (srcRank < dstRank) // Duplicate this rank. return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank, curDim, rewriter); // If all trailing dimensions are the same, the broadcast consists of // simply passing through the source value and we are done. Otherwise, // any non-matching dimension forces a stretch along this rank. assert((srcVectorType != nullptr) && (srcRank > 0) && (srcRank == dstRank) && "invalid rank in broadcast"); for (int64_t r = 0; r < dstRank; r++) { if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) { return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank, curDim, rewriter); } } return value; } // Picks the best way to duplicate a single rank. For the 1-D case, a // single insert-elt/shuffle is the most efficient expansion. For higher // dimensions, however, we need dim x insert-values on a new broadcast // with one less leading dimension, which will be lowered "recursively" // to matching LLVM IR. // For example: // v = broadcast s : f32 to vector<4x2xf32> // becomes: // x = broadcast s : f32 to vector<2xf32> // v = [x,x,x,x] // becomes: // x = [s,s] // v = [x,x,x,x] Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType, VectorType dstVectorType, int64_t rank, int64_t dim, ConversionPatternRewriter &rewriter) const { Type llvmType = lowering.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); if (rank == 1) { Value *undef = rewriter.create(loc, llvmType); Value *expand = insertOne(undef, value, loc, llvmType, rank, 0, rewriter); SmallVector zeroValues(dim, 0); return rewriter.create( loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); } Value *expand = expandRanks(value, loc, srcVectorType, reducedVectorTypeFront(dstVectorType), rewriter); Value *result = rewriter.create(loc, llvmType); for (int64_t d = 0; d < dim; ++d) { result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); } return result; } // Picks the best way to stretch a single rank. For the 1-D case, a // single insert-elt/shuffle is the most efficient expansion when at // a stretch. Otherwise, every dimension needs to be expanded // individually and individually inserted in the resulting vector. // For example: // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32> // becomes: // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32> // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32> // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32> // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32> // v = [a,b,c,d] // becomes: // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32> // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> // a = [x, y] // etc. Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType, VectorType dstVectorType, int64_t rank, int64_t dim, ConversionPatternRewriter &rewriter) const { Type llvmType = lowering.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); Value *result = rewriter.create(loc, llvmType); bool atStretch = dim != srcVectorType.getDimSize(0); if (rank == 1) { Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); if (atStretch) { Value *one = extractOne(value, loc, redLlvmType, rank, 0, rewriter); Value *expand = insertOne(result, one, loc, llvmType, rank, 0, rewriter); SmallVector zeroValues(dim, 0); return rewriter.create( loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); } for (int64_t d = 0; d < dim; ++d) { Value *one = extractOne(value, loc, redLlvmType, rank, d, rewriter); result = insertOne(result, one, loc, llvmType, rank, d, rewriter); } } else { VectorType redSrcType = reducedVectorTypeFront(srcVectorType); VectorType redDstType = reducedVectorTypeFront(dstVectorType); Type redLlvmType = lowering.convertType(redSrcType); for (int64_t d = 0; d < dim; ++d) { int64_t pos = atStretch ? 0 : d; Value *one = extractOne(value, loc, redLlvmType, rank, pos, rewriter); Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); } } return result; } // Picks the proper sequence for inserting. Value *insertOne(Value *val1, Value *val2, Location loc, Type llvmType, int64_t rank, int64_t pos, ConversionPatternRewriter &rewriter) const { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( loc, lowering.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); return rewriter.create(loc, llvmType, val1, val2, constant); } return rewriter.create(loc, llvmType, val1, val2, rewriter.getI64ArrayAttr(pos)); } // Picks the proper sequence for extracting. Value *extractOne(Value *value, Location loc, Type llvmType, int64_t rank, int64_t pos, ConversionPatternRewriter &rewriter) const { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( loc, lowering.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); return rewriter.create(loc, llvmType, value, constant); } return rewriter.create(loc, llvmType, value, rewriter.getI64ArrayAttr(pos)); } }; class VectorExtractOpConversion : public LLVMOpLowering { public: explicit VectorExtractOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : LLVMOpLowering(vector::ExtractOp::getOperationName(), context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ExtractOpOperandAdaptor(operands); auto extractOp = cast(op); auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult()->getType(); auto llvmResultType = lowering.convertType(resultType); auto positionArrayAttr = extractOp.position(); // Bail if result type cannot be lowered. if (!llvmResultType) return matchFailure(); // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa()) { Value *extracted = rewriter.create( loc, llvmResultType, adaptor.vector(), positionArrayAttr); rewriter.replaceOp(op, extracted); return matchSuccess(); } // Potential extraction of 1-D vector from array. auto *context = op->getContext(); Value *extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); if (positionAttrs.size() > 1) { auto oneDVectorType = reducedVectorTypeBack(vectorType); auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create( loc, lowering.convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); } // Remaining extraction of element from 1-D LLVM vector auto position = positionAttrs.back().cast(); auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); auto constant = rewriter.create(loc, i32Type, position); extracted = rewriter.create(loc, extracted, constant); rewriter.replaceOp(op, extracted); return matchSuccess(); } }; class VectorInsertOpConversion : public LLVMOpLowering { public: explicit VectorInsertOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : LLVMOpLowering(vector::InsertOp::getOperationName(), context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::InsertOpOperandAdaptor(operands); auto insertOp = cast(op); auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = lowering.convertType(destVectorType); auto positionArrayAttr = insertOp.position(); // Bail if result type cannot be lowered. if (!llvmResultType) return matchFailure(); // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa()) { Value *inserted = rewriter.create( loc, llvmResultType, adaptor.dest(), adaptor.source(), positionArrayAttr); rewriter.replaceOp(op, inserted); return matchSuccess(); } // Potential extraction of 1-D vector from array. auto *context = op->getContext(); Value *extracted = adaptor.dest(); auto positionAttrs = positionArrayAttr.getValue(); auto position = positionAttrs.back().cast(); auto oneDVectorType = destVectorType; if (positionAttrs.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create( loc, lowering.convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); } // Insertion of an element into a 1-D LLVM vector. auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); auto constant = rewriter.create(loc, i32Type, position); Value *inserted = rewriter.create( loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), constant); // Potential insertion of resulting 1-D vector into array. if (positionAttrs.size() > 1) { auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); inserted = rewriter.create(loc, llvmResultType, adaptor.dest(), inserted, nMinusOnePositionAttrs); } rewriter.replaceOp(op, inserted); return matchSuccess(); } }; class VectorOuterProductOpConversion : public LLVMOpLowering { public: explicit VectorOuterProductOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::OuterProductOpOperandAdaptor(operands); auto *ctx = op->getContext(); auto vLHS = adaptor.lhs()->getType().cast(); auto vRHS = adaptor.rhs()->getType().cast(); auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); auto llvmArrayOfVectType = lowering.convertType( cast(op).getResult()->getType()); Value *desc = rewriter.create(loc, llvmArrayOfVectType); Value *a = adaptor.lhs(), *b = adaptor.rhs(); Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); SmallVector lhs, accs; lhs.reserve(rankLHS); accs.reserve(rankLHS); for (unsigned d = 0, e = rankLHS; d < e; ++d) { // shufflevector explicitly requires i32. auto attr = rewriter.getI32IntegerAttr(d); SmallVector bcastAttr(rankRHS, attr); auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); Value *aD = nullptr, *accD = nullptr; // 1. Broadcast the element a[d] into vector aD. aD = rewriter.create(loc, a, a, bcastArrayAttr); // 2. If acc is present, extract 1-d vector acc[d] into accD. if (acc) accD = rewriter.create( loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); // 3. Compute aD outer b (plus accD, if relevant). Value *aOuterbD = accD ? rewriter.create(loc, vRHS, aD, b, accD) .getResult() : rewriter.create(loc, aD, b).getResult(); // 4. Insert as value `d` in the descriptor. desc = rewriter.create(loc, llvmArrayOfVectType, desc, aOuterbD, rewriter.getI64ArrayAttr(d)); } rewriter.replaceOp(op, desc); return matchSuccess(); } }; class VectorTypeCastOpConversion : public LLVMOpLowering { public: explicit VectorTypeCastOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); vector::TypeCastOp castOp = cast(op); MemRefType sourceMemRefType = castOp.getOperand()->getType().cast(); MemRefType targetMemRefType = castOp.getResult()->getType().cast(); // Only static shape casts supported atm. if (!sourceMemRefType.hasStaticShape() || !targetMemRefType.hasStaticShape()) return matchFailure(); auto llvmSourceDescriptorTy = operands[0]->getType().dyn_cast(); if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) return matchFailure(); MemRefDescriptor sourceMemRef(operands[0]); auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return matchFailure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(sourceMemRefType, strides, offset); bool isContiguous = (strides.back() == 1); if (isContiguous) { auto sizes = sourceMemRefType.getShape(); for (int index = 0, e = strides.size() - 2; index < e; ++index) { if (strides[index] != strides[index + 1] * sizes[index + 1]) { isContiguous = false; break; } } } // Only contiguous source tensors supported atm. if (failed(successStrides) || !isContiguous) return matchFailure(); auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); Type llvmTargetElementTy = desc.getElementType(); // Set allocated ptr. Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = rewriter.create(loc, llvmTargetElementTy, allocated); desc.setAllocatedPtr(rewriter, loc, allocated); // Set aligned ptr. Value *ptr = sourceMemRef.alignedPtr(rewriter, loc); ptr = rewriter.create(loc, llvmTargetElementTy, ptr); desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); auto zero = rewriter.create(loc, int64Ty, attr); desc.setOffset(rewriter, loc, zero); // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { int64_t index = indexedSize.index(); auto sizeAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); auto size = rewriter.create(loc, int64Ty, sizeAttr); desc.setSize(rewriter, loc, index, size); auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); auto stride = rewriter.create(loc, int64Ty, strideAttr); desc.setStride(rewriter, loc, index, stride); } rewriter.replaceOp(op, {desc}); return matchSuccess(); } }; /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert( converter.getDialect()->getContext(), converter); } namespace { struct LowerVectorToLLVMPass : public ModulePass { void runOnModule() override; }; } // namespace void LowerVectorToLLVMPass::runOnModule() { // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; LLVMTypeConverter converter(&getContext()); populateVectorToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); target.addLegalDialect(); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); if (failed( applyPartialConversion(getModule(), target, patterns, &converter))) { signalPassFailure(); } } OpPassBase *mlir::createLowerVectorToLLVMPass() { return new LowerVectorToLLVMPass(); } static PassRegistration pass("convert-vector-to-llvm", "Lower the operations from the vector dialect into the LLVM dialect");