1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // One-Shot Analysis analyzes function bodies. By default, function boundaries 10 // (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops. 11 // OneShotModuleBufferization.cpp is an extension of One-Shot Analysis for 12 // simple call graphs without loops. 13 // 14 // One-Shot Bufferize consists of three phases. 15 // 16 // 1. Analyze ops to decide which OpOperands can bufferize inplace, i.e., 17 // without inserting buffer copies. The analysis queries op bufferization 18 // semantics via `BufferizableOpInterface`. 19 // 2. Insert copies for OpOperands that were decided to bufferize out-of-place 20 // in tensor land during `TensorCopyInsertion`. 21 // 3. Bufferize ops by calling `BufferizableOpInterface::bufferize`. 22 // 23 // This file contains only the analysis. For convenience, this file also 24 // contains a helper function `runOneShotBufferize` that analyzes an op (and its 25 // nested ops) and then bufferizes it. 26 // 27 // Inplace bufferization decisions are passed from the analysis to the 28 // `TensorCopyInsertion` phase via `AnalysisState`. They can be printed for 29 // debugging purposes with `testAnalysisOnly`. 30 // 31 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are 32 // treated conservatively. E.g., the analysis has to assume that their tensor 33 // OpOperands bufferize to memory writes. While such ops can be analyzed, they 34 // are not bufferized and remain in the IR. to_tensor and to_memref ops are 35 // inserted at the bufferization boundary. 36 // 37 // This analysis caters to high-performance codegen where buffer reuse is deemed 38 // critical: the analysis should fail if the bufferized form of the function 39 // needs to return a buffer, unless `allowReturnAllocs` is enabled. 40 41 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 42 43 #include <optional> 44 #include <random> 45 46 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 47 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 48 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 49 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 50 #include "mlir/Dialect/Func/IR/FuncOps.h" 51 #include "mlir/Dialect/MemRef/IR/MemRef.h" 52 #include "mlir/IR/AsmState.h" 53 #include "mlir/IR/Dominance.h" 54 #include "mlir/IR/Iterators.h" 55 #include "mlir/IR/Operation.h" 56 #include "mlir/IR/TypeUtilities.h" 57 #include "mlir/Interfaces/ControlFlowInterfaces.h" 58 #include "mlir/Interfaces/SubsetOpInterface.h" 59 #include "llvm/ADT/DenseSet.h" 60 #include "llvm/ADT/SetVector.h" 61 62 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState) 63 64 // Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug 65 // output. 66 #define DEBUG_TYPE "one-shot-analysis" 67 68 using namespace mlir; 69 using namespace mlir::bufferization; 70 71 static bool isaTensor(Type t) { return isa<TensorType>(t); } 72 73 //===----------------------------------------------------------------------===// 74 // Bufferization-specific attribute manipulation. 75 // These are for testing and debugging only. Bufferization information is stored 76 // in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is 77 // annotated with the results of the analysis, so that they can be checked in 78 // tests. 79 //===----------------------------------------------------------------------===// 80 81 /// Attribute marker to specify op operands that bufferize in-place. 82 constexpr StringLiteral kInPlaceOperandsAttrName = "__inplace_operands_attr__"; 83 84 constexpr StringLiteral kOpResultAliasSetAttrName = 85 "__opresult_alias_set_attr__"; 86 87 constexpr StringLiteral kBbArgAliasSetAttrName = "__bbarg_alias_set_attr__"; 88 89 /// Mark whether OpOperand will be bufferized inplace. 90 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) { 91 Operation *op = opOperand.getOwner(); 92 SmallVector<StringRef> inPlaceVector; 93 if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) { 94 inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>( 95 cast<ArrayAttr>(attr).getAsValueRange<StringAttr>())); 96 } else { 97 inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none"); 98 for (OpOperand &opOperand : op->getOpOperands()) 99 if (isa<TensorType>(opOperand.get().getType())) 100 inPlaceVector[opOperand.getOperandNumber()] = "false"; 101 } 102 inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false"; 103 op->setAttr(kInPlaceOperandsAttrName, 104 OpBuilder(op).getStrArrayAttr(inPlaceVector)); 105 } 106 107 //===----------------------------------------------------------------------===// 108 // OneShotAnalysisState 109 //===----------------------------------------------------------------------===// 110 111 OneShotAnalysisState::OneShotAnalysisState( 112 Operation *op, const OneShotBufferizationOptions &options) 113 : AnalysisState(options, TypeID::get<OneShotAnalysisState>()) { 114 // Set up alias sets. 115 op->walk([&](Operation *op) { 116 for (Value v : op->getResults()) 117 if (isa<TensorType>(v.getType())) 118 createAliasInfoEntry(v); 119 for (Region &r : op->getRegions()) 120 for (Block &b : r.getBlocks()) 121 for (auto bbArg : b.getArguments()) 122 if (isa<TensorType>(bbArg.getType())) 123 createAliasInfoEntry(bbArg); 124 }); 125 126 // Mark OpOperands in-place that must bufferize in-place. 127 op->walk([&](BufferizableOpInterface bufferizableOp) { 128 if (!options.isOpAllowed(bufferizableOp)) 129 return WalkResult::skip(); 130 for (OpOperand &opOperand : bufferizableOp->getOpOperands()) 131 if (isa<TensorType>(opOperand.get().getType())) 132 if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) 133 bufferizeInPlace(opOperand); 134 return WalkResult::advance(); 135 }); 136 } 137 138 void OneShotAnalysisState::applyOnEquivalenceClass( 139 Value v, function_ref<void(Value)> fun) const { 140 auto leaderIt = equivalentInfo.findLeader(v); 141 for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; 142 ++mit) { 143 fun(*mit); 144 } 145 } 146 147 void OneShotAnalysisState::applyOnAliases(Value v, 148 function_ref<void(Value)> fun) const { 149 auto leaderIt = aliasInfo.findLeader(v); 150 for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) { 151 fun(*mit); 152 } 153 } 154 155 bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1, 156 Value v2) const { 157 return equivalentInfo.isEquivalent(v1, v2); 158 } 159 160 bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1, 161 Value v2) const { 162 return aliasInfo.isEquivalent(v1, v2); 163 } 164 165 void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) { 166 if (inplaceBufferized.contains(&operand)) 167 return; 168 inplaceBufferized.insert(&operand); 169 for (AliasingValue alias : getAliasingValues(operand)) 170 aliasInfo.unionSets(alias.value, operand.get()); 171 ++statNumTensorInPlace; 172 } 173 174 void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) { 175 assert(!inplaceBufferized.contains(&operand) && 176 "OpOperand was already decided to bufferize inplace"); 177 ++statNumTensorOutOfPlace; 178 } 179 180 void OneShotAnalysisState::createAliasInfoEntry(Value v) { 181 aliasInfo.insert(v); 182 equivalentInfo.insert(v); 183 } 184 185 void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) { 186 op->walk([&](Operation *op) { 187 // Skip unknown ops. 188 auto bufferizableOp = getOptions().dynCastBufferizableOp(op); 189 if (!bufferizableOp) 190 return WalkResult::skip(); 191 192 // Check all tensor OpResults. 193 for (OpResult opResult : op->getOpResults()) { 194 if (!isa<TensorType>(opResult.getType())) 195 continue; 196 197 // If there is no preceding definition, the tensor contents are 198 // undefined. 199 if (opResult.getUses().empty()) 200 continue; 201 // It does not really matter which use to take to search about 202 // the value's definitions. 203 OpOperand *opOperand = &(*opResult.getUses().begin()); 204 if (findDefinitionsCached(opOperand).empty()) 205 for (OpOperand &use : opResult.getUses()) 206 undefinedTensorUses.insert(&use); 207 } 208 209 return WalkResult::advance(); 210 }); 211 } 212 213 bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 214 return undefinedTensorUses.contains(opOperand); 215 } 216 217 bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const { 218 return inplaceBufferized.contains(&opOperand); 219 } 220 221 bool OneShotAnalysisState::isValueWritten(Value value) const { 222 bool isWritten = false; 223 applyOnAliases(value, [&](Value val) { 224 for (OpOperand &use : val.getUses()) 225 if (isInPlace(use) && bufferizesToMemoryWrite(use)) 226 isWritten = true; 227 }); 228 return isWritten; 229 } 230 231 bool OneShotAnalysisState::isWritable(Value value) const { 232 // TODO: Out-of-place bufferized value could be considered writable. 233 // Query BufferizableOpInterface to see if the BlockArgument is writable. 234 if (auto bufferizableOp = 235 getOptions().dynCastBufferizableOp(getOwnerOfValue(value))) 236 return bufferizableOp.isWritable(value, *this); 237 238 // Not a bufferizable op: The conservative answer is "not writable". 239 return false; 240 } 241 242 void OneShotAnalysisState::unionAliasSets(Value v1, Value v2) { 243 aliasInfo.unionSets(v1, v2); 244 } 245 246 void OneShotAnalysisState::unionEquivalenceClasses(Value v1, Value v2) { 247 equivalentInfo.unionSets(v1, v2); 248 } 249 250 OneShotAnalysisState::Extension::~Extension() = default; 251 252 //===----------------------------------------------------------------------===// 253 // Bufferization-specific alias analysis. 254 //===----------------------------------------------------------------------===// 255 256 /// Return true if opOperand has been decided to bufferize in-place. 257 static bool isInplaceMemoryWrite(OpOperand &opOperand, 258 const OneShotAnalysisState &state) { 259 // OpOperands that do not bufferize to a memory write do not write in-place. 260 if (!state.bufferizesToMemoryWrite(opOperand)) 261 return false; 262 // Check current bufferization decisions. 263 return state.isInPlace(opOperand); 264 } 265 266 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors 267 /// properly dominates `b` and `b` is not inside `a`. 268 static bool happensBefore(Operation *a, Operation *b, 269 const DominanceInfo &domInfo) { 270 do { 271 // TODO: Instead of isProperAncestor + properlyDominates, we should use 272 // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false) 273 if (a->isProperAncestor(b)) 274 return false; 275 if (domInfo.properlyDominates(a, b)) 276 return true; 277 } while ((a = a->getParentOp())); 278 return false; 279 } 280 281 /// Return `true` if op dominance can be used to rule out a read-after-write 282 /// conflicts based on the ordering of ops. Returns `false` if op dominance 283 /// cannot be used to due region-based loops. 284 /// 285 /// Generalized op dominance can often be used to rule out potential conflicts 286 /// due to "read happens before write". E.g., the following IR is not a RaW 287 /// conflict because the read happens *before* the write. 288 /// 289 /// Example 1: 290 /// %0 = ... : tensor<?xf32> // DEF 291 /// "reading_op"(%0) : tensor<?xf32> // READ 292 /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE 293 /// 294 /// This is no longer true inside loops (or repetitive regions). In such cases, 295 /// there may not be a meaningful `happensBefore` relationship because ops 296 /// could be executed multiple times. E.g.: 297 /// 298 /// Example 2: 299 /// %0 = ... : tensor<?xf32> // DEF 300 /// scf.for ... { 301 /// "reading_op"(%0) : tensor<?xf32> // READ 302 /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE 303 /// ... 304 /// } 305 /// 306 /// In the above example, reading_op happens before writing_op according to 307 /// op dominance. However, both ops may happen multiple times; in 308 /// particular, the second execution of reading_op happens after the first 309 /// execution of writing_op. This is problematic because the tensor %0 they 310 /// operate on (i.e., the "definition") is defined outside of the loop. 311 /// 312 /// On a high-level, there is a potential RaW in a program if there exists a 313 /// possible program execution such that there is a sequence of DEF, followed 314 /// by WRITE, followed by READ. Each additional DEF resets the sequence. 315 /// 316 /// E.g.: 317 /// No conflict: DEF, WRITE, DEF, READ 318 /// Potential conflict: DEF, READ, WRITE, READ, WRITE 319 /// 320 /// Example 1 has no conflict: DEF, READ, WRITE 321 /// Example 2 has a potential conflict: DEF, (READ, WRITE)* 322 // 323 /// Example 3: 324 /// scf.for ... { 325 /// %0 = ... : tensor<?xf32> 326 /// "reading_op"(%0) : tensor<?xf32> 327 /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> 328 /// ... 329 /// } 330 /// This has no conflict: (DEF, READ, WRITE)* 331 /// 332 /// Example 4: 333 /// %0 = ... : tensor<?xf32> 334 /// scf.for ... { 335 /// scf.for ... { "reading_op"(%0) } 336 /// %1 = "writing_op"(%0) 337 /// } 338 /// This has a potential conflict: DEF, ((READ)*, WRITE)* 339 /// 340 /// Example 5: 341 /// %0 = ... : tensor<?xf32> 342 /// scf.for ... { %1 = "writing_op"(%0) } 343 /// scf.for ... { "reading_op"(%0) } 344 /// This has a potential conflict: DEF, WRITE*, READ* 345 /// 346 /// The following rules are used to rule out RaW conflicts via ordering of ops: 347 /// 348 /// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of 349 /// a repetitive region that enclosing both READ and WRITE, we cannot rule 350 /// out RaW conflict due to the ordering of ops. 351 /// 2. Otherwise: There are no loops that interfere with our analysis; for 352 /// analysis purposes, we can assume that there are no loops/repetitive 353 /// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE 354 /// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.) 355 /// 356 static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite, 357 const SetVector<Value> &definitions, 358 AnalysisState &state) { 359 const BufferizationOptions &options = state.getOptions(); 360 for (Value def : definitions) { 361 Region *rRead = 362 state.getEnclosingRepetitiveRegion(uRead->getOwner(), options); 363 Region *rDef = state.getEnclosingRepetitiveRegion(def, options); 364 365 // READ and DEF are in the same repetitive region. `happensBefore` can be 366 // used to rule out RaW conflicts due to op ordering. 367 if (rRead == rDef) 368 continue; 369 370 // Find the enclosing repetitive region of READ that is closest to DEF but 371 // not the repetitive region of DEF itself. 372 while (true) { 373 Region *nextRegion = getNextEnclosingRepetitiveRegion(rRead, options); 374 if (nextRegion == rDef) 375 break; 376 assert(nextRegion && "expected to find another repetitive region"); 377 rRead = nextRegion; 378 } 379 380 // We cannot use op dominance if WRITE is inside the same repetitive region. 381 if (rRead->getParentOp()->isAncestor(uWrite->getOwner())) 382 return false; 383 } 384 385 return true; 386 } 387 388 /// Return `true` if op dominance can be used to rule out a read-after-write 389 /// conflicts based on the ordering of ops. Returns `false` if op dominance 390 /// cannot be used to due block-based loops within a region. 391 /// 392 /// Refer to the `canUseOpDominanceDueToRegions` documentation for details on 393 /// how op domiance is used during RaW conflict detection. 394 /// 395 /// On a high-level, there is a potential RaW in a program if there exists a 396 /// possible program execution such that there is a sequence of DEF, followed 397 /// by WRITE, followed by READ. Each additional DEF resets the sequence. 398 /// 399 /// Op dominance cannot be used if there is a path from block(READ) to 400 /// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should 401 /// not appear on that path. 402 static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite, 403 const SetVector<Value> &definitions, 404 AnalysisState &state) { 405 // Fast path: If READ and WRITE are in different regions, their block cannot 406 // be reachable just via unstructured control flow. (Loops due to regions are 407 // covered by `canUseOpDominanceDueToRegions`.) 408 if (uRead->getOwner()->getParentRegion() != 409 uWrite->getOwner()->getParentRegion()) 410 return true; 411 412 Block *readBlock = uRead->getOwner()->getBlock(); 413 Block *writeBlock = uWrite->getOwner()->getBlock(); 414 for (Value def : definitions) { 415 Block *defBlock = def.getParentBlock(); 416 if (readBlock->isReachable(writeBlock, {defBlock}) && 417 writeBlock->isReachable(readBlock, {defBlock})) 418 return false; 419 } 420 421 return true; 422 } 423 424 static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, 425 const SetVector<Value> &definitions, 426 AnalysisState &state) { 427 return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) && 428 canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state); 429 } 430 431 /// Annotate IR with details about the detected RaW conflict. 432 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, 433 Value definition) { 434 static uint64_t counter = 0; 435 Operation *readingOp = uRead->getOwner(); 436 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 437 438 OpBuilder b(conflictingWritingOp->getContext()); 439 std::string id = "C_" + std::to_string(counter++); 440 441 std::string conflictingWriteAttr = 442 id + 443 "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) + 444 "]"; 445 conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr()); 446 447 std::string readAttr = 448 id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]"; 449 readingOp->setAttr(readAttr, b.getUnitAttr()); 450 451 if (auto opResult = dyn_cast<OpResult>(definition)) { 452 std::string defAttr = 453 id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]"; 454 opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr()); 455 } else { 456 auto bbArg = cast<BlockArgument>(definition); 457 std::string defAttr = 458 id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; 459 bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr()); 460 } 461 } 462 463 /// Return 'true' if a tensor that is equivalent to `other` can be found in the 464 /// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of 465 /// place along that use-def chain, the two tensors may not materialize as 466 /// equivalent buffers (but separate allocations). 467 /// 468 /// Note: This function also requires that the two tensors have equivalent 469 /// indexing. I.e., the tensor types do not change along the use-def chain, 470 /// apart from static <-> dynamic dim casts. 471 static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, 472 OpOperand *start, 473 Value other) { 474 TraversalConfig config; 475 config.followEquivalentOnly = true; 476 config.alwaysIncludeLeaves = false; 477 config.followSameTypeOrCastsOnly = true; 478 return !state 479 .findValueInReverseUseDefChain( 480 start, [&](Value v) { return v == other; }, config) 481 .empty(); 482 } 483 484 /// Return "true" if the given operand's value is originating from a subset 485 /// that is equivalent to the subset that `subsetOp` inserts into. 486 static bool matchesInsertDestination(const AnalysisState &state, 487 OpOperand *opOperand, 488 SubsetInsertionOpInterface subsetOp) { 489 auto matchingSubset = [&](Value val) { 490 if (auto opResult = dyn_cast<OpResult>(val)) 491 if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) { 492 return state.areEquivalentBufferizedValues(v1, v2); 493 })) 494 return true; 495 return false; 496 }; 497 // There may be multiple leaves at which the reverse SSA use-def chain lookup 498 // terminates. All of them must be equivalent subsets. 499 SetVector<Value> backwardSlice = 500 state.findValueInReverseUseDefChain(opOperand, matchingSubset); 501 return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset)); 502 } 503 504 /// Return "true" if the given "read" and potentially conflicting "write" are 505 /// not conflicting due to their subset relationship. The comments in this 506 /// function are expressed in terms of tensor.extract_slice/tensor.insert_slice 507 /// pairs, but apply to any subset ops that implement the 508 /// `SubsetInsertionOpInterface`. 509 static bool areNonConflictingSubsets(OpOperand *uRead, 510 OpOperand *uConflictingWrite, 511 const AnalysisState &state) { 512 Operation *readingOp = uRead->getOwner(); 513 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 514 515 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 516 // uRead is an InsertSliceOp... 517 if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) { 518 // As an example, consider the following IR. 519 // 520 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 521 // %1 = linalg.fill %cst, %0 {inplace= [true] } 522 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 523 // {inplace= [true] } 524 525 if (uRead == &subsetOp.getDestinationOperand() && 526 matchesInsertDestination(state, uConflictingWrite, subsetOp)) 527 // Case 1: The main insight is that InsertSliceOp reads only part of 528 // the destination tensor. The overwritten area is not read. If 529 // uConflictingWrite writes into exactly the memory location that is 530 // being read by uRead, this is not a conflict. 531 // 532 // In the above example: 533 // uRead = OpOperand 1 (%t) of tensor.insert_slice 534 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 535 // 536 // The read of %t does not conflict with the write of the FillOp 537 // (same aliases!) because the area that the FillOp operates on is 538 // exactly the one that is *not* read via %t. 539 return true; 540 541 if (uRead == &subsetOp.getSourceOperand() && 542 uConflictingWrite == &subsetOp.getDestinationOperand() && 543 matchesInsertDestination(state, uRead, subsetOp)) 544 // Case 2: The read of the source tensor and the write to the dest 545 // tensor via an InsertSliceOp is not a conflict if the read is 546 // reading exactly that part of an equivalent tensor that the 547 // InsertSliceOp is writing. 548 // 549 // In the above example: 550 // uRead = OpOperand 0 (%1) of tensor.insert_slice 551 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 552 return true; 553 } 554 555 // If uConflictingWrite is an InsertSliceOp... 556 if (auto subsetOp = 557 dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp)) 558 // As an example, consider the following IR. 559 // 560 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 561 // %1 = linalg.fill %cst, %0 {inplace= [true] } 562 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 563 // {inplace= [true] } 564 // %3 = vector.transfer_read %1, %cst 565 // 566 // In the above example: 567 // uRead = OpOperand 0 (%1) of vector.transfer_read 568 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 569 // definition = %1 570 // 571 // This is not a conflict because the InsertSliceOp overwrites the 572 // memory segment of %1 with the exact same data. (Effectively, there 573 // is no memory write here.) 574 if (uConflictingWrite == &subsetOp.getDestinationOperand() && 575 state.areEquivalentBufferizedValues( 576 uRead->get(), subsetOp.getSourceOperand().get()) && 577 matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp)) 578 return true; 579 580 return false; 581 } 582 583 /// Given sets of uses and writes, return true if there is a RaW conflict under 584 /// the assumption that all given reads/writes alias the same buffer and that 585 /// all given writes bufferize inplace. 586 /// 587 /// A conflict is: According to SSA use-def chains, a read R is supposed to read 588 /// the result of a definition W1. But because of bufferization decisions, R 589 /// actually reads another definition W2. 590 static bool 591 hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, 592 const DenseSet<OpOperand *> &usesWrite, 593 const DominanceInfo &domInfo, 594 OneShotAnalysisState &state) { 595 const BufferizationOptions &options = state.getOptions(); 596 597 // Before going through the main RaW analysis, find cases where a buffer must 598 // be privatized due to parallelism. If the result of a write is never read, 599 // privatization is not necessary (and large parts of the IR are likely dead). 600 if (options.checkParallelRegions && !usesRead.empty()) { 601 for (OpOperand *uConflictingWrite : usesWrite) { 602 // Find the allocation point or last write (definition) of the buffer. 603 // Note: In contrast to `findDefinitions`, this also returns results of 604 // ops that do not bufferize to memory write when no other definition 605 // could be found. E.g., "bufferization.alloc_tensor" would be included, 606 // even though that op just bufferizes to an allocation but does define 607 // the contents of the buffer. 608 SetVector<Value> definitionsOrLeaves = 609 state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) { 610 return state.bufferizesToMemoryWrite(v); 611 }); 612 assert(!definitionsOrLeaves.empty() && 613 "expected at least one definition or leaf"); 614 615 // The writing op must bufferize out-of-place if the definition is in a 616 // different parallel region than this write. 617 for (Value def : definitionsOrLeaves) { 618 if (getParallelRegion(def.getParentRegion(), options) != 619 getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(), 620 options)) { 621 LLVM_DEBUG( 622 llvm::dbgs() 623 << "\n- bufferizes out-of-place due to parallel region:\n"); 624 LLVM_DEBUG(llvm::dbgs() 625 << " unConflictingWrite = operand " 626 << uConflictingWrite->getOperandNumber() << " of " 627 << *uConflictingWrite->getOwner() << "\n"); 628 return true; 629 } 630 } 631 } 632 } 633 634 for (OpOperand *uRead : usesRead) { 635 Operation *readingOp = uRead->getOwner(); 636 LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n"); 637 LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber() 638 << " of " << *readingOp << "\n"); 639 640 // Find the definition of uRead by following the SSA use-def chain. 641 // E.g.: 642 // 643 // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32> 644 // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32> 645 // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type 646 // 647 // In the above example, if uRead is the OpOperand of reading_op, the 648 // definition is %0. Note that operations that create an alias but do not 649 // bufferize to a memory write (such as ExtractSliceOp) are skipped. 650 const SetVector<Value> &definitions = state.findDefinitionsCached(uRead); 651 if (definitions.empty()) { 652 // Fast path: No conflict if there are no definitions. 653 LLVM_DEBUG(llvm::dbgs() 654 << " no conflict: read value has no definitions\n"); 655 continue; 656 } 657 658 // Look for conflicting memory writes. Potential conflicts are writes to an 659 // alias that have been decided to bufferize inplace. 660 for (OpOperand *uConflictingWrite : usesWrite) { 661 LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand " 662 << uConflictingWrite->getOperandNumber() << " of " 663 << *uConflictingWrite->getOwner() << "\n"); 664 665 // Check if op dominance can be used to rule out read-after-write 666 // conflicts. 667 bool useDominance = 668 canUseOpDominance(uRead, uConflictingWrite, definitions, state); 669 LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n"); 670 671 // Throughout this loop, check for multiple requirements that have to be 672 // met for uConflictingWrite to be an actual conflict. 673 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 674 675 // Inside of repetitive regions, ops may be executed multiple times and op 676 // dominance cannot be used to rule out conflicts. 677 if (useDominance) { 678 // No conflict if the readingOp dominates conflictingWritingOp, i.e., 679 // the write is not visible when reading. 680 // 681 // Note: If ops are executed multiple times (e.g., because they are 682 // inside a loop), there may be no meaningful `happensBefore` 683 // relationship. 684 if (happensBefore(readingOp, conflictingWritingOp, domInfo)) { 685 LLVM_DEBUG(llvm::dbgs() 686 << " no conflict: read happens before write\n"); 687 continue; 688 } 689 690 // No conflict if the reading use equals the use of the conflicting 691 // write. A use cannot conflict with itself. 692 // 693 // Note: Just being the same op is not enough. It has to be the same 694 // use. 695 // Note: If the op is executed multiple times (e.g., because it is 696 // inside a loop), it may be conflicting with itself. 697 if (uConflictingWrite == uRead) { 698 LLVM_DEBUG(llvm::dbgs() 699 << " no conflict: read and write are same use\n"); 700 continue; 701 } 702 703 // Ops are not conflicting if they are in mutually exclusive regions. 704 // 705 // Note: If ops are executed multiple times (e.g., because they are 706 // inside a loop), mutually exclusive regions may be executed 707 // multiple times. 708 if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) { 709 LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in " 710 "mutually exclusive regions\n"); 711 continue; 712 } 713 714 // Two equivalent operands of the same op are not conflicting if the op 715 // bufferizes to element-wise access. I.e., all loads at a position 716 // happen before all stores to the same position. 717 if (conflictingWritingOp == readingOp) { 718 if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) { 719 if (bufferizableOp.bufferizesToElementwiseAccess( 720 state, {uRead, uConflictingWrite})) { 721 if (hasEquivalentValueInReverseUseDefChain( 722 state, uRead, uConflictingWrite->get()) || 723 hasEquivalentValueInReverseUseDefChain( 724 state, uConflictingWrite, uRead->get())) { 725 LLVM_DEBUG( 726 llvm::dbgs() 727 << " no conflict: op bufferizes to element-wise access\n"); 728 continue; 729 } 730 } 731 } 732 } 733 } 734 735 // No conflict if the operands are non-conflicting subsets. 736 if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) { 737 LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n"); 738 continue; 739 } 740 741 // No conflict if the op interface says so. 742 if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) { 743 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) { 744 LLVM_DEBUG(llvm::dbgs() 745 << " no conflict: op interace of reading op says 'no'\n"); 746 continue; 747 } 748 } 749 750 if (conflictingWritingOp != readingOp) { 751 if (auto bufferizableOp = 752 options.dynCastBufferizableOp(conflictingWritingOp)) { 753 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, 754 state)) { 755 LLVM_DEBUG( 756 llvm::dbgs() 757 << " no conflict: op interace of writing op says 'no'\n"); 758 continue; 759 } 760 } 761 } 762 763 // Check all possible definitions. 764 for (Value definition : definitions) { 765 LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n"); 766 767 // No conflict if the conflicting write happens before the definition. 768 if (Operation *defOp = definition.getDefiningOp()) { 769 if (happensBefore(conflictingWritingOp, defOp, domInfo)) { 770 // conflictingWritingOp happens before defOp. No conflict. 771 LLVM_DEBUG(llvm::dbgs() 772 << " no conflict: write happens before definition\n"); 773 continue; 774 } 775 // No conflict if conflictingWritingOp is contained in defOp. 776 if (defOp->isProperAncestor(conflictingWritingOp)) { 777 LLVM_DEBUG( 778 llvm::dbgs() 779 << " no conflict: write is contained in definition\n"); 780 continue; 781 } 782 } else { 783 auto bbArg = cast<BlockArgument>(definition); 784 Block *block = bbArg.getOwner(); 785 if (!block->findAncestorOpInBlock(*conflictingWritingOp)) { 786 LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg " 787 "and write happens outside of block\n"); 788 // conflictingWritingOp happens outside of the block. No 789 // conflict. 790 continue; 791 } 792 } 793 794 // No conflict if the conflicting write and the definition are the same 795 // use. 796 AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite); 797 if (aliases.getNumAliases() == 1 && 798 aliases.getAliases()[0].value == definition) { 799 LLVM_DEBUG(llvm::dbgs() 800 << " no conflict: definition and write are same\n"); 801 continue; 802 } 803 804 // All requirements are met. Conflict found! 805 806 if (options.printConflicts) 807 annotateConflict(uRead, uConflictingWrite, definition); 808 LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n"); 809 return true; 810 } 811 } 812 } 813 814 return false; 815 } 816 817 // Helper function to iterate on aliases of `root` and capture the writes. 818 static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root, 819 const OneShotAnalysisState &state) { 820 state.applyOnAliases(root, [&](Value alias) { 821 for (auto &use : alias.getUses()) 822 // Inplace write to a value that aliases root. 823 if (isInplaceMemoryWrite(use, state)) 824 res.insert(&use); 825 }); 826 } 827 828 // Helper function to iterate on aliases of `root` and capture the reads. 829 static void getAliasingReads(DenseSet<OpOperand *> &res, Value root, 830 const OneShotAnalysisState &state) { 831 state.applyOnAliases(root, [&](Value alias) { 832 for (auto &use : alias.getUses()) { 833 // Read of a value that aliases root. 834 if (state.bufferizesToMemoryRead(use)) { 835 res.insert(&use); 836 continue; 837 } 838 839 // Read of a dependent value in the SSA use-def chain. E.g.: 840 // 841 // %0 = ... 842 // %1 = tensor.extract_slice %0 {not_analyzed_yet} 843 // "read"(%1) 844 // 845 // In the above example, getAliasingReads(%0) includes the first OpOperand 846 // of the tensor.extract_slice op. The extract_slice itself does not read 847 // but its aliasing result is eventually fed into an op that does. 848 // 849 // Note: This is considered a "read" only if the use does not bufferize to 850 // a memory write. (We already ruled out memory reads. In case of a memory 851 // write, the buffer would be entirely overwritten; in the above example 852 // there would then be no flow of data from the extract_slice operand to 853 // its result's uses.) 854 if (!state.bufferizesToMemoryWrite(use)) { 855 AliasingValueList aliases = state.getAliasingValues(use); 856 if (llvm::any_of(aliases, [&](AliasingValue a) { 857 return state.isValueRead(a.value); 858 })) 859 res.insert(&use); 860 } 861 } 862 }); 863 } 864 865 /// Return true if bufferizing `operand` inplace would create a conflict. A read 866 /// R and a write W of the same alias set is a conflict if inplace bufferization 867 /// of W changes the value read by R to a value different from the one that 868 /// would be expected by tracing back R's origin through SSA use-def chains. 869 /// A conflict can only be introduced by a new alias and/or an inplace 870 /// bufferization decision. 871 /// 872 /// Example: 873 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?} 874 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32> 875 /// %e = tensor.extract_slice %1 876 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32> 877 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32> 878 /// 879 /// In the above example, the two TransferWriteOps have already been decided to 880 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a 881 /// conflict because: 882 /// * According to SSA use-def chains, we expect to read the result of %1. 883 /// * However, adding an alias {%0, %t} would mean that the second 884 /// TransferWriteOp overwrites the result of the first one. Therefore, the 885 /// TransferReadOp would no longer be reading the result of %1. 886 /// 887 /// If `checkConsistencyOnly` is true, this function checks if there is a 888 /// read-after-write conflict without bufferizing `operand` inplace. This would 889 /// indicate a problem with the current inplace bufferization decisions. 890 /// 891 /// Note: If `checkConsistencyOnly`, this function may be called with a null 892 /// OpResult. In that case, only the consistency of bufferization decisions 893 /// involving aliases of the given OpOperand are checked. 894 static bool wouldCreateReadAfterWriteInterference( 895 OpOperand &operand, const DominanceInfo &domInfo, 896 OneShotAnalysisState &state, bool checkConsistencyOnly = false) { 897 // Collect reads and writes of all aliases of OpOperand and OpResult. 898 DenseSet<OpOperand *> usesRead, usesWrite; 899 getAliasingReads(usesRead, operand.get(), state); 900 getAliasingInplaceWrites(usesWrite, operand.get(), state); 901 for (AliasingValue alias : state.getAliasingValues(operand)) { 902 getAliasingReads(usesRead, alias.value, state); 903 getAliasingInplaceWrites(usesWrite, alias.value, state); 904 } 905 if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) 906 usesWrite.insert(&operand); 907 908 return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state); 909 } 910 911 /// Annotate IR with details about the detected non-writability conflict. 912 static void annotateNonWritableTensor(Value value) { 913 static int64_t counter = 0; 914 OpBuilder b(value.getContext()); 915 std::string id = "W_" + std::to_string(counter++); 916 if (auto opResult = dyn_cast<OpResult>(value)) { 917 std::string attr = id + "[NOT-WRITABLE: result " + 918 std::to_string(opResult.getResultNumber()) + "]"; 919 opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr()); 920 } else { 921 auto bbArg = cast<BlockArgument>(value); 922 std::string attr = id + "[NOT-WRITABLE: bbArg " + 923 std::to_string(bbArg.getArgNumber()) + "]"; 924 bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr()); 925 } 926 } 927 928 /// Return true if bufferizing `operand` inplace would create a write to a 929 /// non-writable buffer. 930 static bool 931 wouldCreateWriteToNonWritableBuffer(OpOperand &operand, 932 OneShotAnalysisState &state, 933 bool checkConsistencyOnly = false) { 934 bool foundWrite = 935 !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand); 936 937 if (!foundWrite) { 938 // Collect writes of all aliases of OpOperand and OpResult. 939 DenseSet<OpOperand *> usesWrite; 940 getAliasingInplaceWrites(usesWrite, operand.get(), state); 941 for (AliasingValue alias : state.getAliasingValues(operand)) 942 getAliasingInplaceWrites(usesWrite, alias.value, state); 943 foundWrite = !usesWrite.empty(); 944 } 945 946 if (!foundWrite) 947 return false; 948 949 // Look for a read-only tensor among all aliases. 950 bool foundReadOnly = false; 951 auto checkReadOnly = [&](Value v) { 952 if (!state.isWritable(v)) { 953 foundReadOnly = true; 954 if (state.getOptions().printConflicts) 955 annotateNonWritableTensor(v); 956 } 957 }; 958 state.applyOnAliases(operand.get(), checkReadOnly); 959 for (AliasingValue alias : state.getAliasingValues(operand)) 960 state.applyOnAliases(alias.value, checkReadOnly); 961 if (foundReadOnly) { 962 LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n"); 963 return true; 964 } 965 966 return false; 967 } 968 969 //===----------------------------------------------------------------------===// 970 // Bufferization analyses. 971 //===----------------------------------------------------------------------===// 972 973 // Find the values that define the contents of the given operand's value. 974 const llvm::SetVector<Value> & 975 OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) { 976 Value value = opOperand->get(); 977 if (!cachedDefinitions.count(value)) 978 cachedDefinitions[value] = findDefinitions(opOperand); 979 return cachedDefinitions[value]; 980 } 981 982 void OneShotAnalysisState::resetCache() { 983 AnalysisState::resetCache(); 984 cachedDefinitions.clear(); 985 } 986 987 /// Determine if `operand` can be bufferized in-place. 988 static LogicalResult 989 bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, 990 const DominanceInfo &domInfo) { 991 LLVM_DEBUG( 992 llvm::dbgs() << "//===-------------------------------------------===//\n" 993 << "Analyzing operand #" << operand.getOperandNumber() 994 << " of " << *operand.getOwner() << "\n"); 995 996 bool foundInterference = 997 wouldCreateWriteToNonWritableBuffer(operand, state) || 998 wouldCreateReadAfterWriteInterference(operand, domInfo, state); 999 1000 if (foundInterference) 1001 state.bufferizeOutOfPlace(operand); 1002 else 1003 state.bufferizeInPlace(operand); 1004 1005 LLVM_DEBUG(llvm::dbgs() 1006 << "//===-------------------------------------------===//\n"); 1007 return success(); 1008 } 1009 1010 LogicalResult 1011 OneShotAnalysisState::analyzeSingleOp(Operation *op, 1012 const DominanceInfo &domInfo) { 1013 for (OpOperand &opOperand : op->getOpOperands()) 1014 if (isa<TensorType>(opOperand.get().getType())) 1015 if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo))) 1016 return failure(); 1017 return success(); 1018 } 1019 1020 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. 1021 static void equivalenceAnalysis(SmallVector<Operation *> &ops, 1022 OneShotAnalysisState &state) { 1023 for (Operation *op : ops) { 1024 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) { 1025 for (OpResult opResult : op->getOpResults()) { 1026 if (!isa<TensorType>(opResult.getType())) 1027 continue; 1028 AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); 1029 if (aliases.getNumAliases() == 0) 1030 // Nothing to do if there are no aliasing OpOperands. 1031 continue; 1032 1033 Value firstOperand = aliases.begin()->opOperand->get(); 1034 bool allEquivalent = true; 1035 for (AliasingOpOperand alias : aliases) { 1036 bool isEquiv = alias.relation == BufferRelation::Equivalent; 1037 bool isInPlace = state.isInPlace(*alias.opOperand); 1038 Value operand = alias.opOperand->get(); 1039 if (isEquiv && isInPlace && alias.isDefinite) { 1040 // Found a definite, equivalent alias. Merge equivalence sets. 1041 // There can only be one definite alias, so we can stop here. 1042 state.unionEquivalenceClasses(opResult, operand); 1043 allEquivalent = false; 1044 break; 1045 } 1046 if (!isEquiv || !isInPlace) 1047 allEquivalent = false; 1048 if (!state.areEquivalentBufferizedValues(operand, firstOperand)) 1049 allEquivalent = false; 1050 } 1051 1052 // If all "maybe" aliases are equivalent and the OpResult is not a new 1053 // allocation, it is a definite, equivalent alias. E.g.: 1054 // 1055 // aliasingOpOperands(%r) = {(%t0, EQUIV, MAYBE), (%t1, EQUIV, MAYBE)} 1056 // aliasingValues(%t0) = {(%r, EQUIV, MAYBE)} 1057 // aliasingValues(%t1) = {(%r, EQUIV, MAYBE)} 1058 // %r = arith.select %c, %t0, %t1 : tensor<?xf32> 1059 // 1060 // If %t0 and %t1 are equivalent, it is safe to union the equivalence 1061 // classes of %r, %t0 and %t1. 1062 if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult)) 1063 state.unionEquivalenceClasses(opResult, firstOperand); 1064 } 1065 } 1066 } 1067 } 1068 1069 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained 1070 /// in `op`. 1071 static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) { 1072 // Traverse ops in PostOrder: Nested ops first, then enclosing ops. 1073 SmallVector<Operation *> ops; 1074 op->walk<WalkOrder::PostOrder>([&](Operation *op) { 1075 // No tensors => no buffers. 1076 if (none_of(op->getResultTypes(), isaTensor)) 1077 return; 1078 ops.push_back(op); 1079 }); 1080 1081 equivalenceAnalysis(ops, state); 1082 } 1083 1084 /// "Bottom-up from terminators" heuristic. 1085 static SmallVector<Operation *> 1086 bottomUpFromTerminatorsHeuristic(Operation *op, 1087 const OneShotAnalysisState &state) { 1088 SetVector<Operation *> traversedOps; 1089 1090 // Find region terminators. 1091 op->walk<WalkOrder::PostOrder>([&](RegionBranchTerminatorOpInterface term) { 1092 if (!traversedOps.insert(term)) 1093 return; 1094 // Follow the reverse SSA use-def chain from each yielded value as long as 1095 // we stay within the same region. 1096 SmallVector<OpResult> worklist; 1097 for (Value v : term->getOperands()) { 1098 if (!isa<TensorType>(v.getType())) 1099 continue; 1100 auto opResult = dyn_cast<OpResult>(v); 1101 if (!opResult) 1102 continue; 1103 worklist.push_back(opResult); 1104 } 1105 while (!worklist.empty()) { 1106 OpResult opResult = worklist.pop_back_val(); 1107 Operation *defOp = opResult.getDefiningOp(); 1108 if (!traversedOps.insert(defOp)) 1109 continue; 1110 if (!term->getParentRegion()->findAncestorOpInRegion(*defOp)) 1111 continue; 1112 AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); 1113 for (auto alias : aliases) { 1114 Value v = alias.opOperand->get(); 1115 if (!isa<TensorType>(v.getType())) 1116 continue; 1117 auto opResult = dyn_cast<OpResult>(v); 1118 if (!opResult) 1119 continue; 1120 worklist.push_back(opResult); 1121 } 1122 } 1123 }); 1124 1125 // Analyze traversed ops, then all remaining ops. 1126 SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end()); 1127 op->walk<WalkOrder::PostOrder, ReverseIterator>([&](Operation *op) { 1128 if (!traversedOps.contains(op) && hasTensorSemantics(op)) 1129 result.push_back(op); 1130 }); 1131 return result; 1132 } 1133 1134 LogicalResult OneShotAnalysisState::analyzeOp(Operation *op, 1135 const DominanceInfo &domInfo) { 1136 OneShotBufferizationOptions::AnalysisHeuristic heuristic = 1137 getOptions().analysisHeuristic; 1138 1139 SmallVector<Operation *> orderedOps; 1140 if (heuristic == 1141 OneShotBufferizationOptions::AnalysisHeuristic::BottomUpFromTerminators) { 1142 orderedOps = bottomUpFromTerminatorsHeuristic(op, *this); 1143 } else { 1144 op->walk([&](Operation *op) { 1145 // No tensors => no buffers. 1146 if (!hasTensorSemantics(op)) 1147 return; 1148 orderedOps.push_back(op); 1149 }); 1150 switch (heuristic) { 1151 case OneShotBufferizationOptions::AnalysisHeuristic::BottomUp: { 1152 // Default: Walk ops in reverse for better interference analysis. 1153 std::reverse(orderedOps.begin(), orderedOps.end()); 1154 break; 1155 } 1156 case OneShotBufferizationOptions::AnalysisHeuristic::TopDown: { 1157 // Ops are already sorted top-down in `orderedOps`. 1158 break; 1159 } 1160 case OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer: { 1161 assert(getOptions().analysisFuzzerSeed && 1162 "expected that fuzzer seed it set"); 1163 // This is a fuzzer. For testing purposes only. Randomize the order in 1164 // which operations are analyzed. The bufferization quality is likely 1165 // worse, but we want to make sure that no assertions are triggered 1166 // anywhere. 1167 std::mt19937 g(getOptions().analysisFuzzerSeed); 1168 llvm::shuffle(orderedOps.begin(), orderedOps.end(), g); 1169 break; 1170 } 1171 default: { 1172 llvm_unreachable("unsupported heuristic"); 1173 } 1174 } 1175 } 1176 1177 // Analyze ops in the computed order. 1178 for (Operation *op : orderedOps) 1179 if (failed(analyzeSingleOp(op, domInfo))) 1180 return failure(); 1181 1182 equivalenceAnalysis(op, *this); 1183 return success(); 1184 } 1185 1186 /// Perform various checks on the input IR to see if it contains IR constructs 1187 /// that are unsupported by One-Shot Bufferize. 1188 static LogicalResult 1189 checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo, 1190 OneShotAnalysisState &state) { 1191 const BufferizationOptions &options = state.getOptions(); 1192 1193 // Note: This walk cannot be combined with the one below because interface 1194 // methods of invalid/unsupported ops may be called during the second walk. 1195 // (On ops different from `op`.) 1196 WalkResult walkResult = op->walk([&](BufferizableOpInterface op) { 1197 // Skip ops that are not in the filter. 1198 if (!options.isOpAllowed(op.getOperation())) 1199 return WalkResult::advance(); 1200 1201 // Check for unsupported unstructured control flow. 1202 if (!op.supportsUnstructuredControlFlow()) { 1203 for (Region &r : op->getRegions()) { 1204 if (r.getBlocks().size() > 1) { 1205 op->emitOpError("op or BufferizableOpInterface implementation does " 1206 "not support unstructured control flow, but at least " 1207 "one region has multiple blocks"); 1208 return WalkResult::interrupt(); 1209 } 1210 } 1211 } 1212 1213 return WalkResult::advance(); 1214 }); 1215 if (walkResult.wasInterrupted()) 1216 return failure(); 1217 1218 walkResult = op->walk([&](BufferizableOpInterface op) { 1219 // Skip ops that are not in the filter. 1220 if (!options.isOpAllowed(op.getOperation())) 1221 return WalkResult::advance(); 1222 1223 // Input IR may not contain any ToTensorOps without the "restrict" 1224 // attribute. Such tensors may alias any other tensor, which is currently 1225 // not handled in the analysis. 1226 if (auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) { 1227 if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) { 1228 op->emitOpError("to_tensor ops without `restrict` are not supported by " 1229 "One-Shot Analysis"); 1230 return WalkResult::interrupt(); 1231 } 1232 } 1233 1234 for (OpOperand &opOperand : op->getOpOperands()) { 1235 if (isa<TensorType>(opOperand.get().getType())) { 1236 if (wouldCreateReadAfterWriteInterference( 1237 opOperand, domInfo, state, 1238 /*checkConsistencyOnly=*/true)) { 1239 // This error can happen if certain "mustBufferizeInPlace" interface 1240 // methods are implemented incorrectly, such that the IR already has 1241 // a RaW conflict before making any bufferization decisions. It can 1242 // also happen if the bufferization.materialize_in_destination is used 1243 // in such a way that a RaW conflict is not avoidable. 1244 op->emitOpError("not bufferizable under the given constraints: " 1245 "cannot avoid RaW conflict"); 1246 return WalkResult::interrupt(); 1247 } 1248 1249 if (state.isInPlace(opOperand) && 1250 wouldCreateWriteToNonWritableBuffer( 1251 opOperand, state, /*checkConsistencyOnly=*/true)) { 1252 op->emitOpError("not bufferizable under the given constraints: would " 1253 "write to read-only buffer"); 1254 return WalkResult::interrupt(); 1255 } 1256 } 1257 } 1258 1259 return WalkResult::advance(); 1260 }); 1261 1262 return success(!walkResult.wasInterrupted()); 1263 } 1264 1265 /// Annotate the IR with the result of the analysis. For testing/debugging only. 1266 static void 1267 annotateOpsWithBufferizationMarkers(Operation *op, 1268 const OneShotAnalysisState &state) { 1269 // Add __inplace_operands_attr__. 1270 op->walk([&](Operation *op) { 1271 for (OpOperand &opOperand : op->getOpOperands()) 1272 if (isa<TensorType>(opOperand.get().getType())) 1273 setInPlaceOpOperand(opOperand, state.isInPlace(opOperand)); 1274 }); 1275 } 1276 1277 static void annotateOpsWithAliasSets(Operation *op, 1278 const OneShotAnalysisState &state) { 1279 AsmState asmState(op); 1280 Builder b(op->getContext()); 1281 // Helper function to build an array attribute of aliasing SSA value strings. 1282 auto buildAliasesArray = [&](Value v) { 1283 SmallVector<Attribute> aliases; 1284 state.applyOnAliases(v, [&](Value alias) { 1285 std::string buffer; 1286 llvm::raw_string_ostream stream(buffer); 1287 alias.printAsOperand(stream, asmState); 1288 aliases.push_back(b.getStringAttr(buffer)); 1289 }); 1290 return b.getArrayAttr(aliases); 1291 }; 1292 1293 op->walk([&](Operation *op) { 1294 // Build alias set array for every OpResult. 1295 SmallVector<Attribute> opResultAliasSets; 1296 for (OpResult opResult : op->getOpResults()) { 1297 if (llvm::isa<TensorType>(opResult.getType())) { 1298 opResultAliasSets.push_back(buildAliasesArray(opResult)); 1299 } 1300 } 1301 if (!opResultAliasSets.empty()) 1302 op->setAttr(kOpResultAliasSetAttrName, b.getArrayAttr(opResultAliasSets)); 1303 1304 // Build alias set array for every BlockArgument. 1305 SmallVector<Attribute> regionAliasSets; 1306 bool hasTensorBbArg = false; 1307 for (Region &r : op->getRegions()) { 1308 SmallVector<Attribute> blockAliasSets; 1309 for (Block &block : r.getBlocks()) { 1310 SmallVector<Attribute> bbArgAliasSets; 1311 for (BlockArgument bbArg : block.getArguments()) { 1312 if (llvm::isa<TensorType>(bbArg.getType())) { 1313 bbArgAliasSets.push_back(buildAliasesArray(bbArg)); 1314 hasTensorBbArg = true; 1315 } 1316 } 1317 blockAliasSets.push_back(b.getArrayAttr(bbArgAliasSets)); 1318 } 1319 regionAliasSets.push_back(b.getArrayAttr(blockAliasSets)); 1320 } 1321 if (hasTensorBbArg) 1322 op->setAttr(kBbArgAliasSetAttrName, b.getArrayAttr(regionAliasSets)); 1323 }); 1324 } 1325 1326 LogicalResult bufferization::analyzeOp(Operation *op, 1327 OneShotAnalysisState &state, 1328 BufferizationStatistics *statistics) { 1329 DominanceInfo domInfo(op); 1330 const OneShotBufferizationOptions &options = state.getOptions(); 1331 1332 if (failed(checkPreBufferizationAssumptions(op, domInfo, state))) 1333 return failure(); 1334 1335 // If the analysis fails, just return. 1336 if (failed(state.analyzeOp(op, domInfo))) 1337 return failure(); 1338 1339 if (statistics) { 1340 statistics->numTensorInPlace = state.getStatNumTensorInPlace(); 1341 statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace(); 1342 } 1343 1344 bool failedAnalysis = false; 1345 1346 // Gather some extra analysis data. 1347 state.gatherUndefinedTensorUses(op); 1348 1349 // Analysis verification: After setting up alias/equivalence sets, each op 1350 // can check for expected invariants/limitations and fail the analysis if 1351 // necessary. 1352 op->walk([&](Operation *op) { 1353 if (BufferizableOpInterface bufferizableOp = 1354 options.dynCastBufferizableOp(op)) 1355 failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state)); 1356 }); 1357 1358 // Annotate operations if we only want to report the analysis. 1359 if (options.testAnalysisOnly) 1360 annotateOpsWithBufferizationMarkers(op, state); 1361 if (options.dumpAliasSets) 1362 annotateOpsWithAliasSets(op, state); 1363 1364 return success(!failedAnalysis); 1365 } 1366 1367 LogicalResult 1368 bufferization::runOneShotBufferize(Operation *op, 1369 const OneShotBufferizationOptions &options, 1370 BufferizationStatistics *statistics) { 1371 // copy-before-write deactivates the analysis. It cannot be used together with 1372 // test-analysis-only. 1373 assert(!(options.copyBeforeWrite && options.testAnalysisOnly) && 1374 "invalid combination of bufferization flags"); 1375 1376 if (options.copyBeforeWrite) { 1377 // Copy buffer before each write. No analysis is needed. 1378 } else { 1379 // Run One-Shot Analysis and insert buffer copies (on the tensor level) 1380 // only where needed. This is the default and much more efficient than 1381 // copy-before-write. 1382 if (failed(insertTensorCopies(op, options, statistics))) 1383 return failure(); 1384 1385 // If test-analysis-only is set, the IR was annotated with RaW conflict 1386 // markers (attributes) during One-Shot Analysis. 1387 if (options.testAnalysisOnly) 1388 return success(); 1389 } 1390 1391 // Bufferize the op and its nested ops. If options.copyBeforeWrite is set, 1392 // a new buffer copy is allocated every time a buffer is written to. 1393 return bufferizeOp(op, options, statistics); 1394 } 1395