1*b00e0c16SChristian Ulmann //===- TestTopologicalSort.cpp - Pass to test topological sort analysis ---===// 2*b00e0c16SChristian Ulmann // 3*b00e0c16SChristian Ulmann // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*b00e0c16SChristian Ulmann // See https://llvm.org/LICENSE.txt for license information. 5*b00e0c16SChristian Ulmann // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*b00e0c16SChristian Ulmann // 7*b00e0c16SChristian Ulmann //===----------------------------------------------------------------------===// 8*b00e0c16SChristian Ulmann 9*b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h" 10*b00e0c16SChristian Ulmann #include "mlir/IR/Builders.h" 11*b00e0c16SChristian Ulmann #include "mlir/IR/BuiltinOps.h" 12*b00e0c16SChristian Ulmann #include "mlir/Pass/Pass.h" 13*b00e0c16SChristian Ulmann 14*b00e0c16SChristian Ulmann using namespace mlir; 15*b00e0c16SChristian Ulmann 16*b00e0c16SChristian Ulmann namespace { 17*b00e0c16SChristian Ulmann struct TestTopologicalSortAnalysisPass 18*b00e0c16SChristian Ulmann : public PassWrapper<TestTopologicalSortAnalysisPass, 19*b00e0c16SChristian Ulmann OperationPass<ModuleOp>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon54dfb7230111::TestTopologicalSortAnalysisPass20*b00e0c16SChristian Ulmann MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortAnalysisPass) 21*b00e0c16SChristian Ulmann 22*b00e0c16SChristian Ulmann StringRef getArgument() const final { 23*b00e0c16SChristian Ulmann return "test-topological-sort-analysis"; 24*b00e0c16SChristian Ulmann } getDescription__anon54dfb7230111::TestTopologicalSortAnalysisPass25*b00e0c16SChristian Ulmann StringRef getDescription() const final { 26*b00e0c16SChristian Ulmann return "Test topological sorting of ops"; 27*b00e0c16SChristian Ulmann } 28*b00e0c16SChristian Ulmann runOnOperation__anon54dfb7230111::TestTopologicalSortAnalysisPass29*b00e0c16SChristian Ulmann void runOnOperation() override { 30*b00e0c16SChristian Ulmann Operation *op = getOperation(); 31*b00e0c16SChristian Ulmann OpBuilder builder(op->getContext()); 32*b00e0c16SChristian Ulmann 33*b00e0c16SChristian Ulmann WalkResult result = op->walk([&](Operation *root) { 34*b00e0c16SChristian Ulmann if (!root->hasAttr("root")) 35*b00e0c16SChristian Ulmann return WalkResult::advance(); 36*b00e0c16SChristian Ulmann 37*b00e0c16SChristian Ulmann SmallVector<Operation *> selectedOps; 38*b00e0c16SChristian Ulmann root->walk([&](Operation *selected) { 39*b00e0c16SChristian Ulmann if (!selected->hasAttr("selected")) 40*b00e0c16SChristian Ulmann return WalkResult::advance(); 41*b00e0c16SChristian Ulmann if (root->hasAttr("ordered")) { 42*b00e0c16SChristian Ulmann // If the root has an "ordered" attribute, we fill the selectedOps 43*b00e0c16SChristian Ulmann // vector in a certain order. 44*b00e0c16SChristian Ulmann int64_t pos = 45*b00e0c16SChristian Ulmann cast<IntegerAttr>(selected->getDiscardableAttr("selected")) 46*b00e0c16SChristian Ulmann .getInt(); 47*b00e0c16SChristian Ulmann if (pos >= static_cast<int64_t>(selectedOps.size())) 48*b00e0c16SChristian Ulmann selectedOps.append(pos + 1 - selectedOps.size(), nullptr); 49*b00e0c16SChristian Ulmann selectedOps[pos] = selected; 50*b00e0c16SChristian Ulmann } else { 51*b00e0c16SChristian Ulmann selectedOps.push_back(selected); 52*b00e0c16SChristian Ulmann } 53*b00e0c16SChristian Ulmann return WalkResult::advance(); 54*b00e0c16SChristian Ulmann }); 55*b00e0c16SChristian Ulmann 56*b00e0c16SChristian Ulmann if (llvm::find(selectedOps, nullptr) != selectedOps.end()) { 57*b00e0c16SChristian Ulmann root->emitError("invalid test case: some indices are missing among the " 58*b00e0c16SChristian Ulmann "selected ops"); 59*b00e0c16SChristian Ulmann return WalkResult::skip(); 60*b00e0c16SChristian Ulmann } 61*b00e0c16SChristian Ulmann 62*b00e0c16SChristian Ulmann if (!computeTopologicalSorting(selectedOps)) { 63*b00e0c16SChristian Ulmann root->emitError("could not schedule all ops"); 64*b00e0c16SChristian Ulmann return WalkResult::skip(); 65*b00e0c16SChristian Ulmann } 66*b00e0c16SChristian Ulmann 67*b00e0c16SChristian Ulmann for (const auto &it : llvm::enumerate(selectedOps)) 68*b00e0c16SChristian Ulmann it.value()->setAttr("pos", builder.getIndexAttr(it.index())); 69*b00e0c16SChristian Ulmann 70*b00e0c16SChristian Ulmann return WalkResult::advance(); 71*b00e0c16SChristian Ulmann }); 72*b00e0c16SChristian Ulmann 73*b00e0c16SChristian Ulmann if (result.wasSkipped()) 74*b00e0c16SChristian Ulmann signalPassFailure(); 75*b00e0c16SChristian Ulmann } 76*b00e0c16SChristian Ulmann }; 77*b00e0c16SChristian Ulmann } // namespace 78*b00e0c16SChristian Ulmann 79*b00e0c16SChristian Ulmann namespace mlir { 80*b00e0c16SChristian Ulmann namespace test { registerTestTopologicalSortAnalysisPass()81*b00e0c16SChristian Ulmannvoid registerTestTopologicalSortAnalysisPass() { 82*b00e0c16SChristian Ulmann PassRegistration<TestTopologicalSortAnalysisPass>(); 83*b00e0c16SChristian Ulmann } 84*b00e0c16SChristian Ulmann } // namespace test 85*b00e0c16SChristian Ulmann } // namespace mlir 86