17ceffae1SRiver Riddle //===- TestFunc.cpp - Pass to test helpers on function utilities ----------===//
2486f2122SSean Silva //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6486f2122SSean Silva //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
8486f2122SSean Silva
965fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
10*34a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h"
11486f2122SSean Silva #include "mlir/Pass/Pass.h"
12486f2122SSean Silva
13486f2122SSean Silva using namespace mlir;
14486f2122SSean Silva
15486f2122SSean Silva namespace {
1687d6bf37SRiver Riddle /// This is a test pass for verifying FunctionOpInterface's insertArgument
1787d6bf37SRiver Riddle /// method.
188066f22cSFabian Schuiki struct TestFuncInsertArg
198066f22cSFabian Schuiki : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncInsertArg205e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertArg)
215e50dd04SRiver Riddle
228066f22cSFabian Schuiki StringRef getArgument() const final { return "test-func-insert-arg"; }
getDescription__anon913d9fc10111::TestFuncInsertArg238066f22cSFabian Schuiki StringRef getDescription() const final { return "Test inserting func args."; }
runOnOperation__anon913d9fc10111::TestFuncInsertArg248066f22cSFabian Schuiki void runOnOperation() override {
258066f22cSFabian Schuiki auto module = getOperation();
268066f22cSFabian Schuiki
27e084679fSRiver Riddle UnknownLoc unknownLoc = UnknownLoc::get(module.getContext());
2887d6bf37SRiver Riddle for (auto func : module.getOps<FunctionOpInterface>()) {
298066f22cSFabian Schuiki auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args");
308066f22cSFabian Schuiki if (!inserts || inserts.empty())
318066f22cSFabian Schuiki continue;
328066f22cSFabian Schuiki SmallVector<unsigned, 4> indicesToInsert;
338066f22cSFabian Schuiki SmallVector<Type, 4> typesToInsert;
348066f22cSFabian Schuiki SmallVector<DictionaryAttr, 4> attrsToInsert;
35e084679fSRiver Riddle SmallVector<Location, 4> locsToInsert;
368066f22cSFabian Schuiki for (auto insert : inserts.getAsRange<ArrayAttr>()) {
378066f22cSFabian Schuiki indicesToInsert.push_back(
385550c821STres Popp cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
395550c821STres Popp typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
408066f22cSFabian Schuiki attrsToInsert.push_back(insert.size() > 2
415550c821STres Popp ? cast<DictionaryAttr>(insert[2])
428066f22cSFabian Schuiki : DictionaryAttr::get(&getContext()));
43e084679fSRiver Riddle locsToInsert.push_back(insert.size() > 3
445550c821STres Popp ? Location(cast<LocationAttr>(insert[3]))
45e084679fSRiver Riddle : unknownLoc);
468066f22cSFabian Schuiki }
478066f22cSFabian Schuiki func->removeAttr("test.insert_args");
488066f22cSFabian Schuiki func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
498066f22cSFabian Schuiki locsToInsert);
508066f22cSFabian Schuiki }
518066f22cSFabian Schuiki }
528066f22cSFabian Schuiki };
538066f22cSFabian Schuiki
5487d6bf37SRiver Riddle /// This is a test pass for verifying FunctionOpInterface's insertResult method.
558066f22cSFabian Schuiki struct TestFuncInsertResult
568066f22cSFabian Schuiki : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncInsertResult575e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertResult)
585e50dd04SRiver Riddle
598066f22cSFabian Schuiki StringRef getArgument() const final { return "test-func-insert-result"; }
getDescription__anon913d9fc10111::TestFuncInsertResult608066f22cSFabian Schuiki StringRef getDescription() const final {
618066f22cSFabian Schuiki return "Test inserting func results.";
628066f22cSFabian Schuiki }
runOnOperation__anon913d9fc10111::TestFuncInsertResult638066f22cSFabian Schuiki void runOnOperation() override {
648066f22cSFabian Schuiki auto module = getOperation();
658066f22cSFabian Schuiki
6687d6bf37SRiver Riddle for (auto func : module.getOps<FunctionOpInterface>()) {
678066f22cSFabian Schuiki auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results");
688066f22cSFabian Schuiki if (!inserts || inserts.empty())
698066f22cSFabian Schuiki continue;
708066f22cSFabian Schuiki SmallVector<unsigned, 4> indicesToInsert;
718066f22cSFabian Schuiki SmallVector<Type, 4> typesToInsert;
728066f22cSFabian Schuiki SmallVector<DictionaryAttr, 4> attrsToInsert;
738066f22cSFabian Schuiki for (auto insert : inserts.getAsRange<ArrayAttr>()) {
748066f22cSFabian Schuiki indicesToInsert.push_back(
755550c821STres Popp cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
765550c821STres Popp typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
778066f22cSFabian Schuiki attrsToInsert.push_back(insert.size() > 2
785550c821STres Popp ? cast<DictionaryAttr>(insert[2])
798066f22cSFabian Schuiki : DictionaryAttr::get(&getContext()));
808066f22cSFabian Schuiki }
818066f22cSFabian Schuiki func->removeAttr("test.insert_results");
828066f22cSFabian Schuiki func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
838066f22cSFabian Schuiki }
848066f22cSFabian Schuiki }
858066f22cSFabian Schuiki };
868066f22cSFabian Schuiki
8787d6bf37SRiver Riddle /// This is a test pass for verifying FunctionOpInterface's eraseArgument
8887d6bf37SRiver Riddle /// method.
8980aca1eaSRiver Riddle struct TestFuncEraseArg
9080aca1eaSRiver Riddle : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncEraseArg915e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseArg)
925e50dd04SRiver Riddle
93b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-func-erase-arg"; }
getDescription__anon913d9fc10111::TestFuncEraseArg94b5e22e6dSMehdi Amini StringRef getDescription() const final { return "Test erasing func args."; }
runOnOperation__anon913d9fc10111::TestFuncEraseArg95722f909fSRiver Riddle void runOnOperation() override {
96722f909fSRiver Riddle auto module = getOperation();
97486f2122SSean Silva
9887d6bf37SRiver Riddle for (auto func : module.getOps<FunctionOpInterface>()) {
99d10d49dcSRiver Riddle BitVector indicesToErase(func.getNumArguments());
100e3cd80eaSRiver Riddle for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
101e3cd80eaSRiver Riddle if (func.getArgAttr(argIndex, "test.erase_this_arg"))
102e3cd80eaSRiver Riddle indicesToErase.set(argIndex);
103486f2122SSean Silva func.eraseArguments(indicesToErase);
104486f2122SSean Silva }
105486f2122SSean Silva }
106486f2122SSean Silva };
107486f2122SSean Silva
10887d6bf37SRiver Riddle /// This is a test pass for verifying FunctionOpInterface's eraseResult method.
1091253c407SSean Silva struct TestFuncEraseResult
1101253c407SSean Silva : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncEraseResult1115e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseResult)
1125e50dd04SRiver Riddle
113b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-func-erase-result"; }
getDescription__anon913d9fc10111::TestFuncEraseResult114b5e22e6dSMehdi Amini StringRef getDescription() const final {
115b5e22e6dSMehdi Amini return "Test erasing func results.";
116b5e22e6dSMehdi Amini }
runOnOperation__anon913d9fc10111::TestFuncEraseResult1171253c407SSean Silva void runOnOperation() override {
1181253c407SSean Silva auto module = getOperation();
1191253c407SSean Silva
12087d6bf37SRiver Riddle for (auto func : module.getOps<FunctionOpInterface>()) {
121d10d49dcSRiver Riddle BitVector indicesToErase(func.getNumResults());
122e3cd80eaSRiver Riddle for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
123e3cd80eaSRiver Riddle if (func.getResultAttr(resultIndex, "test.erase_this_result"))
124e3cd80eaSRiver Riddle indicesToErase.set(resultIndex);
1251253c407SSean Silva func.eraseResults(indicesToErase);
1261253c407SSean Silva }
1271253c407SSean Silva }
1281253c407SSean Silva };
1291253c407SSean Silva
13087d6bf37SRiver Riddle /// This is a test pass for verifying FunctionOpInterface's setType method.
13180aca1eaSRiver Riddle struct TestFuncSetType
13280aca1eaSRiver Riddle : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon913d9fc10111::TestFuncSetType1335e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncSetType)
1345e50dd04SRiver Riddle
135b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-func-set-type"; }
getDescription__anon913d9fc10111::TestFuncSetType13687d6bf37SRiver Riddle StringRef getDescription() const final {
13787d6bf37SRiver Riddle return "Test FunctionOpInterface::setType.";
13887d6bf37SRiver Riddle }
runOnOperation__anon913d9fc10111::TestFuncSetType139722f909fSRiver Riddle void runOnOperation() override {
140722f909fSRiver Riddle auto module = getOperation();
141486f2122SSean Silva SymbolTable symbolTable(module);
142486f2122SSean Silva
14387d6bf37SRiver Riddle for (auto func : module.getOps<FunctionOpInterface>()) {
1440bf4a82aSChristian Sigg auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
145486f2122SSean Silva if (!sym)
146486f2122SSean Silva continue;
1474a3460a7SRiver Riddle func.setType(symbolTable.lookup<FunctionOpInterface>(sym.getValue())
1484a3460a7SRiver Riddle .getFunctionType());
149486f2122SSean Silva }
150486f2122SSean Silva }
151486f2122SSean Silva };
152be0a7e9fSMehdi Amini } // namespace
153486f2122SSean Silva
154c6477050SMehdi Amini namespace mlir {
registerTestFunc()155c6477050SMehdi Amini void registerTestFunc() {
1568066f22cSFabian Schuiki PassRegistration<TestFuncInsertArg>();
1578066f22cSFabian Schuiki
1588066f22cSFabian Schuiki PassRegistration<TestFuncInsertResult>();
1598066f22cSFabian Schuiki
160b5e22e6dSMehdi Amini PassRegistration<TestFuncEraseArg>();
161486f2122SSean Silva
162b5e22e6dSMehdi Amini PassRegistration<TestFuncEraseResult>();
1631253c407SSean Silva
164b5e22e6dSMehdi Amini PassRegistration<TestFuncSetType>();
165c6477050SMehdi Amini }
166c6477050SMehdi Amini } // namespace mlir
167