1e07a7fd5SMatthias Springer //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// 2e07a7fd5SMatthias Springer // 3e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e07a7fd5SMatthias Springer // 7e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===// 8e07a7fd5SMatthias Springer // 9e07a7fd5SMatthias Springer // Module Bufferization is an extension of One-Shot Bufferize that 10e07a7fd5SMatthias Springer // bufferizes function boundaries. It provides `BufferizableOpInterface` 11e07a7fd5SMatthias Springer // implementations for FuncOp, CallOp and ReturnOp. 12e07a7fd5SMatthias Springer // 13e07a7fd5SMatthias Springer // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. 14e07a7fd5SMatthias Springer // This function analyzes the given module and determines the order of analysis 15e07a7fd5SMatthias Springer // and bufferization: Functions that are called are processed before their 16e07a7fd5SMatthias Springer // respective callers. 17e07a7fd5SMatthias Springer // 18e07a7fd5SMatthias Springer // After analyzing a FuncOp, additional information about its bbArgs is 193490aadfSMatthias Springer // gathered and stored in `FuncAnalysisState`. 20e07a7fd5SMatthias Springer // 21e07a7fd5SMatthias Springer // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs 22e07a7fd5SMatthias Springer // for 23e07a7fd5SMatthias Springer // each tensor return value (if any). 24e07a7fd5SMatthias Springer // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is 25e07a7fd5SMatthias Springer // read/written. 26e07a7fd5SMatthias Springer // 27e07a7fd5SMatthias Springer // Module Bufferization implements the following calling convention. 28e07a7fd5SMatthias Springer // 29e07a7fd5SMatthias Springer // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always 30e07a7fd5SMatthias Springer // be written to in-place. 31e07a7fd5SMatthias Springer // * If a tensor operand of a CallOp is read after the CallOp, the operand of 32e07a7fd5SMatthias Springer // the CallOp must bufferize out-of-place. 33e07a7fd5SMatthias Springer // 34e07a7fd5SMatthias Springer // Example: The tensor.insert op bufferizes in-place because it is allowed to 35e07a7fd5SMatthias Springer // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize 36e07a7fd5SMatthias Springer // out-of-place because `%t0` is modified by the callee but read by the 37e07a7fd5SMatthias Springer // tensor.extract op. The analysis of CallOps decides whether an OpOperand must 38e07a7fd5SMatthias Springer // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. 39e07a7fd5SMatthias Springer // ``` 40e07a7fd5SMatthias Springer // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { 41e07a7fd5SMatthias Springer // %f = ... : f32 42e07a7fd5SMatthias Springer // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> 43e07a7fd5SMatthias Springer // return %0 : tensor<?xf32> 44e07a7fd5SMatthias Springer // } 45e07a7fd5SMatthias Springer // 46e07a7fd5SMatthias Springer // func @caller() -> () { 47e07a7fd5SMatthias Springer // %t0 = ... : tensor<?xf32> 48e07a7fd5SMatthias Springer // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) 49e07a7fd5SMatthias Springer // %2 = tensor.extract %1[...] : tensor<?xf32> 50e07a7fd5SMatthias Springer // } 51e07a7fd5SMatthias Springer // ``` 52e07a7fd5SMatthias Springer // 53e07a7fd5SMatthias Springer // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot 54e07a7fd5SMatthias Springer // analyze the function body. In such a case, the CallOp analysis conservatively 55e07a7fd5SMatthias Springer // assumes that each tensor OpOperand is both read and written. 56e07a7fd5SMatthias Springer // 57e07a7fd5SMatthias Springer // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked 58e07a7fd5SMatthias Springer // as "not reading" and/or "not writing". 59e07a7fd5SMatthias Springer 60e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 61e07a7fd5SMatthias Springer 62e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 63e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 64e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 65e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 66e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 6728b2f792SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 68e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h" 69e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 70971b8525SJakub Kuderski #include "mlir/IR/BuiltinTypes.h" 71e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h" 72e07a7fd5SMatthias Springer 73e07a7fd5SMatthias Springer using namespace mlir; 74e07a7fd5SMatthias Springer using namespace mlir::bufferization; 75e07a7fd5SMatthias Springer using namespace mlir::bufferization::func_ext; 76e07a7fd5SMatthias Springer 77e07a7fd5SMatthias Springer /// A mapping of FuncOps to their callers. 7891c11574SAndrzej Warzyński using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>; 79e07a7fd5SMatthias Springer 80e07a7fd5SMatthias Springer /// Get or create FuncAnalysisState. 81faa9be75SMatthias Springer static FuncAnalysisState & 82faa9be75SMatthias Springer getOrCreateFuncAnalysisState(OneShotAnalysisState &state) { 83faa9be75SMatthias Springer auto *result = state.getExtension<FuncAnalysisState>(); 84faa9be75SMatthias Springer if (result) 85faa9be75SMatthias Springer return *result; 86faa9be75SMatthias Springer return state.addExtension<FuncAnalysisState>(); 87e07a7fd5SMatthias Springer } 88e07a7fd5SMatthias Springer 89e07a7fd5SMatthias Springer namespace { 90e07a7fd5SMatthias Springer 91e07a7fd5SMatthias Springer /// Annotate IR with the results of the analysis. For testing purposes only. 92e07a7fd5SMatthias Springer static void annotateEquivalentReturnBbArg(OpOperand &returnVal, 93e07a7fd5SMatthias Springer BlockArgument bbArg) { 94e07a7fd5SMatthias Springer const char *kEquivalentArgsAttr = "__equivalent_func_args__"; 95e07a7fd5SMatthias Springer Operation *op = returnVal.getOwner(); 96e07a7fd5SMatthias Springer 97e07a7fd5SMatthias Springer SmallVector<int64_t> equivBbArgs; 98e07a7fd5SMatthias Springer if (op->hasAttr(kEquivalentArgsAttr)) { 995550c821STres Popp auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr)); 100e07a7fd5SMatthias Springer equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { 1015550c821STres Popp return cast<IntegerAttr>(a).getValue().getSExtValue(); 102e07a7fd5SMatthias Springer })); 103e07a7fd5SMatthias Springer } else { 104e07a7fd5SMatthias Springer equivBbArgs.append(op->getNumOperands(), -1); 105e07a7fd5SMatthias Springer } 106e07a7fd5SMatthias Springer equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); 107e07a7fd5SMatthias Springer 108e07a7fd5SMatthias Springer OpBuilder b(op->getContext()); 109e07a7fd5SMatthias Springer op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); 110e07a7fd5SMatthias Springer } 111e07a7fd5SMatthias Springer 112e07a7fd5SMatthias Springer /// Store function BlockArguments that are equivalent to/aliasing a returned 113e07a7fd5SMatthias Springer /// value in FuncAnalysisState. 114faa9be75SMatthias Springer static LogicalResult 11591c11574SAndrzej Warzyński aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, 116faa9be75SMatthias Springer FuncAnalysisState &funcState) { 11791c11574SAndrzej Warzyński if (funcOp.getBody().empty()) { 1184002eaaaSMatthias Springer // No function body available. Conservatively assume that every tensor 1194002eaaaSMatthias Springer // return value may alias with any tensor bbArg. 12091c11574SAndrzej Warzyński FunctionType type = funcOp.getFunctionType(); 12191c11574SAndrzej Warzyński for (const auto &inputIt : llvm::enumerate(type.getInputs())) { 1225550c821STres Popp if (!isa<TensorType>(inputIt.value())) 1234002eaaaSMatthias Springer continue; 12491c11574SAndrzej Warzyński for (const auto &resultIt : llvm::enumerate(type.getResults())) { 1255550c821STres Popp if (!isa<TensorType>(resultIt.value())) 1264002eaaaSMatthias Springer continue; 1274002eaaaSMatthias Springer int64_t returnIdx = resultIt.index(); 1284002eaaaSMatthias Springer int64_t bbArgIdx = inputIt.index(); 1294002eaaaSMatthias Springer funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); 1304002eaaaSMatthias Springer } 1314002eaaaSMatthias Springer } 1324002eaaaSMatthias Springer return success(); 1334002eaaaSMatthias Springer } 1344002eaaaSMatthias Springer 135*b0a4e958SMatthias Springer // Find all func.return ops. 136*b0a4e958SMatthias Springer SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp); 137*b0a4e958SMatthias Springer assert(!returnOps.empty() && "expected at least one ReturnOp"); 138e07a7fd5SMatthias Springer 139*b0a4e958SMatthias Springer // Build alias sets. Merge all aliases from all func.return ops. 140*b0a4e958SMatthias Springer for (BlockArgument bbArg : funcOp.getArguments()) { 1415550c821STres Popp if (isa<RankedTensorType>(bbArg.getType())) { 142e07a7fd5SMatthias Springer int64_t bbArgIdx = bbArg.getArgNumber(); 143*b0a4e958SMatthias Springer // Store aliases in a set, so that we don't add the same alias twice. 144*b0a4e958SMatthias Springer SetVector<int64_t> aliases; 145*b0a4e958SMatthias Springer for (func::ReturnOp returnOp : returnOps) { 146*b0a4e958SMatthias Springer for (OpOperand &returnVal : returnOp->getOpOperands()) { 147*b0a4e958SMatthias Springer if (isa<RankedTensorType>(returnVal.get().getType())) { 148*b0a4e958SMatthias Springer int64_t returnIdx = returnVal.getOperandNumber(); 149b7858f85SMatthias Springer if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) 150*b0a4e958SMatthias Springer aliases.insert(returnIdx); 151*b0a4e958SMatthias Springer } 152*b0a4e958SMatthias Springer } 153*b0a4e958SMatthias Springer } 154*b0a4e958SMatthias Springer for (int64_t alias : aliases) 155*b0a4e958SMatthias Springer funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias); 156*b0a4e958SMatthias Springer } 157*b0a4e958SMatthias Springer } 158*b0a4e958SMatthias Springer 159*b0a4e958SMatthias Springer // Build equivalence sets. 160*b0a4e958SMatthias Springer // Helper function that finds an equivalent block argument index for the 161*b0a4e958SMatthias Springer // given OpOperand. Return std::nullopt if no equivalent block argument could 162*b0a4e958SMatthias Springer // be found. 163*b0a4e958SMatthias Springer auto findEquivalentBlockArgIdx = 164*b0a4e958SMatthias Springer [&](OpOperand &opOperand) -> std::optional<int64_t> { 165*b0a4e958SMatthias Springer Value v = opOperand.get(); 166*b0a4e958SMatthias Springer if (!isa<TensorType>(v.getType())) 167*b0a4e958SMatthias Springer return std::nullopt; 168*b0a4e958SMatthias Springer for (BlockArgument bbArg : funcOp.getArguments()) { 169*b0a4e958SMatthias Springer if (isa<RankedTensorType>(bbArg.getType())) { 170*b0a4e958SMatthias Springer if (state.areEquivalentBufferizedValues(v, bbArg)) { 171*b0a4e958SMatthias Springer if (state.getOptions().testAnalysisOnly) 172*b0a4e958SMatthias Springer annotateEquivalentReturnBbArg(opOperand, bbArg); 173*b0a4e958SMatthias Springer return bbArg.getArgNumber(); 174*b0a4e958SMatthias Springer } 175*b0a4e958SMatthias Springer } 176*b0a4e958SMatthias Springer } 177*b0a4e958SMatthias Springer return std::nullopt; 178*b0a4e958SMatthias Springer }; 179*b0a4e958SMatthias Springer 180*b0a4e958SMatthias Springer int64_t numResults = returnOps.front()->getNumOperands(); 181*b0a4e958SMatthias Springer for (int64_t i = 0; i < numResults; ++i) { 182*b0a4e958SMatthias Springer // Find the equivalent block argument index for the i-th operand of the 183*b0a4e958SMatthias Springer // first func.return op. 184*b0a4e958SMatthias Springer std::optional<int64_t> maybeEquiv = 185*b0a4e958SMatthias Springer findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i)); 186*b0a4e958SMatthias Springer if (!maybeEquiv.has_value()) 187*b0a4e958SMatthias Springer continue; 188*b0a4e958SMatthias Springer int64_t bbArgIdx = *maybeEquiv; 189*b0a4e958SMatthias Springer bool allEquiv = true; 190*b0a4e958SMatthias Springer 191*b0a4e958SMatthias Springer // Check if all other func.return ops have the same equivalent block 192*b0a4e958SMatthias Springer // argument for the i-th operand. In contrast to aliasing information, 193*b0a4e958SMatthias Springer // which is just "merged", equivalence information must match across all 194*b0a4e958SMatthias Springer // func.return ops. 195*b0a4e958SMatthias Springer for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) { 196*b0a4e958SMatthias Springer std::optional<int64_t> maybeEquiv = 197*b0a4e958SMatthias Springer findEquivalentBlockArgIdx(returnOp->getOpOperand(i)); 198*b0a4e958SMatthias Springer if (maybeEquiv != bbArgIdx) { 199*b0a4e958SMatthias Springer allEquiv = false; 200*b0a4e958SMatthias Springer break; 201*b0a4e958SMatthias Springer } 202*b0a4e958SMatthias Springer } 203*b0a4e958SMatthias Springer 204*b0a4e958SMatthias Springer // All func.return ops have the same equivalent block argument for the i-th 205*b0a4e958SMatthias Springer // operand. 206*b0a4e958SMatthias Springer if (allEquiv) 207*b0a4e958SMatthias Springer funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx; 208e07a7fd5SMatthias Springer } 209e07a7fd5SMatthias Springer 210e07a7fd5SMatthias Springer return success(); 211e07a7fd5SMatthias Springer } 212e07a7fd5SMatthias Springer 21391c11574SAndrzej Warzyński static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead, 21491c11574SAndrzej Warzyński bool isWritten) { 215e07a7fd5SMatthias Springer OpBuilder b(funcOp.getContext()); 216e07a7fd5SMatthias Springer Attribute accessType; 217e07a7fd5SMatthias Springer if (isRead && isWritten) { 218e07a7fd5SMatthias Springer accessType = b.getStringAttr("read-write"); 219e07a7fd5SMatthias Springer } else if (isRead) { 220e07a7fd5SMatthias Springer accessType = b.getStringAttr("read"); 221e07a7fd5SMatthias Springer } else if (isWritten) { 222e07a7fd5SMatthias Springer accessType = b.getStringAttr("write"); 223e07a7fd5SMatthias Springer } else { 224e07a7fd5SMatthias Springer accessType = b.getStringAttr("none"); 225e07a7fd5SMatthias Springer } 2264002eaaaSMatthias Springer funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName, 2274002eaaaSMatthias Springer accessType); 228e07a7fd5SMatthias Springer } 229e07a7fd5SMatthias Springer 2303490aadfSMatthias Springer /// Determine which FuncOp bbArgs are read and which are written. When run on a 2313490aadfSMatthias Springer /// function with unknown ops, we conservatively assume that such ops bufferize 2323490aadfSMatthias Springer /// to a read + write. 233faa9be75SMatthias Springer static LogicalResult 23491c11574SAndrzej Warzyński funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, 235faa9be75SMatthias Springer FuncAnalysisState &funcState) { 23691c11574SAndrzej Warzyński for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e; 23791c11574SAndrzej Warzyński ++idx) { 2384002eaaaSMatthias Springer // Skip non-tensor arguments. 23991c11574SAndrzej Warzyński if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx))) 2404002eaaaSMatthias Springer continue; 2414002eaaaSMatthias Springer bool isRead; 2424002eaaaSMatthias Springer bool isWritten; 2434002eaaaSMatthias Springer if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>( 2444002eaaaSMatthias Springer idx, BufferizationDialect::kBufferAccessAttrName)) { 2454002eaaaSMatthias Springer // Buffer access behavior is specified on the function. Skip the analysis. 2464002eaaaSMatthias Springer StringRef str = accessAttr.getValue(); 2474002eaaaSMatthias Springer isRead = str == "read" || str == "read-write"; 2484002eaaaSMatthias Springer isWritten = str == "write" || str == "read-write"; 24991c11574SAndrzej Warzyński } else if (funcOp.getBody().empty()) { 250e07a7fd5SMatthias Springer // If the function has no body, conservatively assume that all args are 251e07a7fd5SMatthias Springer // read + written. 2524002eaaaSMatthias Springer isRead = true; 2534002eaaaSMatthias Springer isWritten = true; 2544002eaaaSMatthias Springer } else { 2554002eaaaSMatthias Springer // Analyze the body of the function. 2564002eaaaSMatthias Springer BlockArgument bbArg = funcOp.getArgument(idx); 2574002eaaaSMatthias Springer isRead = state.isValueRead(bbArg); 2584002eaaaSMatthias Springer isWritten = state.isValueWritten(bbArg); 259e07a7fd5SMatthias Springer } 260e07a7fd5SMatthias Springer 261e07a7fd5SMatthias Springer if (state.getOptions().testAnalysisOnly) 2624002eaaaSMatthias Springer annotateFuncArgAccess(funcOp, idx, isRead, isWritten); 263e07a7fd5SMatthias Springer if (isRead) 2644002eaaaSMatthias Springer funcState.readBbArgs[funcOp].insert(idx); 265e07a7fd5SMatthias Springer if (isWritten) 2664002eaaaSMatthias Springer funcState.writtenBbArgs[funcOp].insert(idx); 267e07a7fd5SMatthias Springer } 268e07a7fd5SMatthias Springer 269e07a7fd5SMatthias Springer return success(); 270e07a7fd5SMatthias Springer } 271e07a7fd5SMatthias Springer } // namespace 272e07a7fd5SMatthias Springer 273e07a7fd5SMatthias Springer /// Remove bufferization attributes on FuncOp arguments. 274e07a7fd5SMatthias Springer static void removeBufferizationAttributes(BlockArgument bbArg) { 27591c11574SAndrzej Warzyński auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 276e07a7fd5SMatthias Springer funcOp.removeArgAttr(bbArg.getArgNumber(), 277e07a7fd5SMatthias Springer BufferizationDialect::kBufferLayoutAttrName); 278e07a7fd5SMatthias Springer funcOp.removeArgAttr(bbArg.getArgNumber(), 279e07a7fd5SMatthias Springer BufferizationDialect::kWritableAttrName); 280e07a7fd5SMatthias Springer } 281e07a7fd5SMatthias Springer 28291c11574SAndrzej Warzyński /// Return the func::FuncOp called by `callOp`. 28391c11574SAndrzej Warzyński static func::FuncOp getCalledFunction(func::CallOp callOp) { 2849d34c052SMatthias Springer SymbolRefAttr sym = 2859d34c052SMatthias Springer llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); 286e07a7fd5SMatthias Springer if (!sym) 287e07a7fd5SMatthias Springer return nullptr; 28891c11574SAndrzej Warzyński return dyn_cast_or_null<func::FuncOp>( 289e07a7fd5SMatthias Springer SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 290e07a7fd5SMatthias Springer } 291e07a7fd5SMatthias Springer 292e07a7fd5SMatthias Springer /// Gather equivalence info of CallOps. 293e07a7fd5SMatthias Springer /// Note: This only adds new equivalence info if the called function was already 294e07a7fd5SMatthias Springer /// analyzed. 295e07a7fd5SMatthias Springer // TODO: This does not handle cyclic function call graphs etc. 29691c11574SAndrzej Warzyński static void equivalenceAnalysis(func::FuncOp funcOp, 297faa9be75SMatthias Springer OneShotAnalysisState &state, 298faa9be75SMatthias Springer FuncAnalysisState &funcState) { 29991c11574SAndrzej Warzyński funcOp->walk([&](func::CallOp callOp) { 30091c11574SAndrzej Warzyński func::FuncOp calledFunction = getCalledFunction(callOp); 30191c11574SAndrzej Warzyński assert(calledFunction && "could not retrieved called func::FuncOp"); 302e07a7fd5SMatthias Springer 303e07a7fd5SMatthias Springer // No equivalence info available for the called function. 304e07a7fd5SMatthias Springer if (!funcState.equivalentFuncArgs.count(calledFunction)) 305e07a7fd5SMatthias Springer return WalkResult::skip(); 306e07a7fd5SMatthias Springer 307e07a7fd5SMatthias Springer for (auto it : funcState.equivalentFuncArgs[calledFunction]) { 308e07a7fd5SMatthias Springer int64_t returnIdx = it.first; 309e07a7fd5SMatthias Springer int64_t bbargIdx = it.second; 310bf582569SMatthias Springer if (!state.isInPlace(callOp->getOpOperand(bbargIdx))) 311bf582569SMatthias Springer continue; 31291c11574SAndrzej Warzyński Value returnVal = callOp.getResult(returnIdx); 313e07a7fd5SMatthias Springer Value argVal = callOp->getOperand(bbargIdx); 314cf2d374eSMatthias Springer state.unionEquivalenceClasses(returnVal, argVal); 315e07a7fd5SMatthias Springer } 316e07a7fd5SMatthias Springer 317e07a7fd5SMatthias Springer return WalkResult::advance(); 318e07a7fd5SMatthias Springer }); 319e07a7fd5SMatthias Springer } 320e07a7fd5SMatthias Springer 3213d0ca2cfSMatthias Springer /// Return "true" if the given function signature has tensor semantics. 32291c11574SAndrzej Warzyński static bool hasTensorSignature(func::FuncOp funcOp) { 32391c11574SAndrzej Warzyński return llvm::any_of(funcOp.getFunctionType().getInputs(), 32491c11574SAndrzej Warzyński llvm::IsaPred<TensorType>) || 32591c11574SAndrzej Warzyński llvm::any_of(funcOp.getFunctionType().getResults(), 32691c11574SAndrzej Warzyński llvm::IsaPred<TensorType>); 3273d0ca2cfSMatthias Springer } 3283d0ca2cfSMatthias Springer 329e07a7fd5SMatthias Springer /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by 330c271ba7fSMatthias Springer /// callee-caller order (i.e., callees without callers first). Store all 331c271ba7fSMatthias Springer /// remaining functions (i.e., the ones that call each other recursively) in 332c271ba7fSMatthias Springer /// `remainingFuncOps`. 333c271ba7fSMatthias Springer /// 334e07a7fd5SMatthias Springer /// Store the map of FuncOp to all its callers in `callerMap`. 335c271ba7fSMatthias Springer /// 336c271ba7fSMatthias Springer /// Return `failure()` if we are unable to retrieve the called FuncOp from 337c271ba7fSMatthias Springer /// any func::CallOp. 338c271ba7fSMatthias Springer static LogicalResult getFuncOpsOrderedByCalls( 339c271ba7fSMatthias Springer ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, 340c271ba7fSMatthias Springer SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) { 341e07a7fd5SMatthias Springer // For each FuncOp, the set of functions called by it (i.e. the union of 342dc700f1eSIngo Müller // symbols of all nested func::CallOp). 34391c11574SAndrzej Warzyński DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; 344dc700f1eSIngo Müller // For each FuncOp, the number of func::CallOp it contains. 34591c11574SAndrzej Warzyński DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; 34691c11574SAndrzej Warzyński WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { 3473d0ca2cfSMatthias Springer // Collect function calls and populate the caller map. 348e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp[funcOp] = 0; 34991c11574SAndrzej Warzyński return funcOp.walk([&](func::CallOp callOp) -> WalkResult { 35091c11574SAndrzej Warzyński func::FuncOp calledFunction = getCalledFunction(callOp); 35191c11574SAndrzej Warzyński assert(calledFunction && "could not retrieved called func::FuncOp"); 3523d0ca2cfSMatthias Springer // If the called function does not have any tensors in its signature, then 3533d0ca2cfSMatthias Springer // it is not necessary to bufferize the callee before the caller. 3543d0ca2cfSMatthias Springer if (!hasTensorSignature(calledFunction)) 3553d0ca2cfSMatthias Springer return WalkResult::skip(); 3563d0ca2cfSMatthias Springer 35786fd1c13SBenjamin Kramer callerMap[calledFunction].insert(callOp); 35886fd1c13SBenjamin Kramer if (calledBy[calledFunction].insert(funcOp).second) { 359e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp[funcOp]++; 360e07a7fd5SMatthias Springer } 361e07a7fd5SMatthias Springer return WalkResult::advance(); 362e07a7fd5SMatthias Springer }); 363e07a7fd5SMatthias Springer }); 364e07a7fd5SMatthias Springer if (res.wasInterrupted()) 365e07a7fd5SMatthias Springer return failure(); 366c271ba7fSMatthias Springer 3673d0ca2cfSMatthias Springer // Iteratively remove function operations that do not call any of the 368c271ba7fSMatthias Springer // functions remaining in the callCounter map and add them to ordered list. 369e07a7fd5SMatthias Springer while (!numberCallOpsContainedInFuncOp.empty()) { 370e07a7fd5SMatthias Springer auto it = llvm::find_if(numberCallOpsContainedInFuncOp, 371e07a7fd5SMatthias Springer [](auto entry) { return entry.getSecond() == 0; }); 372e07a7fd5SMatthias Springer if (it == numberCallOpsContainedInFuncOp.end()) 373c271ba7fSMatthias Springer break; 374e07a7fd5SMatthias Springer orderedFuncOps.push_back(it->getFirst()); 375e07a7fd5SMatthias Springer for (auto callee : calledBy[it->getFirst()]) 376e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp[callee]--; 377e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp.erase(it); 378e07a7fd5SMatthias Springer } 379c271ba7fSMatthias Springer 380c271ba7fSMatthias Springer // Put all other functions in the list of remaining functions. These are 381c271ba7fSMatthias Springer // functions that call each other circularly. 382c271ba7fSMatthias Springer for (auto it : numberCallOpsContainedInFuncOp) 383c271ba7fSMatthias Springer remainingFuncOps.push_back(it.first); 384c271ba7fSMatthias Springer 385e07a7fd5SMatthias Springer return success(); 386e07a7fd5SMatthias Springer } 387e07a7fd5SMatthias Springer 388*b0a4e958SMatthias Springer /// Helper function that extracts the source from a memref.cast. If the given 389*b0a4e958SMatthias Springer /// value is not a memref.cast result, simply returns the given value. 390*b0a4e958SMatthias Springer static Value unpackCast(Value v) { 391*b0a4e958SMatthias Springer auto castOp = v.getDefiningOp<memref::CastOp>(); 392*b0a4e958SMatthias Springer if (!castOp) 393*b0a4e958SMatthias Springer return v; 394*b0a4e958SMatthias Springer return castOp.getSource(); 395*b0a4e958SMatthias Springer } 396*b0a4e958SMatthias Springer 397*b0a4e958SMatthias Springer /// Helper function that returns the return types (skipping casts) of the given 398*b0a4e958SMatthias Springer /// func.return ops. This function returns as many types as the return ops have 399*b0a4e958SMatthias Springer /// operands. If the i-th operand is not the same for all func.return ops, then 400*b0a4e958SMatthias Springer /// the i-th returned type is an "empty" type. 401*b0a4e958SMatthias Springer static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) { 402*b0a4e958SMatthias Springer assert(!returnOps.empty() && "expected at least one ReturnOp"); 403*b0a4e958SMatthias Springer int numOperands = returnOps.front()->getNumOperands(); 404*b0a4e958SMatthias Springer 405*b0a4e958SMatthias Springer // Helper function that unpacks memref.cast ops and returns the type. 406*b0a4e958SMatthias Springer auto getSourceType = [&](Value v) { return unpackCast(v).getType(); }; 407*b0a4e958SMatthias Springer 408*b0a4e958SMatthias Springer SmallVector<Type> result; 409*b0a4e958SMatthias Springer for (int i = 0; i < numOperands; ++i) { 410*b0a4e958SMatthias Springer // Get the type of the i-th operand of the first func.return ops. 411*b0a4e958SMatthias Springer Type t = getSourceType(returnOps.front()->getOperand(i)); 412*b0a4e958SMatthias Springer 413*b0a4e958SMatthias Springer // Check if all other func.return ops have a matching operand type. 414*b0a4e958SMatthias Springer for (int j = 1; j < static_cast<int>(returnOps.size()); ++j) 415*b0a4e958SMatthias Springer if (getSourceType(returnOps[j]->getOperand(i)) != t) 416*b0a4e958SMatthias Springer t = Type(); 417*b0a4e958SMatthias Springer 418*b0a4e958SMatthias Springer result.push_back(t); 419*b0a4e958SMatthias Springer } 420*b0a4e958SMatthias Springer 421*b0a4e958SMatthias Springer return result; 422*b0a4e958SMatthias Springer } 423*b0a4e958SMatthias Springer 424e07a7fd5SMatthias Springer /// Fold return values that are memref casts and update function return types. 425e07a7fd5SMatthias Springer /// 426e07a7fd5SMatthias Springer /// During FuncOp bufferization, the exact type of the returned memrefs (if any) 427e07a7fd5SMatthias Springer /// is not known yet. Therefore, the bufferization uses memref types with the 428e07a7fd5SMatthias Springer /// most generic layout map as function return types. After bufferizing the 429e07a7fd5SMatthias Springer /// entire function body, a more concise memref type can potentially be used for 430e07a7fd5SMatthias Springer /// the return type of the function. 43191c11574SAndrzej Warzyński static void foldMemRefCasts(func::FuncOp funcOp) { 432*b0a4e958SMatthias Springer // There is nothing to do for bodiless ops. 43391c11574SAndrzej Warzyński if (funcOp.getBody().empty()) 434e07a7fd5SMatthias Springer return; 435e07a7fd5SMatthias Springer 436*b0a4e958SMatthias Springer // Compute the common result types of all return ops. 437*b0a4e958SMatthias Springer SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp); 438*b0a4e958SMatthias Springer SmallVector<Type> resultTypes = getReturnTypes(returnOps); 439e07a7fd5SMatthias Springer 440*b0a4e958SMatthias Springer // Remove direct casts. 441*b0a4e958SMatthias Springer for (func::ReturnOp returnOp : returnOps) { 442e07a7fd5SMatthias Springer for (OpOperand &operand : returnOp->getOpOperands()) { 443*b0a4e958SMatthias Springer // Bail if no common result type was found. 444*b0a4e958SMatthias Springer if (resultTypes[operand.getOperandNumber()]) { 445*b0a4e958SMatthias Springer operand.set(unpackCast(operand.get())); 446*b0a4e958SMatthias Springer } 447e07a7fd5SMatthias Springer } 448e07a7fd5SMatthias Springer } 449e07a7fd5SMatthias Springer 450*b0a4e958SMatthias Springer // Fill in the missing result types that were not the same among all 451*b0a4e958SMatthias Springer // func.return ops. 452*b0a4e958SMatthias Springer for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) { 453*b0a4e958SMatthias Springer if (resultTypes[i]) 454*b0a4e958SMatthias Springer continue; 455*b0a4e958SMatthias Springer resultTypes[i] = funcOp.getFunctionType().getResult(i); 456*b0a4e958SMatthias Springer } 457*b0a4e958SMatthias Springer 458*b0a4e958SMatthias Springer // Update the function type. 45991c11574SAndrzej Warzyński auto newFuncType = FunctionType::get( 46091c11574SAndrzej Warzyński funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); 461e07a7fd5SMatthias Springer funcOp.setType(newFuncType); 462e07a7fd5SMatthias Springer } 463e07a7fd5SMatthias Springer 464f470f8cbSMatthias Springer LogicalResult 465f470f8cbSMatthias Springer mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, 466ae05bd99SMatthias Springer OneShotAnalysisState &state, 467ae05bd99SMatthias Springer BufferizationStatistics *statistics) { 4687cdfc843SMatthias Springer assert(state.getOptions().bufferizeFunctionBoundaries && 469d6dab38aSMatthias Springer "expected that function boundary bufferization is activated"); 470faa9be75SMatthias Springer FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); 471e07a7fd5SMatthias Springer 472c271ba7fSMatthias Springer // A list of non-circular functions in the order in which they are analyzed 473c271ba7fSMatthias Springer // and bufferized. 47491c11574SAndrzej Warzyński SmallVector<func::FuncOp> orderedFuncOps; 475c271ba7fSMatthias Springer // A list of all other functions. I.e., functions that call each other 476c271ba7fSMatthias Springer // recursively. For these, we analyze the function body but not the function 477c271ba7fSMatthias Springer // boundary. 478c271ba7fSMatthias Springer SmallVector<func::FuncOp> remainingFuncOps; 479e07a7fd5SMatthias Springer 480e07a7fd5SMatthias Springer // A mapping of FuncOps to their callers. 481e07a7fd5SMatthias Springer FuncCallerMap callerMap; 482e07a7fd5SMatthias Springer 483c271ba7fSMatthias Springer if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, 484c271ba7fSMatthias Springer remainingFuncOps, callerMap))) 485e07a7fd5SMatthias Springer return failure(); 486e07a7fd5SMatthias Springer 487c271ba7fSMatthias Springer // Analyze functions in order. Starting with functions that are not calling 488c271ba7fSMatthias Springer // any other functions. 48991c11574SAndrzej Warzyński for (func::FuncOp funcOp : orderedFuncOps) { 490060c8be5SMaya Amrami if (!state.getOptions().isOpAllowed(funcOp)) 491060c8be5SMaya Amrami continue; 492060c8be5SMaya Amrami 493e07a7fd5SMatthias Springer // Now analyzing function. 494e07a7fd5SMatthias Springer funcState.startFunctionAnalysis(funcOp); 495e07a7fd5SMatthias Springer 496e07a7fd5SMatthias Springer // Gather equivalence info for CallOps. 497cf2d374eSMatthias Springer equivalenceAnalysis(funcOp, state, funcState); 498e07a7fd5SMatthias Springer 499e07a7fd5SMatthias Springer // Analyze funcOp. 500ae05bd99SMatthias Springer if (failed(analyzeOp(funcOp, state, statistics))) 501e07a7fd5SMatthias Springer return failure(); 502e07a7fd5SMatthias Springer 5033490aadfSMatthias Springer // Run some extra function analyses. 504faa9be75SMatthias Springer if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) || 505faa9be75SMatthias Springer failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState))) 5063490aadfSMatthias Springer return failure(); 5073490aadfSMatthias Springer 508e07a7fd5SMatthias Springer // Mark op as fully analyzed. 509e07a7fd5SMatthias Springer funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; 510e07a7fd5SMatthias Springer } 511e07a7fd5SMatthias Springer 512c271ba7fSMatthias Springer // Analyze all other functions. All function boundary analyses are skipped. 513c271ba7fSMatthias Springer for (func::FuncOp funcOp : remainingFuncOps) { 514c271ba7fSMatthias Springer if (!state.getOptions().isOpAllowed(funcOp)) 515c271ba7fSMatthias Springer continue; 516c271ba7fSMatthias Springer 517c271ba7fSMatthias Springer // Gather equivalence info for CallOps. 518c271ba7fSMatthias Springer equivalenceAnalysis(funcOp, state, funcState); 519c271ba7fSMatthias Springer 520c271ba7fSMatthias Springer // Analyze funcOp. 521c271ba7fSMatthias Springer if (failed(analyzeOp(funcOp, state, statistics))) 522c271ba7fSMatthias Springer return failure(); 523c271ba7fSMatthias Springer 524c271ba7fSMatthias Springer // TODO: We currently skip all function argument analyses for functions 525c271ba7fSMatthias Springer // that call each other circularly. These analyses do not support recursive 526c271ba7fSMatthias Springer // calls yet. The `BufferizableOpInterface` implementations of `func` 527c271ba7fSMatthias Springer // dialect ops return conservative results in the absence of analysis 528c271ba7fSMatthias Springer // information. 529c271ba7fSMatthias Springer } 530c271ba7fSMatthias Springer 531e07a7fd5SMatthias Springer return success(); 532f470f8cbSMatthias Springer } 533f470f8cbSMatthias Springer 534c7a9e5e5SPeiming Liu void mlir::bufferization::removeBufferizationAttributesInModule( 535c7a9e5e5SPeiming Liu ModuleOp moduleOp) { 53691c11574SAndrzej Warzyński moduleOp.walk([&](func::FuncOp op) { 537c7a9e5e5SPeiming Liu for (BlockArgument bbArg : op.getArguments()) 538c7a9e5e5SPeiming Liu removeBufferizationAttributes(bbArg); 539c7a9e5e5SPeiming Liu }); 540c7a9e5e5SPeiming Liu } 541c7a9e5e5SPeiming Liu 542f470f8cbSMatthias Springer LogicalResult mlir::bufferization::bufferizeModuleOp( 543ae05bd99SMatthias Springer ModuleOp moduleOp, const OneShotBufferizationOptions &options, 5449cf96850SMaya Amrami BufferizationStatistics *statistics) { 545f470f8cbSMatthias Springer assert(options.bufferizeFunctionBoundaries && 546f470f8cbSMatthias Springer "expected that function boundary bufferization is activated"); 547f470f8cbSMatthias Springer IRRewriter rewriter(moduleOp.getContext()); 548f470f8cbSMatthias Springer 549c271ba7fSMatthias Springer // A list of non-circular functions in the order in which they are analyzed 550c271ba7fSMatthias Springer // and bufferized. 55191c11574SAndrzej Warzyński SmallVector<func::FuncOp> orderedFuncOps; 552c271ba7fSMatthias Springer // A list of all other functions. I.e., functions that call each other 553c271ba7fSMatthias Springer // recursively. For these, we analyze the function body but not the function 554c271ba7fSMatthias Springer // boundary. 555c271ba7fSMatthias Springer SmallVector<func::FuncOp> remainingFuncOps; 556f470f8cbSMatthias Springer 557f470f8cbSMatthias Springer // A mapping of FuncOps to their callers. 558f470f8cbSMatthias Springer FuncCallerMap callerMap; 559f470f8cbSMatthias Springer 560c271ba7fSMatthias Springer // Try to bufferize functions in calling order. I.e., first bufferize 561c271ba7fSMatthias Springer // functions that do not call other functions. This allows us to infer 562c271ba7fSMatthias Springer // accurate buffer types for function return values. Functions that call 563c271ba7fSMatthias Springer // each other recursively are bufferized in an unspecified order at the end. 564c271ba7fSMatthias Springer // We may use unnecessarily "complex" (in terms of layout map) buffer types. 565c271ba7fSMatthias Springer if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, 566c271ba7fSMatthias Springer remainingFuncOps, callerMap))) 567f470f8cbSMatthias Springer return failure(); 568c271ba7fSMatthias Springer llvm::append_range(orderedFuncOps, remainingFuncOps); 569e07a7fd5SMatthias Springer 570e07a7fd5SMatthias Springer // Bufferize functions. 57191c11574SAndrzej Warzyński for (func::FuncOp funcOp : orderedFuncOps) { 572e07a7fd5SMatthias Springer // Note: It would be good to apply cleanups here but we cannot as aliasInfo 573e07a7fd5SMatthias Springer // would be invalidated. 5749d34c052SMatthias Springer 57591c11574SAndrzej Warzyński if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) { 5769d34c052SMatthias Springer // This function was not analyzed and RaW conflicts were not resolved. 5779d34c052SMatthias Springer // Buffer copies must be inserted before every write. 5789d34c052SMatthias Springer OneShotBufferizationOptions updatedOptions = options; 5799d34c052SMatthias Springer updatedOptions.copyBeforeWrite = true; 5809d34c052SMatthias Springer if (failed(bufferizeOp(funcOp, updatedOptions, statistics))) 581e07a7fd5SMatthias Springer return failure(); 5829d34c052SMatthias Springer } else { 5839d34c052SMatthias Springer if (failed(bufferizeOp(funcOp, options, statistics))) 5849d34c052SMatthias Springer return failure(); 5859d34c052SMatthias Springer } 5869d34c052SMatthias Springer 587f287da8aSMatthias Springer // Change buffer return types to more precise layout maps. 58875ef84bfSOleg Shyshkov if (options.inferFunctionResultLayout) 589e07a7fd5SMatthias Springer foldMemRefCasts(funcOp); 590e07a7fd5SMatthias Springer } 591e07a7fd5SMatthias Springer 5928f2d83daSMatthias Springer // Bufferize all other ops. 593fa101214SRyan Holt for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { 5948f2d83daSMatthias Springer // Functions were already bufferized. 59591c11574SAndrzej Warzyński if (isa<func::FuncOp>(&op)) 5968f2d83daSMatthias Springer continue; 5978f2d83daSMatthias Springer if (failed(bufferizeOp(&op, options, statistics))) 5988f2d83daSMatthias Springer return failure(); 5998f2d83daSMatthias Springer } 6008f2d83daSMatthias Springer 601e07a7fd5SMatthias Springer // Post-pass cleanup of function argument attributes. 602c7a9e5e5SPeiming Liu removeBufferizationAttributesInModule(moduleOp); 603e07a7fd5SMatthias Springer 604e07a7fd5SMatthias Springer return success(); 605e07a7fd5SMatthias Springer } 606f470f8cbSMatthias Springer 607f470f8cbSMatthias Springer LogicalResult mlir::bufferization::runOneShotModuleBufferize( 608ae05bd99SMatthias Springer ModuleOp moduleOp, const OneShotBufferizationOptions &options, 6099cf96850SMaya Amrami BufferizationStatistics *statistics) { 610f470f8cbSMatthias Springer assert(options.bufferizeFunctionBoundaries && 611f470f8cbSMatthias Springer "expected that function boundary bufferization is activated"); 612f7dd9a32SMatthias Springer assert(!(options.copyBeforeWrite && options.testAnalysisOnly) && 613f7dd9a32SMatthias Springer "invalid combination of bufferization flags"); 614f7dd9a32SMatthias Springer if (!options.copyBeforeWrite) { 6159cf96850SMaya Amrami if (options.noAnalysisFuncFilter.empty()) { 616ae05bd99SMatthias Springer if (failed(insertTensorCopies(moduleOp, options, statistics))) 617f470f8cbSMatthias Springer return failure(); 618060c8be5SMaya Amrami } else { 6199cf96850SMaya Amrami // FuncOps whose names are specified in options.noAnalysisFuncFilter will 6209cf96850SMaya Amrami // not be analyzed. Ops in these FuncOps will not be analyzed as well. 6219cf96850SMaya Amrami OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) { 62291c11574SAndrzej Warzyński auto func = dyn_cast<func::FuncOp>(op); 6239cf96850SMaya Amrami if (!func) 62491c11574SAndrzej Warzyński func = op->getParentOfType<func::FuncOp>(); 6259cf96850SMaya Amrami if (func) 6269cf96850SMaya Amrami return llvm::is_contained(options.noAnalysisFuncFilter, 62791c11574SAndrzej Warzyński func.getSymName()); 6289cf96850SMaya Amrami return false; 6299cf96850SMaya Amrami }; 630060c8be5SMaya Amrami OneShotBufferizationOptions updatedOptions(options); 631060c8be5SMaya Amrami updatedOptions.opFilter.denyOperation(analysisFilterFn); 632060c8be5SMaya Amrami if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics))) 633060c8be5SMaya Amrami return failure(); 634060c8be5SMaya Amrami } 635f7dd9a32SMatthias Springer } 636f470f8cbSMatthias Springer if (options.testAnalysisOnly) 637f470f8cbSMatthias Springer return success(); 6389cf96850SMaya Amrami if (failed(bufferizeModuleOp(moduleOp, options, statistics))) 639f470f8cbSMatthias Springer return failure(); 640f470f8cbSMatthias Springer return success(); 641f470f8cbSMatthias Springer } 642