1a531bbd9SButygin //===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===//
2a531bbd9SButygin //
3a531bbd9SButygin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a531bbd9SButygin // See https://llvm.org/LICENSE.txt for license information.
5a531bbd9SButygin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a531bbd9SButygin //
7a531bbd9SButygin //===----------------------------------------------------------------------===//
8a531bbd9SButygin
957d9adefSMehdi Amini #include "mlir/IR/OwningOpRef.h"
10a531bbd9SButygin #include "mlir/IR/PatternMatch.h"
11a531bbd9SButygin #include "mlir/Rewrite/PatternApplicator.h"
12a531bbd9SButygin #include "gtest/gtest.h"
13a531bbd9SButygin
14a531bbd9SButygin using namespace mlir;
15a531bbd9SButygin
16a531bbd9SButygin namespace {
TEST(PatternBenefitTest,BenefitOrder)17a531bbd9SButygin TEST(PatternBenefitTest, BenefitOrder) {
18a531bbd9SButygin // There was a bug which caused low-benefit op-specific patterns to never be
19a531bbd9SButygin // called in presence of high-benefit op-agnostic pattern
20a531bbd9SButygin
21a531bbd9SButygin MLIRContext context;
22a531bbd9SButygin
23a531bbd9SButygin OpBuilder builder(&context);
2457d9adefSMehdi Amini OwningOpRef<ModuleOp> module = ModuleOp::create(builder.getUnknownLoc());
25a531bbd9SButygin
26a531bbd9SButygin struct Pattern1 : public OpRewritePattern<ModuleOp> {
27a531bbd9SButygin Pattern1(mlir::MLIRContext *context, bool *called)
28a531bbd9SButygin : OpRewritePattern<ModuleOp>(context, /*benefit*/ 1), called(called) {}
29a531bbd9SButygin
30*db791b27SRamkumar Ramachandra llvm::LogicalResult
31a531bbd9SButygin matchAndRewrite(ModuleOp /*op*/,
32a531bbd9SButygin mlir::PatternRewriter & /*rewriter*/) const override {
33a531bbd9SButygin *called = true;
34a531bbd9SButygin return failure();
35a531bbd9SButygin }
36a531bbd9SButygin
37a531bbd9SButygin private:
38a531bbd9SButygin bool *called;
39a531bbd9SButygin };
40a531bbd9SButygin
41a531bbd9SButygin struct Pattern2 : public RewritePattern {
4276f3c2f3SRiver Riddle Pattern2(MLIRContext *context, bool *called)
4376f3c2f3SRiver Riddle : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context),
4476f3c2f3SRiver Riddle called(called) {}
45a531bbd9SButygin
46*db791b27SRamkumar Ramachandra llvm::LogicalResult
47a531bbd9SButygin matchAndRewrite(Operation * /*op*/,
48a531bbd9SButygin mlir::PatternRewriter & /*rewriter*/) const override {
49a531bbd9SButygin *called = true;
50a531bbd9SButygin return failure();
51a531bbd9SButygin }
52a531bbd9SButygin
53a531bbd9SButygin private:
54a531bbd9SButygin bool *called;
55a531bbd9SButygin };
56a531bbd9SButygin
57dc4e913bSChris Lattner RewritePatternSet patterns(&context);
58a531bbd9SButygin
59a531bbd9SButygin bool called1 = false;
60a531bbd9SButygin bool called2 = false;
61a531bbd9SButygin
62dc4e913bSChris Lattner patterns.add<Pattern1>(&context, &called1);
6376f3c2f3SRiver Riddle patterns.add<Pattern2>(&context, &called2);
64a531bbd9SButygin
6579d7f618SChris Lattner FrozenRewritePatternSet frozenPatterns(std::move(patterns));
66a531bbd9SButygin PatternApplicator pa(frozenPatterns);
67a531bbd9SButygin pa.applyDefaultCostModel();
68a531bbd9SButygin
69a531bbd9SButygin class MyPatternRewriter : public PatternRewriter {
70a531bbd9SButygin public:
71a531bbd9SButygin MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}
72a531bbd9SButygin };
73a531bbd9SButygin
74a531bbd9SButygin MyPatternRewriter rewriter(&context);
7557d9adefSMehdi Amini (void)pa.matchAndRewrite(*module, rewriter);
76a531bbd9SButygin
77a531bbd9SButygin EXPECT_TRUE(called1);
78a531bbd9SButygin EXPECT_TRUE(called2);
79a531bbd9SButygin }
80a531bbd9SButygin } // namespace
81