1 //===- VectorPattern.h - Conversion pattern to the LLVM dialect -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H 10 #define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H 11 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Transforms/DialectConversion.h" 14 15 namespace mlir { 16 17 namespace LLVM { 18 namespace detail { 19 // Helper struct to "unroll" operations on n-D vectors in terms of operations on 20 // 1-D LLVM vectors. 21 struct NDVectorTypeInfo { 22 // LLVM array struct which encodes n-D vectors. 23 Type llvmNDVectorTy; 24 // LLVM vector type which encodes the inner 1-D vector type. 25 Type llvm1DVectorTy; 26 // Multiplicity of llvmNDVectorTy to llvm1DVectorTy. 27 SmallVector<int64_t, 4> arraySizes; 28 }; 29 30 // For >1-D vector types, extracts the necessary information to iterate over all 31 // 1-D subvectors in the underlying llrepresentation of the n-D vector 32 // Iterates on the llvm array type until we hit a non-array type (which is 33 // asserted to be an llvm vector type). 34 NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, 35 const LLVMTypeConverter &converter); 36 37 // Express `linearIndex` in terms of coordinates of `basis`. 38 // Returns the empty vector when linearIndex is out of the range [0, P] where 39 // P is the product of all the basis coordinates. 40 // 41 // Prerequisites: 42 // Basis is an array of nonnegative integers (signed type inherited from 43 // vector shape type). 44 SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis, 45 unsigned linearIndex); 46 47 // Iterate of linear index, convert to coords space and insert splatted 1-D 48 // vector in each position. 49 void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, 50 function_ref<void(ArrayRef<int64_t>)> fun); 51 52 LogicalResult handleMultidimensionalVectors( 53 Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, 54 std::function<Value(Type, ValueRange)> createOperand, 55 ConversionPatternRewriter &rewriter); 56 57 LogicalResult vectorOneToOneRewrite( 58 Operation *op, StringRef targetOp, ValueRange operands, 59 ArrayRef<NamedAttribute> targetAttrs, 60 const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, 61 IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); 62 } // namespace detail 63 } // namespace LLVM 64 65 // Default attribute conversion class, which passes all source attributes 66 // through to the target op, unmodified. 67 template <typename SourceOp, typename TargetOp> 68 class AttrConvertPassThrough { 69 public: AttrConvertPassThrough(SourceOp srcOp)70 AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {} 71 getAttrs()72 ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; } getOverflowFlags()73 LLVM::IntegerOverflowFlags getOverflowFlags() const { 74 return LLVM::IntegerOverflowFlags::none; 75 } 76 77 private: 78 ArrayRef<NamedAttribute> srcAttrs; 79 }; 80 81 /// Basic lowering implementation to rewrite Ops with just one result to the 82 /// LLVM Dialect. This supports higher-dimensional vector types. 83 /// The AttrConvert template template parameter should be a template class 84 /// with SourceOp and TargetOp type parameters, a constructor that takes 85 /// a SourceOp instance, and a getAttrs() method that returns 86 /// ArrayRef<NamedAttribute>. 87 template <typename SourceOp, typename TargetOp, 88 template <typename, typename> typename AttrConvert = 89 AttrConvertPassThrough> 90 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { 91 public: 92 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 93 using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>; 94 95 LogicalResult matchAndRewrite(SourceOp op,typename SourceOp::Adaptor adaptor,ConversionPatternRewriter & rewriter)96 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 97 ConversionPatternRewriter &rewriter) const override { 98 static_assert( 99 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, 100 "expected single result op"); 101 // Determine attributes for the target op 102 AttrConvert<SourceOp, TargetOp> attrConvert(op); 103 104 return LLVM::detail::vectorOneToOneRewrite( 105 op, TargetOp::getOperationName(), adaptor.getOperands(), 106 attrConvert.getAttrs(), *this->getTypeConverter(), rewriter, 107 attrConvert.getOverflowFlags()); 108 } 109 }; 110 } // namespace mlir 111 112 #endif // MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H 113