xref: /llvm-project/mlir/include/mlir/Interfaces/LoopLikeInterface.h (revision 97a2bd8415dc6792b99ec0f091ad7570673c3f37)
1 //===- LoopLikeInterface.h - Loop-like operations interface ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the operation interface for loop like operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
14 #define MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
15 
16 #include "mlir/IR/OpDefinition.h"
17 
18 namespace mlir {
19 class RewriterBase;
20 
21 /// A function that returns the additional yielded values during
22 /// `replaceWithAdditionalYields`. `newBbArgs` are the newly added region
23 /// iter_args. This function should return as many values as there are block
24 /// arguments in `newBbArgs`.
25 using NewYieldValuesFn = std::function<SmallVector<Value>(
26     OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
27 
28 namespace detail {
29 /// Verify invariants of the LoopLikeOpInterface.
30 LogicalResult verifyLoopLikeOpInterface(Operation *op);
31 } // namespace detail
32 
33 //===----------------------------------------------------------------------===//
34 // Traits
35 //===----------------------------------------------------------------------===//
36 
37 namespace OpTrait {
38 // A trait indicating that the single region contained in the operation has
39 // parallel execution semantics. This may have implications in a certain pass.
40 // For example, buffer hoisting is illegal in parallel loops, and local buffers
41 // may be accessed by parallel threads simultaneously.
42 template <typename ConcreteType>
43 class HasParallelRegion : public TraitBase<ConcreteType, HasParallelRegion> {
44 public:
verifyTrait(Operation * op)45   static LogicalResult verifyTrait(Operation *op) {
46     return impl::verifyOneRegion(op);
47   }
48 };
49 
50 } // namespace OpTrait
51 
52 // Gathers all maximal sub-blocks of operations that do not themselves
53 // include a `OpTy` (an operation could have a descendant `OpTy` though
54 // in its tree). Ignores the block terminators.
55 template <typename OpTy>
56 struct JamBlockGatherer {
57   // Store iterators to the first and last op of each sub-block found.
58   SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
59 
60   // This is a linear time walk.
walkJamBlockGatherer61   void walk(Operation *op) {
62     for (Region &region : op->getRegions())
63       for (Block &block : region)
64         walk(block);
65   }
66 
walkJamBlockGatherer67   void walk(Block &block) {
68     assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
69            "expected block to have a terminator");
70     for (Block::iterator it = block.begin(), e = std::prev(block.end());
71          it != e;) {
72       Block::iterator subBlockStart = it;
73       while (it != e && !isa<OpTy>(&*it))
74         ++it;
75       if (it != subBlockStart)
76         subBlocks.emplace_back(subBlockStart, std::prev(it));
77       // Process all for ops that appear next.
78       while (it != e && isa<OpTy>(&*it))
79         walk(&*it++);
80     }
81   }
82 };
83 
84 } // namespace mlir
85 
86 //===----------------------------------------------------------------------===//
87 // Interfaces
88 //===----------------------------------------------------------------------===//
89 
90 /// Include the generated interface declarations.
91 #include "mlir/Interfaces/LoopLikeInterface.h.inc"
92 
93 #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
94