xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp (revision 2e3c62b15d49fbe11967de7719f4b9d70c5493e4)
1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 
14 using namespace mlir;
15 using namespace mlir::bufferization;
16 using namespace mlir::tensor;
17 
18 namespace {
19 
20 template <typename OpTy>
21 struct InsertSliceLikeOpInterface
22     : public SubsetInsertionOpInterface::ExternalModel<
23           InsertSliceLikeOpInterface<OpTy>, OpTy> {
24   OpOperand &getSourceOperand(Operation *op) const {
25     return cast<OpTy>(op).getSourceMutable();
26   }
27 
28   OpOperand &getDestinationOperand(Operation *op) const {
29     return cast<OpTy>(op).getDestMutable();
30   }
31 
32   /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
33   /// to the subset defined by `candidate`. `equivalenceFn` is used to determine
34   /// equivalence of tensors.
35   bool
36   isEquivalentSubset(Operation *op, Value candidate,
37                      function_ref<bool(Value, Value)> equivalenceFn) const {
38     auto insertSliceOp = cast<OpTy>(op);
39     // Look for a matching tensor.extract_slice op.
40     auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
41     if (!extractSliceOp)
42       return false;
43     if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
44       return false;
45     return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
46                                       isEqualConstantIntOrValue);
47   }
48 
49   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
50                               Location loc) const {
51     auto insertSliceOp = cast<OpTy>(op);
52     auto extractOp = builder.create<tensor::ExtractSliceOp>(
53         loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
54         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
55         insertSliceOp.getMixedStrides());
56     return extractOp.getResult();
57   }
58 
59   SmallVector<Value>
60   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
61     auto insertSliceOp = cast<OpTy>(op);
62     SmallVector<Value> neededValues;
63     // Collect all values that are needed to construct the replacement op.
64     neededValues.append(insertSliceOp.getOffsets().begin(),
65                         insertSliceOp.getOffsets().end());
66     neededValues.append(insertSliceOp.getSizes().begin(),
67                         insertSliceOp.getSizes().end());
68     neededValues.append(insertSliceOp.getStrides().begin(),
69                         insertSliceOp.getStrides().end());
70     neededValues.push_back(insertSliceOp.getDest());
71     return neededValues;
72   }
73 };
74 
75 } // namespace
76 
77 void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
78     DialectRegistry &registry) {
79   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
80     InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>(
81         *ctx);
82     ParallelInsertSliceOp::attachInterface<
83         InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx);
84   });
85 }
86