xref: /llvm-project/mlir/lib/Interfaces/SubsetOpInterface.cpp (revision ff614a5729e9a4fc32465ad5ff3b87e044429c2d)
1 //===- SubsetOpInterface.cpp - Tensor Subsets -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/SubsetOpInterface.h"
10 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
11 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
12 
13 #include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
14 
15 using namespace mlir;
16 
defaultGetDestinationOperand(Operation * op)17 OpOperand &detail::defaultGetDestinationOperand(Operation *op) {
18   auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
19   assert(dstOp && "getDestination must be implemented for non-DPS ops");
20   assert(
21       dstOp.getNumDpsInits() == 1 &&
22       "getDestination must be implemented for ops with 0 or more than 1 init");
23   return *dstOp.getDpsInitOperand(0);
24 }
25 
defaultGetUpdatedDestination(Operation * op)26 OpResult detail::defaultGetUpdatedDestination(Operation *op) {
27   auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
28   assert(dstOp && "getUpdatedDestination must be implemented for non-DPS ops");
29   auto insertionOp = cast<SubsetInsertionOpInterface>(op);
30   return dstOp.getTiedOpResult(&insertionOp.getDestinationOperand());
31 }
32 
defaultIsEquivalentSubset(Operation * op,Value candidate,function_ref<bool (Value,Value)> equivalenceFn)33 bool detail::defaultIsEquivalentSubset(
34     Operation *op, Value candidate,
35     function_ref<bool(Value, Value)> equivalenceFn) {
36   assert(isa<SubsetInsertionOpInterface>(op) &&
37          "expected SubsetInsertionOpInterface");
38   if (!candidate.getDefiningOp<SubsetExtractionOpInterface>())
39     return false;
40   return cast<SubsetOpInterface>(op).operatesOnEquivalentSubset(
41       candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
42 }
43 
defaultOperatesOnEquivalentSubset(Operation * op,SubsetOpInterface candidate,function_ref<bool (Value,Value)> equivalenceFn)44 bool detail::defaultOperatesOnEquivalentSubset(
45     Operation *op, SubsetOpInterface candidate,
46     function_ref<bool(Value, Value)> equivalenceFn) {
47   auto subsetOp = cast<SubsetOpInterface>(op);
48   FailureOr<HyperrectangularSlice> slice =
49       subsetOp.getAccessedHyperrectangularSlice();
50   assert(succeeded(slice) &&
51          "operatesOnEquivalentSubset must be implemented if "
52          "getAccessedHyperrectangularSlice is not implemented");
53   FailureOr<HyperrectangularSlice> otherSlice =
54       candidate.getAccessedHyperrectangularSlice();
55   if (failed(otherSlice))
56     return false;
57   if (!equivalenceFn(subsetOp.getTensorContainer(),
58                      candidate.getTensorContainer()))
59     return false;
60   FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
61       op->getContext(), *slice, *otherSlice);
62   return succeeded(equivalent) && *equivalent;
63 }
64 
defaultOperatesOnDisjointSubset(Operation * op,SubsetOpInterface candidate,function_ref<bool (Value,Value)> equivalenceFn)65 bool detail::defaultOperatesOnDisjointSubset(
66     Operation *op, SubsetOpInterface candidate,
67     function_ref<bool(Value, Value)> equivalenceFn) {
68   auto subsetOp = cast<SubsetOpInterface>(op);
69   FailureOr<HyperrectangularSlice> slice =
70       subsetOp.getAccessedHyperrectangularSlice();
71   assert(succeeded(slice) &&
72          "defaultOperatesOnDisjointSubset must be implemented if "
73          "getAccessedHyperrectangularSlice is not implemented");
74   FailureOr<HyperrectangularSlice> otherSlice =
75       candidate.getAccessedHyperrectangularSlice();
76   if (failed(otherSlice))
77     return false;
78   if (!equivalenceFn(subsetOp.getTensorContainer(),
79                      candidate.getTensorContainer()))
80     return false;
81   FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
82       op->getContext(), *slice, *otherSlice);
83   return succeeded(overlapping) && !*overlapping;
84 }
85 
getTensorContainer(Operation * op)86 Value detail::getTensorContainer(Operation *op) {
87   if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
88     return insertionOp.getDestinationOperand().get();
89   return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
90 }
91 
verifySubsetOpInterface(SubsetOpInterface op)92 LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
93   if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
94         isa<SubsetInsertionOpInterface>(op.getOperation())))
95     return op->emitOpError(
96         "SubsetOpInterface ops must implement either "
97         "SubsetExtractionOpInterface or SubsetInsertionOpInterface");
98   return success();
99 }
100 
101 LogicalResult
verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op)102 detail::verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op) {
103   if (op->getNumResults() != 1)
104     return op->emitOpError(
105         "SubsetExtractionOpInterface ops must have one result");
106   return success();
107 }
108