xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp (revision 1abd8d1a8d962a14ca96e19c0f6da4f9ac394d0a)
1913286baSMatthias Springer //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2913286baSMatthias Springer //
3913286baSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4913286baSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5913286baSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6913286baSMatthias Springer //
7913286baSMatthias Springer //===----------------------------------------------------------------------===//
8913286baSMatthias Springer 
9913286baSMatthias Springer #include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
10913286baSMatthias Springer 
11913286baSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
12*1abd8d1aSMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.h"
13913286baSMatthias Springer 
14913286baSMatthias Springer using namespace mlir;
15913286baSMatthias Springer using namespace mlir::linalg;
16913286baSMatthias Springer 
17913286baSMatthias Springer namespace {
18*1abd8d1aSMatthias Springer struct LinalgCopyOpSubsetOpInterface
19*1abd8d1aSMatthias Springer     : public SubsetOpInterface::ExternalModel<LinalgCopyOpSubsetOpInterface,
20*1abd8d1aSMatthias Springer                                               linalg::CopyOp> {
operatesOnEquivalentSubset__anonc4a13e580111::LinalgCopyOpSubsetOpInterface21*1abd8d1aSMatthias Springer   bool operatesOnEquivalentSubset(
22*1abd8d1aSMatthias Springer       Operation *op, SubsetOpInterface candidate,
23*1abd8d1aSMatthias Springer       function_ref<bool(Value, Value)> equivalenceFn) const {
24*1abd8d1aSMatthias Springer     // linalg.copy operates on the entire destination tensor.
25*1abd8d1aSMatthias Springer     if (auto otherCopyOp = dyn_cast<linalg::CopyOp>(candidate.getOperation()))
26*1abd8d1aSMatthias Springer       return equivalenceFn(cast<linalg::CopyOp>(op).getOutputs()[0],
27*1abd8d1aSMatthias Springer                            otherCopyOp.getOutputs()[0]);
28*1abd8d1aSMatthias Springer     // In the absence of an analysis, "false" is a conservative way to implement
29*1abd8d1aSMatthias Springer     // this interface.
30*1abd8d1aSMatthias Springer     return false;
31*1abd8d1aSMatthias Springer   }
32*1abd8d1aSMatthias Springer 
operatesOnDisjointSubset__anonc4a13e580111::LinalgCopyOpSubsetOpInterface33*1abd8d1aSMatthias Springer   bool operatesOnDisjointSubset(
34*1abd8d1aSMatthias Springer       Operation *op, SubsetOpInterface candidate,
35*1abd8d1aSMatthias Springer       function_ref<bool(Value, Value)> equivalenceFn) const {
36*1abd8d1aSMatthias Springer     // In the absence of an analysis, "false" is a conservative way to implement
37*1abd8d1aSMatthias Springer     // this interface.
38*1abd8d1aSMatthias Springer     return false;
39*1abd8d1aSMatthias Springer   }
40*1abd8d1aSMatthias Springer };
41*1abd8d1aSMatthias Springer 
42913286baSMatthias Springer struct LinalgCopyOpInterface
43913286baSMatthias Springer     : public SubsetInsertionOpInterface::ExternalModel<LinalgCopyOpInterface,
44913286baSMatthias Springer                                                        linalg::CopyOp> {
getSourceOperand__anonc4a13e580111::LinalgCopyOpInterface45913286baSMatthias Springer   OpOperand &getSourceOperand(Operation *op) const {
46913286baSMatthias Springer     auto copyOp = cast<CopyOp>(op);
47913286baSMatthias Springer     assert(copyOp.getInputs().size() == 1 && "expected single input");
48913286baSMatthias Springer     return copyOp.getInputsMutable()[0];
49913286baSMatthias Springer   }
50913286baSMatthias Springer 
51913286baSMatthias Springer   bool
isEquivalentSubset__anonc4a13e580111::LinalgCopyOpInterface52913286baSMatthias Springer   isEquivalentSubset(Operation *op, Value candidate,
53913286baSMatthias Springer                      function_ref<bool(Value, Value)> equivalenceFn) const {
54913286baSMatthias Springer     auto copyOp = cast<CopyOp>(op);
55913286baSMatthias Springer     assert(copyOp.getOutputs().size() == 1 && "expected single output");
56913286baSMatthias Springer     return equivalenceFn(candidate, copyOp.getOutputs()[0]);
57913286baSMatthias Springer   }
58913286baSMatthias Springer 
buildSubsetExtraction__anonc4a13e580111::LinalgCopyOpInterface59913286baSMatthias Springer   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
60913286baSMatthias Springer                               Location loc) const {
61913286baSMatthias Springer     auto copyOp = cast<CopyOp>(op);
62913286baSMatthias Springer     assert(copyOp.getOutputs().size() == 1 && "expected single output");
63913286baSMatthias Springer     return copyOp.getOutputs()[0];
64913286baSMatthias Springer   }
65913286baSMatthias Springer 
66913286baSMatthias Springer   SmallVector<Value>
getValuesNeededToBuildSubsetExtraction__anonc4a13e580111::LinalgCopyOpInterface67913286baSMatthias Springer   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
68913286baSMatthias Springer     auto copyOp = cast<CopyOp>(op);
69913286baSMatthias Springer     assert(copyOp.getOutputs().size() == 1 && "expected single output");
70913286baSMatthias Springer     return {copyOp.getOutputs()[0]};
71913286baSMatthias Springer   }
72913286baSMatthias Springer };
73913286baSMatthias Springer } // namespace
74913286baSMatthias Springer 
registerSubsetOpInterfaceExternalModels(DialectRegistry & registry)75*1abd8d1aSMatthias Springer void mlir::linalg::registerSubsetOpInterfaceExternalModels(
76913286baSMatthias Springer     DialectRegistry &registry) {
77913286baSMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
78*1abd8d1aSMatthias Springer     linalg::CopyOp::attachInterface<LinalgCopyOpSubsetOpInterface>(*ctx);
79913286baSMatthias Springer     linalg::CopyOp::attachInterface<LinalgCopyOpInterface>(*ctx);
80913286baSMatthias Springer   });
81913286baSMatthias Springer }
82