1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
10
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Interfaces/SubsetOpInterface.h"
13
14 using namespace mlir;
15 using namespace mlir::linalg;
16
17 namespace {
18 struct LinalgCopyOpSubsetOpInterface
19 : public SubsetOpInterface::ExternalModel<LinalgCopyOpSubsetOpInterface,
20 linalg::CopyOp> {
operatesOnEquivalentSubset__anonc4a13e580111::LinalgCopyOpSubsetOpInterface21 bool operatesOnEquivalentSubset(
22 Operation *op, SubsetOpInterface candidate,
23 function_ref<bool(Value, Value)> equivalenceFn) const {
24 // linalg.copy operates on the entire destination tensor.
25 if (auto otherCopyOp = dyn_cast<linalg::CopyOp>(candidate.getOperation()))
26 return equivalenceFn(cast<linalg::CopyOp>(op).getOutputs()[0],
27 otherCopyOp.getOutputs()[0]);
28 // In the absence of an analysis, "false" is a conservative way to implement
29 // this interface.
30 return false;
31 }
32
operatesOnDisjointSubset__anonc4a13e580111::LinalgCopyOpSubsetOpInterface33 bool operatesOnDisjointSubset(
34 Operation *op, SubsetOpInterface candidate,
35 function_ref<bool(Value, Value)> equivalenceFn) const {
36 // In the absence of an analysis, "false" is a conservative way to implement
37 // this interface.
38 return false;
39 }
40 };
41
42 struct LinalgCopyOpInterface
43 : public SubsetInsertionOpInterface::ExternalModel<LinalgCopyOpInterface,
44 linalg::CopyOp> {
getSourceOperand__anonc4a13e580111::LinalgCopyOpInterface45 OpOperand &getSourceOperand(Operation *op) const {
46 auto copyOp = cast<CopyOp>(op);
47 assert(copyOp.getInputs().size() == 1 && "expected single input");
48 return copyOp.getInputsMutable()[0];
49 }
50
51 bool
isEquivalentSubset__anonc4a13e580111::LinalgCopyOpInterface52 isEquivalentSubset(Operation *op, Value candidate,
53 function_ref<bool(Value, Value)> equivalenceFn) const {
54 auto copyOp = cast<CopyOp>(op);
55 assert(copyOp.getOutputs().size() == 1 && "expected single output");
56 return equivalenceFn(candidate, copyOp.getOutputs()[0]);
57 }
58
buildSubsetExtraction__anonc4a13e580111::LinalgCopyOpInterface59 Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
60 Location loc) const {
61 auto copyOp = cast<CopyOp>(op);
62 assert(copyOp.getOutputs().size() == 1 && "expected single output");
63 return copyOp.getOutputs()[0];
64 }
65
66 SmallVector<Value>
getValuesNeededToBuildSubsetExtraction__anonc4a13e580111::LinalgCopyOpInterface67 getValuesNeededToBuildSubsetExtraction(Operation *op) const {
68 auto copyOp = cast<CopyOp>(op);
69 assert(copyOp.getOutputs().size() == 1 && "expected single output");
70 return {copyOp.getOutputs()[0]};
71 }
72 };
73 } // namespace
74
registerSubsetOpInterfaceExternalModels(DialectRegistry & registry)75 void mlir::linalg::registerSubsetOpInterfaceExternalModels(
76 DialectRegistry ®istry) {
77 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
78 linalg::CopyOp::attachInterface<LinalgCopyOpSubsetOpInterface>(*ctx);
79 linalg::CopyOp::attachInterface<LinalgCopyOpInterface>(*ctx);
80 });
81 }
82