1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AsmState.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/IRMapping.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/IR/Value.h" 20 #include "mlir/Interfaces/ControlFlowInterfaces.h" 21 #include "llvm/ADT/ScopeExit.h" 22 #include "llvm/Support/Debug.h" 23 24 //===----------------------------------------------------------------------===// 25 // BufferizableOpInterface 26 //===----------------------------------------------------------------------===// 27 28 namespace mlir { 29 namespace bufferization { 30 31 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 32 33 } // namespace bufferization 34 } // namespace mlir 35 36 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState) 37 38 #define DEBUG_TYPE "bufferizable-op-interface" 39 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 40 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 41 42 using namespace mlir; 43 using namespace bufferization; 44 45 static bool isRepetitiveRegion(Region *region, 46 const BufferizationOptions &options) { 47 Operation *op = region->getParentOp(); 48 if (auto bufferizableOp = options.dynCastBufferizableOp(op)) 49 if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber())) 50 return true; 51 return false; 52 } 53 54 Region *AnalysisState::getEnclosingRepetitiveRegion( 55 Operation *op, const BufferizationOptions &options) { 56 if (!op->getBlock()) 57 return nullptr; 58 if (auto iter = enclosingRepetitiveRegionCache.find_as(op); 59 iter != enclosingRepetitiveRegionCache.end()) 60 return iter->second; 61 return enclosingRepetitiveRegionCache[op] = 62 getEnclosingRepetitiveRegion(op->getBlock(), options); 63 } 64 65 Region *AnalysisState::getEnclosingRepetitiveRegion( 66 Value value, const BufferizationOptions &options) { 67 if (auto iter = enclosingRepetitiveRegionCache.find_as(value); 68 iter != enclosingRepetitiveRegionCache.end()) 69 return iter->second; 70 71 Region *region = value.getParentRegion(); 72 // Collect all visited regions since we only know the repetitive region we 73 // want to map it to later on 74 SmallVector<Region *> visitedRegions; 75 while (region) { 76 visitedRegions.push_back(region); 77 if (isRepetitiveRegion(region, options)) 78 break; 79 region = region->getParentRegion(); 80 } 81 enclosingRepetitiveRegionCache[value] = region; 82 for (Region *r : visitedRegions) 83 enclosingRepetitiveRegionCache[r] = region; 84 return region; 85 } 86 87 Region *AnalysisState::getEnclosingRepetitiveRegion( 88 Block *block, const BufferizationOptions &options) { 89 if (auto iter = enclosingRepetitiveRegionCache.find_as(block); 90 iter != enclosingRepetitiveRegionCache.end()) 91 return iter->second; 92 93 Region *region = block->getParent(); 94 Operation *op = nullptr; 95 // Collect all visited regions since we only know the repetitive region we 96 // want to map it to later on 97 SmallVector<Region *> visitedRegions; 98 do { 99 op = region->getParentOp(); 100 if (isRepetitiveRegion(region, options)) 101 break; 102 } while ((region = op->getParentRegion())); 103 104 enclosingRepetitiveRegionCache[block] = region; 105 for (Region *r : visitedRegions) 106 enclosingRepetitiveRegionCache[r] = region; 107 return region; 108 } 109 110 void AnalysisState::resetCache() { enclosingRepetitiveRegionCache.clear(); } 111 112 Region *bufferization::getNextEnclosingRepetitiveRegion( 113 Region *region, const BufferizationOptions &options) { 114 assert(isRepetitiveRegion(region, options) && "expected repetitive region"); 115 while ((region = region->getParentRegion())) { 116 if (isRepetitiveRegion(region, options)) 117 break; 118 } 119 return region; 120 } 121 122 Region *bufferization::getParallelRegion(Region *region, 123 const BufferizationOptions &options) { 124 while (region) { 125 auto bufferizableOp = options.dynCastBufferizableOp(region->getParentOp()); 126 if (bufferizableOp && 127 bufferizableOp.isParallelRegion(region->getRegionNumber())) { 128 assert(isRepetitiveRegion(region, options) && 129 "expected that all parallel regions are also repetitive regions"); 130 return region; 131 } 132 region = region->getParentRegion(); 133 } 134 return nullptr; 135 } 136 137 Operation *bufferization::getOwnerOfValue(Value value) { 138 if (auto opResult = llvm::dyn_cast<OpResult>(value)) 139 return opResult.getDefiningOp(); 140 return llvm::cast<BlockArgument>(value).getOwner()->getParentOp(); 141 } 142 143 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the 144 /// shaped value is copied. Otherwise, a tensor with undefined contents is 145 /// allocated. 146 FailureOr<Value> bufferization::allocateTensorForShapedValue( 147 OpBuilder &b, Location loc, Value shapedValue, 148 const BufferizationOptions &options, bool copy) { 149 Value tensor; 150 if (llvm::isa<RankedTensorType>(shapedValue.getType())) { 151 tensor = shapedValue; 152 } else if (llvm::isa<MemRefType>(shapedValue.getType())) { 153 tensor = b.create<ToTensorOp>(loc, shapedValue); 154 } else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) || 155 llvm::isa<UnrankedMemRefType>(shapedValue.getType())) { 156 return getOwnerOfValue(shapedValue) 157 ->emitError("copying of unranked tensors is not implemented"); 158 } else { 159 llvm_unreachable("expected RankedTensorType or MemRefType"); 160 } 161 RankedTensorType tensorType = llvm::cast<RankedTensorType>(tensor.getType()); 162 SmallVector<Value> dynamicSizes; 163 if (!copy) { 164 // Compute the dynamic part of the shape. 165 // First try to query the shape via ReifyRankedShapedTypeOpInterface. 166 bool reifiedShapes = false; 167 if (llvm::isa<RankedTensorType>(shapedValue.getType()) && 168 llvm::isa<OpResult>(shapedValue)) { 169 ReifiedRankedShapedTypeDims resultDims; 170 if (succeeded( 171 reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) { 172 reifiedShapes = true; 173 auto &shape = 174 resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()]; 175 for (const auto &dim : enumerate(tensorType.getShape())) 176 if (ShapedType::isDynamic(dim.value())) 177 dynamicSizes.push_back(cast<Value>(shape[dim.index()])); 178 } 179 } 180 181 // If the shape could not be reified, create DimOps. 182 if (!reifiedShapes) 183 populateDynamicDimSizes(b, loc, tensor, dynamicSizes); 184 } 185 186 // Create AllocTensorOp. 187 auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes, 188 copy ? tensor : Value()); 189 190 // Add 'memory_space' attribute. Not needed if 'copy' operand is specified. 191 if (copy) 192 return allocTensorOp.getResult(); 193 FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options); 194 if (failed(copyBufferType)) 195 return failure(); 196 std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace(); 197 if (!memorySpace) 198 memorySpace = options.defaultMemorySpaceFn(tensorType); 199 if (memorySpace.has_value()) 200 allocTensorOp.setMemorySpaceAttr(memorySpace.value()); 201 return allocTensorOp.getResult(); 202 } 203 204 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( 205 RewriterBase &rewriter, const AnalysisState &state) { 206 OpBuilder::InsertionGuard g(rewriter); 207 Operation *op = getOperation(); 208 SmallVector<OpOperand *> outOfPlaceOpOperands; 209 DenseSet<OpOperand *> copiedOpOperands; 210 SmallVector<Value> outOfPlaceValues; 211 DenseSet<Value> copiedOpValues; 212 213 // Find all out-of-place OpOperands. 214 for (OpOperand &opOperand : op->getOpOperands()) { 215 Type operandType = opOperand.get().getType(); 216 if (!llvm::isa<TensorType>(operandType)) 217 continue; 218 if (state.isInPlace(opOperand)) 219 continue; 220 if (llvm::isa<UnrankedTensorType>(operandType)) 221 return op->emitError("copying of unranked tensors is not implemented"); 222 223 AliasingValueList aliasingValues = state.getAliasingValues(opOperand); 224 if (aliasingValues.getNumAliases() == 1 && 225 isa<OpResult>(aliasingValues.getAliases()[0].value) && 226 !state.bufferizesToMemoryWrite(opOperand) && 227 state.getAliasingOpOperands(aliasingValues.getAliases()[0].value) 228 .getNumAliases() == 1 && 229 !isa<UnrankedTensorType>( 230 aliasingValues.getAliases()[0].value.getType())) { 231 // The op itself does not write but may create exactly one alias. Instead 232 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes 233 // be smaller than the OpOperand (e.g., in the case of an extract_slice, 234 // where the result is usually a smaller part of the source). Do not apply 235 // this optimization if the OpResult is an unranked tensor (because those 236 // cannot be copied at the moment). 237 Value value = aliasingValues.getAliases()[0].value; 238 outOfPlaceValues.push_back(value); 239 if (!state.canOmitTensorCopy(opOperand)) 240 copiedOpValues.insert(value); 241 } else { 242 // In all other cases, make a copy of the OpOperand. 243 outOfPlaceOpOperands.push_back(&opOperand); 244 if (!state.canOmitTensorCopy(opOperand)) 245 copiedOpOperands.insert(&opOperand); 246 } 247 } 248 249 // Insert copies of OpOperands. 250 rewriter.setInsertionPoint(op); 251 for (OpOperand *opOperand : outOfPlaceOpOperands) { 252 FailureOr<Value> copy = allocateTensorForShapedValue( 253 rewriter, op->getLoc(), opOperand->get(), state.getOptions(), 254 copiedOpOperands.contains(opOperand)); 255 if (failed(copy)) 256 return failure(); 257 rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); }); 258 } 259 260 // Insert copies of Values. 261 rewriter.setInsertionPointAfter(op); 262 for (Value value : outOfPlaceValues) { 263 FailureOr<Value> copy = allocateTensorForShapedValue( 264 rewriter, op->getLoc(), value, state.getOptions(), 265 copiedOpValues.count(value)); 266 if (failed(copy)) 267 return failure(); 268 SmallVector<OpOperand *> uses = llvm::to_vector( 269 llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; })); 270 for (OpOperand *use : uses) { 271 // Do not update the alloc_tensor op that we just created. 272 if (use->getOwner() == copy->getDefiningOp()) 273 continue; 274 // tensor.dim ops may have been created to be used as alloc_tensor op 275 // dynamic extents. Do not update these either. 276 if (isa<tensor::DimOp>(use->getOwner())) 277 continue; 278 rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); }); 279 } 280 } 281 282 return success(); 283 } 284 285 //===----------------------------------------------------------------------===// 286 // OpFilter 287 //===----------------------------------------------------------------------===// 288 289 bool OpFilter::isOpAllowed(Operation *op) const { 290 // All other ops: Allow/disallow according to filter. 291 bool isAllowed = !hasAllowRule(); 292 for (const Entry &entry : entries) { 293 bool filterResult = entry.fn(op); 294 switch (entry.type) { 295 case Entry::ALLOW: 296 isAllowed |= filterResult; 297 break; 298 case Entry::DENY: 299 if (filterResult) 300 // DENY filter matches. This op is no allowed. (Even if other ALLOW 301 // filters may match.) 302 return false; 303 }; 304 } 305 return isAllowed; 306 } 307 308 //===----------------------------------------------------------------------===// 309 // BufferizationOptions 310 //===----------------------------------------------------------------------===// 311 312 namespace { 313 314 /// Default function arg type converter: Use a fully dynamic layout map. 315 BaseMemRefType 316 defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace, 317 func::FuncOp funcOp, 318 const BufferizationOptions &options) { 319 return getMemRefTypeWithFullyDynamicLayout(type, memorySpace); 320 } 321 /// Default unknown type converter: Use a fully dynamic layout map. 322 BaseMemRefType 323 defaultUnknownTypeConverter(Value value, Attribute memorySpace, 324 const BufferizationOptions &options) { 325 return getMemRefTypeWithFullyDynamicLayout( 326 llvm::cast<TensorType>(value.getType()), memorySpace); 327 } 328 329 } // namespace 330 331 // Default constructor for BufferizationOptions. 332 BufferizationOptions::BufferizationOptions() 333 : functionArgTypeConverterFn(defaultFunctionArgTypeConverter), 334 unknownTypeConverterFn(defaultUnknownTypeConverter) {} 335 336 bool BufferizationOptions::isOpAllowed(Operation *op) const { 337 // Special case: If function boundary bufferization is deactivated, do not 338 // allow ops that belong to the `func` dialect. 339 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 340 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 341 return false; 342 343 return opFilter.isOpAllowed(op); 344 } 345 346 BufferizableOpInterface 347 BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 348 if (!isOpAllowed(op)) 349 return nullptr; 350 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 351 if (!bufferizableOp) 352 return nullptr; 353 return bufferizableOp; 354 } 355 356 BufferizableOpInterface 357 BufferizationOptions::dynCastBufferizableOp(Value value) const { 358 return dynCastBufferizableOp(getOwnerOfValue(value)); 359 } 360 361 void BufferizationOptions::setFunctionBoundaryTypeConversion( 362 LayoutMapOption layoutMapOption) { 363 functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace, 364 func::FuncOp funcOp, 365 const BufferizationOptions &options) { 366 if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) 367 return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, 368 memorySpace); 369 return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, 370 memorySpace); 371 }; 372 inferFunctionResultLayout = 373 layoutMapOption == LayoutMapOption::InferLayoutMap; 374 } 375 376 //===----------------------------------------------------------------------===// 377 // Helper functions for BufferizableOpInterface 378 //===----------------------------------------------------------------------===// 379 380 static void setInsertionPointAfter(OpBuilder &b, Value value) { 381 if (auto bbArg = llvm::dyn_cast<BlockArgument>(value)) { 382 b.setInsertionPointToStart(bbArg.getOwner()); 383 } else { 384 b.setInsertionPointAfter(value.getDefiningOp()); 385 } 386 } 387 388 /// Determine which OpOperand* will alias with `value` if the op is bufferized 389 /// in place. Return all tensor OpOperand* if the op is not bufferizable. 390 AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const { 391 if (Operation *op = getOwnerOfValue(value)) 392 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 393 return bufferizableOp.getAliasingOpOperands(value, *this); 394 395 // The op is not bufferizable. 396 return detail::unknownGetAliasingOpOperands(value); 397 } 398 399 /// Determine which Values will alias with `opOperand` if the op is bufferized 400 /// in place. Return all tensor Values if the op is not bufferizable. 401 AliasingValueList AnalysisState::getAliasingValues(OpOperand &opOperand) const { 402 if (auto bufferizableOp = 403 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 404 return bufferizableOp.getAliasingValues(opOperand, *this); 405 406 // The op is not bufferizable. 407 return detail::unknownGetAliasingValues(opOperand); 408 } 409 410 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 411 /// op is not bufferizable. 412 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 413 if (auto bufferizableOp = 414 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 415 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 416 417 // Unknown op that returns a tensor. The inplace analysis does not support it. 418 // Conservatively return true. 419 return true; 420 } 421 422 /// Return true if `opOperand` bufferizes to a memory write. Return 423 /// `true` if the op is not bufferizable. 424 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 425 if (auto bufferizableOp = 426 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 427 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 428 429 // Unknown op that returns a tensor. The inplace analysis does not support it. 430 // Conservatively return true. 431 return true; 432 } 433 434 /// Return true if `opOperand` does neither read nor write but bufferizes to an 435 /// alias. Return false if the op is not bufferizable. 436 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 437 if (auto bufferizableOp = 438 getOptions().dynCastBufferizableOp(opOperand.getOwner())) 439 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 440 441 // Unknown op that returns a tensor. The inplace analysis does not support it. 442 // Conservatively return false. 443 return false; 444 } 445 446 bool AnalysisState::bufferizesToMemoryWrite(Value value) const { 447 auto opResult = llvm::dyn_cast<OpResult>(value); 448 if (!opResult) 449 return true; 450 auto bufferizableOp = getOptions().dynCastBufferizableOp(value); 451 if (!bufferizableOp) 452 return true; 453 return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *this); 454 } 455 456 /// Return true if the given value is read by an op that bufferizes to a memory 457 /// read. Also takes into account ops that create an alias but do not read by 458 /// themselves (e.g., ExtractSliceOp). 459 bool AnalysisState::isValueRead(Value value) const { 460 assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType"); 461 SmallVector<OpOperand *> workingSet; 462 DenseSet<OpOperand *> visited; 463 for (OpOperand &use : value.getUses()) 464 workingSet.push_back(&use); 465 466 while (!workingSet.empty()) { 467 OpOperand *uMaybeReading = workingSet.pop_back_val(); 468 if (!visited.insert(uMaybeReading).second) 469 continue; 470 471 // Skip over all ops that neither read nor write (but create an alias). 472 if (bufferizesToAliasOnly(*uMaybeReading)) 473 for (AliasingValue alias : getAliasingValues(*uMaybeReading)) 474 for (OpOperand &use : alias.value.getUses()) 475 workingSet.push_back(&use); 476 if (bufferizesToMemoryRead(*uMaybeReading)) 477 return true; 478 } 479 480 return false; 481 } 482 483 // Starting from `opOperand`, follow the use-def chain in reverse, always 484 // selecting the aliasing OpOperands. Find and return Values for which 485 // `condition` evaluates to true. Uses of such matching Values are not 486 // traversed any further, the visited aliasing opOperands will be preserved 487 // through `visitedOpOperands`. 488 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 489 OpOperand *opOperand, llvm::function_ref<bool(Value)> condition, 490 TraversalConfig config, 491 llvm::DenseSet<OpOperand *> *visitedOpOperands) const { 492 llvm::DenseSet<Value> visited; 493 llvm::SetVector<Value> result, workingSet; 494 workingSet.insert(opOperand->get()); 495 496 if (visitedOpOperands) 497 visitedOpOperands->insert(opOperand); 498 499 while (!workingSet.empty()) { 500 Value value = workingSet.pop_back_val(); 501 502 if (!config.revisitAlreadyVisitedValues && visited.contains(value)) { 503 // Stop traversal if value was already visited. 504 if (config.alwaysIncludeLeaves) 505 result.insert(value); 506 continue; 507 } 508 visited.insert(value); 509 510 if (condition(value)) { 511 result.insert(value); 512 continue; 513 } 514 515 if (!config.followUnknownOps && !options.dynCastBufferizableOp(value)) { 516 // Stop iterating if `followUnknownOps` is unset and the op is either 517 // not bufferizable or excluded in the OpFilter. 518 if (config.alwaysIncludeLeaves) 519 result.insert(value); 520 continue; 521 } 522 523 AliasingOpOperandList aliases = getAliasingOpOperands(value); 524 if (aliases.getNumAliases() == 0) { 525 // The traversal ends naturally if there are no more OpOperands that 526 // could be followed. 527 if (config.alwaysIncludeLeaves) 528 result.insert(value); 529 continue; 530 } 531 532 for (AliasingOpOperand a : aliases) { 533 if (config.followEquivalentOnly && 534 a.relation != BufferRelation::Equivalent) { 535 // Stop iterating if `followEquivalentOnly` is set but the alias is not 536 // equivalent. 537 if (config.alwaysIncludeLeaves) 538 result.insert(value); 539 continue; 540 } 541 542 if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) { 543 // Stop iterating if `followInPlaceOnly` is set but the alias is 544 // out-of-place. 545 if (config.alwaysIncludeLeaves) 546 result.insert(value); 547 continue; 548 } 549 550 if (config.followSameTypeOrCastsOnly && 551 a.opOperand->get().getType() != value.getType() && 552 !value.getDefiningOp<CastOpInterface>()) { 553 // Stop iterating if `followSameTypeOrCastsOnly` is set but the alias is 554 // has a different type and the op is not a cast. 555 if (config.alwaysIncludeLeaves) 556 result.insert(value); 557 continue; 558 } 559 560 workingSet.insert(a.opOperand->get()); 561 if (visitedOpOperands) 562 visitedOpOperands->insert(a.opOperand); 563 } 564 } 565 566 return result; 567 } 568 569 // Find the values that define the contents of the given operand's value. 570 llvm::SetVector<Value> 571 AnalysisState::findDefinitions(OpOperand *opOperand) const { 572 TraversalConfig config; 573 config.alwaysIncludeLeaves = false; 574 return findValueInReverseUseDefChain( 575 opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, 576 config); 577 } 578 579 AnalysisState::AnalysisState(const BufferizationOptions &options) 580 : AnalysisState(options, TypeID::get<AnalysisState>()) {} 581 582 AnalysisState::AnalysisState(const BufferizationOptions &options, TypeID type) 583 : options(options), type(type) { 584 for (const BufferizationOptions::AnalysisStateInitFn &fn : 585 options.stateInitializers) 586 fn(*this); 587 } 588 589 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const { 590 // Do not copy if the tensor has undefined contents. 591 if (hasUndefinedContents(&opOperand)) 592 return true; 593 594 // Do not copy if the buffer of the tensor is entirely overwritten (with 595 // values that do not depend on the old tensor). 596 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) 597 return true; 598 599 // Do not copy if the tensor is never read. 600 AliasingValueList aliases = getAliasingValues(opOperand); 601 if (!bufferizesToMemoryRead(opOperand) && 602 llvm::none_of(aliases, 603 [&](AliasingValue a) { return isValueRead(a.value); })) 604 return true; 605 606 // Default: Cannot omit the copy. 607 return false; 608 } 609 610 bool AnalysisState::isInPlace(OpOperand &opOperand) const { 611 // ToMemrefOps are always in-place. 612 if (isa<ToMemrefOp>(opOperand.getOwner())) 613 return true; 614 615 // In the absence of analysis information, OpOperands that bufferize to a 616 // memory write are out-of-place, i.e., an alloc and copy is inserted. 617 return !bufferizesToMemoryWrite(opOperand); 618 } 619 620 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const { 621 // In the absence of analysis information, we do not know if the values are 622 // equivalent. The conservative answer is "false". 623 return false; 624 } 625 626 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const { 627 // In the absence of analysis information, we do not know if the values may be 628 // aliasing. The conservative answer is "true". 629 return true; 630 } 631 632 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 633 // In the absence of analysis information, the conservative answer is "false". 634 return false; 635 } 636 637 // bufferization.to_memref is not allowed to change the rank. 638 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 639 #ifndef NDEBUG 640 auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType()); 641 assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() == 642 rankedTensorType.getRank()) && 643 "to_memref would be invalid: mismatching ranks"); 644 #endif 645 } 646 647 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, 648 const BufferizationOptions &options) { 649 #ifndef NDEBUG 650 auto tensorType = llvm::dyn_cast<TensorType>(value.getType()); 651 assert(tensorType && "unexpected non-tensor type"); 652 #endif // NDEBUG 653 654 // Replace "%t = to_tensor %m" with %m. 655 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 656 return toTensorOp.getMemref(); 657 658 // Insert to_memref op. 659 OpBuilder::InsertionGuard g(rewriter); 660 setInsertionPointAfter(rewriter, value); 661 FailureOr<BaseMemRefType> memrefType = getBufferType(value, options); 662 if (failed(memrefType)) 663 return failure(); 664 ensureToMemrefOpIsValid(value, *memrefType); 665 return rewriter 666 .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value) 667 .getResult(); 668 } 669 670 /// Return the buffer type for a given Value (tensor) after bufferization. 671 FailureOr<BaseMemRefType> 672 bufferization::getBufferType(Value value, const BufferizationOptions &options) { 673 SmallVector<Value> invocationStack; 674 return getBufferType(value, options, invocationStack); 675 } 676 677 /// Return the buffer type for a given Value (tensor) after bufferization. 678 FailureOr<BaseMemRefType> 679 bufferization::getBufferType(Value value, const BufferizationOptions &options, 680 SmallVector<Value> &invocationStack) { 681 assert(llvm::isa<TensorType>(value.getType()) && 682 "unexpected non-tensor type"); 683 invocationStack.push_back(value); 684 auto popFromStack = 685 llvm::make_scope_exit([&]() { invocationStack.pop_back(); }); 686 687 // Try querying BufferizableOpInterface. 688 Operation *op = getOwnerOfValue(value); 689 auto bufferizableOp = options.dynCastBufferizableOp(op); 690 if (bufferizableOp) 691 return bufferizableOp.getBufferType(value, options, invocationStack); 692 693 // Op is not bufferizable. 694 auto memSpace = 695 options.defaultMemorySpaceFn(cast<TensorType>(value.getType())); 696 if (!memSpace.has_value()) 697 return op->emitError("could not infer memory space"); 698 699 return getMemRefType(value, options, /*layout=*/{}, *memSpace); 700 } 701 702 bool bufferization::hasTensorSemantics(Operation *op) { 703 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) 704 return bufferizableOp.hasTensorSemantics(); 705 return detail::defaultHasTensorSemantics(op); 706 } 707 708 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 709 Operation *op, 710 ValueRange values) { 711 assert(values.size() == op->getNumResults() && 712 "expected one value per OpResult"); 713 OpBuilder::InsertionGuard g(rewriter); 714 715 // Replace all OpResults with the given values. 716 SmallVector<Value> replacements; 717 for (OpResult opResult : op->getOpResults()) { 718 Value replacement = values[opResult.getResultNumber()]; 719 if (llvm::isa<TensorType>(opResult.getType())) { 720 // The OpResult is a tensor. Such values are replaced with memrefs during 721 // bufferization. 722 assert((llvm::isa<MemRefType>(replacement.getType()) || 723 llvm::isa<UnrankedMemRefType>(replacement.getType())) && 724 "tensor op result should be replaced with a memref value"); 725 // The existing uses of the OpResult still expect a tensor. Insert a 726 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 727 // loose all of its users and eventually DCE away. 728 rewriter.setInsertionPointAfter(op); 729 replacement = rewriter.create<bufferization::ToTensorOp>( 730 replacement.getLoc(), opResult.getType(), replacement); 731 } 732 replacements.push_back(replacement); 733 } 734 735 rewriter.replaceOp(op, replacements); 736 } 737 738 //===----------------------------------------------------------------------===// 739 // Bufferization-specific scoped alloc insertion support. 740 //===----------------------------------------------------------------------===// 741 742 /// Create a memref allocation with the given type and dynamic extents. 743 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 744 MemRefType type, 745 ValueRange dynShape) const { 746 if (allocationFn) 747 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 748 749 // Default bufferallocation via AllocOp. 750 if (bufferAlignment != 0) 751 return b 752 .create<memref::AllocOp>(loc, type, dynShape, 753 b.getI64IntegerAttr(bufferAlignment)) 754 .getResult(); 755 return b.create<memref::AllocOp>(loc, type, dynShape).getResult(); 756 } 757 758 /// Create a memory copy between two memref buffers. 759 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 760 Value from, Value to) const { 761 if (memCpyFn) 762 return (*memCpyFn)(b, loc, from, to); 763 764 b.create<memref::CopyOp>(loc, from, to); 765 return success(); 766 } 767 768 //===----------------------------------------------------------------------===// 769 // Bufferization-specific IRMapping support with debugging. 770 //===----------------------------------------------------------------------===// 771 772 BaseMemRefType bufferization::getMemRefType(Value value, 773 const BufferizationOptions &options, 774 MemRefLayoutAttrInterface layout, 775 Attribute memorySpace) { 776 auto tensorType = llvm::cast<TensorType>(value.getType()); 777 778 // Case 1: Unranked memref type. 779 if (auto unrankedTensorType = 780 llvm::dyn_cast<UnrankedTensorType>(tensorType)) { 781 assert(!layout && "UnrankedTensorType cannot have a layout map"); 782 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 783 memorySpace); 784 } 785 786 // Case 2: Ranked memref type with specified layout. 787 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType); 788 if (layout) { 789 return MemRefType::get(rankedTensorType.getShape(), 790 rankedTensorType.getElementType(), layout, 791 memorySpace); 792 } 793 794 return options.unknownTypeConverterFn(value, memorySpace, options); 795 } 796 797 BaseMemRefType 798 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 799 Attribute memorySpace) { 800 // Case 1: Unranked memref type. 801 if (auto unrankedTensorType = 802 llvm::dyn_cast<UnrankedTensorType>(tensorType)) { 803 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 804 memorySpace); 805 } 806 807 // Case 2: Ranked memref type. 808 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType); 809 int64_t dynamicOffset = ShapedType::kDynamic; 810 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 811 ShapedType::kDynamic); 812 auto stridedLayout = StridedLayoutAttr::get(tensorType.getContext(), 813 dynamicOffset, dynamicStrides); 814 return MemRefType::get(rankedTensorType.getShape(), 815 rankedTensorType.getElementType(), stridedLayout, 816 memorySpace); 817 } 818 819 /// Return a MemRef type with a static identity layout (i.e., no layout map). If 820 /// the given tensor type is unranked, return an unranked MemRef type. 821 BaseMemRefType 822 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 823 Attribute memorySpace) { 824 // Case 1: Unranked memref type. 825 if (auto unrankedTensorType = 826 llvm::dyn_cast<UnrankedTensorType>(tensorType)) { 827 return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 828 memorySpace); 829 } 830 831 // Case 2: Ranked memref type. 832 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType); 833 MemRefLayoutAttrInterface layout = {}; 834 return MemRefType::get(rankedTensorType.getShape(), 835 rankedTensorType.getElementType(), layout, 836 memorySpace); 837 } 838 839 //===----------------------------------------------------------------------===// 840 // Default implementations of interface methods 841 //===----------------------------------------------------------------------===// 842 843 bool bufferization::detail::defaultResultBufferizesToMemoryWrite( 844 OpResult opResult, const AnalysisState &state) { 845 auto bufferizableOp = cast<BufferizableOpInterface>(opResult.getDefiningOp()); 846 AliasingOpOperandList opOperands = 847 bufferizableOp.getAliasingOpOperands(opResult, state); 848 849 // Case 1: OpResults that have no aliasing OpOperand usually bufferize to 850 // memory writes. 851 if (opOperands.getAliases().empty()) 852 return true; 853 854 // Case 2: If an aliasing OpOperand bufferizes to a memory write, the OpResult 855 // may bufferize to a memory write. 856 if (llvm::any_of(opOperands, [&](AliasingOpOperand alias) { 857 return state.bufferizesToMemoryWrite(*alias.opOperand); 858 })) 859 return true; 860 861 // Case 3: Check if a nested aliasing OpOperand value bufferizes to a memory 862 // write. (Or: The reverse SSA use-def chain ends inside the reigon.) In that 863 // case, the OpResult bufferizes to a memory write. E.g.: 864 // 865 // %0 = "some_writing_op" : tensor<?xf32> 866 // %r = scf.if ... -> tensor<?xf32> { 867 // scf.yield %0 : tensor<?xf32> 868 // } else { 869 // %1 = "another_writing_op"(%0) : tensor<?xf32> 870 // scf.yield %1 : tensor<?xf32> 871 // } 872 // "some_reading_op"(%r) 873 // 874 // %r bufferizes to a memory write because an aliasing OpOperand value (%1) 875 // bufferizes to a memory write and the defining op is inside the scf.if. 876 // 877 // Note: This treatment of surrouding ops is useful for ops that have a 878 // region but no OpOperand such as scf.if or scf.execute_region. It simplifies 879 // the analysis considerably. 880 // 881 // "another_writing_op" in the above example should be able to bufferize 882 // inplace in the absence of another read of %0. However, if the scf.if op 883 // would not be considered a "write", the analysis would detect the 884 // following conflict: 885 // 886 // * read = some_reading_op 887 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 888 // * conflictingWrite = %1 889 // 890 auto isMemoryWriteInsideOp = [&](Value v) { 891 Operation *op = getOwnerOfValue(v); 892 if (!opResult.getDefiningOp()->isAncestor(op)) 893 return false; 894 return state.bufferizesToMemoryWrite(v); 895 }; 896 TraversalConfig config; 897 config.alwaysIncludeLeaves = false; 898 for (AliasingOpOperand alias : opOperands) { 899 if (!state 900 .findValueInReverseUseDefChain(alias.opOperand, 901 isMemoryWriteInsideOp, config) 902 .empty()) 903 return true; 904 } 905 return false; 906 } 907 908 // Compute the AliasingOpOperandList for a given Value based on 909 // getAliasingValues. 910 AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands( 911 Value value, const AnalysisState &state) { 912 Operation *op = getOwnerOfValue(value); 913 SmallVector<AliasingOpOperand> result; 914 for (OpOperand &opOperand : op->getOpOperands()) { 915 if (!llvm::isa<TensorType>(opOperand.get().getType())) 916 continue; 917 AliasingValueList aliasingValues = state.getAliasingValues(opOperand); 918 for (const auto &it : aliasingValues) 919 if (it.value == value) 920 result.emplace_back(&opOperand, it.relation, it.isDefinite); 921 } 922 return AliasingOpOperandList(std::move(result)); 923 } 924 925 FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType( 926 Value value, const BufferizationOptions &options, 927 SmallVector<Value> &invocationStack) { 928 assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type"); 929 930 // No further analysis is possible for a block argument. 931 if (llvm::isa<BlockArgument>(value)) 932 return bufferization::getMemRefType(value, options); 933 934 // Value is an OpResult. 935 Operation *op = getOwnerOfValue(value); 936 auto opResult = llvm::cast<OpResult>(value); 937 AnalysisState state(options); 938 AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); 939 if (aliases.getNumAliases() > 0 && 940 aliases.getAliases()[0].relation == BufferRelation::Equivalent) { 941 // If the OpResult has an equivalent OpOperand, both OpResult and 942 // OpOperand bufferize to the exact same buffer type. 943 Value equivalentOperand = aliases.getAliases().front().opOperand->get(); 944 return getBufferType(equivalentOperand, options, invocationStack); 945 } 946 947 // If we do not know the memory space and there is no default memory space, 948 // report a failure. 949 auto memSpace = 950 options.defaultMemorySpaceFn(cast<TensorType>(value.getType())); 951 if (!memSpace.has_value()) 952 return op->emitError("could not infer memory space"); 953 954 return getMemRefType(value, options, /*layout=*/{}, *memSpace); 955 } 956 957 bool bufferization::detail::defaultIsRepetitiveRegion( 958 BufferizableOpInterface bufferizableOp, unsigned index) { 959 assert(index < bufferizableOp->getNumRegions() && "invalid region index"); 960 auto regionInterface = 961 dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation()); 962 if (!regionInterface) 963 return false; 964 return regionInterface.isRepetitiveRegion(index); 965 } 966 967 AliasingOpOperandList 968 bufferization::detail::unknownGetAliasingOpOperands(Value value) { 969 // TODO: Take into account successor blocks. 970 // No aliasing in case of non-entry blocks. 971 if (auto bbArg = dyn_cast<BlockArgument>(value)) 972 if (bbArg.getOwner() != &bbArg.getOwner()->getParent()->getBlocks().front()) 973 return {}; 974 975 // Unknown op: Conservatively assume that each OpResult may alias with every 976 // OpOperand. In addition, each block argument of an entry block may alias 977 // with every OpOperand. 978 AliasingOpOperandList r; 979 for (OpOperand &operand : value.getDefiningOp()->getOpOperands()) 980 if (isa<TensorType>(operand.get().getType())) 981 r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false}); 982 return r; 983 } 984 985 AliasingValueList 986 bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) { 987 // TODO: Take into account successor blocks. 988 // Unknown op: Conservatively assume that each OpResult may alias with every 989 // OpOperand. In addition, each block argument of an entry block may alias 990 // with every OpOperand. 991 AliasingValueList r; 992 for (OpResult result : opOperand.getOwner()->getOpResults()) 993 if (llvm::isa<TensorType>(result.getType())) 994 r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false}); 995 for (Region ®ion : opOperand.getOwner()->getRegions()) 996 if (!region.getBlocks().empty()) 997 for (BlockArgument bbArg : region.getBlocks().front().getArguments()) 998 if (isa<TensorType>(bbArg.getType())) 999 r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false}); 1000 return r; 1001 } 1002 1003 bool bufferization::detail::defaultHasTensorSemantics(Operation *op) { 1004 auto isaTensor = [](Type t) { return isa<TensorType>(t); }; 1005 bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) { 1006 return any_of(r.getBlocks(), [&](Block &b) { 1007 return any_of(b.getArguments(), [&](BlockArgument bbArg) { 1008 return isaTensor(bbArg.getType()); 1009 }); 1010 }); 1011 }); 1012 if (hasTensorBlockArgument) 1013 return true; 1014 1015 if (any_of(op->getResultTypes(), isaTensor)) 1016 return true; 1017 return any_of(op->getOperandTypes(), isaTensor); 1018 } 1019