1 //===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <utility>
10
11 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/Allocator.h"
17 #include "llvm/Support/raw_ostream.h"
18
19 using namespace mlir;
20 using namespace mlir::affine;
21
allocator()22 llvm::BumpPtrAllocator *&NestedMatch::allocator() {
23 thread_local llvm::BumpPtrAllocator *allocator = nullptr;
24 return allocator;
25 }
26
build(Operation * operation,ArrayRef<NestedMatch> nestedMatches)27 NestedMatch NestedMatch::build(Operation *operation,
28 ArrayRef<NestedMatch> nestedMatches) {
29 auto *result = allocator()->Allocate<NestedMatch>();
30 auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
31 std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
32 new (result) NestedMatch();
33 result->matchedOperation = operation;
34 result->matchedChildren =
35 ArrayRef<NestedMatch>(children, nestedMatches.size());
36 return *result;
37 }
38
allocator()39 llvm::BumpPtrAllocator *&NestedPattern::allocator() {
40 thread_local llvm::BumpPtrAllocator *allocator = nullptr;
41 return allocator;
42 }
43
copyNestedToThis(ArrayRef<NestedPattern> nested)44 void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) {
45 if (nested.empty())
46 return;
47
48 auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
49 std::uninitialized_copy(nested.begin(), nested.end(), newNested);
50 nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
51 }
52
freeNested()53 void NestedPattern::freeNested() {
54 for (const auto &p : nestedPatterns)
55 p.~NestedPattern();
56 }
57
NestedPattern(ArrayRef<NestedPattern> nested,FilterFunctionType filter)58 NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
59 FilterFunctionType filter)
60 : filter(std::move(filter)), skip(nullptr) {
61 copyNestedToThis(nested);
62 }
63
NestedPattern(const NestedPattern & other)64 NestedPattern::NestedPattern(const NestedPattern &other)
65 : filter(other.filter), skip(other.skip) {
66 copyNestedToThis(other.nestedPatterns);
67 }
68
operator =(const NestedPattern & other)69 NestedPattern &NestedPattern::operator=(const NestedPattern &other) {
70 freeNested();
71 filter = other.filter;
72 skip = other.skip;
73 copyNestedToThis(other.nestedPatterns);
74 return *this;
75 }
76
getDepth() const77 unsigned NestedPattern::getDepth() const {
78 if (nestedPatterns.empty()) {
79 return 1;
80 }
81 unsigned depth = 0;
82 for (auto &c : nestedPatterns) {
83 depth = std::max(depth, c.getDepth());
84 }
85 return depth + 1;
86 }
87
88 /// Matches a single operation in the following way:
89 /// 1. checks the kind of operation against the matcher, if different then
90 /// there is no match;
91 /// 2. calls the customizable filter function to refine the single operation
92 /// match with extra semantic constraints;
93 /// 3. if all is good, recursively matches the nested patterns;
94 /// 4. if all nested match then the single operation matches too and is
95 /// appended to the list of matches;
96 /// 5. TODO: Optionally applies actions (lambda), in which case we will want
97 /// to traverse in post-order DFS to avoid invalidating iterators.
matchOne(Operation * op,SmallVectorImpl<NestedMatch> * matches)98 void NestedPattern::matchOne(Operation *op,
99 SmallVectorImpl<NestedMatch> *matches) {
100 if (skip == op) {
101 return;
102 }
103 // Local custom filter function
104 if (!filter(*op)) {
105 return;
106 }
107
108 if (nestedPatterns.empty()) {
109 SmallVector<NestedMatch, 8> nestedMatches;
110 matches->push_back(NestedMatch::build(op, nestedMatches));
111 return;
112 }
113 // Take a copy of each nested pattern so we can match it.
114 for (auto nestedPattern : nestedPatterns) {
115 SmallVector<NestedMatch, 8> nestedMatches;
116 // Skip elem in the walk immediately following. Without this we would
117 // essentially need to reimplement walk here.
118 nestedPattern.skip = op;
119 nestedPattern.match(op, &nestedMatches);
120 // If we could not match even one of the specified nestedPattern, early exit
121 // as this whole branch is not a match.
122 if (nestedMatches.empty()) {
123 return;
124 }
125 matches->push_back(NestedMatch::build(op, nestedMatches));
126 }
127 }
128
isAffineForOp(Operation & op)129 static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
130
isAffineIfOp(Operation & op)131 static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
132
133 namespace mlir {
134 namespace affine {
135 namespace matcher {
136
Op(FilterFunctionType filter)137 NestedPattern Op(FilterFunctionType filter) {
138 return NestedPattern({}, std::move(filter));
139 }
140
If(const NestedPattern & child)141 NestedPattern If(const NestedPattern &child) {
142 return NestedPattern(child, isAffineIfOp);
143 }
If(const FilterFunctionType & filter,const NestedPattern & child)144 NestedPattern If(const FilterFunctionType &filter, const NestedPattern &child) {
145 return NestedPattern(child, [filter](Operation &op) {
146 return isAffineIfOp(op) && filter(op);
147 });
148 }
If(ArrayRef<NestedPattern> nested)149 NestedPattern If(ArrayRef<NestedPattern> nested) {
150 return NestedPattern(nested, isAffineIfOp);
151 }
If(const FilterFunctionType & filter,ArrayRef<NestedPattern> nested)152 NestedPattern If(const FilterFunctionType &filter,
153 ArrayRef<NestedPattern> nested) {
154 return NestedPattern(nested, [filter](Operation &op) {
155 return isAffineIfOp(op) && filter(op);
156 });
157 }
158
For(const NestedPattern & child)159 NestedPattern For(const NestedPattern &child) {
160 return NestedPattern(child, isAffineForOp);
161 }
For(const FilterFunctionType & filter,const NestedPattern & child)162 NestedPattern For(const FilterFunctionType &filter,
163 const NestedPattern &child) {
164 return NestedPattern(
165 child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
166 }
For(ArrayRef<NestedPattern> nested)167 NestedPattern For(ArrayRef<NestedPattern> nested) {
168 return NestedPattern(nested, isAffineForOp);
169 }
For(const FilterFunctionType & filter,ArrayRef<NestedPattern> nested)170 NestedPattern For(const FilterFunctionType &filter,
171 ArrayRef<NestedPattern> nested) {
172 return NestedPattern(
173 nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
174 }
175
isLoadOrStore(Operation & op)176 bool isLoadOrStore(Operation &op) {
177 return isa<AffineLoadOp, AffineStoreOp>(op);
178 }
179
180 } // namespace matcher
181 } // namespace affine
182 } // namespace mlir
183