xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp (revision 8143307b336db18f5db4e19a202194ee6cb614de)
1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 
14 using namespace mlir;
15 using namespace mlir::bufferization;
16 using namespace mlir::tensor;
17 
18 namespace {
19 
20 /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
21 /// to the subset defined by `candidate`. `equivalenceFn` is used to determine
22 /// equivalence of tensors.
23 template <typename OpTy>
24 bool isSubsetEquivalentToInsertSliceLikeOp(
25     OpTy insertSliceOp, Value candidate,
26     function_ref<bool(Value, Value)> equivalenceFn) {
27   // Look for a matching tensor.extract_slice op.
28   auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
29   if (!extractSliceOp)
30     return false;
31   if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
32     return false;
33   return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
34                                     isEqualConstantIntOrValue);
35 }
36 
37 struct InsertSliceOpInterface
38     : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
39                                                        tensor::InsertSliceOp> {
40   OpOperand &getSourceOperand(Operation *op) const {
41     return op->getOpOperand(0);
42   }
43 
44   bool
45   isEquivalentSubset(Operation *op, Value candidate,
46                      function_ref<bool(Value, Value)> equivalenceFn) const {
47     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
48     return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
49                                                  equivalenceFn);
50   }
51 };
52 
53 struct ParallelInsertSliceOpInterface
54     : public SubsetInsertionOpInterface::ExternalModel<
55           ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
56   OpOperand &getSourceOperand(Operation *op) const {
57     return op->getOpOperand(0);
58   }
59 
60   OpOperand &getDestinationOperand(Operation *op) const {
61     return op->getOpOperand(1);
62   }
63 
64   bool
65   isEquivalentSubset(Operation *op, Value candidate,
66                      function_ref<bool(Value, Value)> equivalenceFn) const {
67     auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
68     return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
69                                                  equivalenceFn);
70   }
71 };
72 
73 } // namespace
74 
75 void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
76     DialectRegistry &registry) {
77   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
78     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
79     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
80         *ctx);
81   });
82 }
83