xref: /llvm-project/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1333ee218SAlex Zinenko //===- BuildOnlyExtensionTest.cpp - unit test for transform extensions ----===//
2333ee218SAlex Zinenko //
3333ee218SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4333ee218SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5333ee218SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6333ee218SAlex Zinenko //
7333ee218SAlex Zinenko //===----------------------------------------------------------------------===//
8333ee218SAlex Zinenko 
9333ee218SAlex Zinenko #include "mlir/Dialect/Func/IR/FuncOps.h"
10333ee218SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h"
11333ee218SAlex Zinenko #include "mlir/IR/DialectRegistry.h"
12333ee218SAlex Zinenko #include "mlir/IR/MLIRContext.h"
13333ee218SAlex Zinenko #include "gtest/gtest.h"
14333ee218SAlex Zinenko 
15333ee218SAlex Zinenko using namespace mlir;
16333ee218SAlex Zinenko using namespace mlir::transform;
17333ee218SAlex Zinenko 
18333ee218SAlex Zinenko namespace {
19333ee218SAlex Zinenko class Extension : public TransformDialectExtension<Extension> {
20333ee218SAlex Zinenko public:
21*84cc1865SNikhil Kalra   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Extension)
22*84cc1865SNikhil Kalra 
23333ee218SAlex Zinenko   using Base::Base;
24333ee218SAlex Zinenko   void init() { declareGeneratedDialect<func::FuncDialect>(); }
25333ee218SAlex Zinenko };
26333ee218SAlex Zinenko } // end namespace
27333ee218SAlex Zinenko 
28333ee218SAlex Zinenko TEST(BuildOnlyExtensionTest, buildOnlyExtension) {
29333ee218SAlex Zinenko   // Register the build-only version of the transform dialect extension. The
30333ee218SAlex Zinenko   // func dialect is declared as generated so it should not be loaded along with
31333ee218SAlex Zinenko   // the transform dialect.
32333ee218SAlex Zinenko   DialectRegistry registry;
33333ee218SAlex Zinenko   registry.addExtensions<BuildOnly<Extension>>();
34333ee218SAlex Zinenko   MLIRContext ctx(registry);
35333ee218SAlex Zinenko   ctx.getOrLoadDialect<TransformDialect>();
36333ee218SAlex Zinenko   ASSERT_FALSE(ctx.getLoadedDialect<func::FuncDialect>());
37333ee218SAlex Zinenko }
38333ee218SAlex Zinenko 
39333ee218SAlex Zinenko TEST(BuildOnlyExtensionTest, buildAndApplyExtension) {
40333ee218SAlex Zinenko   // Register the full version of the transform dialect extension. The func
41333ee218SAlex Zinenko   // dialect should be loaded along with the transform dialect.
42333ee218SAlex Zinenko   DialectRegistry registry;
43333ee218SAlex Zinenko   registry.addExtensions<Extension>();
44333ee218SAlex Zinenko   MLIRContext ctx(registry);
45333ee218SAlex Zinenko   ctx.getOrLoadDialect<TransformDialect>();
46333ee218SAlex Zinenko   ASSERT_TRUE(ctx.getLoadedDialect<func::FuncDialect>());
47333ee218SAlex Zinenko }
48