1c095afcbSMogball //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===// 2c095afcbSMogball // 3c095afcbSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c095afcbSMogball // See https://llvm.org/LICENSE.txt for license information. 5c095afcbSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c095afcbSMogball // 7c095afcbSMogball //===----------------------------------------------------------------------===// 8c095afcbSMogball 9c095afcbSMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 108a515180SMehdi Amini #include "mlir/Analysis/DataFlow/SparseAnalysis.h" 118a515180SMehdi Amini #include "mlir/IR/BuiltinAttributes.h" 129432fbfeSMogball #include "mlir/IR/OpDefinition.h" 138a515180SMehdi Amini #include "mlir/IR/Operation.h" 148a515180SMehdi Amini #include "mlir/IR/Value.h" 158a515180SMehdi Amini #include "mlir/Support/LLVM.h" 168a515180SMehdi Amini #include "llvm/ADT/STLExtras.h" 178a515180SMehdi Amini #include "llvm/Support/Casting.h" 189432fbfeSMogball #include "llvm/Support/Debug.h" 198a515180SMehdi Amini #include <cassert> 209432fbfeSMogball 219432fbfeSMogball #define DEBUG_TYPE "constant-propagation" 22c095afcbSMogball 23c095afcbSMogball using namespace mlir; 24c095afcbSMogball using namespace mlir::dataflow; 25c095afcbSMogball 26c095afcbSMogball //===----------------------------------------------------------------------===// 27c095afcbSMogball // ConstantValue 28c095afcbSMogball //===----------------------------------------------------------------------===// 29c095afcbSMogball 30c095afcbSMogball void ConstantValue::print(raw_ostream &os) const { 3147bf3e38SZhixun Tan if (isUninitialized()) { 3247bf3e38SZhixun Tan os << "<UNINITIALIZED>"; 3347bf3e38SZhixun Tan return; 3447bf3e38SZhixun Tan } 3547bf3e38SZhixun Tan if (getConstantValue() == nullptr) { 3647bf3e38SZhixun Tan os << "<UNKNOWN>"; 3747bf3e38SZhixun Tan return; 3847bf3e38SZhixun Tan } 3947bf3e38SZhixun Tan return getConstantValue().print(os); 40c095afcbSMogball } 419432fbfeSMogball 429432fbfeSMogball //===----------------------------------------------------------------------===// 439432fbfeSMogball // SparseConstantPropagation 449432fbfeSMogball //===----------------------------------------------------------------------===// 459432fbfeSMogball 4615e915a4SIvan Butygin LogicalResult SparseConstantPropagation::visitOperation( 479432fbfeSMogball Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands, 489432fbfeSMogball ArrayRef<Lattice<ConstantValue> *> results) { 499432fbfeSMogball LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n"); 509432fbfeSMogball 519432fbfeSMogball // Don't try to simulate the results of a region operation as we can't 529432fbfeSMogball // guarantee that folding will be out-of-place. We don't allow in-place 539432fbfeSMogball // folds as the desire here is for simulated execution, and not general 549432fbfeSMogball // folding. 55052669e7SJacques Pienaar if (op->getNumRegions()) { 56de0ebc52SZhixun Tan setAllToEntryStates(results); 5715e915a4SIvan Butygin return success(); 58052669e7SJacques Pienaar } 599432fbfeSMogball 609432fbfeSMogball SmallVector<Attribute, 8> constantOperands; 619432fbfeSMogball constantOperands.reserve(op->getNumOperands()); 6247bf3e38SZhixun Tan for (auto *operandLattice : operands) { 6347bf3e38SZhixun Tan if (operandLattice->getValue().isUninitialized()) 6415e915a4SIvan Butygin return success(); 659432fbfeSMogball constantOperands.push_back(operandLattice->getValue().getConstantValue()); 6647bf3e38SZhixun Tan } 679432fbfeSMogball 689432fbfeSMogball // Save the original operands and attributes just in case the operation 699432fbfeSMogball // folds in-place. The constant passed in may not correspond to the real 709432fbfeSMogball // runtime value, so in-place updates are not allowed. 719432fbfeSMogball SmallVector<Value, 8> originalOperands(op->getOperands()); 729432fbfeSMogball DictionaryAttr originalAttrs = op->getAttrDictionary(); 739432fbfeSMogball 749432fbfeSMogball // Simulate the result of folding this operation to a constant. If folding 759432fbfeSMogball // fails or was an in-place fold, mark the results as overdefined. 769432fbfeSMogball SmallVector<OpFoldResult, 8> foldResults; 779432fbfeSMogball foldResults.reserve(op->getNumResults()); 789432fbfeSMogball if (failed(op->fold(constantOperands, foldResults))) { 79de0ebc52SZhixun Tan setAllToEntryStates(results); 8015e915a4SIvan Butygin return success(); 819432fbfeSMogball } 829432fbfeSMogball 839432fbfeSMogball // If the folding was in-place, mark the results as overdefined and reset 849432fbfeSMogball // the operation. We don't allow in-place folds as the desire here is for 859432fbfeSMogball // simulated execution, and not general folding. 869432fbfeSMogball if (foldResults.empty()) { 879432fbfeSMogball op->setOperands(originalOperands); 889432fbfeSMogball op->setAttrs(originalAttrs); 89de0ebc52SZhixun Tan setAllToEntryStates(results); 9015e915a4SIvan Butygin return success(); 919432fbfeSMogball } 929432fbfeSMogball 939432fbfeSMogball // Merge the fold results into the lattice for this operation. 949432fbfeSMogball assert(foldResults.size() == op->getNumResults() && "invalid result size"); 959432fbfeSMogball for (const auto it : llvm::zip(results, foldResults)) { 969432fbfeSMogball Lattice<ConstantValue> *lattice = std::get<0>(it); 979432fbfeSMogball 989432fbfeSMogball // Merge in the result of the fold, either a constant or a value. 999432fbfeSMogball OpFoldResult foldResult = std::get<1>(it); 10068f58812STres Popp if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) { 1019432fbfeSMogball LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n"); 1029432fbfeSMogball propagateIfChanged(lattice, 1039432fbfeSMogball lattice->join(ConstantValue(attr, op->getDialect()))); 1049432fbfeSMogball } else { 1059432fbfeSMogball LLVM_DEBUG(llvm::dbgs() 106*b5c5c2b2SKazu Hirata << "Folded to value: " << cast<Value>(foldResult) << "\n"); 107b2b7efb9SAlex Zinenko AbstractSparseForwardDataFlowAnalysis::join( 108*b5c5c2b2SKazu Hirata lattice, *getLatticeElement(cast<Value>(foldResult))); 1099432fbfeSMogball } 1109432fbfeSMogball } 11115e915a4SIvan Butygin return success(); 1129432fbfeSMogball } 113de0ebc52SZhixun Tan 114de0ebc52SZhixun Tan void SparseConstantPropagation::setToEntryState( 115de0ebc52SZhixun Tan Lattice<ConstantValue> *lattice) { 116de0ebc52SZhixun Tan propagateIfChanged(lattice, 117de0ebc52SZhixun Tan lattice->join(ConstantValue::getUnknownConstant())); 118de0ebc52SZhixun Tan } 119