17a1579acSMatthias Springer //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===// 27a1579acSMatthias Springer // 37a1579acSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 47a1579acSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 57a1579acSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 67a1579acSMatthias Springer // 77a1579acSMatthias Springer //===----------------------------------------------------------------------===// 87a1579acSMatthias Springer // 97cdfc843SMatthias Springer // One-Shot Analysis analyzes function bodies. By default, function boundaries 107cdfc843SMatthias Springer // (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops. 117cdfc843SMatthias Springer // OneShotModuleBufferization.cpp is an extension of One-Shot Analysis for 127cdfc843SMatthias Springer // simple call graphs without loops. 137a1579acSMatthias Springer // 147cdfc843SMatthias Springer // One-Shot Bufferize consists of three phases. 157a1579acSMatthias Springer // 167cdfc843SMatthias Springer // 1. Analyze ops to decide which OpOperands can bufferize inplace, i.e., 177cdfc843SMatthias Springer // without inserting buffer copies. The analysis queries op bufferization 187cdfc843SMatthias Springer // semantics via `BufferizableOpInterface`. 197cdfc843SMatthias Springer // 2. Insert copies for OpOperands that were decided to bufferize out-of-place 207cdfc843SMatthias Springer // in tensor land during `TensorCopyInsertion`. 217cdfc843SMatthias Springer // 3. Bufferize ops by calling `BufferizableOpInterface::bufferize`. 227a1579acSMatthias Springer // 237cdfc843SMatthias Springer // This file contains only the analysis. For convenience, this file also 247cdfc843SMatthias Springer // contains a helper function `runOneShotBufferize` that analyzes an op (and its 257cdfc843SMatthias Springer // nested ops) and then bufferizes it. 267a1579acSMatthias Springer // 277a1579acSMatthias Springer // Inplace bufferization decisions are passed from the analysis to the 287cdfc843SMatthias Springer // `TensorCopyInsertion` phase via `AnalysisState`. They can be printed for 297cdfc843SMatthias Springer // debugging purposes with `testAnalysisOnly`. 307a1579acSMatthias Springer // 317a1579acSMatthias Springer // Ops that do not implement `BufferizableOpInterface` can be analyzed but are 327a1579acSMatthias Springer // treated conservatively. E.g., the analysis has to assume that their tensor 337a1579acSMatthias Springer // OpOperands bufferize to memory writes. While such ops can be analyzed, they 347a1579acSMatthias Springer // are not bufferized and remain in the IR. to_tensor and to_memref ops are 357a1579acSMatthias Springer // inserted at the bufferization boundary. 367a1579acSMatthias Springer // 377a1579acSMatthias Springer // This analysis caters to high-performance codegen where buffer reuse is deemed 387a1579acSMatthias Springer // critical: the analysis should fail if the bufferized form of the function 39855a11eeSMatthias Springer // needs to return a buffer, unless `allowReturnAllocs` is enabled. 407a1579acSMatthias Springer 417a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 427a1579acSMatthias Springer 43a1fe1f5fSKazu Hirata #include <optional> 448ee38f3bSMatthias Springer #include <random> 457a1579acSMatthias Springer 467a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 477a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 487a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 4928b2f792SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 50d6dab38aSMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h" 517a1579acSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 527a1579acSMatthias Springer #include "mlir/IR/AsmState.h" 537a1579acSMatthias Springer #include "mlir/IR/Dominance.h" 5435d3b343SMatthias Springer #include "mlir/IR/Iterators.h" 557a1579acSMatthias Springer #include "mlir/IR/Operation.h" 567a1579acSMatthias Springer #include "mlir/IR/TypeUtilities.h" 577a1579acSMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h" 581abd8d1aSMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.h" 597a1579acSMatthias Springer #include "llvm/ADT/DenseSet.h" 607a1579acSMatthias Springer #include "llvm/ADT/SetVector.h" 617a1579acSMatthias Springer 62faa9be75SMatthias Springer MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState) 63faa9be75SMatthias Springer 64e0b40af7SMatthias Springer // Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug 65e0b40af7SMatthias Springer // output. 66e0b40af7SMatthias Springer #define DEBUG_TYPE "one-shot-analysis" 67e0b40af7SMatthias Springer 687a1579acSMatthias Springer using namespace mlir; 697a1579acSMatthias Springer using namespace mlir::bufferization; 707a1579acSMatthias Springer 715550c821STres Popp static bool isaTensor(Type t) { return isa<TensorType>(t); } 727a1579acSMatthias Springer 737a1579acSMatthias Springer //===----------------------------------------------------------------------===// 747a1579acSMatthias Springer // Bufferization-specific attribute manipulation. 757cdfc843SMatthias Springer // These are for testing and debugging only. Bufferization information is stored 76cf2d374eSMatthias Springer // in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is 777cdfc843SMatthias Springer // annotated with the results of the analysis, so that they can be checked in 787cdfc843SMatthias Springer // tests. 797a1579acSMatthias Springer //===----------------------------------------------------------------------===// 807a1579acSMatthias Springer 817cdfc843SMatthias Springer /// Attribute marker to specify op operands that bufferize in-place. 827cdfc843SMatthias Springer constexpr StringLiteral kInPlaceOperandsAttrName = "__inplace_operands_attr__"; 837a1579acSMatthias Springer 84a02ad6c1SMatthias Springer constexpr StringLiteral kOpResultAliasSetAttrName = 85a02ad6c1SMatthias Springer "__opresult_alias_set_attr__"; 86a02ad6c1SMatthias Springer 87a02ad6c1SMatthias Springer constexpr StringLiteral kBbArgAliasSetAttrName = "__bbarg_alias_set_attr__"; 88bb9d1b55SMatthias Springer 897a1579acSMatthias Springer /// Mark whether OpOperand will be bufferized inplace. 907a1579acSMatthias Springer static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) { 917a1579acSMatthias Springer Operation *op = opOperand.getOwner(); 927a1579acSMatthias Springer SmallVector<StringRef> inPlaceVector; 937cdfc843SMatthias Springer if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) { 947cdfc843SMatthias Springer inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>( 955550c821STres Popp cast<ArrayAttr>(attr).getAsValueRange<StringAttr>())); 967a1579acSMatthias Springer } else { 977a1579acSMatthias Springer inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none"); 987a1579acSMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) 995550c821STres Popp if (isa<TensorType>(opOperand.get().getType())) 1007a1579acSMatthias Springer inPlaceVector[opOperand.getOperandNumber()] = "false"; 1017a1579acSMatthias Springer } 1027a1579acSMatthias Springer inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false"; 1037cdfc843SMatthias Springer op->setAttr(kInPlaceOperandsAttrName, 1047a1579acSMatthias Springer OpBuilder(op).getStrArrayAttr(inPlaceVector)); 1057a1579acSMatthias Springer } 1067a1579acSMatthias Springer 1077a1579acSMatthias Springer //===----------------------------------------------------------------------===// 108cf2d374eSMatthias Springer // OneShotAnalysisState 1097a1579acSMatthias Springer //===----------------------------------------------------------------------===// 1107a1579acSMatthias Springer 111cf2d374eSMatthias Springer OneShotAnalysisState::OneShotAnalysisState( 112cf2d374eSMatthias Springer Operation *op, const OneShotBufferizationOptions &options) 113cf2d374eSMatthias Springer : AnalysisState(options, TypeID::get<OneShotAnalysisState>()) { 114cf2d374eSMatthias Springer // Set up alias sets. 115cf2d374eSMatthias Springer op->walk([&](Operation *op) { 1167a1579acSMatthias Springer for (Value v : op->getResults()) 1175550c821STres Popp if (isa<TensorType>(v.getType())) 1187a1579acSMatthias Springer createAliasInfoEntry(v); 1197a1579acSMatthias Springer for (Region &r : op->getRegions()) 1207a1579acSMatthias Springer for (Block &b : r.getBlocks()) 1217a1579acSMatthias Springer for (auto bbArg : b.getArguments()) 1225550c821STres Popp if (isa<TensorType>(bbArg.getType())) 1237a1579acSMatthias Springer createAliasInfoEntry(bbArg); 1247a1579acSMatthias Springer }); 125cf2d374eSMatthias Springer 126cf2d374eSMatthias Springer // Mark OpOperands in-place that must bufferize in-place. 127cf2d374eSMatthias Springer op->walk([&](BufferizableOpInterface bufferizableOp) { 128cf2d374eSMatthias Springer if (!options.isOpAllowed(bufferizableOp)) 129cf2d374eSMatthias Springer return WalkResult::skip(); 130cf2d374eSMatthias Springer for (OpOperand &opOperand : bufferizableOp->getOpOperands()) 1315550c821STres Popp if (isa<TensorType>(opOperand.get().getType())) 132cf2d374eSMatthias Springer if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) 133cf2d374eSMatthias Springer bufferizeInPlace(opOperand); 134cf2d374eSMatthias Springer return WalkResult::advance(); 135cf2d374eSMatthias Springer }); 1367a1579acSMatthias Springer } 1377a1579acSMatthias Springer 138cf2d374eSMatthias Springer void OneShotAnalysisState::applyOnEquivalenceClass( 1397a1579acSMatthias Springer Value v, function_ref<void(Value)> fun) const { 1407a1579acSMatthias Springer auto leaderIt = equivalentInfo.findLeader(v); 1417a1579acSMatthias Springer for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; 1427a1579acSMatthias Springer ++mit) { 1437a1579acSMatthias Springer fun(*mit); 1447a1579acSMatthias Springer } 1457a1579acSMatthias Springer } 1467a1579acSMatthias Springer 147cf2d374eSMatthias Springer void OneShotAnalysisState::applyOnAliases(Value v, 148cf2d374eSMatthias Springer function_ref<void(Value)> fun) const { 1497a1579acSMatthias Springer auto leaderIt = aliasInfo.findLeader(v); 1507a1579acSMatthias Springer for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) { 1517a1579acSMatthias Springer fun(*mit); 1527a1579acSMatthias Springer } 1537a1579acSMatthias Springer } 1547a1579acSMatthias Springer 1559597b16aSMatthias Springer bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1, 1567a1579acSMatthias Springer Value v2) const { 157cf2d374eSMatthias Springer return equivalentInfo.isEquivalent(v1, v2); 1587a1579acSMatthias Springer } 1597a1579acSMatthias Springer 1603490aadfSMatthias Springer bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1, 1613490aadfSMatthias Springer Value v2) const { 162cf2d374eSMatthias Springer return aliasInfo.isEquivalent(v1, v2); 163cf2d374eSMatthias Springer } 164cf2d374eSMatthias Springer 165cf2d374eSMatthias Springer void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) { 166cf2d374eSMatthias Springer if (inplaceBufferized.contains(&operand)) 167cf2d374eSMatthias Springer return; 168cf2d374eSMatthias Springer inplaceBufferized.insert(&operand); 169a02ad6c1SMatthias Springer for (AliasingValue alias : getAliasingValues(operand)) 170a02ad6c1SMatthias Springer aliasInfo.unionSets(alias.value, operand.get()); 171cf2d374eSMatthias Springer ++statNumTensorInPlace; 172cf2d374eSMatthias Springer } 173cf2d374eSMatthias Springer 174cf2d374eSMatthias Springer void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) { 175cf2d374eSMatthias Springer assert(!inplaceBufferized.contains(&operand) && 176cf2d374eSMatthias Springer "OpOperand was already decided to bufferize inplace"); 177cf2d374eSMatthias Springer ++statNumTensorOutOfPlace; 178cf2d374eSMatthias Springer } 179cf2d374eSMatthias Springer 180cf2d374eSMatthias Springer void OneShotAnalysisState::createAliasInfoEntry(Value v) { 181cf2d374eSMatthias Springer aliasInfo.insert(v); 182cf2d374eSMatthias Springer equivalentInfo.insert(v); 1833490aadfSMatthias Springer } 1843490aadfSMatthias Springer 185988748c0SMatthias Springer void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) { 186988748c0SMatthias Springer op->walk([&](Operation *op) { 187988748c0SMatthias Springer // Skip unknown ops. 188988748c0SMatthias Springer auto bufferizableOp = getOptions().dynCastBufferizableOp(op); 189988748c0SMatthias Springer if (!bufferizableOp) 190988748c0SMatthias Springer return WalkResult::skip(); 191988748c0SMatthias Springer 192988748c0SMatthias Springer // Check all tensor OpResults. 193988748c0SMatthias Springer for (OpResult opResult : op->getOpResults()) { 1945550c821STres Popp if (!isa<TensorType>(opResult.getType())) 195988748c0SMatthias Springer continue; 196988748c0SMatthias Springer 1971840d18aSMatthias Springer // If there is no preceding definition, the tensor contents are 198988748c0SMatthias Springer // undefined. 199*d9111f19SAmir Bishara if (opResult.getUses().empty()) 200*d9111f19SAmir Bishara continue; 201*d9111f19SAmir Bishara // It does not really matter which use to take to search about 202*d9111f19SAmir Bishara // the value's definitions. 203*d9111f19SAmir Bishara OpOperand *opOperand = &(*opResult.getUses().begin()); 204*d9111f19SAmir Bishara if (findDefinitionsCached(opOperand).empty()) 205988748c0SMatthias Springer for (OpOperand &use : opResult.getUses()) 206988748c0SMatthias Springer undefinedTensorUses.insert(&use); 207988748c0SMatthias Springer } 208988748c0SMatthias Springer 209988748c0SMatthias Springer return WalkResult::advance(); 210988748c0SMatthias Springer }); 211988748c0SMatthias Springer } 212988748c0SMatthias Springer 213988748c0SMatthias Springer bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 214988748c0SMatthias Springer return undefinedTensorUses.contains(opOperand); 215988748c0SMatthias Springer } 216988748c0SMatthias Springer 217cf2d374eSMatthias Springer bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const { 218cf2d374eSMatthias Springer return inplaceBufferized.contains(&opOperand); 219cf2d374eSMatthias Springer } 220cf2d374eSMatthias Springer 2213490aadfSMatthias Springer bool OneShotAnalysisState::isValueWritten(Value value) const { 2223490aadfSMatthias Springer bool isWritten = false; 223cf2d374eSMatthias Springer applyOnAliases(value, [&](Value val) { 2243490aadfSMatthias Springer for (OpOperand &use : val.getUses()) 2253490aadfSMatthias Springer if (isInPlace(use) && bufferizesToMemoryWrite(use)) 2263490aadfSMatthias Springer isWritten = true; 2273490aadfSMatthias Springer }); 2283490aadfSMatthias Springer return isWritten; 2293490aadfSMatthias Springer } 2303490aadfSMatthias Springer 231032be233SMatthias Springer bool OneShotAnalysisState::isWritable(Value value) const { 232032be233SMatthias Springer // TODO: Out-of-place bufferized value could be considered writable. 233032be233SMatthias Springer // Query BufferizableOpInterface to see if the BlockArgument is writable. 234032be233SMatthias Springer if (auto bufferizableOp = 235a02ad6c1SMatthias Springer getOptions().dynCastBufferizableOp(getOwnerOfValue(value))) 236a02ad6c1SMatthias Springer return bufferizableOp.isWritable(value, *this); 237032be233SMatthias Springer 238032be233SMatthias Springer // Not a bufferizable op: The conservative answer is "not writable". 239032be233SMatthias Springer return false; 240032be233SMatthias Springer } 241032be233SMatthias Springer 242cf2d374eSMatthias Springer void OneShotAnalysisState::unionAliasSets(Value v1, Value v2) { 243cf2d374eSMatthias Springer aliasInfo.unionSets(v1, v2); 244cf2d374eSMatthias Springer } 245cf2d374eSMatthias Springer 246cf2d374eSMatthias Springer void OneShotAnalysisState::unionEquivalenceClasses(Value v1, Value v2) { 247cf2d374eSMatthias Springer equivalentInfo.unionSets(v1, v2); 248cf2d374eSMatthias Springer } 249cf2d374eSMatthias Springer 250faa9be75SMatthias Springer OneShotAnalysisState::Extension::~Extension() = default; 251faa9be75SMatthias Springer 2527a1579acSMatthias Springer //===----------------------------------------------------------------------===// 2537a1579acSMatthias Springer // Bufferization-specific alias analysis. 2547a1579acSMatthias Springer //===----------------------------------------------------------------------===// 2557a1579acSMatthias Springer 2567a1579acSMatthias Springer /// Return true if opOperand has been decided to bufferize in-place. 2577a1579acSMatthias Springer static bool isInplaceMemoryWrite(OpOperand &opOperand, 258cf2d374eSMatthias Springer const OneShotAnalysisState &state) { 2597a1579acSMatthias Springer // OpOperands that do not bufferize to a memory write do not write in-place. 2607a1579acSMatthias Springer if (!state.bufferizesToMemoryWrite(opOperand)) 2617a1579acSMatthias Springer return false; 2627a1579acSMatthias Springer // Check current bufferization decisions. 263cf2d374eSMatthias Springer return state.isInPlace(opOperand); 2647a1579acSMatthias Springer } 2657a1579acSMatthias Springer 2667a1579acSMatthias Springer /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors 2677a1579acSMatthias Springer /// properly dominates `b` and `b` is not inside `a`. 2687a1579acSMatthias Springer static bool happensBefore(Operation *a, Operation *b, 2697a1579acSMatthias Springer const DominanceInfo &domInfo) { 2707a1579acSMatthias Springer do { 2717a1579acSMatthias Springer // TODO: Instead of isProperAncestor + properlyDominates, we should use 2727a1579acSMatthias Springer // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false) 2737a1579acSMatthias Springer if (a->isProperAncestor(b)) 2747a1579acSMatthias Springer return false; 2757a1579acSMatthias Springer if (domInfo.properlyDominates(a, b)) 2767a1579acSMatthias Springer return true; 2777a1579acSMatthias Springer } while ((a = a->getParentOp())); 2787a1579acSMatthias Springer return false; 2797a1579acSMatthias Springer } 2807a1579acSMatthias Springer 281c89c31a2SMatthias Springer /// Return `true` if op dominance can be used to rule out a read-after-write 2826ecebb49SMatthias Springer /// conflicts based on the ordering of ops. Returns `false` if op dominance 2836ecebb49SMatthias Springer /// cannot be used to due region-based loops. 2842e210034SMatthias Springer /// 285c89c31a2SMatthias Springer /// Generalized op dominance can often be used to rule out potential conflicts 286c89c31a2SMatthias Springer /// due to "read happens before write". E.g., the following IR is not a RaW 287c89c31a2SMatthias Springer /// conflict because the read happens *before* the write. 2882e210034SMatthias Springer /// 289c89c31a2SMatthias Springer /// Example 1: 290c89c31a2SMatthias Springer /// %0 = ... : tensor<?xf32> // DEF 291c89c31a2SMatthias Springer /// "reading_op"(%0) : tensor<?xf32> // READ 292c89c31a2SMatthias Springer /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE 2932e210034SMatthias Springer /// 2942e210034SMatthias Springer /// This is no longer true inside loops (or repetitive regions). In such cases, 2952e210034SMatthias Springer /// there may not be a meaningful `happensBefore` relationship because ops 2962e210034SMatthias Springer /// could be executed multiple times. E.g.: 2972e210034SMatthias Springer /// 298c89c31a2SMatthias Springer /// Example 2: 299c89c31a2SMatthias Springer /// %0 = ... : tensor<?xf32> // DEF 3002e210034SMatthias Springer /// scf.for ... { 301c89c31a2SMatthias Springer /// "reading_op"(%0) : tensor<?xf32> // READ 302c89c31a2SMatthias Springer /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE 3032e210034SMatthias Springer /// ... 3042e210034SMatthias Springer /// } 3052e210034SMatthias Springer /// 3062e210034SMatthias Springer /// In the above example, reading_op happens before writing_op according to 3072e210034SMatthias Springer /// op dominance. However, both ops may happen multiple times; in 3082e210034SMatthias Springer /// particular, the second execution of reading_op happens after the first 3092e210034SMatthias Springer /// execution of writing_op. This is problematic because the tensor %0 they 3102e210034SMatthias Springer /// operate on (i.e., the "definition") is defined outside of the loop. 3112e210034SMatthias Springer /// 312c89c31a2SMatthias Springer /// On a high-level, there is a potential RaW in a program if there exists a 313c89c31a2SMatthias Springer /// possible program execution such that there is a sequence of DEF, followed 314c89c31a2SMatthias Springer /// by WRITE, followed by READ. Each additional DEF resets the sequence. 3152e210034SMatthias Springer /// 316c89c31a2SMatthias Springer /// E.g.: 317c89c31a2SMatthias Springer /// No conflict: DEF, WRITE, DEF, READ 318c89c31a2SMatthias Springer /// Potential conflict: DEF, READ, WRITE, READ, WRITE 319c89c31a2SMatthias Springer /// 320c89c31a2SMatthias Springer /// Example 1 has no conflict: DEF, READ, WRITE 321c89c31a2SMatthias Springer /// Example 2 has a potential conflict: DEF, (READ, WRITE)* 322c89c31a2SMatthias Springer // 323c89c31a2SMatthias Springer /// Example 3: 3242e210034SMatthias Springer /// scf.for ... { 3252e210034SMatthias Springer /// %0 = ... : tensor<?xf32> 3262e210034SMatthias Springer /// "reading_op"(%0) : tensor<?xf32> 3272e210034SMatthias Springer /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> 3282e210034SMatthias Springer /// ... 3292e210034SMatthias Springer /// } 330c89c31a2SMatthias Springer /// This has no conflict: (DEF, READ, WRITE)* 3312e210034SMatthias Springer /// 332c89c31a2SMatthias Springer /// Example 4: 333c89c31a2SMatthias Springer /// %0 = ... : tensor<?xf32> 3342e210034SMatthias Springer /// scf.for ... { 335c89c31a2SMatthias Springer /// scf.for ... { "reading_op"(%0) } 336c89c31a2SMatthias Springer /// %1 = "writing_op"(%0) 3372e210034SMatthias Springer /// } 338c89c31a2SMatthias Springer /// This has a potential conflict: DEF, ((READ)*, WRITE)* 3392e210034SMatthias Springer /// 340c89c31a2SMatthias Springer /// Example 5: 341c89c31a2SMatthias Springer /// %0 = ... : tensor<?xf32> 342c89c31a2SMatthias Springer /// scf.for ... { %1 = "writing_op"(%0) } 343c89c31a2SMatthias Springer /// scf.for ... { "reading_op"(%0) } 344c89c31a2SMatthias Springer /// This has a potential conflict: DEF, WRITE*, READ* 3452e210034SMatthias Springer /// 346c89c31a2SMatthias Springer /// The following rules are used to rule out RaW conflicts via ordering of ops: 3472e210034SMatthias Springer /// 348c89c31a2SMatthias Springer /// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of 349c89c31a2SMatthias Springer /// a repetitive region that enclosing both READ and WRITE, we cannot rule 350c89c31a2SMatthias Springer /// out RaW conflict due to the ordering of ops. 351c89c31a2SMatthias Springer /// 2. Otherwise: There are no loops that interfere with our analysis; for 352c89c31a2SMatthias Springer /// analysis purposes, we can assume that there are no loops/repetitive 353c89c31a2SMatthias Springer /// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE 354c89c31a2SMatthias Springer /// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.) 3552e210034SMatthias Springer /// 3566ecebb49SMatthias Springer static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite, 357c89c31a2SMatthias Springer const SetVector<Value> &definitions, 3589312b4f9SMartin Erhart AnalysisState &state) { 3592e210034SMatthias Springer const BufferizationOptions &options = state.getOptions(); 360c89c31a2SMatthias Springer for (Value def : definitions) { 3619312b4f9SMartin Erhart Region *rRead = 3629312b4f9SMartin Erhart state.getEnclosingRepetitiveRegion(uRead->getOwner(), options); 3639312b4f9SMartin Erhart Region *rDef = state.getEnclosingRepetitiveRegion(def, options); 3642e210034SMatthias Springer 365c89c31a2SMatthias Springer // READ and DEF are in the same repetitive region. `happensBefore` can be 366c89c31a2SMatthias Springer // used to rule out RaW conflicts due to op ordering. 367c89c31a2SMatthias Springer if (rRead == rDef) 3682e210034SMatthias Springer continue; 369c89c31a2SMatthias Springer 370c89c31a2SMatthias Springer // Find the enclosing repetitive region of READ that is closest to DEF but 371c89c31a2SMatthias Springer // not the repetitive region of DEF itself. 372c89c31a2SMatthias Springer while (true) { 373c89c31a2SMatthias Springer Region *nextRegion = getNextEnclosingRepetitiveRegion(rRead, options); 374c89c31a2SMatthias Springer if (nextRegion == rDef) 375c89c31a2SMatthias Springer break; 376c89c31a2SMatthias Springer assert(nextRegion && "expected to find another repetitive region"); 377c89c31a2SMatthias Springer rRead = nextRegion; 3782e210034SMatthias Springer } 379c89c31a2SMatthias Springer 380c89c31a2SMatthias Springer // We cannot use op dominance if WRITE is inside the same repetitive region. 381c89c31a2SMatthias Springer if (rRead->getParentOp()->isAncestor(uWrite->getOwner())) 3822e210034SMatthias Springer return false; 3832e210034SMatthias Springer } 3846ecebb49SMatthias Springer 385c89c31a2SMatthias Springer return true; 3862e210034SMatthias Springer } 3872e210034SMatthias Springer 3886ecebb49SMatthias Springer /// Return `true` if op dominance can be used to rule out a read-after-write 3896ecebb49SMatthias Springer /// conflicts based on the ordering of ops. Returns `false` if op dominance 3906ecebb49SMatthias Springer /// cannot be used to due block-based loops within a region. 3916ecebb49SMatthias Springer /// 3926ecebb49SMatthias Springer /// Refer to the `canUseOpDominanceDueToRegions` documentation for details on 3936ecebb49SMatthias Springer /// how op domiance is used during RaW conflict detection. 3946ecebb49SMatthias Springer /// 3956ecebb49SMatthias Springer /// On a high-level, there is a potential RaW in a program if there exists a 3966ecebb49SMatthias Springer /// possible program execution such that there is a sequence of DEF, followed 3976ecebb49SMatthias Springer /// by WRITE, followed by READ. Each additional DEF resets the sequence. 3986ecebb49SMatthias Springer /// 3996ecebb49SMatthias Springer /// Op dominance cannot be used if there is a path from block(READ) to 4006ecebb49SMatthias Springer /// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should 4016ecebb49SMatthias Springer /// not appear on that path. 4026ecebb49SMatthias Springer static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite, 4036ecebb49SMatthias Springer const SetVector<Value> &definitions, 4046ecebb49SMatthias Springer AnalysisState &state) { 4056ecebb49SMatthias Springer // Fast path: If READ and WRITE are in different regions, their block cannot 4066ecebb49SMatthias Springer // be reachable just via unstructured control flow. (Loops due to regions are 4076ecebb49SMatthias Springer // covered by `canUseOpDominanceDueToRegions`.) 4086ecebb49SMatthias Springer if (uRead->getOwner()->getParentRegion() != 4096ecebb49SMatthias Springer uWrite->getOwner()->getParentRegion()) 4106ecebb49SMatthias Springer return true; 4116ecebb49SMatthias Springer 4126ecebb49SMatthias Springer Block *readBlock = uRead->getOwner()->getBlock(); 4136ecebb49SMatthias Springer Block *writeBlock = uWrite->getOwner()->getBlock(); 4146ecebb49SMatthias Springer for (Value def : definitions) { 4156ecebb49SMatthias Springer Block *defBlock = def.getParentBlock(); 416804d3c4cSMatthias Springer if (readBlock->isReachable(writeBlock, {defBlock}) && 417804d3c4cSMatthias Springer writeBlock->isReachable(readBlock, {defBlock})) 4186ecebb49SMatthias Springer return false; 4196ecebb49SMatthias Springer } 4206ecebb49SMatthias Springer 4216ecebb49SMatthias Springer return true; 4226ecebb49SMatthias Springer } 4236ecebb49SMatthias Springer 4246ecebb49SMatthias Springer static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, 4256ecebb49SMatthias Springer const SetVector<Value> &definitions, 4266ecebb49SMatthias Springer AnalysisState &state) { 4276ecebb49SMatthias Springer return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) && 4286ecebb49SMatthias Springer canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state); 4296ecebb49SMatthias Springer } 4306ecebb49SMatthias Springer 4317a1579acSMatthias Springer /// Annotate IR with details about the detected RaW conflict. 4327a1579acSMatthias Springer static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, 4331840d18aSMatthias Springer Value definition) { 4347a1579acSMatthias Springer static uint64_t counter = 0; 4357a1579acSMatthias Springer Operation *readingOp = uRead->getOwner(); 4367a1579acSMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 4377a1579acSMatthias Springer 4387a1579acSMatthias Springer OpBuilder b(conflictingWritingOp->getContext()); 4397a1579acSMatthias Springer std::string id = "C_" + std::to_string(counter++); 4407a1579acSMatthias Springer 4417a1579acSMatthias Springer std::string conflictingWriteAttr = 4427a1579acSMatthias Springer id + 4437a1579acSMatthias Springer "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) + 4447a1579acSMatthias Springer "]"; 4457a1579acSMatthias Springer conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr()); 4467a1579acSMatthias Springer 4477a1579acSMatthias Springer std::string readAttr = 4487a1579acSMatthias Springer id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]"; 4497a1579acSMatthias Springer readingOp->setAttr(readAttr, b.getUnitAttr()); 4507a1579acSMatthias Springer 4515550c821STres Popp if (auto opResult = dyn_cast<OpResult>(definition)) { 4521840d18aSMatthias Springer std::string defAttr = 4531840d18aSMatthias Springer id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]"; 4541840d18aSMatthias Springer opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr()); 4557a1579acSMatthias Springer } else { 4565550c821STres Popp auto bbArg = cast<BlockArgument>(definition); 4571840d18aSMatthias Springer std::string defAttr = 4581840d18aSMatthias Springer id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; 4591840d18aSMatthias Springer bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr()); 4607a1579acSMatthias Springer } 4617a1579acSMatthias Springer } 4627a1579acSMatthias Springer 463f36e1934SMatthias Springer /// Return 'true' if a tensor that is equivalent to `other` can be found in the 464f36e1934SMatthias Springer /// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of 465f36e1934SMatthias Springer /// place along that use-def chain, the two tensors may not materialize as 466f36e1934SMatthias Springer /// equivalent buffers (but separate allocations). 467f36e1934SMatthias Springer /// 468f36e1934SMatthias Springer /// Note: This function also requires that the two tensors have equivalent 469f36e1934SMatthias Springer /// indexing. I.e., the tensor types do not change along the use-def chain, 470f36e1934SMatthias Springer /// apart from static <-> dynamic dim casts. 471f36e1934SMatthias Springer static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, 472*d9111f19SAmir Bishara OpOperand *start, 473*d9111f19SAmir Bishara Value other) { 474f36e1934SMatthias Springer TraversalConfig config; 475f36e1934SMatthias Springer config.followEquivalentOnly = true; 476f36e1934SMatthias Springer config.alwaysIncludeLeaves = false; 477f36e1934SMatthias Springer config.followSameTypeOrCastsOnly = true; 478f36e1934SMatthias Springer return !state 479f36e1934SMatthias Springer .findValueInReverseUseDefChain( 480f36e1934SMatthias Springer start, [&](Value v) { return v == other; }, config) 481f36e1934SMatthias Springer .empty(); 482f36e1934SMatthias Springer } 483f36e1934SMatthias Springer 484*d9111f19SAmir Bishara /// Return "true" if the given operand's value is originating from a subset 485*d9111f19SAmir Bishara /// that is equivalent to the subset that `subsetOp` inserts into. 486*d9111f19SAmir Bishara static bool matchesInsertDestination(const AnalysisState &state, 487*d9111f19SAmir Bishara OpOperand *opOperand, 4888143307bSMatthias Springer SubsetInsertionOpInterface subsetOp) { 4898143307bSMatthias Springer auto matchingSubset = [&](Value val) { 4908143307bSMatthias Springer if (auto opResult = dyn_cast<OpResult>(val)) 4918143307bSMatthias Springer if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) { 4928143307bSMatthias Springer return state.areEquivalentBufferizedValues(v1, v2); 4938143307bSMatthias Springer })) 4948143307bSMatthias Springer return true; 4958143307bSMatthias Springer return false; 4968143307bSMatthias Springer }; 4978143307bSMatthias Springer // There may be multiple leaves at which the reverse SSA use-def chain lookup 4988143307bSMatthias Springer // terminates. All of them must be equivalent subsets. 4998143307bSMatthias Springer SetVector<Value> backwardSlice = 500*d9111f19SAmir Bishara state.findValueInReverseUseDefChain(opOperand, matchingSubset); 5018143307bSMatthias Springer return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset)); 5028143307bSMatthias Springer } 5038143307bSMatthias Springer 5048143307bSMatthias Springer /// Return "true" if the given "read" and potentially conflicting "write" are 5058143307bSMatthias Springer /// not conflicting due to their subset relationship. The comments in this 5068143307bSMatthias Springer /// function are expressed in terms of tensor.extract_slice/tensor.insert_slice 5078143307bSMatthias Springer /// pairs, but apply to any subset ops that implement the 5088143307bSMatthias Springer /// `SubsetInsertionOpInterface`. 5098143307bSMatthias Springer static bool areNonConflictingSubsets(OpOperand *uRead, 5108143307bSMatthias Springer OpOperand *uConflictingWrite, 5118143307bSMatthias Springer const AnalysisState &state) { 5128143307bSMatthias Springer Operation *readingOp = uRead->getOwner(); 5138143307bSMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 5148143307bSMatthias Springer 5158143307bSMatthias Springer // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 5168143307bSMatthias Springer // uRead is an InsertSliceOp... 5178143307bSMatthias Springer if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) { 5188143307bSMatthias Springer // As an example, consider the following IR. 5198143307bSMatthias Springer // 5208143307bSMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 5218143307bSMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 5228143307bSMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 5238143307bSMatthias Springer // {inplace= [true] } 5248143307bSMatthias Springer 5258143307bSMatthias Springer if (uRead == &subsetOp.getDestinationOperand() && 526*d9111f19SAmir Bishara matchesInsertDestination(state, uConflictingWrite, subsetOp)) 5278143307bSMatthias Springer // Case 1: The main insight is that InsertSliceOp reads only part of 5288143307bSMatthias Springer // the destination tensor. The overwritten area is not read. If 5298143307bSMatthias Springer // uConflictingWrite writes into exactly the memory location that is 5308143307bSMatthias Springer // being read by uRead, this is not a conflict. 5318143307bSMatthias Springer // 5328143307bSMatthias Springer // In the above example: 5338143307bSMatthias Springer // uRead = OpOperand 1 (%t) of tensor.insert_slice 5348143307bSMatthias Springer // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 5358143307bSMatthias Springer // 5368143307bSMatthias Springer // The read of %t does not conflict with the write of the FillOp 5378143307bSMatthias Springer // (same aliases!) because the area that the FillOp operates on is 5388143307bSMatthias Springer // exactly the one that is *not* read via %t. 5398143307bSMatthias Springer return true; 5408143307bSMatthias Springer 5418143307bSMatthias Springer if (uRead == &subsetOp.getSourceOperand() && 5428143307bSMatthias Springer uConflictingWrite == &subsetOp.getDestinationOperand() && 543*d9111f19SAmir Bishara matchesInsertDestination(state, uRead, subsetOp)) 5448143307bSMatthias Springer // Case 2: The read of the source tensor and the write to the dest 5458143307bSMatthias Springer // tensor via an InsertSliceOp is not a conflict if the read is 5468143307bSMatthias Springer // reading exactly that part of an equivalent tensor that the 5478143307bSMatthias Springer // InsertSliceOp is writing. 5488143307bSMatthias Springer // 5498143307bSMatthias Springer // In the above example: 5508143307bSMatthias Springer // uRead = OpOperand 0 (%1) of tensor.insert_slice 5518143307bSMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 5528143307bSMatthias Springer return true; 5538143307bSMatthias Springer } 5548143307bSMatthias Springer 5558143307bSMatthias Springer // If uConflictingWrite is an InsertSliceOp... 5568143307bSMatthias Springer if (auto subsetOp = 5578143307bSMatthias Springer dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp)) 5588143307bSMatthias Springer // As an example, consider the following IR. 5598143307bSMatthias Springer // 5608143307bSMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 5618143307bSMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 5628143307bSMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 5638143307bSMatthias Springer // {inplace= [true] } 5648143307bSMatthias Springer // %3 = vector.transfer_read %1, %cst 5658143307bSMatthias Springer // 5668143307bSMatthias Springer // In the above example: 5678143307bSMatthias Springer // uRead = OpOperand 0 (%1) of vector.transfer_read 5688143307bSMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 5698143307bSMatthias Springer // definition = %1 5708143307bSMatthias Springer // 5718143307bSMatthias Springer // This is not a conflict because the InsertSliceOp overwrites the 5728143307bSMatthias Springer // memory segment of %1 with the exact same data. (Effectively, there 5738143307bSMatthias Springer // is no memory write here.) 5748143307bSMatthias Springer if (uConflictingWrite == &subsetOp.getDestinationOperand() && 5758143307bSMatthias Springer state.areEquivalentBufferizedValues( 5768143307bSMatthias Springer uRead->get(), subsetOp.getSourceOperand().get()) && 577*d9111f19SAmir Bishara matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp)) 5788143307bSMatthias Springer return true; 5798143307bSMatthias Springer 5808143307bSMatthias Springer return false; 5818143307bSMatthias Springer } 5828143307bSMatthias Springer 5837a1579acSMatthias Springer /// Given sets of uses and writes, return true if there is a RaW conflict under 5847a1579acSMatthias Springer /// the assumption that all given reads/writes alias the same buffer and that 5857a1579acSMatthias Springer /// all given writes bufferize inplace. 5867a1579acSMatthias Springer /// 5877a1579acSMatthias Springer /// A conflict is: According to SSA use-def chains, a read R is supposed to read 5881840d18aSMatthias Springer /// the result of a definition W1. But because of bufferization decisions, R 5891840d18aSMatthias Springer /// actually reads another definition W2. 590cf2d374eSMatthias Springer static bool 591cf2d374eSMatthias Springer hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, 592cf2d374eSMatthias Springer const DenseSet<OpOperand *> &usesWrite, 593cf2d374eSMatthias Springer const DominanceInfo &domInfo, 594cf2d374eSMatthias Springer OneShotAnalysisState &state) { 5957a1579acSMatthias Springer const BufferizationOptions &options = state.getOptions(); 5967a1579acSMatthias Springer 5971e1a3112SMatthias Springer // Before going through the main RaW analysis, find cases where a buffer must 5981e1a3112SMatthias Springer // be privatized due to parallelism. If the result of a write is never read, 5991e1a3112SMatthias Springer // privatization is not necessary (and large parts of the IR are likely dead). 600d5863721SMax191 if (options.checkParallelRegions && !usesRead.empty()) { 6011e1a3112SMatthias Springer for (OpOperand *uConflictingWrite : usesWrite) { 6021e1a3112SMatthias Springer // Find the allocation point or last write (definition) of the buffer. 6031e1a3112SMatthias Springer // Note: In contrast to `findDefinitions`, this also returns results of 6041e1a3112SMatthias Springer // ops that do not bufferize to memory write when no other definition 6051e1a3112SMatthias Springer // could be found. E.g., "bufferization.alloc_tensor" would be included, 6061e1a3112SMatthias Springer // even though that op just bufferizes to an allocation but does define 6071e1a3112SMatthias Springer // the contents of the buffer. 6081e1a3112SMatthias Springer SetVector<Value> definitionsOrLeaves = 609*d9111f19SAmir Bishara state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) { 610*d9111f19SAmir Bishara return state.bufferizesToMemoryWrite(v); 611*d9111f19SAmir Bishara }); 6121e1a3112SMatthias Springer assert(!definitionsOrLeaves.empty() && 6131e1a3112SMatthias Springer "expected at least one definition or leaf"); 6141e1a3112SMatthias Springer 6151e1a3112SMatthias Springer // The writing op must bufferize out-of-place if the definition is in a 6161e1a3112SMatthias Springer // different parallel region than this write. 6171e1a3112SMatthias Springer for (Value def : definitionsOrLeaves) { 6181e1a3112SMatthias Springer if (getParallelRegion(def.getParentRegion(), options) != 6191e1a3112SMatthias Springer getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(), 6201e1a3112SMatthias Springer options)) { 6211e1a3112SMatthias Springer LLVM_DEBUG( 6221e1a3112SMatthias Springer llvm::dbgs() 6231e1a3112SMatthias Springer << "\n- bufferizes out-of-place due to parallel region:\n"); 6241e1a3112SMatthias Springer LLVM_DEBUG(llvm::dbgs() 6251e1a3112SMatthias Springer << " unConflictingWrite = operand " 6261e1a3112SMatthias Springer << uConflictingWrite->getOperandNumber() << " of " 6271e1a3112SMatthias Springer << *uConflictingWrite->getOwner() << "\n"); 6281e1a3112SMatthias Springer return true; 6291e1a3112SMatthias Springer } 6301e1a3112SMatthias Springer } 6311e1a3112SMatthias Springer } 6321e1a3112SMatthias Springer } 6331e1a3112SMatthias Springer 6347a1579acSMatthias Springer for (OpOperand *uRead : usesRead) { 6357a1579acSMatthias Springer Operation *readingOp = uRead->getOwner(); 6361fdf06d6SMatthias Springer LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n"); 6371fdf06d6SMatthias Springer LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber() 6381fdf06d6SMatthias Springer << " of " << *readingOp << "\n"); 6397a1579acSMatthias Springer 6401fdf06d6SMatthias Springer // Find the definition of uRead by following the SSA use-def chain. 6417a1579acSMatthias Springer // E.g.: 6427a1579acSMatthias Springer // 6437a1579acSMatthias Springer // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32> 6447a1579acSMatthias Springer // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32> 6457a1579acSMatthias Springer // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type 6467a1579acSMatthias Springer // 6471840d18aSMatthias Springer // In the above example, if uRead is the OpOperand of reading_op, the 6481840d18aSMatthias Springer // definition is %0. Note that operations that create an alias but do not 6491840d18aSMatthias Springer // bufferize to a memory write (such as ExtractSliceOp) are skipped. 650*d9111f19SAmir Bishara const SetVector<Value> &definitions = state.findDefinitionsCached(uRead); 6511fdf06d6SMatthias Springer if (definitions.empty()) { 6521fdf06d6SMatthias Springer // Fast path: No conflict if there are no definitions. 6531fdf06d6SMatthias Springer LLVM_DEBUG(llvm::dbgs() 6541fdf06d6SMatthias Springer << " no conflict: read value has no definitions\n"); 6551fdf06d6SMatthias Springer continue; 6561fdf06d6SMatthias Springer } 6577a1579acSMatthias Springer 6587a1579acSMatthias Springer // Look for conflicting memory writes. Potential conflicts are writes to an 6597a1579acSMatthias Springer // alias that have been decided to bufferize inplace. 6607a1579acSMatthias Springer for (OpOperand *uConflictingWrite : usesWrite) { 661e0b40af7SMatthias Springer LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand " 662e0b40af7SMatthias Springer << uConflictingWrite->getOperandNumber() << " of " 663e0b40af7SMatthias Springer << *uConflictingWrite->getOwner() << "\n"); 664e0b40af7SMatthias Springer 665c89c31a2SMatthias Springer // Check if op dominance can be used to rule out read-after-write 666c89c31a2SMatthias Springer // conflicts. 667c89c31a2SMatthias Springer bool useDominance = 668c89c31a2SMatthias Springer canUseOpDominance(uRead, uConflictingWrite, definitions, state); 669c89c31a2SMatthias Springer LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n"); 670c89c31a2SMatthias Springer 6717a1579acSMatthias Springer // Throughout this loop, check for multiple requirements that have to be 6727a1579acSMatthias Springer // met for uConflictingWrite to be an actual conflict. 6737a1579acSMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 6747a1579acSMatthias Springer 6751f8ffbd1SMatthias Springer // Inside of repetitive regions, ops may be executed multiple times and op 6761f8ffbd1SMatthias Springer // dominance cannot be used to rule out conflicts. 6771f8ffbd1SMatthias Springer if (useDominance) { 6781f8ffbd1SMatthias Springer // No conflict if the readingOp dominates conflictingWritingOp, i.e., 6791f8ffbd1SMatthias Springer // the write is not visible when reading. 6809235e597SMatthias Springer // 6811f8ffbd1SMatthias Springer // Note: If ops are executed multiple times (e.g., because they are 6821f8ffbd1SMatthias Springer // inside a loop), there may be no meaningful `happensBefore` 6831f8ffbd1SMatthias Springer // relationship. 684e0b40af7SMatthias Springer if (happensBefore(readingOp, conflictingWritingOp, domInfo)) { 685e0b40af7SMatthias Springer LLVM_DEBUG(llvm::dbgs() 686e0b40af7SMatthias Springer << " no conflict: read happens before write\n"); 6877a1579acSMatthias Springer continue; 688e0b40af7SMatthias Springer } 6897a1579acSMatthias Springer 6901f8ffbd1SMatthias Springer // No conflict if the reading use equals the use of the conflicting 6911f8ffbd1SMatthias Springer // write. A use cannot conflict with itself. 6929235e597SMatthias Springer // 6931f8ffbd1SMatthias Springer // Note: Just being the same op is not enough. It has to be the same 6941f8ffbd1SMatthias Springer // use. 6951f8ffbd1SMatthias Springer // Note: If the op is executed multiple times (e.g., because it is 6961f8ffbd1SMatthias Springer // inside a loop), it may be conflicting with itself. 697e0b40af7SMatthias Springer if (uConflictingWrite == uRead) { 698e0b40af7SMatthias Springer LLVM_DEBUG(llvm::dbgs() 699e0b40af7SMatthias Springer << " no conflict: read and write are same use\n"); 7007a1579acSMatthias Springer continue; 701e0b40af7SMatthias Springer } 7027a1579acSMatthias Springer 7031f8ffbd1SMatthias Springer // Ops are not conflicting if they are in mutually exclusive regions. 7041f8ffbd1SMatthias Springer // 7051f8ffbd1SMatthias Springer // Note: If ops are executed multiple times (e.g., because they are 7061f8ffbd1SMatthias Springer // inside a loop), mutually exclusive regions may be executed 7071f8ffbd1SMatthias Springer // multiple times. 708e0b40af7SMatthias Springer if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) { 709e0b40af7SMatthias Springer LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in " 710e0b40af7SMatthias Springer "mutually exclusive regions\n"); 7111f8ffbd1SMatthias Springer continue; 7121f8ffbd1SMatthias Springer } 7131f8ffbd1SMatthias Springer 71454683405SMatthias Springer // Two equivalent operands of the same op are not conflicting if the op 715cf9b77a6SMatthias Springer // bufferizes to element-wise access. I.e., all loads at a position 716cf9b77a6SMatthias Springer // happen before all stores to the same position. 717f36e1934SMatthias Springer if (conflictingWritingOp == readingOp) { 71854683405SMatthias Springer if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) { 719f36e1934SMatthias Springer if (bufferizableOp.bufferizesToElementwiseAccess( 720f36e1934SMatthias Springer state, {uRead, uConflictingWrite})) { 721f36e1934SMatthias Springer if (hasEquivalentValueInReverseUseDefChain( 722*d9111f19SAmir Bishara state, uRead, uConflictingWrite->get()) || 723f36e1934SMatthias Springer hasEquivalentValueInReverseUseDefChain( 724*d9111f19SAmir Bishara state, uConflictingWrite, uRead->get())) { 72554683405SMatthias Springer LLVM_DEBUG( 72654683405SMatthias Springer llvm::dbgs() 72754683405SMatthias Springer << " no conflict: op bufferizes to element-wise access\n"); 72854683405SMatthias Springer continue; 72954683405SMatthias Springer } 73054683405SMatthias Springer } 73154683405SMatthias Springer } 732f36e1934SMatthias Springer } 733cf9b77a6SMatthias Springer } 73454683405SMatthias Springer 7358143307bSMatthias Springer // No conflict if the operands are non-conflicting subsets. 7368143307bSMatthias Springer if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) { 7378143307bSMatthias Springer LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n"); 7388143307bSMatthias Springer continue; 7398143307bSMatthias Springer } 7408143307bSMatthias Springer 7417a1579acSMatthias Springer // No conflict if the op interface says so. 742e0b40af7SMatthias Springer if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) { 743e0b40af7SMatthias Springer if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) { 744e0b40af7SMatthias Springer LLVM_DEBUG(llvm::dbgs() 745e0b40af7SMatthias Springer << " no conflict: op interace of reading op says 'no'\n"); 7467a1579acSMatthias Springer continue; 747e0b40af7SMatthias Springer } 748e0b40af7SMatthias Springer } 7497a1579acSMatthias Springer 750e0b40af7SMatthias Springer if (conflictingWritingOp != readingOp) { 7517a1579acSMatthias Springer if (auto bufferizableOp = 752e0b40af7SMatthias Springer options.dynCastBufferizableOp(conflictingWritingOp)) { 753e0b40af7SMatthias Springer if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, 754e0b40af7SMatthias Springer state)) { 755e0b40af7SMatthias Springer LLVM_DEBUG( 756e0b40af7SMatthias Springer llvm::dbgs() 757e0b40af7SMatthias Springer << " no conflict: op interace of writing op says 'no'\n"); 7587a1579acSMatthias Springer continue; 759e0b40af7SMatthias Springer } 760e0b40af7SMatthias Springer } 761e0b40af7SMatthias Springer } 7627a1579acSMatthias Springer 7631840d18aSMatthias Springer // Check all possible definitions. 7641840d18aSMatthias Springer for (Value definition : definitions) { 7651840d18aSMatthias Springer LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n"); 766e0b40af7SMatthias Springer 7671840d18aSMatthias Springer // No conflict if the conflicting write happens before the definition. 7681fdf06d6SMatthias Springer if (Operation *defOp = definition.getDefiningOp()) { 7691fdf06d6SMatthias Springer if (happensBefore(conflictingWritingOp, defOp, domInfo)) { 7701fdf06d6SMatthias Springer // conflictingWritingOp happens before defOp. No conflict. 771e0b40af7SMatthias Springer LLVM_DEBUG(llvm::dbgs() 7721840d18aSMatthias Springer << " no conflict: write happens before definition\n"); 7737a1579acSMatthias Springer continue; 774e0b40af7SMatthias Springer } 7751fdf06d6SMatthias Springer // No conflict if conflictingWritingOp is contained in defOp. 7761fdf06d6SMatthias Springer if (defOp->isProperAncestor(conflictingWritingOp)) { 777e0b40af7SMatthias Springer LLVM_DEBUG( 778e0b40af7SMatthias Springer llvm::dbgs() 7791840d18aSMatthias Springer << " no conflict: write is contained in definition\n"); 7807a1579acSMatthias Springer continue; 781e0b40af7SMatthias Springer } 7827a1579acSMatthias Springer } else { 7835550c821STres Popp auto bbArg = cast<BlockArgument>(definition); 7847a1579acSMatthias Springer Block *block = bbArg.getOwner(); 785e0b40af7SMatthias Springer if (!block->findAncestorOpInBlock(*conflictingWritingOp)) { 7861840d18aSMatthias Springer LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg " 787e0b40af7SMatthias Springer "and write happens outside of block\n"); 7887a1579acSMatthias Springer // conflictingWritingOp happens outside of the block. No 7897a1579acSMatthias Springer // conflict. 7907a1579acSMatthias Springer continue; 7917a1579acSMatthias Springer } 792e0b40af7SMatthias Springer } 7937a1579acSMatthias Springer 7941840d18aSMatthias Springer // No conflict if the conflicting write and the definition are the same 7957a1579acSMatthias Springer // use. 796a02ad6c1SMatthias Springer AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite); 7979fa6b350SMatthias Springer if (aliases.getNumAliases() == 1 && 798a02ad6c1SMatthias Springer aliases.getAliases()[0].value == definition) { 799e0b40af7SMatthias Springer LLVM_DEBUG(llvm::dbgs() 8001840d18aSMatthias Springer << " no conflict: definition and write are same\n"); 8017a1579acSMatthias Springer continue; 802e0b40af7SMatthias Springer } 8037a1579acSMatthias Springer 8047a1579acSMatthias Springer // All requirements are met. Conflict found! 8057a1579acSMatthias Springer 8067a1579acSMatthias Springer if (options.printConflicts) 8071840d18aSMatthias Springer annotateConflict(uRead, uConflictingWrite, definition); 808e0b40af7SMatthias Springer LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n"); 8097a1579acSMatthias Springer return true; 8107a1579acSMatthias Springer } 8117a1579acSMatthias Springer } 8127a1579acSMatthias Springer } 8137a1579acSMatthias Springer 8147a1579acSMatthias Springer return false; 8157a1579acSMatthias Springer } 8167a1579acSMatthias Springer 817032be233SMatthias Springer // Helper function to iterate on aliases of `root` and capture the writes. 818032be233SMatthias Springer static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root, 819cf2d374eSMatthias Springer const OneShotAnalysisState &state) { 820cf2d374eSMatthias Springer state.applyOnAliases(root, [&](Value alias) { 821032be233SMatthias Springer for (auto &use : alias.getUses()) 822032be233SMatthias Springer // Inplace write to a value that aliases root. 823cf2d374eSMatthias Springer if (isInplaceMemoryWrite(use, state)) 824032be233SMatthias Springer res.insert(&use); 825032be233SMatthias Springer }); 826032be233SMatthias Springer } 827032be233SMatthias Springer 828032be233SMatthias Springer // Helper function to iterate on aliases of `root` and capture the reads. 829032be233SMatthias Springer static void getAliasingReads(DenseSet<OpOperand *> &res, Value root, 830cf2d374eSMatthias Springer const OneShotAnalysisState &state) { 831cf2d374eSMatthias Springer state.applyOnAliases(root, [&](Value alias) { 8321b99f3a2SMatthias Springer for (auto &use : alias.getUses()) { 8331b99f3a2SMatthias Springer // Read of a value that aliases root. 8341b99f3a2SMatthias Springer if (state.bufferizesToMemoryRead(use)) { 835032be233SMatthias Springer res.insert(&use); 8361b99f3a2SMatthias Springer continue; 8371b99f3a2SMatthias Springer } 8381b99f3a2SMatthias Springer 8391b99f3a2SMatthias Springer // Read of a dependent value in the SSA use-def chain. E.g.: 8401b99f3a2SMatthias Springer // 8411b99f3a2SMatthias Springer // %0 = ... 8421b99f3a2SMatthias Springer // %1 = tensor.extract_slice %0 {not_analyzed_yet} 8431b99f3a2SMatthias Springer // "read"(%1) 8441b99f3a2SMatthias Springer // 8451b99f3a2SMatthias Springer // In the above example, getAliasingReads(%0) includes the first OpOperand 8461b99f3a2SMatthias Springer // of the tensor.extract_slice op. The extract_slice itself does not read 8471b99f3a2SMatthias Springer // but its aliasing result is eventually fed into an op that does. 8481b99f3a2SMatthias Springer // 8491b99f3a2SMatthias Springer // Note: This is considered a "read" only if the use does not bufferize to 8501b99f3a2SMatthias Springer // a memory write. (We already ruled out memory reads. In case of a memory 8511b99f3a2SMatthias Springer // write, the buffer would be entirely overwritten; in the above example 8521b99f3a2SMatthias Springer // there would then be no flow of data from the extract_slice operand to 8531b99f3a2SMatthias Springer // its result's uses.) 8541b99f3a2SMatthias Springer if (!state.bufferizesToMemoryWrite(use)) { 855a02ad6c1SMatthias Springer AliasingValueList aliases = state.getAliasingValues(use); 856a02ad6c1SMatthias Springer if (llvm::any_of(aliases, [&](AliasingValue a) { 857a02ad6c1SMatthias Springer return state.isValueRead(a.value); 8589fa6b350SMatthias Springer })) 8591b99f3a2SMatthias Springer res.insert(&use); 8601b99f3a2SMatthias Springer } 8611b99f3a2SMatthias Springer } 862032be233SMatthias Springer }); 863032be233SMatthias Springer } 864032be233SMatthias Springer 8657a1579acSMatthias Springer /// Return true if bufferizing `operand` inplace would create a conflict. A read 8667a1579acSMatthias Springer /// R and a write W of the same alias set is a conflict if inplace bufferization 8677a1579acSMatthias Springer /// of W changes the value read by R to a value different from the one that 8687a1579acSMatthias Springer /// would be expected by tracing back R's origin through SSA use-def chains. 8697a1579acSMatthias Springer /// A conflict can only be introduced by a new alias and/or an inplace 8707a1579acSMatthias Springer /// bufferization decision. 8717a1579acSMatthias Springer /// 8727a1579acSMatthias Springer /// Example: 8737a1579acSMatthias Springer /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?} 8747a1579acSMatthias Springer /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32> 8757a1579acSMatthias Springer /// %e = tensor.extract_slice %1 8767a1579acSMatthias Springer /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32> 8777a1579acSMatthias Springer /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32> 8787a1579acSMatthias Springer /// 8797a1579acSMatthias Springer /// In the above example, the two TransferWriteOps have already been decided to 8807a1579acSMatthias Springer /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a 8817a1579acSMatthias Springer /// conflict because: 8827a1579acSMatthias Springer /// * According to SSA use-def chains, we expect to read the result of %1. 8837a1579acSMatthias Springer /// * However, adding an alias {%0, %t} would mean that the second 8841840d18aSMatthias Springer /// TransferWriteOp overwrites the result of the first one. Therefore, the 8851840d18aSMatthias Springer /// TransferReadOp would no longer be reading the result of %1. 8867a1579acSMatthias Springer /// 8877a1579acSMatthias Springer /// If `checkConsistencyOnly` is true, this function checks if there is a 8887a1579acSMatthias Springer /// read-after-write conflict without bufferizing `operand` inplace. This would 8897a1579acSMatthias Springer /// indicate a problem with the current inplace bufferization decisions. 8907a1579acSMatthias Springer /// 8917a1579acSMatthias Springer /// Note: If `checkConsistencyOnly`, this function may be called with a null 8927a1579acSMatthias Springer /// OpResult. In that case, only the consistency of bufferization decisions 8937a1579acSMatthias Springer /// involving aliases of the given OpOperand are checked. 8947a1579acSMatthias Springer static bool wouldCreateReadAfterWriteInterference( 895cf2d374eSMatthias Springer OpOperand &operand, const DominanceInfo &domInfo, 896cf2d374eSMatthias Springer OneShotAnalysisState &state, bool checkConsistencyOnly = false) { 8977a1579acSMatthias Springer // Collect reads and writes of all aliases of OpOperand and OpResult. 8987a1579acSMatthias Springer DenseSet<OpOperand *> usesRead, usesWrite; 899cf2d374eSMatthias Springer getAliasingReads(usesRead, operand.get(), state); 900cf2d374eSMatthias Springer getAliasingInplaceWrites(usesWrite, operand.get(), state); 901a02ad6c1SMatthias Springer for (AliasingValue alias : state.getAliasingValues(operand)) { 902a02ad6c1SMatthias Springer getAliasingReads(usesRead, alias.value, state); 903a02ad6c1SMatthias Springer getAliasingInplaceWrites(usesWrite, alias.value, state); 9047a1579acSMatthias Springer } 9057a1579acSMatthias Springer if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) 9067a1579acSMatthias Springer usesWrite.insert(&operand); 9077a1579acSMatthias Springer 908cf2d374eSMatthias Springer return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state); 9097a1579acSMatthias Springer } 9107a1579acSMatthias Springer 911be630f07SMatthias Springer /// Annotate IR with details about the detected non-writability conflict. 912be630f07SMatthias Springer static void annotateNonWritableTensor(Value value) { 913be630f07SMatthias Springer static int64_t counter = 0; 914be630f07SMatthias Springer OpBuilder b(value.getContext()); 915be630f07SMatthias Springer std::string id = "W_" + std::to_string(counter++); 9165550c821STres Popp if (auto opResult = dyn_cast<OpResult>(value)) { 917be630f07SMatthias Springer std::string attr = id + "[NOT-WRITABLE: result " + 918be630f07SMatthias Springer std::to_string(opResult.getResultNumber()) + "]"; 919be630f07SMatthias Springer opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr()); 920be630f07SMatthias Springer } else { 9215550c821STres Popp auto bbArg = cast<BlockArgument>(value); 922be630f07SMatthias Springer std::string attr = id + "[NOT-WRITABLE: bbArg " + 923be630f07SMatthias Springer std::to_string(bbArg.getArgNumber()) + "]"; 924be630f07SMatthias Springer bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr()); 925be630f07SMatthias Springer } 926be630f07SMatthias Springer } 927be630f07SMatthias Springer 928032be233SMatthias Springer /// Return true if bufferizing `operand` inplace would create a write to a 929032be233SMatthias Springer /// non-writable buffer. 930cf2d374eSMatthias Springer static bool 931cf2d374eSMatthias Springer wouldCreateWriteToNonWritableBuffer(OpOperand &operand, 932cf2d374eSMatthias Springer OneShotAnalysisState &state, 933cf2d374eSMatthias Springer bool checkConsistencyOnly = false) { 934dc7ad194SMatthias Springer bool foundWrite = 935dc7ad194SMatthias Springer !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand); 936dc7ad194SMatthias Springer 937dc7ad194SMatthias Springer if (!foundWrite) { 938032be233SMatthias Springer // Collect writes of all aliases of OpOperand and OpResult. 939032be233SMatthias Springer DenseSet<OpOperand *> usesWrite; 940cf2d374eSMatthias Springer getAliasingInplaceWrites(usesWrite, operand.get(), state); 941a02ad6c1SMatthias Springer for (AliasingValue alias : state.getAliasingValues(operand)) 942a02ad6c1SMatthias Springer getAliasingInplaceWrites(usesWrite, alias.value, state); 943dc7ad194SMatthias Springer foundWrite = !usesWrite.empty(); 944032be233SMatthias Springer } 9457a1579acSMatthias Springer 946dc7ad194SMatthias Springer if (!foundWrite) 947dc7ad194SMatthias Springer return false; 948dc7ad194SMatthias Springer 949dc7ad194SMatthias Springer // Look for a read-only tensor among all aliases. 950dc7ad194SMatthias Springer bool foundReadOnly = false; 951dc7ad194SMatthias Springer auto checkReadOnly = [&](Value v) { 952dc7ad194SMatthias Springer if (!state.isWritable(v)) { 953dc7ad194SMatthias Springer foundReadOnly = true; 954dc7ad194SMatthias Springer if (state.getOptions().printConflicts) 955dc7ad194SMatthias Springer annotateNonWritableTensor(v); 956dc7ad194SMatthias Springer } 957dc7ad194SMatthias Springer }; 958dc7ad194SMatthias Springer state.applyOnAliases(operand.get(), checkReadOnly); 959a02ad6c1SMatthias Springer for (AliasingValue alias : state.getAliasingValues(operand)) 960a02ad6c1SMatthias Springer state.applyOnAliases(alias.value, checkReadOnly); 961dc7ad194SMatthias Springer if (foundReadOnly) { 962e0b40af7SMatthias Springer LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n"); 963032be233SMatthias Springer return true; 964e0b40af7SMatthias Springer } 9657a1579acSMatthias Springer 966032be233SMatthias Springer return false; 9677a1579acSMatthias Springer } 9687a1579acSMatthias Springer 9697a1579acSMatthias Springer //===----------------------------------------------------------------------===// 9707a1579acSMatthias Springer // Bufferization analyses. 9717a1579acSMatthias Springer //===----------------------------------------------------------------------===// 9727a1579acSMatthias Springer 973*d9111f19SAmir Bishara // Find the values that define the contents of the given operand's value. 9742b5a020dSMatthias Springer const llvm::SetVector<Value> & 975*d9111f19SAmir Bishara OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) { 976*d9111f19SAmir Bishara Value value = opOperand->get(); 9771f479c1eSMatthias Springer if (!cachedDefinitions.count(value)) 978*d9111f19SAmir Bishara cachedDefinitions[value] = findDefinitions(opOperand); 9792b5a020dSMatthias Springer return cachedDefinitions[value]; 9802b5a020dSMatthias Springer } 9812b5a020dSMatthias Springer 9829312b4f9SMartin Erhart void OneShotAnalysisState::resetCache() { 9839312b4f9SMartin Erhart AnalysisState::resetCache(); 9849312b4f9SMartin Erhart cachedDefinitions.clear(); 9859312b4f9SMartin Erhart } 9862b5a020dSMatthias Springer 9877a1579acSMatthias Springer /// Determine if `operand` can be bufferized in-place. 988cf2d374eSMatthias Springer static LogicalResult 989cf2d374eSMatthias Springer bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, 990cf2d374eSMatthias Springer const DominanceInfo &domInfo) { 991e0b40af7SMatthias Springer LLVM_DEBUG( 992e0b40af7SMatthias Springer llvm::dbgs() << "//===-------------------------------------------===//\n" 993e0b40af7SMatthias Springer << "Analyzing operand #" << operand.getOperandNumber() 994e0b40af7SMatthias Springer << " of " << *operand.getOwner() << "\n"); 995e0b40af7SMatthias Springer 9967a1579acSMatthias Springer bool foundInterference = 997cf2d374eSMatthias Springer wouldCreateWriteToNonWritableBuffer(operand, state) || 998cf2d374eSMatthias Springer wouldCreateReadAfterWriteInterference(operand, domInfo, state); 9997a1579acSMatthias Springer 10007a1579acSMatthias Springer if (foundInterference) 1001cf2d374eSMatthias Springer state.bufferizeOutOfPlace(operand); 10027a1579acSMatthias Springer else 1003cf2d374eSMatthias Springer state.bufferizeInPlace(operand); 10047a1579acSMatthias Springer 1005e0b40af7SMatthias Springer LLVM_DEBUG(llvm::dbgs() 1006e0b40af7SMatthias Springer << "//===-------------------------------------------===//\n"); 10077a1579acSMatthias Springer return success(); 10087a1579acSMatthias Springer } 10097a1579acSMatthias Springer 10106d14b110SMatthias Springer LogicalResult 10116d14b110SMatthias Springer OneShotAnalysisState::analyzeSingleOp(Operation *op, 10126d14b110SMatthias Springer const DominanceInfo &domInfo) { 10137a1579acSMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) 10145550c821STres Popp if (isa<TensorType>(opOperand.get().getType())) 10156d14b110SMatthias Springer if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo))) 10167a1579acSMatthias Springer return failure(); 10171b99f3a2SMatthias Springer return success(); 10187a1579acSMatthias Springer } 10197a1579acSMatthias Springer 10207a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. 10217a1579acSMatthias Springer static void equivalenceAnalysis(SmallVector<Operation *> &ops, 1022cf2d374eSMatthias Springer OneShotAnalysisState &state) { 10239fa6b350SMatthias Springer for (Operation *op : ops) { 10249fa6b350SMatthias Springer if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) { 10259fa6b350SMatthias Springer for (OpResult opResult : op->getOpResults()) { 10265550c821STres Popp if (!isa<TensorType>(opResult.getType())) 10279fa6b350SMatthias Springer continue; 10289fa6b350SMatthias Springer AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); 10299fa6b350SMatthias Springer if (aliases.getNumAliases() == 0) 10309fa6b350SMatthias Springer // Nothing to do if there are no aliasing OpOperands. 10319fa6b350SMatthias Springer continue; 10329fa6b350SMatthias Springer 10339fa6b350SMatthias Springer Value firstOperand = aliases.begin()->opOperand->get(); 10349fa6b350SMatthias Springer bool allEquivalent = true; 10359fa6b350SMatthias Springer for (AliasingOpOperand alias : aliases) { 10369fa6b350SMatthias Springer bool isEquiv = alias.relation == BufferRelation::Equivalent; 10379fa6b350SMatthias Springer bool isInPlace = state.isInPlace(*alias.opOperand); 10389fa6b350SMatthias Springer Value operand = alias.opOperand->get(); 10399fa6b350SMatthias Springer if (isEquiv && isInPlace && alias.isDefinite) { 10409fa6b350SMatthias Springer // Found a definite, equivalent alias. Merge equivalence sets. 10419fa6b350SMatthias Springer // There can only be one definite alias, so we can stop here. 10429fa6b350SMatthias Springer state.unionEquivalenceClasses(opResult, operand); 10439fa6b350SMatthias Springer allEquivalent = false; 10449fa6b350SMatthias Springer break; 10459fa6b350SMatthias Springer } 10469fa6b350SMatthias Springer if (!isEquiv || !isInPlace) 10479fa6b350SMatthias Springer allEquivalent = false; 10489fa6b350SMatthias Springer if (!state.areEquivalentBufferizedValues(operand, firstOperand)) 10499fa6b350SMatthias Springer allEquivalent = false; 10509fa6b350SMatthias Springer } 10519fa6b350SMatthias Springer 10529fa6b350SMatthias Springer // If all "maybe" aliases are equivalent and the OpResult is not a new 10539fa6b350SMatthias Springer // allocation, it is a definite, equivalent alias. E.g.: 10549fa6b350SMatthias Springer // 10559fa6b350SMatthias Springer // aliasingOpOperands(%r) = {(%t0, EQUIV, MAYBE), (%t1, EQUIV, MAYBE)} 1056a02ad6c1SMatthias Springer // aliasingValues(%t0) = {(%r, EQUIV, MAYBE)} 1057a02ad6c1SMatthias Springer // aliasingValues(%t1) = {(%r, EQUIV, MAYBE)} 10589fa6b350SMatthias Springer // %r = arith.select %c, %t0, %t1 : tensor<?xf32> 10599fa6b350SMatthias Springer // 10609fa6b350SMatthias Springer // If %t0 and %t1 are equivalent, it is safe to union the equivalence 10619fa6b350SMatthias Springer // classes of %r, %t0 and %t1. 10629fa6b350SMatthias Springer if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult)) 10639fa6b350SMatthias Springer state.unionEquivalenceClasses(opResult, firstOperand); 10649fa6b350SMatthias Springer } 10659fa6b350SMatthias Springer } 10669fa6b350SMatthias Springer } 10677a1579acSMatthias Springer } 10687a1579acSMatthias Springer 10697a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained 10707a1579acSMatthias Springer /// in `op`. 1071cf2d374eSMatthias Springer static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) { 10727a1579acSMatthias Springer // Traverse ops in PostOrder: Nested ops first, then enclosing ops. 10737a1579acSMatthias Springer SmallVector<Operation *> ops; 10747a1579acSMatthias Springer op->walk<WalkOrder::PostOrder>([&](Operation *op) { 10757a1579acSMatthias Springer // No tensors => no buffers. 10767a1579acSMatthias Springer if (none_of(op->getResultTypes(), isaTensor)) 10777a1579acSMatthias Springer return; 10787a1579acSMatthias Springer ops.push_back(op); 10797a1579acSMatthias Springer }); 10807a1579acSMatthias Springer 1081cf2d374eSMatthias Springer equivalenceAnalysis(ops, state); 10827a1579acSMatthias Springer } 10837a1579acSMatthias Springer 108435d3b343SMatthias Springer /// "Bottom-up from terminators" heuristic. 108535d3b343SMatthias Springer static SmallVector<Operation *> 108635d3b343SMatthias Springer bottomUpFromTerminatorsHeuristic(Operation *op, 108735d3b343SMatthias Springer const OneShotAnalysisState &state) { 108835d3b343SMatthias Springer SetVector<Operation *> traversedOps; 108935d3b343SMatthias Springer 109035d3b343SMatthias Springer // Find region terminators. 109135d3b343SMatthias Springer op->walk<WalkOrder::PostOrder>([&](RegionBranchTerminatorOpInterface term) { 109235d3b343SMatthias Springer if (!traversedOps.insert(term)) 109335d3b343SMatthias Springer return; 109435d3b343SMatthias Springer // Follow the reverse SSA use-def chain from each yielded value as long as 109535d3b343SMatthias Springer // we stay within the same region. 109635d3b343SMatthias Springer SmallVector<OpResult> worklist; 109735d3b343SMatthias Springer for (Value v : term->getOperands()) { 109835d3b343SMatthias Springer if (!isa<TensorType>(v.getType())) 109935d3b343SMatthias Springer continue; 110035d3b343SMatthias Springer auto opResult = dyn_cast<OpResult>(v); 110135d3b343SMatthias Springer if (!opResult) 110235d3b343SMatthias Springer continue; 110335d3b343SMatthias Springer worklist.push_back(opResult); 110435d3b343SMatthias Springer } 110535d3b343SMatthias Springer while (!worklist.empty()) { 110635d3b343SMatthias Springer OpResult opResult = worklist.pop_back_val(); 110735d3b343SMatthias Springer Operation *defOp = opResult.getDefiningOp(); 110835d3b343SMatthias Springer if (!traversedOps.insert(defOp)) 110935d3b343SMatthias Springer continue; 111035d3b343SMatthias Springer if (!term->getParentRegion()->findAncestorOpInRegion(*defOp)) 111135d3b343SMatthias Springer continue; 111235d3b343SMatthias Springer AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); 111335d3b343SMatthias Springer for (auto alias : aliases) { 111435d3b343SMatthias Springer Value v = alias.opOperand->get(); 111535d3b343SMatthias Springer if (!isa<TensorType>(v.getType())) 111635d3b343SMatthias Springer continue; 111735d3b343SMatthias Springer auto opResult = dyn_cast<OpResult>(v); 111835d3b343SMatthias Springer if (!opResult) 111935d3b343SMatthias Springer continue; 112035d3b343SMatthias Springer worklist.push_back(opResult); 112135d3b343SMatthias Springer } 112235d3b343SMatthias Springer } 112335d3b343SMatthias Springer }); 112435d3b343SMatthias Springer 112535d3b343SMatthias Springer // Analyze traversed ops, then all remaining ops. 112635d3b343SMatthias Springer SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end()); 112735d3b343SMatthias Springer op->walk<WalkOrder::PostOrder, ReverseIterator>([&](Operation *op) { 112835d3b343SMatthias Springer if (!traversedOps.contains(op) && hasTensorSemantics(op)) 112935d3b343SMatthias Springer result.push_back(op); 113035d3b343SMatthias Springer }); 113135d3b343SMatthias Springer return result; 113235d3b343SMatthias Springer } 113335d3b343SMatthias Springer 11346d14b110SMatthias Springer LogicalResult OneShotAnalysisState::analyzeOp(Operation *op, 11356d14b110SMatthias Springer const DominanceInfo &domInfo) { 113635d3b343SMatthias Springer OneShotBufferizationOptions::AnalysisHeuristic heuristic = 113735d3b343SMatthias Springer getOptions().analysisHeuristic; 113835d3b343SMatthias Springer 113935d3b343SMatthias Springer SmallVector<Operation *> orderedOps; 114035d3b343SMatthias Springer if (heuristic == 114135d3b343SMatthias Springer OneShotBufferizationOptions::AnalysisHeuristic::BottomUpFromTerminators) { 114235d3b343SMatthias Springer orderedOps = bottomUpFromTerminatorsHeuristic(op, *this); 114335d3b343SMatthias Springer } else { 11446d14b110SMatthias Springer op->walk([&](Operation *op) { 11456d14b110SMatthias Springer // No tensors => no buffers. 11466d14b110SMatthias Springer if (!hasTensorSemantics(op)) 11476d14b110SMatthias Springer return; 114835d3b343SMatthias Springer orderedOps.push_back(op); 11496d14b110SMatthias Springer }); 115035d3b343SMatthias Springer switch (heuristic) { 115135d3b343SMatthias Springer case OneShotBufferizationOptions::AnalysisHeuristic::BottomUp: { 11526d14b110SMatthias Springer // Default: Walk ops in reverse for better interference analysis. 115335d3b343SMatthias Springer std::reverse(orderedOps.begin(), orderedOps.end()); 115435d3b343SMatthias Springer break; 115535d3b343SMatthias Springer } 115635d3b343SMatthias Springer case OneShotBufferizationOptions::AnalysisHeuristic::TopDown: { 115735d3b343SMatthias Springer // Ops are already sorted top-down in `orderedOps`. 115835d3b343SMatthias Springer break; 115935d3b343SMatthias Springer } 116035d3b343SMatthias Springer case OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer: { 116135d3b343SMatthias Springer assert(getOptions().analysisFuzzerSeed && 116235d3b343SMatthias Springer "expected that fuzzer seed it set"); 116335d3b343SMatthias Springer // This is a fuzzer. For testing purposes only. Randomize the order in 116435d3b343SMatthias Springer // which operations are analyzed. The bufferization quality is likely 116535d3b343SMatthias Springer // worse, but we want to make sure that no assertions are triggered 116635d3b343SMatthias Springer // anywhere. 116735d3b343SMatthias Springer std::mt19937 g(getOptions().analysisFuzzerSeed); 116835d3b343SMatthias Springer llvm::shuffle(orderedOps.begin(), orderedOps.end(), g); 116935d3b343SMatthias Springer break; 117035d3b343SMatthias Springer } 117135d3b343SMatthias Springer default: { 11726d14b110SMatthias Springer llvm_unreachable("unsupported heuristic"); 11736d14b110SMatthias Springer } 117435d3b343SMatthias Springer } 117535d3b343SMatthias Springer } 117635d3b343SMatthias Springer 117735d3b343SMatthias Springer // Analyze ops in the computed order. 117835d3b343SMatthias Springer for (Operation *op : orderedOps) 117935d3b343SMatthias Springer if (failed(analyzeSingleOp(op, domInfo))) 118035d3b343SMatthias Springer return failure(); 11816d14b110SMatthias Springer 11826d14b110SMatthias Springer equivalenceAnalysis(op, *this); 11836d14b110SMatthias Springer return success(); 11846d14b110SMatthias Springer } 11856d14b110SMatthias Springer 1186061aa2e3SMatthias Springer /// Perform various checks on the input IR to see if it contains IR constructs 1187061aa2e3SMatthias Springer /// that are unsupported by One-Shot Bufferize. 1188061aa2e3SMatthias Springer static LogicalResult 1189061aa2e3SMatthias Springer checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo, 1190cf2d374eSMatthias Springer OneShotAnalysisState &state) { 11917a1579acSMatthias Springer const BufferizationOptions &options = state.getOptions(); 11923f914d84SMatthias Springer 1193061aa2e3SMatthias Springer // Note: This walk cannot be combined with the one below because interface 1194061aa2e3SMatthias Springer // methods of invalid/unsupported ops may be called during the second walk. 1195061aa2e3SMatthias Springer // (On ops different from `op`.) 11963f914d84SMatthias Springer WalkResult walkResult = op->walk([&](BufferizableOpInterface op) { 11973f914d84SMatthias Springer // Skip ops that are not in the filter. 11983f914d84SMatthias Springer if (!options.isOpAllowed(op.getOperation())) 11993f914d84SMatthias Springer return WalkResult::advance(); 12003f914d84SMatthias Springer 1201061aa2e3SMatthias Springer // Check for unsupported unstructured control flow. 1202061aa2e3SMatthias Springer if (!op.supportsUnstructuredControlFlow()) { 1203061aa2e3SMatthias Springer for (Region &r : op->getRegions()) { 1204061aa2e3SMatthias Springer if (r.getBlocks().size() > 1) { 1205061aa2e3SMatthias Springer op->emitOpError("op or BufferizableOpInterface implementation does " 1206061aa2e3SMatthias Springer "not support unstructured control flow, but at least " 1207061aa2e3SMatthias Springer "one region has multiple blocks"); 1208061aa2e3SMatthias Springer return WalkResult::interrupt(); 1209061aa2e3SMatthias Springer } 1210061aa2e3SMatthias Springer } 1211061aa2e3SMatthias Springer } 1212061aa2e3SMatthias Springer 1213061aa2e3SMatthias Springer return WalkResult::advance(); 1214061aa2e3SMatthias Springer }); 1215061aa2e3SMatthias Springer if (walkResult.wasInterrupted()) 1216061aa2e3SMatthias Springer return failure(); 1217061aa2e3SMatthias Springer 1218061aa2e3SMatthias Springer walkResult = op->walk([&](BufferizableOpInterface op) { 1219061aa2e3SMatthias Springer // Skip ops that are not in the filter. 1220061aa2e3SMatthias Springer if (!options.isOpAllowed(op.getOperation())) 1221061aa2e3SMatthias Springer return WalkResult::advance(); 1222061aa2e3SMatthias Springer 12238f7e7400SMatthias Springer // Input IR may not contain any ToTensorOps without the "restrict" 12248f7e7400SMatthias Springer // attribute. Such tensors may alias any other tensor, which is currently 12258f7e7400SMatthias Springer // not handled in the analysis. 12268f7e7400SMatthias Springer if (auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) { 12276badbd6fSMatthias Springer if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) { 12288ee38f3bSMatthias Springer op->emitOpError("to_tensor ops without `restrict` are not supported by " 12298f7e7400SMatthias Springer "One-Shot Analysis"); 12308f7e7400SMatthias Springer return WalkResult::interrupt(); 12318f7e7400SMatthias Springer } 12328f7e7400SMatthias Springer } 12338f7e7400SMatthias Springer 12343f914d84SMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) { 12355550c821STres Popp if (isa<TensorType>(opOperand.get().getType())) { 12367a1579acSMatthias Springer if (wouldCreateReadAfterWriteInterference( 1237cf2d374eSMatthias Springer opOperand, domInfo, state, 12387a1579acSMatthias Springer /*checkConsistencyOnly=*/true)) { 12397a1579acSMatthias Springer // This error can happen if certain "mustBufferizeInPlace" interface 12407a1579acSMatthias Springer // methods are implemented incorrectly, such that the IR already has 12418ee38f3bSMatthias Springer // a RaW conflict before making any bufferization decisions. It can 12428ee38f3bSMatthias Springer // also happen if the bufferization.materialize_in_destination is used 12438ee38f3bSMatthias Springer // in such a way that a RaW conflict is not avoidable. 12448ee38f3bSMatthias Springer op->emitOpError("not bufferizable under the given constraints: " 12458ee38f3bSMatthias Springer "cannot avoid RaW conflict"); 12468ee38f3bSMatthias Springer return WalkResult::interrupt(); 12478ee38f3bSMatthias Springer } 12488ee38f3bSMatthias Springer 12498ee38f3bSMatthias Springer if (state.isInPlace(opOperand) && 12508ee38f3bSMatthias Springer wouldCreateWriteToNonWritableBuffer( 12518ee38f3bSMatthias Springer opOperand, state, /*checkConsistencyOnly=*/true)) { 12528ee38f3bSMatthias Springer op->emitOpError("not bufferizable under the given constraints: would " 12538ee38f3bSMatthias Springer "write to read-only buffer"); 12547a1579acSMatthias Springer return WalkResult::interrupt(); 12557a1579acSMatthias Springer } 12567a1579acSMatthias Springer } 12573f914d84SMatthias Springer } 12583f914d84SMatthias Springer 12597a1579acSMatthias Springer return WalkResult::advance(); 12607a1579acSMatthias Springer }); 12617a1579acSMatthias Springer 12623f914d84SMatthias Springer return success(!walkResult.wasInterrupted()); 12637a1579acSMatthias Springer } 12647a1579acSMatthias Springer 12657a1579acSMatthias Springer /// Annotate the IR with the result of the analysis. For testing/debugging only. 12667a1579acSMatthias Springer static void 12677a1579acSMatthias Springer annotateOpsWithBufferizationMarkers(Operation *op, 1268cf2d374eSMatthias Springer const OneShotAnalysisState &state) { 12697cdfc843SMatthias Springer // Add __inplace_operands_attr__. 1270f3483c23SMatthias Springer op->walk([&](Operation *op) { 1271f3483c23SMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) 12725550c821STres Popp if (isa<TensorType>(opOperand.get().getType())) 1273cf2d374eSMatthias Springer setInPlaceOpOperand(opOperand, state.isInPlace(opOperand)); 12747a1579acSMatthias Springer }); 12757a1579acSMatthias Springer } 12767a1579acSMatthias Springer 1277bb9d1b55SMatthias Springer static void annotateOpsWithAliasSets(Operation *op, 1278bb9d1b55SMatthias Springer const OneShotAnalysisState &state) { 1279bb9d1b55SMatthias Springer AsmState asmState(op); 1280bb9d1b55SMatthias Springer Builder b(op->getContext()); 1281a02ad6c1SMatthias Springer // Helper function to build an array attribute of aliasing SSA value strings. 1282a02ad6c1SMatthias Springer auto buildAliasesArray = [&](Value v) { 1283bb9d1b55SMatthias Springer SmallVector<Attribute> aliases; 1284a02ad6c1SMatthias Springer state.applyOnAliases(v, [&](Value alias) { 1285bb9d1b55SMatthias Springer std::string buffer; 1286bb9d1b55SMatthias Springer llvm::raw_string_ostream stream(buffer); 1287bb9d1b55SMatthias Springer alias.printAsOperand(stream, asmState); 1288884221edSJOE1994 aliases.push_back(b.getStringAttr(buffer)); 1289bb9d1b55SMatthias Springer }); 1290a02ad6c1SMatthias Springer return b.getArrayAttr(aliases); 1291a02ad6c1SMatthias Springer }; 1292a02ad6c1SMatthias Springer 1293a02ad6c1SMatthias Springer op->walk([&](Operation *op) { 1294a02ad6c1SMatthias Springer // Build alias set array for every OpResult. 1295a02ad6c1SMatthias Springer SmallVector<Attribute> opResultAliasSets; 1296a02ad6c1SMatthias Springer for (OpResult opResult : op->getOpResults()) { 1297a02ad6c1SMatthias Springer if (llvm::isa<TensorType>(opResult.getType())) { 1298a02ad6c1SMatthias Springer opResultAliasSets.push_back(buildAliasesArray(opResult)); 1299bb9d1b55SMatthias Springer } 1300bb9d1b55SMatthias Springer } 1301a02ad6c1SMatthias Springer if (!opResultAliasSets.empty()) 1302a02ad6c1SMatthias Springer op->setAttr(kOpResultAliasSetAttrName, b.getArrayAttr(opResultAliasSets)); 1303a02ad6c1SMatthias Springer 1304a02ad6c1SMatthias Springer // Build alias set array for every BlockArgument. 1305a02ad6c1SMatthias Springer SmallVector<Attribute> regionAliasSets; 1306a02ad6c1SMatthias Springer bool hasTensorBbArg = false; 1307a02ad6c1SMatthias Springer for (Region &r : op->getRegions()) { 1308a02ad6c1SMatthias Springer SmallVector<Attribute> blockAliasSets; 1309a02ad6c1SMatthias Springer for (Block &block : r.getBlocks()) { 1310a02ad6c1SMatthias Springer SmallVector<Attribute> bbArgAliasSets; 1311a02ad6c1SMatthias Springer for (BlockArgument bbArg : block.getArguments()) { 1312a02ad6c1SMatthias Springer if (llvm::isa<TensorType>(bbArg.getType())) { 1313a02ad6c1SMatthias Springer bbArgAliasSets.push_back(buildAliasesArray(bbArg)); 1314a02ad6c1SMatthias Springer hasTensorBbArg = true; 1315a02ad6c1SMatthias Springer } 1316a02ad6c1SMatthias Springer } 1317a02ad6c1SMatthias Springer blockAliasSets.push_back(b.getArrayAttr(bbArgAliasSets)); 1318a02ad6c1SMatthias Springer } 1319a02ad6c1SMatthias Springer regionAliasSets.push_back(b.getArrayAttr(blockAliasSets)); 1320a02ad6c1SMatthias Springer } 1321a02ad6c1SMatthias Springer if (hasTensorBbArg) 1322a02ad6c1SMatthias Springer op->setAttr(kBbArgAliasSetAttrName, b.getArrayAttr(regionAliasSets)); 1323bb9d1b55SMatthias Springer }); 1324bb9d1b55SMatthias Springer } 1325bb9d1b55SMatthias Springer 13267a1579acSMatthias Springer LogicalResult bufferization::analyzeOp(Operation *op, 1327ae05bd99SMatthias Springer OneShotAnalysisState &state, 1328ae05bd99SMatthias Springer BufferizationStatistics *statistics) { 13297a1579acSMatthias Springer DominanceInfo domInfo(op); 13307cdfc843SMatthias Springer const OneShotBufferizationOptions &options = state.getOptions(); 13317a1579acSMatthias Springer 1332061aa2e3SMatthias Springer if (failed(checkPreBufferizationAssumptions(op, domInfo, state))) 13337a1579acSMatthias Springer return failure(); 13347a1579acSMatthias Springer 13357a1579acSMatthias Springer // If the analysis fails, just return. 13366d14b110SMatthias Springer if (failed(state.analyzeOp(op, domInfo))) 13377a1579acSMatthias Springer return failure(); 1338ae05bd99SMatthias Springer 1339ae05bd99SMatthias Springer if (statistics) { 1340cf2d374eSMatthias Springer statistics->numTensorInPlace = state.getStatNumTensorInPlace(); 1341cf2d374eSMatthias Springer statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace(); 1342ae05bd99SMatthias Springer } 1343ae05bd99SMatthias Springer 1344d1d79920SMatthias Springer bool failedAnalysis = false; 13457a1579acSMatthias Springer 1346988748c0SMatthias Springer // Gather some extra analysis data. 1347988748c0SMatthias Springer state.gatherUndefinedTensorUses(op); 13489e24f0f4SMatthias Springer 13494ec00fb3SMatthias Springer // Analysis verification: After setting up alias/equivalence sets, each op 13504ec00fb3SMatthias Springer // can check for expected invariants/limitations and fail the analysis if 13514ec00fb3SMatthias Springer // necessary. 13524ec00fb3SMatthias Springer op->walk([&](Operation *op) { 13534ec00fb3SMatthias Springer if (BufferizableOpInterface bufferizableOp = 13544ec00fb3SMatthias Springer options.dynCastBufferizableOp(op)) 1355d1d79920SMatthias Springer failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state)); 13564ec00fb3SMatthias Springer }); 13574ec00fb3SMatthias Springer 13587a1579acSMatthias Springer // Annotate operations if we only want to report the analysis. 13597a1579acSMatthias Springer if (options.testAnalysisOnly) 1360cf2d374eSMatthias Springer annotateOpsWithBufferizationMarkers(op, state); 1361bb9d1b55SMatthias Springer if (options.dumpAliasSets) 1362bb9d1b55SMatthias Springer annotateOpsWithAliasSets(op, state); 13637a1579acSMatthias Springer 1364d1d79920SMatthias Springer return success(!failedAnalysis); 13657a1579acSMatthias Springer } 13667a1579acSMatthias Springer 13679597b16aSMatthias Springer LogicalResult 13689597b16aSMatthias Springer bufferization::runOneShotBufferize(Operation *op, 1369ae05bd99SMatthias Springer const OneShotBufferizationOptions &options, 1370ae05bd99SMatthias Springer BufferizationStatistics *statistics) { 1371179e1749SMatthias Springer // copy-before-write deactivates the analysis. It cannot be used together with 1372179e1749SMatthias Springer // test-analysis-only. 1373f7dd9a32SMatthias Springer assert(!(options.copyBeforeWrite && options.testAnalysisOnly) && 1374f7dd9a32SMatthias Springer "invalid combination of bufferization flags"); 1375179e1749SMatthias Springer 1376179e1749SMatthias Springer if (options.copyBeforeWrite) { 1377179e1749SMatthias Springer // Copy buffer before each write. No analysis is needed. 1378179e1749SMatthias Springer } else { 1379179e1749SMatthias Springer // Run One-Shot Analysis and insert buffer copies (on the tensor level) 1380179e1749SMatthias Springer // only where needed. This is the default and much more efficient than 1381179e1749SMatthias Springer // copy-before-write. 1382ae05bd99SMatthias Springer if (failed(insertTensorCopies(op, options, statistics))) 13837a1579acSMatthias Springer return failure(); 1384179e1749SMatthias Springer 1385179e1749SMatthias Springer // If test-analysis-only is set, the IR was annotated with RaW conflict 1386179e1749SMatthias Springer // markers (attributes) during One-Shot Analysis. 1387d2dacde5SMatthias Springer if (options.testAnalysisOnly) 13887a1579acSMatthias Springer return success(); 1389179e1749SMatthias Springer } 1390179e1749SMatthias Springer 1391179e1749SMatthias Springer // Bufferize the op and its nested ops. If options.copyBeforeWrite is set, 1392179e1749SMatthias Springer // a new buffer copy is allocated every time a buffer is written to. 13939d34c052SMatthias Springer return bufferizeOp(op, options, statistics); 13947a1579acSMatthias Springer } 1395