xref: /llvm-project/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (revision d9111f19d2ea53d8ce105b3d09425394ccf37969)
17a1579acSMatthias Springer //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
27a1579acSMatthias Springer //
37a1579acSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47a1579acSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
57a1579acSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67a1579acSMatthias Springer //
77a1579acSMatthias Springer //===----------------------------------------------------------------------===//
87a1579acSMatthias Springer 
97a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
107a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1136550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
127a1579acSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1379f11591SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
147a1579acSMatthias Springer #include "mlir/IR/AsmState.h"
157a1579acSMatthias Springer #include "mlir/IR/BuiltinOps.h"
164d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
177a1579acSMatthias Springer #include "mlir/IR/Operation.h"
187a1579acSMatthias Springer #include "mlir/IR/TypeUtilities.h"
197a1579acSMatthias Springer #include "mlir/IR/Value.h"
20f7f0c7f7SMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h"
21878950b8SMatthias Springer #include "llvm/ADT/ScopeExit.h"
227a1579acSMatthias Springer #include "llvm/Support/Debug.h"
237a1579acSMatthias Springer 
2487c770bbSMatthias Springer //===----------------------------------------------------------------------===//
2587c770bbSMatthias Springer // BufferizableOpInterface
2687c770bbSMatthias Springer //===----------------------------------------------------------------------===//
2787c770bbSMatthias Springer 
287a1579acSMatthias Springer namespace mlir {
297a1579acSMatthias Springer namespace bufferization {
307a1579acSMatthias Springer 
317a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
327a1579acSMatthias Springer 
337a1579acSMatthias Springer } // namespace bufferization
347a1579acSMatthias Springer } // namespace mlir
357a1579acSMatthias Springer 
36faa9be75SMatthias Springer MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
37faa9be75SMatthias Springer 
387a1579acSMatthias Springer #define DEBUG_TYPE "bufferizable-op-interface"
397a1579acSMatthias Springer #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
407a1579acSMatthias Springer #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
417a1579acSMatthias Springer 
427a1579acSMatthias Springer using namespace mlir;
437a1579acSMatthias Springer using namespace bufferization;
447a1579acSMatthias Springer 
45c89c31a2SMatthias Springer static bool isRepetitiveRegion(Region *region,
46c89c31a2SMatthias Springer                                const BufferizationOptions &options) {
47c89c31a2SMatthias Springer   Operation *op = region->getParentOp();
48c89c31a2SMatthias Springer   if (auto bufferizableOp = options.dynCastBufferizableOp(op))
49c89c31a2SMatthias Springer     if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
50c89c31a2SMatthias Springer       return true;
51c89c31a2SMatthias Springer   return false;
52c89c31a2SMatthias Springer }
53c89c31a2SMatthias Springer 
549312b4f9SMartin Erhart Region *AnalysisState::getEnclosingRepetitiveRegion(
5545cd0e45SMatthias Springer     Operation *op, const BufferizationOptions &options) {
5645cd0e45SMatthias Springer   if (!op->getBlock())
5745cd0e45SMatthias Springer     return nullptr;
589312b4f9SMartin Erhart   if (auto iter = enclosingRepetitiveRegionCache.find_as(op);
599312b4f9SMartin Erhart       iter != enclosingRepetitiveRegionCache.end())
609312b4f9SMartin Erhart     return iter->second;
619312b4f9SMartin Erhart   return enclosingRepetitiveRegionCache[op] =
629312b4f9SMartin Erhart              getEnclosingRepetitiveRegion(op->getBlock(), options);
6345cd0e45SMatthias Springer }
6445cd0e45SMatthias Springer 
659312b4f9SMartin Erhart Region *AnalysisState::getEnclosingRepetitiveRegion(
6645cd0e45SMatthias Springer     Value value, const BufferizationOptions &options) {
679312b4f9SMartin Erhart   if (auto iter = enclosingRepetitiveRegionCache.find_as(value);
689312b4f9SMartin Erhart       iter != enclosingRepetitiveRegionCache.end())
699312b4f9SMartin Erhart     return iter->second;
709312b4f9SMartin Erhart 
7145cd0e45SMatthias Springer   Region *region = value.getParentRegion();
729312b4f9SMartin Erhart   // Collect all visited regions since we only know the repetitive region we
739312b4f9SMartin Erhart   // want to map it to later on
749312b4f9SMartin Erhart   SmallVector<Region *> visitedRegions;
7545cd0e45SMatthias Springer   while (region) {
769312b4f9SMartin Erhart     visitedRegions.push_back(region);
77c89c31a2SMatthias Springer     if (isRepetitiveRegion(region, options))
789312b4f9SMartin Erhart       break;
79c89c31a2SMatthias Springer     region = region->getParentRegion();
8045cd0e45SMatthias Springer   }
819312b4f9SMartin Erhart   enclosingRepetitiveRegionCache[value] = region;
829312b4f9SMartin Erhart   for (Region *r : visitedRegions)
839312b4f9SMartin Erhart     enclosingRepetitiveRegionCache[r] = region;
849312b4f9SMartin Erhart   return region;
8545cd0e45SMatthias Springer }
8645cd0e45SMatthias Springer 
879312b4f9SMartin Erhart Region *AnalysisState::getEnclosingRepetitiveRegion(
8845cd0e45SMatthias Springer     Block *block, const BufferizationOptions &options) {
899312b4f9SMartin Erhart   if (auto iter = enclosingRepetitiveRegionCache.find_as(block);
909312b4f9SMartin Erhart       iter != enclosingRepetitiveRegionCache.end())
919312b4f9SMartin Erhart     return iter->second;
929312b4f9SMartin Erhart 
9345cd0e45SMatthias Springer   Region *region = block->getParent();
9445cd0e45SMatthias Springer   Operation *op = nullptr;
959312b4f9SMartin Erhart   // Collect all visited regions since we only know the repetitive region we
969312b4f9SMartin Erhart   // want to map it to later on
979312b4f9SMartin Erhart   SmallVector<Region *> visitedRegions;
9845cd0e45SMatthias Springer   do {
9945cd0e45SMatthias Springer     op = region->getParentOp();
100c89c31a2SMatthias Springer     if (isRepetitiveRegion(region, options))
1019312b4f9SMartin Erhart       break;
10245cd0e45SMatthias Springer   } while ((region = op->getParentRegion()));
1039312b4f9SMartin Erhart 
1049312b4f9SMartin Erhart   enclosingRepetitiveRegionCache[block] = region;
1059312b4f9SMartin Erhart   for (Region *r : visitedRegions)
1069312b4f9SMartin Erhart     enclosingRepetitiveRegionCache[r] = region;
1079312b4f9SMartin Erhart   return region;
10845cd0e45SMatthias Springer }
10945cd0e45SMatthias Springer 
1109312b4f9SMartin Erhart void AnalysisState::resetCache() { enclosingRepetitiveRegionCache.clear(); }
1119312b4f9SMartin Erhart 
112c89c31a2SMatthias Springer Region *bufferization::getNextEnclosingRepetitiveRegion(
113c89c31a2SMatthias Springer     Region *region, const BufferizationOptions &options) {
114c89c31a2SMatthias Springer   assert(isRepetitiveRegion(region, options) && "expected repetitive region");
115c89c31a2SMatthias Springer   while ((region = region->getParentRegion())) {
116c89c31a2SMatthias Springer     if (isRepetitiveRegion(region, options))
117c89c31a2SMatthias Springer       break;
118c89c31a2SMatthias Springer   }
119c89c31a2SMatthias Springer   return region;
120c89c31a2SMatthias Springer }
121c89c31a2SMatthias Springer 
1221e1a3112SMatthias Springer Region *bufferization::getParallelRegion(Region *region,
1231e1a3112SMatthias Springer                                          const BufferizationOptions &options) {
1241e1a3112SMatthias Springer   while (region) {
1251e1a3112SMatthias Springer     auto bufferizableOp = options.dynCastBufferizableOp(region->getParentOp());
1261e1a3112SMatthias Springer     if (bufferizableOp &&
1271e1a3112SMatthias Springer         bufferizableOp.isParallelRegion(region->getRegionNumber())) {
1281e1a3112SMatthias Springer       assert(isRepetitiveRegion(region, options) &&
1291e1a3112SMatthias Springer              "expected that all parallel regions are also repetitive regions");
1301e1a3112SMatthias Springer       return region;
1311e1a3112SMatthias Springer     }
1321e1a3112SMatthias Springer     region = region->getParentRegion();
1331e1a3112SMatthias Springer   }
1341e1a3112SMatthias Springer   return nullptr;
1351e1a3112SMatthias Springer }
1361e1a3112SMatthias Springer 
137111c9196SMatthias Springer Operation *bufferization::getOwnerOfValue(Value value) {
138c1fa60b4STres Popp   if (auto opResult = llvm::dyn_cast<OpResult>(value))
139c0b0b6a0SMatthias Springer     return opResult.getDefiningOp();
140c1fa60b4STres Popp   return llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
141c0b0b6a0SMatthias Springer }
142c0b0b6a0SMatthias Springer 
143b3ebe3beSMatthias Springer /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
144b3ebe3beSMatthias Springer /// shaped value is copied. Otherwise, a tensor with undefined contents is
145b3ebe3beSMatthias Springer /// allocated.
14645b995cdSMatthias Springer FailureOr<Value> bufferization::allocateTensorForShapedValue(
1476bf043e7SMartin Erhart     OpBuilder &b, Location loc, Value shapedValue,
14845b995cdSMatthias Springer     const BufferizationOptions &options, bool copy) {
149b3ebe3beSMatthias Springer   Value tensor;
150c1fa60b4STres Popp   if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
151b3ebe3beSMatthias Springer     tensor = shapedValue;
152c1fa60b4STres Popp   } else if (llvm::isa<MemRefType>(shapedValue.getType())) {
153b3ebe3beSMatthias Springer     tensor = b.create<ToTensorOp>(loc, shapedValue);
154c1fa60b4STres Popp   } else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) ||
155c1fa60b4STres Popp              llvm::isa<UnrankedMemRefType>(shapedValue.getType())) {
156d7f72d4bSMatthias Springer     return getOwnerOfValue(shapedValue)
157d7f72d4bSMatthias Springer         ->emitError("copying of unranked tensors is not implemented");
15879f11591SMatthias Springer   } else {
159b3ebe3beSMatthias Springer     llvm_unreachable("expected RankedTensorType or MemRefType");
16079f11591SMatthias Springer   }
161c1fa60b4STres Popp   RankedTensorType tensorType = llvm::cast<RankedTensorType>(tensor.getType());
162b3ebe3beSMatthias Springer   SmallVector<Value> dynamicSizes;
163b3ebe3beSMatthias Springer   if (!copy) {
164b3ebe3beSMatthias Springer     // Compute the dynamic part of the shape.
165b3ebe3beSMatthias Springer     // First try to query the shape via ReifyRankedShapedTypeOpInterface.
166b3ebe3beSMatthias Springer     bool reifiedShapes = false;
167c1fa60b4STres Popp     if (llvm::isa<RankedTensorType>(shapedValue.getType()) &&
168c1fa60b4STres Popp         llvm::isa<OpResult>(shapedValue)) {
169b3ebe3beSMatthias Springer       ReifiedRankedShapedTypeDims resultDims;
170758329dcSMatthias Springer       if (succeeded(
171758329dcSMatthias Springer               reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) {
172b3ebe3beSMatthias Springer         reifiedShapes = true;
173b3ebe3beSMatthias Springer         auto &shape =
174c1fa60b4STres Popp             resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()];
175b3ebe3beSMatthias Springer         for (const auto &dim : enumerate(tensorType.getShape()))
176b3ebe3beSMatthias Springer           if (ShapedType::isDynamic(dim.value()))
177129f1001SKazu Hirata             dynamicSizes.push_back(cast<Value>(shape[dim.index()]));
178b3ebe3beSMatthias Springer       }
179b3ebe3beSMatthias Springer     }
180b3ebe3beSMatthias Springer 
181b3ebe3beSMatthias Springer     // If the shape could not be reified, create DimOps.
182b3ebe3beSMatthias Springer     if (!reifiedShapes)
183b3ebe3beSMatthias Springer       populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
184b3ebe3beSMatthias Springer   }
185b3ebe3beSMatthias Springer 
186c0b0b6a0SMatthias Springer   // Create AllocTensorOp.
1873474d10eSMatthias Springer   auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
1883474d10eSMatthias Springer                                                copy ? tensor : Value());
189c0b0b6a0SMatthias Springer 
190c0b0b6a0SMatthias Springer   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
191c0b0b6a0SMatthias Springer   if (copy)
192c0b0b6a0SMatthias Springer     return allocTensorOp.getResult();
193c0b0b6a0SMatthias Springer   FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
194c0b0b6a0SMatthias Springer   if (failed(copyBufferType))
195c0b0b6a0SMatthias Springer     return failure();
1966f1e23b4SKunwar Grover   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
1979bb63374SLei Zhang   if (!memorySpace)
1986f1e23b4SKunwar Grover     memorySpace = options.defaultMemorySpaceFn(tensorType);
1996f1e23b4SKunwar Grover   if (memorySpace.has_value())
2006f1e23b4SKunwar Grover     allocTensorOp.setMemorySpaceAttr(memorySpace.value());
20145b995cdSMatthias Springer   return allocTensorOp.getResult();
20279f11591SMatthias Springer }
20379f11591SMatthias Springer 
20487c770bbSMatthias Springer LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
20587c770bbSMatthias Springer     RewriterBase &rewriter, const AnalysisState &state) {
20687b46776SMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
20787c770bbSMatthias Springer   Operation *op = getOperation();
20887b46776SMatthias Springer   SmallVector<OpOperand *> outOfPlaceOpOperands;
20979f11591SMatthias Springer   DenseSet<OpOperand *> copiedOpOperands;
210a02ad6c1SMatthias Springer   SmallVector<Value> outOfPlaceValues;
211a02ad6c1SMatthias Springer   DenseSet<Value> copiedOpValues;
21287b46776SMatthias Springer 
21387b46776SMatthias Springer   // Find all out-of-place OpOperands.
21487c770bbSMatthias Springer   for (OpOperand &opOperand : op->getOpOperands()) {
21587c770bbSMatthias Springer     Type operandType = opOperand.get().getType();
216c1fa60b4STres Popp     if (!llvm::isa<TensorType>(operandType))
21787c770bbSMatthias Springer       continue;
21887c770bbSMatthias Springer     if (state.isInPlace(opOperand))
21987c770bbSMatthias Springer       continue;
220c1fa60b4STres Popp     if (llvm::isa<UnrankedTensorType>(operandType))
221d7f72d4bSMatthias Springer       return op->emitError("copying of unranked tensors is not implemented");
22287b46776SMatthias Springer 
223a02ad6c1SMatthias Springer     AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
224a02ad6c1SMatthias Springer     if (aliasingValues.getNumAliases() == 1 &&
2256ecebb49SMatthias Springer         isa<OpResult>(aliasingValues.getAliases()[0].value) &&
22687b46776SMatthias Springer         !state.bufferizesToMemoryWrite(opOperand) &&
227a02ad6c1SMatthias Springer         state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
2289fa6b350SMatthias Springer                 .getNumAliases() == 1 &&
229a02ad6c1SMatthias Springer         !isa<UnrankedTensorType>(
230a02ad6c1SMatthias Springer             aliasingValues.getAliases()[0].value.getType())) {
23187b46776SMatthias Springer       // The op itself does not write but may create exactly one alias. Instead
23287b46776SMatthias Springer       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
23387b46776SMatthias Springer       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
234d7f72d4bSMatthias Springer       // where the result is usually a smaller part of the source). Do not apply
235d7f72d4bSMatthias Springer       // this optimization if the OpResult is an unranked tensor (because those
236d7f72d4bSMatthias Springer       // cannot be copied at the moment).
237a02ad6c1SMatthias Springer       Value value = aliasingValues.getAliases()[0].value;
238a02ad6c1SMatthias Springer       outOfPlaceValues.push_back(value);
23979f11591SMatthias Springer       if (!state.canOmitTensorCopy(opOperand))
240a02ad6c1SMatthias Springer         copiedOpValues.insert(value);
24187b46776SMatthias Springer     } else {
24287b46776SMatthias Springer       // In all other cases, make a copy of the OpOperand.
24387b46776SMatthias Springer       outOfPlaceOpOperands.push_back(&opOperand);
24479f11591SMatthias Springer       if (!state.canOmitTensorCopy(opOperand))
24579f11591SMatthias Springer         copiedOpOperands.insert(&opOperand);
24687b46776SMatthias Springer     }
24787b46776SMatthias Springer   }
24887b46776SMatthias Springer 
24987b46776SMatthias Springer   // Insert copies of OpOperands.
25087b46776SMatthias Springer   rewriter.setInsertionPoint(op);
25187b46776SMatthias Springer   for (OpOperand *opOperand : outOfPlaceOpOperands) {
25245b995cdSMatthias Springer     FailureOr<Value> copy = allocateTensorForShapedValue(
2536bf043e7SMartin Erhart         rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
25479f11591SMatthias Springer         copiedOpOperands.contains(opOperand));
25545b995cdSMatthias Springer     if (failed(copy))
25645b995cdSMatthias Springer       return failure();
2575fcf907bSMatthias Springer     rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
25887c770bbSMatthias Springer   }
25987b46776SMatthias Springer 
260a02ad6c1SMatthias Springer   // Insert copies of Values.
26187b46776SMatthias Springer   rewriter.setInsertionPointAfter(op);
262a02ad6c1SMatthias Springer   for (Value value : outOfPlaceValues) {
26345b995cdSMatthias Springer     FailureOr<Value> copy = allocateTensorForShapedValue(
2646bf043e7SMartin Erhart         rewriter, op->getLoc(), value, state.getOptions(),
2656bf043e7SMartin Erhart         copiedOpValues.count(value));
26645b995cdSMatthias Springer     if (failed(copy))
26745b995cdSMatthias Springer       return failure();
268a02ad6c1SMatthias Springer     SmallVector<OpOperand *> uses = llvm::to_vector(
269a02ad6c1SMatthias Springer         llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; }));
27087b46776SMatthias Springer     for (OpOperand *use : uses) {
27187b46776SMatthias Springer       // Do not update the alloc_tensor op that we just created.
2725f6d5ca0SMatthias Springer       if (use->getOwner() == copy->getDefiningOp())
2735f6d5ca0SMatthias Springer         continue;
2745f6d5ca0SMatthias Springer       // tensor.dim ops may have been created to be used as alloc_tensor op
2755f6d5ca0SMatthias Springer       // dynamic extents. Do not update these either.
2765f6d5ca0SMatthias Springer       if (isa<tensor::DimOp>(use->getOwner()))
2775f6d5ca0SMatthias Springer         continue;
2785fcf907bSMatthias Springer       rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); });
27987b46776SMatthias Springer     }
28087b46776SMatthias Springer   }
28187b46776SMatthias Springer 
28287c770bbSMatthias Springer   return success();
28387c770bbSMatthias Springer }
28487c770bbSMatthias Springer 
2857a1579acSMatthias Springer //===----------------------------------------------------------------------===//
2861534177fSMatthias Springer // OpFilter
2871534177fSMatthias Springer //===----------------------------------------------------------------------===//
2881534177fSMatthias Springer 
2891534177fSMatthias Springer bool OpFilter::isOpAllowed(Operation *op) const {
2901534177fSMatthias Springer   // All other ops: Allow/disallow according to filter.
2911534177fSMatthias Springer   bool isAllowed = !hasAllowRule();
2921534177fSMatthias Springer   for (const Entry &entry : entries) {
2931534177fSMatthias Springer     bool filterResult = entry.fn(op);
2941534177fSMatthias Springer     switch (entry.type) {
2951534177fSMatthias Springer     case Entry::ALLOW:
2961534177fSMatthias Springer       isAllowed |= filterResult;
2971534177fSMatthias Springer       break;
2981534177fSMatthias Springer     case Entry::DENY:
2991534177fSMatthias Springer       if (filterResult)
3001534177fSMatthias Springer         // DENY filter matches. This op is no allowed. (Even if other ALLOW
3011534177fSMatthias Springer         // filters may match.)
3021534177fSMatthias Springer         return false;
3031534177fSMatthias Springer     };
3041534177fSMatthias Springer   }
3051534177fSMatthias Springer   return isAllowed;
3061534177fSMatthias Springer }
3071534177fSMatthias Springer 
3081534177fSMatthias Springer //===----------------------------------------------------------------------===//
3097a1579acSMatthias Springer // BufferizationOptions
3107a1579acSMatthias Springer //===----------------------------------------------------------------------===//
3117a1579acSMatthias Springer 
31275ef84bfSOleg Shyshkov namespace {
31375ef84bfSOleg Shyshkov 
31475ef84bfSOleg Shyshkov /// Default function arg type converter: Use a fully dynamic layout map.
31575ef84bfSOleg Shyshkov BaseMemRefType
31675ef84bfSOleg Shyshkov defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
31791c11574SAndrzej Warzyński                                 func::FuncOp funcOp,
31875ef84bfSOleg Shyshkov                                 const BufferizationOptions &options) {
31975ef84bfSOleg Shyshkov   return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
32075ef84bfSOleg Shyshkov }
321606f7c8fSMatthias Springer /// Default unknown type converter: Use a fully dynamic layout map.
32275ef84bfSOleg Shyshkov BaseMemRefType
3239bb63374SLei Zhang defaultUnknownTypeConverter(Value value, Attribute memorySpace,
324606f7c8fSMatthias Springer                             const BufferizationOptions &options) {
325c1fa60b4STres Popp   return getMemRefTypeWithFullyDynamicLayout(
326c1fa60b4STres Popp       llvm::cast<TensorType>(value.getType()), memorySpace);
327606f7c8fSMatthias Springer }
328606f7c8fSMatthias Springer 
329ea3e8d3bSJie Fu } // namespace
33075ef84bfSOleg Shyshkov 
3317a1579acSMatthias Springer // Default constructor for BufferizationOptions.
332606f7c8fSMatthias Springer BufferizationOptions::BufferizationOptions()
33375ef84bfSOleg Shyshkov     : functionArgTypeConverterFn(defaultFunctionArgTypeConverter),
33475ef84bfSOleg Shyshkov       unknownTypeConverterFn(defaultUnknownTypeConverter) {}
3357a1579acSMatthias Springer 
336d6dab38aSMatthias Springer bool BufferizationOptions::isOpAllowed(Operation *op) const {
337d6dab38aSMatthias Springer   // Special case: If function boundary bufferization is deactivated, do not
338d6dab38aSMatthias Springer   // allow ops that belong to the `func` dialect.
339d6dab38aSMatthias Springer   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
340d6dab38aSMatthias Springer   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
341d6dab38aSMatthias Springer     return false;
342d6dab38aSMatthias Springer 
3431534177fSMatthias Springer   return opFilter.isOpAllowed(op);
344d6dab38aSMatthias Springer }
345d6dab38aSMatthias Springer 
3467a1579acSMatthias Springer BufferizableOpInterface
3477a1579acSMatthias Springer BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
348db604911SBenjamin Kramer   if (!isOpAllowed(op))
349db604911SBenjamin Kramer     return nullptr;
3509785eb1bSMatthias Springer   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
3519785eb1bSMatthias Springer   if (!bufferizableOp)
3527a1579acSMatthias Springer     return nullptr;
3539785eb1bSMatthias Springer   return bufferizableOp;
3547a1579acSMatthias Springer }
3557a1579acSMatthias Springer 
3567a1579acSMatthias Springer BufferizableOpInterface
3577a1579acSMatthias Springer BufferizationOptions::dynCastBufferizableOp(Value value) const {
3581f479c1eSMatthias Springer   return dynCastBufferizableOp(getOwnerOfValue(value));
3597a1579acSMatthias Springer }
3607a1579acSMatthias Springer 
36175ef84bfSOleg Shyshkov void BufferizationOptions::setFunctionBoundaryTypeConversion(
36275ef84bfSOleg Shyshkov     LayoutMapOption layoutMapOption) {
36375ef84bfSOleg Shyshkov   functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
36491c11574SAndrzej Warzyński                                    func::FuncOp funcOp,
36575ef84bfSOleg Shyshkov                                    const BufferizationOptions &options) {
36675ef84bfSOleg Shyshkov     if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
36775ef84bfSOleg Shyshkov       return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
36875ef84bfSOleg Shyshkov                                                                   memorySpace);
36975ef84bfSOleg Shyshkov     return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
37075ef84bfSOleg Shyshkov                                                               memorySpace);
37175ef84bfSOleg Shyshkov   };
37275ef84bfSOleg Shyshkov   inferFunctionResultLayout =
37375ef84bfSOleg Shyshkov       layoutMapOption == LayoutMapOption::InferLayoutMap;
37475ef84bfSOleg Shyshkov }
37575ef84bfSOleg Shyshkov 
3767a1579acSMatthias Springer //===----------------------------------------------------------------------===//
3777a1579acSMatthias Springer // Helper functions for BufferizableOpInterface
3787a1579acSMatthias Springer //===----------------------------------------------------------------------===//
3797a1579acSMatthias Springer 
3807a1579acSMatthias Springer static void setInsertionPointAfter(OpBuilder &b, Value value) {
381c1fa60b4STres Popp   if (auto bbArg = llvm::dyn_cast<BlockArgument>(value)) {
3827a1579acSMatthias Springer     b.setInsertionPointToStart(bbArg.getOwner());
3837a1579acSMatthias Springer   } else {
3847a1579acSMatthias Springer     b.setInsertionPointAfter(value.getDefiningOp());
3857a1579acSMatthias Springer   }
3867a1579acSMatthias Springer }
3877a1579acSMatthias Springer 
388a02ad6c1SMatthias Springer /// Determine which OpOperand* will alias with `value` if the op is bufferized
389a02ad6c1SMatthias Springer /// in place. Return all tensor OpOperand* if the op is not bufferizable.
390a02ad6c1SMatthias Springer AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const {
391a02ad6c1SMatthias Springer   if (Operation *op = getOwnerOfValue(value))
3922fe40c34SMatthias Springer     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
393a02ad6c1SMatthias Springer       return bufferizableOp.getAliasingOpOperands(value, *this);
394f3483c23SMatthias Springer 
395f3483c23SMatthias Springer   // The op is not bufferizable.
396a02ad6c1SMatthias Springer   return detail::unknownGetAliasingOpOperands(value);
3977a1579acSMatthias Springer }
3987a1579acSMatthias Springer 
399a02ad6c1SMatthias Springer /// Determine which Values will alias with `opOperand` if the op is bufferized
400a02ad6c1SMatthias Springer /// in place. Return all tensor Values if the op is not bufferizable.
401a02ad6c1SMatthias Springer AliasingValueList AnalysisState::getAliasingValues(OpOperand &opOperand) const {
4027a1579acSMatthias Springer   if (auto bufferizableOp =
4032fe40c34SMatthias Springer           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
404a02ad6c1SMatthias Springer     return bufferizableOp.getAliasingValues(opOperand, *this);
405f3483c23SMatthias Springer 
406f3483c23SMatthias Springer   // The op is not bufferizable.
407a02ad6c1SMatthias Springer   return detail::unknownGetAliasingValues(opOperand);
4087a1579acSMatthias Springer }
4097a1579acSMatthias Springer 
4107a1579acSMatthias Springer /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
4117a1579acSMatthias Springer /// op is not bufferizable.
4129597b16aSMatthias Springer bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
4137a1579acSMatthias Springer   if (auto bufferizableOp =
4142fe40c34SMatthias Springer           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
4157a1579acSMatthias Springer     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
4167a1579acSMatthias Springer 
4177a1579acSMatthias Springer   // Unknown op that returns a tensor. The inplace analysis does not support it.
4187a1579acSMatthias Springer   // Conservatively return true.
4197a1579acSMatthias Springer   return true;
4207a1579acSMatthias Springer }
4217a1579acSMatthias Springer 
4227a1579acSMatthias Springer /// Return true if `opOperand` bufferizes to a memory write. Return
4237a1579acSMatthias Springer /// `true` if the op is not bufferizable.
4249597b16aSMatthias Springer bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
4257a1579acSMatthias Springer   if (auto bufferizableOp =
4262fe40c34SMatthias Springer           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
4277a1579acSMatthias Springer     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
4287a1579acSMatthias Springer 
4297a1579acSMatthias Springer   // Unknown op that returns a tensor. The inplace analysis does not support it.
4307a1579acSMatthias Springer   // Conservatively return true.
4317a1579acSMatthias Springer   return true;
4327a1579acSMatthias Springer }
4337a1579acSMatthias Springer 
4347a1579acSMatthias Springer /// Return true if `opOperand` does neither read nor write but bufferizes to an
4357a1579acSMatthias Springer /// alias. Return false if the op is not bufferizable.
4369597b16aSMatthias Springer bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
4377a1579acSMatthias Springer   if (auto bufferizableOp =
4382fe40c34SMatthias Springer           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
4397a1579acSMatthias Springer     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
4407a1579acSMatthias Springer 
4417a1579acSMatthias Springer   // Unknown op that returns a tensor. The inplace analysis does not support it.
4427a1579acSMatthias Springer   // Conservatively return false.
4437a1579acSMatthias Springer   return false;
4447a1579acSMatthias Springer }
4457a1579acSMatthias Springer 
44634d65e81SMatthias Springer bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
447c1fa60b4STres Popp   auto opResult = llvm::dyn_cast<OpResult>(value);
44834d65e81SMatthias Springer   if (!opResult)
44934d65e81SMatthias Springer     return true;
45034d65e81SMatthias Springer   auto bufferizableOp = getOptions().dynCastBufferizableOp(value);
45134d65e81SMatthias Springer   if (!bufferizableOp)
45234d65e81SMatthias Springer     return true;
45334d65e81SMatthias Springer   return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *this);
45434d65e81SMatthias Springer }
45534d65e81SMatthias Springer 
4567a1579acSMatthias Springer /// Return true if the given value is read by an op that bufferizes to a memory
4577a1579acSMatthias Springer /// read. Also takes into account ops that create an alias but do not read by
4587a1579acSMatthias Springer /// themselves (e.g., ExtractSliceOp).
4599597b16aSMatthias Springer bool AnalysisState::isValueRead(Value value) const {
460c1fa60b4STres Popp   assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
4617a1579acSMatthias Springer   SmallVector<OpOperand *> workingSet;
4626ecebb49SMatthias Springer   DenseSet<OpOperand *> visited;
4637a1579acSMatthias Springer   for (OpOperand &use : value.getUses())
4647a1579acSMatthias Springer     workingSet.push_back(&use);
4657a1579acSMatthias Springer 
4667a1579acSMatthias Springer   while (!workingSet.empty()) {
4677a1579acSMatthias Springer     OpOperand *uMaybeReading = workingSet.pop_back_val();
4687be6ea12SKazu Hirata     if (!visited.insert(uMaybeReading).second)
4696ecebb49SMatthias Springer       continue;
4706ecebb49SMatthias Springer 
4717a1579acSMatthias Springer     // Skip over all ops that neither read nor write (but create an alias).
4727a1579acSMatthias Springer     if (bufferizesToAliasOnly(*uMaybeReading))
473a02ad6c1SMatthias Springer       for (AliasingValue alias : getAliasingValues(*uMaybeReading))
474a02ad6c1SMatthias Springer         for (OpOperand &use : alias.value.getUses())
4757a1579acSMatthias Springer           workingSet.push_back(&use);
4767a1579acSMatthias Springer     if (bufferizesToMemoryRead(*uMaybeReading))
4777a1579acSMatthias Springer       return true;
4787a1579acSMatthias Springer   }
4797a1579acSMatthias Springer 
4807a1579acSMatthias Springer   return false;
4817a1579acSMatthias Springer }
4827a1579acSMatthias Springer 
483*d9111f19SAmir Bishara // Starting from `opOperand`, follow the use-def chain in reverse, always
484*d9111f19SAmir Bishara // selecting the aliasing OpOperands. Find and return Values for which
485*d9111f19SAmir Bishara // `condition` evaluates to true. Uses of such matching Values are not
486*d9111f19SAmir Bishara // traversed any further, the visited aliasing opOperands will be preserved
487*d9111f19SAmir Bishara // through `visitedOpOperands`.
4889597b16aSMatthias Springer llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
489*d9111f19SAmir Bishara     OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
49008aa9563SAmir Bishara     TraversalConfig config,
49108aa9563SAmir Bishara     llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
4926ecebb49SMatthias Springer   llvm::DenseSet<Value> visited;
4937a1579acSMatthias Springer   llvm::SetVector<Value> result, workingSet;
494*d9111f19SAmir Bishara   workingSet.insert(opOperand->get());
495*d9111f19SAmir Bishara 
496*d9111f19SAmir Bishara   if (visitedOpOperands)
497*d9111f19SAmir Bishara     visitedOpOperands->insert(opOperand);
4987a1579acSMatthias Springer 
4997a1579acSMatthias Springer   while (!workingSet.empty()) {
5007a1579acSMatthias Springer     Value value = workingSet.pop_back_val();
5016ecebb49SMatthias Springer 
5026ecebb49SMatthias Springer     if (!config.revisitAlreadyVisitedValues && visited.contains(value)) {
5036ecebb49SMatthias Springer       // Stop traversal if value was already visited.
5046ecebb49SMatthias Springer       if (config.alwaysIncludeLeaves)
5056ecebb49SMatthias Springer         result.insert(value);
5066ecebb49SMatthias Springer       continue;
5076ecebb49SMatthias Springer     }
5086ecebb49SMatthias Springer     visited.insert(value);
5096ecebb49SMatthias Springer 
510fdb9e6a3SMatthias Springer     if (condition(value)) {
511fdb9e6a3SMatthias Springer       result.insert(value);
512fdb9e6a3SMatthias Springer       continue;
513fdb9e6a3SMatthias Springer     }
514fdb9e6a3SMatthias Springer 
515a02ad6c1SMatthias Springer     if (!config.followUnknownOps && !options.dynCastBufferizableOp(value)) {
5161f479c1eSMatthias Springer       // Stop iterating if `followUnknownOps` is unset and the op is either
5171f479c1eSMatthias Springer       // not bufferizable or excluded in the OpFilter.
5181f479c1eSMatthias Springer       if (config.alwaysIncludeLeaves)
5191f479c1eSMatthias Springer         result.insert(value);
5201f479c1eSMatthias Springer       continue;
5211f479c1eSMatthias Springer     }
5226cdd34b9SMatthias Springer 
523a02ad6c1SMatthias Springer     AliasingOpOperandList aliases = getAliasingOpOperands(value);
5241f479c1eSMatthias Springer     if (aliases.getNumAliases() == 0) {
5251f479c1eSMatthias Springer       // The traversal ends naturally if there are no more OpOperands that
5261f479c1eSMatthias Springer       // could be followed.
5271f479c1eSMatthias Springer       if (config.alwaysIncludeLeaves)
5287a1579acSMatthias Springer         result.insert(value);
5297a1579acSMatthias Springer       continue;
5307a1579acSMatthias Springer     }
5317a1579acSMatthias Springer 
5329fa6b350SMatthias Springer     for (AliasingOpOperand a : aliases) {
5331f479c1eSMatthias Springer       if (config.followEquivalentOnly &&
5341f479c1eSMatthias Springer           a.relation != BufferRelation::Equivalent) {
5359fa6b350SMatthias Springer         // Stop iterating if `followEquivalentOnly` is set but the alias is not
5369fa6b350SMatthias Springer         // equivalent.
5371f479c1eSMatthias Springer         if (config.alwaysIncludeLeaves)
5389fa6b350SMatthias Springer           result.insert(value);
539aa909483SMatthias Springer         continue;
5409fa6b350SMatthias Springer       }
5411f479c1eSMatthias Springer 
5421f479c1eSMatthias Springer       if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) {
5431f479c1eSMatthias Springer         // Stop iterating if `followInPlaceOnly` is set but the alias is
5441f479c1eSMatthias Springer         // out-of-place.
5451f479c1eSMatthias Springer         if (config.alwaysIncludeLeaves)
5461f479c1eSMatthias Springer           result.insert(value);
5471f479c1eSMatthias Springer         continue;
5481f479c1eSMatthias Springer       }
5491f479c1eSMatthias Springer 
550aba0ef70SMatthias Springer       if (config.followSameTypeOrCastsOnly &&
551aba0ef70SMatthias Springer           a.opOperand->get().getType() != value.getType() &&
552a02ad6c1SMatthias Springer           !value.getDefiningOp<CastOpInterface>()) {
553aba0ef70SMatthias Springer         // Stop iterating if `followSameTypeOrCastsOnly` is set but the alias is
554aba0ef70SMatthias Springer         // has a different type and the op is not a cast.
555aba0ef70SMatthias Springer         if (config.alwaysIncludeLeaves)
556aba0ef70SMatthias Springer           result.insert(value);
557aba0ef70SMatthias Springer         continue;
558aba0ef70SMatthias Springer       }
559aba0ef70SMatthias Springer 
5601f479c1eSMatthias Springer       workingSet.insert(a.opOperand->get());
56108aa9563SAmir Bishara       if (visitedOpOperands)
56208aa9563SAmir Bishara         visitedOpOperands->insert(a.opOperand);
5639fa6b350SMatthias Springer     }
5647a1579acSMatthias Springer   }
5657a1579acSMatthias Springer 
5667a1579acSMatthias Springer   return result;
5677a1579acSMatthias Springer }
5687a1579acSMatthias Springer 
569*d9111f19SAmir Bishara // Find the values that define the contents of the given operand's value.
570*d9111f19SAmir Bishara llvm::SetVector<Value>
571*d9111f19SAmir Bishara AnalysisState::findDefinitions(OpOperand *opOperand) const {
5721f479c1eSMatthias Springer   TraversalConfig config;
5731f479c1eSMatthias Springer   config.alwaysIncludeLeaves = false;
57434d65e81SMatthias Springer   return findValueInReverseUseDefChain(
575*d9111f19SAmir Bishara       opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
576*d9111f19SAmir Bishara       config);
5777a1579acSMatthias Springer }
5787a1579acSMatthias Springer 
5799597b16aSMatthias Springer AnalysisState::AnalysisState(const BufferizationOptions &options)
580faa9be75SMatthias Springer     : AnalysisState(options, TypeID::get<AnalysisState>()) {}
581faa9be75SMatthias Springer 
582faa9be75SMatthias Springer AnalysisState::AnalysisState(const BufferizationOptions &options, TypeID type)
583faa9be75SMatthias Springer     : options(options), type(type) {
5849597b16aSMatthias Springer   for (const BufferizationOptions::AnalysisStateInitFn &fn :
5856fc11d4dSMatthias Springer        options.stateInitializers)
5866fc11d4dSMatthias Springer     fn(*this);
5876fc11d4dSMatthias Springer }
5887a1579acSMatthias Springer 
58979f11591SMatthias Springer bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
59079f11591SMatthias Springer   // Do not copy if the tensor has undefined contents.
59179f11591SMatthias Springer   if (hasUndefinedContents(&opOperand))
59279f11591SMatthias Springer     return true;
59379f11591SMatthias Springer 
59479f11591SMatthias Springer   // Do not copy if the buffer of the tensor is entirely overwritten (with
59579f11591SMatthias Springer   // values that do not depend on the old tensor).
59679f11591SMatthias Springer   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
59779f11591SMatthias Springer     return true;
59879f11591SMatthias Springer 
59979f11591SMatthias Springer   // Do not copy if the tensor is never read.
600a02ad6c1SMatthias Springer   AliasingValueList aliases = getAliasingValues(opOperand);
60179f11591SMatthias Springer   if (!bufferizesToMemoryRead(opOperand) &&
602a02ad6c1SMatthias Springer       llvm::none_of(aliases,
603a02ad6c1SMatthias Springer                     [&](AliasingValue a) { return isValueRead(a.value); }))
60479f11591SMatthias Springer     return true;
60579f11591SMatthias Springer 
60679f11591SMatthias Springer   // Default: Cannot omit the copy.
60779f11591SMatthias Springer   return false;
60879f11591SMatthias Springer }
60979f11591SMatthias Springer 
610a3bca118SMatthias Springer bool AnalysisState::isInPlace(OpOperand &opOperand) const {
611b3ebe3beSMatthias Springer   // ToMemrefOps are always in-place.
612b3ebe3beSMatthias Springer   if (isa<ToMemrefOp>(opOperand.getOwner()))
613b3ebe3beSMatthias Springer     return true;
614b3ebe3beSMatthias Springer 
615a3bca118SMatthias Springer   // In the absence of analysis information, OpOperands that bufferize to a
616a3bca118SMatthias Springer   // memory write are out-of-place, i.e., an alloc and copy is inserted.
617a3bca118SMatthias Springer   return !bufferizesToMemoryWrite(opOperand);
618a3bca118SMatthias Springer }
619a3bca118SMatthias Springer 
620a3bca118SMatthias Springer bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
621a3bca118SMatthias Springer   // In the absence of analysis information, we do not know if the values are
622a3bca118SMatthias Springer   // equivalent. The conservative answer is "false".
623a3bca118SMatthias Springer   return false;
624a3bca118SMatthias Springer }
625a3bca118SMatthias Springer 
626a3bca118SMatthias Springer bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
627a3bca118SMatthias Springer   // In the absence of analysis information, we do not know if the values may be
628a3bca118SMatthias Springer   // aliasing. The conservative answer is "true".
629f2ada383Slorenzo chelini   return true;
630a3bca118SMatthias Springer }
631a3bca118SMatthias Springer 
632a3bca118SMatthias Springer bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
633a3bca118SMatthias Springer   // In the absence of analysis information, the conservative answer is "false".
634a3bca118SMatthias Springer   return false;
635a3bca118SMatthias Springer }
636a3bca118SMatthias Springer 
6377a1579acSMatthias Springer // bufferization.to_memref is not allowed to change the rank.
6387a1579acSMatthias Springer static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
6397a1579acSMatthias Springer #ifndef NDEBUG
640c1fa60b4STres Popp   auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
641c1fa60b4STres Popp   assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
6427a1579acSMatthias Springer                                    rankedTensorType.getRank()) &&
6437a1579acSMatthias Springer          "to_memref would be invalid: mismatching ranks");
6447a1579acSMatthias Springer #endif
6457a1579acSMatthias Springer }
6467a1579acSMatthias Springer 
6475d50f51cSMatthias Springer FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
648b55d55ecSMatthias Springer                                           const BufferizationOptions &options) {
649ba9d886dSMatthias Springer #ifndef NDEBUG
650c1fa60b4STres Popp   auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
65126852423SMatthias Springer   assert(tensorType && "unexpected non-tensor type");
652ba9d886dSMatthias Springer #endif // NDEBUG
6537a1579acSMatthias Springer 
6547a1579acSMatthias Springer   // Replace "%t = to_tensor %m" with %m.
655b3ebe3beSMatthias Springer   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
65699260e95SMatthias Springer     return toTensorOp.getMemref();
6577a1579acSMatthias Springer 
6587a1579acSMatthias Springer   // Insert to_memref op.
6597a1579acSMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
660b3ebe3beSMatthias Springer   setInsertionPointAfter(rewriter, value);
6615d50f51cSMatthias Springer   FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
6625d50f51cSMatthias Springer   if (failed(memrefType))
6635d50f51cSMatthias Springer     return failure();
6645d50f51cSMatthias Springer   ensureToMemrefOpIsValid(value, *memrefType);
6655d50f51cSMatthias Springer   return rewriter
6665d50f51cSMatthias Springer       .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
6675d50f51cSMatthias Springer       .getResult();
6687a1579acSMatthias Springer }
6697a1579acSMatthias Springer 
670996834e6SMatthias Springer /// Return the buffer type for a given Value (tensor) after bufferization.
6715d50f51cSMatthias Springer FailureOr<BaseMemRefType>
672b55d55ecSMatthias Springer bufferization::getBufferType(Value value, const BufferizationOptions &options) {
673878950b8SMatthias Springer   SmallVector<Value> invocationStack;
674878950b8SMatthias Springer   return getBufferType(value, options, invocationStack);
675123c4b02SMatthias Springer }
676123c4b02SMatthias Springer 
677123c4b02SMatthias Springer /// Return the buffer type for a given Value (tensor) after bufferization.
678878950b8SMatthias Springer FailureOr<BaseMemRefType>
679878950b8SMatthias Springer bufferization::getBufferType(Value value, const BufferizationOptions &options,
680878950b8SMatthias Springer                              SmallVector<Value> &invocationStack) {
681c1fa60b4STres Popp   assert(llvm::isa<TensorType>(value.getType()) &&
682c1fa60b4STres Popp          "unexpected non-tensor type");
683878950b8SMatthias Springer   invocationStack.push_back(value);
684878950b8SMatthias Springer   auto popFromStack =
685878950b8SMatthias Springer       llvm::make_scope_exit([&]() { invocationStack.pop_back(); });
686123c4b02SMatthias Springer 
687123c4b02SMatthias Springer   // Try querying BufferizableOpInterface.
688c0b0b6a0SMatthias Springer   Operation *op = getOwnerOfValue(value);
689111c9196SMatthias Springer   auto bufferizableOp = options.dynCastBufferizableOp(op);
690111c9196SMatthias Springer   if (bufferizableOp)
691878950b8SMatthias Springer     return bufferizableOp.getBufferType(value, options, invocationStack);
692d7a9bf91SMatthias Springer 
693111c9196SMatthias Springer   // Op is not bufferizable.
694067d2779Sian Bearman   auto memSpace =
695a5757c5bSChristian Sigg       options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
696067d2779Sian Bearman   if (!memSpace.has_value())
697c0b0b6a0SMatthias Springer     return op->emitError("could not infer memory space");
698c0b0b6a0SMatthias Springer 
699067d2779Sian Bearman   return getMemRefType(value, options, /*layout=*/{}, *memSpace);
700d7a9bf91SMatthias Springer }
701d7a9bf91SMatthias Springer 
7028f2d83daSMatthias Springer bool bufferization::hasTensorSemantics(Operation *op) {
7038f2d83daSMatthias Springer   if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
7048f2d83daSMatthias Springer     return bufferizableOp.hasTensorSemantics();
7058f2d83daSMatthias Springer   return detail::defaultHasTensorSemantics(op);
7068f2d83daSMatthias Springer }
7078f2d83daSMatthias Springer 
7087a1579acSMatthias Springer void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
7097a1579acSMatthias Springer                                                   Operation *op,
7107a1579acSMatthias Springer                                                   ValueRange values) {
7119106d35bSMatthias Springer   assert(values.size() == op->getNumResults() &&
7129106d35bSMatthias Springer          "expected one value per OpResult");
7137a1579acSMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
7147a1579acSMatthias Springer 
7157a1579acSMatthias Springer   // Replace all OpResults with the given values.
7169106d35bSMatthias Springer   SmallVector<Value> replacements;
7177a1579acSMatthias Springer   for (OpResult opResult : op->getOpResults()) {
7187a1579acSMatthias Springer     Value replacement = values[opResult.getResultNumber()];
719c1fa60b4STres Popp     if (llvm::isa<TensorType>(opResult.getType())) {
7207a1579acSMatthias Springer       // The OpResult is a tensor. Such values are replaced with memrefs during
7217a1579acSMatthias Springer       // bufferization.
722c1fa60b4STres Popp       assert((llvm::isa<MemRefType>(replacement.getType()) ||
723c1fa60b4STres Popp               llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
7247a1579acSMatthias Springer              "tensor op result should be replaced with a memref value");
7257a1579acSMatthias Springer       // The existing uses of the OpResult still expect a tensor. Insert a
7267a1579acSMatthias Springer       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
7277a1579acSMatthias Springer       // loose all of its users and eventually DCE away.
728c30d2893SMatthias Springer       rewriter.setInsertionPointAfter(op);
7297a1579acSMatthias Springer       replacement = rewriter.create<bufferization::ToTensorOp>(
730ced2fc78SChristopher Bate           replacement.getLoc(), opResult.getType(), replacement);
7317a1579acSMatthias Springer     }
7329106d35bSMatthias Springer     replacements.push_back(replacement);
7337a1579acSMatthias Springer   }
7347a1579acSMatthias Springer 
7359106d35bSMatthias Springer   rewriter.replaceOp(op, replacements);
7367a1579acSMatthias Springer }
7377a1579acSMatthias Springer 
7387a1579acSMatthias Springer //===----------------------------------------------------------------------===//
739caa2a4aeSMatthias Springer // Bufferization-specific scoped alloc insertion support.
7407a1579acSMatthias Springer //===----------------------------------------------------------------------===//
7417a1579acSMatthias Springer 
74205e0495fSMatthias Springer /// Create a memref allocation with the given type and dynamic extents.
743248e113eSMatthias Springer FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
744248e113eSMatthias Springer                                                    MemRefType type,
745248e113eSMatthias Springer                                                    ValueRange dynShape) const {
746248e113eSMatthias Springer   if (allocationFn)
747248e113eSMatthias Springer     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
74805e0495fSMatthias Springer 
74905e0495fSMatthias Springer   // Default bufferallocation via AllocOp.
750b3ebe3beSMatthias Springer   if (bufferAlignment != 0)
751b3ebe3beSMatthias Springer     return b
752b3ebe3beSMatthias Springer         .create<memref::AllocOp>(loc, type, dynShape,
753b3ebe3beSMatthias Springer                                  b.getI64IntegerAttr(bufferAlignment))
754b3ebe3beSMatthias Springer         .getResult();
755b3ebe3beSMatthias Springer   return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
75605e0495fSMatthias Springer }
75705e0495fSMatthias Springer 
7587a1579acSMatthias Springer /// Create a memory copy between two memref buffers.
759248e113eSMatthias Springer LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
760248e113eSMatthias Springer                                                  Value from, Value to) const {
761248e113eSMatthias Springer   if (memCpyFn)
762248e113eSMatthias Springer     return (*memCpyFn)(b, loc, from, to);
7637a1579acSMatthias Springer 
7647a1579acSMatthias Springer   b.create<memref::CopyOp>(loc, from, to);
7657a1579acSMatthias Springer   return success();
7667a1579acSMatthias Springer }
7677a1579acSMatthias Springer 
7687a1579acSMatthias Springer //===----------------------------------------------------------------------===//
7694d67b278SJeff Niu // Bufferization-specific IRMapping support with debugging.
7707a1579acSMatthias Springer //===----------------------------------------------------------------------===//
7717a1579acSMatthias Springer 
772606f7c8fSMatthias Springer BaseMemRefType bufferization::getMemRefType(Value value,
77326852423SMatthias Springer                                             const BufferizationOptions &options,
77426852423SMatthias Springer                                             MemRefLayoutAttrInterface layout,
7759bb63374SLei Zhang                                             Attribute memorySpace) {
776c1fa60b4STres Popp   auto tensorType = llvm::cast<TensorType>(value.getType());
777b06614e2SMatthias Springer 
77826852423SMatthias Springer   // Case 1: Unranked memref type.
779c1fa60b4STres Popp   if (auto unrankedTensorType =
780c1fa60b4STres Popp           llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
78126852423SMatthias Springer     assert(!layout && "UnrankedTensorType cannot have a layout map");
78226852423SMatthias Springer     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
7839bb63374SLei Zhang                                    memorySpace);
7847a1579acSMatthias Springer   }
7857a1579acSMatthias Springer 
786f287da8aSMatthias Springer   // Case 2: Ranked memref type with specified layout.
787c1fa60b4STres Popp   auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
788f287da8aSMatthias Springer   if (layout) {
78926852423SMatthias Springer     return MemRefType::get(rankedTensorType.getShape(),
79026852423SMatthias Springer                            rankedTensorType.getElementType(), layout,
7919bb63374SLei Zhang                            memorySpace);
79226852423SMatthias Springer   }
79326852423SMatthias Springer 
794606f7c8fSMatthias Springer   return options.unknownTypeConverterFn(value, memorySpace, options);
795f287da8aSMatthias Springer }
796f287da8aSMatthias Springer 
797f287da8aSMatthias Springer BaseMemRefType
798f287da8aSMatthias Springer bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
7999bb63374SLei Zhang                                                    Attribute memorySpace) {
800f287da8aSMatthias Springer   // Case 1: Unranked memref type.
801c1fa60b4STres Popp   if (auto unrankedTensorType =
802c1fa60b4STres Popp           llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
803f287da8aSMatthias Springer     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
804f287da8aSMatthias Springer                                    memorySpace);
805f287da8aSMatthias Springer   }
806f287da8aSMatthias Springer 
807f287da8aSMatthias Springer   // Case 2: Ranked memref type.
808c1fa60b4STres Popp   auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
809399638f9SAliia Khasanova   int64_t dynamicOffset = ShapedType::kDynamic;
81026852423SMatthias Springer   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
811399638f9SAliia Khasanova                                       ShapedType::kDynamic);
812f096e72cSAlex Zinenko   auto stridedLayout = StridedLayoutAttr::get(tensorType.getContext(),
813f096e72cSAlex Zinenko                                               dynamicOffset, dynamicStrides);
81426852423SMatthias Springer   return MemRefType::get(rankedTensorType.getShape(),
81526852423SMatthias Springer                          rankedTensorType.getElementType(), stridedLayout,
8169bb63374SLei Zhang                          memorySpace);
8177a1579acSMatthias Springer }
818f287da8aSMatthias Springer 
819f287da8aSMatthias Springer /// Return a MemRef type with a static identity layout (i.e., no layout map). If
820f287da8aSMatthias Springer /// the given tensor type is unranked, return an unranked MemRef type.
821f287da8aSMatthias Springer BaseMemRefType
822f287da8aSMatthias Springer bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
8239bb63374SLei Zhang                                                      Attribute memorySpace) {
824f287da8aSMatthias Springer   // Case 1: Unranked memref type.
825c1fa60b4STres Popp   if (auto unrankedTensorType =
826c1fa60b4STres Popp           llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
827f287da8aSMatthias Springer     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
828f287da8aSMatthias Springer                                    memorySpace);
829f287da8aSMatthias Springer   }
830f287da8aSMatthias Springer 
831f287da8aSMatthias Springer   // Case 2: Ranked memref type.
832c1fa60b4STres Popp   auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
833f287da8aSMatthias Springer   MemRefLayoutAttrInterface layout = {};
834f287da8aSMatthias Springer   return MemRefType::get(rankedTensorType.getShape(),
835f287da8aSMatthias Springer                          rankedTensorType.getElementType(), layout,
8369bb63374SLei Zhang                          memorySpace);
837f287da8aSMatthias Springer }
838f7f0c7f7SMatthias Springer 
839f3483c23SMatthias Springer //===----------------------------------------------------------------------===//
840f3483c23SMatthias Springer // Default implementations of interface methods
841f3483c23SMatthias Springer //===----------------------------------------------------------------------===//
842f3483c23SMatthias Springer 
843f3483c23SMatthias Springer bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
844f3483c23SMatthias Springer     OpResult opResult, const AnalysisState &state) {
845f3483c23SMatthias Springer   auto bufferizableOp = cast<BufferizableOpInterface>(opResult.getDefiningOp());
8461ac248e4SMatthias Springer   AliasingOpOperandList opOperands =
8471ac248e4SMatthias Springer       bufferizableOp.getAliasingOpOperands(opResult, state);
848f3483c23SMatthias Springer 
849f3483c23SMatthias Springer   // Case 1: OpResults that have no aliasing OpOperand usually bufferize to
850f3483c23SMatthias Springer   // memory writes.
8519fa6b350SMatthias Springer   if (opOperands.getAliases().empty())
852f3483c23SMatthias Springer     return true;
853f3483c23SMatthias Springer 
854f3483c23SMatthias Springer   // Case 2: If an aliasing OpOperand bufferizes to a memory write, the OpResult
855f3483c23SMatthias Springer   // may bufferize to a memory write.
8569fa6b350SMatthias Springer   if (llvm::any_of(opOperands, [&](AliasingOpOperand alias) {
8579fa6b350SMatthias Springer         return state.bufferizesToMemoryWrite(*alias.opOperand);
858f3483c23SMatthias Springer       }))
859f3483c23SMatthias Springer     return true;
860f3483c23SMatthias Springer 
861f3483c23SMatthias Springer   // Case 3: Check if a nested aliasing OpOperand value bufferizes to a memory
862f3483c23SMatthias Springer   // write. (Or: The reverse SSA use-def chain ends inside the reigon.) In that
863f3483c23SMatthias Springer   // case, the OpResult bufferizes to a memory write. E.g.:
864f3483c23SMatthias Springer   //
865f3483c23SMatthias Springer   // %0 = "some_writing_op" : tensor<?xf32>
866f3483c23SMatthias Springer   // %r = scf.if ... -> tensor<?xf32> {
867f3483c23SMatthias Springer   //   scf.yield %0 : tensor<?xf32>
868f3483c23SMatthias Springer   // } else {
869f3483c23SMatthias Springer   //   %1 = "another_writing_op"(%0) : tensor<?xf32>
870f3483c23SMatthias Springer   //   scf.yield %1 : tensor<?xf32>
871f3483c23SMatthias Springer   // }
872f3483c23SMatthias Springer   // "some_reading_op"(%r)
873f3483c23SMatthias Springer   //
874f3483c23SMatthias Springer   // %r bufferizes to a memory write because an aliasing OpOperand value (%1)
875f3483c23SMatthias Springer   // bufferizes to a memory write and the defining op is inside the scf.if.
876f3483c23SMatthias Springer   //
877f3483c23SMatthias Springer   // Note: This treatment of surrouding ops is useful for ops that have a
878f3483c23SMatthias Springer   // region but no OpOperand such as scf.if or scf.execute_region. It simplifies
879f3483c23SMatthias Springer   // the analysis considerably.
880f3483c23SMatthias Springer   //
881f3483c23SMatthias Springer   // "another_writing_op" in the above example should be able to bufferize
882f3483c23SMatthias Springer   // inplace in the absence of another read of %0. However, if the scf.if op
883f3483c23SMatthias Springer   // would not be considered a "write", the analysis would detect the
884f3483c23SMatthias Springer   // following conflict:
885f3483c23SMatthias Springer   //
886f3483c23SMatthias Springer   // * read = some_reading_op
887f3483c23SMatthias Springer   // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
888f3483c23SMatthias Springer   // * conflictingWrite = %1
889f3483c23SMatthias Springer   //
890f3483c23SMatthias Springer   auto isMemoryWriteInsideOp = [&](Value v) {
891f3483c23SMatthias Springer     Operation *op = getOwnerOfValue(v);
892f3483c23SMatthias Springer     if (!opResult.getDefiningOp()->isAncestor(op))
893f3483c23SMatthias Springer       return false;
894f3483c23SMatthias Springer     return state.bufferizesToMemoryWrite(v);
895f3483c23SMatthias Springer   };
8961f479c1eSMatthias Springer   TraversalConfig config;
8971f479c1eSMatthias Springer   config.alwaysIncludeLeaves = false;
8989fa6b350SMatthias Springer   for (AliasingOpOperand alias : opOperands) {
899f3483c23SMatthias Springer     if (!state
900*d9111f19SAmir Bishara              .findValueInReverseUseDefChain(alias.opOperand,
9011f479c1eSMatthias Springer                                             isMemoryWriteInsideOp, config)
902f3483c23SMatthias Springer              .empty())
903f3483c23SMatthias Springer       return true;
904f3483c23SMatthias Springer   }
905f3483c23SMatthias Springer   return false;
906f3483c23SMatthias Springer }
907f3483c23SMatthias Springer 
908a02ad6c1SMatthias Springer // Compute the AliasingOpOperandList for a given Value based on
909a02ad6c1SMatthias Springer // getAliasingValues.
9101ac248e4SMatthias Springer AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
911a02ad6c1SMatthias Springer     Value value, const AnalysisState &state) {
912a02ad6c1SMatthias Springer   Operation *op = getOwnerOfValue(value);
9139fa6b350SMatthias Springer   SmallVector<AliasingOpOperand> result;
9141ac248e4SMatthias Springer   for (OpOperand &opOperand : op->getOpOperands()) {
915c1fa60b4STres Popp     if (!llvm::isa<TensorType>(opOperand.get().getType()))
9161ac248e4SMatthias Springer       continue;
917a02ad6c1SMatthias Springer     AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
918a02ad6c1SMatthias Springer     for (const auto &it : aliasingValues)
919a02ad6c1SMatthias Springer       if (it.value == value)
9209fa6b350SMatthias Springer         result.emplace_back(&opOperand, it.relation, it.isDefinite);
9211ac248e4SMatthias Springer   }
9229fa6b350SMatthias Springer   return AliasingOpOperandList(std::move(result));
9231ac248e4SMatthias Springer }
9241ac248e4SMatthias Springer 
925f3483c23SMatthias Springer FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
926f3483c23SMatthias Springer     Value value, const BufferizationOptions &options,
927878950b8SMatthias Springer     SmallVector<Value> &invocationStack) {
928c1fa60b4STres Popp   assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
929f3483c23SMatthias Springer 
930f3483c23SMatthias Springer   // No further analysis is possible for a block argument.
931c1fa60b4STres Popp   if (llvm::isa<BlockArgument>(value))
932f3483c23SMatthias Springer     return bufferization::getMemRefType(value, options);
933f3483c23SMatthias Springer 
934f3483c23SMatthias Springer   // Value is an OpResult.
935f3483c23SMatthias Springer   Operation *op = getOwnerOfValue(value);
936c1fa60b4STres Popp   auto opResult = llvm::cast<OpResult>(value);
937f3483c23SMatthias Springer   AnalysisState state(options);
9381ac248e4SMatthias Springer   AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
9399fa6b350SMatthias Springer   if (aliases.getNumAliases() > 0 &&
9409fa6b350SMatthias Springer       aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
941f3483c23SMatthias Springer     // If the OpResult has an equivalent OpOperand, both OpResult and
942f3483c23SMatthias Springer     // OpOperand bufferize to the exact same buffer type.
9439fa6b350SMatthias Springer     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
944878950b8SMatthias Springer     return getBufferType(equivalentOperand, options, invocationStack);
945f3483c23SMatthias Springer   }
946f3483c23SMatthias Springer 
947f3483c23SMatthias Springer   // If we do not know the memory space and there is no default memory space,
948f3483c23SMatthias Springer   // report a failure.
949067d2779Sian Bearman   auto memSpace =
950a5757c5bSChristian Sigg       options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
951067d2779Sian Bearman   if (!memSpace.has_value())
952f3483c23SMatthias Springer     return op->emitError("could not infer memory space");
953f3483c23SMatthias Springer 
954067d2779Sian Bearman   return getMemRefType(value, options, /*layout=*/{}, *memSpace);
955f3483c23SMatthias Springer }
956f3483c23SMatthias Springer 
957f7f0c7f7SMatthias Springer bool bufferization::detail::defaultIsRepetitiveRegion(
958f7f0c7f7SMatthias Springer     BufferizableOpInterface bufferizableOp, unsigned index) {
959f7f0c7f7SMatthias Springer   assert(index < bufferizableOp->getNumRegions() && "invalid region index");
960f7f0c7f7SMatthias Springer   auto regionInterface =
961f7f0c7f7SMatthias Springer       dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation());
962f7f0c7f7SMatthias Springer   if (!regionInterface)
963f7f0c7f7SMatthias Springer     return false;
964f7f0c7f7SMatthias Springer   return regionInterface.isRepetitiveRegion(index);
965f7f0c7f7SMatthias Springer }
966f3483c23SMatthias Springer 
9671ac248e4SMatthias Springer AliasingOpOperandList
968a02ad6c1SMatthias Springer bufferization::detail::unknownGetAliasingOpOperands(Value value) {
969a02ad6c1SMatthias Springer   // TODO: Take into account successor blocks.
970a02ad6c1SMatthias Springer   // No aliasing in case of non-entry blocks.
971a02ad6c1SMatthias Springer   if (auto bbArg = dyn_cast<BlockArgument>(value))
972a02ad6c1SMatthias Springer     if (bbArg.getOwner() != &bbArg.getOwner()->getParent()->getBlocks().front())
973a02ad6c1SMatthias Springer       return {};
974a02ad6c1SMatthias Springer 
975a02ad6c1SMatthias Springer   // Unknown op: Conservatively assume that each OpResult may alias with every
976a02ad6c1SMatthias Springer   // OpOperand. In addition, each block argument of an entry block may alias
977a02ad6c1SMatthias Springer   // with every OpOperand.
9781ac248e4SMatthias Springer   AliasingOpOperandList r;
979a02ad6c1SMatthias Springer   for (OpOperand &operand : value.getDefiningOp()->getOpOperands())
980a02ad6c1SMatthias Springer     if (isa<TensorType>(operand.get().getType()))
9819fa6b350SMatthias Springer       r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false});
982f3483c23SMatthias Springer   return r;
983f3483c23SMatthias Springer }
984f3483c23SMatthias Springer 
985a02ad6c1SMatthias Springer AliasingValueList
986a02ad6c1SMatthias Springer bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
987a02ad6c1SMatthias Springer   // TODO: Take into account successor blocks.
988a02ad6c1SMatthias Springer   // Unknown op: Conservatively assume that each OpResult may alias with every
989a02ad6c1SMatthias Springer   // OpOperand. In addition, each block argument of an entry block may alias
990a02ad6c1SMatthias Springer   // with every OpOperand.
991a02ad6c1SMatthias Springer   AliasingValueList r;
992f3483c23SMatthias Springer   for (OpResult result : opOperand.getOwner()->getOpResults())
993c1fa60b4STres Popp     if (llvm::isa<TensorType>(result.getType()))
9949fa6b350SMatthias Springer       r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false});
995a02ad6c1SMatthias Springer   for (Region &region : opOperand.getOwner()->getRegions())
996a02ad6c1SMatthias Springer     if (!region.getBlocks().empty())
997a02ad6c1SMatthias Springer       for (BlockArgument bbArg : region.getBlocks().front().getArguments())
998a5757c5bSChristian Sigg         if (isa<TensorType>(bbArg.getType()))
999a02ad6c1SMatthias Springer           r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
1000f3483c23SMatthias Springer   return r;
1001f3483c23SMatthias Springer }
10028f2d83daSMatthias Springer 
10038f2d83daSMatthias Springer bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
10048f2d83daSMatthias Springer   auto isaTensor = [](Type t) { return isa<TensorType>(t); };
10058f2d83daSMatthias Springer   bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) {
10068f2d83daSMatthias Springer     return any_of(r.getBlocks(), [&](Block &b) {
10078f2d83daSMatthias Springer       return any_of(b.getArguments(), [&](BlockArgument bbArg) {
10088f2d83daSMatthias Springer         return isaTensor(bbArg.getType());
10098f2d83daSMatthias Springer       });
10108f2d83daSMatthias Springer     });
10118f2d83daSMatthias Springer   });
10128f2d83daSMatthias Springer   if (hasTensorBlockArgument)
10138f2d83daSMatthias Springer     return true;
10148f2d83daSMatthias Springer 
10158f2d83daSMatthias Springer   if (any_of(op->getResultTypes(), isaTensor))
10168f2d83daSMatthias Springer     return true;
10178f2d83daSMatthias Springer   return any_of(op->getOperandTypes(), isaTensor);
10188f2d83daSMatthias Springer }
1019