xref: /llvm-project/mlir/test/lib/IR/TestFunc.cpp (revision 34a35a8b244243f5a4ad5d531007bccfeaa0b02e)
17ceffae1SRiver Riddle //===- TestFunc.cpp - Pass to test helpers on function utilities ----------===//
2486f2122SSean Silva //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6486f2122SSean Silva //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
8486f2122SSean Silva 
965fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
10*34a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h"
11486f2122SSean Silva #include "mlir/Pass/Pass.h"
12486f2122SSean Silva 
13486f2122SSean Silva using namespace mlir;
14486f2122SSean Silva 
15486f2122SSean Silva namespace {
1687d6bf37SRiver Riddle /// This is a test pass for verifying FunctionOpInterface's insertArgument
1787d6bf37SRiver Riddle /// method.
188066f22cSFabian Schuiki struct TestFuncInsertArg
198066f22cSFabian Schuiki     : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncInsertArg205e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertArg)
215e50dd04SRiver Riddle 
228066f22cSFabian Schuiki   StringRef getArgument() const final { return "test-func-insert-arg"; }
getDescription__anon913d9fc10111::TestFuncInsertArg238066f22cSFabian Schuiki   StringRef getDescription() const final { return "Test inserting func args."; }
runOnOperation__anon913d9fc10111::TestFuncInsertArg248066f22cSFabian Schuiki   void runOnOperation() override {
258066f22cSFabian Schuiki     auto module = getOperation();
268066f22cSFabian Schuiki 
27e084679fSRiver Riddle     UnknownLoc unknownLoc = UnknownLoc::get(module.getContext());
2887d6bf37SRiver Riddle     for (auto func : module.getOps<FunctionOpInterface>()) {
298066f22cSFabian Schuiki       auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args");
308066f22cSFabian Schuiki       if (!inserts || inserts.empty())
318066f22cSFabian Schuiki         continue;
328066f22cSFabian Schuiki       SmallVector<unsigned, 4> indicesToInsert;
338066f22cSFabian Schuiki       SmallVector<Type, 4> typesToInsert;
348066f22cSFabian Schuiki       SmallVector<DictionaryAttr, 4> attrsToInsert;
35e084679fSRiver Riddle       SmallVector<Location, 4> locsToInsert;
368066f22cSFabian Schuiki       for (auto insert : inserts.getAsRange<ArrayAttr>()) {
378066f22cSFabian Schuiki         indicesToInsert.push_back(
385550c821STres Popp             cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
395550c821STres Popp         typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
408066f22cSFabian Schuiki         attrsToInsert.push_back(insert.size() > 2
415550c821STres Popp                                     ? cast<DictionaryAttr>(insert[2])
428066f22cSFabian Schuiki                                     : DictionaryAttr::get(&getContext()));
43e084679fSRiver Riddle         locsToInsert.push_back(insert.size() > 3
445550c821STres Popp                                    ? Location(cast<LocationAttr>(insert[3]))
45e084679fSRiver Riddle                                    : unknownLoc);
468066f22cSFabian Schuiki       }
478066f22cSFabian Schuiki       func->removeAttr("test.insert_args");
488066f22cSFabian Schuiki       func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
498066f22cSFabian Schuiki                            locsToInsert);
508066f22cSFabian Schuiki     }
518066f22cSFabian Schuiki   }
528066f22cSFabian Schuiki };
538066f22cSFabian Schuiki 
5487d6bf37SRiver Riddle /// This is a test pass for verifying FunctionOpInterface's insertResult method.
558066f22cSFabian Schuiki struct TestFuncInsertResult
568066f22cSFabian Schuiki     : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncInsertResult575e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertResult)
585e50dd04SRiver Riddle 
598066f22cSFabian Schuiki   StringRef getArgument() const final { return "test-func-insert-result"; }
getDescription__anon913d9fc10111::TestFuncInsertResult608066f22cSFabian Schuiki   StringRef getDescription() const final {
618066f22cSFabian Schuiki     return "Test inserting func results.";
628066f22cSFabian Schuiki   }
runOnOperation__anon913d9fc10111::TestFuncInsertResult638066f22cSFabian Schuiki   void runOnOperation() override {
648066f22cSFabian Schuiki     auto module = getOperation();
658066f22cSFabian Schuiki 
6687d6bf37SRiver Riddle     for (auto func : module.getOps<FunctionOpInterface>()) {
678066f22cSFabian Schuiki       auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results");
688066f22cSFabian Schuiki       if (!inserts || inserts.empty())
698066f22cSFabian Schuiki         continue;
708066f22cSFabian Schuiki       SmallVector<unsigned, 4> indicesToInsert;
718066f22cSFabian Schuiki       SmallVector<Type, 4> typesToInsert;
728066f22cSFabian Schuiki       SmallVector<DictionaryAttr, 4> attrsToInsert;
738066f22cSFabian Schuiki       for (auto insert : inserts.getAsRange<ArrayAttr>()) {
748066f22cSFabian Schuiki         indicesToInsert.push_back(
755550c821STres Popp             cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
765550c821STres Popp         typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
778066f22cSFabian Schuiki         attrsToInsert.push_back(insert.size() > 2
785550c821STres Popp                                     ? cast<DictionaryAttr>(insert[2])
798066f22cSFabian Schuiki                                     : DictionaryAttr::get(&getContext()));
808066f22cSFabian Schuiki       }
818066f22cSFabian Schuiki       func->removeAttr("test.insert_results");
828066f22cSFabian Schuiki       func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
838066f22cSFabian Schuiki     }
848066f22cSFabian Schuiki   }
858066f22cSFabian Schuiki };
868066f22cSFabian Schuiki 
8787d6bf37SRiver Riddle /// This is a test pass for verifying FunctionOpInterface's eraseArgument
8887d6bf37SRiver Riddle /// method.
8980aca1eaSRiver Riddle struct TestFuncEraseArg
9080aca1eaSRiver Riddle     : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncEraseArg915e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseArg)
925e50dd04SRiver Riddle 
93b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-func-erase-arg"; }
getDescription__anon913d9fc10111::TestFuncEraseArg94b5e22e6dSMehdi Amini   StringRef getDescription() const final { return "Test erasing func args."; }
runOnOperation__anon913d9fc10111::TestFuncEraseArg95722f909fSRiver Riddle   void runOnOperation() override {
96722f909fSRiver Riddle     auto module = getOperation();
97486f2122SSean Silva 
9887d6bf37SRiver Riddle     for (auto func : module.getOps<FunctionOpInterface>()) {
99d10d49dcSRiver Riddle       BitVector indicesToErase(func.getNumArguments());
100e3cd80eaSRiver Riddle       for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
101e3cd80eaSRiver Riddle         if (func.getArgAttr(argIndex, "test.erase_this_arg"))
102e3cd80eaSRiver Riddle           indicesToErase.set(argIndex);
103486f2122SSean Silva       func.eraseArguments(indicesToErase);
104486f2122SSean Silva     }
105486f2122SSean Silva   }
106486f2122SSean Silva };
107486f2122SSean Silva 
10887d6bf37SRiver Riddle /// This is a test pass for verifying FunctionOpInterface's eraseResult method.
1091253c407SSean Silva struct TestFuncEraseResult
1101253c407SSean Silva     : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncEraseResult1115e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseResult)
1125e50dd04SRiver Riddle 
113b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-func-erase-result"; }
getDescription__anon913d9fc10111::TestFuncEraseResult114b5e22e6dSMehdi Amini   StringRef getDescription() const final {
115b5e22e6dSMehdi Amini     return "Test erasing func results.";
116b5e22e6dSMehdi Amini   }
runOnOperation__anon913d9fc10111::TestFuncEraseResult1171253c407SSean Silva   void runOnOperation() override {
1181253c407SSean Silva     auto module = getOperation();
1191253c407SSean Silva 
12087d6bf37SRiver Riddle     for (auto func : module.getOps<FunctionOpInterface>()) {
121d10d49dcSRiver Riddle       BitVector indicesToErase(func.getNumResults());
122e3cd80eaSRiver Riddle       for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
123e3cd80eaSRiver Riddle         if (func.getResultAttr(resultIndex, "test.erase_this_result"))
124e3cd80eaSRiver Riddle           indicesToErase.set(resultIndex);
1251253c407SSean Silva       func.eraseResults(indicesToErase);
1261253c407SSean Silva     }
1271253c407SSean Silva   }
1281253c407SSean Silva };
1291253c407SSean Silva 
13087d6bf37SRiver Riddle /// This is a test pass for verifying FunctionOpInterface's setType method.
13180aca1eaSRiver Riddle struct TestFuncSetType
13280aca1eaSRiver Riddle     : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncSetType1335e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncSetType)
1345e50dd04SRiver Riddle 
135b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-func-set-type"; }
getDescription__anon913d9fc10111::TestFuncSetType13687d6bf37SRiver Riddle   StringRef getDescription() const final {
13787d6bf37SRiver Riddle     return "Test FunctionOpInterface::setType.";
13887d6bf37SRiver Riddle   }
runOnOperation__anon913d9fc10111::TestFuncSetType139722f909fSRiver Riddle   void runOnOperation() override {
140722f909fSRiver Riddle     auto module = getOperation();
141486f2122SSean Silva     SymbolTable symbolTable(module);
142486f2122SSean Silva 
14387d6bf37SRiver Riddle     for (auto func : module.getOps<FunctionOpInterface>()) {
1440bf4a82aSChristian Sigg       auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
145486f2122SSean Silva       if (!sym)
146486f2122SSean Silva         continue;
1474a3460a7SRiver Riddle       func.setType(symbolTable.lookup<FunctionOpInterface>(sym.getValue())
1484a3460a7SRiver Riddle                        .getFunctionType());
149486f2122SSean Silva     }
150486f2122SSean Silva   }
151486f2122SSean Silva };
152be0a7e9fSMehdi Amini } // namespace
153486f2122SSean Silva 
154c6477050SMehdi Amini namespace mlir {
registerTestFunc()155c6477050SMehdi Amini void registerTestFunc() {
1568066f22cSFabian Schuiki   PassRegistration<TestFuncInsertArg>();
1578066f22cSFabian Schuiki 
1588066f22cSFabian Schuiki   PassRegistration<TestFuncInsertResult>();
1598066f22cSFabian Schuiki 
160b5e22e6dSMehdi Amini   PassRegistration<TestFuncEraseArg>();
161486f2122SSean Silva 
162b5e22e6dSMehdi Amini   PassRegistration<TestFuncEraseResult>();
1631253c407SSean Silva 
164b5e22e6dSMehdi Amini   PassRegistration<TestFuncSetType>();
165c6477050SMehdi Amini }
166c6477050SMehdi Amini } // namespace mlir
167