1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 14 using namespace mlir; 15 using namespace mlir::bufferization; 16 using namespace mlir::tensor; 17 18 namespace { 19 20 /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent 21 /// to the subset defined by `candidate`. `equivalenceFn` is used to determine 22 /// equivalence of tensors. 23 template <typename OpTy> 24 bool isSubsetEquivalentToInsertSliceLikeOp( 25 OpTy insertSliceOp, Value candidate, 26 function_ref<bool(Value, Value)> equivalenceFn) { 27 // Look for a matching tensor.extract_slice op. 28 auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>(); 29 if (!extractSliceOp) 30 return false; 31 if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest())) 32 return false; 33 return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp, 34 isEqualConstantIntOrValue); 35 } 36 37 template <typename OpTy> 38 Value buildSubsetExtractionOfInsertSliceLikeOp(OpBuilder &b, Location loc, 39 OpTy insertSliceOp) { 40 auto extractOp = b.create<tensor::ExtractSliceOp>( 41 loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(), 42 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 43 insertSliceOp.getMixedStrides()); 44 return extractOp.getResult(); 45 } 46 47 template <typename OpTy> 48 SmallVector<Value> 49 getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(OpTy insertSliceOp) { 50 SmallVector<Value> neededValues; 51 // Collect all values that are needed to construct the replacement op. 52 neededValues.append(insertSliceOp.getOffsets().begin(), 53 insertSliceOp.getOffsets().end()); 54 neededValues.append(insertSliceOp.getSizes().begin(), 55 insertSliceOp.getSizes().end()); 56 neededValues.append(insertSliceOp.getStrides().begin(), 57 insertSliceOp.getStrides().end()); 58 neededValues.push_back(insertSliceOp.getDest()); 59 return neededValues; 60 } 61 62 struct InsertSliceOpInterface 63 : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface, 64 tensor::InsertSliceOp> { 65 OpOperand &getSourceOperand(Operation *op) const { 66 return op->getOpOperand(0); 67 } 68 69 bool 70 isEquivalentSubset(Operation *op, Value candidate, 71 function_ref<bool(Value, Value)> equivalenceFn) const { 72 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 73 return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate, 74 equivalenceFn); 75 } 76 77 Value buildSubsetExtraction(Operation *op, OpBuilder &builder, 78 Location loc) const { 79 return buildSubsetExtractionOfInsertSliceLikeOp( 80 builder, loc, cast<tensor::InsertSliceOp>(op)); 81 } 82 83 SmallVector<Value> 84 getValuesNeededToBuildSubsetExtraction(Operation *op) const { 85 return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp( 86 cast<tensor::InsertSliceOp>(op)); 87 } 88 }; 89 90 struct ParallelInsertSliceOpInterface 91 : public SubsetInsertionOpInterface::ExternalModel< 92 ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> { 93 OpOperand &getSourceOperand(Operation *op) const { 94 return op->getOpOperand(0); 95 } 96 97 OpOperand &getDestinationOperand(Operation *op) const { 98 return op->getOpOperand(1); 99 } 100 101 bool 102 isEquivalentSubset(Operation *op, Value candidate, 103 function_ref<bool(Value, Value)> equivalenceFn) const { 104 auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op); 105 return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate, 106 equivalenceFn); 107 } 108 109 Value buildSubsetExtraction(Operation *op, OpBuilder &builder, 110 Location loc) const { 111 return buildSubsetExtractionOfInsertSliceLikeOp( 112 builder, loc, cast<tensor::ParallelInsertSliceOp>(op)); 113 } 114 115 SmallVector<Value> 116 getValuesNeededToBuildSubsetExtraction(Operation *op) const { 117 return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp( 118 cast<tensor::ParallelInsertSliceOp>(op)); 119 } 120 }; 121 122 } // namespace 123 124 void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels( 125 DialectRegistry ®istry) { 126 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 127 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 128 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( 129 *ctx); 130 }); 131 } 132