1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Tensor/IR/Tensor.h" 12 #include "mlir/Interfaces/SubsetOpInterface.h" 13 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 14 15 using namespace mlir; 16 using namespace mlir::tensor; 17 18 namespace { 19 20 /// Return the tensor that the given subset op operates on. 21 Value getContainerOperand(SubsetOpInterface op) { 22 if (auto extractionOp = 23 dyn_cast<SubsetExtractionOpInterface>(op.getOperation())) 24 return extractionOp.getSourceOperand().get(); 25 if (auto insertionOp = 26 dyn_cast<SubsetInsertionOpInterface>(op.getOperation())) 27 return insertionOp.getDestinationOperand().get(); 28 llvm_unreachable("expected SubsetExtraction/InsertionOpInterface"); 29 } 30 31 /// Return "true" if the two ops operate on an equivalent subset. 32 /// `equivalenceFn` is used to determine equivalence of tensors. Return "false" 33 /// if the two ops operate non-equivalent subsets, if equivalence cannot be 34 /// determined or if `op1` is not a subset op. 35 template <typename OpTy> 36 bool operateOnEquivalentSubsets( 37 OpTy op1, SubsetOpInterface op2, 38 function_ref<bool(Value, Value)> equivalenceFn) { 39 auto offsetsSizesAndStrides2 = 40 dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation()); 41 if (!offsetsSizesAndStrides2) 42 return false; 43 if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2, 44 isEqualConstantIntOrValue)) 45 return false; 46 return equivalenceFn( 47 getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())), 48 getContainerOperand(op2)); 49 } 50 51 /// Return "true" if the two ops operate on a disjoint subsets. 52 /// `equivalenceFn` is used to determine equivalence of tensors. Return "false" 53 /// if the two ops operate non-disjoint subsets, if disjointness cannot be 54 /// determined or if `op1` is not a subset op. 55 template <typename OpTy> 56 bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2, 57 function_ref<bool(Value, Value)> equivalenceFn) { 58 auto offsetsSizesAndStrides2 = 59 dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation()); 60 if (!offsetsSizesAndStrides2) 61 return false; 62 FailureOr<bool> overlappingSlices = 63 ValueBoundsConstraintSet::areOverlappingSlices(op1, 64 offsetsSizesAndStrides2); 65 if (failed(overlappingSlices) || *overlappingSlices) 66 return false; 67 return equivalenceFn( 68 getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())), 69 getContainerOperand(op2)); 70 } 71 72 struct ExtractSliceOpSubsetOpInterface 73 : public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface, 74 tensor::ExtractSliceOp> { 75 bool operatesOnEquivalentSubset( 76 Operation *op, SubsetOpInterface candidate, 77 function_ref<bool(Value, Value)> equivalenceFn) const { 78 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 79 return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn); 80 } 81 82 bool operatesOnDisjointSubset( 83 Operation *op, SubsetOpInterface candidate, 84 function_ref<bool(Value, Value)> equivalenceFn) const { 85 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 86 return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn); 87 } 88 }; 89 90 struct ExtractSliceOpSubsetExtractionOpInterface 91 : public SubsetExtractionOpInterface::ExternalModel< 92 ExtractSliceOpSubsetExtractionOpInterface, tensor::ExtractSliceOp> { 93 OpOperand &getSourceOperand(Operation *op) const { 94 return cast<tensor::ExtractSliceOp>(op).getSourceMutable(); 95 } 96 }; 97 98 template <typename OpTy> 99 struct InsertSliceLikeOpSubsetOpInterface 100 : public SubsetOpInterface::ExternalModel< 101 InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> { 102 bool operatesOnEquivalentSubset( 103 Operation *op, SubsetOpInterface candidate, 104 function_ref<bool(Value, Value)> equivalenceFn) const { 105 auto insertSliceOp = cast<OpTy>(op); 106 return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn); 107 } 108 109 bool operatesOnDisjointSubset( 110 Operation *op, SubsetOpInterface candidate, 111 function_ref<bool(Value, Value)> equivalenceFn) const { 112 auto insertSliceOp = cast<OpTy>(op); 113 return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn); 114 } 115 }; 116 117 template <typename OpTy> 118 struct InsertSliceLikeOpSubsetInsertionOpInterface 119 : public SubsetInsertionOpInterface::ExternalModel< 120 InsertSliceLikeOpSubsetInsertionOpInterface<OpTy>, OpTy> { 121 OpOperand &getSourceOperand(Operation *op) const { 122 return cast<OpTy>(op).getSourceMutable(); 123 } 124 125 OpOperand &getDestinationOperand(Operation *op) const { 126 return cast<OpTy>(op).getDestMutable(); 127 } 128 129 Value buildSubsetExtraction(Operation *op, OpBuilder &builder, 130 Location loc) const { 131 auto insertSliceOp = cast<OpTy>(op); 132 auto extractOp = builder.create<tensor::ExtractSliceOp>( 133 loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(), 134 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 135 insertSliceOp.getMixedStrides()); 136 return extractOp.getResult(); 137 } 138 139 SmallVector<Value> 140 getValuesNeededToBuildSubsetExtraction(Operation *op) const { 141 auto insertSliceOp = cast<OpTy>(op); 142 SmallVector<Value> neededValues; 143 // Collect all values that are needed to construct the replacement op. 144 neededValues.append(insertSliceOp.getOffsets().begin(), 145 insertSliceOp.getOffsets().end()); 146 neededValues.append(insertSliceOp.getSizes().begin(), 147 insertSliceOp.getSizes().end()); 148 neededValues.append(insertSliceOp.getStrides().begin(), 149 insertSliceOp.getStrides().end()); 150 neededValues.push_back(insertSliceOp.getDest()); 151 return neededValues; 152 } 153 }; 154 155 } // namespace 156 157 void mlir::tensor::registerSubsetOpInterfaceExternalModels( 158 DialectRegistry ®istry) { 159 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 160 // Note: `SubsetExtractionOpInterface` and `SubsetInsertionOpInterface` 161 // require `SubsetOpInterface`. 162 ExtractSliceOp::attachInterface<ExtractSliceOpSubsetOpInterface>(*ctx); 163 ExtractSliceOp::attachInterface<ExtractSliceOpSubsetExtractionOpInterface>( 164 *ctx); 165 InsertSliceOp::attachInterface< 166 InsertSliceLikeOpSubsetOpInterface<InsertSliceOp>>(*ctx); 167 InsertSliceOp::attachInterface< 168 InsertSliceLikeOpSubsetInsertionOpInterface<InsertSliceOp>>(*ctx); 169 ParallelInsertSliceOp::attachInterface< 170 InsertSliceLikeOpSubsetOpInterface<ParallelInsertSliceOp>>(*ctx); 171 ParallelInsertSliceOp::attachInterface< 172 InsertSliceLikeOpSubsetInsertionOpInterface<ParallelInsertSliceOp>>( 173 *ctx); 174 }); 175 } 176