xref: /llvm-project/mlir/test/lib/Analysis/TestSlice.cpp (revision 5e50dd048e3a20cde5da5d7a754dfee775ef35d6)
1 //===------------- TestSlice.cpp - Test slice related analisis ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/SliceAnalysis.h"
10 #include "mlir/Pass/Pass.h"
11 
12 using namespace mlir;
13 
14 static const StringLiteral kOrderMarker = "__test_sort_original_idx__";
15 
16 namespace {
17 
18 struct TestTopologicalSortPass
19     : public PassWrapper<TestTopologicalSortPass,
20                          InterfacePass<SymbolOpInterface>> {
21   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortPass)
22 
23   StringRef getArgument() const final { return "test-print-topological-sort"; }
24   StringRef getDescription() const final {
25     return "Print operations in topological order";
26   }
27   void runOnOperation() override {
28     std::map<int, Operation *> ops;
29     getOperation().walk([&ops](Operation *op) {
30       if (auto originalOrderAttr = op->getAttrOfType<IntegerAttr>(kOrderMarker))
31         ops[originalOrderAttr.getInt()] = op;
32     });
33     SetVector<Operation *> sortedOp;
34     for (auto op : ops)
35       sortedOp.insert(op.second);
36     sortedOp = topologicalSort(sortedOp);
37     llvm::errs() << "Testing : " << getOperation().getName() << "\n";
38     for (Operation *op : sortedOp) {
39       op->print(llvm::errs());
40       llvm::errs() << "\n";
41     }
42   }
43 };
44 
45 } // namespace
46 
47 namespace mlir {
48 namespace test {
49 void registerTestSliceAnalysisPass() {
50   PassRegistration<TestTopologicalSortPass>();
51 }
52 } // namespace test
53 } // namespace mlir
54