xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp (revision ff614a5729e9a4fc32465ad5ff3b87e044429c2d)
18143307bSMatthias Springer //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
28143307bSMatthias Springer //
38143307bSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48143307bSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
58143307bSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68143307bSMatthias Springer //
78143307bSMatthias Springer //===----------------------------------------------------------------------===//
88143307bSMatthias Springer 
98143307bSMatthias Springer #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
108143307bSMatthias Springer 
118143307bSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
121abd8d1aSMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.h"
131abd8d1aSMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h"
148143307bSMatthias Springer 
158143307bSMatthias Springer using namespace mlir;
168143307bSMatthias Springer using namespace mlir::tensor;
178143307bSMatthias Springer 
188143307bSMatthias Springer namespace {
198143307bSMatthias Springer 
201abd8d1aSMatthias Springer struct ExtractSliceOpSubsetOpInterface
211abd8d1aSMatthias Springer     : public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
221abd8d1aSMatthias Springer                                               tensor::ExtractSliceOp> {
23*ff614a57SMatthias Springer   FailureOr<HyperrectangularSlice>
getAccessedHyperrectangularSlice__anon3d2f355c0111::ExtractSliceOpSubsetOpInterface24*ff614a57SMatthias Springer   getAccessedHyperrectangularSlice(Operation *op) const {
25*ff614a57SMatthias Springer     return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
261abd8d1aSMatthias Springer   }
271abd8d1aSMatthias Springer };
281abd8d1aSMatthias Springer 
291abd8d1aSMatthias Springer struct ExtractSliceOpSubsetExtractionOpInterface
301abd8d1aSMatthias Springer     : public SubsetExtractionOpInterface::ExternalModel<
311abd8d1aSMatthias Springer           ExtractSliceOpSubsetExtractionOpInterface, tensor::ExtractSliceOp> {
getSourceOperand__anon3d2f355c0111::ExtractSliceOpSubsetExtractionOpInterface321abd8d1aSMatthias Springer   OpOperand &getSourceOperand(Operation *op) const {
331abd8d1aSMatthias Springer     return cast<tensor::ExtractSliceOp>(op).getSourceMutable();
341abd8d1aSMatthias Springer   }
351abd8d1aSMatthias Springer };
361abd8d1aSMatthias Springer 
371abd8d1aSMatthias Springer template <typename OpTy>
381abd8d1aSMatthias Springer struct InsertSliceLikeOpSubsetOpInterface
391abd8d1aSMatthias Springer     : public SubsetOpInterface::ExternalModel<
401abd8d1aSMatthias Springer           InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
41*ff614a57SMatthias Springer   FailureOr<HyperrectangularSlice>
getAccessedHyperrectangularSlice__anon3d2f355c0111::InsertSliceLikeOpSubsetOpInterface42*ff614a57SMatthias Springer   getAccessedHyperrectangularSlice(Operation *op) const {
43*ff614a57SMatthias Springer     return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
441abd8d1aSMatthias Springer   }
451abd8d1aSMatthias Springer };
461abd8d1aSMatthias Springer 
471abd8d1aSMatthias Springer template <typename OpTy>
481abd8d1aSMatthias Springer struct InsertSliceLikeOpSubsetInsertionOpInterface
492e3c62b1SMatthias Springer     : public SubsetInsertionOpInterface::ExternalModel<
501abd8d1aSMatthias Springer           InsertSliceLikeOpSubsetInsertionOpInterface<OpTy>, OpTy> {
getSourceOperand__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface512e3c62b1SMatthias Springer   OpOperand &getSourceOperand(Operation *op) const {
522e3c62b1SMatthias Springer     return cast<OpTy>(op).getSourceMutable();
532e3c62b1SMatthias Springer   }
542e3c62b1SMatthias Springer 
getDestinationOperand__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface552e3c62b1SMatthias Springer   OpOperand &getDestinationOperand(Operation *op) const {
562e3c62b1SMatthias Springer     return cast<OpTy>(op).getDestMutable();
572e3c62b1SMatthias Springer   }
582e3c62b1SMatthias Springer 
buildSubsetExtraction__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface592e3c62b1SMatthias Springer   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
602e3c62b1SMatthias Springer                               Location loc) const {
612e3c62b1SMatthias Springer     auto insertSliceOp = cast<OpTy>(op);
622e3c62b1SMatthias Springer     auto extractOp = builder.create<tensor::ExtractSliceOp>(
63a1ef5a94SMatthias Springer         loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
64a1ef5a94SMatthias Springer         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
65a1ef5a94SMatthias Springer         insertSliceOp.getMixedStrides());
66a1ef5a94SMatthias Springer     return extractOp.getResult();
67a1ef5a94SMatthias Springer   }
68a1ef5a94SMatthias Springer 
69a1ef5a94SMatthias Springer   SmallVector<Value>
getValuesNeededToBuildSubsetExtraction__anon3d2f355c0111::InsertSliceLikeOpSubsetInsertionOpInterface702e3c62b1SMatthias Springer   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
712e3c62b1SMatthias Springer     auto insertSliceOp = cast<OpTy>(op);
72a1ef5a94SMatthias Springer     SmallVector<Value> neededValues;
73a1ef5a94SMatthias Springer     // Collect all values that are needed to construct the replacement op.
74a1ef5a94SMatthias Springer     neededValues.append(insertSliceOp.getOffsets().begin(),
75a1ef5a94SMatthias Springer                         insertSliceOp.getOffsets().end());
76a1ef5a94SMatthias Springer     neededValues.append(insertSliceOp.getSizes().begin(),
77a1ef5a94SMatthias Springer                         insertSliceOp.getSizes().end());
78a1ef5a94SMatthias Springer     neededValues.append(insertSliceOp.getStrides().begin(),
79a1ef5a94SMatthias Springer                         insertSliceOp.getStrides().end());
80a1ef5a94SMatthias Springer     neededValues.push_back(insertSliceOp.getDest());
81a1ef5a94SMatthias Springer     return neededValues;
82a1ef5a94SMatthias Springer   }
838143307bSMatthias Springer };
848143307bSMatthias Springer 
858143307bSMatthias Springer } // namespace
868143307bSMatthias Springer 
registerSubsetOpInterfaceExternalModels(DialectRegistry & registry)871abd8d1aSMatthias Springer void mlir::tensor::registerSubsetOpInterfaceExternalModels(
888143307bSMatthias Springer     DialectRegistry &registry) {
898143307bSMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
901abd8d1aSMatthias Springer     // Note: `SubsetExtractionOpInterface` and `SubsetInsertionOpInterface`
911abd8d1aSMatthias Springer     // require `SubsetOpInterface`.
921abd8d1aSMatthias Springer     ExtractSliceOp::attachInterface<ExtractSliceOpSubsetOpInterface>(*ctx);
931abd8d1aSMatthias Springer     ExtractSliceOp::attachInterface<ExtractSliceOpSubsetExtractionOpInterface>(
948143307bSMatthias Springer         *ctx);
951abd8d1aSMatthias Springer     InsertSliceOp::attachInterface<
961abd8d1aSMatthias Springer         InsertSliceLikeOpSubsetOpInterface<InsertSliceOp>>(*ctx);
971abd8d1aSMatthias Springer     InsertSliceOp::attachInterface<
981abd8d1aSMatthias Springer         InsertSliceLikeOpSubsetInsertionOpInterface<InsertSliceOp>>(*ctx);
992e3c62b1SMatthias Springer     ParallelInsertSliceOp::attachInterface<
1001abd8d1aSMatthias Springer         InsertSliceLikeOpSubsetOpInterface<ParallelInsertSliceOp>>(*ctx);
1011abd8d1aSMatthias Springer     ParallelInsertSliceOp::attachInterface<
1021abd8d1aSMatthias Springer         InsertSliceLikeOpSubsetInsertionOpInterface<ParallelInsertSliceOp>>(
1031abd8d1aSMatthias Springer         *ctx);
1048143307bSMatthias Springer   });
1058143307bSMatthias Springer }
106