1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "gtest/gtest.h"
11
12 using namespace mlir;
13
createOp(MLIRContext * context)14 static Operation *createOp(MLIRContext *context) {
15 context->allowUnregisteredDialects();
16 return Operation::create(
17 UnknownLoc::get(context), OperationName("foo.bar", context), std::nullopt,
18 std::nullopt, std::nullopt, /*properties=*/nullptr, std::nullopt, 0);
19 }
20
21 namespace {
22 struct DummyOp {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonf30c54920111::DummyOp23 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyOp)
24
25 static StringRef getOperationName() { return "foo.bar"; }
26 };
27
TEST(DialectConversionTest,DynamicallyLegalOpCallbackOrder)28 TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
29 MLIRContext context;
30 ConversionTarget target(context);
31
32 int index = 0;
33 int callbackCalled1 = 0;
34 target.addDynamicallyLegalOp<DummyOp>([&](Operation *) {
35 callbackCalled1 = ++index;
36 return true;
37 });
38
39 int callbackCalled2 = 0;
40 target.addDynamicallyLegalOp<DummyOp>(
41 [&](Operation *) -> std::optional<bool> {
42 callbackCalled2 = ++index;
43 return std::nullopt;
44 });
45
46 auto *op = createOp(&context);
47 EXPECT_TRUE(target.isLegal(op));
48 EXPECT_EQ(2, callbackCalled1);
49 EXPECT_EQ(1, callbackCalled2);
50 EXPECT_FALSE(target.isIllegal(op));
51 EXPECT_EQ(4, callbackCalled1);
52 EXPECT_EQ(3, callbackCalled2);
53 op->destroy();
54 }
55
TEST(DialectConversionTest,DynamicallyLegalOpCallbackSkip)56 TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
57 MLIRContext context;
58 ConversionTarget target(context);
59
60 int index = 0;
61 int callbackCalled = 0;
62 target.addDynamicallyLegalOp<DummyOp>(
63 [&](Operation *) -> std::optional<bool> {
64 callbackCalled = ++index;
65 return std::nullopt;
66 });
67
68 auto *op = createOp(&context);
69 EXPECT_FALSE(target.isLegal(op));
70 EXPECT_EQ(1, callbackCalled);
71 EXPECT_FALSE(target.isIllegal(op));
72 EXPECT_EQ(2, callbackCalled);
73 op->destroy();
74 }
75
TEST(DialectConversionTest,DynamicallyLegalUnknownOpCallbackOrder)76 TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
77 MLIRContext context;
78 ConversionTarget target(context);
79
80 int index = 0;
81 int callbackCalled1 = 0;
82 target.markUnknownOpDynamicallyLegal([&](Operation *) {
83 callbackCalled1 = ++index;
84 return true;
85 });
86
87 int callbackCalled2 = 0;
88 target.markUnknownOpDynamicallyLegal([&](Operation *) -> std::optional<bool> {
89 callbackCalled2 = ++index;
90 return std::nullopt;
91 });
92
93 auto *op = createOp(&context);
94 EXPECT_TRUE(target.isLegal(op));
95 EXPECT_EQ(2, callbackCalled1);
96 EXPECT_EQ(1, callbackCalled2);
97 EXPECT_FALSE(target.isIllegal(op));
98 EXPECT_EQ(4, callbackCalled1);
99 EXPECT_EQ(3, callbackCalled2);
100 op->destroy();
101 }
102
TEST(DialectConversionTest,DynamicallyLegalReturnNone)103 TEST(DialectConversionTest, DynamicallyLegalReturnNone) {
104 MLIRContext context;
105 ConversionTarget target(context);
106
107 target.addDynamicallyLegalOp<DummyOp>(
108 [&](Operation *) -> std::optional<bool> { return std::nullopt; });
109
110 auto *op = createOp(&context);
111 EXPECT_FALSE(target.isLegal(op));
112 EXPECT_FALSE(target.isIllegal(op));
113
114 EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
115 EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
116
117 op->destroy();
118 }
119
TEST(DialectConversionTest,DynamicallyLegalUnknownReturnNone)120 TEST(DialectConversionTest, DynamicallyLegalUnknownReturnNone) {
121 MLIRContext context;
122 ConversionTarget target(context);
123
124 target.markUnknownOpDynamicallyLegal(
125 [&](Operation *) -> std::optional<bool> { return std::nullopt; });
126
127 auto *op = createOp(&context);
128 EXPECT_FALSE(target.isLegal(op));
129 EXPECT_FALSE(target.isIllegal(op));
130
131 EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
132 EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
133
134 op->destroy();
135 }
136 } // namespace
137