xref: /llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.h (revision ef4800c9168ee45ced8295d13ac68f58b4358759)
1 //===- ViewLikeInterface.h - View-like operations interface ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the operation interface for view-like operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
14 #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
15 
16 #include "mlir/Dialect/Utils/StaticValueUtils.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 
23 namespace mlir {
24 
25 class OffsetSizeAndStrideOpInterface;
26 
27 namespace detail {
28 
29 LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
30 
31 bool sameOffsetsSizesAndStrides(
32     OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
33     llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp);
34 
35 /// Helper method to compute the number of dynamic entries of `staticVals`,
36 /// up to `idx`.
37 unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
38                                      unsigned idx);
39 
40 } // namespace detail
41 } // namespace mlir
42 
43 /// Include the generated interface declarations.
44 #include "mlir/Interfaces/ViewLikeInterface.h.inc"
45 
46 namespace mlir {
47 
48 /// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
49 /// constant arguments. This pattern assumes that the op has a suitable builder
50 /// that takes a result type, a "source" operand and mixed offsets, sizes and
51 /// strides.
52 ///
53 /// `OpType` is the type of op to which this pattern is applied. `ResultTypeFn`
54 /// returns the new result type of the op, based on the new offsets, sizes and
55 /// strides. `CastOpFunc` is used to generate a cast op if the result type of
56 /// the op has changed.
57 template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
58 class OpWithOffsetSizesAndStridesConstantArgumentFolder final
59     : public OpRewritePattern<OpType> {
60 public:
61   using OpRewritePattern<OpType>::OpRewritePattern;
62 
63   LogicalResult matchAndRewrite(OpType op,
64                                 PatternRewriter &rewriter) const override {
65     SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
66     SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
67     SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
68 
69     // No constant operands were folded, just return;
70     if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
71         failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
72         failed(foldDynamicIndexList(mixedStrides)))
73       return failure();
74 
75     // Create the new op in canonical form.
76     auto resultType =
77         ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
78     if (!resultType)
79       return failure();
80     auto newOp =
81         rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
82                                 mixedOffsets, mixedSizes, mixedStrides);
83     CastOpFunc()(rewriter, op, newOp);
84 
85     return success();
86   }
87 };
88 
89 /// Printer hooks for custom directive in assemblyFormat.
90 ///
91 ///   custom<DynamicIndexList>($values, $integers)
92 ///   custom<DynamicIndexList>($values, $integers, type($values))
93 ///
94 /// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS type
95 /// `I64ArrayAttr`. Print a list where each element is either:
96 ///    1. the static integer value in `integers`, if it's not `kDynamic` or,
97 ///    2. the next value in `values`, otherwise.
98 ///
99 /// If `valueTypes` is provided, the corresponding type of each dynamic value is
100 /// printed. Otherwise, the type is not printed. Each type must match the type
101 /// of the corresponding value in `values`. `valueTypes` is redundant for
102 /// printing as we can retrieve the types from the actual `values`. However,
103 /// `valueTypes` is needed for parsing and we must keep the API symmetric for
104 /// parsing and printing. The type for integer elements is `i64` by default and
105 /// never printed.
106 ///
107 /// Integer indices can also be scalable in the context of scalable vectors,
108 /// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in
109 /// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's
110 /// a scalable index. If `scalableFlags` is empty then assume that all indices
111 /// are non-scalable.
112 ///
113 /// Examples:
114 ///
115 ///   * Input: `integers = [kDynamic, 7, 42, kDynamic]`,
116 ///            `values = [%arg0, %arg42]` and
117 ///            `valueTypes = [index, index]`
118 ///     prints:
119 ///       `[%arg0 : index, 7, 42, %arg42 : i32]`
120 ///
121 ///   * Input: `integers = [kDynamic, 7, 42, kDynamic]`,
122 ///            `values = [%arg0, %arg42]` and
123 ///            `valueTypes = []`
124 ///     prints:
125 ///       `[%arg0, 7, 42, %arg42]`
126 ///
127 ///   * Input: `integers = [2, 4, 8]`,
128 ///            `values = []` and
129 ///            `scalableFlags = [false, true, false]`
130 ///     prints:
131 ///       `[2, [4], 8]`
132 ///
133 void printDynamicIndexList(
134     OpAsmPrinter &printer, Operation *op, OperandRange values,
135     ArrayRef<int64_t> integers, ArrayRef<bool> scalableFlags,
136     TypeRange valueTypes = TypeRange(),
137     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
138 inline void printDynamicIndexList(
139     OpAsmPrinter &printer, Operation *op, OperandRange values,
140     ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
141     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
142   return printDynamicIndexList(printer, op, values, integers,
143                                /*scalableFlags=*/{}, valueTypes, delimiter);
144 }
145 
146 /// Parser hooks for custom directive in assemblyFormat.
147 ///
148 ///   custom<DynamicIndexList>($values, $integers)
149 ///   custom<DynamicIndexList>($values, $integers, type($values))
150 ///
151 /// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS
152 /// type `I64ArrayAttr`. Parse a mixed list where each element is either a
153 /// static integer or an SSA value. Fill `integers` with the integer ArrayAttr,
154 /// where `kDynamic` encodes the position of SSA values. Add the parsed SSA
155 /// values to `values` in-order.
156 ///
157 /// If `valueTypes` is provided, fill it with the types corresponding to each
158 /// value in `values`. Otherwise, the caller must handle the types and parsing
159 /// will fail if the type of the value is found (e.g., `[%arg0 : index, 3, %arg1
160 /// : index]`).
161 ///
162 /// Integer indices can also be scalable in the context of scalable vectors,
163 /// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in
164 /// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's
165 /// a scalable index.
166 ///
167 /// Examples:
168 ///
169 ///   * After parsing "[%arg0 : index, 7, 42, %arg42 : i32]":
170 ///       1. `result` is filled with `[kDynamic, 7, 42, kDynamic]`
171 ///       2. `values` is filled with "[%arg0, %arg1]".
172 ///       3. `scalableFlags` is filled with `[false, true, false]`.
173 ///
174 ///   * After parsing `[2, [4], 8]`:
175 ///       1. `result` is filled with `[2, 4, 8]`
176 ///       2. `values` is empty.
177 ///       3. `scalableFlags` is filled with `[false, true, false]`.
178 ///
179 ParseResult parseDynamicIndexList(
180     OpAsmParser &parser,
181     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
182     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
183     SmallVectorImpl<Type> *valueTypes = nullptr,
184     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
185 inline ParseResult parseDynamicIndexList(
186     OpAsmParser &parser,
187     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
188     DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
189     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
190   DenseBoolArrayAttr scalableFlags;
191   return parseDynamicIndexList(parser, values, integers, scalableFlags,
192                                valueTypes, delimiter);
193 }
194 
195 /// Verify that a the `values` has as many elements as the number of entries in
196 /// `attr` for which `isDynamic` evaluates to true.
197 LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name,
198                                              unsigned expectedNumElements,
199                                              ArrayRef<int64_t> attr,
200                                              ValueRange values);
201 
202 } // namespace mlir
203 
204 #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
205