xref: /llvm-project/mlir/test/lib/IR/TestUseListOrders.cpp (revision f6fb639c76ce255e2f5b3e2c8550270243e7e7ab)
1 //===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Bytecode/BytecodeWriter.h"
10 #include "mlir/Bytecode/Encoding.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/IR/OwningOpRef.h"
13 #include "mlir/Parser/Parser.h"
14 #include "mlir/Pass/Pass.h"
15 
16 #include <numeric>
17 #include <random>
18 
19 using namespace mlir;
20 
21 namespace {
22 /// This pass tests that:
23 /// 1) we can shuffle use-lists correctly;
24 /// 2) use-list orders are preserved after a roundtrip to bytecode.
25 class TestPreserveUseListOrders
26     : public PassWrapper<TestPreserveUseListOrders, OperationPass<ModuleOp>> {
27 public:
28   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreserveUseListOrders)
29 
30   TestPreserveUseListOrders() = default;
TestPreserveUseListOrders(const TestPreserveUseListOrders & pass)31   TestPreserveUseListOrders(const TestPreserveUseListOrders &pass)
32       : PassWrapper(pass) {}
getArgument() const33   StringRef getArgument() const final { return "test-verify-uselistorder"; }
getDescription() const34   StringRef getDescription() const final {
35     return "Verify that roundtripping the IR to bytecode preserves the order "
36            "of the uselists";
37   }
38   Option<unsigned> rngSeed{*this, "rng-seed",
39                            llvm::cl::desc("Specify an input random seed"),
40                            llvm::cl::init(1)};
41 
initialize(MLIRContext * context)42   LogicalResult initialize(MLIRContext *context) override {
43     rng.seed(static_cast<unsigned>(rngSeed));
44     return success();
45   }
46 
runOnOperation()47   void runOnOperation() override {
48     // Clone the module so that we can plug in this pass to any other
49     // independently.
50     OwningOpRef<ModuleOp> cloneModule = getOperation().clone();
51 
52     // 1. Compute the op numbering of the module.
53     computeOpNumbering(*cloneModule);
54 
55     // 2. Loop over all the values and shuffle the uses. While doing so, check
56     // that each shuffle is correct.
57     if (failed(shuffleUses(*cloneModule)))
58       return signalPassFailure();
59 
60     // 3. Do a bytecode roundtrip to version 3, which supports use-list order
61     // preservation.
62     auto roundtripModuleOr = doRoundtripToBytecode(*cloneModule, 3);
63     // If the bytecode roundtrip failed, try to roundtrip the original module
64     // to version 2, which does not support use-list. If this also fails, the
65     // original module had an issue unrelated to uselists.
66     if (failed(roundtripModuleOr)) {
67       auto testModuleOr = doRoundtripToBytecode(getOperation(), 2);
68       if (failed(testModuleOr))
69         return;
70 
71       return signalPassFailure();
72     }
73 
74     // 4. Recompute the op numbering on the new module. The numbering should be
75     // the same as (1), but on the new operation pointers.
76     computeOpNumbering(roundtripModuleOr->get());
77 
78     // 5. Loop over all the values and verify that the use-list is consistent
79     // with the post-shuffle order of step (2).
80     if (failed(verifyUseListOrders(roundtripModuleOr->get())))
81       return signalPassFailure();
82   }
83 
84 private:
doRoundtripToBytecode(Operation * module,uint32_t version)85   FailureOr<OwningOpRef<Operation *>> doRoundtripToBytecode(Operation *module,
86                                                             uint32_t version) {
87     std::string str;
88     llvm::raw_string_ostream m(str);
89     BytecodeWriterConfig config;
90     config.setDesiredBytecodeVersion(version);
91     if (failed(writeBytecodeToFile(module, m, config)))
92       return failure();
93 
94     ParserConfig parseConfig(&getContext(), /*verifyAfterParse=*/true);
95     auto newModuleOp = parseSourceString(StringRef(str), parseConfig);
96     if (!newModuleOp.get())
97       return failure();
98     return newModuleOp;
99   }
100 
101   /// Compute an ordered numbering for all the operations in the IR.
computeOpNumbering(Operation * topLevelOp)102   void computeOpNumbering(Operation *topLevelOp) {
103     uint32_t operationID = 0;
104     opNumbering.clear();
105     topLevelOp->walk<mlir::WalkOrder::PreOrder>(
106         [&](Operation *op) { opNumbering.try_emplace(op, operationID++); });
107   }
108 
109   template <typename ValueT>
getUseIDs(ValueT val)110   SmallVector<uint64_t> getUseIDs(ValueT val) {
111     return SmallVector<uint64_t>(llvm::map_range(val.getUses(), [&](auto &use) {
112       return bytecode::getUseID(use, opNumbering.at(use.getOwner()));
113     }));
114   }
115 
shuffleUses(Operation * topLevelOp)116   LogicalResult shuffleUses(Operation *topLevelOp) {
117     uint32_t valueID = 0;
118     /// Permute randomly the use-list of each value. It is guaranteed that at
119     /// least one pair of the use list is permuted.
120     auto doShuffleForRange = [&](ValueRange range) -> LogicalResult {
121       for (auto val : range) {
122         if (val.use_empty() || val.hasOneUse())
123           continue;
124 
125         /// Get a valid index permutation for the uses of value.
126         SmallVector<unsigned> permutation = getRandomPermutation(val);
127 
128         /// Store original order and verify that the shuffle was applied
129         /// correctly.
130         auto useIDs = getUseIDs(val);
131 
132         /// Apply shuffle to the uselist.
133         val.shuffleUseList(permutation);
134 
135         /// Get the new order and verify the shuffle happened correctly.
136         auto permutedIDs = getUseIDs(val);
137         if (permutedIDs.size() != useIDs.size())
138           return failure();
139         for (size_t idx = 0; idx < permutation.size(); idx++)
140           if (useIDs[idx] != permutedIDs[permutation[idx]])
141             return failure();
142 
143         referenceUseListOrder.try_emplace(
144             valueID++, llvm::map_range(val.getUses(), [&](auto &use) {
145               return bytecode::getUseID(use, opNumbering.at(use.getOwner()));
146             }));
147       }
148       return success();
149     };
150 
151     return walkOverValues(topLevelOp, doShuffleForRange);
152   }
153 
verifyUseListOrders(Operation * topLevelOp)154   LogicalResult verifyUseListOrders(Operation *topLevelOp) {
155     uint32_t valueID = 0;
156     /// Check that the use-list for the value range matches the one stored in
157     /// the reference.
158     auto doValidationForRange = [&](ValueRange range) -> LogicalResult {
159       for (auto val : range) {
160         if (val.use_empty() || val.hasOneUse())
161           continue;
162         auto referenceOrder = referenceUseListOrder.at(valueID++);
163         for (auto [use, referenceID] :
164              llvm::zip(val.getUses(), referenceOrder)) {
165           uint64_t uniqueID =
166               bytecode::getUseID(use, opNumbering.at(use.getOwner()));
167           if (uniqueID != referenceID) {
168             use.getOwner()->emitError()
169                 << "found use-list order mismatch for value: " << val;
170             return failure();
171           }
172         }
173       }
174       return success();
175     };
176 
177     return walkOverValues(topLevelOp, doValidationForRange);
178   }
179 
180   /// Walk over blocks and operations and execute a callable over the ranges of
181   /// operands/results respectively.
182   template <typename FuncT>
walkOverValues(Operation * topLevelOp,FuncT callable)183   LogicalResult walkOverValues(Operation *topLevelOp, FuncT callable) {
184     auto blockWalk = topLevelOp->walk([&](Block *block) {
185       if (failed(callable(block->getArguments())))
186         return WalkResult::interrupt();
187       return WalkResult::advance();
188     });
189 
190     if (blockWalk.wasInterrupted())
191       return failure();
192 
193     auto resultsWalk = topLevelOp->walk([&](Operation *op) {
194       if (failed(callable(op->getResults())))
195         return WalkResult::interrupt();
196       return WalkResult::advance();
197     });
198 
199     return failure(resultsWalk.wasInterrupted());
200   }
201 
202   /// Creates a random permutation of the uselist order chain of the provided
203   /// value.
getRandomPermutation(Value value)204   SmallVector<unsigned> getRandomPermutation(Value value) {
205     size_t numUses = std::distance(value.use_begin(), value.use_end());
206     SmallVector<unsigned> permutation(numUses);
207     unsigned zero = 0;
208     std::iota(permutation.begin(), permutation.end(), zero);
209     std::shuffle(permutation.begin(), permutation.end(), rng);
210     return permutation;
211   }
212 
213   /// Map each value to its use-list order encoded with unique use IDs.
214   DenseMap<uint32_t, SmallVector<uint64_t>> referenceUseListOrder;
215 
216   /// Map each operation to its global ID.
217   DenseMap<Operation *, uint32_t> opNumbering;
218 
219   std::default_random_engine rng;
220 };
221 } // namespace
222 
223 namespace mlir {
registerTestPreserveUseListOrders()224 void registerTestPreserveUseListOrders() {
225   PassRegistration<TestPreserveUseListOrders>();
226 }
227 } // namespace mlir
228