xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp (revision 1abd8d1a8d962a14ca96e19c0f6da4f9ac394d0a)
1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Interfaces/SubsetOpInterface.h"
13 
14 using namespace mlir;
15 using namespace mlir::linalg;
16 
17 namespace {
18 struct LinalgCopyOpSubsetOpInterface
19     : public SubsetOpInterface::ExternalModel<LinalgCopyOpSubsetOpInterface,
20                                               linalg::CopyOp> {
operatesOnEquivalentSubset__anonc4a13e580111::LinalgCopyOpSubsetOpInterface21   bool operatesOnEquivalentSubset(
22       Operation *op, SubsetOpInterface candidate,
23       function_ref<bool(Value, Value)> equivalenceFn) const {
24     // linalg.copy operates on the entire destination tensor.
25     if (auto otherCopyOp = dyn_cast<linalg::CopyOp>(candidate.getOperation()))
26       return equivalenceFn(cast<linalg::CopyOp>(op).getOutputs()[0],
27                            otherCopyOp.getOutputs()[0]);
28     // In the absence of an analysis, "false" is a conservative way to implement
29     // this interface.
30     return false;
31   }
32 
operatesOnDisjointSubset__anonc4a13e580111::LinalgCopyOpSubsetOpInterface33   bool operatesOnDisjointSubset(
34       Operation *op, SubsetOpInterface candidate,
35       function_ref<bool(Value, Value)> equivalenceFn) const {
36     // In the absence of an analysis, "false" is a conservative way to implement
37     // this interface.
38     return false;
39   }
40 };
41 
42 struct LinalgCopyOpInterface
43     : public SubsetInsertionOpInterface::ExternalModel<LinalgCopyOpInterface,
44                                                        linalg::CopyOp> {
getSourceOperand__anonc4a13e580111::LinalgCopyOpInterface45   OpOperand &getSourceOperand(Operation *op) const {
46     auto copyOp = cast<CopyOp>(op);
47     assert(copyOp.getInputs().size() == 1 && "expected single input");
48     return copyOp.getInputsMutable()[0];
49   }
50 
51   bool
isEquivalentSubset__anonc4a13e580111::LinalgCopyOpInterface52   isEquivalentSubset(Operation *op, Value candidate,
53                      function_ref<bool(Value, Value)> equivalenceFn) const {
54     auto copyOp = cast<CopyOp>(op);
55     assert(copyOp.getOutputs().size() == 1 && "expected single output");
56     return equivalenceFn(candidate, copyOp.getOutputs()[0]);
57   }
58 
buildSubsetExtraction__anonc4a13e580111::LinalgCopyOpInterface59   Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
60                               Location loc) const {
61     auto copyOp = cast<CopyOp>(op);
62     assert(copyOp.getOutputs().size() == 1 && "expected single output");
63     return copyOp.getOutputs()[0];
64   }
65 
66   SmallVector<Value>
getValuesNeededToBuildSubsetExtraction__anonc4a13e580111::LinalgCopyOpInterface67   getValuesNeededToBuildSubsetExtraction(Operation *op) const {
68     auto copyOp = cast<CopyOp>(op);
69     assert(copyOp.getOutputs().size() == 1 && "expected single output");
70     return {copyOp.getOutputs()[0]};
71   }
72 };
73 } // namespace
74 
registerSubsetOpInterfaceExternalModels(DialectRegistry & registry)75 void mlir::linalg::registerSubsetOpInterfaceExternalModels(
76     DialectRegistry &registry) {
77   registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
78     linalg::CopyOp::attachInterface<LinalgCopyOpSubsetOpInterface>(*ctx);
79     linalg::CopyOp::attachInterface<LinalgCopyOpInterface>(*ctx);
80   });
81 }
82