xref: /llvm-project/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h (revision 9b06e25e73470612d14f0e1e18fde82f62266216)
1 //===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H
11 
12 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
13 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
14 #include "mlir/Interfaces/FunctionInterfaces.h"
15 
16 namespace mlir {
17 class MLIRContext;
18 class VectorTransferOpInterface;
19 class RewritePatternSet;
20 class RewriterBase;
21 
22 namespace scf {
23 class IfOp;
24 } // namespace scf
25 
26 namespace vector {
27 
28 //===----------------------------------------------------------------------===//
29 // Vector transformation options exposed as auxiliary structs.
30 //===----------------------------------------------------------------------===//
31 
32 /// Structure to control the behavior of vector transform patterns.
33 struct VectorTransformsOptions {
34   /// Option to control the lowering of vector.contract.
35   VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
36   VectorTransformsOptions &
37   setVectorTransformsOptions(VectorContractLowering opt) {
38     vectorContractLowering = opt;
39     return *this;
40   }
41   /// Option to control the lowering of vector.multi_reduction.
42   VectorMultiReductionLowering vectorMultiReductionLowering =
43       VectorMultiReductionLowering::InnerParallel;
44   VectorTransformsOptions &
45   setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
46     vectorMultiReductionLowering = opt;
47     return *this;
48   }
49   /// Option to control the lowering of vector.transpose.
50   VectorTransposeLowering vectorTransposeLowering =
51       VectorTransposeLowering::EltWise;
52   VectorTransformsOptions &
53   setVectorTransposeLowering(VectorTransposeLowering opt) {
54     vectorTransposeLowering = opt;
55     return *this;
56   }
57   /// Option to control the splitting of vector transfers.
58   VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
59   VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
60     vectorTransferSplit = opt;
61     return *this;
62   }
63 };
64 
65 //===----------------------------------------------------------------------===//
66 // Standalone transformations and helpers.
67 //===----------------------------------------------------------------------===//
68 
69 /// Split a vector.transfer operation into an in-bounds (i.e., no
70 /// out-of-bounds masking) fastpath and a slowpath. If `ifOp` is not null and
71 /// the result is `success, the `ifOp` points to the newly created conditional
72 /// upon function return. To accomodate for the fact that the original
73 /// vector.transfer indexing may be arbitrary and the slow path indexes
74 /// @[0...0] in the temporary buffer, the scf.if op returns a view and values
75 /// of type index. At this time, only vector.transfer_read case is
76 /// implemented.
77 ///
78 /// Example (a 2-D vector.transfer_read):
79 /// ```
80 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
81 /// ```
82 /// is transformed into:
83 /// ```
84 ///    %1:3 = scf.if (%inBounds) {
85 ///      // fastpath, direct cast
86 ///      memref.cast %A: memref<A...> to compatibleMemRefType
87 ///      scf.yield %view : compatibleMemRefType, index, index
88 ///    } else {
89 ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
90 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
91 ///      scf.yield %4 : compatibleMemRefType, index, index
92 //     }
93 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ...
94 ///    true]}
95 /// ```
96 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
97 ///
98 /// Preconditions:
99 ///  1. `xferOp.permutation_map()` must be a minor identity map
100 ///  2. the rank of the `xferOp.memref()` and the rank of the
101 ///  `xferOp.vector()` must be equal. This will be relaxed in the future but
102 ///  requires rank-reducing subviews.
103 LogicalResult splitFullAndPartialTransfer(
104     RewriterBase &b, VectorTransferOpInterface xferOp,
105     VectorTransformsOptions options = VectorTransformsOptions(),
106     scf::IfOp *ifOp = nullptr);
107 
108 /// Implements transfer op write to read forwarding and dead transfer write
109 /// optimizations.
110 void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp);
111 
112 /// Cast away the leading unit dim, if exists, for the given contract op.
113 /// Return success if the transformation applies; return failure otherwise.
114 FailureOr<Value>
115 castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
116                                  MaskingOpInterface maskingOp,
117                                  RewriterBase &rewriter);
118 
119 // Structure to hold the range of `vector.vscale`.
120 struct VscaleRange {
121   unsigned vscaleMin;
122   unsigned vscaleMax;
123 };
124 
125 /// Attempts to eliminate redundant vector masks by replacing them with all-true
126 /// constants at the top of the function (which results in the masks folding
127 /// away). Note: Currently, this only runs for vector.create_mask ops and
128 /// requires `vscaleRange`. If `vscaleRange` is not provided this transform does
129 /// nothing. This is because these redundant masks are much more likely for
130 /// scalable code which requires memref/tensor dynamic sizes, whereas fixed-size
131 /// code has static sizes, so simpler folds remove the masks.
132 void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
133                           std::optional<VscaleRange> vscaleRange = {});
134 
135 } // namespace vector
136 } // namespace mlir
137 
138 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H
139