xref: /llvm-project/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp (revision 422b84a77167c43259e18cc3eff88b4b2530defc)
1 //===- DIExpressionRewriter.cpp - Rewriter for DIExpression operators -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h"
10 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
11 #include "llvm/Support/Debug.h"
12 
13 using namespace mlir;
14 using namespace LLVM;
15 
16 #define DEBUG_TYPE "llvm-di-expression-simplifier"
17 
18 //===----------------------------------------------------------------------===//
19 // DIExpressionRewriter
20 //===----------------------------------------------------------------------===//
21 
addPattern(std::unique_ptr<ExprRewritePattern> pattern)22 void DIExpressionRewriter::addPattern(
23     std::unique_ptr<ExprRewritePattern> pattern) {
24   patterns.emplace_back(std::move(pattern));
25 }
26 
27 DIExpressionAttr
simplify(DIExpressionAttr expr,std::optional<uint64_t> maxNumRewrites) const28 DIExpressionRewriter::simplify(DIExpressionAttr expr,
29                                std::optional<uint64_t> maxNumRewrites) const {
30   ArrayRef<OperatorT> operators = expr.getOperations();
31 
32   // `inputs` contains the unprocessed postfix of operators.
33   // `result` contains the already finalized prefix of operators.
34   // Invariant: concat(result, inputs) is equivalent to `operators` after some
35   // application of the rewrite patterns.
36   // Using a deque for inputs so that we have efficient front insertion and
37   // removal. Random access is not necessary for patterns.
38   std::deque<OperatorT> inputs(operators.begin(), operators.end());
39   SmallVector<OperatorT> result;
40 
41   uint64_t numRewrites = 0;
42   while (!inputs.empty() &&
43          (!maxNumRewrites || numRewrites < *maxNumRewrites)) {
44     bool foundMatch = false;
45     for (const std::unique_ptr<ExprRewritePattern> &pattern : patterns) {
46       ExprRewritePattern::OpIterT matchEnd = pattern->match(inputs);
47       if (matchEnd == inputs.begin())
48         continue;
49 
50       foundMatch = true;
51       SmallVector<OperatorT> replacement =
52           pattern->replace(llvm::make_range(inputs.cbegin(), matchEnd));
53       inputs.erase(inputs.begin(), matchEnd);
54       inputs.insert(inputs.begin(), replacement.begin(), replacement.end());
55       ++numRewrites;
56       break;
57     }
58 
59     if (!foundMatch) {
60       // If no match, pass along the current operator.
61       result.push_back(inputs.front());
62       inputs.pop_front();
63     }
64   }
65 
66   if (maxNumRewrites && numRewrites >= *maxNumRewrites) {
67     LLVM_DEBUG(llvm::dbgs()
68                << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
69                << maxNumRewrites << ")\n");
70     // Skip rewriting the rest.
71     result.append(inputs.begin(), inputs.end());
72   }
73 
74   return LLVM::DIExpressionAttr::get(expr.getContext(), result);
75 }
76