xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp (revision 1abd8d1a8d962a14ca96e19c0f6da4f9ac394d0a)
1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Tensor/IR/Tensor.h"
12 #include "mlir/Interfaces/SubsetOpInterface.h"
13 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
14 
15 using namespace mlir;
16 using namespace mlir::tensor;
17 
18 namespace {
19 
20 /// Return the tensor that the given subset op operates on.
21 Value getContainerOperand(SubsetOpInterface op) {
22   if (auto extractionOp =
23           dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
24     return extractionOp.getSourceOperand().get();
25   if (auto insertionOp =
26           dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
27     return insertionOp.getDestinationOperand().get();
28   llvm_unreachable("expected SubsetExtraction/InsertionOpInterface");
29 }
30 
31 /// Return "true" if the two ops operate on an equivalent subset.
32 /// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
33 /// if the two ops operate non-equivalent subsets, if equivalence cannot be
34 /// determined or if `op1` is not a subset op.
35 template <typename OpTy>
36 bool operateOnEquivalentSubsets(
37     OpTy op1, SubsetOpInterface op2,
38     function_ref<bool(Value, Value)> equivalenceFn) {
39   auto offsetsSizesAndStrides2 =
40       dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
41   if (!offsetsSizesAndStrides2)
42     return false;
43   if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2,
44                                   isEqualConstantIntOrValue))
45     return false;
46   return equivalenceFn(
47       getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
48       getContainerOperand(op2));
49 }
50 
51 /// Return "true" if the two ops operate on a disjoint subsets.
52 /// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
53 /// if the two ops operate non-disjoint subsets, if disjointness cannot be
54 /// determined or if `op1` is not a subset op.
55 template <typename OpTy>
56 bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2,
57                               function_ref<bool(Value, Value)> equivalenceFn) {
58   auto offsetsSizesAndStrides2 =
59       dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
60   if (!offsetsSizesAndStrides2)
61     return false;
62   FailureOr<bool> overlappingSlices =
63       ValueBoundsConstraintSet::areOverlappingSlices(op1,
64                                                      offsetsSizesAndStrides2);
65   if (failed(overlappingSlices) || *overlappingSlices)
66     return false;
67   return equivalenceFn(
68       getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
69       getContainerOperand(op2));
70 }
71 
72 struct ExtractSliceOpSubsetOpInterface
73     : public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
74                                               tensor::ExtractSliceOp> {
75   bool operatesOnEquivalentSubset(
76       Operation *op, SubsetOpInterface candidate,
77       function_ref<bool(Value, Value)> equivalenceFn) const {
78     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
79     return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn);
80   }
81 
82   bool operatesOnDisjointSubset(
83       Operation *op, SubsetOpInterface candidate,
84       function_ref<bool(Value, Value)> equivalenceFn) const {
85     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
86     return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn);
87   }
88 };
89 
90 struct ExtractSliceOpSubsetExtractionOpInterface
91     : public SubsetExtractionOpInterface::ExternalModel<
92           ExtractSliceOpSubsetExtractionOpInterface, tensor::ExtractSliceOp> {
93   OpOperand &getSourceOperand(Operation *op) const {
94     return cast<tensor::ExtractSliceOp>(op).getSourceMutable();
95   }
96 };
97 
98 template <typename OpTy>
99 struct InsertSliceLikeOpSubsetOpInterface
100     : public SubsetOpInterface::ExternalModel<
101           InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
102   bool operatesOnEquivalentSubset(
103       Operation *op, SubsetOpInterface candidate,
104       function_ref<bool(Value, Value)> equivalenceFn) const {
105     auto insertSliceOp = cast<OpTy>(op);
106     return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn);
107   }
108 
109   bool operatesOnDisjointSubset(
110       Operation *op, SubsetOpInterface candidate,
111       function_ref<bool(Value, Value)> equivalenceFn) const {
112     auto insertSliceOp = cast<OpTy>(op);
113     return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn);
114   }
115 };
116 
117 template <typename OpTy>
118 struct InsertSliceLikeOpSubsetInsertionOpInterface
119     : public SubsetInsertionOpInterface::ExternalModel<
120           InsertSliceLikeOpSubsetInsertionOpInterface<OpTy>, OpTy> {
121   OpOperand &getSourceOperand(Operation *op) const {
122     return cast<OpTy>(op).getSourceMutable();
123   }
124 
125   OpOperand &getDestinationOperand(Operation *op) const {
126     return cast<OpTy>(op).getDestMutable();
127   }
128 
129   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
130                               Location loc) const {
131     auto insertSliceOp = cast<OpTy>(op);
132     auto extractOp = builder.create<tensor::ExtractSliceOp>(
133         loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
134         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
135         insertSliceOp.getMixedStrides());
136     return extractOp.getResult();
137   }
138 
139   SmallVector<Value>
140   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
141     auto insertSliceOp = cast<OpTy>(op);
142     SmallVector<Value> neededValues;
143     // Collect all values that are needed to construct the replacement op.
144     neededValues.append(insertSliceOp.getOffsets().begin(),
145                         insertSliceOp.getOffsets().end());
146     neededValues.append(insertSliceOp.getSizes().begin(),
147                         insertSliceOp.getSizes().end());
148     neededValues.append(insertSliceOp.getStrides().begin(),
149                         insertSliceOp.getStrides().end());
150     neededValues.push_back(insertSliceOp.getDest());
151     return neededValues;
152   }
153 };
154 
155 } // namespace
156 
157 void mlir::tensor::registerSubsetOpInterfaceExternalModels(
158     DialectRegistry &registry) {
159   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
160     // Note: `SubsetExtractionOpInterface` and `SubsetInsertionOpInterface`
161     // require `SubsetOpInterface`.
162     ExtractSliceOp::attachInterface<ExtractSliceOpSubsetOpInterface>(*ctx);
163     ExtractSliceOp::attachInterface<ExtractSliceOpSubsetExtractionOpInterface>(
164         *ctx);
165     InsertSliceOp::attachInterface<
166         InsertSliceLikeOpSubsetOpInterface<InsertSliceOp>>(*ctx);
167     InsertSliceOp::attachInterface<
168         InsertSliceLikeOpSubsetInsertionOpInterface<InsertSliceOp>>(*ctx);
169     ParallelInsertSliceOp::attachInterface<
170         InsertSliceLikeOpSubsetOpInterface<ParallelInsertSliceOp>>(*ctx);
171     ParallelInsertSliceOp::attachInterface<
172         InsertSliceLikeOpSubsetInsertionOpInterface<ParallelInsertSliceOp>>(
173         *ctx);
174   });
175 }
176