1 //===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Vector dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H 14 #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H 15 16 #include "mlir/Bytecode/BytecodeOpInterface.h" 17 #include "mlir/Dialect/Arith/IR/Arith.h" 18 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" 19 #include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/Attributes.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/IR/Dialect.h" 24 #include "mlir/IR/OpDefinition.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Interfaces/ControlFlowInterfaces.h" 27 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 28 #include "mlir/Interfaces/InferTypeOpInterface.h" 29 #include "mlir/Interfaces/SideEffectInterfaces.h" 30 #include "mlir/Interfaces/VectorInterfaces.h" 31 #include "mlir/Interfaces/ViewLikeInterface.h" 32 #include "llvm/ADT/SetVector.h" 33 #include "llvm/ADT/StringExtras.h" 34 35 // Pull in all enum type definitions and utility function declarations. 36 #include "mlir/Dialect/Vector/IR/VectorEnums.h.inc" 37 38 #define GET_ATTRDEF_CLASSES 39 #include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc" 40 41 namespace mlir { 42 class MLIRContext; 43 class RewritePatternSet; 44 45 namespace arith { 46 enum class AtomicRMWKind : uint64_t; 47 } // namespace arith 48 49 namespace vector { 50 class ContractionOp; 51 class TransferReadOp; 52 class TransferWriteOp; 53 class VectorDialect; 54 55 namespace detail { 56 struct BitmaskEnumStorage; 57 } // namespace detail 58 59 /// Predefined constant_mask kinds. 60 enum class ConstantMaskKind { AllFalse = 0, AllTrue }; 61 62 /// Default callback to build a region with a 'vector.yield' terminator with no 63 /// arguments. 64 void buildTerminatedBody(OpBuilder &builder, Location loc); 65 66 /// Return whether `srcType` can be broadcast to `dstVectorType` under the 67 /// semantics of the `vector.broadcast` op. 68 enum class BroadcastableToResult { 69 Success = 0, 70 SourceRankHigher = 1, 71 DimensionMismatch = 2, 72 SourceTypeNotAVector = 3 73 }; 74 75 struct VectorDim { 76 int64_t dim; 77 bool isScalable; 78 }; 79 BroadcastableToResult 80 isBroadcastableTo(Type srcType, VectorType dstVectorType, 81 std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr); 82 83 /// Collect a set of vector-to-vector canonicalization patterns. 84 void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, 85 PatternBenefit benefit = 1); 86 87 /// Collect a set of patterns that fold arithmetic extension on floating point 88 /// into vector contract for the backends with native support. 89 void populateFoldArithExtensionPatterns(RewritePatternSet &patterns); 90 91 /// Collect a set of patterns that fold elementwise op on vectors to the vector 92 /// dialect. 93 void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns); 94 95 /// Returns the integer type required for subscripts in the vector dialect. 96 IntegerType getVectorSubscriptType(Builder &builder); 97 98 /// Returns an integer array attribute containing the given values using 99 /// the integer type required for subscripts in the vector dialect. 100 ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values); 101 102 /// Returns the value obtained by reducing the vector into a scalar using the 103 /// operation kind associated with a binary AtomicRMWKind op. 104 Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, 105 Location loc, Value vector); 106 107 /// Build the default minor identity map suitable for a vector transfer. This 108 /// also handles the case memref<... x vector<...>> -> vector<...> in which the 109 /// rank of the identity map must take the vector element type into account. 110 AffineMap getTransferMinorIdentityMap(ShapedType shapedType, 111 VectorType vectorType); 112 113 /// Return true if the transfer_write fully writes the data accessed by the 114 /// transfer_read. 115 bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read); 116 117 /// Return true if the write op fully over-write the priorWrite transfer_write 118 /// op. 119 bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite); 120 121 /// Return true if we can prove that the transfer operations access disjoint 122 /// memory, without requring the accessed tensor/memref to be the same. 123 /// 124 /// If `testDynamicValueUsingBounds` is true, tries to test dynamic values 125 /// via ValueBoundsOpInterface. 126 bool isDisjointTransferIndices(VectorTransferOpInterface transferA, 127 VectorTransferOpInterface transferB, 128 bool testDynamicValueUsingBounds = false); 129 130 /// Return true if we can prove that the transfer operations access disjoint 131 /// memory, requiring the operations to access the same tensor/memref. 132 /// 133 /// If `testDynamicValueUsingBounds` is true, tries to test dynamic values 134 /// via ValueBoundsOpInterface. 135 bool isDisjointTransferSet(VectorTransferOpInterface transferA, 136 VectorTransferOpInterface transferB, 137 bool testDynamicValueUsingBounds = false); 138 139 /// Returns the result value of reducing two scalar/vector values with the 140 /// corresponding arith operation. 141 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, 142 Value v1, Value acc, 143 arith::FastMathFlagsAttr fastmath = nullptr, 144 Value mask = nullptr); 145 146 /// Returns true if `attr` has "parallel" iterator type semantics. 147 inline bool isParallelIterator(Attribute attr) { 148 return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::parallel; 149 } 150 151 /// Returns true if `attr` has "reduction" iterator type semantics. 152 inline bool isReductionIterator(Attribute attr) { 153 return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::reduction; 154 } 155 156 /// Returns the integer numbers in `values`. `values` are expected to be 157 /// constant operations. 158 SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values); 159 160 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to 161 /// be constant operations. 162 SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults); 163 164 /// Convert `foldResults` into Values. Integer attributes are converted to 165 /// constant op. 166 SmallVector<Value> getAsValues(OpBuilder &builder, Location loc, 167 ArrayRef<OpFoldResult> foldResults); 168 169 /// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst * 170 /// vector.vscale`), return the multiplier (`%cst`). Otherwise, return 171 /// `std::nullopt`. 172 std::optional<int64_t> getConstantVscaleMultiplier(Value value); 173 174 //===----------------------------------------------------------------------===// 175 // Vector Masking Utilities 176 //===----------------------------------------------------------------------===// 177 178 /// Infers the mask type for a transfer op given its vector type and 179 /// permutation map. The mask in a transfer op operation applies to the 180 /// tensor/buffer part of it and its type should match the vector shape 181 /// *before* any permutation or broadcasting. For example, 182 /// 183 /// vecType = vector<1x2x3xf32>, permMap = affine_map<(d0, d1, d2) -> (d1, d0)> 184 /// 185 /// Has inferred mask type: 186 /// 187 /// maskType = vector<2x1xi1> 188 VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap); 189 190 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp` 191 /// as masked operation. 192 void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp); 193 194 /// Creates a vector.mask operation around a maskable operation. Returns the 195 /// vector.mask operation if the mask provided is valid. Otherwise, returns the 196 /// maskable operation itself. 197 Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, 198 Value passthru = Value()); 199 200 /// Creates a vector select operation that picks values from `newValue` or 201 /// `passthru` for each result vector lane based on `mask`. This utility is used 202 /// to propagate the pass-thru value for masked-out or expeculatively executed 203 /// lanes. VP intrinsics do not support pass-thru values and every mask-out lane 204 /// is set to poison. LLVM backends are usually able to match op + select 205 /// patterns and fold them into a native target instructions. 206 Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, 207 Value passthru); 208 209 } // namespace vector 210 } // namespace mlir 211 212 #define GET_OP_CLASSES 213 #include "mlir/Dialect/Vector/IR/VectorDialect.h.inc" 214 #include "mlir/Dialect/Vector/IR/VectorOps.h.inc" 215 216 #endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H 217