1 //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 10 #include "mlir/Analysis/DataFlow/SparseAnalysis.h" 11 #include "mlir/IR/BuiltinAttributes.h" 12 #include "mlir/IR/OpDefinition.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/IR/Value.h" 15 #include "mlir/Support/LLVM.h" 16 #include "llvm/ADT/STLExtras.h" 17 #include "llvm/Support/Casting.h" 18 #include "llvm/Support/Debug.h" 19 #include <cassert> 20 21 #define DEBUG_TYPE "constant-propagation" 22 23 using namespace mlir; 24 using namespace mlir::dataflow; 25 26 //===----------------------------------------------------------------------===// 27 // ConstantValue 28 //===----------------------------------------------------------------------===// 29 30 void ConstantValue::print(raw_ostream &os) const { 31 if (isUninitialized()) { 32 os << "<UNINITIALIZED>"; 33 return; 34 } 35 if (getConstantValue() == nullptr) { 36 os << "<UNKNOWN>"; 37 return; 38 } 39 return getConstantValue().print(os); 40 } 41 42 //===----------------------------------------------------------------------===// 43 // SparseConstantPropagation 44 //===----------------------------------------------------------------------===// 45 46 LogicalResult SparseConstantPropagation::visitOperation( 47 Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands, 48 ArrayRef<Lattice<ConstantValue> *> results) { 49 LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n"); 50 51 // Don't try to simulate the results of a region operation as we can't 52 // guarantee that folding will be out-of-place. We don't allow in-place 53 // folds as the desire here is for simulated execution, and not general 54 // folding. 55 if (op->getNumRegions()) { 56 setAllToEntryStates(results); 57 return success(); 58 } 59 60 SmallVector<Attribute, 8> constantOperands; 61 constantOperands.reserve(op->getNumOperands()); 62 for (auto *operandLattice : operands) { 63 if (operandLattice->getValue().isUninitialized()) 64 return success(); 65 constantOperands.push_back(operandLattice->getValue().getConstantValue()); 66 } 67 68 // Save the original operands and attributes just in case the operation 69 // folds in-place. The constant passed in may not correspond to the real 70 // runtime value, so in-place updates are not allowed. 71 SmallVector<Value, 8> originalOperands(op->getOperands()); 72 DictionaryAttr originalAttrs = op->getAttrDictionary(); 73 74 // Simulate the result of folding this operation to a constant. If folding 75 // fails or was an in-place fold, mark the results as overdefined. 76 SmallVector<OpFoldResult, 8> foldResults; 77 foldResults.reserve(op->getNumResults()); 78 if (failed(op->fold(constantOperands, foldResults))) { 79 setAllToEntryStates(results); 80 return success(); 81 } 82 83 // If the folding was in-place, mark the results as overdefined and reset 84 // the operation. We don't allow in-place folds as the desire here is for 85 // simulated execution, and not general folding. 86 if (foldResults.empty()) { 87 op->setOperands(originalOperands); 88 op->setAttrs(originalAttrs); 89 setAllToEntryStates(results); 90 return success(); 91 } 92 93 // Merge the fold results into the lattice for this operation. 94 assert(foldResults.size() == op->getNumResults() && "invalid result size"); 95 for (const auto it : llvm::zip(results, foldResults)) { 96 Lattice<ConstantValue> *lattice = std::get<0>(it); 97 98 // Merge in the result of the fold, either a constant or a value. 99 OpFoldResult foldResult = std::get<1>(it); 100 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) { 101 LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n"); 102 propagateIfChanged(lattice, 103 lattice->join(ConstantValue(attr, op->getDialect()))); 104 } else { 105 LLVM_DEBUG(llvm::dbgs() 106 << "Folded to value: " << cast<Value>(foldResult) << "\n"); 107 AbstractSparseForwardDataFlowAnalysis::join( 108 lattice, *getLatticeElement(cast<Value>(foldResult))); 109 } 110 } 111 return success(); 112 } 113 114 void SparseConstantPropagation::setToEntryState( 115 Lattice<ConstantValue> *lattice) { 116 propagateIfChanged(lattice, 117 lattice->join(ConstantValue::getUnknownConstant())); 118 } 119