xref: /llvm-project/mlir/lib/Interfaces/LoopLikeInterface.cpp (revision 97a2bd8415dc6792b99ec0f091ad7570673c3f37)
143959a25SRiver Riddle //===- LoopLikeInterface.cpp - Loop-like operations in MLIR ---------------===//
243959a25SRiver Riddle //
343959a25SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
443959a25SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
543959a25SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
643959a25SRiver Riddle //
743959a25SRiver Riddle //===----------------------------------------------------------------------===//
843959a25SRiver Riddle 
943959a25SRiver Riddle #include "mlir/Interfaces/LoopLikeInterface.h"
103bd7a9b4SMatthias Springer 
1134a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h"
1281a79ee4STom Eccles #include "llvm/ADT/DenseSet.h"
1343959a25SRiver Riddle 
1443959a25SRiver Riddle using namespace mlir;
1543959a25SRiver Riddle 
1643959a25SRiver Riddle /// Include the definitions of the loop-like interfaces.
1743959a25SRiver Riddle #include "mlir/Interfaces/LoopLikeInterface.cpp.inc"
1881a79ee4STom Eccles 
blockIsInLoop(Block * block)1981a79ee4STom Eccles bool LoopLikeOpInterface::blockIsInLoop(Block *block) {
2081a79ee4STom Eccles   Operation *parent = block->getParentOp();
2181a79ee4STom Eccles 
2281a79ee4STom Eccles   // The block could be inside a loop-like operation
2381a79ee4STom Eccles   if (isa<LoopLikeOpInterface>(parent) ||
2481a79ee4STom Eccles       parent->getParentOfType<LoopLikeOpInterface>())
2581a79ee4STom Eccles     return true;
2681a79ee4STom Eccles 
2781a79ee4STom Eccles   // This block might be nested inside another block, which is in a loop
2881a79ee4STom Eccles   if (!isa<FunctionOpInterface>(parent))
2981a79ee4STom Eccles     if (mlir::Block *parentBlock = parent->getBlock())
3081a79ee4STom Eccles       if (blockIsInLoop(parentBlock))
3181a79ee4STom Eccles         return true;
3281a79ee4STom Eccles 
3381a79ee4STom Eccles   // Or the block could be inside a control flow graph loop:
3481a79ee4STom Eccles   // A block is in a control flow graph loop if it can reach itself in a graph
3581a79ee4STom Eccles   // traversal
3681a79ee4STom Eccles   DenseSet<Block *> visited;
3781a79ee4STom Eccles   SmallVector<Block *> stack;
3881a79ee4STom Eccles   stack.push_back(block);
3981a79ee4STom Eccles   while (!stack.empty()) {
4081a79ee4STom Eccles     Block *current = stack.pop_back_val();
4181a79ee4STom Eccles     auto [it, inserted] = visited.insert(current);
4281a79ee4STom Eccles     if (!inserted) {
4381a79ee4STom Eccles       // loop detected
4481a79ee4STom Eccles       if (current == block)
4581a79ee4STom Eccles         return true;
4681a79ee4STom Eccles       continue;
4781a79ee4STom Eccles     }
4881a79ee4STom Eccles 
4981a79ee4STom Eccles     stack.reserve(stack.size() + current->getNumSuccessors());
5081a79ee4STom Eccles     for (Block *successor : current->getSuccessors())
5181a79ee4STom Eccles       stack.push_back(successor);
5281a79ee4STom Eccles   }
5381a79ee4STom Eccles   return false;
5481a79ee4STom Eccles }
55ab737a86SMatthias Springer 
verifyLoopLikeOpInterface(Operation * op)56ab737a86SMatthias Springer LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
57ab737a86SMatthias Springer   // Note: These invariants are also verified by the RegionBranchOpInterface,
58ab737a86SMatthias Springer   // but the LoopLikeOpInterface provides better error messages.
59ab737a86SMatthias Springer   auto loopLikeOp = cast<LoopLikeOpInterface>(op);
60ab737a86SMatthias Springer 
6198a6edd3SMatthias Springer   // Verify number of inits/iter_args/yielded values/loop results.
62ab737a86SMatthias Springer   if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
63ab737a86SMatthias Springer     return op->emitOpError("different number of inits and region iter_args: ")
64ab737a86SMatthias Springer            << loopLikeOp.getInits().size()
65ab737a86SMatthias Springer            << " != " << loopLikeOp.getRegionIterArgs().size();
66*76ead96cSMaheshRavishankar   if (!loopLikeOp.getYieldedValues().empty() &&
67*76ead96cSMaheshRavishankar       loopLikeOp.getRegionIterArgs().size() !=
68ab737a86SMatthias Springer           loopLikeOp.getYieldedValues().size())
69ab737a86SMatthias Springer     return op->emitOpError(
70ab737a86SMatthias Springer                "different number of region iter_args and yielded values: ")
71ab737a86SMatthias Springer            << loopLikeOp.getRegionIterArgs().size()
72ab737a86SMatthias Springer            << " != " << loopLikeOp.getYieldedValues().size();
7398a6edd3SMatthias Springer   if (loopLikeOp.getLoopResults() && loopLikeOp.getLoopResults()->size() !=
7498a6edd3SMatthias Springer                                          loopLikeOp.getRegionIterArgs().size())
7598a6edd3SMatthias Springer     return op->emitOpError(
7698a6edd3SMatthias Springer                "different number of loop results and region iter_args: ")
7798a6edd3SMatthias Springer            << loopLikeOp.getLoopResults()->size()
7898a6edd3SMatthias Springer            << " != " << loopLikeOp.getRegionIterArgs().size();
79ab737a86SMatthias Springer 
8098a6edd3SMatthias Springer   // Verify types of inits/iter_args/yielded values/loop results.
81ab737a86SMatthias Springer   int64_t i = 0;
82*76ead96cSMaheshRavishankar   auto yieldedValues = loopLikeOp.getYieldedValues();
83*76ead96cSMaheshRavishankar   for (const auto [index, init, regionIterArg] :
84*76ead96cSMaheshRavishankar        llvm::enumerate(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) {
85*76ead96cSMaheshRavishankar     if (init.getType() != regionIterArg.getType())
86*76ead96cSMaheshRavishankar       return op->emitOpError(std::to_string(index))
87*76ead96cSMaheshRavishankar              << "-th init and " << index
88*76ead96cSMaheshRavishankar              << "-th region iter_arg have different type: " << init.getType()
89*76ead96cSMaheshRavishankar              << " != " << regionIterArg.getType();
90*76ead96cSMaheshRavishankar     if (!yieldedValues.empty()) {
91*76ead96cSMaheshRavishankar       if (regionIterArg.getType() != yieldedValues[index].getType())
92*76ead96cSMaheshRavishankar         return op->emitOpError(std::to_string(index))
93*76ead96cSMaheshRavishankar                << "-th region iter_arg and " << index
94ab737a86SMatthias Springer                << "-th yielded value have different type: "
95*76ead96cSMaheshRavishankar                << regionIterArg.getType()
96*76ead96cSMaheshRavishankar                << " != " << yieldedValues[index].getType();
97*76ead96cSMaheshRavishankar     }
9898a6edd3SMatthias Springer     ++i;
9998a6edd3SMatthias Springer   }
10098a6edd3SMatthias Springer   i = 0;
10198a6edd3SMatthias Springer   if (loopLikeOp.getLoopResults()) {
10298a6edd3SMatthias Springer     for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
10398a6edd3SMatthias Springer                                          *loopLikeOp.getLoopResults())) {
10498a6edd3SMatthias Springer       if (std::get<0>(it).getType() != std::get<1>(it).getType())
10598a6edd3SMatthias Springer         return op->emitOpError(std::to_string(i))
10698a6edd3SMatthias Springer                << "-th region iter_arg and " << i
10798a6edd3SMatthias Springer                << "-th loop result have different type: "
10898a6edd3SMatthias Springer                << std::get<0>(it).getType()
10998a6edd3SMatthias Springer                << " != " << std::get<1>(it).getType();
11098a6edd3SMatthias Springer     }
111ab737a86SMatthias Springer     ++i;
112ab737a86SMatthias Springer   }
113ab737a86SMatthias Springer 
114ab737a86SMatthias Springer   return success();
115ab737a86SMatthias Springer }
116