xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp (revision 2b2ce50fe843b5b550806a0ab15b06cd5c405d48)
1 //===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Swap a `tensor.extract_slice` with the producer of the source if the producer
10 // implements the `TilingInterface`. When used in conjunction with tiling this
11 // effectively tiles + fuses the producer with its consumer.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
18 #include "mlir/Dialect/Utils/StaticValueUtils.h"
19 #include "mlir/Interfaces/TilingInterface.h"
20 
21 using namespace mlir;
22 
replaceExtractSliceWithTiledProducer(OpBuilder & builder,tensor::ExtractSliceOp sliceOp,OpResult producer)23 FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
24     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
25   auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
26   if (!producerOp)
27     return failure();
28 
29   // `TilingInterface` currently only supports strides being 1.
30   if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
31         return !isConstantIntValue(ofr, 1);
32       }))
33     return failure();
34 
35   FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
36       builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
37       sliceOp.getMixedSizes());
38   if (failed(tiledResult))
39     return failure();
40 
41   return *tiledResult;
42 }
43 
replaceInsertSliceWithTiledConsumer(OpBuilder & builder,OffsetSizeAndStrideOpInterface sliceOp,OpOperand & consumer)44 FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
45     OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
46     OpOperand &consumer) {
47   auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
48   if (!consumerOp)
49     return failure();
50 
51   // `TilingInterface` currently only supports strides being 1.
52   if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
53         return !isConstantIntValue(ofr, 1);
54       }))
55     return failure();
56 
57   FailureOr<TilingResult> tiledResult =
58       consumerOp.getTiledImplementationFromOperandTile(
59           builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
60           sliceOp.getMixedSizes());
61   if (failed(tiledResult))
62     return failure();
63 
64   return *tiledResult;
65 }
66