xref: /llvm-project/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (revision 6ecebb496cc6960e100a05375ab7f64e831dd933)
1 //===- UnstructuredControlFlow.h - Op Interface Helpers ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
10 #define MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
11 
12 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13 #include "mlir/Interfaces/ControlFlowInterfaces.h"
14 
15 //===----------------------------------------------------------------------===//
16 // Helpers for Unstructured Control Flow
17 //===----------------------------------------------------------------------===//
18 
19 namespace mlir {
20 namespace bufferization {
21 
22 namespace detail {
23 /// Return a list of operands that are forwarded to the given block argument.
24 /// I.e., find all predecessors of the block argument's owner and gather the
25 /// operands that are equivalent to the block argument.
26 SmallVector<OpOperand *> getCallerOpOperands(BlockArgument bbArg);
27 } // namespace detail
28 
29 /// A template that provides a default implementation of `getAliasingOpOperands`
30 /// for ops that support unstructured control flow within their regions.
31 template <typename ConcreteModel, typename ConcreteOp>
32 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
33     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
34 
35   FailureOr<BaseMemRefType>
getBufferTypeOpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel36   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
37                 SmallVector<Value> &invocationStack) const {
38     // Note: The user may want to override this function for OpResults in
39     // case the bufferized result type is different from the bufferized type of
40     // the aliasing OpOperand (if any).
41     if (isa<OpResult>(value))
42       return bufferization::detail::defaultGetBufferType(value, options,
43                                                          invocationStack);
44 
45     // Compute the buffer type of the block argument by computing the bufferized
46     // operand types of all forwarded values. If these are all the same type,
47     // take that type. Otherwise, take only the memory space and fall back to a
48     // buffer type with a fully dynamic layout map.
49     BaseMemRefType bufferType;
50     auto tensorType = cast<TensorType>(value.getType());
51     for (OpOperand *opOperand :
52          detail::getCallerOpOperands(cast<BlockArgument>(value))) {
53 
54       // If the forwarded operand is already on the invocation stack, we ran
55       // into a loop and this operand cannot be used to compute the bufferized
56       // type.
57       if (llvm::find(invocationStack, opOperand->get()) !=
58           invocationStack.end())
59         continue;
60 
61       // Compute the bufferized type of the forwarded operand.
62       BaseMemRefType callerType;
63       if (auto memrefType =
64               dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
65         // The operand was already bufferized. Take its type directly.
66         callerType = memrefType;
67       } else {
68         FailureOr<BaseMemRefType> maybeCallerType =
69             bufferization::getBufferType(opOperand->get(), options,
70                                          invocationStack);
71         if (failed(maybeCallerType))
72           return failure();
73         callerType = *maybeCallerType;
74       }
75 
76       if (!bufferType) {
77         // This is the first buffer type that we computed.
78         bufferType = callerType;
79         continue;
80       }
81 
82       if (bufferType == callerType)
83         continue;
84 
85         // If the computed buffer type does not match the computed buffer type
86         // of the earlier forwarded operands, fall back to a buffer type with a
87         // fully dynamic layout map.
88 #ifndef NDEBUG
89       if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
90         assert(bufferType.hasRank() && callerType.hasRank() &&
91                "expected ranked memrefs");
92         assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
93                                 rankedTensorType.getShape()}) &&
94                "expected same shape");
95       } else {
96         assert(!bufferType.hasRank() && !callerType.hasRank() &&
97                "expected unranked memrefs");
98       }
99 #endif // NDEBUG
100 
101       if (bufferType.getMemorySpace() != callerType.getMemorySpace())
102         return op->emitOpError("incoming operands of block argument have "
103                                "inconsistent memory spaces");
104 
105       bufferType = getMemRefTypeWithFullyDynamicLayout(
106           tensorType, bufferType.getMemorySpace());
107     }
108 
109     if (!bufferType)
110       return op->emitOpError("could not infer buffer type of block argument");
111 
112     return bufferType;
113   }
114 
115 protected:
116   /// Assuming that `bbArg` is a block argument of a block that belongs to the
117   /// given `op`, return all OpOperands of users of this block that are
118   /// aliasing with the given block argument.
119   AliasingOpOperandList
getAliasingBranchOpOperandsOpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel120   getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg,
121                               const AnalysisState &state) const {
122     assert(bbArg.getOwner()->getParentOp() == op && "invalid bbArg");
123 
124     // Gather aliasing OpOperands of all operations (callers) that link to
125     // this block.
126     AliasingOpOperandList result;
127     for (OpOperand *opOperand : detail::getCallerOpOperands(bbArg))
128       result.addAlias(
129           {opOperand, BufferRelation::Equivalent, /*isDefinite=*/false});
130 
131     return result;
132   }
133 };
134 
135 /// A template that provides a default implementation of `getAliasingValues`
136 /// for ops that implement the `BranchOpInterface`.
137 template <typename ConcreteModel, typename ConcreteOp>
138 struct BranchOpBufferizableOpInterfaceExternalModel
139     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
getAliasingValuesBranchOpBufferizableOpInterfaceExternalModel140   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
141                                       const AnalysisState &state) const {
142     AliasingValueList result;
143     auto branchOp = cast<BranchOpInterface>(op);
144     auto operandNumber = opOperand.getOperandNumber();
145 
146     // Gather aliasing block arguments of blocks to which this op may branch to.
147     for (const auto &it : llvm::enumerate(op->getSuccessors())) {
148       Block *block = it.value();
149       SuccessorOperands operands = branchOp.getSuccessorOperands(it.index());
150       assert(operands.getProducedOperandCount() == 0 &&
151              "produced operands not supported");
152       if (operands.getForwardedOperands().empty())
153         continue;
154       // The first and last operands that are forwarded to this successor.
155       int64_t firstOperandIndex =
156           operands.getForwardedOperands().getBeginOperandIndex();
157       int64_t lastOperandIndex =
158           firstOperandIndex + operands.getForwardedOperands().size();
159       bool matchingDestination = operandNumber >= firstOperandIndex &&
160                                  operandNumber < lastOperandIndex;
161       // A branch op may have multiple successors. Find the ones that correspond
162       // to this OpOperand. (There is usually only one.)
163       if (!matchingDestination)
164         continue;
165       // Compute the matching block argument of the destination block.
166       BlockArgument bbArg =
167           block->getArgument(operandNumber - firstOperandIndex);
168       result.addAlias(
169           {bbArg, BufferRelation::Equivalent, /*isDefinite=*/false});
170     }
171 
172     return result;
173   }
174 };
175 
176 } // namespace bufferization
177 } // namespace mlir
178 
179 #endif // MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
180