1cfc9ddaaSMatthias Springer //===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===// 2cfc9ddaaSMatthias Springer // 3cfc9ddaaSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4cfc9ddaaSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5cfc9ddaaSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6cfc9ddaaSMatthias Springer // 7cfc9ddaaSMatthias Springer //===----------------------------------------------------------------------===// 8cfc9ddaaSMatthias Springer 9cfc9ddaaSMatthias Springer #include "mlir/Interfaces/DestinationStyleOpInterface.h" 10cfc9ddaaSMatthias Springer 11cfc9ddaaSMatthias Springer using namespace mlir; 12cfc9ddaaSMatthias Springer 13cfc9ddaaSMatthias Springer namespace mlir { 14cfc9ddaaSMatthias Springer #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc" 15cfc9ddaaSMatthias Springer } // namespace mlir 16cfc9ddaaSMatthias Springer 172bd6077dSBenoit Jacob namespace { getNumTensorResults(Operation * op)182bd6077dSBenoit Jacobsize_t getNumTensorResults(Operation *op) { 192bd6077dSBenoit Jacob size_t numTensorResults = 0; 202bd6077dSBenoit Jacob for (auto t : op->getResultTypes()) { 212bd6077dSBenoit Jacob if (isa<TensorType>(t)) { 222bd6077dSBenoit Jacob ++numTensorResults; 232bd6077dSBenoit Jacob } 242bd6077dSBenoit Jacob } 252bd6077dSBenoit Jacob return numTensorResults; 262bd6077dSBenoit Jacob } 272bd6077dSBenoit Jacob } // namespace 282bd6077dSBenoit Jacob verifyDestinationStyleOpInterface(Operation * op)29cfc9ddaaSMatthias SpringerLogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { 30cfc9ddaaSMatthias Springer DestinationStyleOpInterface dstStyleOp = 31cfc9ddaaSMatthias Springer cast<DestinationStyleOpInterface>(op); 32cfc9ddaaSMatthias Springer 332bd6077dSBenoit Jacob SmallVector<OpOperand *> outputTensorOperands; 340b2197b0SMatthias Springer for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) { 350b2197b0SMatthias Springer Type type = operand.get().getType(); 36*0a8e3dd4SMatthias Springer if (isa<TensorType>(type)) { 370b2197b0SMatthias Springer outputTensorOperands.push_back(&operand); 38*0a8e3dd4SMatthias Springer } else if (!isa<BaseMemRefType>(type)) { 39ec2ea410SMatthias Springer return op->emitOpError("expected that operand #") 40*0a8e3dd4SMatthias Springer << operand.getOperandNumber() << " is a tensor or a memref"; 41ec2ea410SMatthias Springer } 42cfc9ddaaSMatthias Springer } 43cfc9ddaaSMatthias Springer 442bd6077dSBenoit Jacob // Verify the number of tensor results matches the number of output tensors. 452bd6077dSBenoit Jacob if (getNumTensorResults(op) != outputTensorOperands.size()) 462bd6077dSBenoit Jacob return op->emitOpError("expected the number of tensor results (") 472bd6077dSBenoit Jacob << getNumTensorResults(op) 48cfc9ddaaSMatthias Springer << ") to be equal to the number of output tensors (" 49cfc9ddaaSMatthias Springer << outputTensorOperands.size() << ")"; 50cfc9ddaaSMatthias Springer 51cfc9ddaaSMatthias Springer for (OpOperand *opOperand : outputTensorOperands) { 52cfc9ddaaSMatthias Springer OpResult result = dstStyleOp.getTiedOpResult(opOperand); 53cfc9ddaaSMatthias Springer if (result.getType() != opOperand->get().getType()) 54cfc9ddaaSMatthias Springer return op->emitOpError("expected type of operand #") 55cfc9ddaaSMatthias Springer << opOperand->getOperandNumber() << " (" 56cfc9ddaaSMatthias Springer << opOperand->get().getType() << ")" 57cfc9ddaaSMatthias Springer << " to match type of corresponding result (" << result.getType() 58cfc9ddaaSMatthias Springer << ")"; 59cfc9ddaaSMatthias Springer } 60*0a8e3dd4SMatthias Springer 61cfc9ddaaSMatthias Springer return success(); 62cfc9ddaaSMatthias Springer } 63