1 //===- ViewLikeInterface.h - View-like operations interface ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the operation interface for view-like operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ 14 #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ 15 16 #include "mlir/Dialect/Utils/StaticValueUtils.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinAttributes.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/OpImplementation.h" 21 #include "mlir/IR/PatternMatch.h" 22 23 namespace mlir { 24 25 class OffsetSizeAndStrideOpInterface; 26 27 namespace detail { 28 29 LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); 30 31 bool sameOffsetsSizesAndStrides( 32 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, 33 llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp); 34 35 /// Helper method to compute the number of dynamic entries of `staticVals`, 36 /// up to `idx`. 37 unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals, 38 unsigned idx); 39 40 } // namespace detail 41 } // namespace mlir 42 43 /// Include the generated interface declarations. 44 #include "mlir/Interfaces/ViewLikeInterface.h.inc" 45 46 namespace mlir { 47 48 /// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as 49 /// constant arguments. This pattern assumes that the op has a suitable builder 50 /// that takes a result type, a "source" operand and mixed offsets, sizes and 51 /// strides. 52 /// 53 /// `OpType` is the type of op to which this pattern is applied. `ResultTypeFn` 54 /// returns the new result type of the op, based on the new offsets, sizes and 55 /// strides. `CastOpFunc` is used to generate a cast op if the result type of 56 /// the op has changed. 57 template <typename OpType, typename ResultTypeFn, typename CastOpFunc> 58 class OpWithOffsetSizesAndStridesConstantArgumentFolder final 59 : public OpRewritePattern<OpType> { 60 public: 61 using OpRewritePattern<OpType>::OpRewritePattern; 62 63 LogicalResult matchAndRewrite(OpType op, 64 PatternRewriter &rewriter) const override { 65 SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets()); 66 SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes()); 67 SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides()); 68 69 // No constant operands were folded, just return; 70 if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) && 71 failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) && 72 failed(foldDynamicIndexList(mixedStrides))) 73 return failure(); 74 75 // Create the new op in canonical form. 76 auto resultType = 77 ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides); 78 if (!resultType) 79 return failure(); 80 auto newOp = 81 rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(), 82 mixedOffsets, mixedSizes, mixedStrides); 83 CastOpFunc()(rewriter, op, newOp); 84 85 return success(); 86 } 87 }; 88 89 /// Printer hooks for custom directive in assemblyFormat. 90 /// 91 /// custom<DynamicIndexList>($values, $integers) 92 /// custom<DynamicIndexList>($values, $integers, type($values)) 93 /// 94 /// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS type 95 /// `I64ArrayAttr`. Print a list where each element is either: 96 /// 1. the static integer value in `integers`, if it's not `kDynamic` or, 97 /// 2. the next value in `values`, otherwise. 98 /// 99 /// If `valueTypes` is provided, the corresponding type of each dynamic value is 100 /// printed. Otherwise, the type is not printed. Each type must match the type 101 /// of the corresponding value in `values`. `valueTypes` is redundant for 102 /// printing as we can retrieve the types from the actual `values`. However, 103 /// `valueTypes` is needed for parsing and we must keep the API symmetric for 104 /// parsing and printing. The type for integer elements is `i64` by default and 105 /// never printed. 106 /// 107 /// Integer indices can also be scalable in the context of scalable vectors, 108 /// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in 109 /// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's 110 /// a scalable index. If `scalableFlags` is empty then assume that all indices 111 /// are non-scalable. 112 /// 113 /// Examples: 114 /// 115 /// * Input: `integers = [kDynamic, 7, 42, kDynamic]`, 116 /// `values = [%arg0, %arg42]` and 117 /// `valueTypes = [index, index]` 118 /// prints: 119 /// `[%arg0 : index, 7, 42, %arg42 : i32]` 120 /// 121 /// * Input: `integers = [kDynamic, 7, 42, kDynamic]`, 122 /// `values = [%arg0, %arg42]` and 123 /// `valueTypes = []` 124 /// prints: 125 /// `[%arg0, 7, 42, %arg42]` 126 /// 127 /// * Input: `integers = [2, 4, 8]`, 128 /// `values = []` and 129 /// `scalableFlags = [false, true, false]` 130 /// prints: 131 /// `[2, [4], 8]` 132 /// 133 void printDynamicIndexList( 134 OpAsmPrinter &printer, Operation *op, OperandRange values, 135 ArrayRef<int64_t> integers, ArrayRef<bool> scalableFlags, 136 TypeRange valueTypes = TypeRange(), 137 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); 138 inline void printDynamicIndexList( 139 OpAsmPrinter &printer, Operation *op, OperandRange values, 140 ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(), 141 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { 142 return printDynamicIndexList(printer, op, values, integers, 143 /*scalableFlags=*/{}, valueTypes, delimiter); 144 } 145 146 /// Parser hooks for custom directive in assemblyFormat. 147 /// 148 /// custom<DynamicIndexList>($values, $integers) 149 /// custom<DynamicIndexList>($values, $integers, type($values)) 150 /// 151 /// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS 152 /// type `I64ArrayAttr`. Parse a mixed list where each element is either a 153 /// static integer or an SSA value. Fill `integers` with the integer ArrayAttr, 154 /// where `kDynamic` encodes the position of SSA values. Add the parsed SSA 155 /// values to `values` in-order. 156 /// 157 /// If `valueTypes` is provided, fill it with the types corresponding to each 158 /// value in `values`. Otherwise, the caller must handle the types and parsing 159 /// will fail if the type of the value is found (e.g., `[%arg0 : index, 3, %arg1 160 /// : index]`). 161 /// 162 /// Integer indices can also be scalable in the context of scalable vectors, 163 /// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in 164 /// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's 165 /// a scalable index. 166 /// 167 /// Examples: 168 /// 169 /// * After parsing "[%arg0 : index, 7, 42, %arg42 : i32]": 170 /// 1. `result` is filled with `[kDynamic, 7, 42, kDynamic]` 171 /// 2. `values` is filled with "[%arg0, %arg1]". 172 /// 3. `scalableFlags` is filled with `[false, true, false]`. 173 /// 174 /// * After parsing `[2, [4], 8]`: 175 /// 1. `result` is filled with `[2, 4, 8]` 176 /// 2. `values` is empty. 177 /// 3. `scalableFlags` is filled with `[false, true, false]`. 178 /// 179 ParseResult parseDynamicIndexList( 180 OpAsmParser &parser, 181 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, 182 DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, 183 SmallVectorImpl<Type> *valueTypes = nullptr, 184 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); 185 inline ParseResult parseDynamicIndexList( 186 OpAsmParser &parser, 187 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, 188 DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr, 189 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { 190 DenseBoolArrayAttr scalableFlags; 191 return parseDynamicIndexList(parser, values, integers, scalableFlags, 192 valueTypes, delimiter); 193 } 194 195 /// Verify that a the `values` has as many elements as the number of entries in 196 /// `attr` for which `isDynamic` evaluates to true. 197 LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, 198 unsigned expectedNumElements, 199 ArrayRef<int64_t> attr, 200 ValueRange values); 201 202 } // namespace mlir 203 204 #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ 205