1e07a7fd5SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2e07a7fd5SMatthias Springer // 3e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e07a7fd5SMatthias Springer // 7e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===// 8e07a7fd5SMatthias Springer 9e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 10e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 126ecebb49SMatthias Springer #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" 13a88732d9SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h" 16e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 17e07a7fd5SMatthias Springer #include "mlir/IR/Dialect.h" 18e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h" 19a1fe1f5fSKazu Hirata #include <optional> 20e07a7fd5SMatthias Springer 21e07a7fd5SMatthias Springer namespace mlir { 22*b0a4e958SMatthias Springer /// Return all func.return ops in the given function. 23*b0a4e958SMatthias Springer SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) { 24*b0a4e958SMatthias Springer SmallVector<func::ReturnOp> result; 25*b0a4e958SMatthias Springer for (Block &b : funcOp.getBody()) 26*b0a4e958SMatthias Springer if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator())) 27*b0a4e958SMatthias Springer result.push_back(returnOp); 28*b0a4e958SMatthias Springer return result; 29*b0a4e958SMatthias Springer } 30*b0a4e958SMatthias Springer 31e07a7fd5SMatthias Springer namespace bufferization { 32e07a7fd5SMatthias Springer namespace func_ext { 33e07a7fd5SMatthias Springer 3491c11574SAndrzej Warzyński void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { 35e07a7fd5SMatthias Springer analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; 36e07a7fd5SMatthias Springer auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); 37e07a7fd5SMatthias Springer auto createdAliasingResults = 38e07a7fd5SMatthias Springer aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping()); 39e07a7fd5SMatthias Springer auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); 40e07a7fd5SMatthias Springer auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); 41e07a7fd5SMatthias Springer (void)createdEquiv; 42e07a7fd5SMatthias Springer (void)createdAliasingResults; 43e07a7fd5SMatthias Springer (void)createdRead; 44e07a7fd5SMatthias Springer (void)createdWritten; 45e07a7fd5SMatthias Springer #ifndef NDEBUG 46e07a7fd5SMatthias Springer assert(createdEquiv.second && "equivalence info exists already"); 47e07a7fd5SMatthias Springer assert(createdAliasingResults.second && "aliasing info exists already"); 48e07a7fd5SMatthias Springer assert(createdRead.second && "bbarg access info exists already"); 49e07a7fd5SMatthias Springer assert(createdWritten.second && "bbarg access info exists already"); 50e07a7fd5SMatthias Springer #endif // NDEBUG 51e07a7fd5SMatthias Springer } 52e07a7fd5SMatthias Springer 53e07a7fd5SMatthias Springer /// Return the index-th bufferized function argument type. This assumes that the 54e07a7fd5SMatthias Springer /// specified argument is a tensor. If the tensor is ranked, a layout map may be 5575ef84bfSOleg Shyshkov /// specified by the user (as per `options.functionArgTypeConverterFn`). 56e07a7fd5SMatthias Springer static BaseMemRefType 57e07a7fd5SMatthias Springer getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, 58e07a7fd5SMatthias Springer const BufferizationOptions &options) { 59e07a7fd5SMatthias Springer auto tensorType = 605550c821STres Popp dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index)); 61e07a7fd5SMatthias Springer assert(tensorType && "expected TensorType"); 62f287da8aSMatthias Springer 6375ef84bfSOleg Shyshkov BaseMemRefType memrefType = options.functionArgTypeConverterFn( 64067d2779Sian Bearman tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); 65e07a7fd5SMatthias Springer 66e07a7fd5SMatthias Springer auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( 67e07a7fd5SMatthias Springer index, BufferizationDialect::kBufferLayoutAttrName); 68e07a7fd5SMatthias Springer if (!layoutAttr) 69e07a7fd5SMatthias Springer return memrefType; 70e07a7fd5SMatthias Springer 715550c821STres Popp auto rankedMemrefType = dyn_cast<MemRefType>(memrefType); 72e07a7fd5SMatthias Springer assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); 73e07a7fd5SMatthias Springer return MemRefType::get( 74e07a7fd5SMatthias Springer rankedMemrefType.getShape(), rankedMemrefType.getElementType(), 759bb63374SLei Zhang layoutAttr.getValue(), rankedMemrefType.getMemorySpace()); 76e07a7fd5SMatthias Springer } 77e07a7fd5SMatthias Springer 78e07a7fd5SMatthias Springer /// Return the FuncOp called by `callOp`. 79e07a7fd5SMatthias Springer static FuncOp getCalledFunction(CallOpInterface callOp) { 80217700baSMatthias Springer SymbolRefAttr sym = 81217700baSMatthias Springer llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); 82e07a7fd5SMatthias Springer if (!sym) 83e07a7fd5SMatthias Springer return nullptr; 84e07a7fd5SMatthias Springer return dyn_cast_or_null<FuncOp>( 85e07a7fd5SMatthias Springer SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 86e07a7fd5SMatthias Springer } 87e07a7fd5SMatthias Springer 88e07a7fd5SMatthias Springer /// Get FuncAnalysisState. 89e07a7fd5SMatthias Springer static const FuncAnalysisState & 90e07a7fd5SMatthias Springer getFuncAnalysisState(const AnalysisState &state) { 91faa9be75SMatthias Springer assert(isa<OneShotAnalysisState>(state) && "expected OneShotAnalysisState"); 92faa9be75SMatthias Springer auto *result = static_cast<const OneShotAnalysisState &>(state) 93faa9be75SMatthias Springer .getExtension<FuncAnalysisState>(); 94faa9be75SMatthias Springer assert(result && "FuncAnalysisState does not exist"); 95faa9be75SMatthias Springer return *result; 96e07a7fd5SMatthias Springer } 97e07a7fd5SMatthias Springer 98e07a7fd5SMatthias Springer /// Return the state (phase) of analysis of the FuncOp. 99e07a7fd5SMatthias Springer static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 100e07a7fd5SMatthias Springer FuncOp funcOp) { 101faa9be75SMatthias Springer if (!isa<OneShotAnalysisState>(state)) 102cd80617aSMatthias Springer return FuncOpAnalysisState::NotAnalyzed; 103faa9be75SMatthias Springer auto *funcState = static_cast<const OneShotAnalysisState &>(state) 104faa9be75SMatthias Springer .getExtension<FuncAnalysisState>(); 105faa9be75SMatthias Springer if (!funcState) 106faa9be75SMatthias Springer return FuncOpAnalysisState::NotAnalyzed; 107faa9be75SMatthias Springer const auto &analyzedFuncOps = funcState->analyzedFuncOps; 108cd80617aSMatthias Springer auto it = analyzedFuncOps.find(funcOp); 109cd80617aSMatthias Springer if (it == analyzedFuncOps.end()) 110e07a7fd5SMatthias Springer return FuncOpAnalysisState::NotAnalyzed; 111e07a7fd5SMatthias Springer return it->second; 112e07a7fd5SMatthias Springer } 113e07a7fd5SMatthias Springer 114e07a7fd5SMatthias Springer /// Return the index of the bbArg in the given FuncOp that is equivalent to the 115e07a7fd5SMatthias Springer /// specified return value (if any). 1160a81ace0SKazu Hirata static std::optional<int64_t> 1170a81ace0SKazu Hirata getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, 118e07a7fd5SMatthias Springer int64_t returnValIdx) { 119e07a7fd5SMatthias Springer auto funcOpIt = state.equivalentFuncArgs.find(funcOp); 120e07a7fd5SMatthias Springer if (funcOpIt == state.equivalentFuncArgs.end()) 121e07a7fd5SMatthias Springer // No equivalence info stores for funcOp. 1221a36588eSKazu Hirata return std::nullopt; 123e07a7fd5SMatthias Springer 124e07a7fd5SMatthias Springer auto retValIt = funcOpIt->getSecond().find(returnValIdx); 125e07a7fd5SMatthias Springer if (retValIt == funcOpIt->getSecond().end()) 126e07a7fd5SMatthias Springer // Return value has no equivalent bbArg. 1271a36588eSKazu Hirata return std::nullopt; 128e07a7fd5SMatthias Springer 129e07a7fd5SMatthias Springer return retValIt->getSecond(); 130e07a7fd5SMatthias Springer } 131e07a7fd5SMatthias Springer 132e07a7fd5SMatthias Springer struct CallOpInterface 133e07a7fd5SMatthias Springer : public BufferizableOpInterface::ExternalModel<CallOpInterface, 134e07a7fd5SMatthias Springer func::CallOp> { 135e07a7fd5SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 136e07a7fd5SMatthias Springer const AnalysisState &state) const { 137e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 138e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 139e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 140e07a7fd5SMatthias Springer 141e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 142e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Assume that OpOperand is read. 143e07a7fd5SMatthias Springer return true; 144e07a7fd5SMatthias Springer 145cd80617aSMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 146e07a7fd5SMatthias Springer return funcState.readBbArgs.lookup(funcOp).contains( 147e07a7fd5SMatthias Springer opOperand.getOperandNumber()); 148e07a7fd5SMatthias Springer } 149e07a7fd5SMatthias Springer 150e07a7fd5SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 151e07a7fd5SMatthias Springer const AnalysisState &state) const { 152e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 153e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 154e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 155e07a7fd5SMatthias Springer 156e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 157e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Assume that OpOperand is written. 158e07a7fd5SMatthias Springer return true; 159e07a7fd5SMatthias Springer 160cd80617aSMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 161e07a7fd5SMatthias Springer return funcState.writtenBbArgs.lookup(funcOp).contains( 162e07a7fd5SMatthias Springer opOperand.getOperandNumber()); 163e07a7fd5SMatthias Springer } 164e07a7fd5SMatthias Springer 165a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 166e07a7fd5SMatthias Springer const AnalysisState &state) const { 167e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 168e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 169e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 170f3483c23SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 171e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Any OpResult may be aliasing. 172a02ad6c1SMatthias Springer return detail::unknownGetAliasingValues(opOperand); 173e07a7fd5SMatthias Springer 174e07a7fd5SMatthias Springer // Get aliasing results from state. 175cd80617aSMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 176e07a7fd5SMatthias Springer auto aliasingReturnVals = 177e07a7fd5SMatthias Springer funcState.aliasingReturnVals.lookup(funcOp).lookup( 178e07a7fd5SMatthias Springer opOperand.getOperandNumber()); 1799fa6b350SMatthias Springer 1809fa6b350SMatthias Springer // Check if the aliasing OpResult is equivalent to the OpOperand. 1819fa6b350SMatthias Springer std::optional<int64_t> equivalent = {}; 1829fa6b350SMatthias Springer if (aliasingReturnVals.size() == 1) { 1839fa6b350SMatthias Springer equivalent = getEquivalentFuncArgIdx(funcOp, funcState, 1849fa6b350SMatthias Springer aliasingReturnVals.front()); 1859fa6b350SMatthias Springer assert((!equivalent.has_value() || 1869fa6b350SMatthias Springer *equivalent == opOperand.getOperandNumber()) && 1879fa6b350SMatthias Springer "inconsistent analysis state"); 1889fa6b350SMatthias Springer } 189a02ad6c1SMatthias Springer AliasingValueList result; 190e07a7fd5SMatthias Springer for (int64_t resultIdx : aliasingReturnVals) 1919fa6b350SMatthias Springer result.addAlias({callOp->getOpResult(resultIdx), 1929fa6b350SMatthias Springer equivalent.has_value() ? BufferRelation::Equivalent 1939fa6b350SMatthias Springer : BufferRelation::Unknown, 1949fa6b350SMatthias Springer /*isDefinite=*/equivalent.has_value()}); 195e07a7fd5SMatthias Springer return result; 196e07a7fd5SMatthias Springer } 197e07a7fd5SMatthias Springer 19806dacf5eSMatthias Springer FailureOr<BaseMemRefType> 19906dacf5eSMatthias Springer getBufferType(Operation *op, Value value, const BufferizationOptions &options, 200878950b8SMatthias Springer SmallVector<Value> &invocationStack) const { 20106dacf5eSMatthias Springer auto callOp = cast<func::CallOp>(op); 20206dacf5eSMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 20306dacf5eSMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 20406dacf5eSMatthias Springer 205c271ba7fSMatthias Springer // If the callee was already bufferized, we can directly take the type from 20606dacf5eSMatthias Springer // its signature. 20706dacf5eSMatthias Springer FunctionType funcType = funcOp.getFunctionType(); 208c271ba7fSMatthias Springer Type resultType = 209c271ba7fSMatthias Springer funcType.getResult(cast<OpResult>(value).getResultNumber()); 210c271ba7fSMatthias Springer if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType)) 211c271ba7fSMatthias Springer return bufferizedType; 212c271ba7fSMatthias Springer 213c271ba7fSMatthias Springer // Otherwise, call the type converter to compute the bufferized type. 214c271ba7fSMatthias Springer auto tensorType = cast<TensorType>(resultType); 215c271ba7fSMatthias Springer return options.functionArgTypeConverterFn( 216c271ba7fSMatthias Springer tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); 21706dacf5eSMatthias Springer } 21806dacf5eSMatthias Springer 219e07a7fd5SMatthias Springer /// All function arguments are writable. It is the responsibility of the 220e07a7fd5SMatthias Springer /// CallOp to insert buffer copies where necessary. 221e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 222b55d55ecSMatthias Springer const BufferizationOptions &options) const { 223e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 224e07a7fd5SMatthias Springer 22588539c5bSMatthias Springer // 1. Compute the result types of the new CallOp. 22606dacf5eSMatthias Springer SmallVector<Type> resultTypes; 22706dacf5eSMatthias Springer for (Value result : callOp.getResults()) { 22806dacf5eSMatthias Springer Type returnType = result.getType(); 2295550c821STres Popp if (!isa<TensorType>(returnType)) { 230e07a7fd5SMatthias Springer // Non-tensor values are returned. 231e07a7fd5SMatthias Springer resultTypes.push_back(returnType); 232e07a7fd5SMatthias Springer continue; 233e07a7fd5SMatthias Springer } 234e07a7fd5SMatthias Springer 23588539c5bSMatthias Springer // Returning a memref. 23606dacf5eSMatthias Springer FailureOr<BaseMemRefType> resultType = 23706dacf5eSMatthias Springer bufferization::getBufferType(result, options); 23806dacf5eSMatthias Springer if (failed(resultType)) 23906dacf5eSMatthias Springer return failure(); 24006dacf5eSMatthias Springer resultTypes.push_back(*resultType); 241e07a7fd5SMatthias Springer } 242e07a7fd5SMatthias Springer 24306dacf5eSMatthias Springer // 2. Rewrite tensor operands as memrefs based on type of the already 24406dacf5eSMatthias Springer // bufferized callee. 24506dacf5eSMatthias Springer SmallVector<Value> newOperands; 24606dacf5eSMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 24706dacf5eSMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 24806dacf5eSMatthias Springer FunctionType funcType = funcOp.getFunctionType(); 249e07a7fd5SMatthias Springer 25006dacf5eSMatthias Springer for (OpOperand &opOperand : callOp->getOpOperands()) { 251e07a7fd5SMatthias Springer // Non-tensor operands are just copied. 25206dacf5eSMatthias Springer if (!isa<TensorType>(opOperand.get().getType())) { 25306dacf5eSMatthias Springer newOperands.push_back(opOperand.get()); 254e07a7fd5SMatthias Springer continue; 255e07a7fd5SMatthias Springer } 256e07a7fd5SMatthias Springer 25788539c5bSMatthias Springer // Retrieve buffers for tensor operands. 2585d50f51cSMatthias Springer FailureOr<Value> maybeBuffer = 2595d50f51cSMatthias Springer getBuffer(rewriter, opOperand.get(), options); 2605d50f51cSMatthias Springer if (failed(maybeBuffer)) 2615d50f51cSMatthias Springer return failure(); 26206dacf5eSMatthias Springer Value buffer = *maybeBuffer; 263e07a7fd5SMatthias Springer 2647f04a8adSLongsheng Mou // Caller / callee type mismatch is handled with castOrReallocMemRefValue. 26506dacf5eSMatthias Springer auto memRefType = funcType.getInput(opOperand.getOperandNumber()); 266c271ba7fSMatthias Springer if (!isa<BaseMemRefType>(memRefType)) { 267c271ba7fSMatthias Springer // The called function was not bufferized yet. This can happen when 268c271ba7fSMatthias Springer // there cycles in the function call graph. Compute the bufferized 269c271ba7fSMatthias Springer // result type. 270c271ba7fSMatthias Springer FailureOr<BaseMemRefType> maybeMemRefType = 271c271ba7fSMatthias Springer bufferization::getBufferType( 272c271ba7fSMatthias Springer funcOp.getArgument(opOperand.getOperandNumber()), options); 273c271ba7fSMatthias Springer if (failed(maybeMemRefType)) 274c271ba7fSMatthias Springer return failure(); 275c271ba7fSMatthias Springer memRefType = *maybeMemRefType; 276c271ba7fSMatthias Springer } 277c271ba7fSMatthias Springer 278e07a7fd5SMatthias Springer // Since we don't yet have a clear layout story, to_memref may 279e07a7fd5SMatthias Springer // conservatively turn tensors into more dynamic memref than necessary. 280e07a7fd5SMatthias Springer // If the memref type of the callee fails, introduce an extra memref.cast 281e07a7fd5SMatthias Springer // that will either canonicalize away or fail compilation until we can do 2827f04a8adSLongsheng Mou // something better. Insert a reallocation + copy if it cannot be 2837f04a8adSLongsheng Mou // statically guaranteed that a direct cast would be valid. 284e07a7fd5SMatthias Springer if (buffer.getType() != memRefType) { 2857f04a8adSLongsheng Mou auto memrefDstType = dyn_cast<MemRefType>(memRefType); 2867f04a8adSLongsheng Mou assert(memrefDstType && 2877f04a8adSLongsheng Mou "buffer layout not supported on unranked tensors"); 2887f04a8adSLongsheng Mou FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue( 2897f04a8adSLongsheng Mou rewriter, buffer, memrefDstType, options); 2907f04a8adSLongsheng Mou if (failed(replacement)) 2917f04a8adSLongsheng Mou return failure(); 2927f04a8adSLongsheng Mou buffer = *replacement; 293e07a7fd5SMatthias Springer } 29406dacf5eSMatthias Springer newOperands.push_back(buffer); 295e07a7fd5SMatthias Springer } 296e07a7fd5SMatthias Springer 297e07a7fd5SMatthias Springer // 3. Create the new CallOp. 298e07a7fd5SMatthias Springer Operation *newCallOp = rewriter.create<func::CallOp>( 299e07a7fd5SMatthias Springer callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); 300e07a7fd5SMatthias Springer newCallOp->setAttrs(callOp->getAttrs()); 301e07a7fd5SMatthias Springer 302e07a7fd5SMatthias Springer // 4. Replace the old op with the new op. 30306dacf5eSMatthias Springer replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults()); 304e07a7fd5SMatthias Springer 305e07a7fd5SMatthias Springer return success(); 306e07a7fd5SMatthias Springer } 307e07a7fd5SMatthias Springer }; 308e07a7fd5SMatthias Springer 309e07a7fd5SMatthias Springer struct ReturnOpInterface 310e07a7fd5SMatthias Springer : public BufferizableOpInterface::ExternalModel<ReturnOpInterface, 311e07a7fd5SMatthias Springer func::ReturnOp> { 312e07a7fd5SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 313e07a7fd5SMatthias Springer const AnalysisState &state) const { 314e07a7fd5SMatthias Springer return true; 315e07a7fd5SMatthias Springer } 316e07a7fd5SMatthias Springer 317e07a7fd5SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 318e07a7fd5SMatthias Springer const AnalysisState &state) const { 319e07a7fd5SMatthias Springer return false; 320e07a7fd5SMatthias Springer } 321e07a7fd5SMatthias Springer 322a02ad6c1SMatthias Springer AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 323e07a7fd5SMatthias Springer const AnalysisState &state) const { 324e07a7fd5SMatthias Springer return {}; 325e07a7fd5SMatthias Springer } 326e07a7fd5SMatthias Springer 327e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 328b55d55ecSMatthias Springer const BufferizationOptions &options) const { 329e07a7fd5SMatthias Springer #ifndef NDEBUG 330e07a7fd5SMatthias Springer auto returnOp = cast<func::ReturnOp>(op); 331e07a7fd5SMatthias Springer assert(isa<FuncOp>(returnOp->getParentOp()) && 332e07a7fd5SMatthias Springer "only support FuncOp parent for ReturnOp"); 333e07a7fd5SMatthias Springer #endif // NDEBUG 334e07a7fd5SMatthias Springer 335e07a7fd5SMatthias Springer // ReturnOps are bufferized as part of FuncOps. 3360b293bf0SMatthias Springer return success(); 337e07a7fd5SMatthias Springer } 338e07a7fd5SMatthias Springer }; 339e07a7fd5SMatthias Springer 340e07a7fd5SMatthias Springer struct FuncOpInterface 3416ecebb49SMatthias Springer : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< 3426ecebb49SMatthias Springer FuncOpInterface, FuncOp> { 3436ecebb49SMatthias Springer 3446ecebb49SMatthias Springer static bool supportsUnstructuredControlFlow() { return true; } 3456ecebb49SMatthias Springer 3468f2d83daSMatthias Springer bool hasTensorSemantics(Operation *op) const { 347971b8525SJakub Kuderski auto isaTensor = llvm::IsaPred<TensorType>; 3488f2d83daSMatthias Springer 3498f2d83daSMatthias Springer // A function has tensor semantics if it has tensor arguments/results. 3508f2d83daSMatthias Springer auto funcOp = cast<FuncOp>(op); 3518f2d83daSMatthias Springer bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); 3528f2d83daSMatthias Springer bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); 3538f2d83daSMatthias Springer if (hasTensorArg || hasTensorResult) 3548f2d83daSMatthias Springer return true; 3558f2d83daSMatthias Springer 3568f2d83daSMatthias Springer // It also has tensor semantics if it has tensor block arguments. 3578f2d83daSMatthias Springer // TODO: Decouple bufferization of unstructured control flow from 3588f2d83daSMatthias Springer // BufferizableOpInterface implementations. We should only care about 3598f2d83daSMatthias Springer // region entry block arguments here (which are already covered by the 3608f2d83daSMatthias Springer // argument types of the function). 3618f2d83daSMatthias Springer for (Block &block : funcOp.getBody()) 3628f2d83daSMatthias Springer if (any_of(block.getArgumentTypes(), isaTensor)) 3638f2d83daSMatthias Springer return true; 3648f2d83daSMatthias Springer 3658f2d83daSMatthias Springer return false; 3668f2d83daSMatthias Springer } 3678f2d83daSMatthias Springer 3686ecebb49SMatthias Springer AliasingOpOperandList 3696ecebb49SMatthias Springer getAliasingOpOperands(Operation *op, Value value, 3706ecebb49SMatthias Springer const AnalysisState &state) const { 3716ecebb49SMatthias Springer return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state); 3726ecebb49SMatthias Springer } 3736ecebb49SMatthias Springer 37406dacf5eSMatthias Springer FailureOr<BaseMemRefType> 37506dacf5eSMatthias Springer getBufferType(Operation *op, Value value, const BufferizationOptions &options, 376878950b8SMatthias Springer SmallVector<Value> &invocationStack) const { 37706dacf5eSMatthias Springer auto funcOp = cast<FuncOp>(op); 37806dacf5eSMatthias Springer auto bbArg = cast<BlockArgument>(value); 3796ecebb49SMatthias Springer 3806ecebb49SMatthias Springer // Function arguments are special. 3816ecebb49SMatthias Springer if (bbArg.getOwner() == &funcOp.getBody().front()) 3826ecebb49SMatthias Springer return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), 3836ecebb49SMatthias Springer options); 3846ecebb49SMatthias Springer 3856ecebb49SMatthias Springer return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: 3866ecebb49SMatthias Springer getBufferType(op, value, options, invocationStack); 3876ecebb49SMatthias Springer } 3886ecebb49SMatthias Springer 389f287da8aSMatthias Springer /// Rewrite function bbArgs and return values into buffer form. This function 390f287da8aSMatthias Springer /// bufferizes the function signature and the ReturnOp. When the entire 391f287da8aSMatthias Springer /// function body has been bufferized, function return types can be switched 392f287da8aSMatthias Springer /// to more concise memref types as part of `foldMemRefCasts`. 393e07a7fd5SMatthias Springer /// 394e07a7fd5SMatthias Springer /// All function bbArgs are writable unless they are explicitly marked as 395e07a7fd5SMatthias Springer /// read-only. Callers must insert copies when needed. 396e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 397b55d55ecSMatthias Springer const BufferizationOptions &options) const { 398e07a7fd5SMatthias Springer auto funcOp = cast<FuncOp>(op); 399e07a7fd5SMatthias Springer FunctionType funcType = funcOp.getFunctionType(); 400e07a7fd5SMatthias Springer 401217700baSMatthias Springer // Compute the argument types. 402e07a7fd5SMatthias Springer SmallVector<Type> argTypes; 403e07a7fd5SMatthias Springer for (const auto &it : llvm::enumerate(funcType.getInputs())) { 404e07a7fd5SMatthias Springer Type argType = it.value(); 405217700baSMatthias Springer if (isa<TensorType>(argType)) { 406e07a7fd5SMatthias Springer argTypes.push_back( 407e07a7fd5SMatthias Springer getBufferizedFunctionArgType(funcOp, it.index(), options)); 408e07a7fd5SMatthias Springer continue; 409e07a7fd5SMatthias Springer } 410e07a7fd5SMatthias Springer argTypes.push_back(argType); 411e07a7fd5SMatthias Springer } 412e07a7fd5SMatthias Springer 413217700baSMatthias Springer // Compute the result types. 414e07a7fd5SMatthias Springer SmallVector<Type> retTypes; 415e07a7fd5SMatthias Springer for (Type resultType : funcType.getResults()) { 416217700baSMatthias Springer if (auto tensorType = dyn_cast<TensorType>(resultType)) { 417217700baSMatthias Springer BaseMemRefType resultType = options.functionArgTypeConverterFn( 418217700baSMatthias Springer tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, 419217700baSMatthias Springer options); 420217700baSMatthias Springer retTypes.push_back(resultType); 421217700baSMatthias Springer continue; 422217700baSMatthias Springer } 423e07a7fd5SMatthias Springer retTypes.push_back(resultType); 424e07a7fd5SMatthias Springer } 425217700baSMatthias Springer 426217700baSMatthias Springer // Compute the new function type. 427217700baSMatthias Springer auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes); 428217700baSMatthias Springer 429217700baSMatthias Springer // If the function has no body, set the new function type and we are done. 430217700baSMatthias Springer if (funcOp.isExternal()) { 431217700baSMatthias Springer funcOp.setType(newFuncType); 432e07a7fd5SMatthias Springer return success(); 433e07a7fd5SMatthias Springer } 434e07a7fd5SMatthias Springer 435a88732d9SMatthias Springer // 1. Bufferize every block. 436a88732d9SMatthias Springer for (Block &block : funcOp.getBody()) 437a88732d9SMatthias Springer if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, 438a88732d9SMatthias Springer options))) 43906dacf5eSMatthias Springer return failure(); 440e07a7fd5SMatthias Springer 441*b0a4e958SMatthias Springer // 2. Bufferize the operands of the all return op. 442*b0a4e958SMatthias Springer for (func::ReturnOp returnOp : getReturnOps(funcOp)) { 443*b0a4e958SMatthias Springer assert(returnOp->getNumOperands() == retTypes.size() && 444*b0a4e958SMatthias Springer "incorrect number of return values"); 445e07a7fd5SMatthias Springer SmallVector<Value> returnValues; 446217700baSMatthias Springer for (auto [returnVal, bufferizedType] : 447217700baSMatthias Springer llvm::zip_equal(returnOp->getOperands(), retTypes)) { 4485550c821STres Popp auto tensorType = dyn_cast<TensorType>(returnVal.getType()); 449f287da8aSMatthias Springer rewriter.setInsertionPoint(returnOp); 450e07a7fd5SMatthias Springer 451e07a7fd5SMatthias Springer // If not a tensor type just forward it. 452f287da8aSMatthias Springer if (!tensorType) { 453e07a7fd5SMatthias Springer returnValues.push_back(returnVal); 454e07a7fd5SMatthias Springer continue; 455e07a7fd5SMatthias Springer } 456e07a7fd5SMatthias Springer 457217700baSMatthias Springer // Note: If `inferFunctionResultLayout = true`, casts are later folded 45875ef84bfSOleg Shyshkov // away. 459f287da8aSMatthias Springer Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 460*b0a4e958SMatthias Springer returnOp.getLoc(), bufferizedType, returnVal); 461f287da8aSMatthias Springer returnValues.push_back(toMemrefOp); 462e07a7fd5SMatthias Springer } 463e07a7fd5SMatthias Springer 464b74192b7SRiver Riddle returnOp.getOperandsMutable().assign(returnValues); 465*b0a4e958SMatthias Springer } 466e07a7fd5SMatthias Springer 467217700baSMatthias Springer // 3. Set the new function type. 468217700baSMatthias Springer funcOp.setType(newFuncType); 469e07a7fd5SMatthias Springer return success(); 470e07a7fd5SMatthias Springer } 471e07a7fd5SMatthias Springer 472e07a7fd5SMatthias Springer /// Return `true` if the given function argument is writable. 473e07a7fd5SMatthias Springer bool isWritable(Operation *op, Value value, 474e07a7fd5SMatthias Springer const AnalysisState &state) const { 475e07a7fd5SMatthias Springer auto funcOp = cast<FuncOp>(op); 4765550c821STres Popp BlockArgument bbArg = dyn_cast<BlockArgument>(value); 477e07a7fd5SMatthias Springer assert(bbArg && "expected BlockArgument"); 478e07a7fd5SMatthias Springer 4796ecebb49SMatthias Springer // Non-entry block arguments are always writable. (They may alias with 4806ecebb49SMatthias Springer // values that are not writable, which will turn them into read-only.) 4816ecebb49SMatthias Springer if (bbArg.getOwner() != &funcOp.getBody().front()) 4826ecebb49SMatthias Springer return true; 4836ecebb49SMatthias Springer 484e07a7fd5SMatthias Springer // "bufferization.writable" overrides other writability decisions. This is 485e07a7fd5SMatthias Springer // currently used for testing only. 486e07a7fd5SMatthias Springer if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>( 487e07a7fd5SMatthias Springer bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) 488e07a7fd5SMatthias Springer return writable.getValue(); 489e07a7fd5SMatthias Springer 490e07a7fd5SMatthias Springer // All function arguments are writable by default. 491e07a7fd5SMatthias Springer return true; 492e07a7fd5SMatthias Springer } 493e07a7fd5SMatthias Springer }; 494e07a7fd5SMatthias Springer 495e07a7fd5SMatthias Springer } // namespace func_ext 496e07a7fd5SMatthias Springer } // namespace bufferization 497e07a7fd5SMatthias Springer } // namespace mlir 498e07a7fd5SMatthias Springer 499e07a7fd5SMatthias Springer void mlir::bufferization::func_ext:: 500e07a7fd5SMatthias Springer registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { 501e07a7fd5SMatthias Springer registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 502e07a7fd5SMatthias Springer func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx); 503e07a7fd5SMatthias Springer func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx); 504e07a7fd5SMatthias Springer func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx); 505e07a7fd5SMatthias Springer }); 506e07a7fd5SMatthias Springer } 507