xref: /llvm-project/mlir/unittests/Rewrite/PatternBenefit.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/OwningOpRef.h"
10 #include "mlir/IR/PatternMatch.h"
11 #include "mlir/Rewrite/PatternApplicator.h"
12 #include "gtest/gtest.h"
13 
14 using namespace mlir;
15 
16 namespace {
TEST(PatternBenefitTest,BenefitOrder)17 TEST(PatternBenefitTest, BenefitOrder) {
18   // There was a bug which caused low-benefit op-specific patterns to never be
19   // called in presence of high-benefit op-agnostic pattern
20 
21   MLIRContext context;
22 
23   OpBuilder builder(&context);
24   OwningOpRef<ModuleOp> module = ModuleOp::create(builder.getUnknownLoc());
25 
26   struct Pattern1 : public OpRewritePattern<ModuleOp> {
27     Pattern1(mlir::MLIRContext *context, bool *called)
28         : OpRewritePattern<ModuleOp>(context, /*benefit*/ 1), called(called) {}
29 
30     llvm::LogicalResult
31     matchAndRewrite(ModuleOp /*op*/,
32                     mlir::PatternRewriter & /*rewriter*/) const override {
33       *called = true;
34       return failure();
35     }
36 
37   private:
38     bool *called;
39   };
40 
41   struct Pattern2 : public RewritePattern {
42     Pattern2(MLIRContext *context, bool *called)
43         : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context),
44           called(called) {}
45 
46     llvm::LogicalResult
47     matchAndRewrite(Operation * /*op*/,
48                     mlir::PatternRewriter & /*rewriter*/) const override {
49       *called = true;
50       return failure();
51     }
52 
53   private:
54     bool *called;
55   };
56 
57   RewritePatternSet patterns(&context);
58 
59   bool called1 = false;
60   bool called2 = false;
61 
62   patterns.add<Pattern1>(&context, &called1);
63   patterns.add<Pattern2>(&context, &called2);
64 
65   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
66   PatternApplicator pa(frozenPatterns);
67   pa.applyDefaultCostModel();
68 
69   class MyPatternRewriter : public PatternRewriter {
70   public:
71     MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}
72   };
73 
74   MyPatternRewriter rewriter(&context);
75   (void)pa.matchAndRewrite(*module, rewriter);
76 
77   EXPECT_TRUE(called1);
78   EXPECT_TRUE(called2);
79 }
80 } // namespace
81