xref: /llvm-project/mlir/lib/Transforms/SCCP.cpp (revision 34a65980d7d2e1b05e3fc88535cafe606ee55e04)
1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Transforms/Passes.h"
18 
19 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
20 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Dialect.h"
23 #include "mlir/Interfaces/SideEffectInterfaces.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/FoldUtils.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_SCCP
29 #include "mlir/Transforms/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::dataflow;
34 
35 //===----------------------------------------------------------------------===//
36 // SCCP Rewrites
37 //===----------------------------------------------------------------------===//
38 
39 /// Replace the given value with a constant if the corresponding lattice
40 /// represents a constant. Returns success if the value was replaced, failure
41 /// otherwise.
replaceWithConstant(DataFlowSolver & solver,OpBuilder & builder,OperationFolder & folder,Value value)42 static LogicalResult replaceWithConstant(DataFlowSolver &solver,
43                                          OpBuilder &builder,
44                                          OperationFolder &folder, Value value) {
45   auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
46   if (!lattice || lattice->getValue().isUninitialized())
47     return failure();
48   const ConstantValue &latticeValue = lattice->getValue();
49   if (!latticeValue.getConstantValue())
50     return failure();
51 
52   // Attempt to materialize a constant for the given value.
53   Dialect *dialect = latticeValue.getConstantDialect();
54   Value constant = folder.getOrCreateConstant(
55       builder.getInsertionBlock(), dialect, latticeValue.getConstantValue(),
56       value.getType());
57   if (!constant)
58     return failure();
59 
60   value.replaceAllUsesWith(constant);
61   return success();
62 }
63 
64 /// Rewrite the given regions using the computing analysis. This replaces the
65 /// uses of all values that have been computed to be constant, and erases as
66 /// many newly dead operations.
rewrite(DataFlowSolver & solver,MLIRContext * context,MutableArrayRef<Region> initialRegions)67 static void rewrite(DataFlowSolver &solver, MLIRContext *context,
68                     MutableArrayRef<Region> initialRegions) {
69   SmallVector<Block *> worklist;
70   auto addToWorklist = [&](MutableArrayRef<Region> regions) {
71     for (Region &region : regions)
72       for (Block &block : llvm::reverse(region))
73         worklist.push_back(&block);
74   };
75 
76   // An operation folder used to create and unique constants.
77   OperationFolder folder(context);
78   OpBuilder builder(context);
79 
80   addToWorklist(initialRegions);
81   while (!worklist.empty()) {
82     Block *block = worklist.pop_back_val();
83 
84     for (Operation &op : llvm::make_early_inc_range(*block)) {
85       builder.setInsertionPoint(&op);
86 
87       // Replace any result with constants.
88       bool replacedAll = op.getNumResults() != 0;
89       for (Value res : op.getResults())
90         replacedAll &=
91             succeeded(replaceWithConstant(solver, builder, folder, res));
92 
93       // If all of the results of the operation were replaced, try to erase
94       // the operation completely.
95       if (replacedAll && wouldOpBeTriviallyDead(&op)) {
96         assert(op.use_empty() && "expected all uses to be replaced");
97         op.erase();
98         continue;
99       }
100 
101       // Add any the regions of this operation to the worklist.
102       addToWorklist(op.getRegions());
103     }
104 
105     // Replace any block arguments with constants.
106     builder.setInsertionPointToStart(block);
107     for (BlockArgument arg : block->getArguments())
108       (void)replaceWithConstant(solver, builder, folder, arg);
109   }
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // SCCP Pass
114 //===----------------------------------------------------------------------===//
115 
116 namespace {
117 struct SCCP : public impl::SCCPBase<SCCP> {
118   void runOnOperation() override;
119 };
120 } // namespace
121 
runOnOperation()122 void SCCP::runOnOperation() {
123   Operation *op = getOperation();
124 
125   DataFlowSolver solver;
126   solver.load<DeadCodeAnalysis>();
127   solver.load<SparseConstantPropagation>();
128   if (failed(solver.initializeAndRun(op)))
129     return signalPassFailure();
130   rewrite(solver, op->getContext(), op->getRegions());
131 }
132 
createSCCPPass()133 std::unique_ptr<Pass> mlir::createSCCPPass() {
134   return std::make_unique<SCCP>();
135 }
136