1 //===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Swap a `tensor.extract_slice` with the producer of the source if the producer 10 // implements the `TilingInterface`. When used in conjunction with tiling this 11 // effectively tiles + fuses the producer with its consumer. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 18 #include "mlir/Dialect/Utils/StaticValueUtils.h" 19 #include "mlir/Interfaces/TilingInterface.h" 20 21 using namespace mlir; 22 23 FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer( 24 OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) { 25 auto producerOp = dyn_cast<TilingInterface>(producer.getOwner()); 26 if (!producerOp) 27 return failure(); 28 29 // `TilingInterface` currently only supports strides being 1. 30 if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { 31 return !isConstantIntValue(ofr, 1); 32 })) 33 return failure(); 34 35 FailureOr<Value> tiledResult = producerOp.generateResultTileValue( 36 builder, producer.getResultNumber(), sliceOp.getMixedOffsets(), 37 sliceOp.getMixedSizes()); 38 if (failed(tiledResult)) 39 return failure(); 40 41 return tiledResult.value(); 42 } 43