1 //===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H 10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H 11 12 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 13 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 14 #include "mlir/Interfaces/FunctionInterfaces.h" 15 16 namespace mlir { 17 class MLIRContext; 18 class VectorTransferOpInterface; 19 class RewritePatternSet; 20 class RewriterBase; 21 22 namespace scf { 23 class IfOp; 24 } // namespace scf 25 26 namespace vector { 27 28 //===----------------------------------------------------------------------===// 29 // Vector transformation options exposed as auxiliary structs. 30 //===----------------------------------------------------------------------===// 31 32 /// Structure to control the behavior of vector transform patterns. 33 struct VectorTransformsOptions { 34 /// Option to control the lowering of vector.contract. 35 VectorContractLowering vectorContractLowering = VectorContractLowering::Dot; 36 VectorTransformsOptions & 37 setVectorTransformsOptions(VectorContractLowering opt) { 38 vectorContractLowering = opt; 39 return *this; 40 } 41 /// Option to control the lowering of vector.multi_reduction. 42 VectorMultiReductionLowering vectorMultiReductionLowering = 43 VectorMultiReductionLowering::InnerParallel; 44 VectorTransformsOptions & 45 setVectorMultiReductionLowering(VectorMultiReductionLowering opt) { 46 vectorMultiReductionLowering = opt; 47 return *this; 48 } 49 /// Option to control the lowering of vector.transpose. 50 VectorTransposeLowering vectorTransposeLowering = 51 VectorTransposeLowering::EltWise; 52 VectorTransformsOptions & 53 setVectorTransposeLowering(VectorTransposeLowering opt) { 54 vectorTransposeLowering = opt; 55 return *this; 56 } 57 /// Option to control the splitting of vector transfers. 58 VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None; 59 VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) { 60 vectorTransferSplit = opt; 61 return *this; 62 } 63 }; 64 65 //===----------------------------------------------------------------------===// 66 // Standalone transformations and helpers. 67 //===----------------------------------------------------------------------===// 68 69 /// Split a vector.transfer operation into an in-bounds (i.e., no 70 /// out-of-bounds masking) fastpath and a slowpath. If `ifOp` is not null and 71 /// the result is `success, the `ifOp` points to the newly created conditional 72 /// upon function return. To accomodate for the fact that the original 73 /// vector.transfer indexing may be arbitrary and the slow path indexes 74 /// @[0...0] in the temporary buffer, the scf.if op returns a view and values 75 /// of type index. At this time, only vector.transfer_read case is 76 /// implemented. 77 /// 78 /// Example (a 2-D vector.transfer_read): 79 /// ``` 80 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 81 /// ``` 82 /// is transformed into: 83 /// ``` 84 /// %1:3 = scf.if (%inBounds) { 85 /// // fastpath, direct cast 86 /// memref.cast %A: memref<A...> to compatibleMemRefType 87 /// scf.yield %view : compatibleMemRefType, index, index 88 /// } else { 89 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 90 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 91 /// scf.yield %4 : compatibleMemRefType, index, index 92 // } 93 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... 94 /// true]} 95 /// ``` 96 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 97 /// 98 /// Preconditions: 99 /// 1. `xferOp.permutation_map()` must be a minor identity map 100 /// 2. the rank of the `xferOp.memref()` and the rank of the 101 /// `xferOp.vector()` must be equal. This will be relaxed in the future but 102 /// requires rank-reducing subviews. 103 LogicalResult splitFullAndPartialTransfer( 104 RewriterBase &b, VectorTransferOpInterface xferOp, 105 VectorTransformsOptions options = VectorTransformsOptions(), 106 scf::IfOp *ifOp = nullptr); 107 108 /// Implements transfer op write to read forwarding and dead transfer write 109 /// optimizations. 110 void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp); 111 112 /// Cast away the leading unit dim, if exists, for the given contract op. 113 /// Return success if the transformation applies; return failure otherwise. 114 FailureOr<Value> 115 castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, 116 MaskingOpInterface maskingOp, 117 RewriterBase &rewriter); 118 119 // Structure to hold the range of `vector.vscale`. 120 struct VscaleRange { 121 unsigned vscaleMin; 122 unsigned vscaleMax; 123 }; 124 125 /// Attempts to eliminate redundant vector masks by replacing them with all-true 126 /// constants at the top of the function (which results in the masks folding 127 /// away). Note: Currently, this only runs for vector.create_mask ops and 128 /// requires `vscaleRange`. If `vscaleRange` is not provided this transform does 129 /// nothing. This is because these redundant masks are much more likely for 130 /// scalable code which requires memref/tensor dynamic sizes, whereas fixed-size 131 /// code has static sizes, so simpler folds remove the masks. 132 void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function, 133 std::optional<VscaleRange> vscaleRange = {}); 134 135 } // namespace vector 136 } // namespace mlir 137 138 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORTRANSFORMS_H 139