1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 16 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/Operation.h" 21 #include "mlir/Interfaces/ControlFlowInterfaces.h" 22 #include "mlir/Interfaces/SideEffectInterfaces.h" 23 #include "mlir/Pass/PassManager.h" 24 #include "mlir/Transforms/Passes.h" 25 #include <optional> 26 27 namespace mlir { 28 namespace bufferization { 29 #define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE 30 #define GEN_PASS_DEF_ONESHOTBUFFERIZE 31 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 32 } // namespace bufferization 33 } // namespace mlir 34 35 #define DEBUG_TYPE "bufferize" 36 37 using namespace mlir; 38 using namespace mlir::bufferization; 39 40 namespace { 41 42 static LayoutMapOption parseLayoutMapOption(const std::string &s) { 43 if (s == "fully-dynamic-layout-map") 44 return LayoutMapOption::FullyDynamicLayoutMap; 45 if (s == "identity-layout-map") 46 return LayoutMapOption::IdentityLayoutMap; 47 if (s == "infer-layout-map") 48 return LayoutMapOption::InferLayoutMap; 49 llvm_unreachable("invalid layout map option"); 50 } 51 52 static OneShotBufferizationOptions::AnalysisHeuristic 53 parseHeuristicOption(const std::string &s) { 54 if (s == "bottom-up") 55 return OneShotBufferizationOptions::AnalysisHeuristic::BottomUp; 56 if (s == "top-down") 57 return OneShotBufferizationOptions::AnalysisHeuristic::TopDown; 58 if (s == "bottom-up-from-terminators") 59 return OneShotBufferizationOptions::AnalysisHeuristic:: 60 BottomUpFromTerminators; 61 if (s == "fuzzer") 62 return OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer; 63 llvm_unreachable("invalid analysisheuristic option"); 64 } 65 66 struct OneShotBufferizePass 67 : public bufferization::impl::OneShotBufferizeBase<OneShotBufferizePass> { 68 OneShotBufferizePass() = default; 69 70 explicit OneShotBufferizePass(const OneShotBufferizationOptions &options) 71 : options(options) {} 72 73 void getDependentDialects(DialectRegistry ®istry) const override { 74 registry 75 .insert<bufferization::BufferizationDialect, memref::MemRefDialect>(); 76 } 77 78 void runOnOperation() override { 79 OneShotBufferizationOptions opt; 80 if (!options) { 81 // Make new bufferization options if none were provided when creating the 82 // pass. 83 opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops; 84 opt.allowUnknownOps = allowUnknownOps; 85 opt.analysisFuzzerSeed = analysisFuzzerSeed; 86 opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic); 87 opt.copyBeforeWrite = copyBeforeWrite; 88 opt.dumpAliasSets = dumpAliasSets; 89 opt.setFunctionBoundaryTypeConversion( 90 parseLayoutMapOption(functionBoundaryTypeConversion)); 91 92 if (mustInferMemorySpace && useEncodingForMemorySpace) { 93 emitError(getOperation()->getLoc()) 94 << "only one of 'must-infer-memory-space' and " 95 "'use-encoding-for-memory-space' are allowed in " 96 << getArgument(); 97 return signalPassFailure(); 98 } 99 100 if (mustInferMemorySpace) { 101 opt.defaultMemorySpaceFn = 102 [](TensorType t) -> std::optional<Attribute> { 103 return std::nullopt; 104 }; 105 } 106 107 if (useEncodingForMemorySpace) { 108 opt.defaultMemorySpaceFn = 109 [](TensorType t) -> std::optional<Attribute> { 110 if (auto rtt = dyn_cast<RankedTensorType>(t)) 111 return rtt.getEncoding(); 112 return std::nullopt; 113 }; 114 } 115 116 opt.printConflicts = printConflicts; 117 opt.bufferAlignment = bufferAlignment; 118 opt.testAnalysisOnly = testAnalysisOnly; 119 opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; 120 opt.checkParallelRegions = checkParallelRegions; 121 opt.noAnalysisFuncFilter = noAnalysisFuncFilter; 122 123 // Configure type converter. 124 LayoutMapOption unknownTypeConversionOption = 125 parseLayoutMapOption(unknownTypeConversion); 126 if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) { 127 emitError(UnknownLoc::get(&getContext()), 128 "Invalid option: 'infer-layout-map' is not a valid value for " 129 "'unknown-type-conversion'"); 130 return signalPassFailure(); 131 } 132 opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace, 133 const BufferizationOptions &options) { 134 auto tensorType = cast<TensorType>(value.getType()); 135 if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap) 136 return bufferization::getMemRefTypeWithStaticIdentityLayout( 137 tensorType, memorySpace); 138 assert(unknownTypeConversionOption == 139 LayoutMapOption::FullyDynamicLayoutMap && 140 "invalid layout map option"); 141 return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, 142 memorySpace); 143 }; 144 145 // Configure op filter. 146 OpFilter::Entry::FilterFn filterFn = [&](Operation *op) { 147 // Filter may be specified via options. 148 if (this->dialectFilter.hasValue()) 149 return llvm::is_contained(this->dialectFilter, 150 op->getDialect()->getNamespace()); 151 // No filter specified: All other ops are allowed. 152 return true; 153 }; 154 opt.opFilter.allowOperation(filterFn); 155 } else { 156 opt = *options; 157 } 158 159 if (opt.copyBeforeWrite && opt.testAnalysisOnly) { 160 // These two flags do not make sense together: "copy-before-write" 161 // indicates that copies should be inserted before every memory write, 162 // but "test-analysis-only" indicates that only the analysis should be 163 // tested. (I.e., no IR is bufferized.) 164 emitError(UnknownLoc::get(&getContext()), 165 "Invalid option: 'copy-before-write' cannot be used with " 166 "'test-analysis-only'"); 167 return signalPassFailure(); 168 } 169 170 if (opt.printConflicts && !opt.testAnalysisOnly) { 171 emitError( 172 UnknownLoc::get(&getContext()), 173 "Invalid option: 'print-conflicts' requires 'test-analysis-only'"); 174 return signalPassFailure(); 175 } 176 177 if (opt.dumpAliasSets && !opt.testAnalysisOnly) { 178 emitError( 179 UnknownLoc::get(&getContext()), 180 "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'"); 181 return signalPassFailure(); 182 } 183 184 BufferizationStatistics statistics; 185 ModuleOp moduleOp = getOperation(); 186 if (opt.bufferizeFunctionBoundaries) { 187 if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) { 188 signalPassFailure(); 189 return; 190 } 191 } else { 192 if (!opt.noAnalysisFuncFilter.empty()) { 193 emitError(UnknownLoc::get(&getContext()), 194 "Invalid option: 'no-analysis-func-filter' requires " 195 "'bufferize-function-boundaries'"); 196 return signalPassFailure(); 197 } 198 if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) { 199 signalPassFailure(); 200 return; 201 } 202 } 203 204 // Set pass statistics. 205 this->numBufferAlloc = statistics.numBufferAlloc; 206 this->numTensorInPlace = statistics.numTensorInPlace; 207 this->numTensorOutOfPlace = statistics.numTensorOutOfPlace; 208 } 209 210 private: 211 std::optional<OneShotBufferizationOptions> options; 212 }; 213 } // namespace 214 215 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() { 216 return std::make_unique<OneShotBufferizePass>(); 217 } 218 219 std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass( 220 const OneShotBufferizationOptions &options) { 221 return std::make_unique<OneShotBufferizePass>(options); 222 } 223 224 //===----------------------------------------------------------------------===// 225 // BufferizableOpInterface-based Bufferization 226 //===----------------------------------------------------------------------===// 227 228 namespace { 229 /// A rewriter that keeps track of extra information during bufferization. 230 class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { 231 public: 232 BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps, 233 DenseSet<Operation *> &toMemrefOps, 234 SmallVector<Operation *> &worklist, 235 const BufferizationOptions &options, 236 BufferizationStatistics *statistics) 237 : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), 238 worklist(worklist), analysisState(options), statistics(statistics) { 239 setListener(this); 240 } 241 242 protected: 243 void notifyOperationErased(Operation *op) override { 244 erasedOps.insert(op); 245 // Erase if present. 246 toMemrefOps.erase(op); 247 } 248 249 void notifyOperationInserted(Operation *op, InsertPoint previous) override { 250 // We only care about newly created ops. 251 if (previous.isSet()) 252 return; 253 254 erasedOps.erase(op); 255 256 // Gather statistics about allocs. 257 if (statistics) { 258 if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op)) 259 statistics->numBufferAlloc += static_cast<int64_t>( 260 sideEffectingOp.hasEffect<MemoryEffects::Allocate>()); 261 } 262 263 // Keep track of to_memref ops. 264 if (isa<ToMemrefOp>(op)) { 265 toMemrefOps.insert(op); 266 return; 267 } 268 269 // Skip to_tensor ops. 270 if (isa<ToTensorOp>(op)) 271 return; 272 273 // Skip non-tensor ops. 274 if (!hasTensorSemantics(op)) 275 return; 276 277 // Skip ops that are not allowed to be bufferized. 278 auto const &options = analysisState.getOptions(); 279 if (!options.isOpAllowed(op)) 280 return; 281 282 // Add op to worklist. 283 worklist.push_back(op); 284 } 285 286 private: 287 /// A set of all erased ops. 288 DenseSet<Operation *> &erasedOps; 289 290 /// A set of all to_memref ops. 291 DenseSet<Operation *> &toMemrefOps; 292 293 /// The worklist of ops to be bufferized. 294 SmallVector<Operation *> &worklist; 295 296 /// The analysis state. Used for debug assertions and access to the 297 /// bufferization options. 298 const AnalysisState analysisState; 299 300 /// Bufferization statistics for debugging. 301 BufferizationStatistics *statistics; 302 }; 303 } // namespace 304 305 LogicalResult bufferization::bufferizeOp(Operation *op, 306 const BufferizationOptions &options, 307 BufferizationStatistics *statistics) { 308 if (options.copyBeforeWrite) { 309 AnalysisState state(options); 310 if (failed(insertTensorCopies(op, state))) 311 return failure(); 312 } 313 314 // Keep track of to_memref ops. 315 DenseSet<Operation *> toMemrefOps; 316 op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); }); 317 318 // Gather all bufferizable ops in top-to-bottom order. 319 // 320 // We should ideally know the exact memref type of all operands when 321 // bufferizing an op. (This is the case when bufferizing top-to-bottom.) 322 // Otherwise, we have to use a memref type with a fully dynamic layout map to 323 // avoid copies. We are currently missing patterns for layout maps to 324 // canonicalize away (or canonicalize to more precise layouts). 325 SmallVector<Operation *> worklist; 326 op->walk<WalkOrder::PostOrder>([&](Operation *op) { 327 if (options.isOpAllowed(op) && hasTensorSemantics(op)) 328 worklist.push_back(op); 329 }); 330 331 // Keep track of all erased ops. 332 DenseSet<Operation *> erasedOps; 333 334 // Bufferize all ops. 335 BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, 336 worklist, options, statistics); 337 for (unsigned i = 0; i < worklist.size(); ++i) { 338 Operation *nextOp = worklist[i]; 339 // Skip ops that were erased. 340 if (erasedOps.contains(nextOp)) 341 continue; 342 // Skip ops that are not bufferizable or not allowed. 343 auto bufferizableOp = options.dynCastBufferizableOp(nextOp); 344 if (!bufferizableOp) 345 continue; 346 // Skip ops that no longer have tensor semantics. 347 if (!hasTensorSemantics(nextOp)) 348 continue; 349 // Check for unsupported unstructured control flow. 350 if (!bufferizableOp.supportsUnstructuredControlFlow()) 351 for (Region &r : nextOp->getRegions()) 352 if (r.getBlocks().size() > 1) 353 return nextOp->emitOpError( 354 "op or BufferizableOpInterface implementation does not support " 355 "unstructured control flow, but at least one region has multiple " 356 "blocks"); 357 358 // Bufferize the op. 359 LLVM_DEBUG(llvm::dbgs() 360 << "//===-------------------------------------------===//\n" 361 << "IR after bufferizing: " << nextOp->getName() << "\n"); 362 rewriter.setInsertionPoint(nextOp); 363 if (failed(bufferizableOp.bufferize(rewriter, options))) { 364 LLVM_DEBUG(llvm::dbgs() 365 << "failed to bufferize\n" 366 << "//===-------------------------------------------===//\n"); 367 return nextOp->emitError("failed to bufferize op"); 368 } 369 LLVM_DEBUG(llvm::dbgs() 370 << *op 371 << "\n//===-------------------------------------------===//\n"); 372 } 373 374 // Return early if the top-level op is entirely gone. 375 if (erasedOps.contains(op)) 376 return success(); 377 378 // Fold all to_memref(to_tensor(x)) pairs. 379 for (Operation *op : toMemrefOps) { 380 rewriter.setInsertionPoint(op); 381 (void)bufferization::foldToMemrefToTensorPair( 382 rewriter, cast<ToMemrefOp>(op), options); 383 } 384 385 // Remove all dead to_tensor ops. 386 op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) { 387 if (toTensorOp->getUses().empty()) { 388 rewriter.eraseOp(toTensorOp); 389 return WalkResult::skip(); 390 } 391 return WalkResult::advance(); 392 }); 393 394 /// Check the result of bufferization. Return an error if an op was not 395 /// bufferized, unless partial bufferization is allowed. 396 if (options.allowUnknownOps) 397 return success(); 398 399 for (Operation *op : worklist) { 400 // Skip ops that are entirely gone. 401 if (erasedOps.contains(op)) 402 continue; 403 // Ops that no longer have tensor semantics (because they were updated 404 // in-place) are allowed. 405 if (!hasTensorSemantics(op)) 406 continue; 407 // Continue ops that are not allowed. 408 if (!options.isOpAllowed(op)) 409 continue; 410 // Ops without any uses and no side effects will fold away. 411 if (op->getUses().empty() && isMemoryEffectFree(op)) 412 continue; 413 // ToTensorOps/ToMemrefOps are allowed in the output. 414 if (isa<ToTensorOp, ToMemrefOp>(op)) 415 continue; 416 return op->emitError("op was not bufferized"); 417 } 418 419 return success(); 420 } 421 422 LogicalResult 423 bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, 424 const BufferizationOptions &options) { 425 OpBuilder::InsertionGuard g(rewriter); 426 auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp()); 427 if (!bufferizableOp) 428 return failure(); 429 430 // Compute the new signature. 431 SmallVector<Type> newTypes; 432 for (BlockArgument &bbArg : block->getArguments()) { 433 auto tensorType = dyn_cast<TensorType>(bbArg.getType()); 434 if (!tensorType) { 435 newTypes.push_back(bbArg.getType()); 436 continue; 437 } 438 439 FailureOr<BaseMemRefType> memrefType = 440 bufferization::getBufferType(bbArg, options); 441 if (failed(memrefType)) 442 return failure(); 443 newTypes.push_back(*memrefType); 444 } 445 446 // Change the type of all block arguments. 447 for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) { 448 if (bbArg.getType() == type) 449 continue; 450 451 // Collect all uses of the bbArg. 452 SmallVector<OpOperand *> bbArgUses; 453 for (OpOperand &use : bbArg.getUses()) 454 bbArgUses.push_back(&use); 455 456 Type tensorType = bbArg.getType(); 457 // Change the bbArg type to memref. 458 bbArg.setType(type); 459 460 // Replace all uses of the original tensor bbArg. 461 rewriter.setInsertionPointToStart(block); 462 if (!bbArgUses.empty()) { 463 Value toTensorOp = rewriter.create<bufferization::ToTensorOp>( 464 bbArg.getLoc(), tensorType, bbArg); 465 for (OpOperand *use : bbArgUses) 466 use->set(toTensorOp); 467 } 468 } 469 470 // Bufferize callers of the block. 471 for (Operation *op : block->getUsers()) { 472 auto branchOp = dyn_cast<BranchOpInterface>(op); 473 if (!branchOp) 474 return op->emitOpError("cannot bufferize ops with block references that " 475 "do not implement BranchOpInterface"); 476 477 auto it = llvm::find(op->getSuccessors(), block); 478 assert(it != op->getSuccessors().end() && "could find successor"); 479 int64_t successorIdx = std::distance(op->getSuccessors().begin(), it); 480 481 SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx); 482 SmallVector<Value> newOperands; 483 for (auto [operand, type] : 484 llvm::zip(operands.getForwardedOperands(), newTypes)) { 485 if (operand.getType() == type) { 486 // Not a tensor type. Nothing to do for this operand. 487 newOperands.push_back(operand); 488 continue; 489 } 490 FailureOr<BaseMemRefType> operandBufferType = 491 bufferization::getBufferType(operand, options); 492 if (failed(operandBufferType)) 493 return failure(); 494 rewriter.setInsertionPointAfterValue(operand); 495 Value bufferizedOperand = rewriter.create<bufferization::ToMemrefOp>( 496 operand.getLoc(), *operandBufferType, operand); 497 // A cast is needed if the operand and the block argument have different 498 // bufferized types. 499 if (type != *operandBufferType) 500 bufferizedOperand = rewriter.create<memref::CastOp>( 501 operand.getLoc(), type, bufferizedOperand); 502 newOperands.push_back(bufferizedOperand); 503 } 504 operands.getMutableForwardedOperands().assign(newOperands); 505 } 506 507 return success(); 508 } 509