1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 14 using namespace mlir; 15 using namespace mlir::bufferization; 16 using namespace mlir::tensor; 17 18 namespace { 19 20 template <typename OpTy> 21 struct InsertSliceLikeOpInterface 22 : public SubsetInsertionOpInterface::ExternalModel< 23 InsertSliceLikeOpInterface<OpTy>, OpTy> { 24 OpOperand &getSourceOperand(Operation *op) const { 25 return cast<OpTy>(op).getSourceMutable(); 26 } 27 28 OpOperand &getDestinationOperand(Operation *op) const { 29 return cast<OpTy>(op).getDestMutable(); 30 } 31 32 /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent 33 /// to the subset defined by `candidate`. `equivalenceFn` is used to determine 34 /// equivalence of tensors. 35 bool 36 isEquivalentSubset(Operation *op, Value candidate, 37 function_ref<bool(Value, Value)> equivalenceFn) const { 38 auto insertSliceOp = cast<OpTy>(op); 39 // Look for a matching tensor.extract_slice op. 40 auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>(); 41 if (!extractSliceOp) 42 return false; 43 if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest())) 44 return false; 45 return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp, 46 isEqualConstantIntOrValue); 47 } 48 49 Value buildSubsetExtraction(Operation *op, OpBuilder &builder, 50 Location loc) const { 51 auto insertSliceOp = cast<OpTy>(op); 52 auto extractOp = builder.create<tensor::ExtractSliceOp>( 53 loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(), 54 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 55 insertSliceOp.getMixedStrides()); 56 return extractOp.getResult(); 57 } 58 59 SmallVector<Value> 60 getValuesNeededToBuildSubsetExtraction(Operation *op) const { 61 auto insertSliceOp = cast<OpTy>(op); 62 SmallVector<Value> neededValues; 63 // Collect all values that are needed to construct the replacement op. 64 neededValues.append(insertSliceOp.getOffsets().begin(), 65 insertSliceOp.getOffsets().end()); 66 neededValues.append(insertSliceOp.getSizes().begin(), 67 insertSliceOp.getSizes().end()); 68 neededValues.append(insertSliceOp.getStrides().begin(), 69 insertSliceOp.getStrides().end()); 70 neededValues.push_back(insertSliceOp.getDest()); 71 return neededValues; 72 } 73 }; 74 75 } // namespace 76 77 void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels( 78 DialectRegistry ®istry) { 79 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 80 InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>( 81 *ctx); 82 ParallelInsertSliceOp::attachInterface< 83 InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx); 84 }); 85 } 86