1 //===- TestTopologicalSort.cpp - Pass to test topological sort analysis ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/TopologicalSortUtils.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/BuiltinOps.h" 12 #include "mlir/Pass/Pass.h" 13 14 using namespace mlir; 15 16 namespace { 17 struct TestTopologicalSortAnalysisPass 18 : public PassWrapper<TestTopologicalSortAnalysisPass, 19 OperationPass<ModuleOp>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon54dfb7230111::TestTopologicalSortAnalysisPass20 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortAnalysisPass) 21 22 StringRef getArgument() const final { 23 return "test-topological-sort-analysis"; 24 } getDescription__anon54dfb7230111::TestTopologicalSortAnalysisPass25 StringRef getDescription() const final { 26 return "Test topological sorting of ops"; 27 } 28 runOnOperation__anon54dfb7230111::TestTopologicalSortAnalysisPass29 void runOnOperation() override { 30 Operation *op = getOperation(); 31 OpBuilder builder(op->getContext()); 32 33 WalkResult result = op->walk([&](Operation *root) { 34 if (!root->hasAttr("root")) 35 return WalkResult::advance(); 36 37 SmallVector<Operation *> selectedOps; 38 root->walk([&](Operation *selected) { 39 if (!selected->hasAttr("selected")) 40 return WalkResult::advance(); 41 if (root->hasAttr("ordered")) { 42 // If the root has an "ordered" attribute, we fill the selectedOps 43 // vector in a certain order. 44 int64_t pos = 45 cast<IntegerAttr>(selected->getDiscardableAttr("selected")) 46 .getInt(); 47 if (pos >= static_cast<int64_t>(selectedOps.size())) 48 selectedOps.append(pos + 1 - selectedOps.size(), nullptr); 49 selectedOps[pos] = selected; 50 } else { 51 selectedOps.push_back(selected); 52 } 53 return WalkResult::advance(); 54 }); 55 56 if (llvm::find(selectedOps, nullptr) != selectedOps.end()) { 57 root->emitError("invalid test case: some indices are missing among the " 58 "selected ops"); 59 return WalkResult::skip(); 60 } 61 62 if (!computeTopologicalSorting(selectedOps)) { 63 root->emitError("could not schedule all ops"); 64 return WalkResult::skip(); 65 } 66 67 for (const auto &it : llvm::enumerate(selectedOps)) 68 it.value()->setAttr("pos", builder.getIndexAttr(it.index())); 69 70 return WalkResult::advance(); 71 }); 72 73 if (result.wasSkipped()) 74 signalPassFailure(); 75 } 76 }; 77 } // namespace 78 79 namespace mlir { 80 namespace test { registerTestTopologicalSortAnalysisPass()81void registerTestTopologicalSortAnalysisPass() { 82 PassRegistration<TestTopologicalSortAnalysisPass>(); 83 } 84 } // namespace test 85 } // namespace mlir 86