xref: /llvm-project/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (revision d9111f19d2ea53d8ce105b3d09425394ccf37969)
1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "mlir/Interfaces/ControlFlowInterfaces.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/Support/Debug.h"
23 
24 //===----------------------------------------------------------------------===//
25 // BufferizableOpInterface
26 //===----------------------------------------------------------------------===//
27 
28 namespace mlir {
29 namespace bufferization {
30 
31 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
32 
33 } // namespace bufferization
34 } // namespace mlir
35 
36 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
37 
38 #define DEBUG_TYPE "bufferizable-op-interface"
39 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
40 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
41 
42 using namespace mlir;
43 using namespace bufferization;
44 
45 static bool isRepetitiveRegion(Region *region,
46                                const BufferizationOptions &options) {
47   Operation *op = region->getParentOp();
48   if (auto bufferizableOp = options.dynCastBufferizableOp(op))
49     if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
50       return true;
51   return false;
52 }
53 
54 Region *AnalysisState::getEnclosingRepetitiveRegion(
55     Operation *op, const BufferizationOptions &options) {
56   if (!op->getBlock())
57     return nullptr;
58   if (auto iter = enclosingRepetitiveRegionCache.find_as(op);
59       iter != enclosingRepetitiveRegionCache.end())
60     return iter->second;
61   return enclosingRepetitiveRegionCache[op] =
62              getEnclosingRepetitiveRegion(op->getBlock(), options);
63 }
64 
65 Region *AnalysisState::getEnclosingRepetitiveRegion(
66     Value value, const BufferizationOptions &options) {
67   if (auto iter = enclosingRepetitiveRegionCache.find_as(value);
68       iter != enclosingRepetitiveRegionCache.end())
69     return iter->second;
70 
71   Region *region = value.getParentRegion();
72   // Collect all visited regions since we only know the repetitive region we
73   // want to map it to later on
74   SmallVector<Region *> visitedRegions;
75   while (region) {
76     visitedRegions.push_back(region);
77     if (isRepetitiveRegion(region, options))
78       break;
79     region = region->getParentRegion();
80   }
81   enclosingRepetitiveRegionCache[value] = region;
82   for (Region *r : visitedRegions)
83     enclosingRepetitiveRegionCache[r] = region;
84   return region;
85 }
86 
87 Region *AnalysisState::getEnclosingRepetitiveRegion(
88     Block *block, const BufferizationOptions &options) {
89   if (auto iter = enclosingRepetitiveRegionCache.find_as(block);
90       iter != enclosingRepetitiveRegionCache.end())
91     return iter->second;
92 
93   Region *region = block->getParent();
94   Operation *op = nullptr;
95   // Collect all visited regions since we only know the repetitive region we
96   // want to map it to later on
97   SmallVector<Region *> visitedRegions;
98   do {
99     op = region->getParentOp();
100     if (isRepetitiveRegion(region, options))
101       break;
102   } while ((region = op->getParentRegion()));
103 
104   enclosingRepetitiveRegionCache[block] = region;
105   for (Region *r : visitedRegions)
106     enclosingRepetitiveRegionCache[r] = region;
107   return region;
108 }
109 
110 void AnalysisState::resetCache() { enclosingRepetitiveRegionCache.clear(); }
111 
112 Region *bufferization::getNextEnclosingRepetitiveRegion(
113     Region *region, const BufferizationOptions &options) {
114   assert(isRepetitiveRegion(region, options) && "expected repetitive region");
115   while ((region = region->getParentRegion())) {
116     if (isRepetitiveRegion(region, options))
117       break;
118   }
119   return region;
120 }
121 
122 Region *bufferization::getParallelRegion(Region *region,
123                                          const BufferizationOptions &options) {
124   while (region) {
125     auto bufferizableOp = options.dynCastBufferizableOp(region->getParentOp());
126     if (bufferizableOp &&
127         bufferizableOp.isParallelRegion(region->getRegionNumber())) {
128       assert(isRepetitiveRegion(region, options) &&
129              "expected that all parallel regions are also repetitive regions");
130       return region;
131     }
132     region = region->getParentRegion();
133   }
134   return nullptr;
135 }
136 
137 Operation *bufferization::getOwnerOfValue(Value value) {
138   if (auto opResult = llvm::dyn_cast<OpResult>(value))
139     return opResult.getDefiningOp();
140   return llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
141 }
142 
143 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
144 /// shaped value is copied. Otherwise, a tensor with undefined contents is
145 /// allocated.
146 FailureOr<Value> bufferization::allocateTensorForShapedValue(
147     OpBuilder &b, Location loc, Value shapedValue,
148     const BufferizationOptions &options, bool copy) {
149   Value tensor;
150   if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
151     tensor = shapedValue;
152   } else if (llvm::isa<MemRefType>(shapedValue.getType())) {
153     tensor = b.create<ToTensorOp>(loc, shapedValue);
154   } else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) ||
155              llvm::isa<UnrankedMemRefType>(shapedValue.getType())) {
156     return getOwnerOfValue(shapedValue)
157         ->emitError("copying of unranked tensors is not implemented");
158   } else {
159     llvm_unreachable("expected RankedTensorType or MemRefType");
160   }
161   RankedTensorType tensorType = llvm::cast<RankedTensorType>(tensor.getType());
162   SmallVector<Value> dynamicSizes;
163   if (!copy) {
164     // Compute the dynamic part of the shape.
165     // First try to query the shape via ReifyRankedShapedTypeOpInterface.
166     bool reifiedShapes = false;
167     if (llvm::isa<RankedTensorType>(shapedValue.getType()) &&
168         llvm::isa<OpResult>(shapedValue)) {
169       ReifiedRankedShapedTypeDims resultDims;
170       if (succeeded(
171               reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) {
172         reifiedShapes = true;
173         auto &shape =
174             resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()];
175         for (const auto &dim : enumerate(tensorType.getShape()))
176           if (ShapedType::isDynamic(dim.value()))
177             dynamicSizes.push_back(cast<Value>(shape[dim.index()]));
178       }
179     }
180 
181     // If the shape could not be reified, create DimOps.
182     if (!reifiedShapes)
183       populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
184   }
185 
186   // Create AllocTensorOp.
187   auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
188                                                copy ? tensor : Value());
189 
190   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
191   if (copy)
192     return allocTensorOp.getResult();
193   FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
194   if (failed(copyBufferType))
195     return failure();
196   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
197   if (!memorySpace)
198     memorySpace = options.defaultMemorySpaceFn(tensorType);
199   if (memorySpace.has_value())
200     allocTensorOp.setMemorySpaceAttr(memorySpace.value());
201   return allocTensorOp.getResult();
202 }
203 
204 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
205     RewriterBase &rewriter, const AnalysisState &state) {
206   OpBuilder::InsertionGuard g(rewriter);
207   Operation *op = getOperation();
208   SmallVector<OpOperand *> outOfPlaceOpOperands;
209   DenseSet<OpOperand *> copiedOpOperands;
210   SmallVector<Value> outOfPlaceValues;
211   DenseSet<Value> copiedOpValues;
212 
213   // Find all out-of-place OpOperands.
214   for (OpOperand &opOperand : op->getOpOperands()) {
215     Type operandType = opOperand.get().getType();
216     if (!llvm::isa<TensorType>(operandType))
217       continue;
218     if (state.isInPlace(opOperand))
219       continue;
220     if (llvm::isa<UnrankedTensorType>(operandType))
221       return op->emitError("copying of unranked tensors is not implemented");
222 
223     AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
224     if (aliasingValues.getNumAliases() == 1 &&
225         isa<OpResult>(aliasingValues.getAliases()[0].value) &&
226         !state.bufferizesToMemoryWrite(opOperand) &&
227         state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
228                 .getNumAliases() == 1 &&
229         !isa<UnrankedTensorType>(
230             aliasingValues.getAliases()[0].value.getType())) {
231       // The op itself does not write but may create exactly one alias. Instead
232       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
233       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
234       // where the result is usually a smaller part of the source). Do not apply
235       // this optimization if the OpResult is an unranked tensor (because those
236       // cannot be copied at the moment).
237       Value value = aliasingValues.getAliases()[0].value;
238       outOfPlaceValues.push_back(value);
239       if (!state.canOmitTensorCopy(opOperand))
240         copiedOpValues.insert(value);
241     } else {
242       // In all other cases, make a copy of the OpOperand.
243       outOfPlaceOpOperands.push_back(&opOperand);
244       if (!state.canOmitTensorCopy(opOperand))
245         copiedOpOperands.insert(&opOperand);
246     }
247   }
248 
249   // Insert copies of OpOperands.
250   rewriter.setInsertionPoint(op);
251   for (OpOperand *opOperand : outOfPlaceOpOperands) {
252     FailureOr<Value> copy = allocateTensorForShapedValue(
253         rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
254         copiedOpOperands.contains(opOperand));
255     if (failed(copy))
256       return failure();
257     rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
258   }
259 
260   // Insert copies of Values.
261   rewriter.setInsertionPointAfter(op);
262   for (Value value : outOfPlaceValues) {
263     FailureOr<Value> copy = allocateTensorForShapedValue(
264         rewriter, op->getLoc(), value, state.getOptions(),
265         copiedOpValues.count(value));
266     if (failed(copy))
267       return failure();
268     SmallVector<OpOperand *> uses = llvm::to_vector(
269         llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; }));
270     for (OpOperand *use : uses) {
271       // Do not update the alloc_tensor op that we just created.
272       if (use->getOwner() == copy->getDefiningOp())
273         continue;
274       // tensor.dim ops may have been created to be used as alloc_tensor op
275       // dynamic extents. Do not update these either.
276       if (isa<tensor::DimOp>(use->getOwner()))
277         continue;
278       rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); });
279     }
280   }
281 
282   return success();
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // OpFilter
287 //===----------------------------------------------------------------------===//
288 
289 bool OpFilter::isOpAllowed(Operation *op) const {
290   // All other ops: Allow/disallow according to filter.
291   bool isAllowed = !hasAllowRule();
292   for (const Entry &entry : entries) {
293     bool filterResult = entry.fn(op);
294     switch (entry.type) {
295     case Entry::ALLOW:
296       isAllowed |= filterResult;
297       break;
298     case Entry::DENY:
299       if (filterResult)
300         // DENY filter matches. This op is no allowed. (Even if other ALLOW
301         // filters may match.)
302         return false;
303     };
304   }
305   return isAllowed;
306 }
307 
308 //===----------------------------------------------------------------------===//
309 // BufferizationOptions
310 //===----------------------------------------------------------------------===//
311 
312 namespace {
313 
314 /// Default function arg type converter: Use a fully dynamic layout map.
315 BaseMemRefType
316 defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
317                                 func::FuncOp funcOp,
318                                 const BufferizationOptions &options) {
319   return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
320 }
321 /// Default unknown type converter: Use a fully dynamic layout map.
322 BaseMemRefType
323 defaultUnknownTypeConverter(Value value, Attribute memorySpace,
324                             const BufferizationOptions &options) {
325   return getMemRefTypeWithFullyDynamicLayout(
326       llvm::cast<TensorType>(value.getType()), memorySpace);
327 }
328 
329 } // namespace
330 
331 // Default constructor for BufferizationOptions.
332 BufferizationOptions::BufferizationOptions()
333     : functionArgTypeConverterFn(defaultFunctionArgTypeConverter),
334       unknownTypeConverterFn(defaultUnknownTypeConverter) {}
335 
336 bool BufferizationOptions::isOpAllowed(Operation *op) const {
337   // Special case: If function boundary bufferization is deactivated, do not
338   // allow ops that belong to the `func` dialect.
339   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
340   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
341     return false;
342 
343   return opFilter.isOpAllowed(op);
344 }
345 
346 BufferizableOpInterface
347 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
348   if (!isOpAllowed(op))
349     return nullptr;
350   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
351   if (!bufferizableOp)
352     return nullptr;
353   return bufferizableOp;
354 }
355 
356 BufferizableOpInterface
357 BufferizationOptions::dynCastBufferizableOp(Value value) const {
358   return dynCastBufferizableOp(getOwnerOfValue(value));
359 }
360 
361 void BufferizationOptions::setFunctionBoundaryTypeConversion(
362     LayoutMapOption layoutMapOption) {
363   functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
364                                    func::FuncOp funcOp,
365                                    const BufferizationOptions &options) {
366     if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
367       return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
368                                                                   memorySpace);
369     return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
370                                                               memorySpace);
371   };
372   inferFunctionResultLayout =
373       layoutMapOption == LayoutMapOption::InferLayoutMap;
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // Helper functions for BufferizableOpInterface
378 //===----------------------------------------------------------------------===//
379 
380 static void setInsertionPointAfter(OpBuilder &b, Value value) {
381   if (auto bbArg = llvm::dyn_cast<BlockArgument>(value)) {
382     b.setInsertionPointToStart(bbArg.getOwner());
383   } else {
384     b.setInsertionPointAfter(value.getDefiningOp());
385   }
386 }
387 
388 /// Determine which OpOperand* will alias with `value` if the op is bufferized
389 /// in place. Return all tensor OpOperand* if the op is not bufferizable.
390 AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const {
391   if (Operation *op = getOwnerOfValue(value))
392     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
393       return bufferizableOp.getAliasingOpOperands(value, *this);
394 
395   // The op is not bufferizable.
396   return detail::unknownGetAliasingOpOperands(value);
397 }
398 
399 /// Determine which Values will alias with `opOperand` if the op is bufferized
400 /// in place. Return all tensor Values if the op is not bufferizable.
401 AliasingValueList AnalysisState::getAliasingValues(OpOperand &opOperand) const {
402   if (auto bufferizableOp =
403           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
404     return bufferizableOp.getAliasingValues(opOperand, *this);
405 
406   // The op is not bufferizable.
407   return detail::unknownGetAliasingValues(opOperand);
408 }
409 
410 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
411 /// op is not bufferizable.
412 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
413   if (auto bufferizableOp =
414           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
415     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
416 
417   // Unknown op that returns a tensor. The inplace analysis does not support it.
418   // Conservatively return true.
419   return true;
420 }
421 
422 /// Return true if `opOperand` bufferizes to a memory write. Return
423 /// `true` if the op is not bufferizable.
424 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
425   if (auto bufferizableOp =
426           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
427     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
428 
429   // Unknown op that returns a tensor. The inplace analysis does not support it.
430   // Conservatively return true.
431   return true;
432 }
433 
434 /// Return true if `opOperand` does neither read nor write but bufferizes to an
435 /// alias. Return false if the op is not bufferizable.
436 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
437   if (auto bufferizableOp =
438           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
439     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
440 
441   // Unknown op that returns a tensor. The inplace analysis does not support it.
442   // Conservatively return false.
443   return false;
444 }
445 
446 bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
447   auto opResult = llvm::dyn_cast<OpResult>(value);
448   if (!opResult)
449     return true;
450   auto bufferizableOp = getOptions().dynCastBufferizableOp(value);
451   if (!bufferizableOp)
452     return true;
453   return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *this);
454 }
455 
456 /// Return true if the given value is read by an op that bufferizes to a memory
457 /// read. Also takes into account ops that create an alias but do not read by
458 /// themselves (e.g., ExtractSliceOp).
459 bool AnalysisState::isValueRead(Value value) const {
460   assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
461   SmallVector<OpOperand *> workingSet;
462   DenseSet<OpOperand *> visited;
463   for (OpOperand &use : value.getUses())
464     workingSet.push_back(&use);
465 
466   while (!workingSet.empty()) {
467     OpOperand *uMaybeReading = workingSet.pop_back_val();
468     if (!visited.insert(uMaybeReading).second)
469       continue;
470 
471     // Skip over all ops that neither read nor write (but create an alias).
472     if (bufferizesToAliasOnly(*uMaybeReading))
473       for (AliasingValue alias : getAliasingValues(*uMaybeReading))
474         for (OpOperand &use : alias.value.getUses())
475           workingSet.push_back(&use);
476     if (bufferizesToMemoryRead(*uMaybeReading))
477       return true;
478   }
479 
480   return false;
481 }
482 
483 // Starting from `opOperand`, follow the use-def chain in reverse, always
484 // selecting the aliasing OpOperands. Find and return Values for which
485 // `condition` evaluates to true. Uses of such matching Values are not
486 // traversed any further, the visited aliasing opOperands will be preserved
487 // through `visitedOpOperands`.
488 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
489     OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
490     TraversalConfig config,
491     llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
492   llvm::DenseSet<Value> visited;
493   llvm::SetVector<Value> result, workingSet;
494   workingSet.insert(opOperand->get());
495 
496   if (visitedOpOperands)
497     visitedOpOperands->insert(opOperand);
498 
499   while (!workingSet.empty()) {
500     Value value = workingSet.pop_back_val();
501 
502     if (!config.revisitAlreadyVisitedValues && visited.contains(value)) {
503       // Stop traversal if value was already visited.
504       if (config.alwaysIncludeLeaves)
505         result.insert(value);
506       continue;
507     }
508     visited.insert(value);
509 
510     if (condition(value)) {
511       result.insert(value);
512       continue;
513     }
514 
515     if (!config.followUnknownOps && !options.dynCastBufferizableOp(value)) {
516       // Stop iterating if `followUnknownOps` is unset and the op is either
517       // not bufferizable or excluded in the OpFilter.
518       if (config.alwaysIncludeLeaves)
519         result.insert(value);
520       continue;
521     }
522 
523     AliasingOpOperandList aliases = getAliasingOpOperands(value);
524     if (aliases.getNumAliases() == 0) {
525       // The traversal ends naturally if there are no more OpOperands that
526       // could be followed.
527       if (config.alwaysIncludeLeaves)
528         result.insert(value);
529       continue;
530     }
531 
532     for (AliasingOpOperand a : aliases) {
533       if (config.followEquivalentOnly &&
534           a.relation != BufferRelation::Equivalent) {
535         // Stop iterating if `followEquivalentOnly` is set but the alias is not
536         // equivalent.
537         if (config.alwaysIncludeLeaves)
538           result.insert(value);
539         continue;
540       }
541 
542       if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) {
543         // Stop iterating if `followInPlaceOnly` is set but the alias is
544         // out-of-place.
545         if (config.alwaysIncludeLeaves)
546           result.insert(value);
547         continue;
548       }
549 
550       if (config.followSameTypeOrCastsOnly &&
551           a.opOperand->get().getType() != value.getType() &&
552           !value.getDefiningOp<CastOpInterface>()) {
553         // Stop iterating if `followSameTypeOrCastsOnly` is set but the alias is
554         // has a different type and the op is not a cast.
555         if (config.alwaysIncludeLeaves)
556           result.insert(value);
557         continue;
558       }
559 
560       workingSet.insert(a.opOperand->get());
561       if (visitedOpOperands)
562         visitedOpOperands->insert(a.opOperand);
563     }
564   }
565 
566   return result;
567 }
568 
569 // Find the values that define the contents of the given operand's value.
570 llvm::SetVector<Value>
571 AnalysisState::findDefinitions(OpOperand *opOperand) const {
572   TraversalConfig config;
573   config.alwaysIncludeLeaves = false;
574   return findValueInReverseUseDefChain(
575       opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
576       config);
577 }
578 
579 AnalysisState::AnalysisState(const BufferizationOptions &options)
580     : AnalysisState(options, TypeID::get<AnalysisState>()) {}
581 
582 AnalysisState::AnalysisState(const BufferizationOptions &options, TypeID type)
583     : options(options), type(type) {
584   for (const BufferizationOptions::AnalysisStateInitFn &fn :
585        options.stateInitializers)
586     fn(*this);
587 }
588 
589 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
590   // Do not copy if the tensor has undefined contents.
591   if (hasUndefinedContents(&opOperand))
592     return true;
593 
594   // Do not copy if the buffer of the tensor is entirely overwritten (with
595   // values that do not depend on the old tensor).
596   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
597     return true;
598 
599   // Do not copy if the tensor is never read.
600   AliasingValueList aliases = getAliasingValues(opOperand);
601   if (!bufferizesToMemoryRead(opOperand) &&
602       llvm::none_of(aliases,
603                     [&](AliasingValue a) { return isValueRead(a.value); }))
604     return true;
605 
606   // Default: Cannot omit the copy.
607   return false;
608 }
609 
610 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
611   // ToMemrefOps are always in-place.
612   if (isa<ToMemrefOp>(opOperand.getOwner()))
613     return true;
614 
615   // In the absence of analysis information, OpOperands that bufferize to a
616   // memory write are out-of-place, i.e., an alloc and copy is inserted.
617   return !bufferizesToMemoryWrite(opOperand);
618 }
619 
620 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
621   // In the absence of analysis information, we do not know if the values are
622   // equivalent. The conservative answer is "false".
623   return false;
624 }
625 
626 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
627   // In the absence of analysis information, we do not know if the values may be
628   // aliasing. The conservative answer is "true".
629   return true;
630 }
631 
632 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
633   // In the absence of analysis information, the conservative answer is "false".
634   return false;
635 }
636 
637 // bufferization.to_memref is not allowed to change the rank.
638 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
639 #ifndef NDEBUG
640   auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
641   assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
642                                    rankedTensorType.getRank()) &&
643          "to_memref would be invalid: mismatching ranks");
644 #endif
645 }
646 
647 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
648                                           const BufferizationOptions &options) {
649 #ifndef NDEBUG
650   auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
651   assert(tensorType && "unexpected non-tensor type");
652 #endif // NDEBUG
653 
654   // Replace "%t = to_tensor %m" with %m.
655   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
656     return toTensorOp.getMemref();
657 
658   // Insert to_memref op.
659   OpBuilder::InsertionGuard g(rewriter);
660   setInsertionPointAfter(rewriter, value);
661   FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
662   if (failed(memrefType))
663     return failure();
664   ensureToMemrefOpIsValid(value, *memrefType);
665   return rewriter
666       .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
667       .getResult();
668 }
669 
670 /// Return the buffer type for a given Value (tensor) after bufferization.
671 FailureOr<BaseMemRefType>
672 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
673   SmallVector<Value> invocationStack;
674   return getBufferType(value, options, invocationStack);
675 }
676 
677 /// Return the buffer type for a given Value (tensor) after bufferization.
678 FailureOr<BaseMemRefType>
679 bufferization::getBufferType(Value value, const BufferizationOptions &options,
680                              SmallVector<Value> &invocationStack) {
681   assert(llvm::isa<TensorType>(value.getType()) &&
682          "unexpected non-tensor type");
683   invocationStack.push_back(value);
684   auto popFromStack =
685       llvm::make_scope_exit([&]() { invocationStack.pop_back(); });
686 
687   // Try querying BufferizableOpInterface.
688   Operation *op = getOwnerOfValue(value);
689   auto bufferizableOp = options.dynCastBufferizableOp(op);
690   if (bufferizableOp)
691     return bufferizableOp.getBufferType(value, options, invocationStack);
692 
693   // Op is not bufferizable.
694   auto memSpace =
695       options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
696   if (!memSpace.has_value())
697     return op->emitError("could not infer memory space");
698 
699   return getMemRefType(value, options, /*layout=*/{}, *memSpace);
700 }
701 
702 bool bufferization::hasTensorSemantics(Operation *op) {
703   if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
704     return bufferizableOp.hasTensorSemantics();
705   return detail::defaultHasTensorSemantics(op);
706 }
707 
708 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
709                                                   Operation *op,
710                                                   ValueRange values) {
711   assert(values.size() == op->getNumResults() &&
712          "expected one value per OpResult");
713   OpBuilder::InsertionGuard g(rewriter);
714 
715   // Replace all OpResults with the given values.
716   SmallVector<Value> replacements;
717   for (OpResult opResult : op->getOpResults()) {
718     Value replacement = values[opResult.getResultNumber()];
719     if (llvm::isa<TensorType>(opResult.getType())) {
720       // The OpResult is a tensor. Such values are replaced with memrefs during
721       // bufferization.
722       assert((llvm::isa<MemRefType>(replacement.getType()) ||
723               llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
724              "tensor op result should be replaced with a memref value");
725       // The existing uses of the OpResult still expect a tensor. Insert a
726       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
727       // loose all of its users and eventually DCE away.
728       rewriter.setInsertionPointAfter(op);
729       replacement = rewriter.create<bufferization::ToTensorOp>(
730           replacement.getLoc(), opResult.getType(), replacement);
731     }
732     replacements.push_back(replacement);
733   }
734 
735   rewriter.replaceOp(op, replacements);
736 }
737 
738 //===----------------------------------------------------------------------===//
739 // Bufferization-specific scoped alloc insertion support.
740 //===----------------------------------------------------------------------===//
741 
742 /// Create a memref allocation with the given type and dynamic extents.
743 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
744                                                    MemRefType type,
745                                                    ValueRange dynShape) const {
746   if (allocationFn)
747     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
748 
749   // Default bufferallocation via AllocOp.
750   if (bufferAlignment != 0)
751     return b
752         .create<memref::AllocOp>(loc, type, dynShape,
753                                  b.getI64IntegerAttr(bufferAlignment))
754         .getResult();
755   return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
756 }
757 
758 /// Create a memory copy between two memref buffers.
759 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
760                                                  Value from, Value to) const {
761   if (memCpyFn)
762     return (*memCpyFn)(b, loc, from, to);
763 
764   b.create<memref::CopyOp>(loc, from, to);
765   return success();
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // Bufferization-specific IRMapping support with debugging.
770 //===----------------------------------------------------------------------===//
771 
772 BaseMemRefType bufferization::getMemRefType(Value value,
773                                             const BufferizationOptions &options,
774                                             MemRefLayoutAttrInterface layout,
775                                             Attribute memorySpace) {
776   auto tensorType = llvm::cast<TensorType>(value.getType());
777 
778   // Case 1: Unranked memref type.
779   if (auto unrankedTensorType =
780           llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
781     assert(!layout && "UnrankedTensorType cannot have a layout map");
782     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
783                                    memorySpace);
784   }
785 
786   // Case 2: Ranked memref type with specified layout.
787   auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
788   if (layout) {
789     return MemRefType::get(rankedTensorType.getShape(),
790                            rankedTensorType.getElementType(), layout,
791                            memorySpace);
792   }
793 
794   return options.unknownTypeConverterFn(value, memorySpace, options);
795 }
796 
797 BaseMemRefType
798 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
799                                                    Attribute memorySpace) {
800   // Case 1: Unranked memref type.
801   if (auto unrankedTensorType =
802           llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
803     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
804                                    memorySpace);
805   }
806 
807   // Case 2: Ranked memref type.
808   auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
809   int64_t dynamicOffset = ShapedType::kDynamic;
810   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
811                                       ShapedType::kDynamic);
812   auto stridedLayout = StridedLayoutAttr::get(tensorType.getContext(),
813                                               dynamicOffset, dynamicStrides);
814   return MemRefType::get(rankedTensorType.getShape(),
815                          rankedTensorType.getElementType(), stridedLayout,
816                          memorySpace);
817 }
818 
819 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
820 /// the given tensor type is unranked, return an unranked MemRef type.
821 BaseMemRefType
822 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
823                                                      Attribute memorySpace) {
824   // Case 1: Unranked memref type.
825   if (auto unrankedTensorType =
826           llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
827     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
828                                    memorySpace);
829   }
830 
831   // Case 2: Ranked memref type.
832   auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
833   MemRefLayoutAttrInterface layout = {};
834   return MemRefType::get(rankedTensorType.getShape(),
835                          rankedTensorType.getElementType(), layout,
836                          memorySpace);
837 }
838 
839 //===----------------------------------------------------------------------===//
840 // Default implementations of interface methods
841 //===----------------------------------------------------------------------===//
842 
843 bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
844     OpResult opResult, const AnalysisState &state) {
845   auto bufferizableOp = cast<BufferizableOpInterface>(opResult.getDefiningOp());
846   AliasingOpOperandList opOperands =
847       bufferizableOp.getAliasingOpOperands(opResult, state);
848 
849   // Case 1: OpResults that have no aliasing OpOperand usually bufferize to
850   // memory writes.
851   if (opOperands.getAliases().empty())
852     return true;
853 
854   // Case 2: If an aliasing OpOperand bufferizes to a memory write, the OpResult
855   // may bufferize to a memory write.
856   if (llvm::any_of(opOperands, [&](AliasingOpOperand alias) {
857         return state.bufferizesToMemoryWrite(*alias.opOperand);
858       }))
859     return true;
860 
861   // Case 3: Check if a nested aliasing OpOperand value bufferizes to a memory
862   // write. (Or: The reverse SSA use-def chain ends inside the reigon.) In that
863   // case, the OpResult bufferizes to a memory write. E.g.:
864   //
865   // %0 = "some_writing_op" : tensor<?xf32>
866   // %r = scf.if ... -> tensor<?xf32> {
867   //   scf.yield %0 : tensor<?xf32>
868   // } else {
869   //   %1 = "another_writing_op"(%0) : tensor<?xf32>
870   //   scf.yield %1 : tensor<?xf32>
871   // }
872   // "some_reading_op"(%r)
873   //
874   // %r bufferizes to a memory write because an aliasing OpOperand value (%1)
875   // bufferizes to a memory write and the defining op is inside the scf.if.
876   //
877   // Note: This treatment of surrouding ops is useful for ops that have a
878   // region but no OpOperand such as scf.if or scf.execute_region. It simplifies
879   // the analysis considerably.
880   //
881   // "another_writing_op" in the above example should be able to bufferize
882   // inplace in the absence of another read of %0. However, if the scf.if op
883   // would not be considered a "write", the analysis would detect the
884   // following conflict:
885   //
886   // * read = some_reading_op
887   // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
888   // * conflictingWrite = %1
889   //
890   auto isMemoryWriteInsideOp = [&](Value v) {
891     Operation *op = getOwnerOfValue(v);
892     if (!opResult.getDefiningOp()->isAncestor(op))
893       return false;
894     return state.bufferizesToMemoryWrite(v);
895   };
896   TraversalConfig config;
897   config.alwaysIncludeLeaves = false;
898   for (AliasingOpOperand alias : opOperands) {
899     if (!state
900              .findValueInReverseUseDefChain(alias.opOperand,
901                                             isMemoryWriteInsideOp, config)
902              .empty())
903       return true;
904   }
905   return false;
906 }
907 
908 // Compute the AliasingOpOperandList for a given Value based on
909 // getAliasingValues.
910 AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
911     Value value, const AnalysisState &state) {
912   Operation *op = getOwnerOfValue(value);
913   SmallVector<AliasingOpOperand> result;
914   for (OpOperand &opOperand : op->getOpOperands()) {
915     if (!llvm::isa<TensorType>(opOperand.get().getType()))
916       continue;
917     AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
918     for (const auto &it : aliasingValues)
919       if (it.value == value)
920         result.emplace_back(&opOperand, it.relation, it.isDefinite);
921   }
922   return AliasingOpOperandList(std::move(result));
923 }
924 
925 FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
926     Value value, const BufferizationOptions &options,
927     SmallVector<Value> &invocationStack) {
928   assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
929 
930   // No further analysis is possible for a block argument.
931   if (llvm::isa<BlockArgument>(value))
932     return bufferization::getMemRefType(value, options);
933 
934   // Value is an OpResult.
935   Operation *op = getOwnerOfValue(value);
936   auto opResult = llvm::cast<OpResult>(value);
937   AnalysisState state(options);
938   AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
939   if (aliases.getNumAliases() > 0 &&
940       aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
941     // If the OpResult has an equivalent OpOperand, both OpResult and
942     // OpOperand bufferize to the exact same buffer type.
943     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
944     return getBufferType(equivalentOperand, options, invocationStack);
945   }
946 
947   // If we do not know the memory space and there is no default memory space,
948   // report a failure.
949   auto memSpace =
950       options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
951   if (!memSpace.has_value())
952     return op->emitError("could not infer memory space");
953 
954   return getMemRefType(value, options, /*layout=*/{}, *memSpace);
955 }
956 
957 bool bufferization::detail::defaultIsRepetitiveRegion(
958     BufferizableOpInterface bufferizableOp, unsigned index) {
959   assert(index < bufferizableOp->getNumRegions() && "invalid region index");
960   auto regionInterface =
961       dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation());
962   if (!regionInterface)
963     return false;
964   return regionInterface.isRepetitiveRegion(index);
965 }
966 
967 AliasingOpOperandList
968 bufferization::detail::unknownGetAliasingOpOperands(Value value) {
969   // TODO: Take into account successor blocks.
970   // No aliasing in case of non-entry blocks.
971   if (auto bbArg = dyn_cast<BlockArgument>(value))
972     if (bbArg.getOwner() != &bbArg.getOwner()->getParent()->getBlocks().front())
973       return {};
974 
975   // Unknown op: Conservatively assume that each OpResult may alias with every
976   // OpOperand. In addition, each block argument of an entry block may alias
977   // with every OpOperand.
978   AliasingOpOperandList r;
979   for (OpOperand &operand : value.getDefiningOp()->getOpOperands())
980     if (isa<TensorType>(operand.get().getType()))
981       r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false});
982   return r;
983 }
984 
985 AliasingValueList
986 bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
987   // TODO: Take into account successor blocks.
988   // Unknown op: Conservatively assume that each OpResult may alias with every
989   // OpOperand. In addition, each block argument of an entry block may alias
990   // with every OpOperand.
991   AliasingValueList r;
992   for (OpResult result : opOperand.getOwner()->getOpResults())
993     if (llvm::isa<TensorType>(result.getType()))
994       r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false});
995   for (Region &region : opOperand.getOwner()->getRegions())
996     if (!region.getBlocks().empty())
997       for (BlockArgument bbArg : region.getBlocks().front().getArguments())
998         if (isa<TensorType>(bbArg.getType()))
999           r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
1000   return r;
1001 }
1002 
1003 bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
1004   auto isaTensor = [](Type t) { return isa<TensorType>(t); };
1005   bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) {
1006     return any_of(r.getBlocks(), [&](Block &b) {
1007       return any_of(b.getArguments(), [&](BlockArgument bbArg) {
1008         return isaTensor(bbArg.getType());
1009       });
1010     });
1011   });
1012   if (hasTensorBlockArgument)
1013     return true;
1014 
1015   if (any_of(op->getResultTypes(), isaTensor))
1016     return true;
1017   return any_of(op->getOperandTypes(), isaTensor);
1018 }
1019