xref: /llvm-project/mlir/unittests/Dialect/Transform/Preload.cpp (revision 2798b72ae7e5caad793169b77cbac47fe2362d0f)
1 //===- Preload.cpp - Test MlirOptMain parameterization ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
10 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
11 #include "mlir/Dialect/Transform/IR/Utils.h"
12 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
13 #include "mlir/IR/AsmState.h"
14 #include "mlir/IR/DialectRegistry.h"
15 #include "mlir/IR/Verifier.h"
16 #include "mlir/Parser/Parser.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Support/FileUtilities.h"
20 #include "mlir/Support/TypeID.h"
21 #include "mlir/Tools/mlir-opt/MlirOptMain.h"
22 #include "llvm/Support/MemoryBuffer.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include "gtest/gtest.h"
25 
26 using namespace mlir;
27 
28 namespace mlir {
29 namespace test {
30 std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
31 } // namespace test
32 } // namespace mlir
33 
34 const static llvm::StringLiteral library = R"MLIR(
35 module attributes {transform.with_named_sequence} {
36   transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
37     transform.debug.emit_remark_at %arg0, "from external symbol" : !transform.any_op
38     transform.yield
39   }
40 })MLIR";
41 
42 const static llvm::StringLiteral input = R"MLIR(
43 module attributes {transform.with_named_sequence} {
44   transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})
45 
46   transform.sequence failures(propagate) {
47   ^bb0(%arg0: !transform.any_op):
48     include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
49   }
50 })MLIR";
51 
TEST(Preload,ContextPreloadConstructedLibrary)52 TEST(Preload, ContextPreloadConstructedLibrary) {
53   registerPassManagerCLOptions();
54 
55   MLIRContext context;
56   auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
57   DialectRegistry registry;
58   mlir::transform::registerDebugExtension(registry);
59   registry.applyExtensions(&context);
60   ParserConfig parserConfig(&context);
61 
62   OwningOpRef<ModuleOp> inputModule =
63       parseSourceString<ModuleOp>(input, parserConfig, "<input>");
64   EXPECT_TRUE(inputModule) << "failed to parse input module";
65 
66   OwningOpRef<ModuleOp> transformLibrary =
67       parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
68   EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
69   LogicalResult diag =
70       dialect->loadIntoLibraryModule(std::move(transformLibrary));
71   EXPECT_TRUE(succeeded(diag));
72 
73   ModuleOp retrievedTransformLibrary =
74       transform::detail::getPreloadedTransformModule(&context);
75   EXPECT_TRUE(retrievedTransformLibrary)
76       << "failed to retrieve transform module";
77 
78   OwningOpRef<Operation *> clonedTransformModule(
79       retrievedTransformLibrary->clone());
80 
81   LogicalResult res = transform::detail::mergeSymbolsInto(
82       inputModule->getOperation(), std::move(clonedTransformModule));
83   EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";
84 
85   transform::TransformOpInterface entryPoint =
86       transform::detail::findTransformEntryPoint(inputModule->getOperation(),
87                                                  retrievedTransformLibrary);
88   EXPECT_TRUE(entryPoint) << "failed to find entry point";
89 
90   transform::TransformOptions options;
91   res = transform::applyTransformNamedSequence(
92       inputModule->getOperation(), entryPoint, retrievedTransformLibrary,
93       options);
94   EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
95 }
96