11abd8d1aSMatthias Springer //===- SubsetOpInterface.cpp - Tensor Subsets -----------------------------===//
21abd8d1aSMatthias Springer //
31abd8d1aSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41abd8d1aSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
51abd8d1aSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61abd8d1aSMatthias Springer //
71abd8d1aSMatthias Springer //===----------------------------------------------------------------------===//
81abd8d1aSMatthias Springer
91abd8d1aSMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.h"
101abd8d1aSMatthias Springer #include "mlir/Interfaces/DestinationStyleOpInterface.h"
11*ff614a57SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h"
121abd8d1aSMatthias Springer
131abd8d1aSMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
141abd8d1aSMatthias Springer
151abd8d1aSMatthias Springer using namespace mlir;
161abd8d1aSMatthias Springer
defaultGetDestinationOperand(Operation * op)171abd8d1aSMatthias Springer OpOperand &detail::defaultGetDestinationOperand(Operation *op) {
181abd8d1aSMatthias Springer auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
191abd8d1aSMatthias Springer assert(dstOp && "getDestination must be implemented for non-DPS ops");
201abd8d1aSMatthias Springer assert(
211abd8d1aSMatthias Springer dstOp.getNumDpsInits() == 1 &&
221abd8d1aSMatthias Springer "getDestination must be implemented for ops with 0 or more than 1 init");
231abd8d1aSMatthias Springer return *dstOp.getDpsInitOperand(0);
241abd8d1aSMatthias Springer }
251abd8d1aSMatthias Springer
defaultGetUpdatedDestination(Operation * op)261abd8d1aSMatthias Springer OpResult detail::defaultGetUpdatedDestination(Operation *op) {
271abd8d1aSMatthias Springer auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
281abd8d1aSMatthias Springer assert(dstOp && "getUpdatedDestination must be implemented for non-DPS ops");
291abd8d1aSMatthias Springer auto insertionOp = cast<SubsetInsertionOpInterface>(op);
301abd8d1aSMatthias Springer return dstOp.getTiedOpResult(&insertionOp.getDestinationOperand());
311abd8d1aSMatthias Springer }
321abd8d1aSMatthias Springer
defaultIsEquivalentSubset(Operation * op,Value candidate,function_ref<bool (Value,Value)> equivalenceFn)331abd8d1aSMatthias Springer bool detail::defaultIsEquivalentSubset(
341abd8d1aSMatthias Springer Operation *op, Value candidate,
351abd8d1aSMatthias Springer function_ref<bool(Value, Value)> equivalenceFn) {
361abd8d1aSMatthias Springer assert(isa<SubsetInsertionOpInterface>(op) &&
371abd8d1aSMatthias Springer "expected SubsetInsertionOpInterface");
381abd8d1aSMatthias Springer if (!candidate.getDefiningOp<SubsetExtractionOpInterface>())
391abd8d1aSMatthias Springer return false;
401abd8d1aSMatthias Springer return cast<SubsetOpInterface>(op).operatesOnEquivalentSubset(
411abd8d1aSMatthias Springer candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
421abd8d1aSMatthias Springer }
431abd8d1aSMatthias Springer
defaultOperatesOnEquivalentSubset(Operation * op,SubsetOpInterface candidate,function_ref<bool (Value,Value)> equivalenceFn)44*ff614a57SMatthias Springer bool detail::defaultOperatesOnEquivalentSubset(
45*ff614a57SMatthias Springer Operation *op, SubsetOpInterface candidate,
46*ff614a57SMatthias Springer function_ref<bool(Value, Value)> equivalenceFn) {
47*ff614a57SMatthias Springer auto subsetOp = cast<SubsetOpInterface>(op);
48*ff614a57SMatthias Springer FailureOr<HyperrectangularSlice> slice =
49*ff614a57SMatthias Springer subsetOp.getAccessedHyperrectangularSlice();
50*ff614a57SMatthias Springer assert(succeeded(slice) &&
51*ff614a57SMatthias Springer "operatesOnEquivalentSubset must be implemented if "
52*ff614a57SMatthias Springer "getAccessedHyperrectangularSlice is not implemented");
53*ff614a57SMatthias Springer FailureOr<HyperrectangularSlice> otherSlice =
54*ff614a57SMatthias Springer candidate.getAccessedHyperrectangularSlice();
55*ff614a57SMatthias Springer if (failed(otherSlice))
56*ff614a57SMatthias Springer return false;
57*ff614a57SMatthias Springer if (!equivalenceFn(subsetOp.getTensorContainer(),
58*ff614a57SMatthias Springer candidate.getTensorContainer()))
59*ff614a57SMatthias Springer return false;
60*ff614a57SMatthias Springer FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
61*ff614a57SMatthias Springer op->getContext(), *slice, *otherSlice);
62*ff614a57SMatthias Springer return succeeded(equivalent) && *equivalent;
63*ff614a57SMatthias Springer }
64*ff614a57SMatthias Springer
defaultOperatesOnDisjointSubset(Operation * op,SubsetOpInterface candidate,function_ref<bool (Value,Value)> equivalenceFn)65*ff614a57SMatthias Springer bool detail::defaultOperatesOnDisjointSubset(
66*ff614a57SMatthias Springer Operation *op, SubsetOpInterface candidate,
67*ff614a57SMatthias Springer function_ref<bool(Value, Value)> equivalenceFn) {
68*ff614a57SMatthias Springer auto subsetOp = cast<SubsetOpInterface>(op);
69*ff614a57SMatthias Springer FailureOr<HyperrectangularSlice> slice =
70*ff614a57SMatthias Springer subsetOp.getAccessedHyperrectangularSlice();
71*ff614a57SMatthias Springer assert(succeeded(slice) &&
72*ff614a57SMatthias Springer "defaultOperatesOnDisjointSubset must be implemented if "
73*ff614a57SMatthias Springer "getAccessedHyperrectangularSlice is not implemented");
74*ff614a57SMatthias Springer FailureOr<HyperrectangularSlice> otherSlice =
75*ff614a57SMatthias Springer candidate.getAccessedHyperrectangularSlice();
76*ff614a57SMatthias Springer if (failed(otherSlice))
77*ff614a57SMatthias Springer return false;
78*ff614a57SMatthias Springer if (!equivalenceFn(subsetOp.getTensorContainer(),
79*ff614a57SMatthias Springer candidate.getTensorContainer()))
80*ff614a57SMatthias Springer return false;
81*ff614a57SMatthias Springer FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
82*ff614a57SMatthias Springer op->getContext(), *slice, *otherSlice);
83*ff614a57SMatthias Springer return succeeded(overlapping) && !*overlapping;
84*ff614a57SMatthias Springer }
85*ff614a57SMatthias Springer
getTensorContainer(Operation * op)86*ff614a57SMatthias Springer Value detail::getTensorContainer(Operation *op) {
87*ff614a57SMatthias Springer if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
88*ff614a57SMatthias Springer return insertionOp.getDestinationOperand().get();
89*ff614a57SMatthias Springer return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
90*ff614a57SMatthias Springer }
91*ff614a57SMatthias Springer
verifySubsetOpInterface(SubsetOpInterface op)921abd8d1aSMatthias Springer LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
931abd8d1aSMatthias Springer if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
941abd8d1aSMatthias Springer isa<SubsetInsertionOpInterface>(op.getOperation())))
951abd8d1aSMatthias Springer return op->emitOpError(
961abd8d1aSMatthias Springer "SubsetOpInterface ops must implement either "
971abd8d1aSMatthias Springer "SubsetExtractionOpInterface or SubsetInsertionOpInterface");
981abd8d1aSMatthias Springer return success();
991abd8d1aSMatthias Springer }
1001abd8d1aSMatthias Springer
1011abd8d1aSMatthias Springer LogicalResult
verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op)1021abd8d1aSMatthias Springer detail::verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op) {
1031abd8d1aSMatthias Springer if (op->getNumResults() != 1)
1041abd8d1aSMatthias Springer return op->emitOpError(
1051abd8d1aSMatthias Springer "SubsetExtractionOpInterface ops must have one result");
1061abd8d1aSMatthias Springer return success();
1071abd8d1aSMatthias Springer }
108