1 //===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 10 11 using namespace mlir; 12 13 namespace mlir { 14 #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc" 15 } // namespace mlir 16 17 namespace { getNumTensorResults(Operation * op)18size_t getNumTensorResults(Operation *op) { 19 size_t numTensorResults = 0; 20 for (auto t : op->getResultTypes()) { 21 if (isa<TensorType>(t)) { 22 ++numTensorResults; 23 } 24 } 25 return numTensorResults; 26 } 27 } // namespace 28 verifyDestinationStyleOpInterface(Operation * op)29LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { 30 DestinationStyleOpInterface dstStyleOp = 31 cast<DestinationStyleOpInterface>(op); 32 33 SmallVector<OpOperand *> outputTensorOperands; 34 for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) { 35 Type type = operand.get().getType(); 36 if (isa<TensorType>(type)) { 37 outputTensorOperands.push_back(&operand); 38 } else if (!isa<BaseMemRefType>(type)) { 39 return op->emitOpError("expected that operand #") 40 << operand.getOperandNumber() << " is a tensor or a memref"; 41 } 42 } 43 44 // Verify the number of tensor results matches the number of output tensors. 45 if (getNumTensorResults(op) != outputTensorOperands.size()) 46 return op->emitOpError("expected the number of tensor results (") 47 << getNumTensorResults(op) 48 << ") to be equal to the number of output tensors (" 49 << outputTensorOperands.size() << ")"; 50 51 for (OpOperand *opOperand : outputTensorOperands) { 52 OpResult result = dstStyleOp.getTiedOpResult(opOperand); 53 if (result.getType() != opOperand->get().getType()) 54 return op->emitOpError("expected type of operand #") 55 << opOperand->getOperandNumber() << " (" 56 << opOperand->get().getType() << ")" 57 << " to match type of corresponding result (" << result.getType() 58 << ")"; 59 } 60 61 return success(); 62 } 63