xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (revision ced2fc7819d5ddea616ec330f18e08ff284c1868)
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Utils/StaticValueUtils.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 using namespace mlir::bufferization;
26 using namespace mlir::scf;
27 
28 namespace mlir {
29 namespace scf {
30 namespace {
31 
32 /// Helper function for loop bufferization. Cast the given buffer to the given
33 /// memref type.
34 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
35   assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
36   assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
37   // If the buffer already has the correct type, no cast is needed.
38   if (buffer.getType() == type)
39     return buffer;
40   // TODO: In case `type` has a layout map that is not the fully dynamic
41   // one, we may not be able to cast the buffer. In that case, the loop
42   // iter_arg's layout map must be changed (see uses of `castBuffer`).
43   assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
44          "scf.while op bufferization: cast incompatible");
45   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
46 }
47 
48 /// Helper function for loop bufferization. Return "true" if the given value
49 /// is guaranteed to not alias with an external tensor apart from values in
50 /// `exceptions`. A value is external if it is defined outside of the given
51 /// region or if it is an entry block argument of the region.
52 static bool doesNotAliasExternalValue(Value value, Region *region,
53                                       ValueRange exceptions,
54                                       const OneShotAnalysisState &state) {
55   assert(region->getBlocks().size() == 1 &&
56          "expected region with single block");
57   bool result = true;
58   state.applyOnAliases(value, [&](Value alias) {
59     if (llvm::is_contained(exceptions, alias))
60       return;
61     Region *aliasRegion = alias.getParentRegion();
62     if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
63       result = false;
64     if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
65       result = false;
66   });
67   return result;
68 }
69 
70 /// Bufferization of scf.condition.
71 struct ConditionOpInterface
72     : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
73                                                     scf::ConditionOp> {
74   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
75                               const AnalysisState &state) const {
76     return true;
77   }
78 
79   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
80                                const AnalysisState &state) const {
81     return false;
82   }
83 
84   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
85                                       const AnalysisState &state) const {
86     return {};
87   }
88 
89   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
90                             const AnalysisState &state) const {
91     // Condition operands always bufferize inplace. Otherwise, an alloc + copy
92     // may be generated inside the block. We should not return/yield allocations
93     // when possible.
94     return true;
95   }
96 
97   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
98                           const BufferizationOptions &options) const {
99     auto conditionOp = cast<scf::ConditionOp>(op);
100     auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
101 
102     SmallVector<Value> newArgs;
103     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
104       Value value = it.value();
105       if (isa<TensorType>(value.getType())) {
106         FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
107         if (failed(maybeBuffer))
108           return failure();
109         FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
110             whileOp.getAfterArguments()[it.index()], options);
111         if (failed(resultType))
112           return failure();
113         Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114         newArgs.push_back(buffer);
115       } else {
116         newArgs.push_back(value);
117       }
118     }
119 
120     replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121         rewriter, op, conditionOp.getCondition(), newArgs);
122     return success();
123   }
124 };
125 
126 /// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
127 /// return an empty op.
128 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
129   scf::YieldOp result;
130   for (Block &block : executeRegionOp.getRegion()) {
131     if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
132       if (result)
133         return {};
134       result = yieldOp;
135     }
136   }
137   return result;
138 }
139 
140 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
141 /// fully implemented at the moment.
142 struct ExecuteRegionOpInterface
143     : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
144           ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
145 
146   static bool supportsUnstructuredControlFlow() { return true; }
147 
148   bool isWritable(Operation *op, Value value,
149                   const AnalysisState &state) const {
150     return true;
151   }
152 
153   LogicalResult verifyAnalysis(Operation *op,
154                                const AnalysisState &state) const {
155     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
156     // TODO: scf.execute_region with multiple yields are not supported.
157     if (!getUniqueYieldOp(executeRegionOp))
158       return op->emitOpError("op without unique scf.yield is not supported");
159     return success();
160   }
161 
162   AliasingOpOperandList
163   getAliasingOpOperands(Operation *op, Value value,
164                         const AnalysisState &state) const {
165     if (auto bbArg = dyn_cast<BlockArgument>(value))
166       return getAliasingBranchOpOperands(op, bbArg, state);
167 
168     // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
169     // any SSA value that is in scope. To allow for use-def chain traversal
170     // through ExecuteRegionOps in the analysis, the corresponding yield value
171     // is considered to be aliasing with the result.
172     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
173     auto it = llvm::find(op->getOpResults(), value);
174     assert(it != op->getOpResults().end() && "invalid value");
175     size_t resultNum = std::distance(op->getOpResults().begin(), it);
176     auto yieldOp = getUniqueYieldOp(executeRegionOp);
177     // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
178     if (!yieldOp)
179       return {};
180     return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
181   }
182 
183   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
184                           const BufferizationOptions &options) const {
185     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
186     auto yieldOp = getUniqueYieldOp(executeRegionOp);
187     TypeRange newResultTypes(yieldOp.getResults());
188 
189     // Create new op and move over region.
190     auto newOp =
191         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
192     newOp.getRegion().takeBody(executeRegionOp.getRegion());
193 
194     // Bufferize every block.
195     for (Block &block : newOp.getRegion())
196       if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
197                                                         options)))
198         return failure();
199 
200     // Update all uses of the old op.
201     rewriter.setInsertionPointAfter(newOp);
202     SmallVector<Value> newResults;
203     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
204       if (isa<TensorType>(it.value())) {
205         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
206             executeRegionOp.getLoc(), it.value(),
207             newOp->getResult(it.index())));
208       } else {
209         newResults.push_back(newOp->getResult(it.index()));
210       }
211     }
212 
213     // Replace old op.
214     rewriter.replaceOp(executeRegionOp, newResults);
215 
216     return success();
217   }
218 };
219 
220 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
221 struct IfOpInterface
222     : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
223   AliasingOpOperandList
224   getAliasingOpOperands(Operation *op, Value value,
225                         const AnalysisState &state) const {
226     // IfOps do not have tensor OpOperands. The yielded value can be any SSA
227     // value that is in scope. To allow for use-def chain traversal through
228     // IfOps in the analysis, both corresponding yield values from the then/else
229     // branches are considered to be aliasing with the result.
230     auto ifOp = cast<scf::IfOp>(op);
231     size_t resultNum = std::distance(op->getOpResults().begin(),
232                                      llvm::find(op->getOpResults(), value));
233     OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
234     OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
235     return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
236             {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
237   }
238 
239   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
240                           const BufferizationOptions &options) const {
241     OpBuilder::InsertionGuard g(rewriter);
242     auto ifOp = cast<scf::IfOp>(op);
243 
244     // Compute bufferized result types.
245     SmallVector<Type> newTypes;
246     for (Value result : ifOp.getResults()) {
247       if (!isa<TensorType>(result.getType())) {
248         newTypes.push_back(result.getType());
249         continue;
250       }
251       auto bufferType = bufferization::getBufferType(result, options);
252       if (failed(bufferType))
253         return failure();
254       newTypes.push_back(*bufferType);
255     }
256 
257     // Create new op.
258     rewriter.setInsertionPoint(ifOp);
259     auto newIfOp =
260         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
261                                    /*withElseRegion=*/true);
262 
263     // Move over then/else blocks.
264     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
265     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
266 
267     // Replace op results.
268     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
269 
270     return success();
271   }
272 
273   FailureOr<BaseMemRefType>
274   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
275                 SmallVector<Value> &invocationStack) const {
276     auto ifOp = cast<scf::IfOp>(op);
277     auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
278     auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
279     assert(value.getDefiningOp() == op && "invalid valid");
280 
281     // Determine buffer types of the true/false branches.
282     auto opResult = cast<OpResult>(value);
283     auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
284     auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
285     BaseMemRefType thenBufferType, elseBufferType;
286     if (isa<BaseMemRefType>(thenValue.getType())) {
287       // True branch was already bufferized.
288       thenBufferType = cast<BaseMemRefType>(thenValue.getType());
289     } else {
290       auto maybeBufferType =
291           bufferization::getBufferType(thenValue, options, invocationStack);
292       if (failed(maybeBufferType))
293         return failure();
294       thenBufferType = *maybeBufferType;
295     }
296     if (isa<BaseMemRefType>(elseValue.getType())) {
297       // False branch was already bufferized.
298       elseBufferType = cast<BaseMemRefType>(elseValue.getType());
299     } else {
300       auto maybeBufferType =
301           bufferization::getBufferType(elseValue, options, invocationStack);
302       if (failed(maybeBufferType))
303         return failure();
304       elseBufferType = *maybeBufferType;
305     }
306 
307     // Best case: Both branches have the exact same buffer type.
308     if (thenBufferType == elseBufferType)
309       return thenBufferType;
310 
311     // Memory space mismatch.
312     if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
313       return op->emitError("inconsistent memory space on then/else branches");
314 
315     // Layout maps are different: Promote to fully dynamic layout map.
316     return getMemRefTypeWithFullyDynamicLayout(
317         cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
318   }
319 };
320 
321 /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
322 /// yields memrefs.
323 struct IndexSwitchOpInterface
324     : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
325                                                     scf::IndexSwitchOp> {
326   AliasingOpOperandList
327   getAliasingOpOperands(Operation *op, Value value,
328                         const AnalysisState &state) const {
329     // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
330     // any SSA. This is similar to IfOps.
331     auto switchOp = cast<scf::IndexSwitchOp>(op);
332     int64_t resultNum = cast<OpResult>(value).getResultNumber();
333     AliasingOpOperandList result;
334     for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
335       auto yieldOp =
336           cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
337       result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
338                                         BufferRelation::Equivalent,
339                                         /*isDefinite=*/false));
340     }
341     auto defaultYieldOp =
342         cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
343     result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
344                                       BufferRelation::Equivalent,
345                                       /*isDefinite=*/false));
346     return result;
347   }
348 
349   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
350                           const BufferizationOptions &options) const {
351     OpBuilder::InsertionGuard g(rewriter);
352     auto switchOp = cast<scf::IndexSwitchOp>(op);
353 
354     // Compute bufferized result types.
355     SmallVector<Type> newTypes;
356     for (Value result : switchOp.getResults()) {
357       if (!isa<TensorType>(result.getType())) {
358         newTypes.push_back(result.getType());
359         continue;
360       }
361       auto bufferType = bufferization::getBufferType(result, options);
362       if (failed(bufferType))
363         return failure();
364       newTypes.push_back(*bufferType);
365     }
366 
367     // Create new op.
368     rewriter.setInsertionPoint(switchOp);
369     auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
370         switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
371         switchOp.getCases().size());
372 
373     // Move over blocks.
374     for (auto [src, dest] :
375          llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
376       rewriter.inlineRegionBefore(src, dest, dest.begin());
377     rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
378                                 newSwitchOp.getDefaultRegion(),
379                                 newSwitchOp.getDefaultRegion().begin());
380 
381     // Replace op results.
382     replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
383 
384     return success();
385   }
386 
387   FailureOr<BaseMemRefType>
388   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
389                 SmallVector<Value> &invocationStack) const {
390     auto switchOp = cast<scf::IndexSwitchOp>(op);
391     assert(value.getDefiningOp() == op && "invalid value");
392     int64_t resultNum = cast<OpResult>(value).getResultNumber();
393 
394     // Helper function to get buffer type of a case.
395     SmallVector<BaseMemRefType> yieldedTypes;
396     auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
397       auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
398       Value yieldedValue = yieldOp->getOperand(resultNum);
399       if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
400         return bufferType;
401       auto maybeBufferType =
402           bufferization::getBufferType(yieldedValue, options, invocationStack);
403       if (failed(maybeBufferType))
404         return failure();
405       return maybeBufferType;
406     };
407 
408     // Compute buffer type of the default case.
409     auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
410     if (failed(maybeBufferType))
411       return failure();
412     BaseMemRefType bufferType = *maybeBufferType;
413 
414     // Compute buffer types of all other cases.
415     for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
416       auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
417       if (failed(yieldedBufferType))
418         return failure();
419 
420       // Best case: Both branches have the exact same buffer type.
421       if (bufferType == *yieldedBufferType)
422         continue;
423 
424       // Memory space mismatch.
425       if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
426         return op->emitError("inconsistent memory space on switch cases");
427 
428       // Layout maps are different: Promote to fully dynamic layout map.
429       bufferType = getMemRefTypeWithFullyDynamicLayout(
430           cast<TensorType>(value.getType()), bufferType.getMemorySpace());
431     }
432 
433     return bufferType;
434   }
435 };
436 
437 /// Helper function for loop bufferization. Return the indices of all values
438 /// that have a tensor type.
439 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
440   DenseSet<int64_t> result;
441   for (const auto &it : llvm::enumerate(values))
442     if (isa<TensorType>(it.value().getType()))
443       result.insert(it.index());
444   return result;
445 }
446 
447 /// Helper function for loop bufferization. Return the indices of all
448 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
449 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
450                                        ValueRange yieldedValues,
451                                        const AnalysisState &state) {
452   unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
453   DenseSet<int64_t> result;
454   for (unsigned int i = 0; i < minSize; ++i) {
455     if (!isa<TensorType>(bbArgs[i].getType()) ||
456         !isa<TensorType>(yieldedValues[i].getType()))
457       continue;
458     if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
459       result.insert(i);
460   }
461   return result;
462 }
463 
464 /// Helper function for loop bufferization. Return the bufferized values of the
465 /// given OpOperands. If an operand is not a tensor, return the original value.
466 static FailureOr<SmallVector<Value>>
467 getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
468            const BufferizationOptions &options) {
469   SmallVector<Value> result;
470   for (OpOperand &opOperand : operands) {
471     if (isa<TensorType>(opOperand.get().getType())) {
472       FailureOr<Value> resultBuffer =
473           getBuffer(rewriter, opOperand.get(), options);
474       if (failed(resultBuffer))
475         return failure();
476       result.push_back(*resultBuffer);
477     } else {
478       result.push_back(opOperand.get());
479     }
480   }
481   return result;
482 }
483 
484 /// Helper function for loop bufferization. Given a list of bbArgs of the new
485 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
486 /// ToTensorOps, so that the block body can be moved over to the new op.
487 static SmallVector<Value>
488 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
489                      Block::BlockArgListType oldBbArgs,
490                      const DenseSet<int64_t> &tensorIndices) {
491   SmallVector<Value> result;
492   for (const auto &it : llvm::enumerate(bbArgs)) {
493     size_t idx = it.index();
494     Value val = it.value();
495     if (tensorIndices.contains(idx)) {
496       result.push_back(rewriter
497                            .create<bufferization::ToTensorOp>(
498                                val.getLoc(), oldBbArgs[idx].getType(), val)
499                            .getResult());
500     } else {
501       result.push_back(val);
502     }
503   }
504   return result;
505 }
506 
507 /// Compute the bufferized type of a loop iter_arg. This type must be equal to
508 /// the bufferized type of the corresponding init_arg and the bufferized type
509 /// of the corresponding yielded value.
510 ///
511 /// This function uses bufferization::getBufferType to compute the bufferized
512 /// type of the init_arg and of the yielded value. (The computation of the
513 /// bufferized yielded value type usually requires computing the bufferized type
514 /// of the iter_arg again; the implementation of getBufferType traces back the
515 /// use-def chain of the given value and computes a buffer type along the way.)
516 /// If both buffer types are equal, no casts are needed the computed buffer type
517 /// can be used directly. Otherwise, the buffer types can only differ in their
518 /// layout map and a cast must be inserted.
519 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
520     Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
521     const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
522   // Determine the buffer type of the init_arg.
523   auto initArgBufferType =
524       bufferization::getBufferType(initArg, options, invocationStack);
525   if (failed(initArgBufferType))
526     return failure();
527 
528   if (llvm::count(invocationStack, iterArg) >= 2) {
529     // If the iter_arg is already twice on the invocation stack, just take the
530     // type of the init_arg. This is to avoid infinite loops when calculating
531     // the buffer type. This will most likely result in computing a memref type
532     // with a fully dynamic layout map.
533 
534     // Note: For more precise layout map computation, a fixpoint iteration could
535     // be done (i.e., re-computing the yielded buffer type until the bufferized
536     // iter_arg type no longer changes). This current implementation immediately
537     // switches to a fully dynamic layout map when a mismatch between bufferized
538     // init_arg type and bufferized yield value type is detected.
539     return *initArgBufferType;
540   }
541 
542   // Compute the buffer type of the yielded value.
543   BaseMemRefType yieldedValueBufferType;
544   if (isa<BaseMemRefType>(yieldedValue.getType())) {
545     // scf.yield was already bufferized.
546     yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
547   } else {
548     // Note: This typically triggers a recursive call for the buffer type of
549     // the iter_arg.
550     auto maybeBufferType =
551         bufferization::getBufferType(yieldedValue, options, invocationStack);
552     if (failed(maybeBufferType))
553       return failure();
554     yieldedValueBufferType = *maybeBufferType;
555   }
556 
557   // If yielded type and init_arg type are the same, use that type directly.
558   if (*initArgBufferType == yieldedValueBufferType)
559     return yieldedValueBufferType;
560 
561   // If there is a mismatch between the yielded buffer type and the init_arg
562   // buffer type, the buffer type must be promoted to a fully dynamic layout
563   // map.
564   auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
565   auto iterTensorType = cast<TensorType>(iterArg.getType());
566   auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
567   if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
568     return loopOp->emitOpError(
569         "init_arg and yielded value bufferize to inconsistent memory spaces");
570 #ifndef NDEBUG
571   if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
572     assert(
573         llvm::all_equal({yieldedRankedBufferType.getShape(),
574                          cast<MemRefType>(initBufferType).getShape(),
575                          cast<RankedTensorType>(iterTensorType).getShape()}) &&
576         "expected same shape");
577   }
578 #endif // NDEBUG
579   return getMemRefTypeWithFullyDynamicLayout(
580       iterTensorType, yieldedBufferType.getMemorySpace());
581 }
582 
583 /// Return `true` if the given loop may have 0 iterations.
584 bool mayHaveZeroIterations(scf::ForOp forOp) {
585   std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
586   std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
587   if (!lb.has_value() || !ub.has_value())
588     return true;
589   return *ub <= *lb;
590 }
591 
592 /// Bufferization of scf.for. Replace with a new scf.for that operates on
593 /// memrefs.
594 struct ForOpInterface
595     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
596                                                     scf::ForOp> {
597   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
598                               const AnalysisState &state) const {
599     auto forOp = cast<scf::ForOp>(op);
600 
601     // If the loop has zero iterations, the results of the op are their
602     // corresponding init_args, meaning that the init_args bufferize to a read.
603     if (mayHaveZeroIterations(forOp))
604       return true;
605 
606     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
607     // its matching bbArg may.
608     return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
609   }
610 
611   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
612                                const AnalysisState &state) const {
613     // Tensor iter_args of scf::ForOps are always considered as a write.
614     return true;
615   }
616 
617   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
618                                       const AnalysisState &state) const {
619     auto forOp = cast<scf::ForOp>(op);
620     OpResult opResult = forOp.getTiedLoopResult(&opOperand);
621     BufferRelation relation = bufferRelation(op, opResult, state);
622     return {{opResult, relation,
623              /*isDefinite=*/relation == BufferRelation::Equivalent}};
624   }
625 
626   BufferRelation bufferRelation(Operation *op, OpResult opResult,
627                                 const AnalysisState &state) const {
628     // ForOp results are equivalent to their corresponding init_args if the
629     // corresponding iter_args and yield values are equivalent.
630     auto forOp = cast<scf::ForOp>(op);
631     BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
632     bool equivalentYield = state.areEquivalentBufferizedValues(
633         bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
634     return equivalentYield ? BufferRelation::Equivalent
635                            : BufferRelation::Unknown;
636   }
637 
638   bool isWritable(Operation *op, Value value,
639                   const AnalysisState &state) const {
640     // Interestingly, scf::ForOp's bbArg can **always** be viewed
641     // inplace from the perspective of ops nested under:
642     //   1. Either the matching iter operand is not bufferized inplace and an
643     //      alloc + optional copy makes the bbArg itself inplaceable.
644     //   2. Or the matching iter operand is bufferized inplace and bbArg just
645     //      bufferizes to that too.
646     return true;
647   }
648 
649   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
650                                  const AnalysisState &state) const {
651     auto bufferizableOp = cast<BufferizableOpInterface>(op);
652     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
653       return failure();
654 
655     if (!state.getOptions().enforceAliasingInvariants ||
656         state.getOptions().copyBeforeWrite)
657       return success();
658 
659     // According to the `getAliasing...` implementations, a bufferized OpResult
660     // may alias only with the corresponding bufferized init_arg (or with a
661     // newly allocated buffer) and not with other buffers defined outside of the
662     // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
663     // but not with any other OpOperand.
664     auto forOp = cast<scf::ForOp>(op);
665     auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
666     OpBuilder::InsertionGuard g(rewriter);
667     rewriter.setInsertionPoint(yieldOp);
668 
669     // Indices of all iter_args that have tensor type. These are the ones that
670     // are bufferized.
671     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
672     // For every yielded value, does it alias with something defined outside of
673     // the loop?
674     SmallVector<Value> yieldValues;
675     for (const auto it : llvm::enumerate(yieldOp.getResults())) {
676       // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
677       // type cannot be used in the signature of `resolveConflicts` because the
678       // op interface is in the "IR" build unit and the `OneShotAnalysisState`
679       // is defined in the "Transforms" build unit.
680       if (!indices.contains(it.index()) ||
681           doesNotAliasExternalValue(
682               it.value(), &forOp.getRegion(),
683               /*exceptions=*/forOp.getRegionIterArg(it.index()),
684               static_cast<const OneShotAnalysisState &>(state))) {
685         yieldValues.push_back(it.value());
686         continue;
687       }
688       FailureOr<Value> alloc = allocateTensorForShapedValue(
689           rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
690       if (failed(alloc))
691         return failure();
692       yieldValues.push_back(*alloc);
693     }
694 
695     rewriter.modifyOpInPlace(
696         yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
697     return success();
698   }
699 
700   FailureOr<BaseMemRefType>
701   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
702                 SmallVector<Value> &invocationStack) const {
703     auto forOp = cast<scf::ForOp>(op);
704     assert(getOwnerOfValue(value) == op && "invalid value");
705     assert(isa<TensorType>(value.getType()) && "expected tensor type");
706 
707     if (auto opResult = dyn_cast<OpResult>(value)) {
708       // The type of an OpResult must match the corresponding iter_arg type.
709       BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
710       return bufferization::getBufferType(bbArg, options, invocationStack);
711     }
712 
713     // Compute result/argument number.
714     BlockArgument bbArg = cast<BlockArgument>(value);
715     unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
716 
717     // Compute the bufferized type.
718     auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
719     Value yieldedValue = yieldOp.getOperand(resultNum);
720     BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
721     Value initArg = forOp.getInitArgs()[resultNum];
722     return computeLoopRegionIterArgBufferType(
723         op, iterArg, initArg, yieldedValue, options, invocationStack);
724   }
725 
726   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
727                           const BufferizationOptions &options) const {
728     auto forOp = cast<scf::ForOp>(op);
729     Block *oldLoopBody = forOp.getBody();
730 
731     // Indices of all iter_args that have tensor type. These are the ones that
732     // are bufferized.
733     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
734 
735     // The new memref init_args of the loop.
736     FailureOr<SmallVector<Value>> maybeInitArgs =
737         getBuffers(rewriter, forOp.getInitArgsMutable(), options);
738     if (failed(maybeInitArgs))
739       return failure();
740     SmallVector<Value> initArgs = *maybeInitArgs;
741 
742     // Cast init_args if necessary.
743     SmallVector<Value> castedInitArgs;
744     for (const auto &it : llvm::enumerate(initArgs)) {
745       Value initArg = it.value();
746       Value result = forOp->getResult(it.index());
747       // If the type is not a tensor, bufferization doesn't need to touch it.
748       if (!isa<TensorType>(result.getType())) {
749         castedInitArgs.push_back(initArg);
750         continue;
751       }
752       auto targetType = bufferization::getBufferType(result, options);
753       if (failed(targetType))
754         return failure();
755       castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
756     }
757 
758     // Construct a new scf.for op with memref instead of tensor values.
759     auto newForOp = rewriter.create<scf::ForOp>(
760         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
761         forOp.getStep(), castedInitArgs);
762     newForOp->setAttrs(forOp->getAttrs());
763     Block *loopBody = newForOp.getBody();
764 
765     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
766     // iter_args of the new loop in ToTensorOps.
767     rewriter.setInsertionPointToStart(loopBody);
768     SmallVector<Value> iterArgs =
769         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
770                              forOp.getRegionIterArgs(), indices);
771     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
772 
773     // Move loop body to new loop.
774     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
775 
776     // Replace loop results.
777     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
778 
779     return success();
780   }
781 
782   /// Assert that yielded values of an scf.for op are equivalent to their
783   /// corresponding bbArgs. In that case, the buffer relations of the
784   /// corresponding OpResults are "Equivalent".
785   ///
786   /// If this is not the case, an allocs+copies are inserted and yielded from
787   /// the loop. This could be a performance problem, so it must be explicitly
788   /// activated with `alloc-return-allocs`.
789   LogicalResult verifyAnalysis(Operation *op,
790                                const AnalysisState &state) const {
791     const auto &options =
792         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
793     if (options.allowReturnAllocsFromLoops)
794       return success();
795 
796     auto forOp = cast<scf::ForOp>(op);
797     auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
798     for (OpResult opResult : op->getOpResults()) {
799       if (!isa<TensorType>(opResult.getType()))
800         continue;
801 
802       // Note: This is overly strict. We should check for aliasing bufferized
803       // values. But we don't have a "must-alias" analysis yet.
804       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
805         return yieldOp->emitError()
806                << "Yield operand #" << opResult.getResultNumber()
807                << " is not equivalent to the corresponding iter bbArg";
808     }
809 
810     return success();
811   }
812 };
813 
814 /// Bufferization of scf.while. Replace with a new scf.while that operates on
815 /// memrefs.
816 struct WhileOpInterface
817     : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
818                                                     scf::WhileOp> {
819   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
820                               const AnalysisState &state) const {
821     // Tensor iter_args of scf::WhileOps are always considered as a read.
822     return true;
823   }
824 
825   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
826                                const AnalysisState &state) const {
827     // Tensor iter_args of scf::WhileOps are always considered as a write.
828     return true;
829   }
830 
831   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
832                                       const AnalysisState &state) const {
833     auto whileOp = cast<scf::WhileOp>(op);
834     unsigned int idx = opOperand.getOperandNumber();
835 
836     // The OpResults and OpOperands may not match. They may not even have the
837     // same type. The number of OpResults and OpOperands can also differ.
838     if (idx >= op->getNumResults() ||
839         opOperand.get().getType() != op->getResult(idx).getType())
840       return {};
841 
842     // The only aliasing OpResult may be the one at the same index.
843     OpResult opResult = whileOp->getResult(idx);
844     BufferRelation relation = bufferRelation(op, opResult, state);
845     return {{opResult, relation,
846              /*isDefinite=*/relation == BufferRelation::Equivalent}};
847   }
848 
849   BufferRelation bufferRelation(Operation *op, OpResult opResult,
850                                 const AnalysisState &state) const {
851     // WhileOp results are equivalent to their corresponding init_args if the
852     // corresponding iter_args and yield values are equivalent (for both the
853     // "before" and the "after" block).
854     unsigned int resultNumber = opResult.getResultNumber();
855     auto whileOp = cast<scf::WhileOp>(op);
856 
857     // The "before" region bbArgs and the OpResults may not match.
858     if (resultNumber >= whileOp.getBeforeArguments().size())
859       return BufferRelation::Unknown;
860     if (opResult.getType() !=
861         whileOp.getBeforeArguments()[resultNumber].getType())
862       return BufferRelation::Unknown;
863 
864     auto conditionOp = whileOp.getConditionOp();
865     BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
866     Value conditionOperand = conditionOp.getArgs()[resultNumber];
867     bool equivCondition =
868         state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
869 
870     auto yieldOp = whileOp.getYieldOp();
871     BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
872     Value yieldOperand = yieldOp.getOperand(resultNumber);
873     bool equivYield =
874         state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
875 
876     return equivCondition && equivYield ? BufferRelation::Equivalent
877                                         : BufferRelation::Unknown;
878   }
879 
880   bool isWritable(Operation *op, Value value,
881                   const AnalysisState &state) const {
882     // Interestingly, scf::WhileOp's bbArg can **always** be viewed
883     // inplace from the perspective of ops nested under:
884     //   1. Either the matching iter operand is not bufferized inplace and an
885     //      alloc + optional copy makes the bbArg itself inplaceable.
886     //   2. Or the matching iter operand is bufferized inplace and bbArg just
887     //      bufferizes to that too.
888     return true;
889   }
890 
891   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
892                                  const AnalysisState &state) const {
893     auto bufferizableOp = cast<BufferizableOpInterface>(op);
894     if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
895       return failure();
896 
897     if (!state.getOptions().enforceAliasingInvariants ||
898         state.getOptions().copyBeforeWrite)
899       return success();
900 
901     // According to the `getAliasing...` implementations, a bufferized OpResult
902     // may alias only with the corresponding bufferized init_arg and with no
903     // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
904     // but not with any other OpOperand. If a corresponding OpResult/init_arg
905     // pair bufferizes to equivalent buffers, this aliasing requirement is
906     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
907     // (New buffer copies do not alias with any buffer.)
908     OpBuilder::InsertionGuard g(rewriter);
909     auto whileOp = cast<scf::WhileOp>(op);
910     auto conditionOp = whileOp.getConditionOp();
911 
912     // For every yielded value, is the value equivalent to its corresponding
913     // bbArg?
914     DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
915         whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
916     DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
917         whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
918 
919     // Update "before" region.
920     rewriter.setInsertionPoint(conditionOp);
921     SmallVector<Value> beforeYieldValues;
922     for (int64_t idx = 0;
923          idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
924       Value value = conditionOp.getArgs()[idx];
925       if (!isa<TensorType>(value.getType()) ||
926           (equivalentYieldsAfter.contains(idx) &&
927            equivalentYieldsBefore.contains(idx))) {
928         beforeYieldValues.push_back(value);
929         continue;
930       }
931       FailureOr<Value> alloc = allocateTensorForShapedValue(
932           rewriter, conditionOp.getLoc(), value, state.getOptions());
933       if (failed(alloc))
934         return failure();
935       beforeYieldValues.push_back(*alloc);
936     }
937     rewriter.modifyOpInPlace(conditionOp, [&]() {
938       conditionOp.getArgsMutable().assign(beforeYieldValues);
939     });
940 
941     return success();
942   }
943 
944   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
945                           const BufferizationOptions &options) const {
946     auto whileOp = cast<scf::WhileOp>(op);
947 
948     // Indices of all bbArgs that have tensor type. These are the ones that
949     // are bufferized. The "before" and "after" regions may have different args.
950     DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
951     DenseSet<int64_t> indicesAfter =
952         getTensorIndices(whileOp.getAfterArguments());
953 
954     // The new memref init_args of the loop.
955     FailureOr<SmallVector<Value>> maybeInitArgs =
956         getBuffers(rewriter, whileOp.getInitsMutable(), options);
957     if (failed(maybeInitArgs))
958       return failure();
959     SmallVector<Value> initArgs = *maybeInitArgs;
960 
961     // Cast init_args if necessary.
962     SmallVector<Value> castedInitArgs;
963     for (const auto &it : llvm::enumerate(initArgs)) {
964       Value initArg = it.value();
965       Value beforeArg = whileOp.getBeforeArguments()[it.index()];
966       // If the type is not a tensor, bufferization doesn't need to touch it.
967       if (!isa<TensorType>(beforeArg.getType())) {
968         castedInitArgs.push_back(initArg);
969         continue;
970       }
971       auto targetType = bufferization::getBufferType(beforeArg, options);
972       if (failed(targetType))
973         return failure();
974       castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
975     }
976 
977     // The result types of a WhileOp are the same as the "after" bbArg types.
978     SmallVector<Type> argsTypesAfter = llvm::to_vector(
979         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
980           if (!isa<TensorType>(bbArg.getType()))
981             return bbArg.getType();
982           // TODO: error handling
983           return llvm::cast<Type>(
984               *bufferization::getBufferType(bbArg, options));
985         }));
986 
987     // Construct a new scf.while op with memref instead of tensor values.
988     ValueRange argsRangeBefore(castedInitArgs);
989     TypeRange argsTypesBefore(argsRangeBefore);
990     auto newWhileOp = rewriter.create<scf::WhileOp>(
991         whileOp.getLoc(), argsTypesAfter, castedInitArgs);
992 
993     // Add before/after regions to the new op.
994     SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
995                                           whileOp.getLoc());
996     SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
997                                          whileOp.getLoc());
998     Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
999     newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
1000     Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1001     newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1002 
1003     // Set up new iter_args and move the loop condition block to the new op.
1004     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1005     // in ToTensorOps.
1006     rewriter.setInsertionPointToStart(newBeforeBody);
1007     SmallVector<Value> newBeforeArgs =
1008         getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1009                              whileOp.getBeforeArguments(), indicesBefore);
1010     rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1011 
1012     // Set up new iter_args and move the loop body block to the new op.
1013     // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1014     // in ToTensorOps.
1015     rewriter.setInsertionPointToStart(newAfterBody);
1016     SmallVector<Value> newAfterArgs =
1017         getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1018                              whileOp.getAfterArguments(), indicesAfter);
1019     rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1020 
1021     // Replace loop results.
1022     replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1023 
1024     return success();
1025   }
1026 
1027   FailureOr<BaseMemRefType>
1028   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1029                 SmallVector<Value> &invocationStack) const {
1030     auto whileOp = cast<scf::WhileOp>(op);
1031     assert(getOwnerOfValue(value) == op && "invalid value");
1032     assert(isa<TensorType>(value.getType()) && "expected tensor type");
1033 
1034     // Case 1: Block argument of the "before" region.
1035     if (auto bbArg = dyn_cast<BlockArgument>(value)) {
1036       if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1037         Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1038         auto yieldOp = whileOp.getYieldOp();
1039         Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
1040         return computeLoopRegionIterArgBufferType(
1041             op, bbArg, initArg, yieldedValue, options, invocationStack);
1042       }
1043     }
1044 
1045     // Case 2: OpResult of the loop or block argument of the "after" region.
1046     // The bufferized "after" bbArg type can be directly computed from the
1047     // bufferized "before" bbArg type.
1048     unsigned resultNum;
1049     if (auto opResult = dyn_cast<OpResult>(value)) {
1050       resultNum = opResult.getResultNumber();
1051     } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1052                &whileOp.getAfter()) {
1053       resultNum = cast<BlockArgument>(value).getArgNumber();
1054     } else {
1055       llvm_unreachable("invalid value");
1056     }
1057     Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1058     if (!isa<TensorType>(conditionYieldedVal.getType())) {
1059       // scf.condition was already bufferized.
1060       return cast<BaseMemRefType>(conditionYieldedVal.getType());
1061     }
1062     return bufferization::getBufferType(conditionYieldedVal, options,
1063                                         invocationStack);
1064   }
1065 
1066   /// Assert that yielded values of an scf.while op are equivalent to their
1067   /// corresponding bbArgs. In that case, the buffer relations of the
1068   /// corresponding OpResults are "Equivalent".
1069   ///
1070   /// If this is not the case, allocs+copies are inserted and yielded from
1071   /// the loop. This could be a performance problem, so it must be explicitly
1072   /// activated with `allow-return-allocs`.
1073   ///
1074   /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1075   /// equivalence condition must be checked for both.
1076   LogicalResult verifyAnalysis(Operation *op,
1077                                const AnalysisState &state) const {
1078     auto whileOp = cast<scf::WhileOp>(op);
1079     const auto &options =
1080         static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1081     if (options.allowReturnAllocsFromLoops)
1082       return success();
1083 
1084     auto conditionOp = whileOp.getConditionOp();
1085     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1086       Block *block = conditionOp->getBlock();
1087       if (!isa<TensorType>(it.value().getType()))
1088         continue;
1089       if (it.index() >= block->getNumArguments() ||
1090           !state.areEquivalentBufferizedValues(it.value(),
1091                                                block->getArgument(it.index())))
1092         return conditionOp->emitError()
1093                << "Condition arg #" << it.index()
1094                << " is not equivalent to the corresponding iter bbArg";
1095     }
1096 
1097     auto yieldOp = whileOp.getYieldOp();
1098     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1099       Block *block = yieldOp->getBlock();
1100       if (!isa<TensorType>(it.value().getType()))
1101         continue;
1102       if (it.index() >= block->getNumArguments() ||
1103           !state.areEquivalentBufferizedValues(it.value(),
1104                                                block->getArgument(it.index())))
1105         return yieldOp->emitError()
1106                << "Yield operand #" << it.index()
1107                << " is not equivalent to the corresponding iter bbArg";
1108     }
1109 
1110     return success();
1111   }
1112 };
1113 
1114 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1115 /// this is for analysis only.
1116 struct YieldOpInterface
1117     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1118                                                     scf::YieldOp> {
1119   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1120                               const AnalysisState &state) const {
1121     return true;
1122   }
1123 
1124   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1125                                const AnalysisState &state) const {
1126     return false;
1127   }
1128 
1129   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1130                                       const AnalysisState &state) const {
1131     if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
1132       return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1133                BufferRelation::Equivalent, /*isDefinite=*/false}};
1134     }
1135     if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
1136       return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1137                BufferRelation::Equivalent}};
1138     return {};
1139   }
1140 
1141   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1142                             const AnalysisState &state) const {
1143     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1144     // may be generated inside the block. We should not return/yield allocations
1145     // when possible.
1146     return true;
1147   }
1148 
1149   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1150                           const BufferizationOptions &options) const {
1151     auto yieldOp = cast<scf::YieldOp>(op);
1152     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1153              scf::WhileOp>(yieldOp->getParentOp()))
1154       return yieldOp->emitError("unsupported scf::YieldOp parent");
1155 
1156     SmallVector<Value> newResults;
1157     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1158       Value value = it.value();
1159       if (isa<TensorType>(value.getType())) {
1160         FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
1161         if (failed(maybeBuffer))
1162           return failure();
1163         Value buffer = *maybeBuffer;
1164         // We may have to cast the value before yielding it.
1165         if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1166                 yieldOp->getParentOp())) {
1167           FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1168               yieldOp->getParentOp()->getResult(it.index()), options);
1169           if (failed(resultType))
1170             return failure();
1171           buffer = castBuffer(rewriter, buffer, *resultType);
1172         } else if (auto whileOp =
1173                        dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1174           FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1175               whileOp.getBeforeArguments()[it.index()], options);
1176           if (failed(resultType))
1177             return failure();
1178           buffer = castBuffer(rewriter, buffer, *resultType);
1179         }
1180         newResults.push_back(buffer);
1181       } else {
1182         newResults.push_back(value);
1183       }
1184     }
1185 
1186     replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1187     return success();
1188   }
1189 };
1190 
1191 /// Return `true` if the given loop may have 0 iterations.
1192 bool mayHaveZeroIterations(scf::ForallOp forallOp) {
1193   for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
1194                                  forallOp.getMixedUpperBound())) {
1195     std::optional<int64_t> lbConst = getConstantIntValue(lb);
1196     std::optional<int64_t> ubConst = getConstantIntValue(ub);
1197     if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
1198       return true;
1199   }
1200   return false;
1201 }
1202 
1203 /// Bufferization of ForallOp. This also bufferizes the terminator of the
1204 /// region. There are op interfaces for the terminators (InParallelOp
1205 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
1206 /// for bufferization.
1207 struct ForallOpInterface
1208     : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1209                                                     ForallOp> {
1210   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1211                               const AnalysisState &state) const {
1212     auto forallOp = cast<ForallOp>(op);
1213 
1214     // If the loop has zero iterations, the results of the op are their
1215     // corresponding shared_outs, meaning that the shared_outs bufferize to a
1216     // read.
1217     if (mayHaveZeroIterations(forallOp))
1218       return true;
1219 
1220     // scf::ForallOp alone doesn't bufferize to a memory read, one of the
1221     // uses of its matching bbArg may.
1222     return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
1223   }
1224 
1225   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1226                                const AnalysisState &state) const {
1227     // Outputs of scf::ForallOps are always considered as a write.
1228     return true;
1229   }
1230 
1231   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1232                                       const AnalysisState &state) const {
1233     auto forallOp = cast<ForallOp>(op);
1234     return {
1235         {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1236   }
1237 
1238   bool isWritable(Operation *op, Value value,
1239                   const AnalysisState &state) const {
1240     return true;
1241   }
1242 
1243   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1244                           const BufferizationOptions &options) const {
1245     OpBuilder::InsertionGuard guard(rewriter);
1246     auto forallOp = cast<ForallOp>(op);
1247     int64_t rank = forallOp.getRank();
1248 
1249     // Get buffers for all output operands.
1250     SmallVector<Value> buffers;
1251     for (Value out : forallOp.getOutputs()) {
1252       FailureOr<Value> buffer = getBuffer(rewriter, out, options);
1253       if (failed(buffer))
1254         return failure();
1255       buffers.push_back(*buffer);
1256     }
1257 
1258     // Use buffers instead of block arguments.
1259     rewriter.setInsertionPointToStart(forallOp.getBody());
1260     for (const auto &it : llvm::zip(
1261              forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1262       BlockArgument bbArg = std::get<0>(it);
1263       Value buffer = std::get<1>(it);
1264       Value bufferAsTensor = rewriter.create<ToTensorOp>(
1265           forallOp.getLoc(), bbArg.getType(), buffer);
1266       bbArg.replaceAllUsesWith(bufferAsTensor);
1267     }
1268 
1269     // Create new ForallOp without any results and drop the automatically
1270     // introduced terminator.
1271     rewriter.setInsertionPoint(forallOp);
1272     ForallOp newForallOp;
1273     newForallOp = rewriter.create<ForallOp>(
1274         forallOp.getLoc(), forallOp.getMixedLowerBound(),
1275         forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1276         /*outputs=*/ValueRange(), forallOp.getMapping());
1277 
1278     // Keep discardable attributes from the original op.
1279     newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1280 
1281     rewriter.eraseOp(newForallOp.getBody()->getTerminator());
1282 
1283     // Move over block contents of the old op.
1284     SmallVector<Value> replacementBbArgs;
1285     replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1286                              newForallOp.getBody()->getArguments().end());
1287     replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1288     rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1289                          replacementBbArgs);
1290 
1291     // Remove the old op and replace all of its uses.
1292     replaceOpWithBufferizedValues(rewriter, op, buffers);
1293 
1294     return success();
1295   }
1296 
1297   FailureOr<BaseMemRefType>
1298   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1299                 SmallVector<Value> &invocationStack) const {
1300     auto forallOp = cast<ForallOp>(op);
1301 
1302     if (auto bbArg = dyn_cast<BlockArgument>(value))
1303       // A tensor block argument has the same bufferized type as the
1304       // corresponding output operand.
1305       return bufferization::getBufferType(
1306           forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
1307 
1308     // The bufferized result type is the same as the bufferized type of the
1309     // corresponding output operand.
1310     return bufferization::getBufferType(
1311         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1312         invocationStack);
1313   }
1314 
1315   bool isRepetitiveRegion(Operation *op, unsigned index) const {
1316     auto forallOp = cast<ForallOp>(op);
1317 
1318     // This op is repetitive if it has 1 or more steps.
1319     // If the control variables are dynamic, it is also considered so.
1320     for (auto [lb, ub, step] :
1321          llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1322                    forallOp.getMixedStep())) {
1323       std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1324       if (!lbConstant)
1325         return true;
1326 
1327       std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1328       if (!ubConstant)
1329         return true;
1330 
1331       std::optional<int64_t> stepConstant = getConstantIntValue(step);
1332       if (!stepConstant)
1333         return true;
1334 
1335       if (*lbConstant + *stepConstant < *ubConstant)
1336         return true;
1337     }
1338     return false;
1339   }
1340 
1341   bool isParallelRegion(Operation *op, unsigned index) const {
1342     return isRepetitiveRegion(op, index);
1343   }
1344 };
1345 
1346 /// Nothing to do for InParallelOp.
1347 struct InParallelOpInterface
1348     : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1349                                                     InParallelOp> {
1350   LogicalResult bufferize(Operation *op, RewriterBase &b,
1351                           const BufferizationOptions &options) const {
1352     llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1353     return failure();
1354   }
1355 };
1356 
1357 } // namespace
1358 } // namespace scf
1359 } // namespace mlir
1360 
1361 void mlir::scf::registerBufferizableOpInterfaceExternalModels(
1362     DialectRegistry &registry) {
1363   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1364     ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1365     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1366     ForOp::attachInterface<ForOpInterface>(*ctx);
1367     IfOp::attachInterface<IfOpInterface>(*ctx);
1368     IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1369     ForallOp::attachInterface<ForallOpInterface>(*ctx);
1370     InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1371     WhileOp::attachInterface<WhileOpInterface>(*ctx);
1372     YieldOp::attachInterface<YieldOpInterface>(*ctx);
1373   });
1374 }
1375