1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
10
11 #include "mlir/Dialect/Tensor/IR/Tensor.h"
12 #include "mlir/Interfaces/SubsetOpInterface.h"
13 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
14
15 using namespace mlir;
16 using namespace mlir::tensor;
17
18 namespace {
19
20 struct ExtractSliceOpSubsetOpInterface
21 : public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
22 tensor::ExtractSliceOp> {
23 FailureOr<HyperrectangularSlice>
getAccessedHyperrectangularSlice__anon3d2f355c0111::ExtractSliceOpSubsetOpInterface24 getAccessedHyperrectangularSlice(Operation *op) const {
25 return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
26 }
27 };
28
29 struct ExtractSliceOpSubsetExtractionOpInterface
30 : public SubsetExtractionOpInterface::ExternalModel<
31 ExtractSliceOpSubsetExtractionOpInterface, tensor::ExtractSliceOp> {
getSourceOperand__anon3d2f355c0111::ExtractSliceOpSubsetExtractionOpInterface32 OpOperand &getSourceOperand(Operation *op) const {
33 return cast<tensor::ExtractSliceOp>(op).getSourceMutable();
34 }
35 };
36
37 template <typename OpTy>
38 struct InsertSliceLikeOpSubsetOpInterface
39 : public SubsetOpInterface::ExternalModel<
40 InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
41 FailureOr<HyperrectangularSlice>
getAccessedHyperrectangularSlice__anon3d2f355c0111::InsertSliceLikeOpSubsetOpInterface42 getAccessedHyperrectangularSlice(Operation *op) const {
43 return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
44 }
45 };
46
47 template <typename OpTy>
48 struct InsertSliceLikeOpSubsetInsertionOpInterface
49 : public SubsetInsertionOpInterface::ExternalModel<
50 InsertSliceLikeOpSubsetInsertionOpInterface<OpTy>, OpTy> {
getSourceOperand__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface51 OpOperand &getSourceOperand(Operation *op) const {
52 return cast<OpTy>(op).getSourceMutable();
53 }
54
getDestinationOperand__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface55 OpOperand &getDestinationOperand(Operation *op) const {
56 return cast<OpTy>(op).getDestMutable();
57 }
58
buildSubsetExtraction__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface59 Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
60 Location loc) const {
61 auto insertSliceOp = cast<OpTy>(op);
62 auto extractOp = builder.create<tensor::ExtractSliceOp>(
63 loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
64 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
65 insertSliceOp.getMixedStrides());
66 return extractOp.getResult();
67 }
68
69 SmallVector<Value>
getValuesNeededToBuildSubsetExtraction__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface70 getValuesNeededToBuildSubsetExtraction(Operation *op) const {
71 auto insertSliceOp = cast<OpTy>(op);
72 SmallVector<Value> neededValues;
73 // Collect all values that are needed to construct the replacement op.
74 neededValues.append(insertSliceOp.getOffsets().begin(),
75 insertSliceOp.getOffsets().end());
76 neededValues.append(insertSliceOp.getSizes().begin(),
77 insertSliceOp.getSizes().end());
78 neededValues.append(insertSliceOp.getStrides().begin(),
79 insertSliceOp.getStrides().end());
80 neededValues.push_back(insertSliceOp.getDest());
81 return neededValues;
82 }
83 };
84
85 } // namespace
86
registerSubsetOpInterfaceExternalModels(DialectRegistry & registry)87 void mlir::tensor::registerSubsetOpInterfaceExternalModels(
88 DialectRegistry ®istry) {
89 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
90 // Note: `SubsetExtractionOpInterface` and `SubsetInsertionOpInterface`
91 // require `SubsetOpInterface`.
92 ExtractSliceOp::attachInterface<ExtractSliceOpSubsetOpInterface>(*ctx);
93 ExtractSliceOp::attachInterface<ExtractSliceOpSubsetExtractionOpInterface>(
94 *ctx);
95 InsertSliceOp::attachInterface<
96 InsertSliceLikeOpSubsetOpInterface<InsertSliceOp>>(*ctx);
97 InsertSliceOp::attachInterface<
98 InsertSliceLikeOpSubsetInsertionOpInterface<InsertSliceOp>>(*ctx);
99 ParallelInsertSliceOp::attachInterface<
100 InsertSliceLikeOpSubsetOpInterface<ParallelInsertSliceOp>>(*ctx);
101 ParallelInsertSliceOp::attachInterface<
102 InsertSliceLikeOpSubsetInsertionOpInterface<ParallelInsertSliceOp>>(
103 *ctx);
104 });
105 }
106