143959a25SRiver Riddle //===- LoopLikeInterface.cpp - Loop-like operations in MLIR ---------------===// 243959a25SRiver Riddle // 343959a25SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 443959a25SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 543959a25SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 643959a25SRiver Riddle // 743959a25SRiver Riddle //===----------------------------------------------------------------------===// 843959a25SRiver Riddle 943959a25SRiver Riddle #include "mlir/Interfaces/LoopLikeInterface.h" 103bd7a9b4SMatthias Springer 1134a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h" 1281a79ee4STom Eccles #include "llvm/ADT/DenseSet.h" 1343959a25SRiver Riddle 1443959a25SRiver Riddle using namespace mlir; 1543959a25SRiver Riddle 1643959a25SRiver Riddle /// Include the definitions of the loop-like interfaces. 1743959a25SRiver Riddle #include "mlir/Interfaces/LoopLikeInterface.cpp.inc" 1881a79ee4STom Eccles blockIsInLoop(Block * block)1981a79ee4STom Ecclesbool LoopLikeOpInterface::blockIsInLoop(Block *block) { 2081a79ee4STom Eccles Operation *parent = block->getParentOp(); 2181a79ee4STom Eccles 2281a79ee4STom Eccles // The block could be inside a loop-like operation 2381a79ee4STom Eccles if (isa<LoopLikeOpInterface>(parent) || 2481a79ee4STom Eccles parent->getParentOfType<LoopLikeOpInterface>()) 2581a79ee4STom Eccles return true; 2681a79ee4STom Eccles 2781a79ee4STom Eccles // This block might be nested inside another block, which is in a loop 2881a79ee4STom Eccles if (!isa<FunctionOpInterface>(parent)) 2981a79ee4STom Eccles if (mlir::Block *parentBlock = parent->getBlock()) 3081a79ee4STom Eccles if (blockIsInLoop(parentBlock)) 3181a79ee4STom Eccles return true; 3281a79ee4STom Eccles 3381a79ee4STom Eccles // Or the block could be inside a control flow graph loop: 3481a79ee4STom Eccles // A block is in a control flow graph loop if it can reach itself in a graph 3581a79ee4STom Eccles // traversal 3681a79ee4STom Eccles DenseSet<Block *> visited; 3781a79ee4STom Eccles SmallVector<Block *> stack; 3881a79ee4STom Eccles stack.push_back(block); 3981a79ee4STom Eccles while (!stack.empty()) { 4081a79ee4STom Eccles Block *current = stack.pop_back_val(); 4181a79ee4STom Eccles auto [it, inserted] = visited.insert(current); 4281a79ee4STom Eccles if (!inserted) { 4381a79ee4STom Eccles // loop detected 4481a79ee4STom Eccles if (current == block) 4581a79ee4STom Eccles return true; 4681a79ee4STom Eccles continue; 4781a79ee4STom Eccles } 4881a79ee4STom Eccles 4981a79ee4STom Eccles stack.reserve(stack.size() + current->getNumSuccessors()); 5081a79ee4STom Eccles for (Block *successor : current->getSuccessors()) 5181a79ee4STom Eccles stack.push_back(successor); 5281a79ee4STom Eccles } 5381a79ee4STom Eccles return false; 5481a79ee4STom Eccles } 55ab737a86SMatthias Springer verifyLoopLikeOpInterface(Operation * op)56ab737a86SMatthias SpringerLogicalResult detail::verifyLoopLikeOpInterface(Operation *op) { 57ab737a86SMatthias Springer // Note: These invariants are also verified by the RegionBranchOpInterface, 58ab737a86SMatthias Springer // but the LoopLikeOpInterface provides better error messages. 59ab737a86SMatthias Springer auto loopLikeOp = cast<LoopLikeOpInterface>(op); 60ab737a86SMatthias Springer 6198a6edd3SMatthias Springer // Verify number of inits/iter_args/yielded values/loop results. 62ab737a86SMatthias Springer if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size()) 63ab737a86SMatthias Springer return op->emitOpError("different number of inits and region iter_args: ") 64ab737a86SMatthias Springer << loopLikeOp.getInits().size() 65ab737a86SMatthias Springer << " != " << loopLikeOp.getRegionIterArgs().size(); 66*76ead96cSMaheshRavishankar if (!loopLikeOp.getYieldedValues().empty() && 67*76ead96cSMaheshRavishankar loopLikeOp.getRegionIterArgs().size() != 68ab737a86SMatthias Springer loopLikeOp.getYieldedValues().size()) 69ab737a86SMatthias Springer return op->emitOpError( 70ab737a86SMatthias Springer "different number of region iter_args and yielded values: ") 71ab737a86SMatthias Springer << loopLikeOp.getRegionIterArgs().size() 72ab737a86SMatthias Springer << " != " << loopLikeOp.getYieldedValues().size(); 7398a6edd3SMatthias Springer if (loopLikeOp.getLoopResults() && loopLikeOp.getLoopResults()->size() != 7498a6edd3SMatthias Springer loopLikeOp.getRegionIterArgs().size()) 7598a6edd3SMatthias Springer return op->emitOpError( 7698a6edd3SMatthias Springer "different number of loop results and region iter_args: ") 7798a6edd3SMatthias Springer << loopLikeOp.getLoopResults()->size() 7898a6edd3SMatthias Springer << " != " << loopLikeOp.getRegionIterArgs().size(); 79ab737a86SMatthias Springer 8098a6edd3SMatthias Springer // Verify types of inits/iter_args/yielded values/loop results. 81ab737a86SMatthias Springer int64_t i = 0; 82*76ead96cSMaheshRavishankar auto yieldedValues = loopLikeOp.getYieldedValues(); 83*76ead96cSMaheshRavishankar for (const auto [index, init, regionIterArg] : 84*76ead96cSMaheshRavishankar llvm::enumerate(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) { 85*76ead96cSMaheshRavishankar if (init.getType() != regionIterArg.getType()) 86*76ead96cSMaheshRavishankar return op->emitOpError(std::to_string(index)) 87*76ead96cSMaheshRavishankar << "-th init and " << index 88*76ead96cSMaheshRavishankar << "-th region iter_arg have different type: " << init.getType() 89*76ead96cSMaheshRavishankar << " != " << regionIterArg.getType(); 90*76ead96cSMaheshRavishankar if (!yieldedValues.empty()) { 91*76ead96cSMaheshRavishankar if (regionIterArg.getType() != yieldedValues[index].getType()) 92*76ead96cSMaheshRavishankar return op->emitOpError(std::to_string(index)) 93*76ead96cSMaheshRavishankar << "-th region iter_arg and " << index 94ab737a86SMatthias Springer << "-th yielded value have different type: " 95*76ead96cSMaheshRavishankar << regionIterArg.getType() 96*76ead96cSMaheshRavishankar << " != " << yieldedValues[index].getType(); 97*76ead96cSMaheshRavishankar } 9898a6edd3SMatthias Springer ++i; 9998a6edd3SMatthias Springer } 10098a6edd3SMatthias Springer i = 0; 10198a6edd3SMatthias Springer if (loopLikeOp.getLoopResults()) { 10298a6edd3SMatthias Springer for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(), 10398a6edd3SMatthias Springer *loopLikeOp.getLoopResults())) { 10498a6edd3SMatthias Springer if (std::get<0>(it).getType() != std::get<1>(it).getType()) 10598a6edd3SMatthias Springer return op->emitOpError(std::to_string(i)) 10698a6edd3SMatthias Springer << "-th region iter_arg and " << i 10798a6edd3SMatthias Springer << "-th loop result have different type: " 10898a6edd3SMatthias Springer << std::get<0>(it).getType() 10998a6edd3SMatthias Springer << " != " << std::get<1>(it).getType(); 11098a6edd3SMatthias Springer } 111ab737a86SMatthias Springer ++i; 112ab737a86SMatthias Springer } 113ab737a86SMatthias Springer 114ab737a86SMatthias Springer return success(); 115ab737a86SMatthias Springer } 116