//===- SubsetOpInterface.cpp - Tensor Subsets -----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Interfaces/SubsetOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Interfaces/SubsetOpInterface.cpp.inc" using namespace mlir; OpOperand &detail::defaultGetDestinationOperand(Operation *op) { auto dstOp = dyn_cast(op); assert(dstOp && "getDestination must be implemented for non-DPS ops"); assert( dstOp.getNumDpsInits() == 1 && "getDestination must be implemented for ops with 0 or more than 1 init"); return *dstOp.getDpsInitOperand(0); } OpResult detail::defaultGetUpdatedDestination(Operation *op) { auto dstOp = dyn_cast(op); assert(dstOp && "getUpdatedDestination must be implemented for non-DPS ops"); auto insertionOp = cast(op); return dstOp.getTiedOpResult(&insertionOp.getDestinationOperand()); } bool detail::defaultIsEquivalentSubset( Operation *op, Value candidate, function_ref equivalenceFn) { assert(isa(op) && "expected SubsetInsertionOpInterface"); if (!candidate.getDefiningOp()) return false; return cast(op).operatesOnEquivalentSubset( candidate.getDefiningOp(), equivalenceFn); } bool detail::defaultOperatesOnEquivalentSubset( Operation *op, SubsetOpInterface candidate, function_ref equivalenceFn) { auto subsetOp = cast(op); FailureOr slice = subsetOp.getAccessedHyperrectangularSlice(); assert(succeeded(slice) && "operatesOnEquivalentSubset must be implemented if " "getAccessedHyperrectangularSlice is not implemented"); FailureOr otherSlice = candidate.getAccessedHyperrectangularSlice(); if (failed(otherSlice)) return false; if (!equivalenceFn(subsetOp.getTensorContainer(), candidate.getTensorContainer())) return false; FailureOr equivalent = ValueBoundsConstraintSet::areEquivalentSlices( op->getContext(), *slice, *otherSlice); return succeeded(equivalent) && *equivalent; } bool detail::defaultOperatesOnDisjointSubset( Operation *op, SubsetOpInterface candidate, function_ref equivalenceFn) { auto subsetOp = cast(op); FailureOr slice = subsetOp.getAccessedHyperrectangularSlice(); assert(succeeded(slice) && "defaultOperatesOnDisjointSubset must be implemented if " "getAccessedHyperrectangularSlice is not implemented"); FailureOr otherSlice = candidate.getAccessedHyperrectangularSlice(); if (failed(otherSlice)) return false; if (!equivalenceFn(subsetOp.getTensorContainer(), candidate.getTensorContainer())) return false; FailureOr overlapping = ValueBoundsConstraintSet::areOverlappingSlices( op->getContext(), *slice, *otherSlice); return succeeded(overlapping) && !*overlapping; } Value detail::getTensorContainer(Operation *op) { if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op)) return insertionOp.getDestinationOperand().get(); return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get(); } LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) { if (!(isa(op.getOperation()) ^ isa(op.getOperation()))) return op->emitOpError( "SubsetOpInterface ops must implement either " "SubsetExtractionOpInterface or SubsetInsertionOpInterface"); return success(); } LogicalResult detail::verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op) { if (op->getNumResults() != 1) return op->emitOpError( "SubsetExtractionOpInterface ops must have one result"); return success(); }