12d3b54feSLei Zhang //===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
22d3b54feSLei Zhang //
32d3b54feSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42d3b54feSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
52d3b54feSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62d3b54feSLei Zhang //
72d3b54feSLei Zhang //===----------------------------------------------------------------------===//
82d3b54feSLei Zhang //
92d3b54feSLei Zhang // Swap a `tensor.extract_slice` with the producer of the source if the producer
102d3b54feSLei Zhang // implements the `TilingInterface`. When used in conjunction with tiling this
112d3b54feSLei Zhang // effectively tiles + fuses the producer with its consumer.
122d3b54feSLei Zhang //
132d3b54feSLei Zhang //===----------------------------------------------------------------------===//
142d3b54feSLei Zhang
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
162d3b54feSLei Zhang #include "mlir/Dialect/Tensor/IR/Tensor.h"
172d3b54feSLei Zhang #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
182d3b54feSLei Zhang #include "mlir/Dialect/Utils/StaticValueUtils.h"
192d3b54feSLei Zhang #include "mlir/Interfaces/TilingInterface.h"
202d3b54feSLei Zhang
212d3b54feSLei Zhang using namespace mlir;
222d3b54feSLei Zhang
replaceExtractSliceWithTiledProducer(OpBuilder & builder,tensor::ExtractSliceOp sliceOp,OpResult producer)23809e3d8cSMahesh Ravishankar FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
242d3b54feSLei Zhang OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
252d3b54feSLei Zhang auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
262d3b54feSLei Zhang if (!producerOp)
272d3b54feSLei Zhang return failure();
282d3b54feSLei Zhang
292d3b54feSLei Zhang // `TilingInterface` currently only supports strides being 1.
302d3b54feSLei Zhang if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
312d3b54feSLei Zhang return !isConstantIntValue(ofr, 1);
322d3b54feSLei Zhang }))
332d3b54feSLei Zhang return failure();
342d3b54feSLei Zhang
35809e3d8cSMahesh Ravishankar FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
362d3b54feSLei Zhang builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
372d3b54feSLei Zhang sliceOp.getMixedSizes());
382d3b54feSLei Zhang if (failed(tiledResult))
392d3b54feSLei Zhang return failure();
402d3b54feSLei Zhang
41cbb09813SFangrui Song return *tiledResult;
422d3b54feSLei Zhang }
43*2b2ce50fSAbhishek Varma
replaceInsertSliceWithTiledConsumer(OpBuilder & builder,OffsetSizeAndStrideOpInterface sliceOp,OpOperand & consumer)44*2b2ce50fSAbhishek Varma FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
45*2b2ce50fSAbhishek Varma OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
46*2b2ce50fSAbhishek Varma OpOperand &consumer) {
47*2b2ce50fSAbhishek Varma auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
48*2b2ce50fSAbhishek Varma if (!consumerOp)
49*2b2ce50fSAbhishek Varma return failure();
50*2b2ce50fSAbhishek Varma
51*2b2ce50fSAbhishek Varma // `TilingInterface` currently only supports strides being 1.
52*2b2ce50fSAbhishek Varma if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
53*2b2ce50fSAbhishek Varma return !isConstantIntValue(ofr, 1);
54*2b2ce50fSAbhishek Varma }))
55*2b2ce50fSAbhishek Varma return failure();
56*2b2ce50fSAbhishek Varma
57*2b2ce50fSAbhishek Varma FailureOr<TilingResult> tiledResult =
58*2b2ce50fSAbhishek Varma consumerOp.getTiledImplementationFromOperandTile(
59*2b2ce50fSAbhishek Varma builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
60*2b2ce50fSAbhishek Varma sliceOp.getMixedSizes());
61*2b2ce50fSAbhishek Varma if (failed(tiledResult))
62*2b2ce50fSAbhishek Varma return failure();
63*2b2ce50fSAbhishek Varma
64*2b2ce50fSAbhishek Varma return *tiledResult;
65*2b2ce50fSAbhishek Varma }
66