1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 14 using namespace mlir; 15 using namespace mlir::bufferization; 16 using namespace mlir::tensor; 17 18 namespace { 19 20 /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent 21 /// to the subset defined by `candidate`. `equivalenceFn` is used to determine 22 /// equivalence of tensors. 23 template <typename OpTy> 24 bool isSubsetEquivalentToInsertSliceLikeOp( 25 OpTy insertSliceOp, Value candidate, 26 function_ref<bool(Value, Value)> equivalenceFn) { 27 // Look for a matching tensor.extract_slice op. 28 auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>(); 29 if (!extractSliceOp) 30 return false; 31 if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest())) 32 return false; 33 return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp, 34 isEqualConstantIntOrValue); 35 } 36 37 struct InsertSliceOpInterface 38 : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface, 39 tensor::InsertSliceOp> { 40 OpOperand &getSourceOperand(Operation *op) const { 41 return op->getOpOperand(0); 42 } 43 44 bool 45 isEquivalentSubset(Operation *op, Value candidate, 46 function_ref<bool(Value, Value)> equivalenceFn) const { 47 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 48 return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate, 49 equivalenceFn); 50 } 51 }; 52 53 struct ParallelInsertSliceOpInterface 54 : public SubsetInsertionOpInterface::ExternalModel< 55 ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> { 56 OpOperand &getSourceOperand(Operation *op) const { 57 return op->getOpOperand(0); 58 } 59 60 OpOperand &getDestinationOperand(Operation *op) const { 61 return op->getOpOperand(1); 62 } 63 64 bool 65 isEquivalentSubset(Operation *op, Value candidate, 66 function_ref<bool(Value, Value)> equivalenceFn) const { 67 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op); 68 return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate, 69 equivalenceFn); 70 } 71 }; 72 73 } // namespace 74 75 void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels( 76 DialectRegistry ®istry) { 77 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 78 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 79 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( 80 *ctx); 81 }); 82 } 83