xref: /llvm-project/mlir/lib/Interfaces/SubsetOpInterface.cpp (revision ff614a5729e9a4fc32465ad5ff3b87e044429c2d)
11abd8d1aSMatthias Springer //===- SubsetOpInterface.cpp - Tensor Subsets -----------------------------===//
21abd8d1aSMatthias Springer //
31abd8d1aSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41abd8d1aSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
51abd8d1aSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61abd8d1aSMatthias Springer //
71abd8d1aSMatthias Springer //===----------------------------------------------------------------------===//
81abd8d1aSMatthias Springer 
91abd8d1aSMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.h"
101abd8d1aSMatthias Springer #include "mlir/Interfaces/DestinationStyleOpInterface.h"
11*ff614a57SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h"
121abd8d1aSMatthias Springer 
131abd8d1aSMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
141abd8d1aSMatthias Springer 
151abd8d1aSMatthias Springer using namespace mlir;
161abd8d1aSMatthias Springer 
defaultGetDestinationOperand(Operation * op)171abd8d1aSMatthias Springer OpOperand &detail::defaultGetDestinationOperand(Operation *op) {
181abd8d1aSMatthias Springer   auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
191abd8d1aSMatthias Springer   assert(dstOp && "getDestination must be implemented for non-DPS ops");
201abd8d1aSMatthias Springer   assert(
211abd8d1aSMatthias Springer       dstOp.getNumDpsInits() == 1 &&
221abd8d1aSMatthias Springer       "getDestination must be implemented for ops with 0 or more than 1 init");
231abd8d1aSMatthias Springer   return *dstOp.getDpsInitOperand(0);
241abd8d1aSMatthias Springer }
251abd8d1aSMatthias Springer 
defaultGetUpdatedDestination(Operation * op)261abd8d1aSMatthias Springer OpResult detail::defaultGetUpdatedDestination(Operation *op) {
271abd8d1aSMatthias Springer   auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
281abd8d1aSMatthias Springer   assert(dstOp && "getUpdatedDestination must be implemented for non-DPS ops");
291abd8d1aSMatthias Springer   auto insertionOp = cast<SubsetInsertionOpInterface>(op);
301abd8d1aSMatthias Springer   return dstOp.getTiedOpResult(&insertionOp.getDestinationOperand());
311abd8d1aSMatthias Springer }
321abd8d1aSMatthias Springer 
defaultIsEquivalentSubset(Operation * op,Value candidate,function_ref<bool (Value,Value)> equivalenceFn)331abd8d1aSMatthias Springer bool detail::defaultIsEquivalentSubset(
341abd8d1aSMatthias Springer     Operation *op, Value candidate,
351abd8d1aSMatthias Springer     function_ref<bool(Value, Value)> equivalenceFn) {
361abd8d1aSMatthias Springer   assert(isa<SubsetInsertionOpInterface>(op) &&
371abd8d1aSMatthias Springer          "expected SubsetInsertionOpInterface");
381abd8d1aSMatthias Springer   if (!candidate.getDefiningOp<SubsetExtractionOpInterface>())
391abd8d1aSMatthias Springer     return false;
401abd8d1aSMatthias Springer   return cast<SubsetOpInterface>(op).operatesOnEquivalentSubset(
411abd8d1aSMatthias Springer       candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
421abd8d1aSMatthias Springer }
431abd8d1aSMatthias Springer 
defaultOperatesOnEquivalentSubset(Operation * op,SubsetOpInterface candidate,function_ref<bool (Value,Value)> equivalenceFn)44*ff614a57SMatthias Springer bool detail::defaultOperatesOnEquivalentSubset(
45*ff614a57SMatthias Springer     Operation *op, SubsetOpInterface candidate,
46*ff614a57SMatthias Springer     function_ref<bool(Value, Value)> equivalenceFn) {
47*ff614a57SMatthias Springer   auto subsetOp = cast<SubsetOpInterface>(op);
48*ff614a57SMatthias Springer   FailureOr<HyperrectangularSlice> slice =
49*ff614a57SMatthias Springer       subsetOp.getAccessedHyperrectangularSlice();
50*ff614a57SMatthias Springer   assert(succeeded(slice) &&
51*ff614a57SMatthias Springer          "operatesOnEquivalentSubset must be implemented if "
52*ff614a57SMatthias Springer          "getAccessedHyperrectangularSlice is not implemented");
53*ff614a57SMatthias Springer   FailureOr<HyperrectangularSlice> otherSlice =
54*ff614a57SMatthias Springer       candidate.getAccessedHyperrectangularSlice();
55*ff614a57SMatthias Springer   if (failed(otherSlice))
56*ff614a57SMatthias Springer     return false;
57*ff614a57SMatthias Springer   if (!equivalenceFn(subsetOp.getTensorContainer(),
58*ff614a57SMatthias Springer                      candidate.getTensorContainer()))
59*ff614a57SMatthias Springer     return false;
60*ff614a57SMatthias Springer   FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
61*ff614a57SMatthias Springer       op->getContext(), *slice, *otherSlice);
62*ff614a57SMatthias Springer   return succeeded(equivalent) && *equivalent;
63*ff614a57SMatthias Springer }
64*ff614a57SMatthias Springer 
defaultOperatesOnDisjointSubset(Operation * op,SubsetOpInterface candidate,function_ref<bool (Value,Value)> equivalenceFn)65*ff614a57SMatthias Springer bool detail::defaultOperatesOnDisjointSubset(
66*ff614a57SMatthias Springer     Operation *op, SubsetOpInterface candidate,
67*ff614a57SMatthias Springer     function_ref<bool(Value, Value)> equivalenceFn) {
68*ff614a57SMatthias Springer   auto subsetOp = cast<SubsetOpInterface>(op);
69*ff614a57SMatthias Springer   FailureOr<HyperrectangularSlice> slice =
70*ff614a57SMatthias Springer       subsetOp.getAccessedHyperrectangularSlice();
71*ff614a57SMatthias Springer   assert(succeeded(slice) &&
72*ff614a57SMatthias Springer          "defaultOperatesOnDisjointSubset must be implemented if "
73*ff614a57SMatthias Springer          "getAccessedHyperrectangularSlice is not implemented");
74*ff614a57SMatthias Springer   FailureOr<HyperrectangularSlice> otherSlice =
75*ff614a57SMatthias Springer       candidate.getAccessedHyperrectangularSlice();
76*ff614a57SMatthias Springer   if (failed(otherSlice))
77*ff614a57SMatthias Springer     return false;
78*ff614a57SMatthias Springer   if (!equivalenceFn(subsetOp.getTensorContainer(),
79*ff614a57SMatthias Springer                      candidate.getTensorContainer()))
80*ff614a57SMatthias Springer     return false;
81*ff614a57SMatthias Springer   FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
82*ff614a57SMatthias Springer       op->getContext(), *slice, *otherSlice);
83*ff614a57SMatthias Springer   return succeeded(overlapping) && !*overlapping;
84*ff614a57SMatthias Springer }
85*ff614a57SMatthias Springer 
getTensorContainer(Operation * op)86*ff614a57SMatthias Springer Value detail::getTensorContainer(Operation *op) {
87*ff614a57SMatthias Springer   if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
88*ff614a57SMatthias Springer     return insertionOp.getDestinationOperand().get();
89*ff614a57SMatthias Springer   return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
90*ff614a57SMatthias Springer }
91*ff614a57SMatthias Springer 
verifySubsetOpInterface(SubsetOpInterface op)921abd8d1aSMatthias Springer LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
931abd8d1aSMatthias Springer   if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
941abd8d1aSMatthias Springer         isa<SubsetInsertionOpInterface>(op.getOperation())))
951abd8d1aSMatthias Springer     return op->emitOpError(
961abd8d1aSMatthias Springer         "SubsetOpInterface ops must implement either "
971abd8d1aSMatthias Springer         "SubsetExtractionOpInterface or SubsetInsertionOpInterface");
981abd8d1aSMatthias Springer   return success();
991abd8d1aSMatthias Springer }
1001abd8d1aSMatthias Springer 
1011abd8d1aSMatthias Springer LogicalResult
verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op)1021abd8d1aSMatthias Springer detail::verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op) {
1031abd8d1aSMatthias Springer   if (op->getNumResults() != 1)
1041abd8d1aSMatthias Springer     return op->emitOpError(
1051abd8d1aSMatthias Springer         "SubsetExtractionOpInterface ops must have one result");
1061abd8d1aSMatthias Springer   return success();
1071abd8d1aSMatthias Springer }
108