xref: /llvm-project/mlir/test/lib/IR/TestUseListOrders.cpp (revision f6fb639c76ce255e2f5b3e2c8550270243e7e7ab)
161278191SMatteo Franciolini //===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===//
261278191SMatteo Franciolini //
361278191SMatteo Franciolini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
461278191SMatteo Franciolini // See https://llvm.org/LICENSE.txt for license information.
561278191SMatteo Franciolini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
661278191SMatteo Franciolini //
761278191SMatteo Franciolini //===----------------------------------------------------------------------===//
861278191SMatteo Franciolini 
961278191SMatteo Franciolini #include "mlir/Bytecode/BytecodeWriter.h"
1061278191SMatteo Franciolini #include "mlir/Bytecode/Encoding.h"
1161278191SMatteo Franciolini #include "mlir/IR/BuiltinOps.h"
12*f6fb639cSMehdi Amini #include "mlir/IR/OwningOpRef.h"
1361278191SMatteo Franciolini #include "mlir/Parser/Parser.h"
1461278191SMatteo Franciolini #include "mlir/Pass/Pass.h"
1561278191SMatteo Franciolini 
1661278191SMatteo Franciolini #include <numeric>
1761278191SMatteo Franciolini #include <random>
1861278191SMatteo Franciolini 
1961278191SMatteo Franciolini using namespace mlir;
2061278191SMatteo Franciolini 
2161278191SMatteo Franciolini namespace {
2261278191SMatteo Franciolini /// This pass tests that:
2361278191SMatteo Franciolini /// 1) we can shuffle use-lists correctly;
2461278191SMatteo Franciolini /// 2) use-list orders are preserved after a roundtrip to bytecode.
2561278191SMatteo Franciolini class TestPreserveUseListOrders
2661278191SMatteo Franciolini     : public PassWrapper<TestPreserveUseListOrders, OperationPass<ModuleOp>> {
2761278191SMatteo Franciolini public:
2861278191SMatteo Franciolini   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreserveUseListOrders)
2961278191SMatteo Franciolini 
3061278191SMatteo Franciolini   TestPreserveUseListOrders() = default;
TestPreserveUseListOrders(const TestPreserveUseListOrders & pass)3161278191SMatteo Franciolini   TestPreserveUseListOrders(const TestPreserveUseListOrders &pass)
3261278191SMatteo Franciolini       : PassWrapper(pass) {}
getArgument() const3361278191SMatteo Franciolini   StringRef getArgument() const final { return "test-verify-uselistorder"; }
getDescription() const3461278191SMatteo Franciolini   StringRef getDescription() const final {
3561278191SMatteo Franciolini     return "Verify that roundtripping the IR to bytecode preserves the order "
3661278191SMatteo Franciolini            "of the uselists";
3761278191SMatteo Franciolini   }
3861278191SMatteo Franciolini   Option<unsigned> rngSeed{*this, "rng-seed",
3961278191SMatteo Franciolini                            llvm::cl::desc("Specify an input random seed"),
4061278191SMatteo Franciolini                            llvm::cl::init(1)};
41f58d2cc5SMehdi Amini 
initialize(MLIRContext * context)42f58d2cc5SMehdi Amini   LogicalResult initialize(MLIRContext *context) override {
43f58d2cc5SMehdi Amini     rng.seed(static_cast<unsigned>(rngSeed));
44f58d2cc5SMehdi Amini     return success();
45f58d2cc5SMehdi Amini   }
46f58d2cc5SMehdi Amini 
runOnOperation()4761278191SMatteo Franciolini   void runOnOperation() override {
4861278191SMatteo Franciolini     // Clone the module so that we can plug in this pass to any other
4961278191SMatteo Franciolini     // independently.
50*f6fb639cSMehdi Amini     OwningOpRef<ModuleOp> cloneModule = getOperation().clone();
5161278191SMatteo Franciolini 
5261278191SMatteo Franciolini     // 1. Compute the op numbering of the module.
53*f6fb639cSMehdi Amini     computeOpNumbering(*cloneModule);
5461278191SMatteo Franciolini 
5561278191SMatteo Franciolini     // 2. Loop over all the values and shuffle the uses. While doing so, check
5661278191SMatteo Franciolini     // that each shuffle is correct.
57*f6fb639cSMehdi Amini     if (failed(shuffleUses(*cloneModule)))
5861278191SMatteo Franciolini       return signalPassFailure();
5961278191SMatteo Franciolini 
6061278191SMatteo Franciolini     // 3. Do a bytecode roundtrip to version 3, which supports use-list order
6161278191SMatteo Franciolini     // preservation.
62*f6fb639cSMehdi Amini     auto roundtripModuleOr = doRoundtripToBytecode(*cloneModule, 3);
6361278191SMatteo Franciolini     // If the bytecode roundtrip failed, try to roundtrip the original module
6461278191SMatteo Franciolini     // to version 2, which does not support use-list. If this also fails, the
6561278191SMatteo Franciolini     // original module had an issue unrelated to uselists.
6661278191SMatteo Franciolini     if (failed(roundtripModuleOr)) {
6761278191SMatteo Franciolini       auto testModuleOr = doRoundtripToBytecode(getOperation(), 2);
6861278191SMatteo Franciolini       if (failed(testModuleOr))
6961278191SMatteo Franciolini         return;
7061278191SMatteo Franciolini 
7161278191SMatteo Franciolini       return signalPassFailure();
7261278191SMatteo Franciolini     }
7361278191SMatteo Franciolini 
7461278191SMatteo Franciolini     // 4. Recompute the op numbering on the new module. The numbering should be
7561278191SMatteo Franciolini     // the same as (1), but on the new operation pointers.
7661278191SMatteo Franciolini     computeOpNumbering(roundtripModuleOr->get());
7761278191SMatteo Franciolini 
7861278191SMatteo Franciolini     // 5. Loop over all the values and verify that the use-list is consistent
7961278191SMatteo Franciolini     // with the post-shuffle order of step (2).
8061278191SMatteo Franciolini     if (failed(verifyUseListOrders(roundtripModuleOr->get())))
8161278191SMatteo Franciolini       return signalPassFailure();
8261278191SMatteo Franciolini   }
8361278191SMatteo Franciolini 
8461278191SMatteo Franciolini private:
doRoundtripToBytecode(Operation * module,uint32_t version)8561278191SMatteo Franciolini   FailureOr<OwningOpRef<Operation *>> doRoundtripToBytecode(Operation *module,
8661278191SMatteo Franciolini                                                             uint32_t version) {
8761278191SMatteo Franciolini     std::string str;
8861278191SMatteo Franciolini     llvm::raw_string_ostream m(str);
8961278191SMatteo Franciolini     BytecodeWriterConfig config;
9061278191SMatteo Franciolini     config.setDesiredBytecodeVersion(version);
9161278191SMatteo Franciolini     if (failed(writeBytecodeToFile(module, m, config)))
9261278191SMatteo Franciolini       return failure();
9361278191SMatteo Franciolini 
9461278191SMatteo Franciolini     ParserConfig parseConfig(&getContext(), /*verifyAfterParse=*/true);
9561278191SMatteo Franciolini     auto newModuleOp = parseSourceString(StringRef(str), parseConfig);
9661278191SMatteo Franciolini     if (!newModuleOp.get())
9761278191SMatteo Franciolini       return failure();
9861278191SMatteo Franciolini     return newModuleOp;
9961278191SMatteo Franciolini   }
10061278191SMatteo Franciolini 
10161278191SMatteo Franciolini   /// Compute an ordered numbering for all the operations in the IR.
computeOpNumbering(Operation * topLevelOp)10261278191SMatteo Franciolini   void computeOpNumbering(Operation *topLevelOp) {
10361278191SMatteo Franciolini     uint32_t operationID = 0;
10461278191SMatteo Franciolini     opNumbering.clear();
10561278191SMatteo Franciolini     topLevelOp->walk<mlir::WalkOrder::PreOrder>(
10661278191SMatteo Franciolini         [&](Operation *op) { opNumbering.try_emplace(op, operationID++); });
10761278191SMatteo Franciolini   }
10861278191SMatteo Franciolini 
10961278191SMatteo Franciolini   template <typename ValueT>
getUseIDs(ValueT val)11061278191SMatteo Franciolini   SmallVector<uint64_t> getUseIDs(ValueT val) {
11161278191SMatteo Franciolini     return SmallVector<uint64_t>(llvm::map_range(val.getUses(), [&](auto &use) {
11261278191SMatteo Franciolini       return bytecode::getUseID(use, opNumbering.at(use.getOwner()));
11361278191SMatteo Franciolini     }));
11461278191SMatteo Franciolini   }
11561278191SMatteo Franciolini 
shuffleUses(Operation * topLevelOp)11661278191SMatteo Franciolini   LogicalResult shuffleUses(Operation *topLevelOp) {
11761278191SMatteo Franciolini     uint32_t valueID = 0;
11861278191SMatteo Franciolini     /// Permute randomly the use-list of each value. It is guaranteed that at
11961278191SMatteo Franciolini     /// least one pair of the use list is permuted.
12061278191SMatteo Franciolini     auto doShuffleForRange = [&](ValueRange range) -> LogicalResult {
12161278191SMatteo Franciolini       for (auto val : range) {
12261278191SMatteo Franciolini         if (val.use_empty() || val.hasOneUse())
12361278191SMatteo Franciolini           continue;
12461278191SMatteo Franciolini 
12561278191SMatteo Franciolini         /// Get a valid index permutation for the uses of value.
12661278191SMatteo Franciolini         SmallVector<unsigned> permutation = getRandomPermutation(val);
12761278191SMatteo Franciolini 
12861278191SMatteo Franciolini         /// Store original order and verify that the shuffle was applied
12961278191SMatteo Franciolini         /// correctly.
13061278191SMatteo Franciolini         auto useIDs = getUseIDs(val);
13161278191SMatteo Franciolini 
13261278191SMatteo Franciolini         /// Apply shuffle to the uselist.
13361278191SMatteo Franciolini         val.shuffleUseList(permutation);
13461278191SMatteo Franciolini 
13561278191SMatteo Franciolini         /// Get the new order and verify the shuffle happened correctly.
13661278191SMatteo Franciolini         auto permutedIDs = getUseIDs(val);
13761278191SMatteo Franciolini         if (permutedIDs.size() != useIDs.size())
13861278191SMatteo Franciolini           return failure();
13961278191SMatteo Franciolini         for (size_t idx = 0; idx < permutation.size(); idx++)
14061278191SMatteo Franciolini           if (useIDs[idx] != permutedIDs[permutation[idx]])
14161278191SMatteo Franciolini             return failure();
14261278191SMatteo Franciolini 
14361278191SMatteo Franciolini         referenceUseListOrder.try_emplace(
14461278191SMatteo Franciolini             valueID++, llvm::map_range(val.getUses(), [&](auto &use) {
14561278191SMatteo Franciolini               return bytecode::getUseID(use, opNumbering.at(use.getOwner()));
14661278191SMatteo Franciolini             }));
14761278191SMatteo Franciolini       }
14861278191SMatteo Franciolini       return success();
14961278191SMatteo Franciolini     };
15061278191SMatteo Franciolini 
15161278191SMatteo Franciolini     return walkOverValues(topLevelOp, doShuffleForRange);
15261278191SMatteo Franciolini   }
15361278191SMatteo Franciolini 
verifyUseListOrders(Operation * topLevelOp)15461278191SMatteo Franciolini   LogicalResult verifyUseListOrders(Operation *topLevelOp) {
15561278191SMatteo Franciolini     uint32_t valueID = 0;
15661278191SMatteo Franciolini     /// Check that the use-list for the value range matches the one stored in
15761278191SMatteo Franciolini     /// the reference.
15861278191SMatteo Franciolini     auto doValidationForRange = [&](ValueRange range) -> LogicalResult {
15961278191SMatteo Franciolini       for (auto val : range) {
16061278191SMatteo Franciolini         if (val.use_empty() || val.hasOneUse())
16161278191SMatteo Franciolini           continue;
16261278191SMatteo Franciolini         auto referenceOrder = referenceUseListOrder.at(valueID++);
16361278191SMatteo Franciolini         for (auto [use, referenceID] :
16461278191SMatteo Franciolini              llvm::zip(val.getUses(), referenceOrder)) {
16561278191SMatteo Franciolini           uint64_t uniqueID =
16661278191SMatteo Franciolini               bytecode::getUseID(use, opNumbering.at(use.getOwner()));
16761278191SMatteo Franciolini           if (uniqueID != referenceID) {
16861278191SMatteo Franciolini             use.getOwner()->emitError()
16961278191SMatteo Franciolini                 << "found use-list order mismatch for value: " << val;
17061278191SMatteo Franciolini             return failure();
17161278191SMatteo Franciolini           }
17261278191SMatteo Franciolini         }
17361278191SMatteo Franciolini       }
17461278191SMatteo Franciolini       return success();
17561278191SMatteo Franciolini     };
17661278191SMatteo Franciolini 
17761278191SMatteo Franciolini     return walkOverValues(topLevelOp, doValidationForRange);
17861278191SMatteo Franciolini   }
17961278191SMatteo Franciolini 
18061278191SMatteo Franciolini   /// Walk over blocks and operations and execute a callable over the ranges of
18161278191SMatteo Franciolini   /// operands/results respectively.
18261278191SMatteo Franciolini   template <typename FuncT>
walkOverValues(Operation * topLevelOp,FuncT callable)18361278191SMatteo Franciolini   LogicalResult walkOverValues(Operation *topLevelOp, FuncT callable) {
18461278191SMatteo Franciolini     auto blockWalk = topLevelOp->walk([&](Block *block) {
18561278191SMatteo Franciolini       if (failed(callable(block->getArguments())))
18661278191SMatteo Franciolini         return WalkResult::interrupt();
18761278191SMatteo Franciolini       return WalkResult::advance();
18861278191SMatteo Franciolini     });
18961278191SMatteo Franciolini 
19061278191SMatteo Franciolini     if (blockWalk.wasInterrupted())
19161278191SMatteo Franciolini       return failure();
19261278191SMatteo Franciolini 
19361278191SMatteo Franciolini     auto resultsWalk = topLevelOp->walk([&](Operation *op) {
19461278191SMatteo Franciolini       if (failed(callable(op->getResults())))
19561278191SMatteo Franciolini         return WalkResult::interrupt();
19661278191SMatteo Franciolini       return WalkResult::advance();
19761278191SMatteo Franciolini     });
19861278191SMatteo Franciolini 
19961278191SMatteo Franciolini     return failure(resultsWalk.wasInterrupted());
20061278191SMatteo Franciolini   }
20161278191SMatteo Franciolini 
20261278191SMatteo Franciolini   /// Creates a random permutation of the uselist order chain of the provided
20361278191SMatteo Franciolini   /// value.
getRandomPermutation(Value value)20461278191SMatteo Franciolini   SmallVector<unsigned> getRandomPermutation(Value value) {
20561278191SMatteo Franciolini     size_t numUses = std::distance(value.use_begin(), value.use_end());
20661278191SMatteo Franciolini     SmallVector<unsigned> permutation(numUses);
20761278191SMatteo Franciolini     unsigned zero = 0;
20861278191SMatteo Franciolini     std::iota(permutation.begin(), permutation.end(), zero);
20961278191SMatteo Franciolini     std::shuffle(permutation.begin(), permutation.end(), rng);
21061278191SMatteo Franciolini     return permutation;
21161278191SMatteo Franciolini   }
21261278191SMatteo Franciolini 
21361278191SMatteo Franciolini   /// Map each value to its use-list order encoded with unique use IDs.
21461278191SMatteo Franciolini   DenseMap<uint32_t, SmallVector<uint64_t>> referenceUseListOrder;
21561278191SMatteo Franciolini 
21661278191SMatteo Franciolini   /// Map each operation to its global ID.
21761278191SMatteo Franciolini   DenseMap<Operation *, uint32_t> opNumbering;
218f58d2cc5SMehdi Amini 
219f58d2cc5SMehdi Amini   std::default_random_engine rng;
22061278191SMatteo Franciolini };
22161278191SMatteo Franciolini } // namespace
22261278191SMatteo Franciolini 
22361278191SMatteo Franciolini namespace mlir {
registerTestPreserveUseListOrders()22461278191SMatteo Franciolini void registerTestPreserveUseListOrders() {
22561278191SMatteo Franciolini   PassRegistration<TestPreserveUseListOrders>();
22661278191SMatteo Franciolini }
22761278191SMatteo Franciolini } // namespace mlir
228