xref: /llvm-project/mlir/test/lib/IR/TestTypes.cpp (revision e95e94adc6bb748de015ac3053e7f0786b65f351)
1 //===- TestTypes.cpp - Test passes for MLIR types -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestTypes.h"
10 #include "TestDialect.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Pass/Pass.h"
13 
14 using namespace mlir;
15 using namespace test;
16 
17 namespace {
18 struct TestRecursiveTypesPass
19     : public PassWrapper<TestRecursiveTypesPass, OperationPass<func::FuncOp>> {
20   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRecursiveTypesPass)
21 
22   LogicalResult createIRWithTypes();
23 
getArgument__anon6c29380a0111::TestRecursiveTypesPass24   StringRef getArgument() const final { return "test-recursive-types"; }
getDescription__anon6c29380a0111::TestRecursiveTypesPass25   StringRef getDescription() const final {
26     return "Test support for recursive types";
27   }
runOnOperation__anon6c29380a0111::TestRecursiveTypesPass28   void runOnOperation() override {
29     func::FuncOp func = getOperation();
30 
31     // Just make sure recursive types are printed and parsed.
32     if (func.getName() == "roundtrip")
33       return;
34 
35     // Create a recursive type and print it as a part of a dummy op.
36     if (func.getName() == "create") {
37       if (failed(createIRWithTypes()))
38         signalPassFailure();
39       return;
40     }
41 
42     // Unknown key.
43     func.emitOpError() << "unexpected function name";
44     signalPassFailure();
45   }
46 };
47 } // namespace
48 
createIRWithTypes()49 LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
50   MLIRContext *ctx = &getContext();
51   func::FuncOp func = getOperation();
52   auto type = TestRecursiveType::get(ctx, "some_long_and_unique_name");
53   if (failed(type.setBody(type)))
54     return func.emitError("expected to be able to set the type body");
55 
56   // Setting the same body is fine.
57   if (failed(type.setBody(type)))
58     return func.emitError(
59         "expected to be able to set the type body to the same value");
60 
61   // Setting a different body is not.
62   if (succeeded(type.setBody(IndexType::get(ctx))))
63     return func.emitError(
64         "not expected to be able to change function body more than once");
65 
66   // Expecting to get the same type for the same name.
67   auto other = TestRecursiveType::get(ctx, "some_long_and_unique_name");
68   if (type != other)
69     return func.emitError("expected type name to be the uniquing key");
70 
71   // Create the op to check how the type is printed.
72   OperationState state(func.getLoc(), "test.dummy_type_test_op");
73   state.addTypes(type);
74   func.getBody().front().push_front(Operation::create(state));
75 
76   return success();
77 }
78 
79 namespace mlir {
80 namespace test {
81 
registerTestRecursiveTypesPass()82 void registerTestRecursiveTypesPass() {
83   PassRegistration<TestRecursiveTypesPass>();
84 }
85 
86 } // namespace test
87 } // namespace mlir
88