xref: /llvm-project/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h (revision e553ac4d8148291914526f4f66f09e362ce0a63f)
1 //===- VectorPattern.h - Conversion pattern to the LLVM dialect -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
10 #define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
11 
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Transforms/DialectConversion.h"
14 
15 namespace mlir {
16 
17 namespace LLVM {
18 namespace detail {
19 // Helper struct to "unroll" operations on n-D vectors in terms of operations on
20 // 1-D LLVM vectors.
21 struct NDVectorTypeInfo {
22   // LLVM array struct which encodes n-D vectors.
23   Type llvmNDVectorTy;
24   // LLVM vector type which encodes the inner 1-D vector type.
25   Type llvm1DVectorTy;
26   // Multiplicity of llvmNDVectorTy to llvm1DVectorTy.
27   SmallVector<int64_t, 4> arraySizes;
28 };
29 
30 // For >1-D vector types, extracts the necessary information to iterate over all
31 // 1-D subvectors in the underlying llrepresentation of the n-D vector
32 // Iterates on the llvm array type until we hit a non-array type (which is
33 // asserted to be an llvm vector type).
34 NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
35                                          const LLVMTypeConverter &converter);
36 
37 // Express `linearIndex` in terms of coordinates of `basis`.
38 // Returns the empty vector when linearIndex is out of the range [0, P] where
39 // P is the product of all the basis coordinates.
40 //
41 // Prerequisites:
42 //   Basis is an array of nonnegative integers (signed type inherited from
43 //   vector shape type).
44 SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
45                                        unsigned linearIndex);
46 
47 // Iterate of linear index, convert to coords space and insert splatted 1-D
48 // vector in each position.
49 void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
50                      function_ref<void(ArrayRef<int64_t>)> fun);
51 
52 LogicalResult handleMultidimensionalVectors(
53     Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
54     std::function<Value(Type, ValueRange)> createOperand,
55     ConversionPatternRewriter &rewriter);
56 
57 LogicalResult vectorOneToOneRewrite(
58     Operation *op, StringRef targetOp, ValueRange operands,
59     ArrayRef<NamedAttribute> targetAttrs,
60     const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
61     IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
62 } // namespace detail
63 } // namespace LLVM
64 
65 // Default attribute conversion class, which passes all source attributes
66 // through to the target op, unmodified.
67 template <typename SourceOp, typename TargetOp>
68 class AttrConvertPassThrough {
69 public:
AttrConvertPassThrough(SourceOp srcOp)70   AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
71 
getAttrs()72   ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
getOverflowFlags()73   LLVM::IntegerOverflowFlags getOverflowFlags() const {
74     return LLVM::IntegerOverflowFlags::none;
75   }
76 
77 private:
78   ArrayRef<NamedAttribute> srcAttrs;
79 };
80 
81 /// Basic lowering implementation to rewrite Ops with just one result to the
82 /// LLVM Dialect. This supports higher-dimensional vector types.
83 /// The AttrConvert template template parameter should be a template class
84 /// with SourceOp and TargetOp type parameters, a constructor that takes
85 /// a SourceOp instance, and a getAttrs() method that returns
86 /// ArrayRef<NamedAttribute>.
87 template <typename SourceOp, typename TargetOp,
88           template <typename, typename> typename AttrConvert =
89               AttrConvertPassThrough>
90 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
91 public:
92   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
93   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
94 
95   LogicalResult
matchAndRewrite(SourceOp op,typename SourceOp::Adaptor adaptor,ConversionPatternRewriter & rewriter)96   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
97                   ConversionPatternRewriter &rewriter) const override {
98     static_assert(
99         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
100         "expected single result op");
101     // Determine attributes for the target op
102     AttrConvert<SourceOp, TargetOp> attrConvert(op);
103 
104     return LLVM::detail::vectorOneToOneRewrite(
105         op, TargetOp::getOperationName(), adaptor.getOperands(),
106         attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
107         attrConvert.getOverflowFlags());
108   }
109 };
110 } // namespace mlir
111 
112 #endif // MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
113