xref: /llvm-project/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp (revision 2c1ae801e1b66a09a15028ae4ba614e0911eec00)
1 //===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/OpImplementation.h"
13 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "llvm/ADT/ScopeExit.h"
16 
17 using namespace mlir;
18 
19 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks)
20 
21 #define GET_OP_CLASSES
22 #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
23 
24 //===----------------------------------------------------------------------===//
25 // PatternApplicatorExtension
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 /// A TransformState extension that keeps track of compiled PDL pattern sets.
30 /// This is intended to be used along the WithPDLPatterns op. The extension
31 /// can be constructed given an operation that has a SymbolTable trait and
32 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
33 /// by one when requested; this behavior is subject to change.
34 class PatternApplicatorExtension : public transform::TransformState::Extension {
35 public:
36   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
37 
38   /// Creates the extension for patterns contained in `patternContainer`.
PatternApplicatorExtension(transform::TransformState & state,Operation * patternContainer)39   explicit PatternApplicatorExtension(transform::TransformState &state,
40                                       Operation *patternContainer)
41       : Extension(state), patterns(patternContainer) {}
42 
43   /// Appends to `results` the operations contained in `root` that matched the
44   /// PDL pattern with the given name. Note that `root` may or may not be the
45   /// operation that contains PDL patterns. Reports an error if the pattern
46   /// cannot be found. Note that when no operations are matched, this still
47   /// succeeds as long as the pattern exists.
48   LogicalResult findAllMatches(StringRef patternName, Operation *root,
49                                SmallVectorImpl<Operation *> &results);
50 
51 private:
52   /// Map from the pattern name to a singleton set of rewrite patterns that only
53   /// contains the pattern with this name. Populated when the pattern is first
54   /// requested.
55   // TODO: reconsider the efficiency of this storage when more usage data is
56   // available. Storing individual patterns in a set and triggering compilation
57   // for each of them has overhead. So does compiling a large set of patterns
58   // only to apply a handful of them.
59   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
60 
61   /// A symbol table operation containing the relevant PDL patterns.
62   SymbolTable patterns;
63 };
64 
findAllMatches(StringRef patternName,Operation * root,SmallVectorImpl<Operation * > & results)65 LogicalResult PatternApplicatorExtension::findAllMatches(
66     StringRef patternName, Operation *root,
67     SmallVectorImpl<Operation *> &results) {
68   auto it = compiledPatterns.find(patternName);
69   if (it == compiledPatterns.end()) {
70     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
71     if (!patternOp)
72       return failure();
73 
74     // Copy the pattern operation into a new module that is compiled and
75     // consumed by the PDL interpreter.
76     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
77     auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
78     builder.clone(*patternOp);
79     PDLPatternModule patternModule(std::move(pdlModuleOp));
80 
81     // Merge in the hooks owned by the dialect. Make a copy as they may be
82     // also used by the following operations.
83     auto *dialect =
84         root->getContext()->getLoadedDialect<transform::TransformDialect>();
85     for (const auto &[name, constraintFn] :
86          dialect->getExtraData<transform::PDLMatchHooks>()
87              .getPDLConstraintHooks()) {
88       patternModule.registerConstraintFunction(name, constraintFn);
89     }
90 
91     // Register a noop rewriter because PDL requires patterns to end with some
92     // rewrite call.
93     patternModule.registerRewriteFunction(
94         "transform.dialect", [](PatternRewriter &, Operation *) {});
95 
96     it = compiledPatterns
97              .try_emplace(patternOp.getName(), std::move(patternModule))
98              .first;
99   }
100 
101   PatternApplicator applicator(it->second);
102   // We want to discourage direct use of PatternRewriter in APIs but In this
103   // very specific case, an IRRewriter is not enough.
104   struct TrivialPatternRewriter : public PatternRewriter {
105   public:
106     explicit TrivialPatternRewriter(MLIRContext *context)
107         : PatternRewriter(context) {}
108   };
109   TrivialPatternRewriter rewriter(root->getContext());
110   applicator.applyDefaultCostModel();
111   root->walk([&](Operation *op) {
112     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
113       results.push_back(op);
114   });
115 
116   return success();
117 }
118 } // namespace
119 
120 //===----------------------------------------------------------------------===//
121 // PDLMatchHooks
122 //===----------------------------------------------------------------------===//
123 
mergeInPDLMatchHooks(llvm::StringMap<PDLConstraintFunction> && constraintFns)124 void transform::PDLMatchHooks::mergeInPDLMatchHooks(
125     llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
126   // Steal the constraint functions from the given map.
127   for (auto &it : constraintFns)
128     pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
129 }
130 
131 const llvm::StringMap<PDLConstraintFunction> &
getPDLConstraintHooks() const132 transform::PDLMatchHooks::getPDLConstraintHooks() const {
133   return pdlMatchHooks.getConstraintFunctions();
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // PDLMatchOp
138 //===----------------------------------------------------------------------===//
139 
140 DiagnosedSilenceableFailure
apply(transform::TransformRewriter & rewriter,transform::TransformResults & results,transform::TransformState & state)141 transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter,
142                              transform::TransformResults &results,
143                              transform::TransformState &state) {
144   auto *extension = state.getExtension<PatternApplicatorExtension>();
145   assert(extension &&
146          "expected PatternApplicatorExtension to be attached by the parent op");
147   SmallVector<Operation *> targets;
148   for (Operation *root : state.getPayloadOps(getRoot())) {
149     if (failed(extension->findAllMatches(
150             getPatternName().getLeafReference().getValue(), root, targets))) {
151       emitDefiniteFailure()
152           << "could not find pattern '" << getPatternName() << "'";
153     }
154   }
155   results.set(llvm::cast<OpResult>(getResult()), targets);
156   return DiagnosedSilenceableFailure::success();
157 }
158 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)159 void transform::PDLMatchOp::getEffects(
160     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
161   onlyReadsHandle(getRootMutable(), effects);
162   producesHandle(getOperation()->getOpResults(), effects);
163   onlyReadsPayload(effects);
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // WithPDLPatternsOp
168 //===----------------------------------------------------------------------===//
169 
170 DiagnosedSilenceableFailure
apply(transform::TransformRewriter & rewriter,transform::TransformResults & results,transform::TransformState & state)171 transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter,
172                                     transform::TransformResults &results,
173                                     transform::TransformState &state) {
174   TransformOpInterface transformOp = nullptr;
175   for (Operation &nested : getBody().front()) {
176     if (!isa<pdl::PatternOp>(nested)) {
177       transformOp = cast<TransformOpInterface>(nested);
178       break;
179     }
180   }
181 
182   state.addExtension<PatternApplicatorExtension>(getOperation());
183   auto guard = llvm::make_scope_exit(
184       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
185 
186   auto scope = state.make_region_scope(getBody());
187   if (failed(mapBlockArguments(state)))
188     return DiagnosedSilenceableFailure::definiteFailure();
189   return state.applyTransform(transformOp);
190 }
191 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)192 void transform::WithPDLPatternsOp::getEffects(
193     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
194   getPotentialTopLevelEffects(effects);
195 }
196 
verify()197 LogicalResult transform::WithPDLPatternsOp::verify() {
198   Block *body = getBodyBlock();
199   Operation *topLevelOp = nullptr;
200   for (Operation &op : body->getOperations()) {
201     if (isa<pdl::PatternOp>(op))
202       continue;
203 
204     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
205       if (topLevelOp) {
206         InFlightDiagnostic diag =
207             emitOpError() << "expects only one non-pattern op in its body";
208         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
209         diag.attachNote(op.getLoc()) << "second non-pattern op";
210         return diag;
211       }
212       topLevelOp = &op;
213       continue;
214     }
215 
216     InFlightDiagnostic diag =
217         emitOpError()
218         << "expects only pattern and top-level transform ops in its body";
219     diag.attachNote(op.getLoc()) << "offending op";
220     return diag;
221   }
222 
223   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
224     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
225     diag.attachNote(parent.getLoc()) << "parent operation";
226     return diag;
227   }
228 
229   if (!topLevelOp) {
230     InFlightDiagnostic diag = emitOpError()
231                               << "expects at least one non-pattern op";
232     return diag;
233   }
234 
235   return success();
236 }
237