xref: /llvm-project/mlir/test/lib/IR/TestFunc.cpp (revision 34a35a8b244243f5a4ad5d531007bccfeaa0b02e)
1 //===- TestFunc.cpp - Pass to test helpers on function utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/Interfaces/FunctionInterfaces.h"
11 #include "mlir/Pass/Pass.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 /// This is a test pass for verifying FunctionOpInterface's insertArgument
17 /// method.
18 struct TestFuncInsertArg
19     : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncInsertArg20   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertArg)
21 
22   StringRef getArgument() const final { return "test-func-insert-arg"; }
getDescription__anon913d9fc10111::TestFuncInsertArg23   StringRef getDescription() const final { return "Test inserting func args."; }
runOnOperation__anon913d9fc10111::TestFuncInsertArg24   void runOnOperation() override {
25     auto module = getOperation();
26 
27     UnknownLoc unknownLoc = UnknownLoc::get(module.getContext());
28     for (auto func : module.getOps<FunctionOpInterface>()) {
29       auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args");
30       if (!inserts || inserts.empty())
31         continue;
32       SmallVector<unsigned, 4> indicesToInsert;
33       SmallVector<Type, 4> typesToInsert;
34       SmallVector<DictionaryAttr, 4> attrsToInsert;
35       SmallVector<Location, 4> locsToInsert;
36       for (auto insert : inserts.getAsRange<ArrayAttr>()) {
37         indicesToInsert.push_back(
38             cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
39         typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
40         attrsToInsert.push_back(insert.size() > 2
41                                     ? cast<DictionaryAttr>(insert[2])
42                                     : DictionaryAttr::get(&getContext()));
43         locsToInsert.push_back(insert.size() > 3
44                                    ? Location(cast<LocationAttr>(insert[3]))
45                                    : unknownLoc);
46       }
47       func->removeAttr("test.insert_args");
48       func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
49                            locsToInsert);
50     }
51   }
52 };
53 
54 /// This is a test pass for verifying FunctionOpInterface's insertResult method.
55 struct TestFuncInsertResult
56     : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncInsertResult57   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertResult)
58 
59   StringRef getArgument() const final { return "test-func-insert-result"; }
getDescription__anon913d9fc10111::TestFuncInsertResult60   StringRef getDescription() const final {
61     return "Test inserting func results.";
62   }
runOnOperation__anon913d9fc10111::TestFuncInsertResult63   void runOnOperation() override {
64     auto module = getOperation();
65 
66     for (auto func : module.getOps<FunctionOpInterface>()) {
67       auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results");
68       if (!inserts || inserts.empty())
69         continue;
70       SmallVector<unsigned, 4> indicesToInsert;
71       SmallVector<Type, 4> typesToInsert;
72       SmallVector<DictionaryAttr, 4> attrsToInsert;
73       for (auto insert : inserts.getAsRange<ArrayAttr>()) {
74         indicesToInsert.push_back(
75             cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
76         typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
77         attrsToInsert.push_back(insert.size() > 2
78                                     ? cast<DictionaryAttr>(insert[2])
79                                     : DictionaryAttr::get(&getContext()));
80       }
81       func->removeAttr("test.insert_results");
82       func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
83     }
84   }
85 };
86 
87 /// This is a test pass for verifying FunctionOpInterface's eraseArgument
88 /// method.
89 struct TestFuncEraseArg
90     : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncEraseArg91   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseArg)
92 
93   StringRef getArgument() const final { return "test-func-erase-arg"; }
getDescription__anon913d9fc10111::TestFuncEraseArg94   StringRef getDescription() const final { return "Test erasing func args."; }
runOnOperation__anon913d9fc10111::TestFuncEraseArg95   void runOnOperation() override {
96     auto module = getOperation();
97 
98     for (auto func : module.getOps<FunctionOpInterface>()) {
99       BitVector indicesToErase(func.getNumArguments());
100       for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
101         if (func.getArgAttr(argIndex, "test.erase_this_arg"))
102           indicesToErase.set(argIndex);
103       func.eraseArguments(indicesToErase);
104     }
105   }
106 };
107 
108 /// This is a test pass for verifying FunctionOpInterface's eraseResult method.
109 struct TestFuncEraseResult
110     : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncEraseResult111   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseResult)
112 
113   StringRef getArgument() const final { return "test-func-erase-result"; }
getDescription__anon913d9fc10111::TestFuncEraseResult114   StringRef getDescription() const final {
115     return "Test erasing func results.";
116   }
runOnOperation__anon913d9fc10111::TestFuncEraseResult117   void runOnOperation() override {
118     auto module = getOperation();
119 
120     for (auto func : module.getOps<FunctionOpInterface>()) {
121       BitVector indicesToErase(func.getNumResults());
122       for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
123         if (func.getResultAttr(resultIndex, "test.erase_this_result"))
124           indicesToErase.set(resultIndex);
125       func.eraseResults(indicesToErase);
126     }
127   }
128 };
129 
130 /// This is a test pass for verifying FunctionOpInterface's setType method.
131 struct TestFuncSetType
132     : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncSetType133   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncSetType)
134 
135   StringRef getArgument() const final { return "test-func-set-type"; }
getDescription__anon913d9fc10111::TestFuncSetType136   StringRef getDescription() const final {
137     return "Test FunctionOpInterface::setType.";
138   }
runOnOperation__anon913d9fc10111::TestFuncSetType139   void runOnOperation() override {
140     auto module = getOperation();
141     SymbolTable symbolTable(module);
142 
143     for (auto func : module.getOps<FunctionOpInterface>()) {
144       auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
145       if (!sym)
146         continue;
147       func.setType(symbolTable.lookup<FunctionOpInterface>(sym.getValue())
148                        .getFunctionType());
149     }
150   }
151 };
152 } // namespace
153 
154 namespace mlir {
registerTestFunc()155 void registerTestFunc() {
156   PassRegistration<TestFuncInsertArg>();
157 
158   PassRegistration<TestFuncInsertResult>();
159 
160   PassRegistration<TestFuncEraseArg>();
161 
162   PassRegistration<TestFuncEraseResult>();
163 
164   PassRegistration<TestFuncSetType>();
165 }
166 } // namespace mlir
167