xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp (revision a1ef5a9437fd027da6b479426e52dc0c8a713a3f)
1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 
14 using namespace mlir;
15 using namespace mlir::bufferization;
16 using namespace mlir::tensor;
17 
18 namespace {
19 
20 /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
21 /// to the subset defined by `candidate`. `equivalenceFn` is used to determine
22 /// equivalence of tensors.
23 template <typename OpTy>
24 bool isSubsetEquivalentToInsertSliceLikeOp(
25     OpTy insertSliceOp, Value candidate,
26     function_ref<bool(Value, Value)> equivalenceFn) {
27   // Look for a matching tensor.extract_slice op.
28   auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
29   if (!extractSliceOp)
30     return false;
31   if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
32     return false;
33   return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
34                                     isEqualConstantIntOrValue);
35 }
36 
37 template <typename OpTy>
38 Value buildSubsetExtractionOfInsertSliceLikeOp(OpBuilder &b, Location loc,
39                                                OpTy insertSliceOp) {
40   auto extractOp = b.create<tensor::ExtractSliceOp>(
41       loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
42       insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
43       insertSliceOp.getMixedStrides());
44   return extractOp.getResult();
45 }
46 
47 template <typename OpTy>
48 SmallVector<Value>
49 getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(OpTy insertSliceOp) {
50   SmallVector<Value> neededValues;
51   // Collect all values that are needed to construct the replacement op.
52   neededValues.append(insertSliceOp.getOffsets().begin(),
53                       insertSliceOp.getOffsets().end());
54   neededValues.append(insertSliceOp.getSizes().begin(),
55                       insertSliceOp.getSizes().end());
56   neededValues.append(insertSliceOp.getStrides().begin(),
57                       insertSliceOp.getStrides().end());
58   neededValues.push_back(insertSliceOp.getDest());
59   return neededValues;
60 }
61 
62 struct InsertSliceOpInterface
63     : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
64                                                        tensor::InsertSliceOp> {
65   OpOperand &getSourceOperand(Operation *op) const {
66     return op->getOpOperand(0);
67   }
68 
69   bool
70   isEquivalentSubset(Operation *op, Value candidate,
71                      function_ref<bool(Value, Value)> equivalenceFn) const {
72     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
73     return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
74                                                  equivalenceFn);
75   }
76 
77   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
78                               Location loc) const {
79     return buildSubsetExtractionOfInsertSliceLikeOp(
80         builder, loc, cast<tensor::InsertSliceOp>(op));
81   }
82 
83   SmallVector<Value>
84   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
85     return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
86         cast<tensor::InsertSliceOp>(op));
87   }
88 };
89 
90 struct ParallelInsertSliceOpInterface
91     : public SubsetInsertionOpInterface::ExternalModel<
92           ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
93   OpOperand &getSourceOperand(Operation *op) const {
94     return op->getOpOperand(0);
95   }
96 
97   OpOperand &getDestinationOperand(Operation *op) const {
98     return op->getOpOperand(1);
99   }
100 
101   bool
102   isEquivalentSubset(Operation *op, Value candidate,
103                      function_ref<bool(Value, Value)> equivalenceFn) const {
104     auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
105     return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
106                                                  equivalenceFn);
107   }
108 
109   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
110                               Location loc) const {
111     return buildSubsetExtractionOfInsertSliceLikeOp(
112         builder, loc, cast<tensor::ParallelInsertSliceOp>(op));
113   }
114 
115   SmallVector<Value>
116   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
117     return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
118         cast<tensor::ParallelInsertSliceOp>(op));
119   }
120 };
121 
122 } // namespace
123 
124 void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
125     DialectRegistry &registry) {
126   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
127     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
128     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
129         *ctx);
130   });
131 }
132