17a1579acSMatthias Springer //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===// 27a1579acSMatthias Springer // 37a1579acSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 47a1579acSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 57a1579acSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 67a1579acSMatthias Springer // 77a1579acSMatthias Springer //===----------------------------------------------------------------------===// 87a1579acSMatthias Springer 97a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 107a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 1136550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 127a1579acSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 1379f11591SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 147a1579acSMatthias Springer #include "mlir/IR/AsmState.h" 157a1579acSMatthias Springer #include "mlir/IR/BuiltinOps.h" 164d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 177a1579acSMatthias Springer #include "mlir/IR/Operation.h" 187a1579acSMatthias Springer #include "mlir/IR/TypeUtilities.h" 197a1579acSMatthias Springer #include "mlir/IR/Value.h" 20f7f0c7f7SMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h" 21878950b8SMatthias Springer #include "llvm/ADT/ScopeExit.h" 227a1579acSMatthias Springer #include "llvm/Support/Debug.h" 237a1579acSMatthias Springer 2487c770bbSMatthias Springer //===----------------------------------------------------------------------===// 2587c770bbSMatthias Springer // BufferizableOpInterface 2687c770bbSMatthias Springer //===----------------------------------------------------------------------===// 2787c770bbSMatthias Springer 287a1579acSMatthias Springer namespace mlir { 297a1579acSMatthias Springer namespace bufferization { 307a1579acSMatthias Springer 317a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc" 327a1579acSMatthias Springer 337a1579acSMatthias Springer } // namespace bufferization 347a1579acSMatthias Springer } // namespace mlir 357a1579acSMatthias Springer 36faa9be75SMatthias Springer MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState) 37faa9be75SMatthias Springer 387a1579acSMatthias Springer #define DEBUG_TYPE "bufferizable-op-interface" 397a1579acSMatthias Springer #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 407a1579acSMatthias Springer #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 417a1579acSMatthias Springer 427a1579acSMatthias Springer using namespace mlir; 437a1579acSMatthias Springer using namespace bufferization; 447a1579acSMatthias Springer 45c89c31a2SMatthias Springer static bool isRepetitiveRegion(Region *region, 46c89c31a2SMatthias Springer const BufferizationOptions &options) { 47c89c31a2SMatthias Springer Operation *op = region->getParentOp(); 48c89c31a2SMatthias Springer if (auto bufferizableOp = options.dynCastBufferizableOp(op)) 49c89c31a2SMatthias Springer if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber())) 50c89c31a2SMatthias Springer return true; 51c89c31a2SMatthias Springer return false; 52c89c31a2SMatthias Springer } 53c89c31a2SMatthias Springer 549312b4f9SMartin Erhart Region *AnalysisState::getEnclosingRepetitiveRegion( 5545cd0e45SMatthias Springer Operation *op, const BufferizationOptions &options) { 5645cd0e45SMatthias Springer if (!op->getBlock()) 5745cd0e45SMatthias Springer return nullptr; 589312b4f9SMartin Erhart if (auto iter = enclosingRepetitiveRegionCache.find_as(op); 599312b4f9SMartin Erhart iter != enclosingRepetitiveRegionCache.end()) 609312b4f9SMartin Erhart return iter->second; 619312b4f9SMartin Erhart return enclosingRepetitiveRegionCache[op] = 629312b4f9SMartin Erhart getEnclosingRepetitiveRegion(op->getBlock(), options); 6345cd0e45SMatthias Springer } 6445cd0e45SMatthias Springer 659312b4f9SMartin Erhart Region *AnalysisState::getEnclosingRepetitiveRegion( 6645cd0e45SMatthias Springer Value value, const BufferizationOptions &options) { 679312b4f9SMartin Erhart if (auto iter = enclosingRepetitiveRegionCache.find_as(value); 689312b4f9SMartin Erhart iter != enclosingRepetitiveRegionCache.end()) 699312b4f9SMartin Erhart return iter->second; 709312b4f9SMartin Erhart 7145cd0e45SMatthias Springer Region *region = value.getParentRegion(); 729312b4f9SMartin Erhart // Collect all visited regions since we only know the repetitive region we 739312b4f9SMartin Erhart // want to map it to later on 749312b4f9SMartin Erhart SmallVector<Region *> visitedRegions; 7545cd0e45SMatthias Springer while (region) { 769312b4f9SMartin Erhart visitedRegions.push_back(region); 77c89c31a2SMatthias Springer if (isRepetitiveRegion(region, options)) 789312b4f9SMartin Erhart break; 79c89c31a2SMatthias Springer region = region->getParentRegion(); 8045cd0e45SMatthias Springer } 819312b4f9SMartin Erhart enclosingRepetitiveRegionCache[value] = region; 829312b4f9SMartin Erhart for (Region *r : visitedRegions) 839312b4f9SMartin Erhart enclosingRepetitiveRegionCache[r] = region; 849312b4f9SMartin Erhart return region; 8545cd0e45SMatthias Springer } 8645cd0e45SMatthias Springer 879312b4f9SMartin Erhart Region *AnalysisState::getEnclosingRepetitiveRegion( 8845cd0e45SMatthias Springer Block *block, const BufferizationOptions &options) { 899312b4f9SMartin Erhart if (auto iter = enclosingRepetitiveRegionCache.find_as(block); 909312b4f9SMartin Erhart iter != enclosingRepetitiveRegionCache.end()) 919312b4f9SMartin Erhart return iter->second; 929312b4f9SMartin Erhart 9345cd0e45SMatthias Springer Region *region = block->getParent(); 9445cd0e45SMatthias Springer Operation *op = nullptr; 959312b4f9SMartin Erhart // Collect all visited regions since we only know the repetitive region we 969312b4f9SMartin Erhart // want to map it to later on 979312b4f9SMartin Erhart SmallVector<Region *> visitedRegions; 9845cd0e45SMatthias Springer do { 9945cd0e45SMatthias Springer op = region->getParentOp(); 100c89c31a2SMatthias Springer if (isRepetitiveRegion(region, options)) 1019312b4f9SMartin Erhart break; 10245cd0e45SMatthias Springer } while ((region = op->getParentRegion())); 1039312b4f9SMartin Erhart 1049312b4f9SMartin Erhart enclosingRepetitiveRegionCache[block] = region; 1059312b4f9SMartin Erhart for (Region *r : visitedRegions) 1069312b4f9SMartin Erhart enclosingRepetitiveRegionCache[r] = region; 1079312b4f9SMartin Erhart return region; 10845cd0e45SMatthias Springer } 10945cd0e45SMatthias Springer 1109312b4f9SMartin Erhart void AnalysisState::resetCache() { enclosingRepetitiveRegionCache.clear(); } 1119312b4f9SMartin Erhart 112c89c31a2SMatthias Springer Region *bufferization::getNextEnclosingRepetitiveRegion( 113c89c31a2SMatthias Springer Region *region, const BufferizationOptions &options) { 114c89c31a2SMatthias Springer assert(isRepetitiveRegion(region, options) && "expected repetitive region"); 115c89c31a2SMatthias Springer while ((region = region->getParentRegion())) { 116c89c31a2SMatthias Springer if (isRepetitiveRegion(region, options)) 117c89c31a2SMatthias Springer break; 118c89c31a2SMatthias Springer } 119c89c31a2SMatthias Springer return region; 120c89c31a2SMatthias Springer } 121c89c31a2SMatthias Springer 1221e1a3112SMatthias Springer Region *bufferization::getParallelRegion(Region *region, 1231e1a3112SMatthias Springer const BufferizationOptions &options) { 1241e1a3112SMatthias Springer while (region) { 1251e1a3112SMatthias Springer auto bufferizableOp = options.dynCastBufferizableOp(region->getParentOp()); 1261e1a3112SMatthias Springer if (bufferizableOp && 1271e1a3112SMatthias Springer bufferizableOp.isParallelRegion(region->getRegionNumber())) { 1281e1a3112SMatthias Springer assert(isRepetitiveRegion(region, options) && 1291e1a3112SMatthias Springer "expected that all parallel regions are also repetitive regions"); 1301e1a3112SMatthias Springer return region; 1311e1a3112SMatthias Springer } 1321e1a3112SMatthias Springer region = region->getParentRegion(); 1331e1a3112SMatthias Springer } 1341e1a3112SMatthias Springer return nullptr; 1351e1a3112SMatthias Springer } 1361e1a3112SMatthias Springer 137111c9196SMatthias Springer Operation *bufferization::getOwnerOfValue(Value value) { 138c1fa60b4STres Popp if (auto opResult = llvm::dyn_cast<OpResult>(value)) 139c0b0b6a0SMatthias Springer return opResult.getDefiningOp(); 140c1fa60b4STres Popp return llvm::cast<BlockArgument>(value).getOwner()->getParentOp(); 141c0b0b6a0SMatthias Springer } 142c0b0b6a0SMatthias Springer 143b3ebe3beSMatthias Springer /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the 144b3ebe3beSMatthias Springer /// shaped value is copied. Otherwise, a tensor with undefined contents is 145b3ebe3beSMatthias Springer /// allocated. 14645b995cdSMatthias Springer FailureOr<Value> bufferization::allocateTensorForShapedValue( 1476bf043e7SMartin Erhart OpBuilder &b, Location loc, Value shapedValue, 14845b995cdSMatthias Springer const BufferizationOptions &options, bool copy) { 149b3ebe3beSMatthias Springer Value tensor; 150c1fa60b4STres Popp if (llvm::isa<RankedTensorType>(shapedValue.getType())) { 151b3ebe3beSMatthias Springer tensor = shapedValue; 152c1fa60b4STres Popp } else if (llvm::isa<MemRefType>(shapedValue.getType())) { 153b3ebe3beSMatthias Springer tensor = b.create<ToTensorOp>(loc, shapedValue); 154c1fa60b4STres Popp } else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) || 155c1fa60b4STres Popp llvm::isa<UnrankedMemRefType>(shapedValue.getType())) { 156d7f72d4bSMatthias Springer return getOwnerOfValue(shapedValue) 157d7f72d4bSMatthias Springer ->emitError("copying of unranked tensors is not implemented"); 15879f11591SMatthias Springer } else { 159b3ebe3beSMatthias Springer llvm_unreachable("expected RankedTensorType or MemRefType"); 16079f11591SMatthias Springer } 161c1fa60b4STres Popp RankedTensorType tensorType = llvm::cast<RankedTensorType>(tensor.getType()); 162b3ebe3beSMatthias Springer SmallVector<Value> dynamicSizes; 163b3ebe3beSMatthias Springer if (!copy) { 164b3ebe3beSMatthias Springer // Compute the dynamic part of the shape. 165b3ebe3beSMatthias Springer // First try to query the shape via ReifyRankedShapedTypeOpInterface. 166b3ebe3beSMatthias Springer bool reifiedShapes = false; 167c1fa60b4STres Popp if (llvm::isa<RankedTensorType>(shapedValue.getType()) && 168c1fa60b4STres Popp llvm::isa<OpResult>(shapedValue)) { 169b3ebe3beSMatthias Springer ReifiedRankedShapedTypeDims resultDims; 170758329dcSMatthias Springer if (succeeded( 171758329dcSMatthias Springer reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) { 172b3ebe3beSMatthias Springer reifiedShapes = true; 173b3ebe3beSMatthias Springer auto &shape = 174c1fa60b4STres Popp resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()]; 175b3ebe3beSMatthias Springer for (const auto &dim : enumerate(tensorType.getShape())) 176b3ebe3beSMatthias Springer if (ShapedType::isDynamic(dim.value())) 177129f1001SKazu Hirata dynamicSizes.push_back(cast<Value>(shape[dim.index()])); 178b3ebe3beSMatthias Springer } 179b3ebe3beSMatthias Springer } 180b3ebe3beSMatthias Springer 181b3ebe3beSMatthias Springer // If the shape could not be reified, create DimOps. 182b3ebe3beSMatthias Springer if (!reifiedShapes) 183b3ebe3beSMatthias Springer populateDynamicDimSizes(b, loc, tensor, dynamicSizes); 184b3ebe3beSMatthias Springer } 185b3ebe3beSMatthias Springer 186c0b0b6a0SMatthias Springer // Create AllocTensorOp. 1873474d10eSMatthias Springer auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes, 1883474d10eSMatthias Springer copy ? tensor : Value()); 189c0b0b6a0SMatthias Springer 190c0b0b6a0SMatthias Springer // Add 'memory_space' attribute. Not needed if 'copy' operand is specified. 191c0b0b6a0SMatthias Springer if (copy) 192c0b0b6a0SMatthias Springer return allocTensorOp.getResult(); 193c0b0b6a0SMatthias Springer FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options); 194c0b0b6a0SMatthias Springer if (failed(copyBufferType)) 195c0b0b6a0SMatthias Springer return failure(); 1966f1e23b4SKunwar Grover std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace(); 1979bb63374SLei Zhang if (!memorySpace) 1986f1e23b4SKunwar Grover memorySpace = options.defaultMemorySpaceFn(tensorType); 1996f1e23b4SKunwar Grover if (memorySpace.has_value()) 2006f1e23b4SKunwar Grover allocTensorOp.setMemorySpaceAttr(memorySpace.value()); 20145b995cdSMatthias Springer return allocTensorOp.getResult(); 20279f11591SMatthias Springer } 20379f11591SMatthias Springer 20487c770bbSMatthias Springer LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( 20587c770bbSMatthias Springer RewriterBase &rewriter, const AnalysisState &state) { 20687b46776SMatthias Springer OpBuilder::InsertionGuard g(rewriter); 20787c770bbSMatthias Springer Operation *op = getOperation(); 20887b46776SMatthias Springer SmallVector<OpOperand *> outOfPlaceOpOperands; 20979f11591SMatthias Springer DenseSet<OpOperand *> copiedOpOperands; 210a02ad6c1SMatthias Springer SmallVector<Value> outOfPlaceValues; 211a02ad6c1SMatthias Springer DenseSet<Value> copiedOpValues; 21287b46776SMatthias Springer 21387b46776SMatthias Springer // Find all out-of-place OpOperands. 21487c770bbSMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) { 21587c770bbSMatthias Springer Type operandType = opOperand.get().getType(); 216c1fa60b4STres Popp if (!llvm::isa<TensorType>(operandType)) 21787c770bbSMatthias Springer continue; 21887c770bbSMatthias Springer if (state.isInPlace(opOperand)) 21987c770bbSMatthias Springer continue; 220c1fa60b4STres Popp if (llvm::isa<UnrankedTensorType>(operandType)) 221d7f72d4bSMatthias Springer return op->emitError("copying of unranked tensors is not implemented"); 22287b46776SMatthias Springer 223a02ad6c1SMatthias Springer AliasingValueList aliasingValues = state.getAliasingValues(opOperand); 224a02ad6c1SMatthias Springer if (aliasingValues.getNumAliases() == 1 && 2256ecebb49SMatthias Springer isa<OpResult>(aliasingValues.getAliases()[0].value) && 22687b46776SMatthias Springer !state.bufferizesToMemoryWrite(opOperand) && 227a02ad6c1SMatthias Springer state.getAliasingOpOperands(aliasingValues.getAliases()[0].value) 2289fa6b350SMatthias Springer .getNumAliases() == 1 && 229a02ad6c1SMatthias Springer !isa<UnrankedTensorType>( 230a02ad6c1SMatthias Springer aliasingValues.getAliases()[0].value.getType())) { 23187b46776SMatthias Springer // The op itself does not write but may create exactly one alias. Instead 23287b46776SMatthias Springer // of copying the OpOperand, copy the OpResult. The OpResult can sometimes 23387b46776SMatthias Springer // be smaller than the OpOperand (e.g., in the case of an extract_slice, 234d7f72d4bSMatthias Springer // where the result is usually a smaller part of the source). Do not apply 235d7f72d4bSMatthias Springer // this optimization if the OpResult is an unranked tensor (because those 236d7f72d4bSMatthias Springer // cannot be copied at the moment). 237a02ad6c1SMatthias Springer Value value = aliasingValues.getAliases()[0].value; 238a02ad6c1SMatthias Springer outOfPlaceValues.push_back(value); 23979f11591SMatthias Springer if (!state.canOmitTensorCopy(opOperand)) 240a02ad6c1SMatthias Springer copiedOpValues.insert(value); 24187b46776SMatthias Springer } else { 24287b46776SMatthias Springer // In all other cases, make a copy of the OpOperand. 24387b46776SMatthias Springer outOfPlaceOpOperands.push_back(&opOperand); 24479f11591SMatthias Springer if (!state.canOmitTensorCopy(opOperand)) 24579f11591SMatthias Springer copiedOpOperands.insert(&opOperand); 24687b46776SMatthias Springer } 24787b46776SMatthias Springer } 24887b46776SMatthias Springer 24987b46776SMatthias Springer // Insert copies of OpOperands. 25087b46776SMatthias Springer rewriter.setInsertionPoint(op); 25187b46776SMatthias Springer for (OpOperand *opOperand : outOfPlaceOpOperands) { 25245b995cdSMatthias Springer FailureOr<Value> copy = allocateTensorForShapedValue( 2536bf043e7SMartin Erhart rewriter, op->getLoc(), opOperand->get(), state.getOptions(), 25479f11591SMatthias Springer copiedOpOperands.contains(opOperand)); 25545b995cdSMatthias Springer if (failed(copy)) 25645b995cdSMatthias Springer return failure(); 2575fcf907bSMatthias Springer rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); }); 25887c770bbSMatthias Springer } 25987b46776SMatthias Springer 260a02ad6c1SMatthias Springer // Insert copies of Values. 26187b46776SMatthias Springer rewriter.setInsertionPointAfter(op); 262a02ad6c1SMatthias Springer for (Value value : outOfPlaceValues) { 26345b995cdSMatthias Springer FailureOr<Value> copy = allocateTensorForShapedValue( 2646bf043e7SMartin Erhart rewriter, op->getLoc(), value, state.getOptions(), 2656bf043e7SMartin Erhart copiedOpValues.count(value)); 26645b995cdSMatthias Springer if (failed(copy)) 26745b995cdSMatthias Springer return failure(); 268a02ad6c1SMatthias Springer SmallVector<OpOperand *> uses = llvm::to_vector( 269a02ad6c1SMatthias Springer llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; })); 27087b46776SMatthias Springer for (OpOperand *use : uses) { 27187b46776SMatthias Springer // Do not update the alloc_tensor op that we just created. 2725f6d5ca0SMatthias Springer if (use->getOwner() == copy->getDefiningOp()) 2735f6d5ca0SMatthias Springer continue; 2745f6d5ca0SMatthias Springer // tensor.dim ops may have been created to be used as alloc_tensor op 2755f6d5ca0SMatthias Springer // dynamic extents. Do not update these either. 2765f6d5ca0SMatthias Springer if (isa<tensor::DimOp>(use->getOwner())) 2775f6d5ca0SMatthias Springer continue; 2785fcf907bSMatthias Springer rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); }); 27987b46776SMatthias Springer } 28087b46776SMatthias Springer } 28187b46776SMatthias Springer 28287c770bbSMatthias Springer return success(); 28387c770bbSMatthias Springer } 28487c770bbSMatthias Springer 2857a1579acSMatthias Springer //===----------------------------------------------------------------------===// 2861534177fSMatthias Springer // OpFilter 2871534177fSMatthias Springer //===----------------------------------------------------------------------===// 2881534177fSMatthias Springer 2891534177fSMatthias Springer bool OpFilter::isOpAllowed(Operation *op) const { 2901534177fSMatthias Springer // All other ops: Allow/disallow according to filter. 2911534177fSMatthias Springer bool isAllowed = !hasAllowRule(); 2921534177fSMatthias Springer for (const Entry &entry : entries) { 2931534177fSMatthias Springer bool filterResult = entry.fn(op); 2941534177fSMatthias Springer switch (entry.type) { 2951534177fSMatthias Springer case Entry::ALLOW: 2961534177fSMatthias Springer isAllowed |= filterResult; 2971534177fSMatthias Springer break; 2981534177fSMatthias Springer case Entry::DENY: 2991534177fSMatthias Springer if (filterResult) 3001534177fSMatthias Springer // DENY filter matches. This op is no allowed. (Even if other ALLOW 3011534177fSMatthias Springer // filters may match.) 3021534177fSMatthias Springer return false; 3031534177fSMatthias Springer }; 3041534177fSMatthias Springer } 3051534177fSMatthias Springer return isAllowed; 3061534177fSMatthias Springer } 3071534177fSMatthias Springer 3081534177fSMatthias Springer //===----------------------------------------------------------------------===// 3097a1579acSMatthias Springer // BufferizationOptions 3107a1579acSMatthias Springer //===----------------------------------------------------------------------===// 3117a1579acSMatthias Springer 31275ef84bfSOleg Shyshkov namespace { 31375ef84bfSOleg Shyshkov 31475ef84bfSOleg Shyshkov /// Default function arg type converter: Use a fully dynamic layout map. 31575ef84bfSOleg Shyshkov BaseMemRefType 31675ef84bfSOleg Shyshkov defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace, 31791c11574SAndrzej Warzyński func::FuncOp funcOp, 31875ef84bfSOleg Shyshkov const BufferizationOptions &options) { 31975ef84bfSOleg Shyshkov return getMemRefTypeWithFullyDynamicLayout(type, memorySpace); 32075ef84bfSOleg Shyshkov } 321606f7c8fSMatthias Springer /// Default unknown type converter: Use a fully dynamic layout map. 32275ef84bfSOleg Shyshkov BaseMemRefType 3239bb63374SLei Zhang defaultUnknownTypeConverter(Value value, Attribute memorySpace, 324606f7c8fSMatthias Springer const BufferizationOptions &options) { 325c1fa60b4STres Popp return getMemRefTypeWithFullyDynamicLayout( 326c1fa60b4STres Popp llvm::cast<TensorType>(value.getType()), memorySpace); 327606f7c8fSMatthias Springer } 328606f7c8fSMatthias Springer 329ea3e8d3bSJie Fu } // namespace 33075ef84bfSOleg Shyshkov 3317a1579acSMatthias Springer // Default constructor for BufferizationOptions. 332606f7c8fSMatthias Springer BufferizationOptions::BufferizationOptions() 33375ef84bfSOleg Shyshkov : functionArgTypeConverterFn(defaultFunctionArgTypeConverter), 33475ef84bfSOleg Shyshkov unknownTypeConverterFn(defaultUnknownTypeConverter) {} 3357a1579acSMatthias Springer 336d6dab38aSMatthias Springer bool BufferizationOptions::isOpAllowed(Operation *op) const { 337d6dab38aSMatthias Springer // Special case: If function boundary bufferization is deactivated, do not 338d6dab38aSMatthias Springer // allow ops that belong to the `func` dialect. 339d6dab38aSMatthias Springer bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect()); 340d6dab38aSMatthias Springer if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) 341d6dab38aSMatthias Springer return false; 342d6dab38aSMatthias Springer 3431534177fSMatthias Springer return opFilter.isOpAllowed(op); 344d6dab38aSMatthias Springer } 345d6dab38aSMatthias Springer 3467a1579acSMatthias Springer BufferizableOpInterface 3477a1579acSMatthias Springer BufferizationOptions::dynCastBufferizableOp(Operation *op) const { 348db604911SBenjamin Kramer if (!isOpAllowed(op)) 349db604911SBenjamin Kramer return nullptr; 3509785eb1bSMatthias Springer auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 3519785eb1bSMatthias Springer if (!bufferizableOp) 3527a1579acSMatthias Springer return nullptr; 3539785eb1bSMatthias Springer return bufferizableOp; 3547a1579acSMatthias Springer } 3557a1579acSMatthias Springer 3567a1579acSMatthias Springer BufferizableOpInterface 3577a1579acSMatthias Springer BufferizationOptions::dynCastBufferizableOp(Value value) const { 3581f479c1eSMatthias Springer return dynCastBufferizableOp(getOwnerOfValue(value)); 3597a1579acSMatthias Springer } 3607a1579acSMatthias Springer 36175ef84bfSOleg Shyshkov void BufferizationOptions::setFunctionBoundaryTypeConversion( 36275ef84bfSOleg Shyshkov LayoutMapOption layoutMapOption) { 36375ef84bfSOleg Shyshkov functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace, 36491c11574SAndrzej Warzyński func::FuncOp funcOp, 36575ef84bfSOleg Shyshkov const BufferizationOptions &options) { 36675ef84bfSOleg Shyshkov if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) 36775ef84bfSOleg Shyshkov return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, 36875ef84bfSOleg Shyshkov memorySpace); 36975ef84bfSOleg Shyshkov return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, 37075ef84bfSOleg Shyshkov memorySpace); 37175ef84bfSOleg Shyshkov }; 37275ef84bfSOleg Shyshkov inferFunctionResultLayout = 37375ef84bfSOleg Shyshkov layoutMapOption == LayoutMapOption::InferLayoutMap; 37475ef84bfSOleg Shyshkov } 37575ef84bfSOleg Shyshkov 3767a1579acSMatthias Springer //===----------------------------------------------------------------------===// 3777a1579acSMatthias Springer // Helper functions for BufferizableOpInterface 3787a1579acSMatthias Springer //===----------------------------------------------------------------------===// 3797a1579acSMatthias Springer 3807a1579acSMatthias Springer static void setInsertionPointAfter(OpBuilder &b, Value value) { 381c1fa60b4STres Popp if (auto bbArg = llvm::dyn_cast<BlockArgument>(value)) { 3827a1579acSMatthias Springer b.setInsertionPointToStart(bbArg.getOwner()); 3837a1579acSMatthias Springer } else { 3847a1579acSMatthias Springer b.setInsertionPointAfter(value.getDefiningOp()); 3857a1579acSMatthias Springer } 3867a1579acSMatthias Springer } 3877a1579acSMatthias Springer 388a02ad6c1SMatthias Springer /// Determine which OpOperand* will alias with `value` if the op is bufferized 389a02ad6c1SMatthias Springer /// in place. Return all tensor OpOperand* if the op is not bufferizable. 390a02ad6c1SMatthias Springer AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const { 391a02ad6c1SMatthias Springer if (Operation *op = getOwnerOfValue(value)) 3922fe40c34SMatthias Springer if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) 393a02ad6c1SMatthias Springer return bufferizableOp.getAliasingOpOperands(value, *this); 394f3483c23SMatthias Springer 395f3483c23SMatthias Springer // The op is not bufferizable. 396a02ad6c1SMatthias Springer return detail::unknownGetAliasingOpOperands(value); 3977a1579acSMatthias Springer } 3987a1579acSMatthias Springer 399a02ad6c1SMatthias Springer /// Determine which Values will alias with `opOperand` if the op is bufferized 400a02ad6c1SMatthias Springer /// in place. Return all tensor Values if the op is not bufferizable. 401a02ad6c1SMatthias Springer AliasingValueList AnalysisState::getAliasingValues(OpOperand &opOperand) const { 4027a1579acSMatthias Springer if (auto bufferizableOp = 4032fe40c34SMatthias Springer getOptions().dynCastBufferizableOp(opOperand.getOwner())) 404a02ad6c1SMatthias Springer return bufferizableOp.getAliasingValues(opOperand, *this); 405f3483c23SMatthias Springer 406f3483c23SMatthias Springer // The op is not bufferizable. 407a02ad6c1SMatthias Springer return detail::unknownGetAliasingValues(opOperand); 4087a1579acSMatthias Springer } 4097a1579acSMatthias Springer 4107a1579acSMatthias Springer /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the 4117a1579acSMatthias Springer /// op is not bufferizable. 4129597b16aSMatthias Springer bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { 4137a1579acSMatthias Springer if (auto bufferizableOp = 4142fe40c34SMatthias Springer getOptions().dynCastBufferizableOp(opOperand.getOwner())) 4157a1579acSMatthias Springer return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); 4167a1579acSMatthias Springer 4177a1579acSMatthias Springer // Unknown op that returns a tensor. The inplace analysis does not support it. 4187a1579acSMatthias Springer // Conservatively return true. 4197a1579acSMatthias Springer return true; 4207a1579acSMatthias Springer } 4217a1579acSMatthias Springer 4227a1579acSMatthias Springer /// Return true if `opOperand` bufferizes to a memory write. Return 4237a1579acSMatthias Springer /// `true` if the op is not bufferizable. 4249597b16aSMatthias Springer bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { 4257a1579acSMatthias Springer if (auto bufferizableOp = 4262fe40c34SMatthias Springer getOptions().dynCastBufferizableOp(opOperand.getOwner())) 4277a1579acSMatthias Springer return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); 4287a1579acSMatthias Springer 4297a1579acSMatthias Springer // Unknown op that returns a tensor. The inplace analysis does not support it. 4307a1579acSMatthias Springer // Conservatively return true. 4317a1579acSMatthias Springer return true; 4327a1579acSMatthias Springer } 4337a1579acSMatthias Springer 4347a1579acSMatthias Springer /// Return true if `opOperand` does neither read nor write but bufferizes to an 4357a1579acSMatthias Springer /// alias. Return false if the op is not bufferizable. 4369597b16aSMatthias Springer bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { 4377a1579acSMatthias Springer if (auto bufferizableOp = 4382fe40c34SMatthias Springer getOptions().dynCastBufferizableOp(opOperand.getOwner())) 4397a1579acSMatthias Springer return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); 4407a1579acSMatthias Springer 4417a1579acSMatthias Springer // Unknown op that returns a tensor. The inplace analysis does not support it. 4427a1579acSMatthias Springer // Conservatively return false. 4437a1579acSMatthias Springer return false; 4447a1579acSMatthias Springer } 4457a1579acSMatthias Springer 44634d65e81SMatthias Springer bool AnalysisState::bufferizesToMemoryWrite(Value value) const { 447c1fa60b4STres Popp auto opResult = llvm::dyn_cast<OpResult>(value); 44834d65e81SMatthias Springer if (!opResult) 44934d65e81SMatthias Springer return true; 45034d65e81SMatthias Springer auto bufferizableOp = getOptions().dynCastBufferizableOp(value); 45134d65e81SMatthias Springer if (!bufferizableOp) 45234d65e81SMatthias Springer return true; 45334d65e81SMatthias Springer return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *this); 45434d65e81SMatthias Springer } 45534d65e81SMatthias Springer 4567a1579acSMatthias Springer /// Return true if the given value is read by an op that bufferizes to a memory 4577a1579acSMatthias Springer /// read. Also takes into account ops that create an alias but do not read by 4587a1579acSMatthias Springer /// themselves (e.g., ExtractSliceOp). 4599597b16aSMatthias Springer bool AnalysisState::isValueRead(Value value) const { 460c1fa60b4STres Popp assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType"); 4617a1579acSMatthias Springer SmallVector<OpOperand *> workingSet; 4626ecebb49SMatthias Springer DenseSet<OpOperand *> visited; 4637a1579acSMatthias Springer for (OpOperand &use : value.getUses()) 4647a1579acSMatthias Springer workingSet.push_back(&use); 4657a1579acSMatthias Springer 4667a1579acSMatthias Springer while (!workingSet.empty()) { 4677a1579acSMatthias Springer OpOperand *uMaybeReading = workingSet.pop_back_val(); 4687be6ea12SKazu Hirata if (!visited.insert(uMaybeReading).second) 4696ecebb49SMatthias Springer continue; 4706ecebb49SMatthias Springer 4717a1579acSMatthias Springer // Skip over all ops that neither read nor write (but create an alias). 4727a1579acSMatthias Springer if (bufferizesToAliasOnly(*uMaybeReading)) 473a02ad6c1SMatthias Springer for (AliasingValue alias : getAliasingValues(*uMaybeReading)) 474a02ad6c1SMatthias Springer for (OpOperand &use : alias.value.getUses()) 4757a1579acSMatthias Springer workingSet.push_back(&use); 4767a1579acSMatthias Springer if (bufferizesToMemoryRead(*uMaybeReading)) 4777a1579acSMatthias Springer return true; 4787a1579acSMatthias Springer } 4797a1579acSMatthias Springer 4807a1579acSMatthias Springer return false; 4817a1579acSMatthias Springer } 4827a1579acSMatthias Springer 483*d9111f19SAmir Bishara // Starting from `opOperand`, follow the use-def chain in reverse, always 484*d9111f19SAmir Bishara // selecting the aliasing OpOperands. Find and return Values for which 485*d9111f19SAmir Bishara // `condition` evaluates to true. Uses of such matching Values are not 486*d9111f19SAmir Bishara // traversed any further, the visited aliasing opOperands will be preserved 487*d9111f19SAmir Bishara // through `visitedOpOperands`. 4889597b16aSMatthias Springer llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( 489*d9111f19SAmir Bishara OpOperand *opOperand, llvm::function_ref<bool(Value)> condition, 49008aa9563SAmir Bishara TraversalConfig config, 49108aa9563SAmir Bishara llvm::DenseSet<OpOperand *> *visitedOpOperands) const { 4926ecebb49SMatthias Springer llvm::DenseSet<Value> visited; 4937a1579acSMatthias Springer llvm::SetVector<Value> result, workingSet; 494*d9111f19SAmir Bishara workingSet.insert(opOperand->get()); 495*d9111f19SAmir Bishara 496*d9111f19SAmir Bishara if (visitedOpOperands) 497*d9111f19SAmir Bishara visitedOpOperands->insert(opOperand); 4987a1579acSMatthias Springer 4997a1579acSMatthias Springer while (!workingSet.empty()) { 5007a1579acSMatthias Springer Value value = workingSet.pop_back_val(); 5016ecebb49SMatthias Springer 5026ecebb49SMatthias Springer if (!config.revisitAlreadyVisitedValues && visited.contains(value)) { 5036ecebb49SMatthias Springer // Stop traversal if value was already visited. 5046ecebb49SMatthias Springer if (config.alwaysIncludeLeaves) 5056ecebb49SMatthias Springer result.insert(value); 5066ecebb49SMatthias Springer continue; 5076ecebb49SMatthias Springer } 5086ecebb49SMatthias Springer visited.insert(value); 5096ecebb49SMatthias Springer 510fdb9e6a3SMatthias Springer if (condition(value)) { 511fdb9e6a3SMatthias Springer result.insert(value); 512fdb9e6a3SMatthias Springer continue; 513fdb9e6a3SMatthias Springer } 514fdb9e6a3SMatthias Springer 515a02ad6c1SMatthias Springer if (!config.followUnknownOps && !options.dynCastBufferizableOp(value)) { 5161f479c1eSMatthias Springer // Stop iterating if `followUnknownOps` is unset and the op is either 5171f479c1eSMatthias Springer // not bufferizable or excluded in the OpFilter. 5181f479c1eSMatthias Springer if (config.alwaysIncludeLeaves) 5191f479c1eSMatthias Springer result.insert(value); 5201f479c1eSMatthias Springer continue; 5211f479c1eSMatthias Springer } 5226cdd34b9SMatthias Springer 523a02ad6c1SMatthias Springer AliasingOpOperandList aliases = getAliasingOpOperands(value); 5241f479c1eSMatthias Springer if (aliases.getNumAliases() == 0) { 5251f479c1eSMatthias Springer // The traversal ends naturally if there are no more OpOperands that 5261f479c1eSMatthias Springer // could be followed. 5271f479c1eSMatthias Springer if (config.alwaysIncludeLeaves) 5287a1579acSMatthias Springer result.insert(value); 5297a1579acSMatthias Springer continue; 5307a1579acSMatthias Springer } 5317a1579acSMatthias Springer 5329fa6b350SMatthias Springer for (AliasingOpOperand a : aliases) { 5331f479c1eSMatthias Springer if (config.followEquivalentOnly && 5341f479c1eSMatthias Springer a.relation != BufferRelation::Equivalent) { 5359fa6b350SMatthias Springer // Stop iterating if `followEquivalentOnly` is set but the alias is not 5369fa6b350SMatthias Springer // equivalent. 5371f479c1eSMatthias Springer if (config.alwaysIncludeLeaves) 5389fa6b350SMatthias Springer result.insert(value); 539aa909483SMatthias Springer continue; 5409fa6b350SMatthias Springer } 5411f479c1eSMatthias Springer 5421f479c1eSMatthias Springer if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) { 5431f479c1eSMatthias Springer // Stop iterating if `followInPlaceOnly` is set but the alias is 5441f479c1eSMatthias Springer // out-of-place. 5451f479c1eSMatthias Springer if (config.alwaysIncludeLeaves) 5461f479c1eSMatthias Springer result.insert(value); 5471f479c1eSMatthias Springer continue; 5481f479c1eSMatthias Springer } 5491f479c1eSMatthias Springer 550aba0ef70SMatthias Springer if (config.followSameTypeOrCastsOnly && 551aba0ef70SMatthias Springer a.opOperand->get().getType() != value.getType() && 552a02ad6c1SMatthias Springer !value.getDefiningOp<CastOpInterface>()) { 553aba0ef70SMatthias Springer // Stop iterating if `followSameTypeOrCastsOnly` is set but the alias is 554aba0ef70SMatthias Springer // has a different type and the op is not a cast. 555aba0ef70SMatthias Springer if (config.alwaysIncludeLeaves) 556aba0ef70SMatthias Springer result.insert(value); 557aba0ef70SMatthias Springer continue; 558aba0ef70SMatthias Springer } 559aba0ef70SMatthias Springer 5601f479c1eSMatthias Springer workingSet.insert(a.opOperand->get()); 56108aa9563SAmir Bishara if (visitedOpOperands) 56208aa9563SAmir Bishara visitedOpOperands->insert(a.opOperand); 5639fa6b350SMatthias Springer } 5647a1579acSMatthias Springer } 5657a1579acSMatthias Springer 5667a1579acSMatthias Springer return result; 5677a1579acSMatthias Springer } 5687a1579acSMatthias Springer 569*d9111f19SAmir Bishara // Find the values that define the contents of the given operand's value. 570*d9111f19SAmir Bishara llvm::SetVector<Value> 571*d9111f19SAmir Bishara AnalysisState::findDefinitions(OpOperand *opOperand) const { 5721f479c1eSMatthias Springer TraversalConfig config; 5731f479c1eSMatthias Springer config.alwaysIncludeLeaves = false; 57434d65e81SMatthias Springer return findValueInReverseUseDefChain( 575*d9111f19SAmir Bishara opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, 576*d9111f19SAmir Bishara config); 5777a1579acSMatthias Springer } 5787a1579acSMatthias Springer 5799597b16aSMatthias Springer AnalysisState::AnalysisState(const BufferizationOptions &options) 580faa9be75SMatthias Springer : AnalysisState(options, TypeID::get<AnalysisState>()) {} 581faa9be75SMatthias Springer 582faa9be75SMatthias Springer AnalysisState::AnalysisState(const BufferizationOptions &options, TypeID type) 583faa9be75SMatthias Springer : options(options), type(type) { 5849597b16aSMatthias Springer for (const BufferizationOptions::AnalysisStateInitFn &fn : 5856fc11d4dSMatthias Springer options.stateInitializers) 5866fc11d4dSMatthias Springer fn(*this); 5876fc11d4dSMatthias Springer } 5887a1579acSMatthias Springer 58979f11591SMatthias Springer bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const { 59079f11591SMatthias Springer // Do not copy if the tensor has undefined contents. 59179f11591SMatthias Springer if (hasUndefinedContents(&opOperand)) 59279f11591SMatthias Springer return true; 59379f11591SMatthias Springer 59479f11591SMatthias Springer // Do not copy if the buffer of the tensor is entirely overwritten (with 59579f11591SMatthias Springer // values that do not depend on the old tensor). 59679f11591SMatthias Springer if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) 59779f11591SMatthias Springer return true; 59879f11591SMatthias Springer 59979f11591SMatthias Springer // Do not copy if the tensor is never read. 600a02ad6c1SMatthias Springer AliasingValueList aliases = getAliasingValues(opOperand); 60179f11591SMatthias Springer if (!bufferizesToMemoryRead(opOperand) && 602a02ad6c1SMatthias Springer llvm::none_of(aliases, 603a02ad6c1SMatthias Springer [&](AliasingValue a) { return isValueRead(a.value); })) 60479f11591SMatthias Springer return true; 60579f11591SMatthias Springer 60679f11591SMatthias Springer // Default: Cannot omit the copy. 60779f11591SMatthias Springer return false; 60879f11591SMatthias Springer } 60979f11591SMatthias Springer 610a3bca118SMatthias Springer bool AnalysisState::isInPlace(OpOperand &opOperand) const { 611b3ebe3beSMatthias Springer // ToMemrefOps are always in-place. 612b3ebe3beSMatthias Springer if (isa<ToMemrefOp>(opOperand.getOwner())) 613b3ebe3beSMatthias Springer return true; 614b3ebe3beSMatthias Springer 615a3bca118SMatthias Springer // In the absence of analysis information, OpOperands that bufferize to a 616a3bca118SMatthias Springer // memory write are out-of-place, i.e., an alloc and copy is inserted. 617a3bca118SMatthias Springer return !bufferizesToMemoryWrite(opOperand); 618a3bca118SMatthias Springer } 619a3bca118SMatthias Springer 620a3bca118SMatthias Springer bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const { 621a3bca118SMatthias Springer // In the absence of analysis information, we do not know if the values are 622a3bca118SMatthias Springer // equivalent. The conservative answer is "false". 623a3bca118SMatthias Springer return false; 624a3bca118SMatthias Springer } 625a3bca118SMatthias Springer 626a3bca118SMatthias Springer bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const { 627a3bca118SMatthias Springer // In the absence of analysis information, we do not know if the values may be 628a3bca118SMatthias Springer // aliasing. The conservative answer is "true". 629f2ada383Slorenzo chelini return true; 630a3bca118SMatthias Springer } 631a3bca118SMatthias Springer 632a3bca118SMatthias Springer bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { 633a3bca118SMatthias Springer // In the absence of analysis information, the conservative answer is "false". 634a3bca118SMatthias Springer return false; 635a3bca118SMatthias Springer } 636a3bca118SMatthias Springer 6377a1579acSMatthias Springer // bufferization.to_memref is not allowed to change the rank. 6387a1579acSMatthias Springer static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 6397a1579acSMatthias Springer #ifndef NDEBUG 640c1fa60b4STres Popp auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType()); 641c1fa60b4STres Popp assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() == 6427a1579acSMatthias Springer rankedTensorType.getRank()) && 6437a1579acSMatthias Springer "to_memref would be invalid: mismatching ranks"); 6447a1579acSMatthias Springer #endif 6457a1579acSMatthias Springer } 6467a1579acSMatthias Springer 6475d50f51cSMatthias Springer FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, 648b55d55ecSMatthias Springer const BufferizationOptions &options) { 649ba9d886dSMatthias Springer #ifndef NDEBUG 650c1fa60b4STres Popp auto tensorType = llvm::dyn_cast<TensorType>(value.getType()); 65126852423SMatthias Springer assert(tensorType && "unexpected non-tensor type"); 652ba9d886dSMatthias Springer #endif // NDEBUG 6537a1579acSMatthias Springer 6547a1579acSMatthias Springer // Replace "%t = to_tensor %m" with %m. 655b3ebe3beSMatthias Springer if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>()) 65699260e95SMatthias Springer return toTensorOp.getMemref(); 6577a1579acSMatthias Springer 6587a1579acSMatthias Springer // Insert to_memref op. 6597a1579acSMatthias Springer OpBuilder::InsertionGuard g(rewriter); 660b3ebe3beSMatthias Springer setInsertionPointAfter(rewriter, value); 6615d50f51cSMatthias Springer FailureOr<BaseMemRefType> memrefType = getBufferType(value, options); 6625d50f51cSMatthias Springer if (failed(memrefType)) 6635d50f51cSMatthias Springer return failure(); 6645d50f51cSMatthias Springer ensureToMemrefOpIsValid(value, *memrefType); 6655d50f51cSMatthias Springer return rewriter 6665d50f51cSMatthias Springer .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value) 6675d50f51cSMatthias Springer .getResult(); 6687a1579acSMatthias Springer } 6697a1579acSMatthias Springer 670996834e6SMatthias Springer /// Return the buffer type for a given Value (tensor) after bufferization. 6715d50f51cSMatthias Springer FailureOr<BaseMemRefType> 672b55d55ecSMatthias Springer bufferization::getBufferType(Value value, const BufferizationOptions &options) { 673878950b8SMatthias Springer SmallVector<Value> invocationStack; 674878950b8SMatthias Springer return getBufferType(value, options, invocationStack); 675123c4b02SMatthias Springer } 676123c4b02SMatthias Springer 677123c4b02SMatthias Springer /// Return the buffer type for a given Value (tensor) after bufferization. 678878950b8SMatthias Springer FailureOr<BaseMemRefType> 679878950b8SMatthias Springer bufferization::getBufferType(Value value, const BufferizationOptions &options, 680878950b8SMatthias Springer SmallVector<Value> &invocationStack) { 681c1fa60b4STres Popp assert(llvm::isa<TensorType>(value.getType()) && 682c1fa60b4STres Popp "unexpected non-tensor type"); 683878950b8SMatthias Springer invocationStack.push_back(value); 684878950b8SMatthias Springer auto popFromStack = 685878950b8SMatthias Springer llvm::make_scope_exit([&]() { invocationStack.pop_back(); }); 686123c4b02SMatthias Springer 687123c4b02SMatthias Springer // Try querying BufferizableOpInterface. 688c0b0b6a0SMatthias Springer Operation *op = getOwnerOfValue(value); 689111c9196SMatthias Springer auto bufferizableOp = options.dynCastBufferizableOp(op); 690111c9196SMatthias Springer if (bufferizableOp) 691878950b8SMatthias Springer return bufferizableOp.getBufferType(value, options, invocationStack); 692d7a9bf91SMatthias Springer 693111c9196SMatthias Springer // Op is not bufferizable. 694067d2779Sian Bearman auto memSpace = 695a5757c5bSChristian Sigg options.defaultMemorySpaceFn(cast<TensorType>(value.getType())); 696067d2779Sian Bearman if (!memSpace.has_value()) 697c0b0b6a0SMatthias Springer return op->emitError("could not infer memory space"); 698c0b0b6a0SMatthias Springer 699067d2779Sian Bearman return getMemRefType(value, options, /*layout=*/{}, *memSpace); 700d7a9bf91SMatthias Springer } 701d7a9bf91SMatthias Springer 7028f2d83daSMatthias Springer bool bufferization::hasTensorSemantics(Operation *op) { 7038f2d83daSMatthias Springer if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) 7048f2d83daSMatthias Springer return bufferizableOp.hasTensorSemantics(); 7058f2d83daSMatthias Springer return detail::defaultHasTensorSemantics(op); 7068f2d83daSMatthias Springer } 7078f2d83daSMatthias Springer 7087a1579acSMatthias Springer void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, 7097a1579acSMatthias Springer Operation *op, 7107a1579acSMatthias Springer ValueRange values) { 7119106d35bSMatthias Springer assert(values.size() == op->getNumResults() && 7129106d35bSMatthias Springer "expected one value per OpResult"); 7137a1579acSMatthias Springer OpBuilder::InsertionGuard g(rewriter); 7147a1579acSMatthias Springer 7157a1579acSMatthias Springer // Replace all OpResults with the given values. 7169106d35bSMatthias Springer SmallVector<Value> replacements; 7177a1579acSMatthias Springer for (OpResult opResult : op->getOpResults()) { 7187a1579acSMatthias Springer Value replacement = values[opResult.getResultNumber()]; 719c1fa60b4STres Popp if (llvm::isa<TensorType>(opResult.getType())) { 7207a1579acSMatthias Springer // The OpResult is a tensor. Such values are replaced with memrefs during 7217a1579acSMatthias Springer // bufferization. 722c1fa60b4STres Popp assert((llvm::isa<MemRefType>(replacement.getType()) || 723c1fa60b4STres Popp llvm::isa<UnrankedMemRefType>(replacement.getType())) && 7247a1579acSMatthias Springer "tensor op result should be replaced with a memref value"); 7257a1579acSMatthias Springer // The existing uses of the OpResult still expect a tensor. Insert a 7267a1579acSMatthias Springer // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually 7277a1579acSMatthias Springer // loose all of its users and eventually DCE away. 728c30d2893SMatthias Springer rewriter.setInsertionPointAfter(op); 7297a1579acSMatthias Springer replacement = rewriter.create<bufferization::ToTensorOp>( 730ced2fc78SChristopher Bate replacement.getLoc(), opResult.getType(), replacement); 7317a1579acSMatthias Springer } 7329106d35bSMatthias Springer replacements.push_back(replacement); 7337a1579acSMatthias Springer } 7347a1579acSMatthias Springer 7359106d35bSMatthias Springer rewriter.replaceOp(op, replacements); 7367a1579acSMatthias Springer } 7377a1579acSMatthias Springer 7387a1579acSMatthias Springer //===----------------------------------------------------------------------===// 739caa2a4aeSMatthias Springer // Bufferization-specific scoped alloc insertion support. 7407a1579acSMatthias Springer //===----------------------------------------------------------------------===// 7417a1579acSMatthias Springer 74205e0495fSMatthias Springer /// Create a memref allocation with the given type and dynamic extents. 743248e113eSMatthias Springer FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, 744248e113eSMatthias Springer MemRefType type, 745248e113eSMatthias Springer ValueRange dynShape) const { 746248e113eSMatthias Springer if (allocationFn) 747248e113eSMatthias Springer return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); 74805e0495fSMatthias Springer 74905e0495fSMatthias Springer // Default bufferallocation via AllocOp. 750b3ebe3beSMatthias Springer if (bufferAlignment != 0) 751b3ebe3beSMatthias Springer return b 752b3ebe3beSMatthias Springer .create<memref::AllocOp>(loc, type, dynShape, 753b3ebe3beSMatthias Springer b.getI64IntegerAttr(bufferAlignment)) 754b3ebe3beSMatthias Springer .getResult(); 755b3ebe3beSMatthias Springer return b.create<memref::AllocOp>(loc, type, dynShape).getResult(); 75605e0495fSMatthias Springer } 75705e0495fSMatthias Springer 7587a1579acSMatthias Springer /// Create a memory copy between two memref buffers. 759248e113eSMatthias Springer LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, 760248e113eSMatthias Springer Value from, Value to) const { 761248e113eSMatthias Springer if (memCpyFn) 762248e113eSMatthias Springer return (*memCpyFn)(b, loc, from, to); 7637a1579acSMatthias Springer 7647a1579acSMatthias Springer b.create<memref::CopyOp>(loc, from, to); 7657a1579acSMatthias Springer return success(); 7667a1579acSMatthias Springer } 7677a1579acSMatthias Springer 7687a1579acSMatthias Springer //===----------------------------------------------------------------------===// 7694d67b278SJeff Niu // Bufferization-specific IRMapping support with debugging. 7707a1579acSMatthias Springer //===----------------------------------------------------------------------===// 7717a1579acSMatthias Springer 772606f7c8fSMatthias Springer BaseMemRefType bufferization::getMemRefType(Value value, 77326852423SMatthias Springer const BufferizationOptions &options, 77426852423SMatthias Springer MemRefLayoutAttrInterface layout, 7759bb63374SLei Zhang Attribute memorySpace) { 776c1fa60b4STres Popp auto tensorType = llvm::cast<TensorType>(value.getType()); 777b06614e2SMatthias Springer 77826852423SMatthias Springer // Case 1: Unranked memref type. 779c1fa60b4STres Popp if (auto unrankedTensorType = 780c1fa60b4STres Popp llvm::dyn_cast<UnrankedTensorType>(tensorType)) { 78126852423SMatthias Springer assert(!layout && "UnrankedTensorType cannot have a layout map"); 78226852423SMatthias Springer return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 7839bb63374SLei Zhang memorySpace); 7847a1579acSMatthias Springer } 7857a1579acSMatthias Springer 786f287da8aSMatthias Springer // Case 2: Ranked memref type with specified layout. 787c1fa60b4STres Popp auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType); 788f287da8aSMatthias Springer if (layout) { 78926852423SMatthias Springer return MemRefType::get(rankedTensorType.getShape(), 79026852423SMatthias Springer rankedTensorType.getElementType(), layout, 7919bb63374SLei Zhang memorySpace); 79226852423SMatthias Springer } 79326852423SMatthias Springer 794606f7c8fSMatthias Springer return options.unknownTypeConverterFn(value, memorySpace, options); 795f287da8aSMatthias Springer } 796f287da8aSMatthias Springer 797f287da8aSMatthias Springer BaseMemRefType 798f287da8aSMatthias Springer bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, 7999bb63374SLei Zhang Attribute memorySpace) { 800f287da8aSMatthias Springer // Case 1: Unranked memref type. 801c1fa60b4STres Popp if (auto unrankedTensorType = 802c1fa60b4STres Popp llvm::dyn_cast<UnrankedTensorType>(tensorType)) { 803f287da8aSMatthias Springer return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 804f287da8aSMatthias Springer memorySpace); 805f287da8aSMatthias Springer } 806f287da8aSMatthias Springer 807f287da8aSMatthias Springer // Case 2: Ranked memref type. 808c1fa60b4STres Popp auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType); 809399638f9SAliia Khasanova int64_t dynamicOffset = ShapedType::kDynamic; 81026852423SMatthias Springer SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(), 811399638f9SAliia Khasanova ShapedType::kDynamic); 812f096e72cSAlex Zinenko auto stridedLayout = StridedLayoutAttr::get(tensorType.getContext(), 813f096e72cSAlex Zinenko dynamicOffset, dynamicStrides); 81426852423SMatthias Springer return MemRefType::get(rankedTensorType.getShape(), 81526852423SMatthias Springer rankedTensorType.getElementType(), stridedLayout, 8169bb63374SLei Zhang memorySpace); 8177a1579acSMatthias Springer } 818f287da8aSMatthias Springer 819f287da8aSMatthias Springer /// Return a MemRef type with a static identity layout (i.e., no layout map). If 820f287da8aSMatthias Springer /// the given tensor type is unranked, return an unranked MemRef type. 821f287da8aSMatthias Springer BaseMemRefType 822f287da8aSMatthias Springer bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, 8239bb63374SLei Zhang Attribute memorySpace) { 824f287da8aSMatthias Springer // Case 1: Unranked memref type. 825c1fa60b4STres Popp if (auto unrankedTensorType = 826c1fa60b4STres Popp llvm::dyn_cast<UnrankedTensorType>(tensorType)) { 827f287da8aSMatthias Springer return UnrankedMemRefType::get(unrankedTensorType.getElementType(), 828f287da8aSMatthias Springer memorySpace); 829f287da8aSMatthias Springer } 830f287da8aSMatthias Springer 831f287da8aSMatthias Springer // Case 2: Ranked memref type. 832c1fa60b4STres Popp auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType); 833f287da8aSMatthias Springer MemRefLayoutAttrInterface layout = {}; 834f287da8aSMatthias Springer return MemRefType::get(rankedTensorType.getShape(), 835f287da8aSMatthias Springer rankedTensorType.getElementType(), layout, 8369bb63374SLei Zhang memorySpace); 837f287da8aSMatthias Springer } 838f7f0c7f7SMatthias Springer 839f3483c23SMatthias Springer //===----------------------------------------------------------------------===// 840f3483c23SMatthias Springer // Default implementations of interface methods 841f3483c23SMatthias Springer //===----------------------------------------------------------------------===// 842f3483c23SMatthias Springer 843f3483c23SMatthias Springer bool bufferization::detail::defaultResultBufferizesToMemoryWrite( 844f3483c23SMatthias Springer OpResult opResult, const AnalysisState &state) { 845f3483c23SMatthias Springer auto bufferizableOp = cast<BufferizableOpInterface>(opResult.getDefiningOp()); 8461ac248e4SMatthias Springer AliasingOpOperandList opOperands = 8471ac248e4SMatthias Springer bufferizableOp.getAliasingOpOperands(opResult, state); 848f3483c23SMatthias Springer 849f3483c23SMatthias Springer // Case 1: OpResults that have no aliasing OpOperand usually bufferize to 850f3483c23SMatthias Springer // memory writes. 8519fa6b350SMatthias Springer if (opOperands.getAliases().empty()) 852f3483c23SMatthias Springer return true; 853f3483c23SMatthias Springer 854f3483c23SMatthias Springer // Case 2: If an aliasing OpOperand bufferizes to a memory write, the OpResult 855f3483c23SMatthias Springer // may bufferize to a memory write. 8569fa6b350SMatthias Springer if (llvm::any_of(opOperands, [&](AliasingOpOperand alias) { 8579fa6b350SMatthias Springer return state.bufferizesToMemoryWrite(*alias.opOperand); 858f3483c23SMatthias Springer })) 859f3483c23SMatthias Springer return true; 860f3483c23SMatthias Springer 861f3483c23SMatthias Springer // Case 3: Check if a nested aliasing OpOperand value bufferizes to a memory 862f3483c23SMatthias Springer // write. (Or: The reverse SSA use-def chain ends inside the reigon.) In that 863f3483c23SMatthias Springer // case, the OpResult bufferizes to a memory write. E.g.: 864f3483c23SMatthias Springer // 865f3483c23SMatthias Springer // %0 = "some_writing_op" : tensor<?xf32> 866f3483c23SMatthias Springer // %r = scf.if ... -> tensor<?xf32> { 867f3483c23SMatthias Springer // scf.yield %0 : tensor<?xf32> 868f3483c23SMatthias Springer // } else { 869f3483c23SMatthias Springer // %1 = "another_writing_op"(%0) : tensor<?xf32> 870f3483c23SMatthias Springer // scf.yield %1 : tensor<?xf32> 871f3483c23SMatthias Springer // } 872f3483c23SMatthias Springer // "some_reading_op"(%r) 873f3483c23SMatthias Springer // 874f3483c23SMatthias Springer // %r bufferizes to a memory write because an aliasing OpOperand value (%1) 875f3483c23SMatthias Springer // bufferizes to a memory write and the defining op is inside the scf.if. 876f3483c23SMatthias Springer // 877f3483c23SMatthias Springer // Note: This treatment of surrouding ops is useful for ops that have a 878f3483c23SMatthias Springer // region but no OpOperand such as scf.if or scf.execute_region. It simplifies 879f3483c23SMatthias Springer // the analysis considerably. 880f3483c23SMatthias Springer // 881f3483c23SMatthias Springer // "another_writing_op" in the above example should be able to bufferize 882f3483c23SMatthias Springer // inplace in the absence of another read of %0. However, if the scf.if op 883f3483c23SMatthias Springer // would not be considered a "write", the analysis would detect the 884f3483c23SMatthias Springer // following conflict: 885f3483c23SMatthias Springer // 886f3483c23SMatthias Springer // * read = some_reading_op 887f3483c23SMatthias Springer // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 888f3483c23SMatthias Springer // * conflictingWrite = %1 889f3483c23SMatthias Springer // 890f3483c23SMatthias Springer auto isMemoryWriteInsideOp = [&](Value v) { 891f3483c23SMatthias Springer Operation *op = getOwnerOfValue(v); 892f3483c23SMatthias Springer if (!opResult.getDefiningOp()->isAncestor(op)) 893f3483c23SMatthias Springer return false; 894f3483c23SMatthias Springer return state.bufferizesToMemoryWrite(v); 895f3483c23SMatthias Springer }; 8961f479c1eSMatthias Springer TraversalConfig config; 8971f479c1eSMatthias Springer config.alwaysIncludeLeaves = false; 8989fa6b350SMatthias Springer for (AliasingOpOperand alias : opOperands) { 899f3483c23SMatthias Springer if (!state 900*d9111f19SAmir Bishara .findValueInReverseUseDefChain(alias.opOperand, 9011f479c1eSMatthias Springer isMemoryWriteInsideOp, config) 902f3483c23SMatthias Springer .empty()) 903f3483c23SMatthias Springer return true; 904f3483c23SMatthias Springer } 905f3483c23SMatthias Springer return false; 906f3483c23SMatthias Springer } 907f3483c23SMatthias Springer 908a02ad6c1SMatthias Springer // Compute the AliasingOpOperandList for a given Value based on 909a02ad6c1SMatthias Springer // getAliasingValues. 9101ac248e4SMatthias Springer AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands( 911a02ad6c1SMatthias Springer Value value, const AnalysisState &state) { 912a02ad6c1SMatthias Springer Operation *op = getOwnerOfValue(value); 9139fa6b350SMatthias Springer SmallVector<AliasingOpOperand> result; 9141ac248e4SMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) { 915c1fa60b4STres Popp if (!llvm::isa<TensorType>(opOperand.get().getType())) 9161ac248e4SMatthias Springer continue; 917a02ad6c1SMatthias Springer AliasingValueList aliasingValues = state.getAliasingValues(opOperand); 918a02ad6c1SMatthias Springer for (const auto &it : aliasingValues) 919a02ad6c1SMatthias Springer if (it.value == value) 9209fa6b350SMatthias Springer result.emplace_back(&opOperand, it.relation, it.isDefinite); 9211ac248e4SMatthias Springer } 9229fa6b350SMatthias Springer return AliasingOpOperandList(std::move(result)); 9231ac248e4SMatthias Springer } 9241ac248e4SMatthias Springer 925f3483c23SMatthias Springer FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType( 926f3483c23SMatthias Springer Value value, const BufferizationOptions &options, 927878950b8SMatthias Springer SmallVector<Value> &invocationStack) { 928c1fa60b4STres Popp assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type"); 929f3483c23SMatthias Springer 930f3483c23SMatthias Springer // No further analysis is possible for a block argument. 931c1fa60b4STres Popp if (llvm::isa<BlockArgument>(value)) 932f3483c23SMatthias Springer return bufferization::getMemRefType(value, options); 933f3483c23SMatthias Springer 934f3483c23SMatthias Springer // Value is an OpResult. 935f3483c23SMatthias Springer Operation *op = getOwnerOfValue(value); 936c1fa60b4STres Popp auto opResult = llvm::cast<OpResult>(value); 937f3483c23SMatthias Springer AnalysisState state(options); 9381ac248e4SMatthias Springer AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); 9399fa6b350SMatthias Springer if (aliases.getNumAliases() > 0 && 9409fa6b350SMatthias Springer aliases.getAliases()[0].relation == BufferRelation::Equivalent) { 941f3483c23SMatthias Springer // If the OpResult has an equivalent OpOperand, both OpResult and 942f3483c23SMatthias Springer // OpOperand bufferize to the exact same buffer type. 9439fa6b350SMatthias Springer Value equivalentOperand = aliases.getAliases().front().opOperand->get(); 944878950b8SMatthias Springer return getBufferType(equivalentOperand, options, invocationStack); 945f3483c23SMatthias Springer } 946f3483c23SMatthias Springer 947f3483c23SMatthias Springer // If we do not know the memory space and there is no default memory space, 948f3483c23SMatthias Springer // report a failure. 949067d2779Sian Bearman auto memSpace = 950a5757c5bSChristian Sigg options.defaultMemorySpaceFn(cast<TensorType>(value.getType())); 951067d2779Sian Bearman if (!memSpace.has_value()) 952f3483c23SMatthias Springer return op->emitError("could not infer memory space"); 953f3483c23SMatthias Springer 954067d2779Sian Bearman return getMemRefType(value, options, /*layout=*/{}, *memSpace); 955f3483c23SMatthias Springer } 956f3483c23SMatthias Springer 957f7f0c7f7SMatthias Springer bool bufferization::detail::defaultIsRepetitiveRegion( 958f7f0c7f7SMatthias Springer BufferizableOpInterface bufferizableOp, unsigned index) { 959f7f0c7f7SMatthias Springer assert(index < bufferizableOp->getNumRegions() && "invalid region index"); 960f7f0c7f7SMatthias Springer auto regionInterface = 961f7f0c7f7SMatthias Springer dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation()); 962f7f0c7f7SMatthias Springer if (!regionInterface) 963f7f0c7f7SMatthias Springer return false; 964f7f0c7f7SMatthias Springer return regionInterface.isRepetitiveRegion(index); 965f7f0c7f7SMatthias Springer } 966f3483c23SMatthias Springer 9671ac248e4SMatthias Springer AliasingOpOperandList 968a02ad6c1SMatthias Springer bufferization::detail::unknownGetAliasingOpOperands(Value value) { 969a02ad6c1SMatthias Springer // TODO: Take into account successor blocks. 970a02ad6c1SMatthias Springer // No aliasing in case of non-entry blocks. 971a02ad6c1SMatthias Springer if (auto bbArg = dyn_cast<BlockArgument>(value)) 972a02ad6c1SMatthias Springer if (bbArg.getOwner() != &bbArg.getOwner()->getParent()->getBlocks().front()) 973a02ad6c1SMatthias Springer return {}; 974a02ad6c1SMatthias Springer 975a02ad6c1SMatthias Springer // Unknown op: Conservatively assume that each OpResult may alias with every 976a02ad6c1SMatthias Springer // OpOperand. In addition, each block argument of an entry block may alias 977a02ad6c1SMatthias Springer // with every OpOperand. 9781ac248e4SMatthias Springer AliasingOpOperandList r; 979a02ad6c1SMatthias Springer for (OpOperand &operand : value.getDefiningOp()->getOpOperands()) 980a02ad6c1SMatthias Springer if (isa<TensorType>(operand.get().getType())) 9819fa6b350SMatthias Springer r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false}); 982f3483c23SMatthias Springer return r; 983f3483c23SMatthias Springer } 984f3483c23SMatthias Springer 985a02ad6c1SMatthias Springer AliasingValueList 986a02ad6c1SMatthias Springer bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) { 987a02ad6c1SMatthias Springer // TODO: Take into account successor blocks. 988a02ad6c1SMatthias Springer // Unknown op: Conservatively assume that each OpResult may alias with every 989a02ad6c1SMatthias Springer // OpOperand. In addition, each block argument of an entry block may alias 990a02ad6c1SMatthias Springer // with every OpOperand. 991a02ad6c1SMatthias Springer AliasingValueList r; 992f3483c23SMatthias Springer for (OpResult result : opOperand.getOwner()->getOpResults()) 993c1fa60b4STres Popp if (llvm::isa<TensorType>(result.getType())) 9949fa6b350SMatthias Springer r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false}); 995a02ad6c1SMatthias Springer for (Region ®ion : opOperand.getOwner()->getRegions()) 996a02ad6c1SMatthias Springer if (!region.getBlocks().empty()) 997a02ad6c1SMatthias Springer for (BlockArgument bbArg : region.getBlocks().front().getArguments()) 998a5757c5bSChristian Sigg if (isa<TensorType>(bbArg.getType())) 999a02ad6c1SMatthias Springer r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false}); 1000f3483c23SMatthias Springer return r; 1001f3483c23SMatthias Springer } 10028f2d83daSMatthias Springer 10038f2d83daSMatthias Springer bool bufferization::detail::defaultHasTensorSemantics(Operation *op) { 10048f2d83daSMatthias Springer auto isaTensor = [](Type t) { return isa<TensorType>(t); }; 10058f2d83daSMatthias Springer bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) { 10068f2d83daSMatthias Springer return any_of(r.getBlocks(), [&](Block &b) { 10078f2d83daSMatthias Springer return any_of(b.getArguments(), [&](BlockArgument bbArg) { 10088f2d83daSMatthias Springer return isaTensor(bbArg.getType()); 10098f2d83daSMatthias Springer }); 10108f2d83daSMatthias Springer }); 10118f2d83daSMatthias Springer }); 10128f2d83daSMatthias Springer if (hasTensorBlockArgument) 10138f2d83daSMatthias Springer return true; 10148f2d83daSMatthias Springer 10158f2d83daSMatthias Springer if (any_of(op->getResultTypes(), isaTensor)) 10168f2d83daSMatthias Springer return true; 10178f2d83daSMatthias Springer return any_of(op->getOperandTypes(), isaTensor); 10188f2d83daSMatthias Springer } 1019