xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (revision d9111f19d2ea53d8ce105b3d09425394ccf37969)
1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // One-Shot Analysis analyzes function bodies. By default, function boundaries
10 // (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
11 // OneShotModuleBufferization.cpp is an extension of One-Shot Analysis for
12 // simple call graphs without loops.
13 //
14 // One-Shot Bufferize consists of three phases.
15 //
16 // 1. Analyze ops to decide which OpOperands can bufferize inplace, i.e.,
17 //    without inserting buffer copies. The analysis queries op bufferization
18 //    semantics via `BufferizableOpInterface`.
19 // 2. Insert copies for OpOperands that were decided to bufferize out-of-place
20 //    in tensor land during `TensorCopyInsertion`.
21 // 3. Bufferize ops by calling `BufferizableOpInterface::bufferize`.
22 //
23 // This file contains only the analysis. For convenience, this file also
24 // contains a helper function `runOneShotBufferize` that analyzes an op (and its
25 // nested ops) and then bufferizes it.
26 //
27 // Inplace bufferization decisions are passed from the analysis to the
28 // `TensorCopyInsertion` phase via `AnalysisState`. They can be printed for
29 // debugging purposes with `testAnalysisOnly`.
30 //
31 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
32 // treated conservatively. E.g., the analysis has to assume that their tensor
33 // OpOperands bufferize to memory writes. While such ops can be analyzed, they
34 // are not bufferized and remain in the IR. to_tensor and to_memref ops are
35 // inserted at the bufferization boundary.
36 //
37 // This analysis caters to high-performance codegen where buffer reuse is deemed
38 // critical: the analysis should fail if the bufferized form of the function
39 // needs to return a buffer, unless `allowReturnAllocs` is enabled.
40 
41 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
42 
43 #include <optional>
44 #include <random>
45 
46 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
47 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
48 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
49 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
50 #include "mlir/Dialect/Func/IR/FuncOps.h"
51 #include "mlir/Dialect/MemRef/IR/MemRef.h"
52 #include "mlir/IR/AsmState.h"
53 #include "mlir/IR/Dominance.h"
54 #include "mlir/IR/Iterators.h"
55 #include "mlir/IR/Operation.h"
56 #include "mlir/IR/TypeUtilities.h"
57 #include "mlir/Interfaces/ControlFlowInterfaces.h"
58 #include "mlir/Interfaces/SubsetOpInterface.h"
59 #include "llvm/ADT/DenseSet.h"
60 #include "llvm/ADT/SetVector.h"
61 
62 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
63 
64 // Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug
65 // output.
66 #define DEBUG_TYPE "one-shot-analysis"
67 
68 using namespace mlir;
69 using namespace mlir::bufferization;
70 
71 static bool isaTensor(Type t) { return isa<TensorType>(t); }
72 
73 //===----------------------------------------------------------------------===//
74 // Bufferization-specific attribute manipulation.
75 // These are for testing and debugging only. Bufferization information is stored
76 // in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is
77 // annotated with the results of the analysis, so that they can be checked in
78 // tests.
79 //===----------------------------------------------------------------------===//
80 
81 /// Attribute marker to specify op operands that bufferize in-place.
82 constexpr StringLiteral kInPlaceOperandsAttrName = "__inplace_operands_attr__";
83 
84 constexpr StringLiteral kOpResultAliasSetAttrName =
85     "__opresult_alias_set_attr__";
86 
87 constexpr StringLiteral kBbArgAliasSetAttrName = "__bbarg_alias_set_attr__";
88 
89 /// Mark whether OpOperand will be bufferized inplace.
90 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
91   Operation *op = opOperand.getOwner();
92   SmallVector<StringRef> inPlaceVector;
93   if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) {
94     inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>(
95         cast<ArrayAttr>(attr).getAsValueRange<StringAttr>()));
96   } else {
97     inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
98     for (OpOperand &opOperand : op->getOpOperands())
99       if (isa<TensorType>(opOperand.get().getType()))
100         inPlaceVector[opOperand.getOperandNumber()] = "false";
101   }
102   inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
103   op->setAttr(kInPlaceOperandsAttrName,
104               OpBuilder(op).getStrArrayAttr(inPlaceVector));
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // OneShotAnalysisState
109 //===----------------------------------------------------------------------===//
110 
111 OneShotAnalysisState::OneShotAnalysisState(
112     Operation *op, const OneShotBufferizationOptions &options)
113     : AnalysisState(options, TypeID::get<OneShotAnalysisState>()) {
114   // Set up alias sets.
115   op->walk([&](Operation *op) {
116     for (Value v : op->getResults())
117       if (isa<TensorType>(v.getType()))
118         createAliasInfoEntry(v);
119     for (Region &r : op->getRegions())
120       for (Block &b : r.getBlocks())
121         for (auto bbArg : b.getArguments())
122           if (isa<TensorType>(bbArg.getType()))
123             createAliasInfoEntry(bbArg);
124   });
125 
126   // Mark OpOperands in-place that must bufferize in-place.
127   op->walk([&](BufferizableOpInterface bufferizableOp) {
128     if (!options.isOpAllowed(bufferizableOp))
129       return WalkResult::skip();
130     for (OpOperand &opOperand : bufferizableOp->getOpOperands())
131       if (isa<TensorType>(opOperand.get().getType()))
132         if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
133           bufferizeInPlace(opOperand);
134     return WalkResult::advance();
135   });
136 }
137 
138 void OneShotAnalysisState::applyOnEquivalenceClass(
139     Value v, function_ref<void(Value)> fun) const {
140   auto leaderIt = equivalentInfo.findLeader(v);
141   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
142        ++mit) {
143     fun(*mit);
144   }
145 }
146 
147 void OneShotAnalysisState::applyOnAliases(Value v,
148                                           function_ref<void(Value)> fun) const {
149   auto leaderIt = aliasInfo.findLeader(v);
150   for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
151     fun(*mit);
152   }
153 }
154 
155 bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
156                                                          Value v2) const {
157   return equivalentInfo.isEquivalent(v1, v2);
158 }
159 
160 bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
161                                                        Value v2) const {
162   return aliasInfo.isEquivalent(v1, v2);
163 }
164 
165 void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) {
166   if (inplaceBufferized.contains(&operand))
167     return;
168   inplaceBufferized.insert(&operand);
169   for (AliasingValue alias : getAliasingValues(operand))
170     aliasInfo.unionSets(alias.value, operand.get());
171   ++statNumTensorInPlace;
172 }
173 
174 void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) {
175   assert(!inplaceBufferized.contains(&operand) &&
176          "OpOperand was already decided to bufferize inplace");
177   ++statNumTensorOutOfPlace;
178 }
179 
180 void OneShotAnalysisState::createAliasInfoEntry(Value v) {
181   aliasInfo.insert(v);
182   equivalentInfo.insert(v);
183 }
184 
185 void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
186   op->walk([&](Operation *op) {
187     // Skip unknown ops.
188     auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
189     if (!bufferizableOp)
190       return WalkResult::skip();
191 
192     // Check all tensor OpResults.
193     for (OpResult opResult : op->getOpResults()) {
194       if (!isa<TensorType>(opResult.getType()))
195         continue;
196 
197       // If there is no preceding definition, the tensor contents are
198       // undefined.
199       if (opResult.getUses().empty())
200         continue;
201       // It does not really matter which use to take to search about
202       // the value's definitions.
203       OpOperand *opOperand = &(*opResult.getUses().begin());
204       if (findDefinitionsCached(opOperand).empty())
205         for (OpOperand &use : opResult.getUses())
206           undefinedTensorUses.insert(&use);
207     }
208 
209     return WalkResult::advance();
210   });
211 }
212 
213 bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
214   return undefinedTensorUses.contains(opOperand);
215 }
216 
217 bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
218   return inplaceBufferized.contains(&opOperand);
219 }
220 
221 bool OneShotAnalysisState::isValueWritten(Value value) const {
222   bool isWritten = false;
223   applyOnAliases(value, [&](Value val) {
224     for (OpOperand &use : val.getUses())
225       if (isInPlace(use) && bufferizesToMemoryWrite(use))
226         isWritten = true;
227   });
228   return isWritten;
229 }
230 
231 bool OneShotAnalysisState::isWritable(Value value) const {
232   // TODO: Out-of-place bufferized value could be considered writable.
233   // Query BufferizableOpInterface to see if the BlockArgument is writable.
234   if (auto bufferizableOp =
235           getOptions().dynCastBufferizableOp(getOwnerOfValue(value)))
236     return bufferizableOp.isWritable(value, *this);
237 
238   // Not a bufferizable op: The conservative answer is "not writable".
239   return false;
240 }
241 
242 void OneShotAnalysisState::unionAliasSets(Value v1, Value v2) {
243   aliasInfo.unionSets(v1, v2);
244 }
245 
246 void OneShotAnalysisState::unionEquivalenceClasses(Value v1, Value v2) {
247   equivalentInfo.unionSets(v1, v2);
248 }
249 
250 OneShotAnalysisState::Extension::~Extension() = default;
251 
252 //===----------------------------------------------------------------------===//
253 // Bufferization-specific alias analysis.
254 //===----------------------------------------------------------------------===//
255 
256 /// Return true if opOperand has been decided to bufferize in-place.
257 static bool isInplaceMemoryWrite(OpOperand &opOperand,
258                                  const OneShotAnalysisState &state) {
259   // OpOperands that do not bufferize to a memory write do not write in-place.
260   if (!state.bufferizesToMemoryWrite(opOperand))
261     return false;
262   // Check current bufferization decisions.
263   return state.isInPlace(opOperand);
264 }
265 
266 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
267 /// properly dominates `b` and `b` is not inside `a`.
268 static bool happensBefore(Operation *a, Operation *b,
269                           const DominanceInfo &domInfo) {
270   do {
271     // TODO: Instead of isProperAncestor + properlyDominates, we should use
272     // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
273     if (a->isProperAncestor(b))
274       return false;
275     if (domInfo.properlyDominates(a, b))
276       return true;
277   } while ((a = a->getParentOp()));
278   return false;
279 }
280 
281 /// Return `true` if op dominance can be used to rule out a read-after-write
282 /// conflicts based on the ordering of ops. Returns `false` if op dominance
283 /// cannot be used to due region-based loops.
284 ///
285 /// Generalized op dominance can often be used to rule out potential conflicts
286 /// due to "read happens before write". E.g., the following IR is not a RaW
287 /// conflict because the read happens *before* the write.
288 ///
289 /// Example 1:
290 /// %0 = ... : tensor<?xf32>                                // DEF
291 /// "reading_op"(%0) : tensor<?xf32>                        // READ
292 /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>  // WRITE
293 ///
294 /// This is no longer true inside loops (or repetitive regions). In such cases,
295 /// there may not be a meaningful `happensBefore` relationship because ops
296 /// could be executed multiple times. E.g.:
297 ///
298 /// Example 2:
299 /// %0 = ... : tensor<?xf32>                                  // DEF
300 /// scf.for ... {
301 ///   "reading_op"(%0) : tensor<?xf32>                        // READ
302 ///   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>  // WRITE
303 ///   ...
304 /// }
305 ///
306 /// In the above example, reading_op happens before writing_op according to
307 /// op dominance. However, both ops may happen multiple times; in
308 /// particular, the second execution of reading_op happens after the first
309 /// execution of writing_op. This is problematic because the tensor %0 they
310 /// operate on (i.e., the "definition") is defined outside of the loop.
311 ///
312 /// On a high-level, there is a potential RaW in a program if there exists a
313 /// possible program execution such that there is a sequence of DEF, followed
314 /// by WRITE, followed by READ. Each additional DEF resets the sequence.
315 ///
316 /// E.g.:
317 /// No conflict:        DEF, WRITE, DEF, READ
318 /// Potential conflict: DEF, READ, WRITE, READ, WRITE
319 ///
320 /// Example 1 has no conflict:          DEF, READ, WRITE
321 /// Example 2 has a potential conflict: DEF, (READ, WRITE)*
322 //
323 /// Example 3:
324 /// scf.for ... {
325 ///   %0 = ... : tensor<?xf32>
326 ///   "reading_op"(%0) : tensor<?xf32>
327 ///   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
328 ///   ...
329 /// }
330 /// This has no conflict: (DEF, READ, WRITE)*
331 ///
332 /// Example 4:
333 /// %0 = ... : tensor<?xf32>
334 /// scf.for ... {
335 ///   scf.for ... { "reading_op"(%0) }
336 ///   %1 = "writing_op"(%0)
337 /// }
338 /// This has a potential conflict: DEF, ((READ)*, WRITE)*
339 ///
340 /// Example 5:
341 /// %0 = ... : tensor<?xf32>
342 /// scf.for ... { %1 = "writing_op"(%0) }
343 /// scf.for ... { "reading_op"(%0) }
344 /// This has a potential conflict: DEF, WRITE*, READ*
345 ///
346 /// The following rules are used to rule out RaW conflicts via ordering of ops:
347 ///
348 /// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of
349 ///    a repetitive region that enclosing both READ and WRITE, we cannot rule
350 ///    out RaW conflict due to the ordering of ops.
351 /// 2. Otherwise: There are no loops that interfere with our analysis; for
352 ///    analysis purposes, we can assume that there are no loops/repetitive
353 ///    regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
354 ///    or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
355 ///
356 static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite,
357                                           const SetVector<Value> &definitions,
358                                           AnalysisState &state) {
359   const BufferizationOptions &options = state.getOptions();
360   for (Value def : definitions) {
361     Region *rRead =
362         state.getEnclosingRepetitiveRegion(uRead->getOwner(), options);
363     Region *rDef = state.getEnclosingRepetitiveRegion(def, options);
364 
365     // READ and DEF are in the same repetitive region. `happensBefore` can be
366     // used to rule out RaW conflicts due to op ordering.
367     if (rRead == rDef)
368       continue;
369 
370     // Find the enclosing repetitive region of READ that is closest to DEF but
371     // not the repetitive region of DEF itself.
372     while (true) {
373       Region *nextRegion = getNextEnclosingRepetitiveRegion(rRead, options);
374       if (nextRegion == rDef)
375         break;
376       assert(nextRegion && "expected to find another repetitive region");
377       rRead = nextRegion;
378     }
379 
380     // We cannot use op dominance if WRITE is inside the same repetitive region.
381     if (rRead->getParentOp()->isAncestor(uWrite->getOwner()))
382       return false;
383   }
384 
385   return true;
386 }
387 
388 /// Return `true` if op dominance can be used to rule out a read-after-write
389 /// conflicts based on the ordering of ops. Returns `false` if op dominance
390 /// cannot be used to due block-based loops within a region.
391 ///
392 /// Refer to the `canUseOpDominanceDueToRegions` documentation for details on
393 /// how op domiance is used during RaW conflict detection.
394 ///
395 /// On a high-level, there is a potential RaW in a program if there exists a
396 /// possible program execution such that there is a sequence of DEF, followed
397 /// by WRITE, followed by READ. Each additional DEF resets the sequence.
398 ///
399 /// Op dominance cannot be used if there is a path from block(READ) to
400 /// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should
401 /// not appear on that path.
402 static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite,
403                                          const SetVector<Value> &definitions,
404                                          AnalysisState &state) {
405   // Fast path: If READ and WRITE are in different regions, their block cannot
406   // be reachable just via unstructured control flow. (Loops due to regions are
407   // covered by `canUseOpDominanceDueToRegions`.)
408   if (uRead->getOwner()->getParentRegion() !=
409       uWrite->getOwner()->getParentRegion())
410     return true;
411 
412   Block *readBlock = uRead->getOwner()->getBlock();
413   Block *writeBlock = uWrite->getOwner()->getBlock();
414   for (Value def : definitions) {
415     Block *defBlock = def.getParentBlock();
416     if (readBlock->isReachable(writeBlock, {defBlock}) &&
417         writeBlock->isReachable(readBlock, {defBlock}))
418       return false;
419   }
420 
421   return true;
422 }
423 
424 static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
425                               const SetVector<Value> &definitions,
426                               AnalysisState &state) {
427   return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) &&
428          canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state);
429 }
430 
431 /// Annotate IR with details about the detected RaW conflict.
432 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
433                              Value definition) {
434   static uint64_t counter = 0;
435   Operation *readingOp = uRead->getOwner();
436   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
437 
438   OpBuilder b(conflictingWritingOp->getContext());
439   std::string id = "C_" + std::to_string(counter++);
440 
441   std::string conflictingWriteAttr =
442       id +
443       "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
444       "]";
445   conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
446 
447   std::string readAttr =
448       id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
449   readingOp->setAttr(readAttr, b.getUnitAttr());
450 
451   if (auto opResult = dyn_cast<OpResult>(definition)) {
452     std::string defAttr =
453         id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]";
454     opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
455   } else {
456     auto bbArg = cast<BlockArgument>(definition);
457     std::string defAttr =
458         id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
459     bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
460   }
461 }
462 
463 /// Return 'true' if a tensor that is equivalent to `other` can be found in the
464 /// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of
465 /// place along that use-def chain, the two tensors may not materialize as
466 /// equivalent buffers (but separate allocations).
467 ///
468 /// Note: This function also requires that the two tensors have equivalent
469 /// indexing. I.e., the tensor types do not change along the use-def chain,
470 /// apart from static <-> dynamic dim casts.
471 static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
472                                                    OpOperand *start,
473                                                    Value other) {
474   TraversalConfig config;
475   config.followEquivalentOnly = true;
476   config.alwaysIncludeLeaves = false;
477   config.followSameTypeOrCastsOnly = true;
478   return !state
479               .findValueInReverseUseDefChain(
480                   start, [&](Value v) { return v == other; }, config)
481               .empty();
482 }
483 
484 /// Return "true" if the given operand's value is originating from a subset
485 /// that is equivalent to the subset that `subsetOp` inserts into.
486 static bool matchesInsertDestination(const AnalysisState &state,
487                                      OpOperand *opOperand,
488                                      SubsetInsertionOpInterface subsetOp) {
489   auto matchingSubset = [&](Value val) {
490     if (auto opResult = dyn_cast<OpResult>(val))
491       if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
492             return state.areEquivalentBufferizedValues(v1, v2);
493           }))
494         return true;
495     return false;
496   };
497   // There may be multiple leaves at which the reverse SSA use-def chain lookup
498   // terminates. All of them must be equivalent subsets.
499   SetVector<Value> backwardSlice =
500       state.findValueInReverseUseDefChain(opOperand, matchingSubset);
501   return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
502 }
503 
504 /// Return "true" if the given "read" and potentially conflicting "write" are
505 /// not conflicting due to their subset relationship. The comments in this
506 /// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
507 /// pairs, but apply to any subset ops that implement the
508 /// `SubsetInsertionOpInterface`.
509 static bool areNonConflictingSubsets(OpOperand *uRead,
510                                      OpOperand *uConflictingWrite,
511                                      const AnalysisState &state) {
512   Operation *readingOp = uRead->getOwner();
513   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
514 
515   // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
516   // uRead is an InsertSliceOp...
517   if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) {
518     // As an example, consider the following IR.
519     //
520     // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
521     // %1 = linalg.fill %cst, %0 {inplace= [true] }
522     // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
523     //     {inplace= [true] }
524 
525     if (uRead == &subsetOp.getDestinationOperand() &&
526         matchesInsertDestination(state, uConflictingWrite, subsetOp))
527       // Case 1: The main insight is that InsertSliceOp reads only part of
528       // the destination tensor. The overwritten area is not read. If
529       // uConflictingWrite writes into exactly the memory location that is
530       // being read by uRead, this is not a conflict.
531       //
532       // In the above example:
533       // uRead             = OpOperand 1 (%t) of tensor.insert_slice
534       // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
535       //
536       // The read of %t does not conflict with the write of the FillOp
537       // (same aliases!) because the area that the FillOp operates on is
538       // exactly the one that is *not* read via %t.
539       return true;
540 
541     if (uRead == &subsetOp.getSourceOperand() &&
542         uConflictingWrite == &subsetOp.getDestinationOperand() &&
543         matchesInsertDestination(state, uRead, subsetOp))
544       // Case 2: The read of the source tensor and the write to the dest
545       // tensor via an InsertSliceOp is not a conflict if the read is
546       // reading exactly that part of an equivalent tensor that the
547       // InsertSliceOp is writing.
548       //
549       // In the above example:
550       // uRead             = OpOperand 0 (%1) of tensor.insert_slice
551       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
552       return true;
553   }
554 
555   // If uConflictingWrite is an InsertSliceOp...
556   if (auto subsetOp =
557           dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp))
558     // As an example, consider the following IR.
559     //
560     // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
561     // %1 = linalg.fill %cst, %0 {inplace= [true] }
562     // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
563     //     {inplace= [true] }
564     // %3 = vector.transfer_read %1, %cst
565     //
566     // In the above example:
567     // uRead             = OpOperand 0 (%1) of vector.transfer_read
568     // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
569     // definition        = %1
570     //
571     // This is not a conflict because the InsertSliceOp overwrites the
572     // memory segment of %1 with the exact same data. (Effectively, there
573     // is no memory write here.)
574     if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
575         state.areEquivalentBufferizedValues(
576             uRead->get(), subsetOp.getSourceOperand().get()) &&
577         matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp))
578       return true;
579 
580   return false;
581 }
582 
583 /// Given sets of uses and writes, return true if there is a RaW conflict under
584 /// the assumption that all given reads/writes alias the same buffer and that
585 /// all given writes bufferize inplace.
586 ///
587 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
588 /// the result of a definition W1. But because of bufferization decisions, R
589 /// actually reads another definition W2.
590 static bool
591 hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
592                               const DenseSet<OpOperand *> &usesWrite,
593                               const DominanceInfo &domInfo,
594                               OneShotAnalysisState &state) {
595   const BufferizationOptions &options = state.getOptions();
596 
597   // Before going through the main RaW analysis, find cases where a buffer must
598   // be privatized due to parallelism. If the result of a write is never read,
599   // privatization is not necessary (and large parts of the IR are likely dead).
600   if (options.checkParallelRegions && !usesRead.empty()) {
601     for (OpOperand *uConflictingWrite : usesWrite) {
602       // Find the allocation point or last write (definition) of the buffer.
603       // Note: In contrast to `findDefinitions`, this also returns results of
604       // ops that do not bufferize to memory write when no other definition
605       // could be found. E.g., "bufferization.alloc_tensor" would be included,
606       // even though that op just bufferizes to an allocation but does define
607       // the contents of the buffer.
608       SetVector<Value> definitionsOrLeaves =
609           state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) {
610             return state.bufferizesToMemoryWrite(v);
611           });
612       assert(!definitionsOrLeaves.empty() &&
613              "expected at least one definition or leaf");
614 
615       // The writing op must bufferize out-of-place if the definition is in a
616       // different parallel region than this write.
617       for (Value def : definitionsOrLeaves) {
618         if (getParallelRegion(def.getParentRegion(), options) !=
619             getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(),
620                               options)) {
621           LLVM_DEBUG(
622               llvm::dbgs()
623               << "\n- bufferizes out-of-place due to parallel region:\n");
624           LLVM_DEBUG(llvm::dbgs()
625                      << "  unConflictingWrite = operand "
626                      << uConflictingWrite->getOperandNumber() << " of "
627                      << *uConflictingWrite->getOwner() << "\n");
628           return true;
629         }
630       }
631     }
632   }
633 
634   for (OpOperand *uRead : usesRead) {
635     Operation *readingOp = uRead->getOwner();
636     LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
637     LLVM_DEBUG(llvm::dbgs() << "  uRead = operand " << uRead->getOperandNumber()
638                             << " of " << *readingOp << "\n");
639 
640     // Find the definition of uRead by following the SSA use-def chain.
641     // E.g.:
642     //
643     // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
644     // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
645     // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
646     //
647     // In the above example, if uRead is the OpOperand of reading_op, the
648     // definition is %0. Note that operations that create an alias but do not
649     // bufferize to a memory write (such as ExtractSliceOp) are skipped.
650     const SetVector<Value> &definitions = state.findDefinitionsCached(uRead);
651     if (definitions.empty()) {
652       // Fast path: No conflict if there are no definitions.
653       LLVM_DEBUG(llvm::dbgs()
654                  << "  no conflict: read value has no definitions\n");
655       continue;
656     }
657 
658     // Look for conflicting memory writes. Potential conflicts are writes to an
659     // alias that have been decided to bufferize inplace.
660     for (OpOperand *uConflictingWrite : usesWrite) {
661       LLVM_DEBUG(llvm::dbgs() << "  unConflictingWrite = operand "
662                               << uConflictingWrite->getOperandNumber() << " of "
663                               << *uConflictingWrite->getOwner() << "\n");
664 
665       // Check if op dominance can be used to rule out read-after-write
666       // conflicts.
667       bool useDominance =
668           canUseOpDominance(uRead, uConflictingWrite, definitions, state);
669       LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
670 
671       // Throughout this loop, check for multiple requirements that have to be
672       // met for uConflictingWrite to be an actual conflict.
673       Operation *conflictingWritingOp = uConflictingWrite->getOwner();
674 
675       // Inside of repetitive regions, ops may be executed multiple times and op
676       // dominance cannot be used to rule out conflicts.
677       if (useDominance) {
678         // No conflict if the readingOp dominates conflictingWritingOp, i.e.,
679         // the write is not visible when reading.
680         //
681         // Note: If ops are executed multiple times (e.g., because they are
682         //       inside a loop), there may be no meaningful `happensBefore`
683         //       relationship.
684         if (happensBefore(readingOp, conflictingWritingOp, domInfo)) {
685           LLVM_DEBUG(llvm::dbgs()
686                      << "  no conflict: read happens before write\n");
687           continue;
688         }
689 
690         // No conflict if the reading use equals the use of the conflicting
691         // write. A use cannot conflict with itself.
692         //
693         // Note: Just being the same op is not enough. It has to be the same
694         //       use.
695         // Note: If the op is executed multiple times (e.g., because it is
696         //       inside a loop), it may be conflicting with itself.
697         if (uConflictingWrite == uRead) {
698           LLVM_DEBUG(llvm::dbgs()
699                      << "  no conflict: read and write are same use\n");
700           continue;
701         }
702 
703         // Ops are not conflicting if they are in mutually exclusive regions.
704         //
705         // Note: If ops are executed multiple times (e.g., because they are
706         //       inside a loop), mutually exclusive regions may be executed
707         //       multiple times.
708         if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) {
709           LLVM_DEBUG(llvm::dbgs() << "  no conflict: read and write are in "
710                                      "mutually exclusive regions\n");
711           continue;
712         }
713 
714         // Two equivalent operands of the same op are not conflicting if the op
715         // bufferizes to element-wise access. I.e., all loads at a position
716         // happen before all stores to the same position.
717         if (conflictingWritingOp == readingOp) {
718           if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
719             if (bufferizableOp.bufferizesToElementwiseAccess(
720                     state, {uRead, uConflictingWrite})) {
721               if (hasEquivalentValueInReverseUseDefChain(
722                       state, uRead, uConflictingWrite->get()) ||
723                   hasEquivalentValueInReverseUseDefChain(
724                       state, uConflictingWrite, uRead->get())) {
725                 LLVM_DEBUG(
726                     llvm::dbgs()
727                     << "  no conflict: op bufferizes to element-wise access\n");
728                 continue;
729               }
730             }
731           }
732         }
733       }
734 
735       // No conflict if the operands are non-conflicting subsets.
736       if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
737         LLVM_DEBUG(llvm::dbgs() << "  no conflict: non-conflicting subsets\n");
738         continue;
739       }
740 
741       // No conflict if the op interface says so.
742       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
743         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
744           LLVM_DEBUG(llvm::dbgs()
745                      << "  no conflict: op interace of reading op says 'no'\n");
746           continue;
747         }
748       }
749 
750       if (conflictingWritingOp != readingOp) {
751         if (auto bufferizableOp =
752                 options.dynCastBufferizableOp(conflictingWritingOp)) {
753           if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
754                                               state)) {
755             LLVM_DEBUG(
756                 llvm::dbgs()
757                 << "  no conflict: op interace of writing op says 'no'\n");
758             continue;
759           }
760         }
761       }
762 
763       // Check all possible definitions.
764       for (Value definition : definitions) {
765         LLVM_DEBUG(llvm::dbgs() << "  * definition = " << definition << "\n");
766 
767         // No conflict if the conflicting write happens before the definition.
768         if (Operation *defOp = definition.getDefiningOp()) {
769           if (happensBefore(conflictingWritingOp, defOp, domInfo)) {
770             // conflictingWritingOp happens before defOp. No conflict.
771             LLVM_DEBUG(llvm::dbgs()
772                        << "    no conflict: write happens before definition\n");
773             continue;
774           }
775           // No conflict if conflictingWritingOp is contained in defOp.
776           if (defOp->isProperAncestor(conflictingWritingOp)) {
777             LLVM_DEBUG(
778                 llvm::dbgs()
779                 << "    no conflict: write is contained in definition\n");
780             continue;
781           }
782         } else {
783           auto bbArg = cast<BlockArgument>(definition);
784           Block *block = bbArg.getOwner();
785           if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
786             LLVM_DEBUG(llvm::dbgs() << "    no conflict: definition is bbArg "
787                                        "and write happens outside of block\n");
788             // conflictingWritingOp happens outside of the block. No
789             // conflict.
790             continue;
791           }
792         }
793 
794         // No conflict if the conflicting write and the definition are the same
795         // use.
796         AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite);
797         if (aliases.getNumAliases() == 1 &&
798             aliases.getAliases()[0].value == definition) {
799           LLVM_DEBUG(llvm::dbgs()
800                      << "    no conflict: definition and write are same\n");
801           continue;
802         }
803 
804         // All requirements are met. Conflict found!
805 
806         if (options.printConflicts)
807           annotateConflict(uRead, uConflictingWrite, definition);
808         LLVM_DEBUG(llvm::dbgs() << "  => RaW CONFLICT FOUND\n");
809         return true;
810       }
811     }
812   }
813 
814   return false;
815 }
816 
817 // Helper function to iterate on aliases of `root` and capture the writes.
818 static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
819                                      const OneShotAnalysisState &state) {
820   state.applyOnAliases(root, [&](Value alias) {
821     for (auto &use : alias.getUses())
822       // Inplace write to a value that aliases root.
823       if (isInplaceMemoryWrite(use, state))
824         res.insert(&use);
825   });
826 }
827 
828 // Helper function to iterate on aliases of `root` and capture the reads.
829 static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
830                              const OneShotAnalysisState &state) {
831   state.applyOnAliases(root, [&](Value alias) {
832     for (auto &use : alias.getUses()) {
833       // Read of a value that aliases root.
834       if (state.bufferizesToMemoryRead(use)) {
835         res.insert(&use);
836         continue;
837       }
838 
839       // Read of a dependent value in the SSA use-def chain. E.g.:
840       //
841       // %0 = ...
842       // %1 = tensor.extract_slice %0 {not_analyzed_yet}
843       // "read"(%1)
844       //
845       // In the above example, getAliasingReads(%0) includes the first OpOperand
846       // of the tensor.extract_slice op. The extract_slice itself does not read
847       // but its aliasing result is eventually fed into an op that does.
848       //
849       // Note: This is considered a "read" only if the use does not bufferize to
850       // a memory write. (We already ruled out memory reads. In case of a memory
851       // write, the buffer would be entirely overwritten; in the above example
852       // there would then be no flow of data from the extract_slice operand to
853       // its result's uses.)
854       if (!state.bufferizesToMemoryWrite(use)) {
855         AliasingValueList aliases = state.getAliasingValues(use);
856         if (llvm::any_of(aliases, [&](AliasingValue a) {
857               return state.isValueRead(a.value);
858             }))
859           res.insert(&use);
860       }
861     }
862   });
863 }
864 
865 /// Return true if bufferizing `operand` inplace would create a conflict. A read
866 /// R and a write W of the same alias set is a conflict if inplace bufferization
867 /// of W changes the value read by R to a value different from the one that
868 /// would be expected by tracing back R's origin through SSA use-def chains.
869 /// A conflict can only be introduced by a new alias and/or an inplace
870 /// bufferization decision.
871 ///
872 /// Example:
873 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
874 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
875 /// %e = tensor.extract_slice %1
876 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
877 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
878 ///
879 /// In the above example, the two TransferWriteOps have already been decided to
880 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
881 /// conflict because:
882 /// * According to SSA use-def chains, we expect to read the result of %1.
883 /// * However, adding an alias {%0, %t} would mean that the second
884 ///   TransferWriteOp overwrites the result of the first one. Therefore, the
885 ///   TransferReadOp would no longer be reading the result of %1.
886 ///
887 /// If `checkConsistencyOnly` is true, this function checks if there is a
888 /// read-after-write conflict without bufferizing `operand` inplace. This would
889 /// indicate a problem with the current inplace bufferization decisions.
890 ///
891 /// Note: If `checkConsistencyOnly`, this function may be called with a null
892 /// OpResult. In that case, only the consistency of bufferization decisions
893 /// involving aliases of the given OpOperand are checked.
894 static bool wouldCreateReadAfterWriteInterference(
895     OpOperand &operand, const DominanceInfo &domInfo,
896     OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
897   // Collect reads and writes of all aliases of OpOperand and OpResult.
898   DenseSet<OpOperand *> usesRead, usesWrite;
899   getAliasingReads(usesRead, operand.get(), state);
900   getAliasingInplaceWrites(usesWrite, operand.get(), state);
901   for (AliasingValue alias : state.getAliasingValues(operand)) {
902     getAliasingReads(usesRead, alias.value, state);
903     getAliasingInplaceWrites(usesWrite, alias.value, state);
904   }
905   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
906     usesWrite.insert(&operand);
907 
908   return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state);
909 }
910 
911 /// Annotate IR with details about the detected non-writability conflict.
912 static void annotateNonWritableTensor(Value value) {
913   static int64_t counter = 0;
914   OpBuilder b(value.getContext());
915   std::string id = "W_" + std::to_string(counter++);
916   if (auto opResult = dyn_cast<OpResult>(value)) {
917     std::string attr = id + "[NOT-WRITABLE: result " +
918                        std::to_string(opResult.getResultNumber()) + "]";
919     opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr());
920   } else {
921     auto bbArg = cast<BlockArgument>(value);
922     std::string attr = id + "[NOT-WRITABLE: bbArg " +
923                        std::to_string(bbArg.getArgNumber()) + "]";
924     bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr());
925   }
926 }
927 
928 /// Return true if bufferizing `operand` inplace would create a write to a
929 /// non-writable buffer.
930 static bool
931 wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
932                                     OneShotAnalysisState &state,
933                                     bool checkConsistencyOnly = false) {
934   bool foundWrite =
935       !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);
936 
937   if (!foundWrite) {
938     // Collect writes of all aliases of OpOperand and OpResult.
939     DenseSet<OpOperand *> usesWrite;
940     getAliasingInplaceWrites(usesWrite, operand.get(), state);
941     for (AliasingValue alias : state.getAliasingValues(operand))
942       getAliasingInplaceWrites(usesWrite, alias.value, state);
943     foundWrite = !usesWrite.empty();
944   }
945 
946   if (!foundWrite)
947     return false;
948 
949   // Look for a read-only tensor among all aliases.
950   bool foundReadOnly = false;
951   auto checkReadOnly = [&](Value v) {
952     if (!state.isWritable(v)) {
953       foundReadOnly = true;
954       if (state.getOptions().printConflicts)
955         annotateNonWritableTensor(v);
956     }
957   };
958   state.applyOnAliases(operand.get(), checkReadOnly);
959   for (AliasingValue alias : state.getAliasingValues(operand))
960     state.applyOnAliases(alias.value, checkReadOnly);
961   if (foundReadOnly) {
962     LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
963     return true;
964   }
965 
966   return false;
967 }
968 
969 //===----------------------------------------------------------------------===//
970 // Bufferization analyses.
971 //===----------------------------------------------------------------------===//
972 
973 // Find the values that define the contents of the given operand's value.
974 const llvm::SetVector<Value> &
975 OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
976   Value value = opOperand->get();
977   if (!cachedDefinitions.count(value))
978     cachedDefinitions[value] = findDefinitions(opOperand);
979   return cachedDefinitions[value];
980 }
981 
982 void OneShotAnalysisState::resetCache() {
983   AnalysisState::resetCache();
984   cachedDefinitions.clear();
985 }
986 
987 /// Determine if `operand` can be bufferized in-place.
988 static LogicalResult
989 bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
990                                 const DominanceInfo &domInfo) {
991   LLVM_DEBUG(
992       llvm::dbgs() << "//===-------------------------------------------===//\n"
993                    << "Analyzing operand #" << operand.getOperandNumber()
994                    << " of " << *operand.getOwner() << "\n");
995 
996   bool foundInterference =
997       wouldCreateWriteToNonWritableBuffer(operand, state) ||
998       wouldCreateReadAfterWriteInterference(operand, domInfo, state);
999 
1000   if (foundInterference)
1001     state.bufferizeOutOfPlace(operand);
1002   else
1003     state.bufferizeInPlace(operand);
1004 
1005   LLVM_DEBUG(llvm::dbgs()
1006              << "//===-------------------------------------------===//\n");
1007   return success();
1008 }
1009 
1010 LogicalResult
1011 OneShotAnalysisState::analyzeSingleOp(Operation *op,
1012                                       const DominanceInfo &domInfo) {
1013   for (OpOperand &opOperand : op->getOpOperands())
1014     if (isa<TensorType>(opOperand.get().getType()))
1015       if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo)))
1016         return failure();
1017   return success();
1018 }
1019 
1020 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
1021 static void equivalenceAnalysis(SmallVector<Operation *> &ops,
1022                                 OneShotAnalysisState &state) {
1023   for (Operation *op : ops) {
1024     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
1025       for (OpResult opResult : op->getOpResults()) {
1026         if (!isa<TensorType>(opResult.getType()))
1027           continue;
1028         AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
1029         if (aliases.getNumAliases() == 0)
1030           // Nothing to do if there are no aliasing OpOperands.
1031           continue;
1032 
1033         Value firstOperand = aliases.begin()->opOperand->get();
1034         bool allEquivalent = true;
1035         for (AliasingOpOperand alias : aliases) {
1036           bool isEquiv = alias.relation == BufferRelation::Equivalent;
1037           bool isInPlace = state.isInPlace(*alias.opOperand);
1038           Value operand = alias.opOperand->get();
1039           if (isEquiv && isInPlace && alias.isDefinite) {
1040             // Found a definite, equivalent alias. Merge equivalence sets.
1041             // There can only be one definite alias, so we can stop here.
1042             state.unionEquivalenceClasses(opResult, operand);
1043             allEquivalent = false;
1044             break;
1045           }
1046           if (!isEquiv || !isInPlace)
1047             allEquivalent = false;
1048           if (!state.areEquivalentBufferizedValues(operand, firstOperand))
1049             allEquivalent = false;
1050         }
1051 
1052         // If all "maybe" aliases are equivalent and the OpResult is not a new
1053         // allocation, it is a definite, equivalent alias. E.g.:
1054         //
1055         // aliasingOpOperands(%r) = {(%t0, EQUIV, MAYBE), (%t1, EQUIV, MAYBE)}
1056         // aliasingValues(%t0) = {(%r, EQUIV, MAYBE)}
1057         // aliasingValues(%t1) = {(%r, EQUIV, MAYBE)}
1058         // %r = arith.select %c, %t0, %t1 : tensor<?xf32>
1059         //
1060         // If %t0 and %t1 are equivalent, it is safe to union the equivalence
1061         // classes of %r, %t0 and %t1.
1062         if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult))
1063           state.unionEquivalenceClasses(opResult, firstOperand);
1064       }
1065     }
1066   }
1067 }
1068 
1069 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
1070 /// in `op`.
1071 static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) {
1072   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
1073   SmallVector<Operation *> ops;
1074   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
1075     // No tensors => no buffers.
1076     if (none_of(op->getResultTypes(), isaTensor))
1077       return;
1078     ops.push_back(op);
1079   });
1080 
1081   equivalenceAnalysis(ops, state);
1082 }
1083 
1084 /// "Bottom-up from terminators" heuristic.
1085 static SmallVector<Operation *>
1086 bottomUpFromTerminatorsHeuristic(Operation *op,
1087                                  const OneShotAnalysisState &state) {
1088   SetVector<Operation *> traversedOps;
1089 
1090   // Find region terminators.
1091   op->walk<WalkOrder::PostOrder>([&](RegionBranchTerminatorOpInterface term) {
1092     if (!traversedOps.insert(term))
1093       return;
1094     // Follow the reverse SSA use-def chain from each yielded value as long as
1095     // we stay within the same region.
1096     SmallVector<OpResult> worklist;
1097     for (Value v : term->getOperands()) {
1098       if (!isa<TensorType>(v.getType()))
1099         continue;
1100       auto opResult = dyn_cast<OpResult>(v);
1101       if (!opResult)
1102         continue;
1103       worklist.push_back(opResult);
1104     }
1105     while (!worklist.empty()) {
1106       OpResult opResult = worklist.pop_back_val();
1107       Operation *defOp = opResult.getDefiningOp();
1108       if (!traversedOps.insert(defOp))
1109         continue;
1110       if (!term->getParentRegion()->findAncestorOpInRegion(*defOp))
1111         continue;
1112       AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
1113       for (auto alias : aliases) {
1114         Value v = alias.opOperand->get();
1115         if (!isa<TensorType>(v.getType()))
1116           continue;
1117         auto opResult = dyn_cast<OpResult>(v);
1118         if (!opResult)
1119           continue;
1120         worklist.push_back(opResult);
1121       }
1122     }
1123   });
1124 
1125   // Analyze traversed ops, then all remaining ops.
1126   SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end());
1127   op->walk<WalkOrder::PostOrder, ReverseIterator>([&](Operation *op) {
1128     if (!traversedOps.contains(op) && hasTensorSemantics(op))
1129       result.push_back(op);
1130   });
1131   return result;
1132 }
1133 
1134 LogicalResult OneShotAnalysisState::analyzeOp(Operation *op,
1135                                               const DominanceInfo &domInfo) {
1136   OneShotBufferizationOptions::AnalysisHeuristic heuristic =
1137       getOptions().analysisHeuristic;
1138 
1139   SmallVector<Operation *> orderedOps;
1140   if (heuristic ==
1141       OneShotBufferizationOptions::AnalysisHeuristic::BottomUpFromTerminators) {
1142     orderedOps = bottomUpFromTerminatorsHeuristic(op, *this);
1143   } else {
1144     op->walk([&](Operation *op) {
1145       // No tensors => no buffers.
1146       if (!hasTensorSemantics(op))
1147         return;
1148       orderedOps.push_back(op);
1149     });
1150     switch (heuristic) {
1151     case OneShotBufferizationOptions::AnalysisHeuristic::BottomUp: {
1152       // Default: Walk ops in reverse for better interference analysis.
1153       std::reverse(orderedOps.begin(), orderedOps.end());
1154       break;
1155     }
1156     case OneShotBufferizationOptions::AnalysisHeuristic::TopDown: {
1157       // Ops are already sorted top-down in `orderedOps`.
1158       break;
1159     }
1160     case OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer: {
1161       assert(getOptions().analysisFuzzerSeed &&
1162              "expected that fuzzer seed it set");
1163       // This is a fuzzer. For testing purposes only. Randomize the order in
1164       // which operations are analyzed. The bufferization quality is likely
1165       // worse, but we want to make sure that no assertions are triggered
1166       // anywhere.
1167       std::mt19937 g(getOptions().analysisFuzzerSeed);
1168       llvm::shuffle(orderedOps.begin(), orderedOps.end(), g);
1169       break;
1170     }
1171     default: {
1172       llvm_unreachable("unsupported heuristic");
1173     }
1174     }
1175   }
1176 
1177   // Analyze ops in the computed order.
1178   for (Operation *op : orderedOps)
1179     if (failed(analyzeSingleOp(op, domInfo)))
1180       return failure();
1181 
1182   equivalenceAnalysis(op, *this);
1183   return success();
1184 }
1185 
1186 /// Perform various checks on the input IR to see if it contains IR constructs
1187 /// that are unsupported by One-Shot Bufferize.
1188 static LogicalResult
1189 checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo,
1190                                  OneShotAnalysisState &state) {
1191   const BufferizationOptions &options = state.getOptions();
1192 
1193   // Note: This walk cannot be combined with the one below because interface
1194   // methods of invalid/unsupported ops may be called during the second walk.
1195   // (On ops different from `op`.)
1196   WalkResult walkResult = op->walk([&](BufferizableOpInterface op) {
1197     // Skip ops that are not in the filter.
1198     if (!options.isOpAllowed(op.getOperation()))
1199       return WalkResult::advance();
1200 
1201     // Check for unsupported unstructured control flow.
1202     if (!op.supportsUnstructuredControlFlow()) {
1203       for (Region &r : op->getRegions()) {
1204         if (r.getBlocks().size() > 1) {
1205           op->emitOpError("op or BufferizableOpInterface implementation does "
1206                           "not support unstructured control flow, but at least "
1207                           "one region has multiple blocks");
1208           return WalkResult::interrupt();
1209         }
1210       }
1211     }
1212 
1213     return WalkResult::advance();
1214   });
1215   if (walkResult.wasInterrupted())
1216     return failure();
1217 
1218   walkResult = op->walk([&](BufferizableOpInterface op) {
1219     // Skip ops that are not in the filter.
1220     if (!options.isOpAllowed(op.getOperation()))
1221       return WalkResult::advance();
1222 
1223     // Input IR may not contain any ToTensorOps without the "restrict"
1224     // attribute. Such tensors may alias any other tensor, which is currently
1225     // not handled in the analysis.
1226     if (auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) {
1227       if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) {
1228         op->emitOpError("to_tensor ops without `restrict` are not supported by "
1229                         "One-Shot Analysis");
1230         return WalkResult::interrupt();
1231       }
1232     }
1233 
1234     for (OpOperand &opOperand : op->getOpOperands()) {
1235       if (isa<TensorType>(opOperand.get().getType())) {
1236         if (wouldCreateReadAfterWriteInterference(
1237                 opOperand, domInfo, state,
1238                 /*checkConsistencyOnly=*/true)) {
1239           // This error can happen if certain "mustBufferizeInPlace" interface
1240           // methods are implemented incorrectly, such that the IR already has
1241           // a RaW conflict before making any bufferization decisions. It can
1242           // also happen if the bufferization.materialize_in_destination is used
1243           // in such a way that a RaW conflict is not avoidable.
1244           op->emitOpError("not bufferizable under the given constraints: "
1245                           "cannot avoid RaW conflict");
1246           return WalkResult::interrupt();
1247         }
1248 
1249         if (state.isInPlace(opOperand) &&
1250             wouldCreateWriteToNonWritableBuffer(
1251                 opOperand, state, /*checkConsistencyOnly=*/true)) {
1252           op->emitOpError("not bufferizable under the given constraints: would "
1253                           "write to read-only buffer");
1254           return WalkResult::interrupt();
1255         }
1256       }
1257     }
1258 
1259     return WalkResult::advance();
1260   });
1261 
1262   return success(!walkResult.wasInterrupted());
1263 }
1264 
1265 /// Annotate the IR with the result of the analysis. For testing/debugging only.
1266 static void
1267 annotateOpsWithBufferizationMarkers(Operation *op,
1268                                     const OneShotAnalysisState &state) {
1269   // Add __inplace_operands_attr__.
1270   op->walk([&](Operation *op) {
1271     for (OpOperand &opOperand : op->getOpOperands())
1272       if (isa<TensorType>(opOperand.get().getType()))
1273         setInPlaceOpOperand(opOperand, state.isInPlace(opOperand));
1274   });
1275 }
1276 
1277 static void annotateOpsWithAliasSets(Operation *op,
1278                                      const OneShotAnalysisState &state) {
1279   AsmState asmState(op);
1280   Builder b(op->getContext());
1281   // Helper function to build an array attribute of aliasing SSA value strings.
1282   auto buildAliasesArray = [&](Value v) {
1283     SmallVector<Attribute> aliases;
1284     state.applyOnAliases(v, [&](Value alias) {
1285       std::string buffer;
1286       llvm::raw_string_ostream stream(buffer);
1287       alias.printAsOperand(stream, asmState);
1288       aliases.push_back(b.getStringAttr(buffer));
1289     });
1290     return b.getArrayAttr(aliases);
1291   };
1292 
1293   op->walk([&](Operation *op) {
1294     // Build alias set array for every OpResult.
1295     SmallVector<Attribute> opResultAliasSets;
1296     for (OpResult opResult : op->getOpResults()) {
1297       if (llvm::isa<TensorType>(opResult.getType())) {
1298         opResultAliasSets.push_back(buildAliasesArray(opResult));
1299       }
1300     }
1301     if (!opResultAliasSets.empty())
1302       op->setAttr(kOpResultAliasSetAttrName, b.getArrayAttr(opResultAliasSets));
1303 
1304     // Build alias set array for every BlockArgument.
1305     SmallVector<Attribute> regionAliasSets;
1306     bool hasTensorBbArg = false;
1307     for (Region &r : op->getRegions()) {
1308       SmallVector<Attribute> blockAliasSets;
1309       for (Block &block : r.getBlocks()) {
1310         SmallVector<Attribute> bbArgAliasSets;
1311         for (BlockArgument bbArg : block.getArguments()) {
1312           if (llvm::isa<TensorType>(bbArg.getType())) {
1313             bbArgAliasSets.push_back(buildAliasesArray(bbArg));
1314             hasTensorBbArg = true;
1315           }
1316         }
1317         blockAliasSets.push_back(b.getArrayAttr(bbArgAliasSets));
1318       }
1319       regionAliasSets.push_back(b.getArrayAttr(blockAliasSets));
1320     }
1321     if (hasTensorBbArg)
1322       op->setAttr(kBbArgAliasSetAttrName, b.getArrayAttr(regionAliasSets));
1323   });
1324 }
1325 
1326 LogicalResult bufferization::analyzeOp(Operation *op,
1327                                        OneShotAnalysisState &state,
1328                                        BufferizationStatistics *statistics) {
1329   DominanceInfo domInfo(op);
1330   const OneShotBufferizationOptions &options = state.getOptions();
1331 
1332   if (failed(checkPreBufferizationAssumptions(op, domInfo, state)))
1333     return failure();
1334 
1335   // If the analysis fails, just return.
1336   if (failed(state.analyzeOp(op, domInfo)))
1337     return failure();
1338 
1339   if (statistics) {
1340     statistics->numTensorInPlace = state.getStatNumTensorInPlace();
1341     statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace();
1342   }
1343 
1344   bool failedAnalysis = false;
1345 
1346   // Gather some extra analysis data.
1347   state.gatherUndefinedTensorUses(op);
1348 
1349   // Analysis verification: After setting up alias/equivalence sets, each op
1350   // can check for expected invariants/limitations and fail the analysis if
1351   // necessary.
1352   op->walk([&](Operation *op) {
1353     if (BufferizableOpInterface bufferizableOp =
1354             options.dynCastBufferizableOp(op))
1355       failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
1356   });
1357 
1358   // Annotate operations if we only want to report the analysis.
1359   if (options.testAnalysisOnly)
1360     annotateOpsWithBufferizationMarkers(op, state);
1361   if (options.dumpAliasSets)
1362     annotateOpsWithAliasSets(op, state);
1363 
1364   return success(!failedAnalysis);
1365 }
1366 
1367 LogicalResult
1368 bufferization::runOneShotBufferize(Operation *op,
1369                                    const OneShotBufferizationOptions &options,
1370                                    BufferizationStatistics *statistics) {
1371   // copy-before-write deactivates the analysis. It cannot be used together with
1372   // test-analysis-only.
1373   assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
1374          "invalid combination of bufferization flags");
1375 
1376   if (options.copyBeforeWrite) {
1377     // Copy buffer before each write. No analysis is needed.
1378   } else {
1379     // Run One-Shot Analysis and insert buffer copies (on the tensor level)
1380     // only where needed. This is the default and much more efficient than
1381     // copy-before-write.
1382     if (failed(insertTensorCopies(op, options, statistics)))
1383       return failure();
1384 
1385     // If test-analysis-only is set, the IR was annotated with RaW conflict
1386     // markers (attributes) during One-Shot Analysis.
1387     if (options.testAnalysisOnly)
1388       return success();
1389   }
1390 
1391   // Bufferize the op and its nested ops. If options.copyBeforeWrite is set,
1392   // a new buffer copy is allocated every time a buffer is written to.
1393   return bufferizeOp(op, options, statistics);
1394 }
1395