xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp (revision 2b2ce50fe843b5b550806a0ab15b06cd5c405d48)
12d3b54feSLei Zhang //===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
22d3b54feSLei Zhang //
32d3b54feSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42d3b54feSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
52d3b54feSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62d3b54feSLei Zhang //
72d3b54feSLei Zhang //===----------------------------------------------------------------------===//
82d3b54feSLei Zhang //
92d3b54feSLei Zhang // Swap a `tensor.extract_slice` with the producer of the source if the producer
102d3b54feSLei Zhang // implements the `TilingInterface`. When used in conjunction with tiling this
112d3b54feSLei Zhang // effectively tiles + fuses the producer with its consumer.
122d3b54feSLei Zhang //
132d3b54feSLei Zhang //===----------------------------------------------------------------------===//
142d3b54feSLei Zhang 
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
162d3b54feSLei Zhang #include "mlir/Dialect/Tensor/IR/Tensor.h"
172d3b54feSLei Zhang #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
182d3b54feSLei Zhang #include "mlir/Dialect/Utils/StaticValueUtils.h"
192d3b54feSLei Zhang #include "mlir/Interfaces/TilingInterface.h"
202d3b54feSLei Zhang 
212d3b54feSLei Zhang using namespace mlir;
222d3b54feSLei Zhang 
replaceExtractSliceWithTiledProducer(OpBuilder & builder,tensor::ExtractSliceOp sliceOp,OpResult producer)23809e3d8cSMahesh Ravishankar FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
242d3b54feSLei Zhang     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
252d3b54feSLei Zhang   auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
262d3b54feSLei Zhang   if (!producerOp)
272d3b54feSLei Zhang     return failure();
282d3b54feSLei Zhang 
292d3b54feSLei Zhang   // `TilingInterface` currently only supports strides being 1.
302d3b54feSLei Zhang   if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
312d3b54feSLei Zhang         return !isConstantIntValue(ofr, 1);
322d3b54feSLei Zhang       }))
332d3b54feSLei Zhang     return failure();
342d3b54feSLei Zhang 
35809e3d8cSMahesh Ravishankar   FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
362d3b54feSLei Zhang       builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
372d3b54feSLei Zhang       sliceOp.getMixedSizes());
382d3b54feSLei Zhang   if (failed(tiledResult))
392d3b54feSLei Zhang     return failure();
402d3b54feSLei Zhang 
41cbb09813SFangrui Song   return *tiledResult;
422d3b54feSLei Zhang }
43*2b2ce50fSAbhishek Varma 
replaceInsertSliceWithTiledConsumer(OpBuilder & builder,OffsetSizeAndStrideOpInterface sliceOp,OpOperand & consumer)44*2b2ce50fSAbhishek Varma FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
45*2b2ce50fSAbhishek Varma     OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
46*2b2ce50fSAbhishek Varma     OpOperand &consumer) {
47*2b2ce50fSAbhishek Varma   auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
48*2b2ce50fSAbhishek Varma   if (!consumerOp)
49*2b2ce50fSAbhishek Varma     return failure();
50*2b2ce50fSAbhishek Varma 
51*2b2ce50fSAbhishek Varma   // `TilingInterface` currently only supports strides being 1.
52*2b2ce50fSAbhishek Varma   if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
53*2b2ce50fSAbhishek Varma         return !isConstantIntValue(ofr, 1);
54*2b2ce50fSAbhishek Varma       }))
55*2b2ce50fSAbhishek Varma     return failure();
56*2b2ce50fSAbhishek Varma 
57*2b2ce50fSAbhishek Varma   FailureOr<TilingResult> tiledResult =
58*2b2ce50fSAbhishek Varma       consumerOp.getTiledImplementationFromOperandTile(
59*2b2ce50fSAbhishek Varma           builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
60*2b2ce50fSAbhishek Varma           sliceOp.getMixedSizes());
61*2b2ce50fSAbhishek Varma   if (failed(tiledResult))
62*2b2ce50fSAbhishek Varma     return failure();
63*2b2ce50fSAbhishek Varma 
64*2b2ce50fSAbhishek Varma   return *tiledResult;
65*2b2ce50fSAbhishek Varma }
66