xref: /llvm-project/mlir/unittests/Rewrite/PatternBenefit.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1a531bbd9SButygin //===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===//
2a531bbd9SButygin //
3a531bbd9SButygin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a531bbd9SButygin // See https://llvm.org/LICENSE.txt for license information.
5a531bbd9SButygin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a531bbd9SButygin //
7a531bbd9SButygin //===----------------------------------------------------------------------===//
8a531bbd9SButygin 
957d9adefSMehdi Amini #include "mlir/IR/OwningOpRef.h"
10a531bbd9SButygin #include "mlir/IR/PatternMatch.h"
11a531bbd9SButygin #include "mlir/Rewrite/PatternApplicator.h"
12a531bbd9SButygin #include "gtest/gtest.h"
13a531bbd9SButygin 
14a531bbd9SButygin using namespace mlir;
15a531bbd9SButygin 
16a531bbd9SButygin namespace {
TEST(PatternBenefitTest,BenefitOrder)17a531bbd9SButygin TEST(PatternBenefitTest, BenefitOrder) {
18a531bbd9SButygin   // There was a bug which caused low-benefit op-specific patterns to never be
19a531bbd9SButygin   // called in presence of high-benefit op-agnostic pattern
20a531bbd9SButygin 
21a531bbd9SButygin   MLIRContext context;
22a531bbd9SButygin 
23a531bbd9SButygin   OpBuilder builder(&context);
2457d9adefSMehdi Amini   OwningOpRef<ModuleOp> module = ModuleOp::create(builder.getUnknownLoc());
25a531bbd9SButygin 
26a531bbd9SButygin   struct Pattern1 : public OpRewritePattern<ModuleOp> {
27a531bbd9SButygin     Pattern1(mlir::MLIRContext *context, bool *called)
28a531bbd9SButygin         : OpRewritePattern<ModuleOp>(context, /*benefit*/ 1), called(called) {}
29a531bbd9SButygin 
30*db791b27SRamkumar Ramachandra     llvm::LogicalResult
31a531bbd9SButygin     matchAndRewrite(ModuleOp /*op*/,
32a531bbd9SButygin                     mlir::PatternRewriter & /*rewriter*/) const override {
33a531bbd9SButygin       *called = true;
34a531bbd9SButygin       return failure();
35a531bbd9SButygin     }
36a531bbd9SButygin 
37a531bbd9SButygin   private:
38a531bbd9SButygin     bool *called;
39a531bbd9SButygin   };
40a531bbd9SButygin 
41a531bbd9SButygin   struct Pattern2 : public RewritePattern {
4276f3c2f3SRiver Riddle     Pattern2(MLIRContext *context, bool *called)
4376f3c2f3SRiver Riddle         : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context),
4476f3c2f3SRiver Riddle           called(called) {}
45a531bbd9SButygin 
46*db791b27SRamkumar Ramachandra     llvm::LogicalResult
47a531bbd9SButygin     matchAndRewrite(Operation * /*op*/,
48a531bbd9SButygin                     mlir::PatternRewriter & /*rewriter*/) const override {
49a531bbd9SButygin       *called = true;
50a531bbd9SButygin       return failure();
51a531bbd9SButygin     }
52a531bbd9SButygin 
53a531bbd9SButygin   private:
54a531bbd9SButygin     bool *called;
55a531bbd9SButygin   };
56a531bbd9SButygin 
57dc4e913bSChris Lattner   RewritePatternSet patterns(&context);
58a531bbd9SButygin 
59a531bbd9SButygin   bool called1 = false;
60a531bbd9SButygin   bool called2 = false;
61a531bbd9SButygin 
62dc4e913bSChris Lattner   patterns.add<Pattern1>(&context, &called1);
6376f3c2f3SRiver Riddle   patterns.add<Pattern2>(&context, &called2);
64a531bbd9SButygin 
6579d7f618SChris Lattner   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
66a531bbd9SButygin   PatternApplicator pa(frozenPatterns);
67a531bbd9SButygin   pa.applyDefaultCostModel();
68a531bbd9SButygin 
69a531bbd9SButygin   class MyPatternRewriter : public PatternRewriter {
70a531bbd9SButygin   public:
71a531bbd9SButygin     MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}
72a531bbd9SButygin   };
73a531bbd9SButygin 
74a531bbd9SButygin   MyPatternRewriter rewriter(&context);
7557d9adefSMehdi Amini   (void)pa.matchAndRewrite(*module, rewriter);
76a531bbd9SButygin 
77a531bbd9SButygin   EXPECT_TRUE(called1);
78a531bbd9SButygin   EXPECT_TRUE(called2);
79a531bbd9SButygin }
80a531bbd9SButygin } // namespace
81