xref: /llvm-project/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp (revision b5c5c2b26fd4bd0d0d237aaf77a01ca528810707)
1c095afcbSMogball //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
2c095afcbSMogball //
3c095afcbSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c095afcbSMogball // See https://llvm.org/LICENSE.txt for license information.
5c095afcbSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c095afcbSMogball //
7c095afcbSMogball //===----------------------------------------------------------------------===//
8c095afcbSMogball 
9c095afcbSMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
108a515180SMehdi Amini #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
118a515180SMehdi Amini #include "mlir/IR/BuiltinAttributes.h"
129432fbfeSMogball #include "mlir/IR/OpDefinition.h"
138a515180SMehdi Amini #include "mlir/IR/Operation.h"
148a515180SMehdi Amini #include "mlir/IR/Value.h"
158a515180SMehdi Amini #include "mlir/Support/LLVM.h"
168a515180SMehdi Amini #include "llvm/ADT/STLExtras.h"
178a515180SMehdi Amini #include "llvm/Support/Casting.h"
189432fbfeSMogball #include "llvm/Support/Debug.h"
198a515180SMehdi Amini #include <cassert>
209432fbfeSMogball 
219432fbfeSMogball #define DEBUG_TYPE "constant-propagation"
22c095afcbSMogball 
23c095afcbSMogball using namespace mlir;
24c095afcbSMogball using namespace mlir::dataflow;
25c095afcbSMogball 
26c095afcbSMogball //===----------------------------------------------------------------------===//
27c095afcbSMogball // ConstantValue
28c095afcbSMogball //===----------------------------------------------------------------------===//
29c095afcbSMogball 
30c095afcbSMogball void ConstantValue::print(raw_ostream &os) const {
3147bf3e38SZhixun Tan   if (isUninitialized()) {
3247bf3e38SZhixun Tan     os << "<UNINITIALIZED>";
3347bf3e38SZhixun Tan     return;
3447bf3e38SZhixun Tan   }
3547bf3e38SZhixun Tan   if (getConstantValue() == nullptr) {
3647bf3e38SZhixun Tan     os << "<UNKNOWN>";
3747bf3e38SZhixun Tan     return;
3847bf3e38SZhixun Tan   }
3947bf3e38SZhixun Tan   return getConstantValue().print(os);
40c095afcbSMogball }
419432fbfeSMogball 
429432fbfeSMogball //===----------------------------------------------------------------------===//
439432fbfeSMogball // SparseConstantPropagation
449432fbfeSMogball //===----------------------------------------------------------------------===//
459432fbfeSMogball 
4615e915a4SIvan Butygin LogicalResult SparseConstantPropagation::visitOperation(
479432fbfeSMogball     Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
489432fbfeSMogball     ArrayRef<Lattice<ConstantValue> *> results) {
499432fbfeSMogball   LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
509432fbfeSMogball 
519432fbfeSMogball   // Don't try to simulate the results of a region operation as we can't
529432fbfeSMogball   // guarantee that folding will be out-of-place. We don't allow in-place
539432fbfeSMogball   // folds as the desire here is for simulated execution, and not general
549432fbfeSMogball   // folding.
55052669e7SJacques Pienaar   if (op->getNumRegions()) {
56de0ebc52SZhixun Tan     setAllToEntryStates(results);
5715e915a4SIvan Butygin     return success();
58052669e7SJacques Pienaar   }
599432fbfeSMogball 
609432fbfeSMogball   SmallVector<Attribute, 8> constantOperands;
619432fbfeSMogball   constantOperands.reserve(op->getNumOperands());
6247bf3e38SZhixun Tan   for (auto *operandLattice : operands) {
6347bf3e38SZhixun Tan     if (operandLattice->getValue().isUninitialized())
6415e915a4SIvan Butygin       return success();
659432fbfeSMogball     constantOperands.push_back(operandLattice->getValue().getConstantValue());
6647bf3e38SZhixun Tan   }
679432fbfeSMogball 
689432fbfeSMogball   // Save the original operands and attributes just in case the operation
699432fbfeSMogball   // folds in-place. The constant passed in may not correspond to the real
709432fbfeSMogball   // runtime value, so in-place updates are not allowed.
719432fbfeSMogball   SmallVector<Value, 8> originalOperands(op->getOperands());
729432fbfeSMogball   DictionaryAttr originalAttrs = op->getAttrDictionary();
739432fbfeSMogball 
749432fbfeSMogball   // Simulate the result of folding this operation to a constant. If folding
759432fbfeSMogball   // fails or was an in-place fold, mark the results as overdefined.
769432fbfeSMogball   SmallVector<OpFoldResult, 8> foldResults;
779432fbfeSMogball   foldResults.reserve(op->getNumResults());
789432fbfeSMogball   if (failed(op->fold(constantOperands, foldResults))) {
79de0ebc52SZhixun Tan     setAllToEntryStates(results);
8015e915a4SIvan Butygin     return success();
819432fbfeSMogball   }
829432fbfeSMogball 
839432fbfeSMogball   // If the folding was in-place, mark the results as overdefined and reset
849432fbfeSMogball   // the operation. We don't allow in-place folds as the desire here is for
859432fbfeSMogball   // simulated execution, and not general folding.
869432fbfeSMogball   if (foldResults.empty()) {
879432fbfeSMogball     op->setOperands(originalOperands);
889432fbfeSMogball     op->setAttrs(originalAttrs);
89de0ebc52SZhixun Tan     setAllToEntryStates(results);
9015e915a4SIvan Butygin     return success();
919432fbfeSMogball   }
929432fbfeSMogball 
939432fbfeSMogball   // Merge the fold results into the lattice for this operation.
949432fbfeSMogball   assert(foldResults.size() == op->getNumResults() && "invalid result size");
959432fbfeSMogball   for (const auto it : llvm::zip(results, foldResults)) {
969432fbfeSMogball     Lattice<ConstantValue> *lattice = std::get<0>(it);
979432fbfeSMogball 
989432fbfeSMogball     // Merge in the result of the fold, either a constant or a value.
999432fbfeSMogball     OpFoldResult foldResult = std::get<1>(it);
10068f58812STres Popp     if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
1019432fbfeSMogball       LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
1029432fbfeSMogball       propagateIfChanged(lattice,
1039432fbfeSMogball                          lattice->join(ConstantValue(attr, op->getDialect())));
1049432fbfeSMogball     } else {
1059432fbfeSMogball       LLVM_DEBUG(llvm::dbgs()
106*b5c5c2b2SKazu Hirata                  << "Folded to value: " << cast<Value>(foldResult) << "\n");
107b2b7efb9SAlex Zinenko       AbstractSparseForwardDataFlowAnalysis::join(
108*b5c5c2b2SKazu Hirata           lattice, *getLatticeElement(cast<Value>(foldResult)));
1099432fbfeSMogball     }
1109432fbfeSMogball   }
11115e915a4SIvan Butygin   return success();
1129432fbfeSMogball }
113de0ebc52SZhixun Tan 
114de0ebc52SZhixun Tan void SparseConstantPropagation::setToEntryState(
115de0ebc52SZhixun Tan     Lattice<ConstantValue> *lattice) {
116de0ebc52SZhixun Tan   propagateIfChanged(lattice,
117de0ebc52SZhixun Tan                      lattice->join(ConstantValue::getUnknownConstant()));
118de0ebc52SZhixun Tan }
119