xref: /llvm-project/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp (revision 0a8e3dd432ff15ce871e4b9df0645e8a7e011fb3)
1cfc9ddaaSMatthias Springer //===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
2cfc9ddaaSMatthias Springer //
3cfc9ddaaSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4cfc9ddaaSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5cfc9ddaaSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6cfc9ddaaSMatthias Springer //
7cfc9ddaaSMatthias Springer //===----------------------------------------------------------------------===//
8cfc9ddaaSMatthias Springer 
9cfc9ddaaSMatthias Springer #include "mlir/Interfaces/DestinationStyleOpInterface.h"
10cfc9ddaaSMatthias Springer 
11cfc9ddaaSMatthias Springer using namespace mlir;
12cfc9ddaaSMatthias Springer 
13cfc9ddaaSMatthias Springer namespace mlir {
14cfc9ddaaSMatthias Springer #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
15cfc9ddaaSMatthias Springer } // namespace mlir
16cfc9ddaaSMatthias Springer 
172bd6077dSBenoit Jacob namespace {
getNumTensorResults(Operation * op)182bd6077dSBenoit Jacob size_t getNumTensorResults(Operation *op) {
192bd6077dSBenoit Jacob   size_t numTensorResults = 0;
202bd6077dSBenoit Jacob   for (auto t : op->getResultTypes()) {
212bd6077dSBenoit Jacob     if (isa<TensorType>(t)) {
222bd6077dSBenoit Jacob       ++numTensorResults;
232bd6077dSBenoit Jacob     }
242bd6077dSBenoit Jacob   }
252bd6077dSBenoit Jacob   return numTensorResults;
262bd6077dSBenoit Jacob }
272bd6077dSBenoit Jacob } // namespace
282bd6077dSBenoit Jacob 
verifyDestinationStyleOpInterface(Operation * op)29cfc9ddaaSMatthias Springer LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
30cfc9ddaaSMatthias Springer   DestinationStyleOpInterface dstStyleOp =
31cfc9ddaaSMatthias Springer       cast<DestinationStyleOpInterface>(op);
32cfc9ddaaSMatthias Springer 
332bd6077dSBenoit Jacob   SmallVector<OpOperand *> outputTensorOperands;
340b2197b0SMatthias Springer   for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
350b2197b0SMatthias Springer     Type type = operand.get().getType();
36*0a8e3dd4SMatthias Springer     if (isa<TensorType>(type)) {
370b2197b0SMatthias Springer       outputTensorOperands.push_back(&operand);
38*0a8e3dd4SMatthias Springer     } else if (!isa<BaseMemRefType>(type)) {
39ec2ea410SMatthias Springer       return op->emitOpError("expected that operand #")
40*0a8e3dd4SMatthias Springer              << operand.getOperandNumber() << " is a tensor or a memref";
41ec2ea410SMatthias Springer     }
42cfc9ddaaSMatthias Springer   }
43cfc9ddaaSMatthias Springer 
442bd6077dSBenoit Jacob   // Verify the number of tensor results matches the number of output tensors.
452bd6077dSBenoit Jacob   if (getNumTensorResults(op) != outputTensorOperands.size())
462bd6077dSBenoit Jacob     return op->emitOpError("expected the number of tensor results (")
472bd6077dSBenoit Jacob            << getNumTensorResults(op)
48cfc9ddaaSMatthias Springer            << ") to be equal to the number of output tensors ("
49cfc9ddaaSMatthias Springer            << outputTensorOperands.size() << ")";
50cfc9ddaaSMatthias Springer 
51cfc9ddaaSMatthias Springer   for (OpOperand *opOperand : outputTensorOperands) {
52cfc9ddaaSMatthias Springer     OpResult result = dstStyleOp.getTiedOpResult(opOperand);
53cfc9ddaaSMatthias Springer     if (result.getType() != opOperand->get().getType())
54cfc9ddaaSMatthias Springer       return op->emitOpError("expected type of operand #")
55cfc9ddaaSMatthias Springer              << opOperand->getOperandNumber() << " ("
56cfc9ddaaSMatthias Springer              << opOperand->get().getType() << ")"
57cfc9ddaaSMatthias Springer              << " to match type of corresponding result (" << result.getType()
58cfc9ddaaSMatthias Springer              << ")";
59cfc9ddaaSMatthias Springer   }
60*0a8e3dd4SMatthias Springer 
61cfc9ddaaSMatthias Springer   return success();
62cfc9ddaaSMatthias Springer }
63