xref: /llvm-project/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp (revision 465ec4e0b48c005f5d5de8adee0c33469a7b9862)
1 //===- ViewLikeInterfaceUtils.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 
12 using namespace mlir;
13 
14 LogicalResult mlir::mergeOffsetsSizesAndStrides(
15     OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> producerOffsets,
16     ArrayRef<OpFoldResult> producerSizes,
17     ArrayRef<OpFoldResult> producerStrides,
18     const llvm::SmallBitVector &droppedProducerDims,
19     ArrayRef<OpFoldResult> consumerOffsets,
20     ArrayRef<OpFoldResult> consumerSizes,
21     ArrayRef<OpFoldResult> consumerStrides,
22     SmallVector<OpFoldResult> &combinedOffsets,
23     SmallVector<OpFoldResult> &combinedSizes,
24     SmallVector<OpFoldResult> &combinedStrides) {
25   combinedOffsets.resize(producerOffsets.size());
26   combinedSizes.resize(producerOffsets.size());
27   combinedStrides.resize(producerOffsets.size());
28 
29   AffineExpr s0, s1, s2;
30   bindSymbols(builder.getContext(), s0, s1, s2);
31 
32   unsigned consumerPos = 0;
33   for (auto i : llvm::seq<unsigned>(0, producerOffsets.size())) {
34     if (droppedProducerDims.test(i)) {
35       // For dropped dims, get the values from the producer.
36       combinedOffsets[i] = producerOffsets[i];
37       combinedSizes[i] = producerSizes[i];
38       combinedStrides[i] = producerStrides[i];
39       continue;
40     }
41     SmallVector<OpFoldResult> offsetSymbols, strideSymbols;
42     // The combined offset is computed as
43     //    producer_offset + consumer_offset * producer_strides.
44     combinedOffsets[i] = makeComposedFoldedAffineApply(
45         builder, loc, s0 * s1 + s2,
46         {consumerOffsets[consumerPos], producerStrides[i], producerOffsets[i]});
47     combinedSizes[i] = consumerSizes[consumerPos];
48     // The combined stride is computed as
49     //    consumer_stride * producer_stride.
50     combinedStrides[i] = makeComposedFoldedAffineApply(
51         builder, loc, s0 * s1,
52         {consumerStrides[consumerPos], producerStrides[i]});
53 
54     consumerPos++;
55   }
56   return success();
57 }
58 
59 LogicalResult mlir::mergeOffsetsSizesAndStrides(
60     OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer,
61     OffsetSizeAndStrideOpInterface consumer,
62     const llvm::SmallBitVector &droppedProducerDims,
63     SmallVector<OpFoldResult> &combinedOffsets,
64     SmallVector<OpFoldResult> &combinedSizes,
65     SmallVector<OpFoldResult> &combinedStrides) {
66   SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets();
67   SmallVector<OpFoldResult> consumerSizes = consumer.getMixedSizes();
68   SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides();
69   SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets();
70   SmallVector<OpFoldResult> producerSizes = producer.getMixedSizes();
71   SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides();
72   return mergeOffsetsSizesAndStrides(
73       builder, loc, producerOffsets, producerSizes, producerStrides,
74       droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides,
75       combinedOffsets, combinedSizes, combinedStrides);
76 }
77