xref: /llvm-project/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp (revision 0a8e3dd432ff15ce871e4b9df0645e8a7e011fb3)
1 //===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
10 
11 using namespace mlir;
12 
13 namespace mlir {
14 #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
15 } // namespace mlir
16 
17 namespace {
getNumTensorResults(Operation * op)18 size_t getNumTensorResults(Operation *op) {
19   size_t numTensorResults = 0;
20   for (auto t : op->getResultTypes()) {
21     if (isa<TensorType>(t)) {
22       ++numTensorResults;
23     }
24   }
25   return numTensorResults;
26 }
27 } // namespace
28 
verifyDestinationStyleOpInterface(Operation * op)29 LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
30   DestinationStyleOpInterface dstStyleOp =
31       cast<DestinationStyleOpInterface>(op);
32 
33   SmallVector<OpOperand *> outputTensorOperands;
34   for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
35     Type type = operand.get().getType();
36     if (isa<TensorType>(type)) {
37       outputTensorOperands.push_back(&operand);
38     } else if (!isa<BaseMemRefType>(type)) {
39       return op->emitOpError("expected that operand #")
40              << operand.getOperandNumber() << " is a tensor or a memref";
41     }
42   }
43 
44   // Verify the number of tensor results matches the number of output tensors.
45   if (getNumTensorResults(op) != outputTensorOperands.size())
46     return op->emitOpError("expected the number of tensor results (")
47            << getNumTensorResults(op)
48            << ") to be equal to the number of output tensors ("
49            << outputTensorOperands.size() << ")";
50 
51   for (OpOperand *opOperand : outputTensorOperands) {
52     OpResult result = dstStyleOp.getTiedOpResult(opOperand);
53     if (result.getType() != opOperand->get().getType())
54       return op->emitOpError("expected type of operand #")
55              << opOperand->getOperandNumber() << " ("
56              << opOperand->get().getType() << ")"
57              << " to match type of corresponding result (" << result.getType()
58              << ")";
59   }
60 
61   return success();
62 }
63