1 //===- LoopLikeInterface.h - Loop-like operations interface ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the operation interface for loop like operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_INTERFACES_LOOPLIKEINTERFACE_H_ 14 #define MLIR_INTERFACES_LOOPLIKEINTERFACE_H_ 15 16 #include "mlir/IR/OpDefinition.h" 17 18 namespace mlir { 19 class RewriterBase; 20 21 /// A function that returns the additional yielded values during 22 /// `replaceWithAdditionalYields`. `newBbArgs` are the newly added region 23 /// iter_args. This function should return as many values as there are block 24 /// arguments in `newBbArgs`. 25 using NewYieldValuesFn = std::function<SmallVector<Value>( 26 OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>; 27 28 namespace detail { 29 /// Verify invariants of the LoopLikeOpInterface. 30 LogicalResult verifyLoopLikeOpInterface(Operation *op); 31 } // namespace detail 32 33 //===----------------------------------------------------------------------===// 34 // Traits 35 //===----------------------------------------------------------------------===// 36 37 namespace OpTrait { 38 // A trait indicating that the single region contained in the operation has 39 // parallel execution semantics. This may have implications in a certain pass. 40 // For example, buffer hoisting is illegal in parallel loops, and local buffers 41 // may be accessed by parallel threads simultaneously. 42 template <typename ConcreteType> 43 class HasParallelRegion : public TraitBase<ConcreteType, HasParallelRegion> { 44 public: verifyTrait(Operation * op)45 static LogicalResult verifyTrait(Operation *op) { 46 return impl::verifyOneRegion(op); 47 } 48 }; 49 50 } // namespace OpTrait 51 52 // Gathers all maximal sub-blocks of operations that do not themselves 53 // include a `OpTy` (an operation could have a descendant `OpTy` though 54 // in its tree). Ignores the block terminators. 55 template <typename OpTy> 56 struct JamBlockGatherer { 57 // Store iterators to the first and last op of each sub-block found. 58 SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks; 59 60 // This is a linear time walk. walkJamBlockGatherer61 void walk(Operation *op) { 62 for (Region ®ion : op->getRegions()) 63 for (Block &block : region) 64 walk(block); 65 } 66 walkJamBlockGatherer67 void walk(Block &block) { 68 assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() && 69 "expected block to have a terminator"); 70 for (Block::iterator it = block.begin(), e = std::prev(block.end()); 71 it != e;) { 72 Block::iterator subBlockStart = it; 73 while (it != e && !isa<OpTy>(&*it)) 74 ++it; 75 if (it != subBlockStart) 76 subBlocks.emplace_back(subBlockStart, std::prev(it)); 77 // Process all for ops that appear next. 78 while (it != e && isa<OpTy>(&*it)) 79 walk(&*it++); 80 } 81 } 82 }; 83 84 } // namespace mlir 85 86 //===----------------------------------------------------------------------===// 87 // Interfaces 88 //===----------------------------------------------------------------------===// 89 90 /// Include the generated interface declarations. 91 #include "mlir/Interfaces/LoopLikeInterface.h.inc" 92 93 #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_ 94