xref: /llvm-project/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1 //===- BuildOnlyExtensionTest.cpp - unit test for transform extensions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
11 #include "mlir/IR/DialectRegistry.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "gtest/gtest.h"
14 
15 using namespace mlir;
16 using namespace mlir::transform;
17 
18 namespace {
19 class Extension : public TransformDialectExtension<Extension> {
20 public:
21   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Extension)
22 
23   using Base::Base;
24   void init() { declareGeneratedDialect<func::FuncDialect>(); }
25 };
26 } // end namespace
27 
28 TEST(BuildOnlyExtensionTest, buildOnlyExtension) {
29   // Register the build-only version of the transform dialect extension. The
30   // func dialect is declared as generated so it should not be loaded along with
31   // the transform dialect.
32   DialectRegistry registry;
33   registry.addExtensions<BuildOnly<Extension>>();
34   MLIRContext ctx(registry);
35   ctx.getOrLoadDialect<TransformDialect>();
36   ASSERT_FALSE(ctx.getLoadedDialect<func::FuncDialect>());
37 }
38 
39 TEST(BuildOnlyExtensionTest, buildAndApplyExtension) {
40   // Register the full version of the transform dialect extension. The func
41   // dialect should be loaded along with the transform dialect.
42   DialectRegistry registry;
43   registry.addExtensions<Extension>();
44   MLIRContext ctx(registry);
45   ctx.getOrLoadDialect<TransformDialect>();
46   ASSERT_TRUE(ctx.getLoadedDialect<func::FuncDialect>());
47 }
48