//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/MathExtras.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace mlir::vector; // Helper to reduce vector type by one rank at front. static VectorType reducedVectorTypeFront(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); unsigned numScalableDims = tp.getNumScalableDims(); if (tp.getShape().size() == numScalableDims) --numScalableDims; return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), numScalableDims); } // Helper to reduce vector type by *all* but one rank at back. static VectorType reducedVectorTypeBack(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); unsigned numScalableDims = tp.getNumScalableDims(); if (numScalableDims > 0) --numScalableDims; return VectorType::get(tp.getShape().take_back(), tp.getElementType(), numScalableDims); } // Helper that picks the proper sequence for inserting. static Value insertOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos) { assert(rank > 0 && "0-D vector corner case should have been handled already"); if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); return rewriter.create(loc, llvmType, val1, val2, constant); } return rewriter.create(loc, val1, val2, pos); } // Helper that picks the proper sequence for extracting. static Value extractOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos) { if (rank <= 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); return rewriter.create(loc, llvmType, val, constant); } return rewriter.create(loc, val, pos); } // Helper that returns data layout alignment of a memref. LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align) { Type elementTy = typeConverter.convertType(memrefType.getElementType()); if (!elementTy) return failure(); // TODO: this should use the MLIR data layout when it becomes available and // stop depending on translation. llvm::LLVMContext llvmContext; align = LLVM::TypeToLLVMIRTranslator(llvmContext) .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); return success(); } // Add an index vector component to a base pointer. This almost always succeeds // unless the last stride is non-unit or the memory space is not zero. static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, Value memref, Value base, Value index, MemRefType memRefType, VectorType vType, Value &ptrs) { int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(memRefType, strides, offset); if (failed(successStrides) || strides.back() != 1 || memRefType.getMemorySpaceAsInt() != 0) return failure(); auto pType = MemRefDescriptor(memref).getElementPtrType(); auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); ptrs = rewriter.create(loc, ptrsType, base, index); return success(); } // Casts a strided element pointer to a vector pointer. The vector pointer // will be in the same address space as the incoming memref type. static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr, MemRefType memRefType, Type vt) { auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); return rewriter.create(loc, pType, ptr); } namespace { /// Trivial Vector to LLVM conversions using VectorScaleOpConversion = OneToOneConvertToLLVMPattern; /// Conversion pattern for a vector.bitcast. class VectorBitCastOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only 0-D and 1-D vectors can be lowered to LLVM. VectorType resultTy = bitCastOp.getResultVectorType(); if (resultTy.getRank() > 1) return failure(); Type newResultTy = typeConverter->convertType(resultTy); rewriter.replaceOpWithNewOp(bitCastOp, newResultTy, adaptor.getOperands()[0]); return success(); } }; /// Conversion pattern for a vector.matrix_multiply. /// This is lowered directly to the proper llvm.intr.matrix.multiply. class VectorMatmulOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); return success(); } }; /// Conversion pattern for a vector.flat_transpose. /// This is lowered directly to the proper llvm.intr.matrix.transpose. class VectorFlatTransposeOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( transOp, typeConverter->convertType(transOp.getRes().getType()), adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); return success(); } }; /// Overloaded utility that replaces a vector.load, vector.store, /// vector.maskedload and vector.maskedstore with their respective LLVM /// couterparts. static void replaceLoadOrStoreOp(vector::LoadOp loadOp, vector::LoadOpAdaptor adaptor, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp(loadOp, ptr, align); } static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, vector::MaskedLoadOpAdaptor adaptor, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp( loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); } static void replaceLoadOrStoreOp(vector::StoreOp storeOp, vector::StoreOpAdaptor adaptor, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp(storeOp, adaptor.getValueToStore(), ptr, align); } static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, vector::MaskedStoreOpAdaptor adaptor, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp( storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); } /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and /// vector.maskedstore. template class VectorLoadStoreConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(LoadOrStoreOp loadOrStoreOp, typename LoadOrStoreOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only 1-D vectors can be lowered to LLVM. VectorType vectorTy = loadOrStoreOp.getVectorType(); if (vectorTy.getRank() > 1) return failure(); auto loc = loadOrStoreOp->getLoc(); MemRefType memRefTy = loadOrStoreOp.getMemRefType(); // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) return failure(); // Resolve address. auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) .template cast(); Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), adaptor.getIndices(), rewriter); Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); return success(); } }; /// Conversion pattern for a vector.gather. class VectorGatherOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = gather->getLoc(); MemRefType memRefType = gather.getBaseType().dyn_cast(); assert(memRefType && "The base should be bufferized"); // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); // Resolve address. Value ptrs; VectorType vType = gather.getVectorType(); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, adaptor.getIndexVec(), memRefType, vType, ptrs))) return failure(); // Replace with the gather intrinsic. rewriter.replaceOpWithNewOp( gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); return success(); } }; /// Conversion pattern for a vector.scatter. class VectorScatterOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = scatter->getLoc(); MemRefType memRefType = scatter.getMemRefType(); // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); // Resolve address. Value ptrs; VectorType vType = scatter.getVectorType(); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, adaptor.getIndexVec(), memRefType, vType, ptrs))) return failure(); // Replace with the scatter intrinsic. rewriter.replaceOpWithNewOp( scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), rewriter.getI32IntegerAttr(align)); return success(); } }; /// Conversion pattern for a vector.expandload. class VectorExpandLoadOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = expand->getLoc(); MemRefType memRefType = expand.getMemRefType(); // Resolve address. auto vtype = typeConverter->convertType(expand.getVectorType()); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); rewriter.replaceOpWithNewOp( expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); return success(); } }; /// Conversion pattern for a vector.compressstore. class VectorCompressStoreOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = compress->getLoc(); MemRefType memRefType = compress.getMemRefType(); // Resolve address. Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), adaptor.getIndices(), rewriter); rewriter.replaceOpWithNewOp( compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); return success(); } }; /// Helper method to lower a `vector.reduction` op that performs an arithmetic /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use /// and `ScalarOp` is the scalar operation used to add the accumulation value if /// non-null. template static Value createIntegerReductionArithmeticOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator) { Value result = rewriter.create(loc, llvmType, vectorOperand); if (accumulator) result = rewriter.create(loc, accumulator, result); return result; } /// Helper method to lower a `vector.reduction` operation that performs /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector /// intrinsic to use and `predicate` is the predicate to use to compare+combine /// the accumulator value if non-null. template static Value createIntegerReductionComparisonOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { Value result = rewriter.create(loc, llvmType, vectorOperand); if (accumulator) { Value cmp = rewriter.create(loc, predicate, accumulator, result); result = rewriter.create(loc, cmp, accumulator, result); } return result; } /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum /// with vector types. static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, Value rhs, bool isMin) { auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); Type i1Type = builder.getI1Type(); if (auto vecType = lhs.getType().dyn_cast()) i1Type = VectorType::get(vecType.getShape(), i1Type); Value cmp = builder.create( loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, lhs, rhs); Value sel = builder.create(loc, cmp, lhs, rhs); Value isNan = builder.create( loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); Value nan = builder.create( loc, lhs.getType(), builder.getFloatAttr(floatType, APFloat::getQNaN(floatType.getFloatSemantics()))); return builder.create(loc, isNan, nan, sel); } /// Conversion pattern for all vector reductions. class VectorReductionOpConversion : public ConvertOpToLLVMPattern { public: explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, bool reassociateFPRed) : ConvertOpToLLVMPattern(typeConv), reassociateFPReductions(reassociateFPRed) {} LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto kind = reductionOp.getKind(); Type eltType = reductionOp.getDest().getType(); Type llvmType = typeConverter->convertType(eltType); Value operand = adaptor.getVector(); Value acc = adaptor.getAcc(); Location loc = reductionOp.getLoc(); if (eltType.isIntOrIndex()) { // Integer reductions: add/mul/min/max/and/or/xor. Value result; switch (kind) { case vector::CombiningKind::ADD: result = createIntegerReductionArithmeticOpLowering( rewriter, loc, llvmType, operand, acc); break; case vector::CombiningKind::MUL: result = createIntegerReductionArithmeticOpLowering( rewriter, loc, llvmType, operand, acc); break; case vector::CombiningKind::MINUI: result = createIntegerReductionComparisonOpLowering< LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc, LLVM::ICmpPredicate::ule); break; case vector::CombiningKind::MINSI: result = createIntegerReductionComparisonOpLowering< LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc, LLVM::ICmpPredicate::sle); break; case vector::CombiningKind::MAXUI: result = createIntegerReductionComparisonOpLowering< LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc, LLVM::ICmpPredicate::uge); break; case vector::CombiningKind::MAXSI: result = createIntegerReductionComparisonOpLowering< LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc, LLVM::ICmpPredicate::sge); break; case vector::CombiningKind::AND: result = createIntegerReductionArithmeticOpLowering( rewriter, loc, llvmType, operand, acc); break; case vector::CombiningKind::OR: result = createIntegerReductionArithmeticOpLowering( rewriter, loc, llvmType, operand, acc); break; case vector::CombiningKind::XOR: result = createIntegerReductionArithmeticOpLowering( rewriter, loc, llvmType, operand, acc); break; default: return failure(); } rewriter.replaceOp(reductionOp, result); return success(); } if (!eltType.isa()) return failure(); // Floating-point reductions: add/mul/min/max if (kind == vector::CombiningKind::ADD) { // Optional accumulator (or zero). Value acc = adaptor.getOperands().size() > 1 ? adaptor.getOperands()[1] : rewriter.create( reductionOp->getLoc(), llvmType, rewriter.getZeroAttr(eltType)); rewriter.replaceOpWithNewOp( reductionOp, llvmType, acc, operand, rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == vector::CombiningKind::MUL) { // Optional accumulator (or one). Value acc = adaptor.getOperands().size() > 1 ? adaptor.getOperands()[1] : rewriter.create( reductionOp->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0)); rewriter.replaceOpWithNewOp( reductionOp, llvmType, acc, operand, rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == vector::CombiningKind::MINF) { // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle // NaNs/-0.0/+0.0 in the same way. Value result = rewriter.create(loc, llvmType, operand); if (acc) result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true); rewriter.replaceOp(reductionOp, result); } else if (kind == vector::CombiningKind::MAXF) { // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle // NaNs/-0.0/+0.0 in the same way. Value result = rewriter.create(loc, llvmType, operand); if (acc) result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false); rewriter.replaceOp(reductionOp, result); } else return failure(); return success(); } private: const bool reassociateFPReductions; }; class VectorShuffleOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = shuffleOp->getLoc(); auto v1Type = shuffleOp.getV1VectorType(); auto v2Type = shuffleOp.getV2VectorType(); auto vectorType = shuffleOp.getVectorType(); Type llvmType = typeConverter->convertType(vectorType); auto maskArrayAttr = shuffleOp.getMask(); // Bail if result type cannot be lowered. if (!llvmType) return failure(); // Get rank and dimension sizes. int64_t rank = vectorType.getRank(); assert(v1Type.getRank() == rank); assert(v2Type.getRank() == rank); int64_t v1Dim = v1Type.getDimSize(0); // For rank 1, where both operands have *exactly* the same vector type, // there is direct shuffle support in LLVM. Use it! if (rank == 1 && v1Type == v2Type) { Value llvmShuffleOp = rewriter.create( loc, adaptor.getV1(), adaptor.getV2(), LLVM::convertArrayToIndices(maskArrayAttr)); rewriter.replaceOp(shuffleOp, llvmShuffleOp); return success(); } // For all other cases, insert the individual values individually. Type eltType; if (auto arrayType = llvmType.dyn_cast()) eltType = arrayType.getElementType(); else eltType = llvmType.cast().getElementType(); Value insert = rewriter.create(loc, llvmType); int64_t insPos = 0; for (const auto &en : llvm::enumerate(maskArrayAttr)) { int64_t extPos = en.value().cast().getInt(); Value value = adaptor.getV1(); if (extPos >= v1Dim) { extPos -= v1Dim; value = adaptor.getV2(); } Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, eltType, rank, extPos); insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, llvmType, rank, insPos++); } rewriter.replaceOp(shuffleOp, insert); return success(); } }; class VectorExtractElementOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern< vector::ExtractElementOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto vectorType = extractEltOp.getVectorType(); auto llvmType = typeConverter->convertType(vectorType.getElementType()); // Bail if result type cannot be lowered. if (!llvmType) return failure(); if (vectorType.getRank() == 0) { Location loc = extractEltOp.getLoc(); auto idxType = rewriter.getIndexType(); auto zero = rewriter.create( loc, typeConverter->convertType(idxType), rewriter.getIntegerAttr(idxType, 0)); rewriter.replaceOpWithNewOp( extractEltOp, llvmType, adaptor.getVector(), zero); return success(); } rewriter.replaceOpWithNewOp( extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); return success(); } }; class VectorExtractOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = extractOp->getLoc(); auto resultType = extractOp.getResult().getType(); auto llvmResultType = typeConverter->convertType(resultType); auto positionArrayAttr = extractOp.getPosition(); // Bail if result type cannot be lowered. if (!llvmResultType) return failure(); // Extract entire vector. Should be handled by folder, but just to be safe. if (positionArrayAttr.empty()) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa()) { SmallVector indices; for (auto idx : positionArrayAttr.getAsRange()) indices.push_back(idx.getInt()); Value extracted = rewriter.create( loc, adaptor.getVector(), indices); rewriter.replaceOp(extractOp, extracted); return success(); } // Potential extraction of 1-D vector from array. Value extracted = adaptor.getVector(); auto positionAttrs = positionArrayAttr.getValue(); if (positionAttrs.size() > 1) { SmallVector nMinusOnePosition; for (auto idx : positionAttrs.drop_back()) nMinusOnePosition.push_back(idx.cast().getInt()); extracted = rewriter.create(loc, extracted, nMinusOnePosition); } // Remaining extraction of element from 1-D LLVM vector auto position = positionAttrs.back().cast(); auto i64Type = IntegerType::get(rewriter.getContext(), 64); auto constant = rewriter.create(loc, i64Type, position); extracted = rewriter.create(loc, extracted, constant); rewriter.replaceOp(extractOp, extracted); return success(); } }; /// Conversion pattern that turns a vector.fma on a 1-D vector /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. /// This does not match vectors of n >= 2 rank. /// /// Example: /// ``` /// vector.fma %a, %a, %a : vector<8xf32> /// ``` /// is converted to: /// ``` /// llvm.intr.fmuladd %va, %va, %va: /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) /// -> !llvm."<8 x f32>"> /// ``` class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vType = fmaOp.getVectorType(); if (vType.getRank() != 1) return failure(); rewriter.replaceOpWithNewOp( fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); return success(); } }; class VectorInsertElementOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto vectorType = insertEltOp.getDestVectorType(); auto llvmType = typeConverter->convertType(vectorType); // Bail if result type cannot be lowered. if (!llvmType) return failure(); if (vectorType.getRank() == 0) { Location loc = insertEltOp.getLoc(); auto idxType = rewriter.getIndexType(); auto zero = rewriter.create( loc, typeConverter->convertType(idxType), rewriter.getIntegerAttr(idxType, 0)); rewriter.replaceOpWithNewOp( insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); return success(); } rewriter.replaceOpWithNewOp( insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), adaptor.getPosition()); return success(); } }; class VectorInsertOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = insertOp->getLoc(); auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = typeConverter->convertType(destVectorType); auto positionArrayAttr = insertOp.getPosition(); // Bail if result type cannot be lowered. if (!llvmResultType) return failure(); // Overwrite entire vector with value. Should be handled by folder, but // just to be safe. if (positionArrayAttr.empty()) { rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa()) { Value inserted = rewriter.create( loc, adaptor.getDest(), adaptor.getSource(), LLVM::convertArrayToIndices(positionArrayAttr)); rewriter.replaceOp(insertOp, inserted); return success(); } // Potential extraction of 1-D vector from array. Value extracted = adaptor.getDest(); auto positionAttrs = positionArrayAttr.getValue(); auto position = positionAttrs.back().cast(); auto oneDVectorType = destVectorType; if (positionAttrs.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); extracted = rewriter.create( loc, extracted, LLVM::convertArrayToIndices(positionAttrs.drop_back())); } // Insertion of an element into a 1-D LLVM vector. auto i64Type = IntegerType::get(rewriter.getContext(), 64); auto constant = rewriter.create(loc, i64Type, position); Value inserted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, adaptor.getSource(), constant); // Potential insertion of resulting 1-D vector into array. if (positionAttrs.size() > 1) { inserted = rewriter.create( loc, adaptor.getDest(), inserted, LLVM::convertArrayToIndices(positionAttrs.drop_back())); } rewriter.replaceOp(insertOp, inserted); return success(); } }; /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. /// /// Example: /// ``` /// %d = vector.fma %a, %b, %c : vector<2x4xf32> /// ``` /// is rewritten into: /// ``` /// %r = splat %f0: vector<2x4xf32> /// %va = vector.extractvalue %a[0] : vector<2x4xf32> /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> /// // %r3 holds the final value. /// ``` class VectorFMAOpNDRewritePattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; void initialize() { // This pattern recursively unpacks one dimension at a time. The recursion // bounded as the rank is strictly decreasing. setHasBoundedRewriteRecursion(); } LogicalResult matchAndRewrite(FMAOp op, PatternRewriter &rewriter) const override { auto vType = op.getVectorType(); if (vType.getRank() < 2) return failure(); auto loc = op.getLoc(); auto elemType = vType.getElementType(); Value zero = rewriter.create( loc, elemType, rewriter.getZeroAttr(elemType)); Value desc = rewriter.create(loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { Value extrLHS = rewriter.create(loc, op.getLhs(), i); Value extrRHS = rewriter.create(loc, op.getRhs(), i); Value extrACC = rewriter.create(loc, op.getAcc(), i); Value fma = rewriter.create(loc, extrLHS, extrRHS, extrACC); desc = rewriter.create(loc, fma, desc, i); } rewriter.replaceOp(op, desc); return success(); } }; /// Returns the strides if the memory underlying `memRefType` has a contiguous /// static layout. static llvm::Optional> computeContiguousStrides(MemRefType memRefType) { int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(memRefType, strides, offset))) return None; if (!strides.empty() && strides.back() != 1) return None; // If no layout or identity layout, this is contiguous by definition. if (memRefType.getLayout().isIdentity()) return strides; // Otherwise, we must determine contiguity form shapes. This can only ever // work in static cases because MemRefType is underspecified to represent // contiguous dynamic shapes in other ways than with just empty/identity // layout. auto sizes = memRefType.getShape(); for (int index = 0, e = strides.size() - 1; index < e; ++index) { if (ShapedType::isDynamic(sizes[index + 1]) || ShapedType::isDynamicStrideOrOffset(strides[index]) || ShapedType::isDynamicStrideOrOffset(strides[index + 1])) return None; if (strides[index] != strides[index + 1] * sizes[index + 1]) return None; } return strides; } class VectorTypeCastOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = castOp->getLoc(); MemRefType sourceMemRefType = castOp.getOperand().getType().cast(); MemRefType targetMemRefType = castOp.getType(); // Only static shape casts supported atm. if (!sourceMemRefType.hasStaticShape() || !targetMemRefType.hasStaticShape()) return failure(); auto llvmSourceDescriptorTy = adaptor.getOperands()[0].getType().dyn_cast(); if (!llvmSourceDescriptorTy) return failure(); MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy) return failure(); // Only contiguous source buffers supported atm. auto sourceStrides = computeContiguousStrides(sourceMemRefType); if (!sourceStrides) return failure(); auto targetStrides = computeContiguousStrides(targetMemRefType); if (!targetStrides) return failure(); // Only support static strides for now, regardless of contiguity. if (llvm::any_of(*targetStrides, ShapedType::isDynamicStrideOrOffset)) return failure(); auto int64Ty = IntegerType::get(rewriter.getContext(), 64); // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); Type llvmTargetElementTy = desc.getElementPtrType(); // Set allocated ptr. Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = rewriter.create(loc, llvmTargetElementTy, allocated); desc.setAllocatedPtr(rewriter, loc, allocated); // Set aligned ptr. Value ptr = sourceMemRef.alignedPtr(rewriter, loc); ptr = rewriter.create(loc, llvmTargetElementTy, ptr); desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); auto zero = rewriter.create(loc, int64Ty, attr); desc.setOffset(rewriter, loc, zero); // Fill size and stride descriptors in memref. for (const auto &indexedSize : llvm::enumerate(targetMemRefType.getShape())) { int64_t index = indexedSize.index(); auto sizeAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); auto size = rewriter.create(loc, int64Ty, sizeAttr); desc.setSize(rewriter, loc, index, size); auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), (*targetStrides)[index]); auto stride = rewriter.create(loc, int64Ty, strideAttr); desc.setStride(rewriter, loc, index, stride); } rewriter.replaceOp(castOp, {desc}); return success(); } }; /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only). /// Non-scalable versions of this operation are handled in Vector Transforms. class VectorCreateMaskOpRewritePattern : public OpRewritePattern { public: explicit VectorCreateMaskOpRewritePattern(MLIRContext *context, bool enableIndexOpt) : OpRewritePattern(context), force32BitVectorIndices(enableIndexOpt) {} LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { auto dstType = op.getType(); if (dstType.getRank() != 1 || !dstType.cast().isScalable()) return failure(); IntegerType idxType = force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); auto loc = op->getLoc(); Value indices = rewriter.create( loc, LLVM::getVectorType(idxType, dstType.getShape()[0], /*isScalable=*/true)); auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, op.getOperand(0)); Value bounds = rewriter.create(loc, indices.getType(), bound); Value comp = rewriter.create(loc, arith::CmpIPredicate::slt, indices, bounds); rewriter.replaceOp(op, comp); return success(); } private: const bool force32BitVectorIndices; }; class VectorPrintOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Proof-of-concept lowering implementation that relies on a small // runtime support library, which only needs to provide a few // printing methods (single value for all data types, opening/closing // bracket, comma, newline). The lowering fully unrolls a vector // in terms of these elementary printing operations. The advantage // of this approach is that the library can remain unaware of all // low-level implementation details of vectors while still supporting // output of any shaped and dimensioned vector. Due to full unrolling, // this approach is less suited for very large vectors though. // // TODO: rely solely on libc in future? something else? // LogicalResult matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type printType = printOp.getPrintType(); if (typeConverter->convertType(printType) == nullptr) return failure(); // Make sure element type has runtime support. PrintConversion conversion = PrintConversion::None; VectorType vectorType = printType.dyn_cast(); Type eltType = vectorType ? vectorType.getElementType() : printType; Operation *printer; if (eltType.isF32()) { printer = LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType()); } else if (eltType.isF64()) { printer = LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType()); } else if (eltType.isIndex()) { printer = LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType()); } else if (auto intTy = eltType.dyn_cast()) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or // unsigned print method. Up to 64-bit is supported. unsigned width = intTy.getWidth(); if (intTy.isUnsigned()) { if (width <= 64) { if (width < 64) conversion = PrintConversion::ZeroExt64; printer = LLVM::lookupOrCreatePrintU64Fn( printOp->getParentOfType()); } else { return failure(); } } else { assert(intTy.isSignless() || intTy.isSigned()); if (width <= 64) { // Note that we *always* zero extend booleans (1-bit integers), // so that true/false is printed as 1/0 rather than -1/0. if (width == 1) conversion = PrintConversion::ZeroExt64; else if (width < 64) conversion = PrintConversion::SignExt64; printer = LLVM::lookupOrCreatePrintI64Fn( printOp->getParentOfType()); } else { return failure(); } } } else { return failure(); } // Unroll vector into elementary print calls. int64_t rank = vectorType ? vectorType.getRank() : 0; Type type = vectorType ? vectorType : eltType; emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, conversion); emitCall(rewriter, printOp->getLoc(), LLVM::lookupOrCreatePrintNewlineFn( printOp->getParentOfType())); rewriter.eraseOp(printOp); return success(); } private: enum class PrintConversion { // clang-format off None, ZeroExt64, SignExt64 // clang-format on }; void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, Value value, Type type, Operation *printer, int64_t rank, PrintConversion conversion) const { VectorType vectorType = type.dyn_cast(); Location loc = op->getLoc(); if (!vectorType) { assert(rank == 0 && "The scalar case expects rank == 0"); switch (conversion) { case PrintConversion::ZeroExt64: value = rewriter.create( loc, IntegerType::get(rewriter.getContext(), 64), value); break; case PrintConversion::SignExt64: value = rewriter.create( loc, IntegerType::get(rewriter.getContext(), 64), value); break; case PrintConversion::None: break; } emitCall(rewriter, loc, printer, value); return; } emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType())); Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType()); if (rank <= 1) { auto reducedType = vectorType.getElementType(); auto llvmType = typeConverter->convertType(reducedType); int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0); for (int64_t d = 0; d < dim; ++d) { Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, llvmType, /*rank=*/0, /*pos=*/d); emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0, conversion); if (d != dim - 1) emitCall(rewriter, loc, printComma); } emitCall( rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType())); return; } int64_t dim = vectorType.getDimSize(0); for (int64_t d = 0; d < dim; ++d) { auto reducedType = reducedVectorTypeFront(vectorType); auto llvmType = typeConverter->convertType(reducedType); Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, llvmType, rank, d); emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, conversion); if (d != dim - 1) emitCall(rewriter, loc, printComma); } emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType())); } // Helper to emit a call. static void emitCall(ConversionPatternRewriter &rewriter, Location loc, Operation *ref, ValueRange params = ValueRange()) { rewriter.create(loc, TypeRange(), SymbolRefAttr::get(ref), params); } }; /// The Splat operation is lowered to an insertelement + a shufflevector /// operation. Splat to only 0-d and 1-d vector result types are lowered. struct VectorSplatOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = splatOp.getType().cast(); if (resultType.getRank() > 1) return failure(); // First insert it into an undef vector so we can shuffle it. auto vectorType = typeConverter->convertType(splatOp.getType()); Value undef = rewriter.create(splatOp.getLoc(), vectorType); auto zero = rewriter.create( splatOp.getLoc(), typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); // For 0-d vector, we simply do `insertelement`. if (resultType.getRank() == 0) { rewriter.replaceOpWithNewOp( splatOp, vectorType, undef, adaptor.getInput(), zero); return success(); } // For 1-d vector, we additionally do a `vectorshuffle`. auto v = rewriter.create( splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); int64_t width = splatOp.getType().cast().getDimSize(0); SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. rewriter.replaceOpWithNewOp(splatOp, v, undef, zeroValues); return success(); } }; /// The Splat operation is lowered to an insertelement + a shufflevector /// operation. Splat to only 2+-d vector result types are lowered by the /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = splatOp.getType(); if (resultType.getRank() <= 1) return failure(); // First insert it into an undef vector so we can shuffle it. auto loc = splatOp.getLoc(); auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; if (!llvmNDVectorTy || !llvm1DVectorTy) return failure(); // Construct returned value. Value desc = rewriter.create(loc, llvmNDVectorTy); // Construct a 1-D vector with the splatted value that we insert in all the // places within the returned descriptor. Value vdesc = rewriter.create(loc, llvm1DVectorTy); auto zero = rewriter.create( loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create(loc, llvm1DVectorTy, vdesc, adaptor.getInput(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); SmallVector zeroValues(width, 0); v = rewriter.create(loc, v, v, zeroValues); // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef position) { desc = rewriter.create(loc, desc, v, position); }); rewriter.replaceOp(splatOp, desc); return success(); } }; } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions, bool force32BitVectorIndices) { MLIRContext *ctx = converter.getDialect()->getContext(); patterns.add(ctx); populateVectorInsertExtractStridedSliceTransforms(patterns); patterns.add(converter, reassociateFPReductions); patterns.add(ctx, force32BitVectorIndices); patterns .add, VectorLoadStoreConversion, VectorLoadStoreConversion, VectorLoadStoreConversion, VectorGatherOpConversion, VectorScatterOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, VectorSplatOpLowering, VectorSplatNdOpLowering>(converter); // Transfer ops with rank > 1 are handled by VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } void mlir::populateVectorToLLVMMatrixConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(converter); patterns.add(converter); }