xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp (revision 1df6504ac21acbdaeee4eed52e82af69a302024c)
1 //===- SubsetOpInterfaceImpl.cpp - Tensor subsets -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Vector/IR/VectorOps.h"
12 #include "mlir/Interfaces/SubsetOpInterface.h"
13 
14 using namespace mlir;
15 using namespace mlir::vector;
16 
17 namespace {
18 
19 template <typename OpTy>
20 struct XferOpSubsetOpInterface
21     : public SubsetOpInterface::ExternalModel<XferOpSubsetOpInterface<OpTy>,
22                                               OpTy> {
23   FailureOr<HyperrectangularSlice>
getAccessedHyperrectangularSlice__anon39b34ad90111::XferOpSubsetOpInterface24   getAccessedHyperrectangularSlice(Operation *op) const {
25     auto xferOp = cast<OpTy>(op);
26     Builder b(xferOp->getContext());
27     SmallVector<OpFoldResult> offsets = llvm::map_to_vector(
28         xferOp.getIndices(), [](Value v) -> OpFoldResult { return v; });
29     SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
30         xferOp.getTransferChunkAccessed(),
31         [&](int64_t sz) -> OpFoldResult { return b.getIndexAttr(sz); });
32     return HyperrectangularSlice(offsets, sizes);
33   }
34 };
35 
36 struct TransferReadOpSubsetExtractionOpInterface
37     : public SubsetExtractionOpInterface::ExternalModel<
38           TransferReadOpSubsetExtractionOpInterface, vector::TransferReadOp> {
getSourceOperand__anon39b34ad90111::TransferReadOpSubsetExtractionOpInterface39   OpOperand &getSourceOperand(Operation *op) const {
40     return cast<vector::TransferReadOp>(op).getSourceMutable();
41   }
42 };
43 
44 struct TransferWriteOpSubsetInsertionOpInterface
45     : public SubsetInsertionOpInterface::ExternalModel<
46           TransferWriteOpSubsetInsertionOpInterface, vector::TransferWriteOp> {
getSourceOperand__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface47   OpOperand &getSourceOperand(Operation *op) const {
48     return cast<vector::TransferWriteOp>(op).getVectorMutable();
49   }
50 
getDestinationOperand__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface51   OpOperand &getDestinationOperand(Operation *op) const {
52     return cast<vector::TransferWriteOp>(op).getSourceMutable();
53   }
54 
buildSubsetExtraction__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface55   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
56                               Location loc) const {
57     // TODO: Implement when needed.
58     return Value();
59   }
60 
61   SmallVector<Value>
getValuesNeededToBuildSubsetExtraction__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface62   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
63     // TODO: Implement when needed.
64     return {};
65   }
66 };
67 
68 } // namespace
69 
registerSubsetOpInterfaceExternalModels(DialectRegistry & registry)70 void mlir::vector::registerSubsetOpInterfaceExternalModels(
71     DialectRegistry &registry) {
72   registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
73     TransferReadOp::attachInterface<XferOpSubsetOpInterface<TransferReadOp>>(
74         *ctx);
75     TransferReadOp::attachInterface<TransferReadOpSubsetExtractionOpInterface>(
76         *ctx);
77     TransferWriteOp::attachInterface<XferOpSubsetOpInterface<TransferWriteOp>>(
78         *ctx);
79     TransferWriteOp::attachInterface<TransferWriteOpSubsetInsertionOpInterface>(
80         *ctx);
81   });
82 }
83