161278191SMatteo Franciolini //===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===//
261278191SMatteo Franciolini //
361278191SMatteo Franciolini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
461278191SMatteo Franciolini // See https://llvm.org/LICENSE.txt for license information.
561278191SMatteo Franciolini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
661278191SMatteo Franciolini //
761278191SMatteo Franciolini //===----------------------------------------------------------------------===//
861278191SMatteo Franciolini
961278191SMatteo Franciolini #include "mlir/Bytecode/BytecodeWriter.h"
1061278191SMatteo Franciolini #include "mlir/Bytecode/Encoding.h"
1161278191SMatteo Franciolini #include "mlir/IR/BuiltinOps.h"
12*f6fb639cSMehdi Amini #include "mlir/IR/OwningOpRef.h"
1361278191SMatteo Franciolini #include "mlir/Parser/Parser.h"
1461278191SMatteo Franciolini #include "mlir/Pass/Pass.h"
1561278191SMatteo Franciolini
1661278191SMatteo Franciolini #include <numeric>
1761278191SMatteo Franciolini #include <random>
1861278191SMatteo Franciolini
1961278191SMatteo Franciolini using namespace mlir;
2061278191SMatteo Franciolini
2161278191SMatteo Franciolini namespace {
2261278191SMatteo Franciolini /// This pass tests that:
2361278191SMatteo Franciolini /// 1) we can shuffle use-lists correctly;
2461278191SMatteo Franciolini /// 2) use-list orders are preserved after a roundtrip to bytecode.
2561278191SMatteo Franciolini class TestPreserveUseListOrders
2661278191SMatteo Franciolini : public PassWrapper<TestPreserveUseListOrders, OperationPass<ModuleOp>> {
2761278191SMatteo Franciolini public:
2861278191SMatteo Franciolini MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreserveUseListOrders)
2961278191SMatteo Franciolini
3061278191SMatteo Franciolini TestPreserveUseListOrders() = default;
TestPreserveUseListOrders(const TestPreserveUseListOrders & pass)3161278191SMatteo Franciolini TestPreserveUseListOrders(const TestPreserveUseListOrders &pass)
3261278191SMatteo Franciolini : PassWrapper(pass) {}
getArgument() const3361278191SMatteo Franciolini StringRef getArgument() const final { return "test-verify-uselistorder"; }
getDescription() const3461278191SMatteo Franciolini StringRef getDescription() const final {
3561278191SMatteo Franciolini return "Verify that roundtripping the IR to bytecode preserves the order "
3661278191SMatteo Franciolini "of the uselists";
3761278191SMatteo Franciolini }
3861278191SMatteo Franciolini Option<unsigned> rngSeed{*this, "rng-seed",
3961278191SMatteo Franciolini llvm::cl::desc("Specify an input random seed"),
4061278191SMatteo Franciolini llvm::cl::init(1)};
41f58d2cc5SMehdi Amini
initialize(MLIRContext * context)42f58d2cc5SMehdi Amini LogicalResult initialize(MLIRContext *context) override {
43f58d2cc5SMehdi Amini rng.seed(static_cast<unsigned>(rngSeed));
44f58d2cc5SMehdi Amini return success();
45f58d2cc5SMehdi Amini }
46f58d2cc5SMehdi Amini
runOnOperation()4761278191SMatteo Franciolini void runOnOperation() override {
4861278191SMatteo Franciolini // Clone the module so that we can plug in this pass to any other
4961278191SMatteo Franciolini // independently.
50*f6fb639cSMehdi Amini OwningOpRef<ModuleOp> cloneModule = getOperation().clone();
5161278191SMatteo Franciolini
5261278191SMatteo Franciolini // 1. Compute the op numbering of the module.
53*f6fb639cSMehdi Amini computeOpNumbering(*cloneModule);
5461278191SMatteo Franciolini
5561278191SMatteo Franciolini // 2. Loop over all the values and shuffle the uses. While doing so, check
5661278191SMatteo Franciolini // that each shuffle is correct.
57*f6fb639cSMehdi Amini if (failed(shuffleUses(*cloneModule)))
5861278191SMatteo Franciolini return signalPassFailure();
5961278191SMatteo Franciolini
6061278191SMatteo Franciolini // 3. Do a bytecode roundtrip to version 3, which supports use-list order
6161278191SMatteo Franciolini // preservation.
62*f6fb639cSMehdi Amini auto roundtripModuleOr = doRoundtripToBytecode(*cloneModule, 3);
6361278191SMatteo Franciolini // If the bytecode roundtrip failed, try to roundtrip the original module
6461278191SMatteo Franciolini // to version 2, which does not support use-list. If this also fails, the
6561278191SMatteo Franciolini // original module had an issue unrelated to uselists.
6661278191SMatteo Franciolini if (failed(roundtripModuleOr)) {
6761278191SMatteo Franciolini auto testModuleOr = doRoundtripToBytecode(getOperation(), 2);
6861278191SMatteo Franciolini if (failed(testModuleOr))
6961278191SMatteo Franciolini return;
7061278191SMatteo Franciolini
7161278191SMatteo Franciolini return signalPassFailure();
7261278191SMatteo Franciolini }
7361278191SMatteo Franciolini
7461278191SMatteo Franciolini // 4. Recompute the op numbering on the new module. The numbering should be
7561278191SMatteo Franciolini // the same as (1), but on the new operation pointers.
7661278191SMatteo Franciolini computeOpNumbering(roundtripModuleOr->get());
7761278191SMatteo Franciolini
7861278191SMatteo Franciolini // 5. Loop over all the values and verify that the use-list is consistent
7961278191SMatteo Franciolini // with the post-shuffle order of step (2).
8061278191SMatteo Franciolini if (failed(verifyUseListOrders(roundtripModuleOr->get())))
8161278191SMatteo Franciolini return signalPassFailure();
8261278191SMatteo Franciolini }
8361278191SMatteo Franciolini
8461278191SMatteo Franciolini private:
doRoundtripToBytecode(Operation * module,uint32_t version)8561278191SMatteo Franciolini FailureOr<OwningOpRef<Operation *>> doRoundtripToBytecode(Operation *module,
8661278191SMatteo Franciolini uint32_t version) {
8761278191SMatteo Franciolini std::string str;
8861278191SMatteo Franciolini llvm::raw_string_ostream m(str);
8961278191SMatteo Franciolini BytecodeWriterConfig config;
9061278191SMatteo Franciolini config.setDesiredBytecodeVersion(version);
9161278191SMatteo Franciolini if (failed(writeBytecodeToFile(module, m, config)))
9261278191SMatteo Franciolini return failure();
9361278191SMatteo Franciolini
9461278191SMatteo Franciolini ParserConfig parseConfig(&getContext(), /*verifyAfterParse=*/true);
9561278191SMatteo Franciolini auto newModuleOp = parseSourceString(StringRef(str), parseConfig);
9661278191SMatteo Franciolini if (!newModuleOp.get())
9761278191SMatteo Franciolini return failure();
9861278191SMatteo Franciolini return newModuleOp;
9961278191SMatteo Franciolini }
10061278191SMatteo Franciolini
10161278191SMatteo Franciolini /// Compute an ordered numbering for all the operations in the IR.
computeOpNumbering(Operation * topLevelOp)10261278191SMatteo Franciolini void computeOpNumbering(Operation *topLevelOp) {
10361278191SMatteo Franciolini uint32_t operationID = 0;
10461278191SMatteo Franciolini opNumbering.clear();
10561278191SMatteo Franciolini topLevelOp->walk<mlir::WalkOrder::PreOrder>(
10661278191SMatteo Franciolini [&](Operation *op) { opNumbering.try_emplace(op, operationID++); });
10761278191SMatteo Franciolini }
10861278191SMatteo Franciolini
10961278191SMatteo Franciolini template <typename ValueT>
getUseIDs(ValueT val)11061278191SMatteo Franciolini SmallVector<uint64_t> getUseIDs(ValueT val) {
11161278191SMatteo Franciolini return SmallVector<uint64_t>(llvm::map_range(val.getUses(), [&](auto &use) {
11261278191SMatteo Franciolini return bytecode::getUseID(use, opNumbering.at(use.getOwner()));
11361278191SMatteo Franciolini }));
11461278191SMatteo Franciolini }
11561278191SMatteo Franciolini
shuffleUses(Operation * topLevelOp)11661278191SMatteo Franciolini LogicalResult shuffleUses(Operation *topLevelOp) {
11761278191SMatteo Franciolini uint32_t valueID = 0;
11861278191SMatteo Franciolini /// Permute randomly the use-list of each value. It is guaranteed that at
11961278191SMatteo Franciolini /// least one pair of the use list is permuted.
12061278191SMatteo Franciolini auto doShuffleForRange = [&](ValueRange range) -> LogicalResult {
12161278191SMatteo Franciolini for (auto val : range) {
12261278191SMatteo Franciolini if (val.use_empty() || val.hasOneUse())
12361278191SMatteo Franciolini continue;
12461278191SMatteo Franciolini
12561278191SMatteo Franciolini /// Get a valid index permutation for the uses of value.
12661278191SMatteo Franciolini SmallVector<unsigned> permutation = getRandomPermutation(val);
12761278191SMatteo Franciolini
12861278191SMatteo Franciolini /// Store original order and verify that the shuffle was applied
12961278191SMatteo Franciolini /// correctly.
13061278191SMatteo Franciolini auto useIDs = getUseIDs(val);
13161278191SMatteo Franciolini
13261278191SMatteo Franciolini /// Apply shuffle to the uselist.
13361278191SMatteo Franciolini val.shuffleUseList(permutation);
13461278191SMatteo Franciolini
13561278191SMatteo Franciolini /// Get the new order and verify the shuffle happened correctly.
13661278191SMatteo Franciolini auto permutedIDs = getUseIDs(val);
13761278191SMatteo Franciolini if (permutedIDs.size() != useIDs.size())
13861278191SMatteo Franciolini return failure();
13961278191SMatteo Franciolini for (size_t idx = 0; idx < permutation.size(); idx++)
14061278191SMatteo Franciolini if (useIDs[idx] != permutedIDs[permutation[idx]])
14161278191SMatteo Franciolini return failure();
14261278191SMatteo Franciolini
14361278191SMatteo Franciolini referenceUseListOrder.try_emplace(
14461278191SMatteo Franciolini valueID++, llvm::map_range(val.getUses(), [&](auto &use) {
14561278191SMatteo Franciolini return bytecode::getUseID(use, opNumbering.at(use.getOwner()));
14661278191SMatteo Franciolini }));
14761278191SMatteo Franciolini }
14861278191SMatteo Franciolini return success();
14961278191SMatteo Franciolini };
15061278191SMatteo Franciolini
15161278191SMatteo Franciolini return walkOverValues(topLevelOp, doShuffleForRange);
15261278191SMatteo Franciolini }
15361278191SMatteo Franciolini
verifyUseListOrders(Operation * topLevelOp)15461278191SMatteo Franciolini LogicalResult verifyUseListOrders(Operation *topLevelOp) {
15561278191SMatteo Franciolini uint32_t valueID = 0;
15661278191SMatteo Franciolini /// Check that the use-list for the value range matches the one stored in
15761278191SMatteo Franciolini /// the reference.
15861278191SMatteo Franciolini auto doValidationForRange = [&](ValueRange range) -> LogicalResult {
15961278191SMatteo Franciolini for (auto val : range) {
16061278191SMatteo Franciolini if (val.use_empty() || val.hasOneUse())
16161278191SMatteo Franciolini continue;
16261278191SMatteo Franciolini auto referenceOrder = referenceUseListOrder.at(valueID++);
16361278191SMatteo Franciolini for (auto [use, referenceID] :
16461278191SMatteo Franciolini llvm::zip(val.getUses(), referenceOrder)) {
16561278191SMatteo Franciolini uint64_t uniqueID =
16661278191SMatteo Franciolini bytecode::getUseID(use, opNumbering.at(use.getOwner()));
16761278191SMatteo Franciolini if (uniqueID != referenceID) {
16861278191SMatteo Franciolini use.getOwner()->emitError()
16961278191SMatteo Franciolini << "found use-list order mismatch for value: " << val;
17061278191SMatteo Franciolini return failure();
17161278191SMatteo Franciolini }
17261278191SMatteo Franciolini }
17361278191SMatteo Franciolini }
17461278191SMatteo Franciolini return success();
17561278191SMatteo Franciolini };
17661278191SMatteo Franciolini
17761278191SMatteo Franciolini return walkOverValues(topLevelOp, doValidationForRange);
17861278191SMatteo Franciolini }
17961278191SMatteo Franciolini
18061278191SMatteo Franciolini /// Walk over blocks and operations and execute a callable over the ranges of
18161278191SMatteo Franciolini /// operands/results respectively.
18261278191SMatteo Franciolini template <typename FuncT>
walkOverValues(Operation * topLevelOp,FuncT callable)18361278191SMatteo Franciolini LogicalResult walkOverValues(Operation *topLevelOp, FuncT callable) {
18461278191SMatteo Franciolini auto blockWalk = topLevelOp->walk([&](Block *block) {
18561278191SMatteo Franciolini if (failed(callable(block->getArguments())))
18661278191SMatteo Franciolini return WalkResult::interrupt();
18761278191SMatteo Franciolini return WalkResult::advance();
18861278191SMatteo Franciolini });
18961278191SMatteo Franciolini
19061278191SMatteo Franciolini if (blockWalk.wasInterrupted())
19161278191SMatteo Franciolini return failure();
19261278191SMatteo Franciolini
19361278191SMatteo Franciolini auto resultsWalk = topLevelOp->walk([&](Operation *op) {
19461278191SMatteo Franciolini if (failed(callable(op->getResults())))
19561278191SMatteo Franciolini return WalkResult::interrupt();
19661278191SMatteo Franciolini return WalkResult::advance();
19761278191SMatteo Franciolini });
19861278191SMatteo Franciolini
19961278191SMatteo Franciolini return failure(resultsWalk.wasInterrupted());
20061278191SMatteo Franciolini }
20161278191SMatteo Franciolini
20261278191SMatteo Franciolini /// Creates a random permutation of the uselist order chain of the provided
20361278191SMatteo Franciolini /// value.
getRandomPermutation(Value value)20461278191SMatteo Franciolini SmallVector<unsigned> getRandomPermutation(Value value) {
20561278191SMatteo Franciolini size_t numUses = std::distance(value.use_begin(), value.use_end());
20661278191SMatteo Franciolini SmallVector<unsigned> permutation(numUses);
20761278191SMatteo Franciolini unsigned zero = 0;
20861278191SMatteo Franciolini std::iota(permutation.begin(), permutation.end(), zero);
20961278191SMatteo Franciolini std::shuffle(permutation.begin(), permutation.end(), rng);
21061278191SMatteo Franciolini return permutation;
21161278191SMatteo Franciolini }
21261278191SMatteo Franciolini
21361278191SMatteo Franciolini /// Map each value to its use-list order encoded with unique use IDs.
21461278191SMatteo Franciolini DenseMap<uint32_t, SmallVector<uint64_t>> referenceUseListOrder;
21561278191SMatteo Franciolini
21661278191SMatteo Franciolini /// Map each operation to its global ID.
21761278191SMatteo Franciolini DenseMap<Operation *, uint32_t> opNumbering;
218f58d2cc5SMehdi Amini
219f58d2cc5SMehdi Amini std::default_random_engine rng;
22061278191SMatteo Franciolini };
22161278191SMatteo Franciolini } // namespace
22261278191SMatteo Franciolini
22361278191SMatteo Franciolini namespace mlir {
registerTestPreserveUseListOrders()22461278191SMatteo Franciolini void registerTestPreserveUseListOrders() {
22561278191SMatteo Franciolini PassRegistration<TestPreserveUseListOrders>();
22661278191SMatteo Franciolini }
22761278191SMatteo Franciolini } // namespace mlir
228