1 //===- ViewLikeInterfaceUtils.h ---------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H
10 #define MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H
11
12 #include "mlir/Dialect/Utils/StaticValueUtils.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "mlir/Interfaces/ViewLikeInterface.h"
15
16 namespace mlir {
17 class RewriterBase;
18
19 namespace affine {
20
21 /// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
22 /// when combining a producer slice **into** a consumer slice.
23 ///
24 /// This function performs the following computation:
25 /// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets
26 /// - Combined sizes = consumer_sizes
27 /// - Combined strides = producer_strides * consumer_strides
28 // TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or
29 // deprecate.
30 LogicalResult
31 mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
32 ArrayRef<OpFoldResult> producerOffsets,
33 ArrayRef<OpFoldResult> producerSizes,
34 ArrayRef<OpFoldResult> producerStrides,
35 const llvm::SmallBitVector &droppedProducerDims,
36 ArrayRef<OpFoldResult> consumerOffsets,
37 ArrayRef<OpFoldResult> consumerSizes,
38 ArrayRef<OpFoldResult> consumerStrides,
39 SmallVector<OpFoldResult> &combinedOffsets,
40 SmallVector<OpFoldResult> &combinedSizes,
41 SmallVector<OpFoldResult> &combinedStrides);
42
43 /// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
44 /// when combining a `producer` slice op **into** a `consumer` slice op.
45 // TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or
46 // deprecate.
47 LogicalResult
48 mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
49 OffsetSizeAndStrideOpInterface producer,
50 OffsetSizeAndStrideOpInterface consumer,
51 const llvm::SmallBitVector &droppedProducerDims,
52 SmallVector<OpFoldResult> &combinedOffsets,
53 SmallVector<OpFoldResult> &combinedSizes,
54 SmallVector<OpFoldResult> &combinedStrides);
55
56 /// Given the 'consumerIndices' of a load/store operation operating on an op
57 /// with offsets and strides, return the combined indices.
58 ///
59 /// For example, using `memref.load` and `memref.subview` as an illustration:
60 ///
61 /// ```
62 /// %0 = ... : memref<12x42xf32>
63 /// %1 = memref.subview %0[%arg0, %arg1][...][%stride1, %stride2] :
64 /// memref<12x42xf32> to memref<4x4xf32, offset=?, strides=[?, ?]>
65 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
66 /// ```
67 ///
68 /// could be folded into:
69 ///
70 /// ```
71 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
72 /// memref<12x42xf32>å
73 /// ```
74 void resolveIndicesIntoOpWithOffsetsAndStrides(
75 RewriterBase &rewriter, Location loc,
76 ArrayRef<OpFoldResult> mixedSourceOffsets,
77 ArrayRef<OpFoldResult> mixedSourceStrides,
78 const llvm::SmallBitVector &rankReducedDims,
79 ArrayRef<OpFoldResult> consumerIndices,
80 SmallVectorImpl<Value> &resolvedIndices);
81
resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase & rewriter,Location loc,ArrayRef<OpFoldResult> mixedSourceOffsets,ArrayRef<OpFoldResult> mixedSourceStrides,const llvm::SmallBitVector & rankReducedDims,ValueRange consumerIndices,SmallVectorImpl<Value> & resolvedIndices)82 inline void resolveIndicesIntoOpWithOffsetsAndStrides(
83 RewriterBase &rewriter, Location loc,
84 ArrayRef<OpFoldResult> mixedSourceOffsets,
85 ArrayRef<OpFoldResult> mixedSourceStrides,
86 const llvm::SmallBitVector &rankReducedDims, ValueRange consumerIndices,
87 SmallVectorImpl<Value> &resolvedIndices) {
88 return resolveIndicesIntoOpWithOffsetsAndStrides(
89 rewriter, loc, mixedSourceOffsets, mixedSourceStrides, rankReducedDims,
90 getAsOpFoldResult(consumerIndices), resolvedIndices);
91 }
92
93 /// Given `sourceSizes`, `destSizes` and information about which dimensions are
94 /// dropped by the source: `rankReducedSourceDims`, compute the resolved sizes
95 /// that correspond to dest_op(source_op).
96 /// In practice, this amounts to filtering by `rankReducedSourceDims` and taking
97 /// from `sourceSizes` if a dimension is dropped, otherwise taking from
98 /// `destSizes`.
99 void resolveSizesIntoOpWithSizes(
100 ArrayRef<OpFoldResult> sourceSizes, ArrayRef<OpFoldResult> destSizes,
101 const llvm::SmallBitVector &rankReducedSourceDims,
102 SmallVectorImpl<OpFoldResult> &resolvedSizes);
103
104 } // namespace affine
105 } // namespace mlir
106
107 #endif // MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H
108