xref: /llvm-project/mlir/lib/Dialect/Affine/Analysis/NestedMatcher.cpp (revision 4c48f016effde67d500fc95290096aec9f3bdb70)
1 //===- NestedMatcher.cpp - NestedMatcher Impl  ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/Allocator.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 using namespace mlir;
20 using namespace mlir::affine;
21 
allocator()22 llvm::BumpPtrAllocator *&NestedMatch::allocator() {
23   thread_local llvm::BumpPtrAllocator *allocator = nullptr;
24   return allocator;
25 }
26 
build(Operation * operation,ArrayRef<NestedMatch> nestedMatches)27 NestedMatch NestedMatch::build(Operation *operation,
28                                ArrayRef<NestedMatch> nestedMatches) {
29   auto *result = allocator()->Allocate<NestedMatch>();
30   auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
31   std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
32   new (result) NestedMatch();
33   result->matchedOperation = operation;
34   result->matchedChildren =
35       ArrayRef<NestedMatch>(children, nestedMatches.size());
36   return *result;
37 }
38 
allocator()39 llvm::BumpPtrAllocator *&NestedPattern::allocator() {
40   thread_local llvm::BumpPtrAllocator *allocator = nullptr;
41   return allocator;
42 }
43 
copyNestedToThis(ArrayRef<NestedPattern> nested)44 void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) {
45   if (nested.empty())
46     return;
47 
48   auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
49   std::uninitialized_copy(nested.begin(), nested.end(), newNested);
50   nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
51 }
52 
freeNested()53 void NestedPattern::freeNested() {
54   for (const auto &p : nestedPatterns)
55     p.~NestedPattern();
56 }
57 
NestedPattern(ArrayRef<NestedPattern> nested,FilterFunctionType filter)58 NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
59                              FilterFunctionType filter)
60     : filter(std::move(filter)), skip(nullptr) {
61   copyNestedToThis(nested);
62 }
63 
NestedPattern(const NestedPattern & other)64 NestedPattern::NestedPattern(const NestedPattern &other)
65     : filter(other.filter), skip(other.skip) {
66   copyNestedToThis(other.nestedPatterns);
67 }
68 
operator =(const NestedPattern & other)69 NestedPattern &NestedPattern::operator=(const NestedPattern &other) {
70   freeNested();
71   filter = other.filter;
72   skip = other.skip;
73   copyNestedToThis(other.nestedPatterns);
74   return *this;
75 }
76 
getDepth() const77 unsigned NestedPattern::getDepth() const {
78   if (nestedPatterns.empty()) {
79     return 1;
80   }
81   unsigned depth = 0;
82   for (auto &c : nestedPatterns) {
83     depth = std::max(depth, c.getDepth());
84   }
85   return depth + 1;
86 }
87 
88 /// Matches a single operation in the following way:
89 ///   1. checks the kind of operation against the matcher, if different then
90 ///      there is no match;
91 ///   2. calls the customizable filter function to refine the single operation
92 ///      match with extra semantic constraints;
93 ///   3. if all is good, recursively matches the nested patterns;
94 ///   4. if all nested match then the single operation matches too and is
95 ///      appended to the list of matches;
96 ///   5. TODO: Optionally applies actions (lambda), in which case we will want
97 ///      to traverse in post-order DFS to avoid invalidating iterators.
matchOne(Operation * op,SmallVectorImpl<NestedMatch> * matches)98 void NestedPattern::matchOne(Operation *op,
99                              SmallVectorImpl<NestedMatch> *matches) {
100   if (skip == op) {
101     return;
102   }
103   // Local custom filter function
104   if (!filter(*op)) {
105     return;
106   }
107 
108   if (nestedPatterns.empty()) {
109     SmallVector<NestedMatch, 8> nestedMatches;
110     matches->push_back(NestedMatch::build(op, nestedMatches));
111     return;
112   }
113   // Take a copy of each nested pattern so we can match it.
114   for (auto nestedPattern : nestedPatterns) {
115     SmallVector<NestedMatch, 8> nestedMatches;
116     // Skip elem in the walk immediately following. Without this we would
117     // essentially need to reimplement walk here.
118     nestedPattern.skip = op;
119     nestedPattern.match(op, &nestedMatches);
120     // If we could not match even one of the specified nestedPattern, early exit
121     // as this whole branch is not a match.
122     if (nestedMatches.empty()) {
123       return;
124     }
125     matches->push_back(NestedMatch::build(op, nestedMatches));
126   }
127 }
128 
isAffineForOp(Operation & op)129 static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
130 
isAffineIfOp(Operation & op)131 static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
132 
133 namespace mlir {
134 namespace affine {
135 namespace matcher {
136 
Op(FilterFunctionType filter)137 NestedPattern Op(FilterFunctionType filter) {
138   return NestedPattern({}, std::move(filter));
139 }
140 
If(const NestedPattern & child)141 NestedPattern If(const NestedPattern &child) {
142   return NestedPattern(child, isAffineIfOp);
143 }
If(const FilterFunctionType & filter,const NestedPattern & child)144 NestedPattern If(const FilterFunctionType &filter, const NestedPattern &child) {
145   return NestedPattern(child, [filter](Operation &op) {
146     return isAffineIfOp(op) && filter(op);
147   });
148 }
If(ArrayRef<NestedPattern> nested)149 NestedPattern If(ArrayRef<NestedPattern> nested) {
150   return NestedPattern(nested, isAffineIfOp);
151 }
If(const FilterFunctionType & filter,ArrayRef<NestedPattern> nested)152 NestedPattern If(const FilterFunctionType &filter,
153                  ArrayRef<NestedPattern> nested) {
154   return NestedPattern(nested, [filter](Operation &op) {
155     return isAffineIfOp(op) && filter(op);
156   });
157 }
158 
For(const NestedPattern & child)159 NestedPattern For(const NestedPattern &child) {
160   return NestedPattern(child, isAffineForOp);
161 }
For(const FilterFunctionType & filter,const NestedPattern & child)162 NestedPattern For(const FilterFunctionType &filter,
163                   const NestedPattern &child) {
164   return NestedPattern(
165       child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
166 }
For(ArrayRef<NestedPattern> nested)167 NestedPattern For(ArrayRef<NestedPattern> nested) {
168   return NestedPattern(nested, isAffineForOp);
169 }
For(const FilterFunctionType & filter,ArrayRef<NestedPattern> nested)170 NestedPattern For(const FilterFunctionType &filter,
171                   ArrayRef<NestedPattern> nested) {
172   return NestedPattern(
173       nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
174 }
175 
isLoadOrStore(Operation & op)176 bool isLoadOrStore(Operation &op) {
177   return isa<AffineLoadOp, AffineStoreOp>(op);
178 }
179 
180 } // namespace matcher
181 } // namespace affine
182 } // namespace mlir
183