1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Module Bufferization is an extension of One-Shot Bufferize that 10 // bufferizes function boundaries. It provides `BufferizableOpInterface` 11 // implementations for FuncOp, CallOp and ReturnOp. 12 // 13 // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. 14 // This function analyzes the given module and determines the order of analysis 15 // and bufferization: Functions that are called are processed before their 16 // respective callers. 17 // 18 // After analyzing a FuncOp, additional information about its bbArgs is 19 // gathered and stored in `FuncAnalysisState`. 20 // 21 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs 22 // for 23 // each tensor return value (if any). 24 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is 25 // read/written. 26 // 27 // Module Bufferization implements the following calling convention. 28 // 29 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always 30 // be written to in-place. 31 // * If a tensor operand of a CallOp is read after the CallOp, the operand of 32 // the CallOp must bufferize out-of-place. 33 // 34 // Example: The tensor.insert op bufferizes in-place because it is allowed to 35 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize 36 // out-of-place because `%t0` is modified by the callee but read by the 37 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must 38 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. 39 // ``` 40 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { 41 // %f = ... : f32 42 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> 43 // return %0 : tensor<?xf32> 44 // } 45 // 46 // func @caller() -> () { 47 // %t0 = ... : tensor<?xf32> 48 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) 49 // %2 = tensor.extract %1[...] : tensor<?xf32> 50 // } 51 // ``` 52 // 53 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot 54 // analyze the function body. In such a case, the CallOp analysis conservatively 55 // assumes that each tensor OpOperand is both read and written. 56 // 57 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked 58 // as "not reading" and/or "not writing". 59 60 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 61 62 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 63 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 64 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 65 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 66 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 67 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 68 #include "mlir/Dialect/Func/IR/FuncOps.h" 69 #include "mlir/Dialect/MemRef/IR/MemRef.h" 70 #include "mlir/IR/BuiltinTypes.h" 71 #include "mlir/IR/Operation.h" 72 73 using namespace mlir; 74 using namespace mlir::bufferization; 75 using namespace mlir::bufferization::func_ext; 76 77 /// A mapping of FuncOps to their callers. 78 using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>; 79 80 /// Get or create FuncAnalysisState. 81 static FuncAnalysisState & 82 getOrCreateFuncAnalysisState(OneShotAnalysisState &state) { 83 auto *result = state.getExtension<FuncAnalysisState>(); 84 if (result) 85 return *result; 86 return state.addExtension<FuncAnalysisState>(); 87 } 88 89 namespace { 90 91 /// Annotate IR with the results of the analysis. For testing purposes only. 92 static void annotateEquivalentReturnBbArg(OpOperand &returnVal, 93 BlockArgument bbArg) { 94 const char *kEquivalentArgsAttr = "__equivalent_func_args__"; 95 Operation *op = returnVal.getOwner(); 96 97 SmallVector<int64_t> equivBbArgs; 98 if (op->hasAttr(kEquivalentArgsAttr)) { 99 auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr)); 100 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { 101 return cast<IntegerAttr>(a).getValue().getSExtValue(); 102 })); 103 } else { 104 equivBbArgs.append(op->getNumOperands(), -1); 105 } 106 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); 107 108 OpBuilder b(op->getContext()); 109 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); 110 } 111 112 /// Store function BlockArguments that are equivalent to/aliasing a returned 113 /// value in FuncAnalysisState. 114 static LogicalResult 115 aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, 116 FuncAnalysisState &funcState) { 117 if (funcOp.getBody().empty()) { 118 // No function body available. Conservatively assume that every tensor 119 // return value may alias with any tensor bbArg. 120 FunctionType type = funcOp.getFunctionType(); 121 for (const auto &inputIt : llvm::enumerate(type.getInputs())) { 122 if (!isa<TensorType>(inputIt.value())) 123 continue; 124 for (const auto &resultIt : llvm::enumerate(type.getResults())) { 125 if (!isa<TensorType>(resultIt.value())) 126 continue; 127 int64_t returnIdx = resultIt.index(); 128 int64_t bbArgIdx = inputIt.index(); 129 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); 130 } 131 } 132 return success(); 133 } 134 135 // Find all func.return ops. 136 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp); 137 assert(!returnOps.empty() && "expected at least one ReturnOp"); 138 139 // Build alias sets. Merge all aliases from all func.return ops. 140 for (BlockArgument bbArg : funcOp.getArguments()) { 141 if (isa<RankedTensorType>(bbArg.getType())) { 142 int64_t bbArgIdx = bbArg.getArgNumber(); 143 // Store aliases in a set, so that we don't add the same alias twice. 144 SetVector<int64_t> aliases; 145 for (func::ReturnOp returnOp : returnOps) { 146 for (OpOperand &returnVal : returnOp->getOpOperands()) { 147 if (isa<RankedTensorType>(returnVal.get().getType())) { 148 int64_t returnIdx = returnVal.getOperandNumber(); 149 if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) 150 aliases.insert(returnIdx); 151 } 152 } 153 } 154 for (int64_t alias : aliases) 155 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias); 156 } 157 } 158 159 // Build equivalence sets. 160 // Helper function that finds an equivalent block argument index for the 161 // given OpOperand. Return std::nullopt if no equivalent block argument could 162 // be found. 163 auto findEquivalentBlockArgIdx = 164 [&](OpOperand &opOperand) -> std::optional<int64_t> { 165 Value v = opOperand.get(); 166 if (!isa<TensorType>(v.getType())) 167 return std::nullopt; 168 for (BlockArgument bbArg : funcOp.getArguments()) { 169 if (isa<RankedTensorType>(bbArg.getType())) { 170 if (state.areEquivalentBufferizedValues(v, bbArg)) { 171 if (state.getOptions().testAnalysisOnly) 172 annotateEquivalentReturnBbArg(opOperand, bbArg); 173 return bbArg.getArgNumber(); 174 } 175 } 176 } 177 return std::nullopt; 178 }; 179 180 int64_t numResults = returnOps.front()->getNumOperands(); 181 for (int64_t i = 0; i < numResults; ++i) { 182 // Find the equivalent block argument index for the i-th operand of the 183 // first func.return op. 184 std::optional<int64_t> maybeEquiv = 185 findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i)); 186 if (!maybeEquiv.has_value()) 187 continue; 188 int64_t bbArgIdx = *maybeEquiv; 189 bool allEquiv = true; 190 191 // Check if all other func.return ops have the same equivalent block 192 // argument for the i-th operand. In contrast to aliasing information, 193 // which is just "merged", equivalence information must match across all 194 // func.return ops. 195 for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) { 196 std::optional<int64_t> maybeEquiv = 197 findEquivalentBlockArgIdx(returnOp->getOpOperand(i)); 198 if (maybeEquiv != bbArgIdx) { 199 allEquiv = false; 200 break; 201 } 202 } 203 204 // All func.return ops have the same equivalent block argument for the i-th 205 // operand. 206 if (allEquiv) 207 funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx; 208 } 209 210 return success(); 211 } 212 213 static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead, 214 bool isWritten) { 215 OpBuilder b(funcOp.getContext()); 216 Attribute accessType; 217 if (isRead && isWritten) { 218 accessType = b.getStringAttr("read-write"); 219 } else if (isRead) { 220 accessType = b.getStringAttr("read"); 221 } else if (isWritten) { 222 accessType = b.getStringAttr("write"); 223 } else { 224 accessType = b.getStringAttr("none"); 225 } 226 funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName, 227 accessType); 228 } 229 230 /// Determine which FuncOp bbArgs are read and which are written. When run on a 231 /// function with unknown ops, we conservatively assume that such ops bufferize 232 /// to a read + write. 233 static LogicalResult 234 funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, 235 FuncAnalysisState &funcState) { 236 for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e; 237 ++idx) { 238 // Skip non-tensor arguments. 239 if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx))) 240 continue; 241 bool isRead; 242 bool isWritten; 243 if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>( 244 idx, BufferizationDialect::kBufferAccessAttrName)) { 245 // Buffer access behavior is specified on the function. Skip the analysis. 246 StringRef str = accessAttr.getValue(); 247 isRead = str == "read" || str == "read-write"; 248 isWritten = str == "write" || str == "read-write"; 249 } else if (funcOp.getBody().empty()) { 250 // If the function has no body, conservatively assume that all args are 251 // read + written. 252 isRead = true; 253 isWritten = true; 254 } else { 255 // Analyze the body of the function. 256 BlockArgument bbArg = funcOp.getArgument(idx); 257 isRead = state.isValueRead(bbArg); 258 isWritten = state.isValueWritten(bbArg); 259 } 260 261 if (state.getOptions().testAnalysisOnly) 262 annotateFuncArgAccess(funcOp, idx, isRead, isWritten); 263 if (isRead) 264 funcState.readBbArgs[funcOp].insert(idx); 265 if (isWritten) 266 funcState.writtenBbArgs[funcOp].insert(idx); 267 } 268 269 return success(); 270 } 271 } // namespace 272 273 /// Remove bufferization attributes on FuncOp arguments. 274 static void removeBufferizationAttributes(BlockArgument bbArg) { 275 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 276 funcOp.removeArgAttr(bbArg.getArgNumber(), 277 BufferizationDialect::kBufferLayoutAttrName); 278 funcOp.removeArgAttr(bbArg.getArgNumber(), 279 BufferizationDialect::kWritableAttrName); 280 } 281 282 /// Return the func::FuncOp called by `callOp`. 283 static func::FuncOp getCalledFunction(func::CallOp callOp) { 284 SymbolRefAttr sym = 285 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); 286 if (!sym) 287 return nullptr; 288 return dyn_cast_or_null<func::FuncOp>( 289 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 290 } 291 292 /// Gather equivalence info of CallOps. 293 /// Note: This only adds new equivalence info if the called function was already 294 /// analyzed. 295 // TODO: This does not handle cyclic function call graphs etc. 296 static void equivalenceAnalysis(func::FuncOp funcOp, 297 OneShotAnalysisState &state, 298 FuncAnalysisState &funcState) { 299 funcOp->walk([&](func::CallOp callOp) { 300 func::FuncOp calledFunction = getCalledFunction(callOp); 301 assert(calledFunction && "could not retrieved called func::FuncOp"); 302 303 // No equivalence info available for the called function. 304 if (!funcState.equivalentFuncArgs.count(calledFunction)) 305 return WalkResult::skip(); 306 307 for (auto it : funcState.equivalentFuncArgs[calledFunction]) { 308 int64_t returnIdx = it.first; 309 int64_t bbargIdx = it.second; 310 if (!state.isInPlace(callOp->getOpOperand(bbargIdx))) 311 continue; 312 Value returnVal = callOp.getResult(returnIdx); 313 Value argVal = callOp->getOperand(bbargIdx); 314 state.unionEquivalenceClasses(returnVal, argVal); 315 } 316 317 return WalkResult::advance(); 318 }); 319 } 320 321 /// Return "true" if the given function signature has tensor semantics. 322 static bool hasTensorSignature(func::FuncOp funcOp) { 323 return llvm::any_of(funcOp.getFunctionType().getInputs(), 324 llvm::IsaPred<TensorType>) || 325 llvm::any_of(funcOp.getFunctionType().getResults(), 326 llvm::IsaPred<TensorType>); 327 } 328 329 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by 330 /// callee-caller order (i.e., callees without callers first). Store all 331 /// remaining functions (i.e., the ones that call each other recursively) in 332 /// `remainingFuncOps`. 333 /// 334 /// Store the map of FuncOp to all its callers in `callerMap`. 335 /// 336 /// Return `failure()` if we are unable to retrieve the called FuncOp from 337 /// any func::CallOp. 338 static LogicalResult getFuncOpsOrderedByCalls( 339 ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, 340 SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) { 341 // For each FuncOp, the set of functions called by it (i.e. the union of 342 // symbols of all nested func::CallOp). 343 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; 344 // For each FuncOp, the number of func::CallOp it contains. 345 DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; 346 WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { 347 // Collect function calls and populate the caller map. 348 numberCallOpsContainedInFuncOp[funcOp] = 0; 349 return funcOp.walk([&](func::CallOp callOp) -> WalkResult { 350 func::FuncOp calledFunction = getCalledFunction(callOp); 351 assert(calledFunction && "could not retrieved called func::FuncOp"); 352 // If the called function does not have any tensors in its signature, then 353 // it is not necessary to bufferize the callee before the caller. 354 if (!hasTensorSignature(calledFunction)) 355 return WalkResult::skip(); 356 357 callerMap[calledFunction].insert(callOp); 358 if (calledBy[calledFunction].insert(funcOp).second) { 359 numberCallOpsContainedInFuncOp[funcOp]++; 360 } 361 return WalkResult::advance(); 362 }); 363 }); 364 if (res.wasInterrupted()) 365 return failure(); 366 367 // Iteratively remove function operations that do not call any of the 368 // functions remaining in the callCounter map and add them to ordered list. 369 while (!numberCallOpsContainedInFuncOp.empty()) { 370 auto it = llvm::find_if(numberCallOpsContainedInFuncOp, 371 [](auto entry) { return entry.getSecond() == 0; }); 372 if (it == numberCallOpsContainedInFuncOp.end()) 373 break; 374 orderedFuncOps.push_back(it->getFirst()); 375 for (auto callee : calledBy[it->getFirst()]) 376 numberCallOpsContainedInFuncOp[callee]--; 377 numberCallOpsContainedInFuncOp.erase(it); 378 } 379 380 // Put all other functions in the list of remaining functions. These are 381 // functions that call each other circularly. 382 for (auto it : numberCallOpsContainedInFuncOp) 383 remainingFuncOps.push_back(it.first); 384 385 return success(); 386 } 387 388 /// Helper function that extracts the source from a memref.cast. If the given 389 /// value is not a memref.cast result, simply returns the given value. 390 static Value unpackCast(Value v) { 391 auto castOp = v.getDefiningOp<memref::CastOp>(); 392 if (!castOp) 393 return v; 394 return castOp.getSource(); 395 } 396 397 /// Helper function that returns the return types (skipping casts) of the given 398 /// func.return ops. This function returns as many types as the return ops have 399 /// operands. If the i-th operand is not the same for all func.return ops, then 400 /// the i-th returned type is an "empty" type. 401 static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) { 402 assert(!returnOps.empty() && "expected at least one ReturnOp"); 403 int numOperands = returnOps.front()->getNumOperands(); 404 405 // Helper function that unpacks memref.cast ops and returns the type. 406 auto getSourceType = [&](Value v) { return unpackCast(v).getType(); }; 407 408 SmallVector<Type> result; 409 for (int i = 0; i < numOperands; ++i) { 410 // Get the type of the i-th operand of the first func.return ops. 411 Type t = getSourceType(returnOps.front()->getOperand(i)); 412 413 // Check if all other func.return ops have a matching operand type. 414 for (int j = 1; j < static_cast<int>(returnOps.size()); ++j) 415 if (getSourceType(returnOps[j]->getOperand(i)) != t) 416 t = Type(); 417 418 result.push_back(t); 419 } 420 421 return result; 422 } 423 424 /// Fold return values that are memref casts and update function return types. 425 /// 426 /// During FuncOp bufferization, the exact type of the returned memrefs (if any) 427 /// is not known yet. Therefore, the bufferization uses memref types with the 428 /// most generic layout map as function return types. After bufferizing the 429 /// entire function body, a more concise memref type can potentially be used for 430 /// the return type of the function. 431 static void foldMemRefCasts(func::FuncOp funcOp) { 432 // There is nothing to do for bodiless ops. 433 if (funcOp.getBody().empty()) 434 return; 435 436 // Compute the common result types of all return ops. 437 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp); 438 SmallVector<Type> resultTypes = getReturnTypes(returnOps); 439 440 // Remove direct casts. 441 for (func::ReturnOp returnOp : returnOps) { 442 for (OpOperand &operand : returnOp->getOpOperands()) { 443 // Bail if no common result type was found. 444 if (resultTypes[operand.getOperandNumber()]) { 445 operand.set(unpackCast(operand.get())); 446 } 447 } 448 } 449 450 // Fill in the missing result types that were not the same among all 451 // func.return ops. 452 for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) { 453 if (resultTypes[i]) 454 continue; 455 resultTypes[i] = funcOp.getFunctionType().getResult(i); 456 } 457 458 // Update the function type. 459 auto newFuncType = FunctionType::get( 460 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); 461 funcOp.setType(newFuncType); 462 } 463 464 LogicalResult 465 mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, 466 OneShotAnalysisState &state, 467 BufferizationStatistics *statistics) { 468 assert(state.getOptions().bufferizeFunctionBoundaries && 469 "expected that function boundary bufferization is activated"); 470 FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); 471 472 // A list of non-circular functions in the order in which they are analyzed 473 // and bufferized. 474 SmallVector<func::FuncOp> orderedFuncOps; 475 // A list of all other functions. I.e., functions that call each other 476 // recursively. For these, we analyze the function body but not the function 477 // boundary. 478 SmallVector<func::FuncOp> remainingFuncOps; 479 480 // A mapping of FuncOps to their callers. 481 FuncCallerMap callerMap; 482 483 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, 484 remainingFuncOps, callerMap))) 485 return failure(); 486 487 // Analyze functions in order. Starting with functions that are not calling 488 // any other functions. 489 for (func::FuncOp funcOp : orderedFuncOps) { 490 if (!state.getOptions().isOpAllowed(funcOp)) 491 continue; 492 493 // Now analyzing function. 494 funcState.startFunctionAnalysis(funcOp); 495 496 // Gather equivalence info for CallOps. 497 equivalenceAnalysis(funcOp, state, funcState); 498 499 // Analyze funcOp. 500 if (failed(analyzeOp(funcOp, state, statistics))) 501 return failure(); 502 503 // Run some extra function analyses. 504 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) || 505 failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState))) 506 return failure(); 507 508 // Mark op as fully analyzed. 509 funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; 510 } 511 512 // Analyze all other functions. All function boundary analyses are skipped. 513 for (func::FuncOp funcOp : remainingFuncOps) { 514 if (!state.getOptions().isOpAllowed(funcOp)) 515 continue; 516 517 // Gather equivalence info for CallOps. 518 equivalenceAnalysis(funcOp, state, funcState); 519 520 // Analyze funcOp. 521 if (failed(analyzeOp(funcOp, state, statistics))) 522 return failure(); 523 524 // TODO: We currently skip all function argument analyses for functions 525 // that call each other circularly. These analyses do not support recursive 526 // calls yet. The `BufferizableOpInterface` implementations of `func` 527 // dialect ops return conservative results in the absence of analysis 528 // information. 529 } 530 531 return success(); 532 } 533 534 void mlir::bufferization::removeBufferizationAttributesInModule( 535 ModuleOp moduleOp) { 536 moduleOp.walk([&](func::FuncOp op) { 537 for (BlockArgument bbArg : op.getArguments()) 538 removeBufferizationAttributes(bbArg); 539 }); 540 } 541 542 LogicalResult mlir::bufferization::bufferizeModuleOp( 543 ModuleOp moduleOp, const OneShotBufferizationOptions &options, 544 BufferizationStatistics *statistics) { 545 assert(options.bufferizeFunctionBoundaries && 546 "expected that function boundary bufferization is activated"); 547 IRRewriter rewriter(moduleOp.getContext()); 548 549 // A list of non-circular functions in the order in which they are analyzed 550 // and bufferized. 551 SmallVector<func::FuncOp> orderedFuncOps; 552 // A list of all other functions. I.e., functions that call each other 553 // recursively. For these, we analyze the function body but not the function 554 // boundary. 555 SmallVector<func::FuncOp> remainingFuncOps; 556 557 // A mapping of FuncOps to their callers. 558 FuncCallerMap callerMap; 559 560 // Try to bufferize functions in calling order. I.e., first bufferize 561 // functions that do not call other functions. This allows us to infer 562 // accurate buffer types for function return values. Functions that call 563 // each other recursively are bufferized in an unspecified order at the end. 564 // We may use unnecessarily "complex" (in terms of layout map) buffer types. 565 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, 566 remainingFuncOps, callerMap))) 567 return failure(); 568 llvm::append_range(orderedFuncOps, remainingFuncOps); 569 570 // Bufferize functions. 571 for (func::FuncOp funcOp : orderedFuncOps) { 572 // Note: It would be good to apply cleanups here but we cannot as aliasInfo 573 // would be invalidated. 574 575 if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) { 576 // This function was not analyzed and RaW conflicts were not resolved. 577 // Buffer copies must be inserted before every write. 578 OneShotBufferizationOptions updatedOptions = options; 579 updatedOptions.copyBeforeWrite = true; 580 if (failed(bufferizeOp(funcOp, updatedOptions, statistics))) 581 return failure(); 582 } else { 583 if (failed(bufferizeOp(funcOp, options, statistics))) 584 return failure(); 585 } 586 587 // Change buffer return types to more precise layout maps. 588 if (options.inferFunctionResultLayout) 589 foldMemRefCasts(funcOp); 590 } 591 592 // Bufferize all other ops. 593 for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { 594 // Functions were already bufferized. 595 if (isa<func::FuncOp>(&op)) 596 continue; 597 if (failed(bufferizeOp(&op, options, statistics))) 598 return failure(); 599 } 600 601 // Post-pass cleanup of function argument attributes. 602 removeBufferizationAttributesInModule(moduleOp); 603 604 return success(); 605 } 606 607 LogicalResult mlir::bufferization::runOneShotModuleBufferize( 608 ModuleOp moduleOp, const OneShotBufferizationOptions &options, 609 BufferizationStatistics *statistics) { 610 assert(options.bufferizeFunctionBoundaries && 611 "expected that function boundary bufferization is activated"); 612 assert(!(options.copyBeforeWrite && options.testAnalysisOnly) && 613 "invalid combination of bufferization flags"); 614 if (!options.copyBeforeWrite) { 615 if (options.noAnalysisFuncFilter.empty()) { 616 if (failed(insertTensorCopies(moduleOp, options, statistics))) 617 return failure(); 618 } else { 619 // FuncOps whose names are specified in options.noAnalysisFuncFilter will 620 // not be analyzed. Ops in these FuncOps will not be analyzed as well. 621 OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) { 622 auto func = dyn_cast<func::FuncOp>(op); 623 if (!func) 624 func = op->getParentOfType<func::FuncOp>(); 625 if (func) 626 return llvm::is_contained(options.noAnalysisFuncFilter, 627 func.getSymName()); 628 return false; 629 }; 630 OneShotBufferizationOptions updatedOptions(options); 631 updatedOptions.opFilter.denyOperation(analysisFilterFn); 632 if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics))) 633 return failure(); 634 } 635 } 636 if (options.testAnalysisOnly) 637 return success(); 638 if (failed(bufferizeModuleOp(moduleOp, options, statistics))) 639 return failure(); 640 return success(); 641 } 642