1 //===- LoopLikeInterface.cpp - Loop-like operations in MLIR ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/LoopLikeInterface.h" 10 11 #include "mlir/Interfaces/FunctionInterfaces.h" 12 #include "llvm/ADT/DenseSet.h" 13 14 using namespace mlir; 15 16 /// Include the definitions of the loop-like interfaces. 17 #include "mlir/Interfaces/LoopLikeInterface.cpp.inc" 18 blockIsInLoop(Block * block)19bool LoopLikeOpInterface::blockIsInLoop(Block *block) { 20 Operation *parent = block->getParentOp(); 21 22 // The block could be inside a loop-like operation 23 if (isa<LoopLikeOpInterface>(parent) || 24 parent->getParentOfType<LoopLikeOpInterface>()) 25 return true; 26 27 // This block might be nested inside another block, which is in a loop 28 if (!isa<FunctionOpInterface>(parent)) 29 if (mlir::Block *parentBlock = parent->getBlock()) 30 if (blockIsInLoop(parentBlock)) 31 return true; 32 33 // Or the block could be inside a control flow graph loop: 34 // A block is in a control flow graph loop if it can reach itself in a graph 35 // traversal 36 DenseSet<Block *> visited; 37 SmallVector<Block *> stack; 38 stack.push_back(block); 39 while (!stack.empty()) { 40 Block *current = stack.pop_back_val(); 41 auto [it, inserted] = visited.insert(current); 42 if (!inserted) { 43 // loop detected 44 if (current == block) 45 return true; 46 continue; 47 } 48 49 stack.reserve(stack.size() + current->getNumSuccessors()); 50 for (Block *successor : current->getSuccessors()) 51 stack.push_back(successor); 52 } 53 return false; 54 } 55 verifyLoopLikeOpInterface(Operation * op)56LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) { 57 // Note: These invariants are also verified by the RegionBranchOpInterface, 58 // but the LoopLikeOpInterface provides better error messages. 59 auto loopLikeOp = cast<LoopLikeOpInterface>(op); 60 61 // Verify number of inits/iter_args/yielded values/loop results. 62 if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size()) 63 return op->emitOpError("different number of inits and region iter_args: ") 64 << loopLikeOp.getInits().size() 65 << " != " << loopLikeOp.getRegionIterArgs().size(); 66 if (!loopLikeOp.getYieldedValues().empty() && 67 loopLikeOp.getRegionIterArgs().size() != 68 loopLikeOp.getYieldedValues().size()) 69 return op->emitOpError( 70 "different number of region iter_args and yielded values: ") 71 << loopLikeOp.getRegionIterArgs().size() 72 << " != " << loopLikeOp.getYieldedValues().size(); 73 if (loopLikeOp.getLoopResults() && loopLikeOp.getLoopResults()->size() != 74 loopLikeOp.getRegionIterArgs().size()) 75 return op->emitOpError( 76 "different number of loop results and region iter_args: ") 77 << loopLikeOp.getLoopResults()->size() 78 << " != " << loopLikeOp.getRegionIterArgs().size(); 79 80 // Verify types of inits/iter_args/yielded values/loop results. 81 int64_t i = 0; 82 auto yieldedValues = loopLikeOp.getYieldedValues(); 83 for (const auto [index, init, regionIterArg] : 84 llvm::enumerate(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) { 85 if (init.getType() != regionIterArg.getType()) 86 return op->emitOpError(std::to_string(index)) 87 << "-th init and " << index 88 << "-th region iter_arg have different type: " << init.getType() 89 << " != " << regionIterArg.getType(); 90 if (!yieldedValues.empty()) { 91 if (regionIterArg.getType() != yieldedValues[index].getType()) 92 return op->emitOpError(std::to_string(index)) 93 << "-th region iter_arg and " << index 94 << "-th yielded value have different type: " 95 << regionIterArg.getType() 96 << " != " << yieldedValues[index].getType(); 97 } 98 ++i; 99 } 100 i = 0; 101 if (loopLikeOp.getLoopResults()) { 102 for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(), 103 *loopLikeOp.getLoopResults())) { 104 if (std::get<0>(it).getType() != std::get<1>(it).getType()) 105 return op->emitOpError(std::to_string(i)) 106 << "-th region iter_arg and " << i 107 << "-th loop result have different type: " 108 << std::get<0>(it).getType() 109 << " != " << std::get<1>(it).getType(); 110 } 111 ++i; 112 } 113 114 return success(); 115 } 116