xref: /llvm-project/mlir/include/mlir/Transforms/FoldUtils.h (revision 34a65980d7d2e1b05e3fc88535cafe606ee55e04)
1 //===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various operation folding utilities. These
10 // utilities are intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TRANSFORMS_FOLDUTILS_H
15 #define MLIR_TRANSFORMS_FOLDUTILS_H
16 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/DialectInterface.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Interfaces/FoldInterfaces.h"
22 
23 namespace mlir {
24 class Operation;
25 class Value;
26 
27 //===--------------------------------------------------------------------===//
28 // OperationFolder
29 //===--------------------------------------------------------------------===//
30 
31 /// A utility class for folding operations, and unifying duplicated constants
32 /// generated along the way.
33 class OperationFolder {
34 public:
35   OperationFolder(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr)
erasedFoldedLocation(UnknownLoc::get (ctx))36       : erasedFoldedLocation(UnknownLoc::get(ctx)), interfaces(ctx),
37         rewriter(ctx, listener) {}
38 
39   /// Tries to perform folding on the given `op`, including unifying
40   /// deduplicated constants. If successful, replaces `op`'s uses with
41   /// folded results, and returns success. If the op was completely folded it is
42   /// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
43   LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate = nullptr);
44 
45   /// Tries to fold a pre-existing constant operation. `constValue` represents
46   /// the value of the constant, and can be optionally passed if the value is
47   /// already known (e.g. if the constant was discovered by m_Constant). This is
48   /// purely an optimization opportunity for callers that already know the value
49   /// of the constant. Returns false if an existing constant for `op` already
50   /// exists in the folder, in which case `op` is replaced and erased.
51   /// Otherwise, returns true and `op` is inserted into the folder (and
52   /// hoisted if necessary).
53   bool insertKnownConstant(Operation *op, Attribute constValue = {});
54 
55   /// Notifies that the given constant `op` should be remove from this
56   /// OperationFolder's internal bookkeeping.
57   ///
58   /// Note: this method must be called if a constant op is to be deleted
59   /// externally to this OperationFolder. `op` must be a constant op.
60   void notifyRemoval(Operation *op);
61 
62   /// Clear out any constants cached inside of the folder.
63   void clear();
64 
65   /// Get or create a constant for use in the specified block. The constant may
66   /// be created in a parent block. On success this returns the constant
67   /// operation, nullptr otherwise.
68   Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value,
69                             Type type);
70 
71 private:
72   /// This map keeps track of uniqued constants by dialect, attribute, and type.
73   /// A constant operation materializes an attribute with a type. Dialects may
74   /// generate different constants with the same input attribute and type, so we
75   /// also need to track per-dialect.
76   using ConstantMap =
77       DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
78 
79   /// Returns true if the given operation is an already folded constant that is
80   /// owned by this folder.
81   bool isFolderOwnedConstant(Operation *op) const;
82 
83   /// Tries to perform folding on the given `op`. If successful, populates
84   /// `results` with the results of the folding.
85   LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value> &results);
86 
87   /// Try to process a set of fold results. Populates `results` on success,
88   /// otherwise leaves it unchanged.
89   LogicalResult processFoldResults(Operation *op,
90                                    SmallVectorImpl<Value> &results,
91                                    ArrayRef<OpFoldResult> foldResults);
92 
93   /// Try to get or create a new constant entry. On success this returns the
94   /// constant operation, nullptr otherwise.
95   Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,
96                                     Dialect *dialect, Attribute value,
97                                     Type type, Location loc);
98 
99   /// The location to overwrite with for folder-owned constants.
100   UnknownLoc erasedFoldedLocation;
101 
102   /// A mapping between an insertion region and the constants that have been
103   /// created within it.
104   DenseMap<Region *, ConstantMap> foldScopes;
105 
106   /// This map tracks all of the dialects that an operation is referenced by;
107   /// given that many dialects may generate the same constant.
108   DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects;
109 
110   /// A collection of dialect folder interfaces.
111   DialectInterfaceCollection<DialectFoldInterface> interfaces;
112 
113   /// A rewriter that performs all IR modifications.
114   IRRewriter rewriter;
115 };
116 
117 } // namespace mlir
118 
119 #endif // MLIR_TRANSFORMS_FOLDUTILS_H
120