1 //===- SubsetOpInterfaceImpl.cpp - Tensor subsets -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
10
11 #include "mlir/Dialect/Vector/IR/VectorOps.h"
12 #include "mlir/Interfaces/SubsetOpInterface.h"
13
14 using namespace mlir;
15 using namespace mlir::vector;
16
17 namespace {
18
19 template <typename OpTy>
20 struct XferOpSubsetOpInterface
21 : public SubsetOpInterface::ExternalModel<XferOpSubsetOpInterface<OpTy>,
22 OpTy> {
23 FailureOr<HyperrectangularSlice>
getAccessedHyperrectangularSlice__anon39b34ad90111::XferOpSubsetOpInterface24 getAccessedHyperrectangularSlice(Operation *op) const {
25 auto xferOp = cast<OpTy>(op);
26 Builder b(xferOp->getContext());
27 SmallVector<OpFoldResult> offsets = llvm::map_to_vector(
28 xferOp.getIndices(), [](Value v) -> OpFoldResult { return v; });
29 SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
30 xferOp.getTransferChunkAccessed(),
31 [&](int64_t sz) -> OpFoldResult { return b.getIndexAttr(sz); });
32 return HyperrectangularSlice(offsets, sizes);
33 }
34 };
35
36 struct TransferReadOpSubsetExtractionOpInterface
37 : public SubsetExtractionOpInterface::ExternalModel<
38 TransferReadOpSubsetExtractionOpInterface, vector::TransferReadOp> {
getSourceOperand__anon39b34ad90111::TransferReadOpSubsetExtractionOpInterface39 OpOperand &getSourceOperand(Operation *op) const {
40 return cast<vector::TransferReadOp>(op).getSourceMutable();
41 }
42 };
43
44 struct TransferWriteOpSubsetInsertionOpInterface
45 : public SubsetInsertionOpInterface::ExternalModel<
46 TransferWriteOpSubsetInsertionOpInterface, vector::TransferWriteOp> {
getSourceOperand__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface47 OpOperand &getSourceOperand(Operation *op) const {
48 return cast<vector::TransferWriteOp>(op).getVectorMutable();
49 }
50
getDestinationOperand__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface51 OpOperand &getDestinationOperand(Operation *op) const {
52 return cast<vector::TransferWriteOp>(op).getSourceMutable();
53 }
54
buildSubsetExtraction__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface55 Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
56 Location loc) const {
57 // TODO: Implement when needed.
58 return Value();
59 }
60
61 SmallVector<Value>
getValuesNeededToBuildSubsetExtraction__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface62 getValuesNeededToBuildSubsetExtraction(Operation *op) const {
63 // TODO: Implement when needed.
64 return {};
65 }
66 };
67
68 } // namespace
69
registerSubsetOpInterfaceExternalModels(DialectRegistry & registry)70 void mlir::vector::registerSubsetOpInterfaceExternalModels(
71 DialectRegistry ®istry) {
72 registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
73 TransferReadOp::attachInterface<XferOpSubsetOpInterface<TransferReadOp>>(
74 *ctx);
75 TransferReadOp::attachInterface<TransferReadOpSubsetExtractionOpInterface>(
76 *ctx);
77 TransferWriteOp::attachInterface<XferOpSubsetOpInterface<TransferWriteOp>>(
78 *ctx);
79 TransferWriteOp::attachInterface<TransferWriteOpSubsetInsertionOpInterface>(
80 *ctx);
81 });
82 }
83