1f89bb3c0SAlexander Belyaev //===- Bufferize.cpp - Bufferization utilities ----------------------------===// 2f89bb3c0SAlexander Belyaev // 3f89bb3c0SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4f89bb3c0SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information. 5f89bb3c0SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6f89bb3c0SAlexander Belyaev // 7f89bb3c0SAlexander Belyaev //===----------------------------------------------------------------------===// 8f89bb3c0SAlexander Belyaev 967d0d7acSMichele Scuttari #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 10f89bb3c0SAlexander Belyaev 117a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14d2dacde5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 1628b2f792SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 1723aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 18eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 19a6236898SMatthias Springer #include "mlir/IR/Diagnostics.h" 20f89bb3c0SAlexander Belyaev #include "mlir/IR/Operation.h" 216ecebb49SMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h" 22fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h" 23d2dacde5SMatthias Springer #include "mlir/Pass/PassManager.h" 24d2dacde5SMatthias Springer #include "mlir/Transforms/Passes.h" 25a1fe1f5fSKazu Hirata #include <optional> 26f89bb3c0SAlexander Belyaev 2767d0d7acSMichele Scuttari namespace mlir { 2867d0d7acSMichele Scuttari namespace bufferization { 2967d0d7acSMichele Scuttari #define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE 3067d0d7acSMichele Scuttari #define GEN_PASS_DEF_ONESHOTBUFFERIZE 3167d0d7acSMichele Scuttari #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 3267d0d7acSMichele Scuttari } // namespace bufferization 3367d0d7acSMichele Scuttari } // namespace mlir 3467d0d7acSMichele Scuttari 35df23ede2SMatthias Springer #define DEBUG_TYPE "bufferize" 36df23ede2SMatthias Springer 37f89bb3c0SAlexander Belyaev using namespace mlir; 38f89bb3c0SAlexander Belyaev using namespace mlir::bufferization; 39f89bb3c0SAlexander Belyaev 40f89bb3c0SAlexander Belyaev namespace { 41d2dacde5SMatthias Springer 42c780184aSLorenzo Chelini static LayoutMapOption parseLayoutMapOption(const std::string &s) { 43f287da8aSMatthias Springer if (s == "fully-dynamic-layout-map") 44c780184aSLorenzo Chelini return LayoutMapOption::FullyDynamicLayoutMap; 45f287da8aSMatthias Springer if (s == "identity-layout-map") 46c780184aSLorenzo Chelini return LayoutMapOption::IdentityLayoutMap; 47f287da8aSMatthias Springer if (s == "infer-layout-map") 48c780184aSLorenzo Chelini return LayoutMapOption::InferLayoutMap; 49f287da8aSMatthias Springer llvm_unreachable("invalid layout map option"); 50f287da8aSMatthias Springer } 51f287da8aSMatthias Springer 521b99f3a2SMatthias Springer static OneShotBufferizationOptions::AnalysisHeuristic 531b99f3a2SMatthias Springer parseHeuristicOption(const std::string &s) { 541b99f3a2SMatthias Springer if (s == "bottom-up") 551b99f3a2SMatthias Springer return OneShotBufferizationOptions::AnalysisHeuristic::BottomUp; 561b99f3a2SMatthias Springer if (s == "top-down") 571b99f3a2SMatthias Springer return OneShotBufferizationOptions::AnalysisHeuristic::TopDown; 5835d3b343SMatthias Springer if (s == "bottom-up-from-terminators") 5935d3b343SMatthias Springer return OneShotBufferizationOptions::AnalysisHeuristic:: 6035d3b343SMatthias Springer BottomUpFromTerminators; 6135d3b343SMatthias Springer if (s == "fuzzer") 6235d3b343SMatthias Springer return OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer; 631b99f3a2SMatthias Springer llvm_unreachable("invalid analysisheuristic option"); 641b99f3a2SMatthias Springer } 651b99f3a2SMatthias Springer 66d2dacde5SMatthias Springer struct OneShotBufferizePass 6767d0d7acSMichele Scuttari : public bufferization::impl::OneShotBufferizeBase<OneShotBufferizePass> { 680666e50eSMehdi Amini OneShotBufferizePass() = default; 69d2dacde5SMatthias Springer 709597b16aSMatthias Springer explicit OneShotBufferizePass(const OneShotBufferizationOptions &options) 71d2dacde5SMatthias Springer : options(options) {} 72d2dacde5SMatthias Springer 73d2dacde5SMatthias Springer void getDependentDialects(DialectRegistry ®istry) const override { 74c076fa1cSMatthias Springer registry 75c076fa1cSMatthias Springer .insert<bufferization::BufferizationDialect, memref::MemRefDialect>(); 76d2dacde5SMatthias Springer } 77d2dacde5SMatthias Springer 78d2dacde5SMatthias Springer void runOnOperation() override { 799597b16aSMatthias Springer OneShotBufferizationOptions opt; 80d2dacde5SMatthias Springer if (!options) { 81d2dacde5SMatthias Springer // Make new bufferization options if none were provided when creating the 82d2dacde5SMatthias Springer // pass. 836bf043e7SMartin Erhart opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops; 84d2dacde5SMatthias Springer opt.allowUnknownOps = allowUnknownOps; 85d2dacde5SMatthias Springer opt.analysisFuzzerSeed = analysisFuzzerSeed; 861b99f3a2SMatthias Springer opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic); 87f7dd9a32SMatthias Springer opt.copyBeforeWrite = copyBeforeWrite; 88bb9d1b55SMatthias Springer opt.dumpAliasSets = dumpAliasSets; 8975ef84bfSOleg Shyshkov opt.setFunctionBoundaryTypeConversion( 9075ef84bfSOleg Shyshkov parseLayoutMapOption(functionBoundaryTypeConversion)); 91ced2fc78SChristopher Bate 92ced2fc78SChristopher Bate if (mustInferMemorySpace && useEncodingForMemorySpace) { 93ced2fc78SChristopher Bate emitError(getOperation()->getLoc()) 94ced2fc78SChristopher Bate << "only one of 'must-infer-memory-space' and " 95ced2fc78SChristopher Bate "'use-encoding-for-memory-space' are allowed in " 96ced2fc78SChristopher Bate << getArgument(); 97ced2fc78SChristopher Bate return signalPassFailure(); 98ced2fc78SChristopher Bate } 99ced2fc78SChristopher Bate 100067d2779Sian Bearman if (mustInferMemorySpace) { 101067d2779Sian Bearman opt.defaultMemorySpaceFn = 102067d2779Sian Bearman [](TensorType t) -> std::optional<Attribute> { 103067d2779Sian Bearman return std::nullopt; 104067d2779Sian Bearman }; 105067d2779Sian Bearman } 106ced2fc78SChristopher Bate 107ced2fc78SChristopher Bate if (useEncodingForMemorySpace) { 108ced2fc78SChristopher Bate opt.defaultMemorySpaceFn = 109ced2fc78SChristopher Bate [](TensorType t) -> std::optional<Attribute> { 110ced2fc78SChristopher Bate if (auto rtt = dyn_cast<RankedTensorType>(t)) 111ced2fc78SChristopher Bate return rtt.getEncoding(); 112ced2fc78SChristopher Bate return std::nullopt; 113ced2fc78SChristopher Bate }; 114ced2fc78SChristopher Bate } 115ced2fc78SChristopher Bate 116d2dacde5SMatthias Springer opt.printConflicts = printConflicts; 11770334081SSimon Camphausen opt.bufferAlignment = bufferAlignment; 118d2dacde5SMatthias Springer opt.testAnalysisOnly = testAnalysisOnly; 119d6dab38aSMatthias Springer opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; 120d5863721SMax191 opt.checkParallelRegions = checkParallelRegions; 1219cf96850SMaya Amrami opt.noAnalysisFuncFilter = noAnalysisFuncFilter; 122d2dacde5SMatthias Springer 123606f7c8fSMatthias Springer // Configure type converter. 124c780184aSLorenzo Chelini LayoutMapOption unknownTypeConversionOption = 125606f7c8fSMatthias Springer parseLayoutMapOption(unknownTypeConversion); 126a6236898SMatthias Springer if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) { 127a6236898SMatthias Springer emitError(UnknownLoc::get(&getContext()), 128a6236898SMatthias Springer "Invalid option: 'infer-layout-map' is not a valid value for " 129a6236898SMatthias Springer "'unknown-type-conversion'"); 130a6236898SMatthias Springer return signalPassFailure(); 131a6236898SMatthias Springer } 1329bb63374SLei Zhang opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace, 133606f7c8fSMatthias Springer const BufferizationOptions &options) { 1345550c821STres Popp auto tensorType = cast<TensorType>(value.getType()); 135c780184aSLorenzo Chelini if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap) 136606f7c8fSMatthias Springer return bufferization::getMemRefTypeWithStaticIdentityLayout( 137606f7c8fSMatthias Springer tensorType, memorySpace); 138c780184aSLorenzo Chelini assert(unknownTypeConversionOption == 139c780184aSLorenzo Chelini LayoutMapOption::FullyDynamicLayoutMap && 140606f7c8fSMatthias Springer "invalid layout map option"); 141606f7c8fSMatthias Springer return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, 142606f7c8fSMatthias Springer memorySpace); 143606f7c8fSMatthias Springer }; 144606f7c8fSMatthias Springer 145606f7c8fSMatthias Springer // Configure op filter. 146b7f93c28SJeff Niu OpFilter::Entry::FilterFn filterFn = [&](Operation *op) { 147d2dacde5SMatthias Springer // Filter may be specified via options. 148d2dacde5SMatthias Springer if (this->dialectFilter.hasValue()) 149ad44495aSjacquesguan return llvm::is_contained(this->dialectFilter, 150ad44495aSjacquesguan op->getDialect()->getNamespace()); 151d2dacde5SMatthias Springer // No filter specified: All other ops are allowed. 152d2dacde5SMatthias Springer return true; 153d2dacde5SMatthias Springer }; 1541534177fSMatthias Springer opt.opFilter.allowOperation(filterFn); 155d2dacde5SMatthias Springer } else { 156d2dacde5SMatthias Springer opt = *options; 157d2dacde5SMatthias Springer } 158d2dacde5SMatthias Springer 159a6236898SMatthias Springer if (opt.copyBeforeWrite && opt.testAnalysisOnly) { 160a6236898SMatthias Springer // These two flags do not make sense together: "copy-before-write" 161a6236898SMatthias Springer // indicates that copies should be inserted before every memory write, 162a6236898SMatthias Springer // but "test-analysis-only" indicates that only the analysis should be 163a6236898SMatthias Springer // tested. (I.e., no IR is bufferized.) 164a6236898SMatthias Springer emitError(UnknownLoc::get(&getContext()), 165a6236898SMatthias Springer "Invalid option: 'copy-before-write' cannot be used with " 166a6236898SMatthias Springer "'test-analysis-only'"); 167a6236898SMatthias Springer return signalPassFailure(); 168a6236898SMatthias Springer } 169a6236898SMatthias Springer 170a6236898SMatthias Springer if (opt.printConflicts && !opt.testAnalysisOnly) { 171a6236898SMatthias Springer emitError( 172a6236898SMatthias Springer UnknownLoc::get(&getContext()), 173a6236898SMatthias Springer "Invalid option: 'print-conflicts' requires 'test-analysis-only'"); 174a6236898SMatthias Springer return signalPassFailure(); 175a6236898SMatthias Springer } 176a6236898SMatthias Springer 177a6236898SMatthias Springer if (opt.dumpAliasSets && !opt.testAnalysisOnly) { 178a6236898SMatthias Springer emitError( 179a6236898SMatthias Springer UnknownLoc::get(&getContext()), 180a6236898SMatthias Springer "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'"); 181a6236898SMatthias Springer return signalPassFailure(); 182a6236898SMatthias Springer } 183a6236898SMatthias Springer 184ae05bd99SMatthias Springer BufferizationStatistics statistics; 185d2dacde5SMatthias Springer ModuleOp moduleOp = getOperation(); 186d6dab38aSMatthias Springer if (opt.bufferizeFunctionBoundaries) { 1879cf96850SMaya Amrami if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) { 188e07a7fd5SMatthias Springer signalPassFailure(); 189e07a7fd5SMatthias Springer return; 190e07a7fd5SMatthias Springer } 191e07a7fd5SMatthias Springer } else { 192a6236898SMatthias Springer if (!opt.noAnalysisFuncFilter.empty()) { 193a6236898SMatthias Springer emitError(UnknownLoc::get(&getContext()), 194a6236898SMatthias Springer "Invalid option: 'no-analysis-func-filter' requires " 195a6236898SMatthias Springer "'bufferize-function-boundaries'"); 196a6236898SMatthias Springer return signalPassFailure(); 197a6236898SMatthias Springer } 198ae05bd99SMatthias Springer if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) { 199d2dacde5SMatthias Springer signalPassFailure(); 200d2dacde5SMatthias Springer return; 201d2dacde5SMatthias Springer } 202e07a7fd5SMatthias Springer } 203d2dacde5SMatthias Springer 204ae05bd99SMatthias Springer // Set pass statistics. 205ae05bd99SMatthias Springer this->numBufferAlloc = statistics.numBufferAlloc; 206ae05bd99SMatthias Springer this->numTensorInPlace = statistics.numTensorInPlace; 207ae05bd99SMatthias Springer this->numTensorOutOfPlace = statistics.numTensorOutOfPlace; 208d2dacde5SMatthias Springer } 209d2dacde5SMatthias Springer 210d2dacde5SMatthias Springer private: 2110a81ace0SKazu Hirata std::optional<OneShotBufferizationOptions> options; 212d2dacde5SMatthias Springer }; 213f89bb3c0SAlexander Belyaev } // namespace 214f89bb3c0SAlexander Belyaev 215d2dacde5SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() { 216d2dacde5SMatthias Springer return std::make_unique<OneShotBufferizePass>(); 217d2dacde5SMatthias Springer } 218d2dacde5SMatthias Springer 219d2dacde5SMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass( 2209597b16aSMatthias Springer const OneShotBufferizationOptions &options) { 221d2dacde5SMatthias Springer return std::make_unique<OneShotBufferizePass>(options); 222d2dacde5SMatthias Springer } 223d2dacde5SMatthias Springer 22449e37000SMatthias Springer //===----------------------------------------------------------------------===// 22549e37000SMatthias Springer // BufferizableOpInterface-based Bufferization 22649e37000SMatthias Springer //===----------------------------------------------------------------------===// 22749e37000SMatthias Springer 228d820acddSMatthias Springer namespace { 229d820acddSMatthias Springer /// A rewriter that keeps track of extra information during bufferization. 230c6532830SMatthias Springer class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { 231d820acddSMatthias Springer public: 232d820acddSMatthias Springer BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps, 233d820acddSMatthias Springer DenseSet<Operation *> &toMemrefOps, 234b3ebe3beSMatthias Springer SmallVector<Operation *> &worklist, 2352f0a634cSMatthias Springer const BufferizationOptions &options, 236ae05bd99SMatthias Springer BufferizationStatistics *statistics) 237d820acddSMatthias Springer : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), 2389d34c052SMatthias Springer worklist(worklist), analysisState(options), statistics(statistics) { 239c6532830SMatthias Springer setListener(this); 240c6532830SMatthias Springer } 241d820acddSMatthias Springer 242d820acddSMatthias Springer protected: 243914e6074SMatthias Springer void notifyOperationErased(Operation *op) override { 244d820acddSMatthias Springer erasedOps.insert(op); 2459785eb1bSMatthias Springer // Erase if present. 2469785eb1bSMatthias Springer toMemrefOps.erase(op); 247d820acddSMatthias Springer } 248d820acddSMatthias Springer 2495cc0f76dSMatthias Springer void notifyOperationInserted(Operation *op, InsertPoint previous) override { 2505cc0f76dSMatthias Springer // We only care about newly created ops. 2515cc0f76dSMatthias Springer if (previous.isSet()) 2525cc0f76dSMatthias Springer return; 2535cc0f76dSMatthias Springer 254b3ebe3beSMatthias Springer erasedOps.erase(op); 255d820acddSMatthias Springer 2566bf043e7SMartin Erhart // Gather statistics about allocs. 257ae05bd99SMatthias Springer if (statistics) { 2586bf043e7SMartin Erhart if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op)) 259ae05bd99SMatthias Springer statistics->numBufferAlloc += static_cast<int64_t>( 260ae05bd99SMatthias Springer sideEffectingOp.hasEffect<MemoryEffects::Allocate>()); 261ae05bd99SMatthias Springer } 262ae05bd99SMatthias Springer 263d820acddSMatthias Springer // Keep track of to_memref ops. 264d820acddSMatthias Springer if (isa<ToMemrefOp>(op)) { 265d820acddSMatthias Springer toMemrefOps.insert(op); 266d820acddSMatthias Springer return; 267d820acddSMatthias Springer } 268d820acddSMatthias Springer 269d820acddSMatthias Springer // Skip to_tensor ops. 270d820acddSMatthias Springer if (isa<ToTensorOp>(op)) 271d820acddSMatthias Springer return; 272d820acddSMatthias Springer 2732f0a634cSMatthias Springer // Skip non-tensor ops. 2742f0a634cSMatthias Springer if (!hasTensorSemantics(op)) 2752f0a634cSMatthias Springer return; 2762f0a634cSMatthias Springer 277b3ebe3beSMatthias Springer // Skip ops that are not allowed to be bufferized. 278b3ebe3beSMatthias Springer auto const &options = analysisState.getOptions(); 2799d34c052SMatthias Springer if (!options.isOpAllowed(op)) 2802f0a634cSMatthias Springer return; 2812f0a634cSMatthias Springer 282b3ebe3beSMatthias Springer // Add op to worklist. 283b3ebe3beSMatthias Springer worklist.push_back(op); 284d820acddSMatthias Springer } 285d820acddSMatthias Springer 286d820acddSMatthias Springer private: 287d820acddSMatthias Springer /// A set of all erased ops. 288d820acddSMatthias Springer DenseSet<Operation *> &erasedOps; 289d820acddSMatthias Springer 290d820acddSMatthias Springer /// A set of all to_memref ops. 291d820acddSMatthias Springer DenseSet<Operation *> &toMemrefOps; 292d820acddSMatthias Springer 293b3ebe3beSMatthias Springer /// The worklist of ops to be bufferized. 294b3ebe3beSMatthias Springer SmallVector<Operation *> &worklist; 2952f0a634cSMatthias Springer 296b3ebe3beSMatthias Springer /// The analysis state. Used for debug assertions and access to the 297b3ebe3beSMatthias Springer /// bufferization options. 298b3ebe3beSMatthias Springer const AnalysisState analysisState; 299b3ebe3beSMatthias Springer 300ae05bd99SMatthias Springer /// Bufferization statistics for debugging. 301ae05bd99SMatthias Springer BufferizationStatistics *statistics; 302d820acddSMatthias Springer }; 303d820acddSMatthias Springer } // namespace 304d820acddSMatthias Springer 3052f0a634cSMatthias Springer LogicalResult bufferization::bufferizeOp(Operation *op, 306b3ebe3beSMatthias Springer const BufferizationOptions &options, 307ae05bd99SMatthias Springer BufferizationStatistics *statistics) { 3089d34c052SMatthias Springer if (options.copyBeforeWrite) { 309b3ebe3beSMatthias Springer AnalysisState state(options); 310b3ebe3beSMatthias Springer if (failed(insertTensorCopies(op, state))) 311b3ebe3beSMatthias Springer return failure(); 312b3ebe3beSMatthias Springer } 313b3ebe3beSMatthias Springer 314d820acddSMatthias Springer // Keep track of to_memref ops. 315d820acddSMatthias Springer DenseSet<Operation *> toMemrefOps; 316d820acddSMatthias Springer op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); }); 317d820acddSMatthias Springer 318d820acddSMatthias Springer // Gather all bufferizable ops in top-to-bottom order. 31976b16010SMatthias Springer // 320d820acddSMatthias Springer // We should ideally know the exact memref type of all operands when 321d820acddSMatthias Springer // bufferizing an op. (This is the case when bufferizing top-to-bottom.) 322f287da8aSMatthias Springer // Otherwise, we have to use a memref type with a fully dynamic layout map to 323f287da8aSMatthias Springer // avoid copies. We are currently missing patterns for layout maps to 324f287da8aSMatthias Springer // canonicalize away (or canonicalize to more precise layouts). 325d820acddSMatthias Springer SmallVector<Operation *> worklist; 326ba9d886dSMatthias Springer op->walk<WalkOrder::PostOrder>([&](Operation *op) { 327db604911SBenjamin Kramer if (options.isOpAllowed(op) && hasTensorSemantics(op)) 328d820acddSMatthias Springer worklist.push_back(op); 329d820acddSMatthias Springer }); 3306fc753adSMatthias Springer 331d820acddSMatthias Springer // Keep track of all erased ops. 332d820acddSMatthias Springer DenseSet<Operation *> erasedOps; 3337a1579acSMatthias Springer 334d820acddSMatthias Springer // Bufferize all ops. 335d820acddSMatthias Springer BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, 3369d34c052SMatthias Springer worklist, options, statistics); 337d820acddSMatthias Springer for (unsigned i = 0; i < worklist.size(); ++i) { 338df23ede2SMatthias Springer Operation *nextOp = worklist[i]; 339d820acddSMatthias Springer // Skip ops that were erased. 340df23ede2SMatthias Springer if (erasedOps.contains(nextOp)) 341d820acddSMatthias Springer continue; 3429785eb1bSMatthias Springer // Skip ops that are not bufferizable or not allowed. 343df23ede2SMatthias Springer auto bufferizableOp = options.dynCastBufferizableOp(nextOp); 344d820acddSMatthias Springer if (!bufferizableOp) 345d820acddSMatthias Springer continue; 3469785eb1bSMatthias Springer // Skip ops that no longer have tensor semantics. 347df23ede2SMatthias Springer if (!hasTensorSemantics(nextOp)) 348d820acddSMatthias Springer continue; 349061aa2e3SMatthias Springer // Check for unsupported unstructured control flow. 350061aa2e3SMatthias Springer if (!bufferizableOp.supportsUnstructuredControlFlow()) 351061aa2e3SMatthias Springer for (Region &r : nextOp->getRegions()) 352061aa2e3SMatthias Springer if (r.getBlocks().size() > 1) 353061aa2e3SMatthias Springer return nextOp->emitOpError( 354061aa2e3SMatthias Springer "op or BufferizableOpInterface implementation does not support " 355061aa2e3SMatthias Springer "unstructured control flow, but at least one region has multiple " 356061aa2e3SMatthias Springer "blocks"); 357061aa2e3SMatthias Springer 358d820acddSMatthias Springer // Bufferize the op. 359df23ede2SMatthias Springer LLVM_DEBUG(llvm::dbgs() 360df23ede2SMatthias Springer << "//===-------------------------------------------===//\n" 361df23ede2SMatthias Springer << "IR after bufferizing: " << nextOp->getName() << "\n"); 362df23ede2SMatthias Springer rewriter.setInsertionPoint(nextOp); 363df23ede2SMatthias Springer if (failed(bufferizableOp.bufferize(rewriter, options))) { 364df23ede2SMatthias Springer LLVM_DEBUG(llvm::dbgs() 365df23ede2SMatthias Springer << "failed to bufferize\n" 366df23ede2SMatthias Springer << "//===-------------------------------------------===//\n"); 367df23ede2SMatthias Springer return nextOp->emitError("failed to bufferize op"); 368df23ede2SMatthias Springer } 369df23ede2SMatthias Springer LLVM_DEBUG(llvm::dbgs() 370df23ede2SMatthias Springer << *op 371df23ede2SMatthias Springer << "\n//===-------------------------------------------===//\n"); 372d820acddSMatthias Springer } 373d820acddSMatthias Springer 374fa101214SRyan Holt // Return early if the top-level op is entirely gone. 375fa101214SRyan Holt if (erasedOps.contains(op)) 376fa101214SRyan Holt return success(); 377fa101214SRyan Holt 378d820acddSMatthias Springer // Fold all to_memref(to_tensor(x)) pairs. 379d820acddSMatthias Springer for (Operation *op : toMemrefOps) { 380d820acddSMatthias Springer rewriter.setInsertionPoint(op); 381c515c780SMatthias Gehre (void)bufferization::foldToMemrefToTensorPair( 382c515c780SMatthias Gehre rewriter, cast<ToMemrefOp>(op), options); 383d820acddSMatthias Springer } 384d820acddSMatthias Springer 385199f368eSMatthias Springer // Remove all dead to_tensor ops. 386199f368eSMatthias Springer op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) { 387199f368eSMatthias Springer if (toTensorOp->getUses().empty()) { 388199f368eSMatthias Springer rewriter.eraseOp(toTensorOp); 389199f368eSMatthias Springer return WalkResult::skip(); 390199f368eSMatthias Springer } 391199f368eSMatthias Springer return WalkResult::advance(); 392199f368eSMatthias Springer }); 393199f368eSMatthias Springer 394d820acddSMatthias Springer /// Check the result of bufferization. Return an error if an op was not 395d820acddSMatthias Springer /// bufferized, unless partial bufferization is allowed. 396b55d55ecSMatthias Springer if (options.allowUnknownOps) 397d820acddSMatthias Springer return success(); 398d820acddSMatthias Springer 399d820acddSMatthias Springer for (Operation *op : worklist) { 400d820acddSMatthias Springer // Skip ops that are entirely gone. 401d820acddSMatthias Springer if (erasedOps.contains(op)) 402d820acddSMatthias Springer continue; 403d820acddSMatthias Springer // Ops that no longer have tensor semantics (because they were updated 404d820acddSMatthias Springer // in-place) are allowed. 405d820acddSMatthias Springer if (!hasTensorSemantics(op)) 406d820acddSMatthias Springer continue; 407d820acddSMatthias Springer // Continue ops that are not allowed. 408d820acddSMatthias Springer if (!options.isOpAllowed(op)) 409d820acddSMatthias Springer continue; 410d820acddSMatthias Springer // Ops without any uses and no side effects will fold away. 411fc367dfaSMahesh Ravishankar if (op->getUses().empty() && isMemoryEffectFree(op)) 412d820acddSMatthias Springer continue; 413b3ebe3beSMatthias Springer // ToTensorOps/ToMemrefOps are allowed in the output. 414b3ebe3beSMatthias Springer if (isa<ToTensorOp, ToMemrefOp>(op)) 415b3ebe3beSMatthias Springer continue; 416d820acddSMatthias Springer return op->emitError("op was not bufferized"); 417d820acddSMatthias Springer } 41805e0495fSMatthias Springer 41905e0495fSMatthias Springer return success(); 4207a1579acSMatthias Springer } 421daf18108SMatthias Springer 422a88732d9SMatthias Springer LogicalResult 423a88732d9SMatthias Springer bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, 424a88732d9SMatthias Springer const BufferizationOptions &options) { 425a88732d9SMatthias Springer OpBuilder::InsertionGuard g(rewriter); 426a88732d9SMatthias Springer auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp()); 427a88732d9SMatthias Springer if (!bufferizableOp) 428a88732d9SMatthias Springer return failure(); 429a88732d9SMatthias Springer 430a88732d9SMatthias Springer // Compute the new signature. 431a88732d9SMatthias Springer SmallVector<Type> newTypes; 432a88732d9SMatthias Springer for (BlockArgument &bbArg : block->getArguments()) { 433a88732d9SMatthias Springer auto tensorType = dyn_cast<TensorType>(bbArg.getType()); 434a88732d9SMatthias Springer if (!tensorType) { 435a88732d9SMatthias Springer newTypes.push_back(bbArg.getType()); 436a88732d9SMatthias Springer continue; 437a88732d9SMatthias Springer } 438a88732d9SMatthias Springer 439a88732d9SMatthias Springer FailureOr<BaseMemRefType> memrefType = 440a88732d9SMatthias Springer bufferization::getBufferType(bbArg, options); 441a88732d9SMatthias Springer if (failed(memrefType)) 442a88732d9SMatthias Springer return failure(); 443a88732d9SMatthias Springer newTypes.push_back(*memrefType); 444a88732d9SMatthias Springer } 445a88732d9SMatthias Springer 446a88732d9SMatthias Springer // Change the type of all block arguments. 447a88732d9SMatthias Springer for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) { 448a88732d9SMatthias Springer if (bbArg.getType() == type) 449a88732d9SMatthias Springer continue; 450a88732d9SMatthias Springer 451a88732d9SMatthias Springer // Collect all uses of the bbArg. 452a88732d9SMatthias Springer SmallVector<OpOperand *> bbArgUses; 453a88732d9SMatthias Springer for (OpOperand &use : bbArg.getUses()) 454a88732d9SMatthias Springer bbArgUses.push_back(&use); 455a88732d9SMatthias Springer 456*bfefa15cSYi Zhang Type tensorType = bbArg.getType(); 457a88732d9SMatthias Springer // Change the bbArg type to memref. 458a88732d9SMatthias Springer bbArg.setType(type); 459a88732d9SMatthias Springer 460a88732d9SMatthias Springer // Replace all uses of the original tensor bbArg. 461a88732d9SMatthias Springer rewriter.setInsertionPointToStart(block); 462a88732d9SMatthias Springer if (!bbArgUses.empty()) { 463*bfefa15cSYi Zhang Value toTensorOp = rewriter.create<bufferization::ToTensorOp>( 464*bfefa15cSYi Zhang bbArg.getLoc(), tensorType, bbArg); 465a88732d9SMatthias Springer for (OpOperand *use : bbArgUses) 466a88732d9SMatthias Springer use->set(toTensorOp); 467a88732d9SMatthias Springer } 468a88732d9SMatthias Springer } 469a88732d9SMatthias Springer 4706ecebb49SMatthias Springer // Bufferize callers of the block. 4716ecebb49SMatthias Springer for (Operation *op : block->getUsers()) { 4726ecebb49SMatthias Springer auto branchOp = dyn_cast<BranchOpInterface>(op); 4736ecebb49SMatthias Springer if (!branchOp) 4746ecebb49SMatthias Springer return op->emitOpError("cannot bufferize ops with block references that " 4756ecebb49SMatthias Springer "do not implement BranchOpInterface"); 4766ecebb49SMatthias Springer 4776ecebb49SMatthias Springer auto it = llvm::find(op->getSuccessors(), block); 4786ecebb49SMatthias Springer assert(it != op->getSuccessors().end() && "could find successor"); 4796ecebb49SMatthias Springer int64_t successorIdx = std::distance(op->getSuccessors().begin(), it); 4806ecebb49SMatthias Springer 4816ecebb49SMatthias Springer SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx); 4826ecebb49SMatthias Springer SmallVector<Value> newOperands; 4836ecebb49SMatthias Springer for (auto [operand, type] : 4846ecebb49SMatthias Springer llvm::zip(operands.getForwardedOperands(), newTypes)) { 4856ecebb49SMatthias Springer if (operand.getType() == type) { 4866ecebb49SMatthias Springer // Not a tensor type. Nothing to do for this operand. 4876ecebb49SMatthias Springer newOperands.push_back(operand); 4886ecebb49SMatthias Springer continue; 4896ecebb49SMatthias Springer } 4906ecebb49SMatthias Springer FailureOr<BaseMemRefType> operandBufferType = 4916ecebb49SMatthias Springer bufferization::getBufferType(operand, options); 4926ecebb49SMatthias Springer if (failed(operandBufferType)) 4936ecebb49SMatthias Springer return failure(); 4946ecebb49SMatthias Springer rewriter.setInsertionPointAfterValue(operand); 4956ecebb49SMatthias Springer Value bufferizedOperand = rewriter.create<bufferization::ToMemrefOp>( 4966ecebb49SMatthias Springer operand.getLoc(), *operandBufferType, operand); 4976ecebb49SMatthias Springer // A cast is needed if the operand and the block argument have different 4986ecebb49SMatthias Springer // bufferized types. 4996ecebb49SMatthias Springer if (type != *operandBufferType) 5006ecebb49SMatthias Springer bufferizedOperand = rewriter.create<memref::CastOp>( 5016ecebb49SMatthias Springer operand.getLoc(), type, bufferizedOperand); 5026ecebb49SMatthias Springer newOperands.push_back(bufferizedOperand); 5036ecebb49SMatthias Springer } 5046ecebb49SMatthias Springer operands.getMutableForwardedOperands().assign(newOperands); 5056ecebb49SMatthias Springer } 5066ecebb49SMatthias Springer 507a88732d9SMatthias Springer return success(); 508a88732d9SMatthias Springer } 509