xref: /llvm-project/mlir/test/lib/Analysis/TestTopologicalSort.cpp (revision b00e0c167186d69e1e6bceda57c09b272bd6acfc)
1 //===- TestTopologicalSort.cpp - Pass to test topological sort analysis ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/TopologicalSortUtils.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/Pass/Pass.h"
13 
14 using namespace mlir;
15 
16 namespace {
17 struct TestTopologicalSortAnalysisPass
18     : public PassWrapper<TestTopologicalSortAnalysisPass,
19                          OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon54dfb7230111::TestTopologicalSortAnalysisPass20   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortAnalysisPass)
21 
22   StringRef getArgument() const final {
23     return "test-topological-sort-analysis";
24   }
getDescription__anon54dfb7230111::TestTopologicalSortAnalysisPass25   StringRef getDescription() const final {
26     return "Test topological sorting of ops";
27   }
28 
runOnOperation__anon54dfb7230111::TestTopologicalSortAnalysisPass29   void runOnOperation() override {
30     Operation *op = getOperation();
31     OpBuilder builder(op->getContext());
32 
33     WalkResult result = op->walk([&](Operation *root) {
34       if (!root->hasAttr("root"))
35         return WalkResult::advance();
36 
37       SmallVector<Operation *> selectedOps;
38       root->walk([&](Operation *selected) {
39         if (!selected->hasAttr("selected"))
40           return WalkResult::advance();
41         if (root->hasAttr("ordered")) {
42           // If the root has an "ordered" attribute, we fill the selectedOps
43           // vector in a certain order.
44           int64_t pos =
45               cast<IntegerAttr>(selected->getDiscardableAttr("selected"))
46                   .getInt();
47           if (pos >= static_cast<int64_t>(selectedOps.size()))
48             selectedOps.append(pos + 1 - selectedOps.size(), nullptr);
49           selectedOps[pos] = selected;
50         } else {
51           selectedOps.push_back(selected);
52         }
53         return WalkResult::advance();
54       });
55 
56       if (llvm::find(selectedOps, nullptr) != selectedOps.end()) {
57         root->emitError("invalid test case: some indices are missing among the "
58                         "selected ops");
59         return WalkResult::skip();
60       }
61 
62       if (!computeTopologicalSorting(selectedOps)) {
63         root->emitError("could not schedule all ops");
64         return WalkResult::skip();
65       }
66 
67       for (const auto &it : llvm::enumerate(selectedOps))
68         it.value()->setAttr("pos", builder.getIndexAttr(it.index()));
69 
70       return WalkResult::advance();
71     });
72 
73     if (result.wasSkipped())
74       signalPassFailure();
75   }
76 };
77 } // namespace
78 
79 namespace mlir {
80 namespace test {
registerTestTopologicalSortAnalysisPass()81 void registerTestTopologicalSortAnalysisPass() {
82   PassRegistration<TestTopologicalSortAnalysisPass>();
83 }
84 } // namespace test
85 } // namespace mlir
86