xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp (revision ff614a5729e9a4fc32465ad5ff3b87e044429c2d)
1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Tensor/IR/Tensor.h"
12 #include "mlir/Interfaces/SubsetOpInterface.h"
13 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
14 
15 using namespace mlir;
16 using namespace mlir::tensor;
17 
18 namespace {
19 
20 struct ExtractSliceOpSubsetOpInterface
21     : public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
22                                               tensor::ExtractSliceOp> {
23   FailureOr<HyperrectangularSlice>
getAccessedHyperrectangularSlice__anon3d2f355c0111::ExtractSliceOpSubsetOpInterface24   getAccessedHyperrectangularSlice(Operation *op) const {
25     return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
26   }
27 };
28 
29 struct ExtractSliceOpSubsetExtractionOpInterface
30     : public SubsetExtractionOpInterface::ExternalModel<
31           ExtractSliceOpSubsetExtractionOpInterface, tensor::ExtractSliceOp> {
getSourceOperand__anon3d2f355c0111::ExtractSliceOpSubsetExtractionOpInterface32   OpOperand &getSourceOperand(Operation *op) const {
33     return cast<tensor::ExtractSliceOp>(op).getSourceMutable();
34   }
35 };
36 
37 template <typename OpTy>
38 struct InsertSliceLikeOpSubsetOpInterface
39     : public SubsetOpInterface::ExternalModel<
40           InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
41   FailureOr<HyperrectangularSlice>
getAccessedHyperrectangularSlice__anon3d2f355c0111::InsertSliceLikeOpSubsetOpInterface42   getAccessedHyperrectangularSlice(Operation *op) const {
43     return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
44   }
45 };
46 
47 template <typename OpTy>
48 struct InsertSliceLikeOpSubsetInsertionOpInterface
49     : public SubsetInsertionOpInterface::ExternalModel<
50           InsertSliceLikeOpSubsetInsertionOpInterface<OpTy>, OpTy> {
getSourceOperand__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface51   OpOperand &getSourceOperand(Operation *op) const {
52     return cast<OpTy>(op).getSourceMutable();
53   }
54 
getDestinationOperand__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface55   OpOperand &getDestinationOperand(Operation *op) const {
56     return cast<OpTy>(op).getDestMutable();
57   }
58 
buildSubsetExtraction__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface59   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
60                               Location loc) const {
61     auto insertSliceOp = cast<OpTy>(op);
62     auto extractOp = builder.create<tensor::ExtractSliceOp>(
63         loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
64         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
65         insertSliceOp.getMixedStrides());
66     return extractOp.getResult();
67   }
68 
69   SmallVector<Value>
getValuesNeededToBuildSubsetExtraction__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface70   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
71     auto insertSliceOp = cast<OpTy>(op);
72     SmallVector<Value> neededValues;
73     // Collect all values that are needed to construct the replacement op.
74     neededValues.append(insertSliceOp.getOffsets().begin(),
75                         insertSliceOp.getOffsets().end());
76     neededValues.append(insertSliceOp.getSizes().begin(),
77                         insertSliceOp.getSizes().end());
78     neededValues.append(insertSliceOp.getStrides().begin(),
79                         insertSliceOp.getStrides().end());
80     neededValues.push_back(insertSliceOp.getDest());
81     return neededValues;
82   }
83 };
84 
85 } // namespace
86 
registerSubsetOpInterfaceExternalModels(DialectRegistry & registry)87 void mlir::tensor::registerSubsetOpInterfaceExternalModels(
88     DialectRegistry &registry) {
89   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
90     // Note: `SubsetExtractionOpInterface` and `SubsetInsertionOpInterface`
91     // require `SubsetOpInterface`.
92     ExtractSliceOp::attachInterface<ExtractSliceOpSubsetOpInterface>(*ctx);
93     ExtractSliceOp::attachInterface<ExtractSliceOpSubsetExtractionOpInterface>(
94         *ctx);
95     InsertSliceOp::attachInterface<
96         InsertSliceLikeOpSubsetOpInterface<InsertSliceOp>>(*ctx);
97     InsertSliceOp::attachInterface<
98         InsertSliceLikeOpSubsetInsertionOpInterface<InsertSliceOp>>(*ctx);
99     ParallelInsertSliceOp::attachInterface<
100         InsertSliceLikeOpSubsetOpInterface<ParallelInsertSliceOp>>(*ctx);
101     ParallelInsertSliceOp::attachInterface<
102         InsertSliceLikeOpSubsetInsertionOpInterface<ParallelInsertSliceOp>>(
103         *ctx);
104   });
105 }
106