1 //===- UnstructuredControlFlow.h - Op Interface Helpers ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_ 10 #define MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_ 11 12 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 13 #include "mlir/Interfaces/ControlFlowInterfaces.h" 14 15 //===----------------------------------------------------------------------===// 16 // Helpers for Unstructured Control Flow 17 //===----------------------------------------------------------------------===// 18 19 namespace mlir { 20 namespace bufferization { 21 22 namespace detail { 23 /// Return a list of operands that are forwarded to the given block argument. 24 /// I.e., find all predecessors of the block argument's owner and gather the 25 /// operands that are equivalent to the block argument. 26 SmallVector<OpOperand *> getCallerOpOperands(BlockArgument bbArg); 27 } // namespace detail 28 29 /// A template that provides a default implementation of `getAliasingOpOperands` 30 /// for ops that support unstructured control flow within their regions. 31 template <typename ConcreteModel, typename ConcreteOp> 32 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel 33 : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> { 34 35 FailureOr<BaseMemRefType> getBufferTypeOpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel36 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 37 SmallVector<Value> &invocationStack) const { 38 // Note: The user may want to override this function for OpResults in 39 // case the bufferized result type is different from the bufferized type of 40 // the aliasing OpOperand (if any). 41 if (isa<OpResult>(value)) 42 return bufferization::detail::defaultGetBufferType(value, options, 43 invocationStack); 44 45 // Compute the buffer type of the block argument by computing the bufferized 46 // operand types of all forwarded values. If these are all the same type, 47 // take that type. Otherwise, take only the memory space and fall back to a 48 // buffer type with a fully dynamic layout map. 49 BaseMemRefType bufferType; 50 auto tensorType = cast<TensorType>(value.getType()); 51 for (OpOperand *opOperand : 52 detail::getCallerOpOperands(cast<BlockArgument>(value))) { 53 54 // If the forwarded operand is already on the invocation stack, we ran 55 // into a loop and this operand cannot be used to compute the bufferized 56 // type. 57 if (llvm::find(invocationStack, opOperand->get()) != 58 invocationStack.end()) 59 continue; 60 61 // Compute the bufferized type of the forwarded operand. 62 BaseMemRefType callerType; 63 if (auto memrefType = 64 dyn_cast<BaseMemRefType>(opOperand->get().getType())) { 65 // The operand was already bufferized. Take its type directly. 66 callerType = memrefType; 67 } else { 68 FailureOr<BaseMemRefType> maybeCallerType = 69 bufferization::getBufferType(opOperand->get(), options, 70 invocationStack); 71 if (failed(maybeCallerType)) 72 return failure(); 73 callerType = *maybeCallerType; 74 } 75 76 if (!bufferType) { 77 // This is the first buffer type that we computed. 78 bufferType = callerType; 79 continue; 80 } 81 82 if (bufferType == callerType) 83 continue; 84 85 // If the computed buffer type does not match the computed buffer type 86 // of the earlier forwarded operands, fall back to a buffer type with a 87 // fully dynamic layout map. 88 #ifndef NDEBUG 89 if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) { 90 assert(bufferType.hasRank() && callerType.hasRank() && 91 "expected ranked memrefs"); 92 assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(), 93 rankedTensorType.getShape()}) && 94 "expected same shape"); 95 } else { 96 assert(!bufferType.hasRank() && !callerType.hasRank() && 97 "expected unranked memrefs"); 98 } 99 #endif // NDEBUG 100 101 if (bufferType.getMemorySpace() != callerType.getMemorySpace()) 102 return op->emitOpError("incoming operands of block argument have " 103 "inconsistent memory spaces"); 104 105 bufferType = getMemRefTypeWithFullyDynamicLayout( 106 tensorType, bufferType.getMemorySpace()); 107 } 108 109 if (!bufferType) 110 return op->emitOpError("could not infer buffer type of block argument"); 111 112 return bufferType; 113 } 114 115 protected: 116 /// Assuming that `bbArg` is a block argument of a block that belongs to the 117 /// given `op`, return all OpOperands of users of this block that are 118 /// aliasing with the given block argument. 119 AliasingOpOperandList getAliasingBranchOpOperandsOpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel120 getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, 121 const AnalysisState &state) const { 122 assert(bbArg.getOwner()->getParentOp() == op && "invalid bbArg"); 123 124 // Gather aliasing OpOperands of all operations (callers) that link to 125 // this block. 126 AliasingOpOperandList result; 127 for (OpOperand *opOperand : detail::getCallerOpOperands(bbArg)) 128 result.addAlias( 129 {opOperand, BufferRelation::Equivalent, /*isDefinite=*/false}); 130 131 return result; 132 } 133 }; 134 135 /// A template that provides a default implementation of `getAliasingValues` 136 /// for ops that implement the `BranchOpInterface`. 137 template <typename ConcreteModel, typename ConcreteOp> 138 struct BranchOpBufferizableOpInterfaceExternalModel 139 : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> { getAliasingValuesBranchOpBufferizableOpInterfaceExternalModel140 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 141 const AnalysisState &state) const { 142 AliasingValueList result; 143 auto branchOp = cast<BranchOpInterface>(op); 144 auto operandNumber = opOperand.getOperandNumber(); 145 146 // Gather aliasing block arguments of blocks to which this op may branch to. 147 for (const auto &it : llvm::enumerate(op->getSuccessors())) { 148 Block *block = it.value(); 149 SuccessorOperands operands = branchOp.getSuccessorOperands(it.index()); 150 assert(operands.getProducedOperandCount() == 0 && 151 "produced operands not supported"); 152 if (operands.getForwardedOperands().empty()) 153 continue; 154 // The first and last operands that are forwarded to this successor. 155 int64_t firstOperandIndex = 156 operands.getForwardedOperands().getBeginOperandIndex(); 157 int64_t lastOperandIndex = 158 firstOperandIndex + operands.getForwardedOperands().size(); 159 bool matchingDestination = operandNumber >= firstOperandIndex && 160 operandNumber < lastOperandIndex; 161 // A branch op may have multiple successors. Find the ones that correspond 162 // to this OpOperand. (There is usually only one.) 163 if (!matchingDestination) 164 continue; 165 // Compute the matching block argument of the destination block. 166 BlockArgument bbArg = 167 block->getArgument(operandNumber - firstOperandIndex); 168 result.addAlias( 169 {bbArg, BufferRelation::Equivalent, /*isDefinite=*/false}); 170 } 171 172 return result; 173 } 174 }; 175 176 } // namespace bufferization 177 } // namespace mlir 178 179 #endif // MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_ 180