xref: /llvm-project/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (revision 1335a11176f99cc54f423fe173708bd2373b59f7)
1 //===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Vector dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
14 #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
15 
16 #include "mlir/Bytecode/BytecodeOpInterface.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
19 #include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Dialect.h"
24 #include "mlir/IR/OpDefinition.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/ControlFlowInterfaces.h"
27 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
28 #include "mlir/Interfaces/InferTypeOpInterface.h"
29 #include "mlir/Interfaces/SideEffectInterfaces.h"
30 #include "mlir/Interfaces/VectorInterfaces.h"
31 #include "mlir/Interfaces/ViewLikeInterface.h"
32 #include "llvm/ADT/SetVector.h"
33 #include "llvm/ADT/StringExtras.h"
34 
35 // Pull in all enum type definitions and utility function declarations.
36 #include "mlir/Dialect/Vector/IR/VectorEnums.h.inc"
37 
38 #define GET_ATTRDEF_CLASSES
39 #include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc"
40 
41 namespace mlir {
42 class MLIRContext;
43 class RewritePatternSet;
44 
45 namespace arith {
46 enum class AtomicRMWKind : uint64_t;
47 } // namespace arith
48 
49 namespace vector {
50 class ContractionOp;
51 class TransferReadOp;
52 class TransferWriteOp;
53 class VectorDialect;
54 
55 namespace detail {
56 struct BitmaskEnumStorage;
57 } // namespace detail
58 
59 /// Predefined constant_mask kinds.
60 enum class ConstantMaskKind { AllFalse = 0, AllTrue };
61 
62 /// Default callback to build a region with a 'vector.yield' terminator with no
63 /// arguments.
64 void buildTerminatedBody(OpBuilder &builder, Location loc);
65 
66 /// Return whether `srcType` can be broadcast to `dstVectorType` under the
67 /// semantics of the `vector.broadcast` op.
68 enum class BroadcastableToResult {
69   Success = 0,
70   SourceRankHigher = 1,
71   DimensionMismatch = 2,
72   SourceTypeNotAVector = 3
73 };
74 
75 struct VectorDim {
76   int64_t dim;
77   bool isScalable;
78 };
79 BroadcastableToResult
80 isBroadcastableTo(Type srcType, VectorType dstVectorType,
81                   std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr);
82 
83 /// Collect a set of vector-to-vector canonicalization patterns.
84 void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
85                                                     PatternBenefit benefit = 1);
86 
87 /// Collect a set of patterns that fold arithmetic extension on floating point
88 /// into vector contract for the backends with native support.
89 void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
90 
91 /// Collect a set of patterns that fold elementwise op on vectors to the vector
92 /// dialect.
93 void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);
94 
95 /// Returns the integer type required for subscripts in the vector dialect.
96 IntegerType getVectorSubscriptType(Builder &builder);
97 
98 /// Returns an integer array attribute containing the given values using
99 /// the integer type required for subscripts in the vector dialect.
100 ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
101 
102 /// Returns the value obtained by reducing the vector into a scalar using the
103 /// operation kind associated with a binary AtomicRMWKind op.
104 Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder,
105                            Location loc, Value vector);
106 
107 /// Build the default minor identity map suitable for a vector transfer. This
108 /// also handles the case memref<... x vector<...>> -> vector<...> in which the
109 /// rank of the identity map must take the vector element type into account.
110 AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
111                                       VectorType vectorType);
112 
113 /// Return true if the transfer_write fully writes the data accessed by the
114 /// transfer_read.
115 bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read);
116 
117 /// Return true if the write op fully over-write the priorWrite transfer_write
118 /// op.
119 bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite);
120 
121 /// Return true if we can prove that the transfer operations access disjoint
122 /// memory, without requring the accessed tensor/memref to be the same.
123 ///
124 /// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
125 /// via ValueBoundsOpInterface.
126 bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
127                                VectorTransferOpInterface transferB,
128                                bool testDynamicValueUsingBounds = false);
129 
130 /// Return true if we can prove that the transfer operations access disjoint
131 /// memory, requiring the operations to access the same tensor/memref.
132 ///
133 /// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
134 /// via ValueBoundsOpInterface.
135 bool isDisjointTransferSet(VectorTransferOpInterface transferA,
136                            VectorTransferOpInterface transferB,
137                            bool testDynamicValueUsingBounds = false);
138 
139 /// Returns the result value of reducing two scalar/vector values with the
140 /// corresponding arith operation.
141 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
142                          Value v1, Value acc,
143                          arith::FastMathFlagsAttr fastmath = nullptr,
144                          Value mask = nullptr);
145 
146 /// Returns true if `attr` has "parallel" iterator type semantics.
147 inline bool isParallelIterator(Attribute attr) {
148   return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::parallel;
149 }
150 
151 /// Returns true if `attr` has "reduction" iterator type semantics.
152 inline bool isReductionIterator(Attribute attr) {
153   return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::reduction;
154 }
155 
156 /// Returns the integer numbers in `values`. `values` are expected to be
157 /// constant operations.
158 SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values);
159 
160 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
161 /// be constant operations.
162 SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults);
163 
164 /// Convert `foldResults` into Values. Integer attributes are converted to
165 /// constant op.
166 SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
167                                ArrayRef<OpFoldResult> foldResults);
168 
169 /// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
170 /// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
171 /// `std::nullopt`.
172 std::optional<int64_t> getConstantVscaleMultiplier(Value value);
173 
174 //===----------------------------------------------------------------------===//
175 // Vector Masking Utilities
176 //===----------------------------------------------------------------------===//
177 
178 /// Infers the mask type for a transfer op given its vector type and
179 /// permutation map. The mask in a transfer op operation applies to the
180 /// tensor/buffer part of it and its type should match the vector shape
181 /// *before* any permutation or broadcasting. For example,
182 ///
183 /// vecType = vector<1x2x3xf32>, permMap = affine_map<(d0, d1, d2) -> (d1, d0)>
184 ///
185 /// Has inferred mask type:
186 ///
187 /// maskType = vector<2x1xi1>
188 VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap);
189 
190 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
191 /// as masked operation.
192 void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
193 
194 /// Creates a vector.mask operation around a maskable operation. Returns the
195 /// vector.mask operation if the mask provided is valid. Otherwise, returns the
196 /// maskable operation itself.
197 Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask,
198                          Value passthru = Value());
199 
200 /// Creates a vector select operation that picks values from `newValue` or
201 /// `passthru` for each result vector lane based on `mask`. This utility is used
202 /// to propagate the pass-thru value for masked-out or expeculatively executed
203 /// lanes. VP intrinsics do not support pass-thru values and every mask-out lane
204 /// is set to poison. LLVM backends are usually able to match op + select
205 /// patterns and fold them into a native target instructions.
206 Value selectPassthru(OpBuilder &builder, Value mask, Value newValue,
207                      Value passthru);
208 
209 } // namespace vector
210 } // namespace mlir
211 
212 #define GET_OP_CLASSES
213 #include "mlir/Dialect/Vector/IR/VectorDialect.h.inc"
214 #include "mlir/Dialect/Vector/IR/VectorOps.h.inc"
215 
216 #endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
217