1913286baSMatthias Springer //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2913286baSMatthias Springer //
3913286baSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4913286baSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5913286baSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6913286baSMatthias Springer //
7913286baSMatthias Springer //===----------------------------------------------------------------------===//
8913286baSMatthias Springer
9913286baSMatthias Springer #include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
10913286baSMatthias Springer
11913286baSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
12*1abd8d1aSMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.h"
13913286baSMatthias Springer
14913286baSMatthias Springer using namespace mlir;
15913286baSMatthias Springer using namespace mlir::linalg;
16913286baSMatthias Springer
17913286baSMatthias Springer namespace {
18*1abd8d1aSMatthias Springer struct LinalgCopyOpSubsetOpInterface
19*1abd8d1aSMatthias Springer : public SubsetOpInterface::ExternalModel<LinalgCopyOpSubsetOpInterface,
20*1abd8d1aSMatthias Springer linalg::CopyOp> {
operatesOnEquivalentSubset__anonc4a13e580111::LinalgCopyOpSubsetOpInterface21*1abd8d1aSMatthias Springer bool operatesOnEquivalentSubset(
22*1abd8d1aSMatthias Springer Operation *op, SubsetOpInterface candidate,
23*1abd8d1aSMatthias Springer function_ref<bool(Value, Value)> equivalenceFn) const {
24*1abd8d1aSMatthias Springer // linalg.copy operates on the entire destination tensor.
25*1abd8d1aSMatthias Springer if (auto otherCopyOp = dyn_cast<linalg::CopyOp>(candidate.getOperation()))
26*1abd8d1aSMatthias Springer return equivalenceFn(cast<linalg::CopyOp>(op).getOutputs()[0],
27*1abd8d1aSMatthias Springer otherCopyOp.getOutputs()[0]);
28*1abd8d1aSMatthias Springer // In the absence of an analysis, "false" is a conservative way to implement
29*1abd8d1aSMatthias Springer // this interface.
30*1abd8d1aSMatthias Springer return false;
31*1abd8d1aSMatthias Springer }
32*1abd8d1aSMatthias Springer
operatesOnDisjointSubset__anonc4a13e580111::LinalgCopyOpSubsetOpInterface33*1abd8d1aSMatthias Springer bool operatesOnDisjointSubset(
34*1abd8d1aSMatthias Springer Operation *op, SubsetOpInterface candidate,
35*1abd8d1aSMatthias Springer function_ref<bool(Value, Value)> equivalenceFn) const {
36*1abd8d1aSMatthias Springer // In the absence of an analysis, "false" is a conservative way to implement
37*1abd8d1aSMatthias Springer // this interface.
38*1abd8d1aSMatthias Springer return false;
39*1abd8d1aSMatthias Springer }
40*1abd8d1aSMatthias Springer };
41*1abd8d1aSMatthias Springer
42913286baSMatthias Springer struct LinalgCopyOpInterface
43913286baSMatthias Springer : public SubsetInsertionOpInterface::ExternalModel<LinalgCopyOpInterface,
44913286baSMatthias Springer linalg::CopyOp> {
getSourceOperand__anonc4a13e580111::LinalgCopyOpInterface45913286baSMatthias Springer OpOperand &getSourceOperand(Operation *op) const {
46913286baSMatthias Springer auto copyOp = cast<CopyOp>(op);
47913286baSMatthias Springer assert(copyOp.getInputs().size() == 1 && "expected single input");
48913286baSMatthias Springer return copyOp.getInputsMutable()[0];
49913286baSMatthias Springer }
50913286baSMatthias Springer
51913286baSMatthias Springer bool
isEquivalentSubset__anonc4a13e580111::LinalgCopyOpInterface52913286baSMatthias Springer isEquivalentSubset(Operation *op, Value candidate,
53913286baSMatthias Springer function_ref<bool(Value, Value)> equivalenceFn) const {
54913286baSMatthias Springer auto copyOp = cast<CopyOp>(op);
55913286baSMatthias Springer assert(copyOp.getOutputs().size() == 1 && "expected single output");
56913286baSMatthias Springer return equivalenceFn(candidate, copyOp.getOutputs()[0]);
57913286baSMatthias Springer }
58913286baSMatthias Springer
buildSubsetExtraction__anonc4a13e580111::LinalgCopyOpInterface59913286baSMatthias Springer Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
60913286baSMatthias Springer Location loc) const {
61913286baSMatthias Springer auto copyOp = cast<CopyOp>(op);
62913286baSMatthias Springer assert(copyOp.getOutputs().size() == 1 && "expected single output");
63913286baSMatthias Springer return copyOp.getOutputs()[0];
64913286baSMatthias Springer }
65913286baSMatthias Springer
66913286baSMatthias Springer SmallVector<Value>
getValuesNeededToBuildSubsetExtraction__anonc4a13e580111::LinalgCopyOpInterface67913286baSMatthias Springer getValuesNeededToBuildSubsetExtraction(Operation *op) const {
68913286baSMatthias Springer auto copyOp = cast<CopyOp>(op);
69913286baSMatthias Springer assert(copyOp.getOutputs().size() == 1 && "expected single output");
70913286baSMatthias Springer return {copyOp.getOutputs()[0]};
71913286baSMatthias Springer }
72913286baSMatthias Springer };
73913286baSMatthias Springer } // namespace
74913286baSMatthias Springer
registerSubsetOpInterfaceExternalModels(DialectRegistry & registry)75*1abd8d1aSMatthias Springer void mlir::linalg::registerSubsetOpInterfaceExternalModels(
76913286baSMatthias Springer DialectRegistry ®istry) {
77913286baSMatthias Springer registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
78*1abd8d1aSMatthias Springer linalg::CopyOp::attachInterface<LinalgCopyOpSubsetOpInterface>(*ctx);
79913286baSMatthias Springer linalg::CopyOp::attachInterface<LinalgCopyOpInterface>(*ctx);
80913286baSMatthias Springer });
81913286baSMatthias Springer }
82