xref: /llvm-project/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h (revision 15e915a44f0d0bf092214586d3ec86e2bb7636d7)
1 //===- ConstantPropagationAnalysis.h - Constant propagation analysis ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements constant propagation analysis. In this file are defined
10 // the lattice value class that represents constant values in the program and
11 // a sparse constant propagation analysis that uses operation folders to
12 // speculate about constant values in the program.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_ANALYSIS_DATAFLOW_CONSTANTPROPAGATIONANALYSIS_H
17 #define MLIR_ANALYSIS_DATAFLOW_CONSTANTPROPAGATIONANALYSIS_H
18 
19 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
20 #include <optional>
21 
22 namespace mlir {
23 namespace dataflow {
24 
25 //===----------------------------------------------------------------------===//
26 // ConstantValue
27 //===----------------------------------------------------------------------===//
28 
29 /// This lattice value represents a known constant value of a lattice.
30 class ConstantValue {
31 public:
32   /// Construct a constant value as uninitialized.
33   explicit ConstantValue() = default;
34 
35   /// Construct a constant value with a known constant.
36   explicit ConstantValue(Attribute constant, Dialect *dialect)
37       : constant(constant), dialect(dialect) {}
38 
39   /// Get the constant value. Returns null if no value was determined.
40   Attribute getConstantValue() const {
41     assert(!isUninitialized());
42     return *constant;
43   }
44 
45   /// Get the dialect instance that can be used to materialize the constant.
46   Dialect *getConstantDialect() const {
47     assert(!isUninitialized());
48     return dialect;
49   }
50 
51   /// Compare the constant values.
52   bool operator==(const ConstantValue &rhs) const {
53     return constant == rhs.constant;
54   }
55 
56   /// Print the constant value.
57   void print(raw_ostream &os) const;
58 
59   /// The state where the constant value is uninitialized. This happens when the
60   /// state hasn't been set during the analysis.
61   static ConstantValue getUninitialized() { return ConstantValue{}; }
62 
63   /// Whether the state is uninitialized.
64   bool isUninitialized() const { return !constant.has_value(); }
65 
66   /// The state where the constant value is unknown.
67   static ConstantValue getUnknownConstant() {
68     return ConstantValue{/*constant=*/nullptr, /*dialect=*/nullptr};
69   }
70 
71   /// The union with another constant value is null if they are different, and
72   /// the same if they are the same.
73   static ConstantValue join(const ConstantValue &lhs,
74                             const ConstantValue &rhs) {
75     if (lhs.isUninitialized())
76       return rhs;
77     if (rhs.isUninitialized())
78       return lhs;
79     if (lhs == rhs)
80       return lhs;
81     return getUnknownConstant();
82   }
83 
84 private:
85   /// The constant value.
86   std::optional<Attribute> constant;
87   /// A dialect instance that can be used to materialize the constant.
88   Dialect *dialect = nullptr;
89 };
90 
91 //===----------------------------------------------------------------------===//
92 // SparseConstantPropagation
93 //===----------------------------------------------------------------------===//
94 
95 /// This analysis implements sparse constant propagation, which attempts to
96 /// determine constant-valued results for operations using constant-valued
97 /// operands, by speculatively folding operations. When combined with dead-code
98 /// analysis, this becomes sparse conditional constant propagation (SCCP).
99 class SparseConstantPropagation
100     : public SparseForwardDataFlowAnalysis<Lattice<ConstantValue>> {
101 public:
102   using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
103 
104   LogicalResult
105   visitOperation(Operation *op,
106                  ArrayRef<const Lattice<ConstantValue> *> operands,
107                  ArrayRef<Lattice<ConstantValue> *> results) override;
108 
109   void setToEntryState(Lattice<ConstantValue> *lattice) override;
110 };
111 
112 } // end namespace dataflow
113 } // end namespace mlir
114 
115 #endif // MLIR_ANALYSIS_DATAFLOW_CONSTANTPROPAGATIONANALYSIS_H
116