xref: /llvm-project/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp (revision b5c5c2b26fd4bd0d0d237aaf77a01ca528810707)
1 //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/OpDefinition.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/Support/Casting.h"
18 #include "llvm/Support/Debug.h"
19 #include <cassert>
20 
21 #define DEBUG_TYPE "constant-propagation"
22 
23 using namespace mlir;
24 using namespace mlir::dataflow;
25 
26 //===----------------------------------------------------------------------===//
27 // ConstantValue
28 //===----------------------------------------------------------------------===//
29 
30 void ConstantValue::print(raw_ostream &os) const {
31   if (isUninitialized()) {
32     os << "<UNINITIALIZED>";
33     return;
34   }
35   if (getConstantValue() == nullptr) {
36     os << "<UNKNOWN>";
37     return;
38   }
39   return getConstantValue().print(os);
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // SparseConstantPropagation
44 //===----------------------------------------------------------------------===//
45 
46 LogicalResult SparseConstantPropagation::visitOperation(
47     Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
48     ArrayRef<Lattice<ConstantValue> *> results) {
49   LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
50 
51   // Don't try to simulate the results of a region operation as we can't
52   // guarantee that folding will be out-of-place. We don't allow in-place
53   // folds as the desire here is for simulated execution, and not general
54   // folding.
55   if (op->getNumRegions()) {
56     setAllToEntryStates(results);
57     return success();
58   }
59 
60   SmallVector<Attribute, 8> constantOperands;
61   constantOperands.reserve(op->getNumOperands());
62   for (auto *operandLattice : operands) {
63     if (operandLattice->getValue().isUninitialized())
64       return success();
65     constantOperands.push_back(operandLattice->getValue().getConstantValue());
66   }
67 
68   // Save the original operands and attributes just in case the operation
69   // folds in-place. The constant passed in may not correspond to the real
70   // runtime value, so in-place updates are not allowed.
71   SmallVector<Value, 8> originalOperands(op->getOperands());
72   DictionaryAttr originalAttrs = op->getAttrDictionary();
73 
74   // Simulate the result of folding this operation to a constant. If folding
75   // fails or was an in-place fold, mark the results as overdefined.
76   SmallVector<OpFoldResult, 8> foldResults;
77   foldResults.reserve(op->getNumResults());
78   if (failed(op->fold(constantOperands, foldResults))) {
79     setAllToEntryStates(results);
80     return success();
81   }
82 
83   // If the folding was in-place, mark the results as overdefined and reset
84   // the operation. We don't allow in-place folds as the desire here is for
85   // simulated execution, and not general folding.
86   if (foldResults.empty()) {
87     op->setOperands(originalOperands);
88     op->setAttrs(originalAttrs);
89     setAllToEntryStates(results);
90     return success();
91   }
92 
93   // Merge the fold results into the lattice for this operation.
94   assert(foldResults.size() == op->getNumResults() && "invalid result size");
95   for (const auto it : llvm::zip(results, foldResults)) {
96     Lattice<ConstantValue> *lattice = std::get<0>(it);
97 
98     // Merge in the result of the fold, either a constant or a value.
99     OpFoldResult foldResult = std::get<1>(it);
100     if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
101       LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
102       propagateIfChanged(lattice,
103                          lattice->join(ConstantValue(attr, op->getDialect())));
104     } else {
105       LLVM_DEBUG(llvm::dbgs()
106                  << "Folded to value: " << cast<Value>(foldResult) << "\n");
107       AbstractSparseForwardDataFlowAnalysis::join(
108           lattice, *getLatticeElement(cast<Value>(foldResult)));
109     }
110   }
111   return success();
112 }
113 
114 void SparseConstantPropagation::setToEntryState(
115     Lattice<ConstantValue> *lattice) {
116   propagateIfChanged(lattice,
117                      lattice->join(ConstantValue::getUnknownConstant()));
118 }
119