xref: /llvm-project/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp (revision 2c1ae801e1b66a09a15028ae4ba614e0911eec00)
194d608d4SAlex Zinenko //===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===//
294d608d4SAlex Zinenko //
394d608d4SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
494d608d4SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
594d608d4SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
694d608d4SAlex Zinenko //
794d608d4SAlex Zinenko //===----------------------------------------------------------------------===//
894d608d4SAlex Zinenko 
994d608d4SAlex Zinenko #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
1094d608d4SAlex Zinenko #include "mlir/Dialect/PDL/IR/PDLOps.h"
1194d608d4SAlex Zinenko #include "mlir/IR/Builders.h"
1294d608d4SAlex Zinenko #include "mlir/IR/OpImplementation.h"
1394d608d4SAlex Zinenko #include "mlir/Rewrite/FrozenRewritePatternSet.h"
1494d608d4SAlex Zinenko #include "mlir/Rewrite/PatternApplicator.h"
1594d608d4SAlex Zinenko #include "llvm/ADT/ScopeExit.h"
1694d608d4SAlex Zinenko 
1794d608d4SAlex Zinenko using namespace mlir;
1894d608d4SAlex Zinenko 
1994d608d4SAlex Zinenko MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks)
2094d608d4SAlex Zinenko 
2194d608d4SAlex Zinenko #define GET_OP_CLASSES
2294d608d4SAlex Zinenko #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
2394d608d4SAlex Zinenko 
2494d608d4SAlex Zinenko //===----------------------------------------------------------------------===//
2594d608d4SAlex Zinenko // PatternApplicatorExtension
2694d608d4SAlex Zinenko //===----------------------------------------------------------------------===//
2794d608d4SAlex Zinenko 
2894d608d4SAlex Zinenko namespace {
2994d608d4SAlex Zinenko /// A TransformState extension that keeps track of compiled PDL pattern sets.
3094d608d4SAlex Zinenko /// This is intended to be used along the WithPDLPatterns op. The extension
3194d608d4SAlex Zinenko /// can be constructed given an operation that has a SymbolTable trait and
3294d608d4SAlex Zinenko /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
3394d608d4SAlex Zinenko /// by one when requested; this behavior is subject to change.
3494d608d4SAlex Zinenko class PatternApplicatorExtension : public transform::TransformState::Extension {
3594d608d4SAlex Zinenko public:
3694d608d4SAlex Zinenko   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
3794d608d4SAlex Zinenko 
3894d608d4SAlex Zinenko   /// Creates the extension for patterns contained in `patternContainer`.
PatternApplicatorExtension(transform::TransformState & state,Operation * patternContainer)3994d608d4SAlex Zinenko   explicit PatternApplicatorExtension(transform::TransformState &state,
4094d608d4SAlex Zinenko                                       Operation *patternContainer)
4194d608d4SAlex Zinenko       : Extension(state), patterns(patternContainer) {}
4294d608d4SAlex Zinenko 
4394d608d4SAlex Zinenko   /// Appends to `results` the operations contained in `root` that matched the
4494d608d4SAlex Zinenko   /// PDL pattern with the given name. Note that `root` may or may not be the
4594d608d4SAlex Zinenko   /// operation that contains PDL patterns. Reports an error if the pattern
4694d608d4SAlex Zinenko   /// cannot be found. Note that when no operations are matched, this still
4794d608d4SAlex Zinenko   /// succeeds as long as the pattern exists.
4894d608d4SAlex Zinenko   LogicalResult findAllMatches(StringRef patternName, Operation *root,
4994d608d4SAlex Zinenko                                SmallVectorImpl<Operation *> &results);
5094d608d4SAlex Zinenko 
5194d608d4SAlex Zinenko private:
5294d608d4SAlex Zinenko   /// Map from the pattern name to a singleton set of rewrite patterns that only
5394d608d4SAlex Zinenko   /// contains the pattern with this name. Populated when the pattern is first
5494d608d4SAlex Zinenko   /// requested.
5594d608d4SAlex Zinenko   // TODO: reconsider the efficiency of this storage when more usage data is
5694d608d4SAlex Zinenko   // available. Storing individual patterns in a set and triggering compilation
5794d608d4SAlex Zinenko   // for each of them has overhead. So does compiling a large set of patterns
5894d608d4SAlex Zinenko   // only to apply a handful of them.
5994d608d4SAlex Zinenko   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
6094d608d4SAlex Zinenko 
6194d608d4SAlex Zinenko   /// A symbol table operation containing the relevant PDL patterns.
6294d608d4SAlex Zinenko   SymbolTable patterns;
6394d608d4SAlex Zinenko };
6494d608d4SAlex Zinenko 
findAllMatches(StringRef patternName,Operation * root,SmallVectorImpl<Operation * > & results)6594d608d4SAlex Zinenko LogicalResult PatternApplicatorExtension::findAllMatches(
6694d608d4SAlex Zinenko     StringRef patternName, Operation *root,
6794d608d4SAlex Zinenko     SmallVectorImpl<Operation *> &results) {
6894d608d4SAlex Zinenko   auto it = compiledPatterns.find(patternName);
6994d608d4SAlex Zinenko   if (it == compiledPatterns.end()) {
7094d608d4SAlex Zinenko     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
7194d608d4SAlex Zinenko     if (!patternOp)
7294d608d4SAlex Zinenko       return failure();
7394d608d4SAlex Zinenko 
7494d608d4SAlex Zinenko     // Copy the pattern operation into a new module that is compiled and
7594d608d4SAlex Zinenko     // consumed by the PDL interpreter.
7694d608d4SAlex Zinenko     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
7794d608d4SAlex Zinenko     auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
7894d608d4SAlex Zinenko     builder.clone(*patternOp);
7994d608d4SAlex Zinenko     PDLPatternModule patternModule(std::move(pdlModuleOp));
8094d608d4SAlex Zinenko 
8194d608d4SAlex Zinenko     // Merge in the hooks owned by the dialect. Make a copy as they may be
8294d608d4SAlex Zinenko     // also used by the following operations.
8394d608d4SAlex Zinenko     auto *dialect =
8494d608d4SAlex Zinenko         root->getContext()->getLoadedDialect<transform::TransformDialect>();
8594d608d4SAlex Zinenko     for (const auto &[name, constraintFn] :
8694d608d4SAlex Zinenko          dialect->getExtraData<transform::PDLMatchHooks>()
8794d608d4SAlex Zinenko              .getPDLConstraintHooks()) {
8894d608d4SAlex Zinenko       patternModule.registerConstraintFunction(name, constraintFn);
8994d608d4SAlex Zinenko     }
9094d608d4SAlex Zinenko 
9194d608d4SAlex Zinenko     // Register a noop rewriter because PDL requires patterns to end with some
9294d608d4SAlex Zinenko     // rewrite call.
9394d608d4SAlex Zinenko     patternModule.registerRewriteFunction(
9494d608d4SAlex Zinenko         "transform.dialect", [](PatternRewriter &, Operation *) {});
9594d608d4SAlex Zinenko 
9694d608d4SAlex Zinenko     it = compiledPatterns
9794d608d4SAlex Zinenko              .try_emplace(patternOp.getName(), std::move(patternModule))
9894d608d4SAlex Zinenko              .first;
9994d608d4SAlex Zinenko   }
10094d608d4SAlex Zinenko 
10194d608d4SAlex Zinenko   PatternApplicator applicator(it->second);
10294d608d4SAlex Zinenko   // We want to discourage direct use of PatternRewriter in APIs but In this
10394d608d4SAlex Zinenko   // very specific case, an IRRewriter is not enough.
10494d608d4SAlex Zinenko   struct TrivialPatternRewriter : public PatternRewriter {
10594d608d4SAlex Zinenko   public:
10694d608d4SAlex Zinenko     explicit TrivialPatternRewriter(MLIRContext *context)
10794d608d4SAlex Zinenko         : PatternRewriter(context) {}
10894d608d4SAlex Zinenko   };
10994d608d4SAlex Zinenko   TrivialPatternRewriter rewriter(root->getContext());
11094d608d4SAlex Zinenko   applicator.applyDefaultCostModel();
11194d608d4SAlex Zinenko   root->walk([&](Operation *op) {
11294d608d4SAlex Zinenko     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
11394d608d4SAlex Zinenko       results.push_back(op);
11494d608d4SAlex Zinenko   });
11594d608d4SAlex Zinenko 
11694d608d4SAlex Zinenko   return success();
11794d608d4SAlex Zinenko }
11894d608d4SAlex Zinenko } // namespace
11994d608d4SAlex Zinenko 
12094d608d4SAlex Zinenko //===----------------------------------------------------------------------===//
12194d608d4SAlex Zinenko // PDLMatchHooks
12294d608d4SAlex Zinenko //===----------------------------------------------------------------------===//
12394d608d4SAlex Zinenko 
mergeInPDLMatchHooks(llvm::StringMap<PDLConstraintFunction> && constraintFns)12494d608d4SAlex Zinenko void transform::PDLMatchHooks::mergeInPDLMatchHooks(
12594d608d4SAlex Zinenko     llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
12694d608d4SAlex Zinenko   // Steal the constraint functions from the given map.
12794d608d4SAlex Zinenko   for (auto &it : constraintFns)
12894d608d4SAlex Zinenko     pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
12994d608d4SAlex Zinenko }
13094d608d4SAlex Zinenko 
13194d608d4SAlex Zinenko const llvm::StringMap<PDLConstraintFunction> &
getPDLConstraintHooks() const13294d608d4SAlex Zinenko transform::PDLMatchHooks::getPDLConstraintHooks() const {
13394d608d4SAlex Zinenko   return pdlMatchHooks.getConstraintFunctions();
13494d608d4SAlex Zinenko }
13594d608d4SAlex Zinenko 
13694d608d4SAlex Zinenko //===----------------------------------------------------------------------===//
13794d608d4SAlex Zinenko // PDLMatchOp
13894d608d4SAlex Zinenko //===----------------------------------------------------------------------===//
13994d608d4SAlex Zinenko 
14094d608d4SAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformRewriter & rewriter,transform::TransformResults & results,transform::TransformState & state)141c63d2b2cSMatthias Springer transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter,
142c63d2b2cSMatthias Springer                              transform::TransformResults &results,
14394d608d4SAlex Zinenko                              transform::TransformState &state) {
14494d608d4SAlex Zinenko   auto *extension = state.getExtension<PatternApplicatorExtension>();
14594d608d4SAlex Zinenko   assert(extension &&
14694d608d4SAlex Zinenko          "expected PatternApplicatorExtension to be attached by the parent op");
14794d608d4SAlex Zinenko   SmallVector<Operation *> targets;
14894d608d4SAlex Zinenko   for (Operation *root : state.getPayloadOps(getRoot())) {
14994d608d4SAlex Zinenko     if (failed(extension->findAllMatches(
15094d608d4SAlex Zinenko             getPatternName().getLeafReference().getValue(), root, targets))) {
15194d608d4SAlex Zinenko       emitDefiniteFailure()
15294d608d4SAlex Zinenko           << "could not find pattern '" << getPatternName() << "'";
15394d608d4SAlex Zinenko     }
15494d608d4SAlex Zinenko   }
15594d608d4SAlex Zinenko   results.set(llvm::cast<OpResult>(getResult()), targets);
15694d608d4SAlex Zinenko   return DiagnosedSilenceableFailure::success();
15794d608d4SAlex Zinenko }
15894d608d4SAlex Zinenko 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)15994d608d4SAlex Zinenko void transform::PDLMatchOp::getEffects(
16094d608d4SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
161*2c1ae801Sdonald chen   onlyReadsHandle(getRootMutable(), effects);
162*2c1ae801Sdonald chen   producesHandle(getOperation()->getOpResults(), effects);
16394d608d4SAlex Zinenko   onlyReadsPayload(effects);
16494d608d4SAlex Zinenko }
16594d608d4SAlex Zinenko 
16694d608d4SAlex Zinenko //===----------------------------------------------------------------------===//
16794d608d4SAlex Zinenko // WithPDLPatternsOp
16894d608d4SAlex Zinenko //===----------------------------------------------------------------------===//
16994d608d4SAlex Zinenko 
17094d608d4SAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformRewriter & rewriter,transform::TransformResults & results,transform::TransformState & state)171c63d2b2cSMatthias Springer transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter,
172c63d2b2cSMatthias Springer                                     transform::TransformResults &results,
17394d608d4SAlex Zinenko                                     transform::TransformState &state) {
17494d608d4SAlex Zinenko   TransformOpInterface transformOp = nullptr;
17594d608d4SAlex Zinenko   for (Operation &nested : getBody().front()) {
17694d608d4SAlex Zinenko     if (!isa<pdl::PatternOp>(nested)) {
17794d608d4SAlex Zinenko       transformOp = cast<TransformOpInterface>(nested);
17894d608d4SAlex Zinenko       break;
17994d608d4SAlex Zinenko     }
18094d608d4SAlex Zinenko   }
18194d608d4SAlex Zinenko 
18294d608d4SAlex Zinenko   state.addExtension<PatternApplicatorExtension>(getOperation());
18394d608d4SAlex Zinenko   auto guard = llvm::make_scope_exit(
18494d608d4SAlex Zinenko       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
18594d608d4SAlex Zinenko 
18694d608d4SAlex Zinenko   auto scope = state.make_region_scope(getBody());
18794d608d4SAlex Zinenko   if (failed(mapBlockArguments(state)))
18894d608d4SAlex Zinenko     return DiagnosedSilenceableFailure::definiteFailure();
18994d608d4SAlex Zinenko   return state.applyTransform(transformOp);
19094d608d4SAlex Zinenko }
19194d608d4SAlex Zinenko 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)19294d608d4SAlex Zinenko void transform::WithPDLPatternsOp::getEffects(
19394d608d4SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
19494d608d4SAlex Zinenko   getPotentialTopLevelEffects(effects);
19594d608d4SAlex Zinenko }
19694d608d4SAlex Zinenko 
verify()19794d608d4SAlex Zinenko LogicalResult transform::WithPDLPatternsOp::verify() {
19894d608d4SAlex Zinenko   Block *body = getBodyBlock();
19994d608d4SAlex Zinenko   Operation *topLevelOp = nullptr;
20094d608d4SAlex Zinenko   for (Operation &op : body->getOperations()) {
20194d608d4SAlex Zinenko     if (isa<pdl::PatternOp>(op))
20294d608d4SAlex Zinenko       continue;
20394d608d4SAlex Zinenko 
20494d608d4SAlex Zinenko     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
20594d608d4SAlex Zinenko       if (topLevelOp) {
20694d608d4SAlex Zinenko         InFlightDiagnostic diag =
20794d608d4SAlex Zinenko             emitOpError() << "expects only one non-pattern op in its body";
20894d608d4SAlex Zinenko         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
20994d608d4SAlex Zinenko         diag.attachNote(op.getLoc()) << "second non-pattern op";
21094d608d4SAlex Zinenko         return diag;
21194d608d4SAlex Zinenko       }
21294d608d4SAlex Zinenko       topLevelOp = &op;
21394d608d4SAlex Zinenko       continue;
21494d608d4SAlex Zinenko     }
21594d608d4SAlex Zinenko 
21694d608d4SAlex Zinenko     InFlightDiagnostic diag =
21794d608d4SAlex Zinenko         emitOpError()
21894d608d4SAlex Zinenko         << "expects only pattern and top-level transform ops in its body";
21994d608d4SAlex Zinenko     diag.attachNote(op.getLoc()) << "offending op";
22094d608d4SAlex Zinenko     return diag;
22194d608d4SAlex Zinenko   }
22294d608d4SAlex Zinenko 
22394d608d4SAlex Zinenko   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
22494d608d4SAlex Zinenko     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
22594d608d4SAlex Zinenko     diag.attachNote(parent.getLoc()) << "parent operation";
22694d608d4SAlex Zinenko     return diag;
22794d608d4SAlex Zinenko   }
22894d608d4SAlex Zinenko 
22994d608d4SAlex Zinenko   if (!topLevelOp) {
23094d608d4SAlex Zinenko     InFlightDiagnostic diag = emitOpError()
23194d608d4SAlex Zinenko                               << "expects at least one non-pattern op";
23294d608d4SAlex Zinenko     return diag;
23394d608d4SAlex Zinenko   }
23494d608d4SAlex Zinenko 
23594d608d4SAlex Zinenko   return success();
23694d608d4SAlex Zinenko }
237