xref: /llvm-project/mlir/unittests/Transforms/DialectConversion.cpp (revision 5e118f933b6590cecd7f1afb30845a1594bc4a5d)
1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "gtest/gtest.h"
11 
12 using namespace mlir;
13 
createOp(MLIRContext * context)14 static Operation *createOp(MLIRContext *context) {
15   context->allowUnregisteredDialects();
16   return Operation::create(
17       UnknownLoc::get(context), OperationName("foo.bar", context), std::nullopt,
18       std::nullopt, std::nullopt, /*properties=*/nullptr, std::nullopt, 0);
19 }
20 
21 namespace {
22 struct DummyOp {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonf30c54920111::DummyOp23   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyOp)
24 
25   static StringRef getOperationName() { return "foo.bar"; }
26 };
27 
TEST(DialectConversionTest,DynamicallyLegalOpCallbackOrder)28 TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
29   MLIRContext context;
30   ConversionTarget target(context);
31 
32   int index = 0;
33   int callbackCalled1 = 0;
34   target.addDynamicallyLegalOp<DummyOp>([&](Operation *) {
35     callbackCalled1 = ++index;
36     return true;
37   });
38 
39   int callbackCalled2 = 0;
40   target.addDynamicallyLegalOp<DummyOp>(
41       [&](Operation *) -> std::optional<bool> {
42         callbackCalled2 = ++index;
43         return std::nullopt;
44       });
45 
46   auto *op = createOp(&context);
47   EXPECT_TRUE(target.isLegal(op));
48   EXPECT_EQ(2, callbackCalled1);
49   EXPECT_EQ(1, callbackCalled2);
50   EXPECT_FALSE(target.isIllegal(op));
51   EXPECT_EQ(4, callbackCalled1);
52   EXPECT_EQ(3, callbackCalled2);
53   op->destroy();
54 }
55 
TEST(DialectConversionTest,DynamicallyLegalOpCallbackSkip)56 TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
57   MLIRContext context;
58   ConversionTarget target(context);
59 
60   int index = 0;
61   int callbackCalled = 0;
62   target.addDynamicallyLegalOp<DummyOp>(
63       [&](Operation *) -> std::optional<bool> {
64         callbackCalled = ++index;
65         return std::nullopt;
66       });
67 
68   auto *op = createOp(&context);
69   EXPECT_FALSE(target.isLegal(op));
70   EXPECT_EQ(1, callbackCalled);
71   EXPECT_FALSE(target.isIllegal(op));
72   EXPECT_EQ(2, callbackCalled);
73   op->destroy();
74 }
75 
TEST(DialectConversionTest,DynamicallyLegalUnknownOpCallbackOrder)76 TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
77   MLIRContext context;
78   ConversionTarget target(context);
79 
80   int index = 0;
81   int callbackCalled1 = 0;
82   target.markUnknownOpDynamicallyLegal([&](Operation *) {
83     callbackCalled1 = ++index;
84     return true;
85   });
86 
87   int callbackCalled2 = 0;
88   target.markUnknownOpDynamicallyLegal([&](Operation *) -> std::optional<bool> {
89     callbackCalled2 = ++index;
90     return std::nullopt;
91   });
92 
93   auto *op = createOp(&context);
94   EXPECT_TRUE(target.isLegal(op));
95   EXPECT_EQ(2, callbackCalled1);
96   EXPECT_EQ(1, callbackCalled2);
97   EXPECT_FALSE(target.isIllegal(op));
98   EXPECT_EQ(4, callbackCalled1);
99   EXPECT_EQ(3, callbackCalled2);
100   op->destroy();
101 }
102 
TEST(DialectConversionTest,DynamicallyLegalReturnNone)103 TEST(DialectConversionTest, DynamicallyLegalReturnNone) {
104   MLIRContext context;
105   ConversionTarget target(context);
106 
107   target.addDynamicallyLegalOp<DummyOp>(
108       [&](Operation *) -> std::optional<bool> { return std::nullopt; });
109 
110   auto *op = createOp(&context);
111   EXPECT_FALSE(target.isLegal(op));
112   EXPECT_FALSE(target.isIllegal(op));
113 
114   EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
115   EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
116 
117   op->destroy();
118 }
119 
TEST(DialectConversionTest,DynamicallyLegalUnknownReturnNone)120 TEST(DialectConversionTest, DynamicallyLegalUnknownReturnNone) {
121   MLIRContext context;
122   ConversionTarget target(context);
123 
124   target.markUnknownOpDynamicallyLegal(
125       [&](Operation *) -> std::optional<bool> { return std::nullopt; });
126 
127   auto *op = createOp(&context);
128   EXPECT_FALSE(target.isLegal(op));
129   EXPECT_FALSE(target.isIllegal(op));
130 
131   EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
132   EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
133 
134   op->destroy();
135 }
136 } // namespace
137