xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp (revision 1df6504ac21acbdaeee4eed52e82af69a302024c)
1*1df6504aSMatthias Springer //===- SubsetOpInterfaceImpl.cpp - Tensor subsets -------------------------===//
2*1df6504aSMatthias Springer //
3*1df6504aSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*1df6504aSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5*1df6504aSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*1df6504aSMatthias Springer //
7*1df6504aSMatthias Springer //===----------------------------------------------------------------------===//
8*1df6504aSMatthias Springer 
9*1df6504aSMatthias Springer #include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
10*1df6504aSMatthias Springer 
11*1df6504aSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
12*1df6504aSMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.h"
13*1df6504aSMatthias Springer 
14*1df6504aSMatthias Springer using namespace mlir;
15*1df6504aSMatthias Springer using namespace mlir::vector;
16*1df6504aSMatthias Springer 
17*1df6504aSMatthias Springer namespace {
18*1df6504aSMatthias Springer 
19*1df6504aSMatthias Springer template <typename OpTy>
20*1df6504aSMatthias Springer struct XferOpSubsetOpInterface
21*1df6504aSMatthias Springer     : public SubsetOpInterface::ExternalModel<XferOpSubsetOpInterface<OpTy>,
22*1df6504aSMatthias Springer                                               OpTy> {
23*1df6504aSMatthias Springer   FailureOr<HyperrectangularSlice>
getAccessedHyperrectangularSlice__anon39b34ad90111::XferOpSubsetOpInterface24*1df6504aSMatthias Springer   getAccessedHyperrectangularSlice(Operation *op) const {
25*1df6504aSMatthias Springer     auto xferOp = cast<OpTy>(op);
26*1df6504aSMatthias Springer     Builder b(xferOp->getContext());
27*1df6504aSMatthias Springer     SmallVector<OpFoldResult> offsets = llvm::map_to_vector(
28*1df6504aSMatthias Springer         xferOp.getIndices(), [](Value v) -> OpFoldResult { return v; });
29*1df6504aSMatthias Springer     SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
30*1df6504aSMatthias Springer         xferOp.getTransferChunkAccessed(),
31*1df6504aSMatthias Springer         [&](int64_t sz) -> OpFoldResult { return b.getIndexAttr(sz); });
32*1df6504aSMatthias Springer     return HyperrectangularSlice(offsets, sizes);
33*1df6504aSMatthias Springer   }
34*1df6504aSMatthias Springer };
35*1df6504aSMatthias Springer 
36*1df6504aSMatthias Springer struct TransferReadOpSubsetExtractionOpInterface
37*1df6504aSMatthias Springer     : public SubsetExtractionOpInterface::ExternalModel<
38*1df6504aSMatthias Springer           TransferReadOpSubsetExtractionOpInterface, vector::TransferReadOp> {
getSourceOperand__anon39b34ad90111::TransferReadOpSubsetExtractionOpInterface39*1df6504aSMatthias Springer   OpOperand &getSourceOperand(Operation *op) const {
40*1df6504aSMatthias Springer     return cast<vector::TransferReadOp>(op).getSourceMutable();
41*1df6504aSMatthias Springer   }
42*1df6504aSMatthias Springer };
43*1df6504aSMatthias Springer 
44*1df6504aSMatthias Springer struct TransferWriteOpSubsetInsertionOpInterface
45*1df6504aSMatthias Springer     : public SubsetInsertionOpInterface::ExternalModel<
46*1df6504aSMatthias Springer           TransferWriteOpSubsetInsertionOpInterface, vector::TransferWriteOp> {
getSourceOperand__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface47*1df6504aSMatthias Springer   OpOperand &getSourceOperand(Operation *op) const {
48*1df6504aSMatthias Springer     return cast<vector::TransferWriteOp>(op).getVectorMutable();
49*1df6504aSMatthias Springer   }
50*1df6504aSMatthias Springer 
getDestinationOperand__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface51*1df6504aSMatthias Springer   OpOperand &getDestinationOperand(Operation *op) const {
52*1df6504aSMatthias Springer     return cast<vector::TransferWriteOp>(op).getSourceMutable();
53*1df6504aSMatthias Springer   }
54*1df6504aSMatthias Springer 
buildSubsetExtraction__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface55*1df6504aSMatthias Springer   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
56*1df6504aSMatthias Springer                               Location loc) const {
57*1df6504aSMatthias Springer     // TODO: Implement when needed.
58*1df6504aSMatthias Springer     return Value();
59*1df6504aSMatthias Springer   }
60*1df6504aSMatthias Springer 
61*1df6504aSMatthias Springer   SmallVector<Value>
getValuesNeededToBuildSubsetExtraction__anon39b34ad90111::TransferWriteOpSubsetInsertionOpInterface62*1df6504aSMatthias Springer   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
63*1df6504aSMatthias Springer     // TODO: Implement when needed.
64*1df6504aSMatthias Springer     return {};
65*1df6504aSMatthias Springer   }
66*1df6504aSMatthias Springer };
67*1df6504aSMatthias Springer 
68*1df6504aSMatthias Springer } // namespace
69*1df6504aSMatthias Springer 
registerSubsetOpInterfaceExternalModels(DialectRegistry & registry)70*1df6504aSMatthias Springer void mlir::vector::registerSubsetOpInterfaceExternalModels(
71*1df6504aSMatthias Springer     DialectRegistry &registry) {
72*1df6504aSMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
73*1df6504aSMatthias Springer     TransferReadOp::attachInterface<XferOpSubsetOpInterface<TransferReadOp>>(
74*1df6504aSMatthias Springer         *ctx);
75*1df6504aSMatthias Springer     TransferReadOp::attachInterface<TransferReadOpSubsetExtractionOpInterface>(
76*1df6504aSMatthias Springer         *ctx);
77*1df6504aSMatthias Springer     TransferWriteOp::attachInterface<XferOpSubsetOpInterface<TransferWriteOp>>(
78*1df6504aSMatthias Springer         *ctx);
79*1df6504aSMatthias Springer     TransferWriteOp::attachInterface<TransferWriteOpSubsetInsertionOpInterface>(
80*1df6504aSMatthias Springer         *ctx);
81*1df6504aSMatthias Springer   });
82*1df6504aSMatthias Springer }
83