xref: /llvm-project/mlir/test/lib/Analysis/TestSlice.cpp (revision b00e0c167186d69e1e6bceda57c09b272bd6acfc)
1 //===- TestSlice.cpp - Test slice related analisis ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/TopologicalSortUtils.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/SymbolTable.h"
12 #include "mlir/Pass/Pass.h"
13 
14 using namespace mlir;
15 
16 static const StringLiteral kToSortMark = "test_to_sort";
17 static const StringLiteral kOrderIndex = "test_sort_index";
18 
19 namespace {
20 
21 struct TestTopologicalSortPass
22     : public PassWrapper<TestTopologicalSortPass,
23                          InterfacePass<SymbolOpInterface>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anona1b0e00e0111::TestTopologicalSortPass24   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortPass)
25 
26   StringRef getArgument() const final { return "test-print-topological-sort"; }
getDescription__anona1b0e00e0111::TestTopologicalSortPass27   StringRef getDescription() const final {
28     return "Sorts operations topologically and attaches attributes with their "
29            "corresponding index in the ordering to them";
30   }
runOnOperation__anona1b0e00e0111::TestTopologicalSortPass31   void runOnOperation() override {
32     SetVector<Operation *> toSort;
33     getOperation().walk([&](Operation *op) {
34       if (op->hasAttrOfType<UnitAttr>(kToSortMark))
35         toSort.insert(op);
36     });
37 
38     auto i32Type = IntegerType::get(&getContext(), 32);
39     SetVector<Operation *> sortedOps = topologicalSort(toSort);
40     for (auto [index, op] : llvm::enumerate(sortedOps))
41       op->setAttr(kOrderIndex, IntegerAttr::get(i32Type, index));
42   }
43 };
44 
45 } // namespace
46 
47 namespace mlir {
48 namespace test {
registerTestSliceAnalysisPass()49 void registerTestSliceAnalysisPass() {
50   PassRegistration<TestTopologicalSortPass>();
51 }
52 } // namespace test
53 } // namespace mlir
54