xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (revision d9111f19d2ea53d8ce105b3d09425394ccf37969)
17a1579acSMatthias Springer //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
27a1579acSMatthias Springer //
37a1579acSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47a1579acSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
57a1579acSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67a1579acSMatthias Springer //
77a1579acSMatthias Springer //===----------------------------------------------------------------------===//
87a1579acSMatthias Springer //
97cdfc843SMatthias Springer // One-Shot Analysis analyzes function bodies. By default, function boundaries
107cdfc843SMatthias Springer // (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
117cdfc843SMatthias Springer // OneShotModuleBufferization.cpp is an extension of One-Shot Analysis for
127cdfc843SMatthias Springer // simple call graphs without loops.
137a1579acSMatthias Springer //
147cdfc843SMatthias Springer // One-Shot Bufferize consists of three phases.
157a1579acSMatthias Springer //
167cdfc843SMatthias Springer // 1. Analyze ops to decide which OpOperands can bufferize inplace, i.e.,
177cdfc843SMatthias Springer //    without inserting buffer copies. The analysis queries op bufferization
187cdfc843SMatthias Springer //    semantics via `BufferizableOpInterface`.
197cdfc843SMatthias Springer // 2. Insert copies for OpOperands that were decided to bufferize out-of-place
207cdfc843SMatthias Springer //    in tensor land during `TensorCopyInsertion`.
217cdfc843SMatthias Springer // 3. Bufferize ops by calling `BufferizableOpInterface::bufferize`.
227a1579acSMatthias Springer //
237cdfc843SMatthias Springer // This file contains only the analysis. For convenience, this file also
247cdfc843SMatthias Springer // contains a helper function `runOneShotBufferize` that analyzes an op (and its
257cdfc843SMatthias Springer // nested ops) and then bufferizes it.
267a1579acSMatthias Springer //
277a1579acSMatthias Springer // Inplace bufferization decisions are passed from the analysis to the
287cdfc843SMatthias Springer // `TensorCopyInsertion` phase via `AnalysisState`. They can be printed for
297cdfc843SMatthias Springer // debugging purposes with `testAnalysisOnly`.
307a1579acSMatthias Springer //
317a1579acSMatthias Springer // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
327a1579acSMatthias Springer // treated conservatively. E.g., the analysis has to assume that their tensor
337a1579acSMatthias Springer // OpOperands bufferize to memory writes. While such ops can be analyzed, they
347a1579acSMatthias Springer // are not bufferized and remain in the IR. to_tensor and to_memref ops are
357a1579acSMatthias Springer // inserted at the bufferization boundary.
367a1579acSMatthias Springer //
377a1579acSMatthias Springer // This analysis caters to high-performance codegen where buffer reuse is deemed
387a1579acSMatthias Springer // critical: the analysis should fail if the bufferized form of the function
39855a11eeSMatthias Springer // needs to return a buffer, unless `allowReturnAllocs` is enabled.
407a1579acSMatthias Springer 
417a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
427a1579acSMatthias Springer 
43a1fe1f5fSKazu Hirata #include <optional>
448ee38f3bSMatthias Springer #include <random>
457a1579acSMatthias Springer 
467a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
477a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
487a1579acSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
4928b2f792SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
50d6dab38aSMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
517a1579acSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
527a1579acSMatthias Springer #include "mlir/IR/AsmState.h"
537a1579acSMatthias Springer #include "mlir/IR/Dominance.h"
5435d3b343SMatthias Springer #include "mlir/IR/Iterators.h"
557a1579acSMatthias Springer #include "mlir/IR/Operation.h"
567a1579acSMatthias Springer #include "mlir/IR/TypeUtilities.h"
577a1579acSMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h"
581abd8d1aSMatthias Springer #include "mlir/Interfaces/SubsetOpInterface.h"
597a1579acSMatthias Springer #include "llvm/ADT/DenseSet.h"
607a1579acSMatthias Springer #include "llvm/ADT/SetVector.h"
617a1579acSMatthias Springer 
62faa9be75SMatthias Springer MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
63faa9be75SMatthias Springer 
64e0b40af7SMatthias Springer // Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug
65e0b40af7SMatthias Springer // output.
66e0b40af7SMatthias Springer #define DEBUG_TYPE "one-shot-analysis"
67e0b40af7SMatthias Springer 
687a1579acSMatthias Springer using namespace mlir;
697a1579acSMatthias Springer using namespace mlir::bufferization;
707a1579acSMatthias Springer 
715550c821STres Popp static bool isaTensor(Type t) { return isa<TensorType>(t); }
727a1579acSMatthias Springer 
737a1579acSMatthias Springer //===----------------------------------------------------------------------===//
747a1579acSMatthias Springer // Bufferization-specific attribute manipulation.
757cdfc843SMatthias Springer // These are for testing and debugging only. Bufferization information is stored
76cf2d374eSMatthias Springer // in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is
777cdfc843SMatthias Springer // annotated with the results of the analysis, so that they can be checked in
787cdfc843SMatthias Springer // tests.
797a1579acSMatthias Springer //===----------------------------------------------------------------------===//
807a1579acSMatthias Springer 
817cdfc843SMatthias Springer /// Attribute marker to specify op operands that bufferize in-place.
827cdfc843SMatthias Springer constexpr StringLiteral kInPlaceOperandsAttrName = "__inplace_operands_attr__";
837a1579acSMatthias Springer 
84a02ad6c1SMatthias Springer constexpr StringLiteral kOpResultAliasSetAttrName =
85a02ad6c1SMatthias Springer     "__opresult_alias_set_attr__";
86a02ad6c1SMatthias Springer 
87a02ad6c1SMatthias Springer constexpr StringLiteral kBbArgAliasSetAttrName = "__bbarg_alias_set_attr__";
88bb9d1b55SMatthias Springer 
897a1579acSMatthias Springer /// Mark whether OpOperand will be bufferized inplace.
907a1579acSMatthias Springer static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
917a1579acSMatthias Springer   Operation *op = opOperand.getOwner();
927a1579acSMatthias Springer   SmallVector<StringRef> inPlaceVector;
937cdfc843SMatthias Springer   if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) {
947cdfc843SMatthias Springer     inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>(
955550c821STres Popp         cast<ArrayAttr>(attr).getAsValueRange<StringAttr>()));
967a1579acSMatthias Springer   } else {
977a1579acSMatthias Springer     inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
987a1579acSMatthias Springer     for (OpOperand &opOperand : op->getOpOperands())
995550c821STres Popp       if (isa<TensorType>(opOperand.get().getType()))
1007a1579acSMatthias Springer         inPlaceVector[opOperand.getOperandNumber()] = "false";
1017a1579acSMatthias Springer   }
1027a1579acSMatthias Springer   inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
1037cdfc843SMatthias Springer   op->setAttr(kInPlaceOperandsAttrName,
1047a1579acSMatthias Springer               OpBuilder(op).getStrArrayAttr(inPlaceVector));
1057a1579acSMatthias Springer }
1067a1579acSMatthias Springer 
1077a1579acSMatthias Springer //===----------------------------------------------------------------------===//
108cf2d374eSMatthias Springer // OneShotAnalysisState
1097a1579acSMatthias Springer //===----------------------------------------------------------------------===//
1107a1579acSMatthias Springer 
111cf2d374eSMatthias Springer OneShotAnalysisState::OneShotAnalysisState(
112cf2d374eSMatthias Springer     Operation *op, const OneShotBufferizationOptions &options)
113cf2d374eSMatthias Springer     : AnalysisState(options, TypeID::get<OneShotAnalysisState>()) {
114cf2d374eSMatthias Springer   // Set up alias sets.
115cf2d374eSMatthias Springer   op->walk([&](Operation *op) {
1167a1579acSMatthias Springer     for (Value v : op->getResults())
1175550c821STres Popp       if (isa<TensorType>(v.getType()))
1187a1579acSMatthias Springer         createAliasInfoEntry(v);
1197a1579acSMatthias Springer     for (Region &r : op->getRegions())
1207a1579acSMatthias Springer       for (Block &b : r.getBlocks())
1217a1579acSMatthias Springer         for (auto bbArg : b.getArguments())
1225550c821STres Popp           if (isa<TensorType>(bbArg.getType()))
1237a1579acSMatthias Springer             createAliasInfoEntry(bbArg);
1247a1579acSMatthias Springer   });
125cf2d374eSMatthias Springer 
126cf2d374eSMatthias Springer   // Mark OpOperands in-place that must bufferize in-place.
127cf2d374eSMatthias Springer   op->walk([&](BufferizableOpInterface bufferizableOp) {
128cf2d374eSMatthias Springer     if (!options.isOpAllowed(bufferizableOp))
129cf2d374eSMatthias Springer       return WalkResult::skip();
130cf2d374eSMatthias Springer     for (OpOperand &opOperand : bufferizableOp->getOpOperands())
1315550c821STres Popp       if (isa<TensorType>(opOperand.get().getType()))
132cf2d374eSMatthias Springer         if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
133cf2d374eSMatthias Springer           bufferizeInPlace(opOperand);
134cf2d374eSMatthias Springer     return WalkResult::advance();
135cf2d374eSMatthias Springer   });
1367a1579acSMatthias Springer }
1377a1579acSMatthias Springer 
138cf2d374eSMatthias Springer void OneShotAnalysisState::applyOnEquivalenceClass(
1397a1579acSMatthias Springer     Value v, function_ref<void(Value)> fun) const {
1407a1579acSMatthias Springer   auto leaderIt = equivalentInfo.findLeader(v);
1417a1579acSMatthias Springer   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
1427a1579acSMatthias Springer        ++mit) {
1437a1579acSMatthias Springer     fun(*mit);
1447a1579acSMatthias Springer   }
1457a1579acSMatthias Springer }
1467a1579acSMatthias Springer 
147cf2d374eSMatthias Springer void OneShotAnalysisState::applyOnAliases(Value v,
148cf2d374eSMatthias Springer                                           function_ref<void(Value)> fun) const {
1497a1579acSMatthias Springer   auto leaderIt = aliasInfo.findLeader(v);
1507a1579acSMatthias Springer   for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
1517a1579acSMatthias Springer     fun(*mit);
1527a1579acSMatthias Springer   }
1537a1579acSMatthias Springer }
1547a1579acSMatthias Springer 
1559597b16aSMatthias Springer bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
1567a1579acSMatthias Springer                                                          Value v2) const {
157cf2d374eSMatthias Springer   return equivalentInfo.isEquivalent(v1, v2);
1587a1579acSMatthias Springer }
1597a1579acSMatthias Springer 
1603490aadfSMatthias Springer bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
1613490aadfSMatthias Springer                                                        Value v2) const {
162cf2d374eSMatthias Springer   return aliasInfo.isEquivalent(v1, v2);
163cf2d374eSMatthias Springer }
164cf2d374eSMatthias Springer 
165cf2d374eSMatthias Springer void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) {
166cf2d374eSMatthias Springer   if (inplaceBufferized.contains(&operand))
167cf2d374eSMatthias Springer     return;
168cf2d374eSMatthias Springer   inplaceBufferized.insert(&operand);
169a02ad6c1SMatthias Springer   for (AliasingValue alias : getAliasingValues(operand))
170a02ad6c1SMatthias Springer     aliasInfo.unionSets(alias.value, operand.get());
171cf2d374eSMatthias Springer   ++statNumTensorInPlace;
172cf2d374eSMatthias Springer }
173cf2d374eSMatthias Springer 
174cf2d374eSMatthias Springer void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) {
175cf2d374eSMatthias Springer   assert(!inplaceBufferized.contains(&operand) &&
176cf2d374eSMatthias Springer          "OpOperand was already decided to bufferize inplace");
177cf2d374eSMatthias Springer   ++statNumTensorOutOfPlace;
178cf2d374eSMatthias Springer }
179cf2d374eSMatthias Springer 
180cf2d374eSMatthias Springer void OneShotAnalysisState::createAliasInfoEntry(Value v) {
181cf2d374eSMatthias Springer   aliasInfo.insert(v);
182cf2d374eSMatthias Springer   equivalentInfo.insert(v);
1833490aadfSMatthias Springer }
1843490aadfSMatthias Springer 
185988748c0SMatthias Springer void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
186988748c0SMatthias Springer   op->walk([&](Operation *op) {
187988748c0SMatthias Springer     // Skip unknown ops.
188988748c0SMatthias Springer     auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
189988748c0SMatthias Springer     if (!bufferizableOp)
190988748c0SMatthias Springer       return WalkResult::skip();
191988748c0SMatthias Springer 
192988748c0SMatthias Springer     // Check all tensor OpResults.
193988748c0SMatthias Springer     for (OpResult opResult : op->getOpResults()) {
1945550c821STres Popp       if (!isa<TensorType>(opResult.getType()))
195988748c0SMatthias Springer         continue;
196988748c0SMatthias Springer 
1971840d18aSMatthias Springer       // If there is no preceding definition, the tensor contents are
198988748c0SMatthias Springer       // undefined.
199*d9111f19SAmir Bishara       if (opResult.getUses().empty())
200*d9111f19SAmir Bishara         continue;
201*d9111f19SAmir Bishara       // It does not really matter which use to take to search about
202*d9111f19SAmir Bishara       // the value's definitions.
203*d9111f19SAmir Bishara       OpOperand *opOperand = &(*opResult.getUses().begin());
204*d9111f19SAmir Bishara       if (findDefinitionsCached(opOperand).empty())
205988748c0SMatthias Springer         for (OpOperand &use : opResult.getUses())
206988748c0SMatthias Springer           undefinedTensorUses.insert(&use);
207988748c0SMatthias Springer     }
208988748c0SMatthias Springer 
209988748c0SMatthias Springer     return WalkResult::advance();
210988748c0SMatthias Springer   });
211988748c0SMatthias Springer }
212988748c0SMatthias Springer 
213988748c0SMatthias Springer bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
214988748c0SMatthias Springer   return undefinedTensorUses.contains(opOperand);
215988748c0SMatthias Springer }
216988748c0SMatthias Springer 
217cf2d374eSMatthias Springer bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
218cf2d374eSMatthias Springer   return inplaceBufferized.contains(&opOperand);
219cf2d374eSMatthias Springer }
220cf2d374eSMatthias Springer 
2213490aadfSMatthias Springer bool OneShotAnalysisState::isValueWritten(Value value) const {
2223490aadfSMatthias Springer   bool isWritten = false;
223cf2d374eSMatthias Springer   applyOnAliases(value, [&](Value val) {
2243490aadfSMatthias Springer     for (OpOperand &use : val.getUses())
2253490aadfSMatthias Springer       if (isInPlace(use) && bufferizesToMemoryWrite(use))
2263490aadfSMatthias Springer         isWritten = true;
2273490aadfSMatthias Springer   });
2283490aadfSMatthias Springer   return isWritten;
2293490aadfSMatthias Springer }
2303490aadfSMatthias Springer 
231032be233SMatthias Springer bool OneShotAnalysisState::isWritable(Value value) const {
232032be233SMatthias Springer   // TODO: Out-of-place bufferized value could be considered writable.
233032be233SMatthias Springer   // Query BufferizableOpInterface to see if the BlockArgument is writable.
234032be233SMatthias Springer   if (auto bufferizableOp =
235a02ad6c1SMatthias Springer           getOptions().dynCastBufferizableOp(getOwnerOfValue(value)))
236a02ad6c1SMatthias Springer     return bufferizableOp.isWritable(value, *this);
237032be233SMatthias Springer 
238032be233SMatthias Springer   // Not a bufferizable op: The conservative answer is "not writable".
239032be233SMatthias Springer   return false;
240032be233SMatthias Springer }
241032be233SMatthias Springer 
242cf2d374eSMatthias Springer void OneShotAnalysisState::unionAliasSets(Value v1, Value v2) {
243cf2d374eSMatthias Springer   aliasInfo.unionSets(v1, v2);
244cf2d374eSMatthias Springer }
245cf2d374eSMatthias Springer 
246cf2d374eSMatthias Springer void OneShotAnalysisState::unionEquivalenceClasses(Value v1, Value v2) {
247cf2d374eSMatthias Springer   equivalentInfo.unionSets(v1, v2);
248cf2d374eSMatthias Springer }
249cf2d374eSMatthias Springer 
250faa9be75SMatthias Springer OneShotAnalysisState::Extension::~Extension() = default;
251faa9be75SMatthias Springer 
2527a1579acSMatthias Springer //===----------------------------------------------------------------------===//
2537a1579acSMatthias Springer // Bufferization-specific alias analysis.
2547a1579acSMatthias Springer //===----------------------------------------------------------------------===//
2557a1579acSMatthias Springer 
2567a1579acSMatthias Springer /// Return true if opOperand has been decided to bufferize in-place.
2577a1579acSMatthias Springer static bool isInplaceMemoryWrite(OpOperand &opOperand,
258cf2d374eSMatthias Springer                                  const OneShotAnalysisState &state) {
2597a1579acSMatthias Springer   // OpOperands that do not bufferize to a memory write do not write in-place.
2607a1579acSMatthias Springer   if (!state.bufferizesToMemoryWrite(opOperand))
2617a1579acSMatthias Springer     return false;
2627a1579acSMatthias Springer   // Check current bufferization decisions.
263cf2d374eSMatthias Springer   return state.isInPlace(opOperand);
2647a1579acSMatthias Springer }
2657a1579acSMatthias Springer 
2667a1579acSMatthias Springer /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
2677a1579acSMatthias Springer /// properly dominates `b` and `b` is not inside `a`.
2687a1579acSMatthias Springer static bool happensBefore(Operation *a, Operation *b,
2697a1579acSMatthias Springer                           const DominanceInfo &domInfo) {
2707a1579acSMatthias Springer   do {
2717a1579acSMatthias Springer     // TODO: Instead of isProperAncestor + properlyDominates, we should use
2727a1579acSMatthias Springer     // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
2737a1579acSMatthias Springer     if (a->isProperAncestor(b))
2747a1579acSMatthias Springer       return false;
2757a1579acSMatthias Springer     if (domInfo.properlyDominates(a, b))
2767a1579acSMatthias Springer       return true;
2777a1579acSMatthias Springer   } while ((a = a->getParentOp()));
2787a1579acSMatthias Springer   return false;
2797a1579acSMatthias Springer }
2807a1579acSMatthias Springer 
281c89c31a2SMatthias Springer /// Return `true` if op dominance can be used to rule out a read-after-write
2826ecebb49SMatthias Springer /// conflicts based on the ordering of ops. Returns `false` if op dominance
2836ecebb49SMatthias Springer /// cannot be used to due region-based loops.
2842e210034SMatthias Springer ///
285c89c31a2SMatthias Springer /// Generalized op dominance can often be used to rule out potential conflicts
286c89c31a2SMatthias Springer /// due to "read happens before write". E.g., the following IR is not a RaW
287c89c31a2SMatthias Springer /// conflict because the read happens *before* the write.
2882e210034SMatthias Springer ///
289c89c31a2SMatthias Springer /// Example 1:
290c89c31a2SMatthias Springer /// %0 = ... : tensor<?xf32>                                // DEF
291c89c31a2SMatthias Springer /// "reading_op"(%0) : tensor<?xf32>                        // READ
292c89c31a2SMatthias Springer /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>  // WRITE
2932e210034SMatthias Springer ///
2942e210034SMatthias Springer /// This is no longer true inside loops (or repetitive regions). In such cases,
2952e210034SMatthias Springer /// there may not be a meaningful `happensBefore` relationship because ops
2962e210034SMatthias Springer /// could be executed multiple times. E.g.:
2972e210034SMatthias Springer ///
298c89c31a2SMatthias Springer /// Example 2:
299c89c31a2SMatthias Springer /// %0 = ... : tensor<?xf32>                                  // DEF
3002e210034SMatthias Springer /// scf.for ... {
301c89c31a2SMatthias Springer ///   "reading_op"(%0) : tensor<?xf32>                        // READ
302c89c31a2SMatthias Springer ///   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>  // WRITE
3032e210034SMatthias Springer ///   ...
3042e210034SMatthias Springer /// }
3052e210034SMatthias Springer ///
3062e210034SMatthias Springer /// In the above example, reading_op happens before writing_op according to
3072e210034SMatthias Springer /// op dominance. However, both ops may happen multiple times; in
3082e210034SMatthias Springer /// particular, the second execution of reading_op happens after the first
3092e210034SMatthias Springer /// execution of writing_op. This is problematic because the tensor %0 they
3102e210034SMatthias Springer /// operate on (i.e., the "definition") is defined outside of the loop.
3112e210034SMatthias Springer ///
312c89c31a2SMatthias Springer /// On a high-level, there is a potential RaW in a program if there exists a
313c89c31a2SMatthias Springer /// possible program execution such that there is a sequence of DEF, followed
314c89c31a2SMatthias Springer /// by WRITE, followed by READ. Each additional DEF resets the sequence.
3152e210034SMatthias Springer ///
316c89c31a2SMatthias Springer /// E.g.:
317c89c31a2SMatthias Springer /// No conflict:        DEF, WRITE, DEF, READ
318c89c31a2SMatthias Springer /// Potential conflict: DEF, READ, WRITE, READ, WRITE
319c89c31a2SMatthias Springer ///
320c89c31a2SMatthias Springer /// Example 1 has no conflict:          DEF, READ, WRITE
321c89c31a2SMatthias Springer /// Example 2 has a potential conflict: DEF, (READ, WRITE)*
322c89c31a2SMatthias Springer //
323c89c31a2SMatthias Springer /// Example 3:
3242e210034SMatthias Springer /// scf.for ... {
3252e210034SMatthias Springer ///   %0 = ... : tensor<?xf32>
3262e210034SMatthias Springer ///   "reading_op"(%0) : tensor<?xf32>
3272e210034SMatthias Springer ///   %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
3282e210034SMatthias Springer ///   ...
3292e210034SMatthias Springer /// }
330c89c31a2SMatthias Springer /// This has no conflict: (DEF, READ, WRITE)*
3312e210034SMatthias Springer ///
332c89c31a2SMatthias Springer /// Example 4:
333c89c31a2SMatthias Springer /// %0 = ... : tensor<?xf32>
3342e210034SMatthias Springer /// scf.for ... {
335c89c31a2SMatthias Springer ///   scf.for ... { "reading_op"(%0) }
336c89c31a2SMatthias Springer ///   %1 = "writing_op"(%0)
3372e210034SMatthias Springer /// }
338c89c31a2SMatthias Springer /// This has a potential conflict: DEF, ((READ)*, WRITE)*
3392e210034SMatthias Springer ///
340c89c31a2SMatthias Springer /// Example 5:
341c89c31a2SMatthias Springer /// %0 = ... : tensor<?xf32>
342c89c31a2SMatthias Springer /// scf.for ... { %1 = "writing_op"(%0) }
343c89c31a2SMatthias Springer /// scf.for ... { "reading_op"(%0) }
344c89c31a2SMatthias Springer /// This has a potential conflict: DEF, WRITE*, READ*
3452e210034SMatthias Springer ///
346c89c31a2SMatthias Springer /// The following rules are used to rule out RaW conflicts via ordering of ops:
3472e210034SMatthias Springer ///
348c89c31a2SMatthias Springer /// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of
349c89c31a2SMatthias Springer ///    a repetitive region that enclosing both READ and WRITE, we cannot rule
350c89c31a2SMatthias Springer ///    out RaW conflict due to the ordering of ops.
351c89c31a2SMatthias Springer /// 2. Otherwise: There are no loops that interfere with our analysis; for
352c89c31a2SMatthias Springer ///    analysis purposes, we can assume that there are no loops/repetitive
353c89c31a2SMatthias Springer ///    regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
354c89c31a2SMatthias Springer ///    or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
3552e210034SMatthias Springer ///
3566ecebb49SMatthias Springer static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite,
357c89c31a2SMatthias Springer                                           const SetVector<Value> &definitions,
3589312b4f9SMartin Erhart                                           AnalysisState &state) {
3592e210034SMatthias Springer   const BufferizationOptions &options = state.getOptions();
360c89c31a2SMatthias Springer   for (Value def : definitions) {
3619312b4f9SMartin Erhart     Region *rRead =
3629312b4f9SMartin Erhart         state.getEnclosingRepetitiveRegion(uRead->getOwner(), options);
3639312b4f9SMartin Erhart     Region *rDef = state.getEnclosingRepetitiveRegion(def, options);
3642e210034SMatthias Springer 
365c89c31a2SMatthias Springer     // READ and DEF are in the same repetitive region. `happensBefore` can be
366c89c31a2SMatthias Springer     // used to rule out RaW conflicts due to op ordering.
367c89c31a2SMatthias Springer     if (rRead == rDef)
3682e210034SMatthias Springer       continue;
369c89c31a2SMatthias Springer 
370c89c31a2SMatthias Springer     // Find the enclosing repetitive region of READ that is closest to DEF but
371c89c31a2SMatthias Springer     // not the repetitive region of DEF itself.
372c89c31a2SMatthias Springer     while (true) {
373c89c31a2SMatthias Springer       Region *nextRegion = getNextEnclosingRepetitiveRegion(rRead, options);
374c89c31a2SMatthias Springer       if (nextRegion == rDef)
375c89c31a2SMatthias Springer         break;
376c89c31a2SMatthias Springer       assert(nextRegion && "expected to find another repetitive region");
377c89c31a2SMatthias Springer       rRead = nextRegion;
3782e210034SMatthias Springer     }
379c89c31a2SMatthias Springer 
380c89c31a2SMatthias Springer     // We cannot use op dominance if WRITE is inside the same repetitive region.
381c89c31a2SMatthias Springer     if (rRead->getParentOp()->isAncestor(uWrite->getOwner()))
3822e210034SMatthias Springer       return false;
3832e210034SMatthias Springer   }
3846ecebb49SMatthias Springer 
385c89c31a2SMatthias Springer   return true;
3862e210034SMatthias Springer }
3872e210034SMatthias Springer 
3886ecebb49SMatthias Springer /// Return `true` if op dominance can be used to rule out a read-after-write
3896ecebb49SMatthias Springer /// conflicts based on the ordering of ops. Returns `false` if op dominance
3906ecebb49SMatthias Springer /// cannot be used to due block-based loops within a region.
3916ecebb49SMatthias Springer ///
3926ecebb49SMatthias Springer /// Refer to the `canUseOpDominanceDueToRegions` documentation for details on
3936ecebb49SMatthias Springer /// how op domiance is used during RaW conflict detection.
3946ecebb49SMatthias Springer ///
3956ecebb49SMatthias Springer /// On a high-level, there is a potential RaW in a program if there exists a
3966ecebb49SMatthias Springer /// possible program execution such that there is a sequence of DEF, followed
3976ecebb49SMatthias Springer /// by WRITE, followed by READ. Each additional DEF resets the sequence.
3986ecebb49SMatthias Springer ///
3996ecebb49SMatthias Springer /// Op dominance cannot be used if there is a path from block(READ) to
4006ecebb49SMatthias Springer /// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should
4016ecebb49SMatthias Springer /// not appear on that path.
4026ecebb49SMatthias Springer static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite,
4036ecebb49SMatthias Springer                                          const SetVector<Value> &definitions,
4046ecebb49SMatthias Springer                                          AnalysisState &state) {
4056ecebb49SMatthias Springer   // Fast path: If READ and WRITE are in different regions, their block cannot
4066ecebb49SMatthias Springer   // be reachable just via unstructured control flow. (Loops due to regions are
4076ecebb49SMatthias Springer   // covered by `canUseOpDominanceDueToRegions`.)
4086ecebb49SMatthias Springer   if (uRead->getOwner()->getParentRegion() !=
4096ecebb49SMatthias Springer       uWrite->getOwner()->getParentRegion())
4106ecebb49SMatthias Springer     return true;
4116ecebb49SMatthias Springer 
4126ecebb49SMatthias Springer   Block *readBlock = uRead->getOwner()->getBlock();
4136ecebb49SMatthias Springer   Block *writeBlock = uWrite->getOwner()->getBlock();
4146ecebb49SMatthias Springer   for (Value def : definitions) {
4156ecebb49SMatthias Springer     Block *defBlock = def.getParentBlock();
416804d3c4cSMatthias Springer     if (readBlock->isReachable(writeBlock, {defBlock}) &&
417804d3c4cSMatthias Springer         writeBlock->isReachable(readBlock, {defBlock}))
4186ecebb49SMatthias Springer       return false;
4196ecebb49SMatthias Springer   }
4206ecebb49SMatthias Springer 
4216ecebb49SMatthias Springer   return true;
4226ecebb49SMatthias Springer }
4236ecebb49SMatthias Springer 
4246ecebb49SMatthias Springer static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
4256ecebb49SMatthias Springer                               const SetVector<Value> &definitions,
4266ecebb49SMatthias Springer                               AnalysisState &state) {
4276ecebb49SMatthias Springer   return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) &&
4286ecebb49SMatthias Springer          canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state);
4296ecebb49SMatthias Springer }
4306ecebb49SMatthias Springer 
4317a1579acSMatthias Springer /// Annotate IR with details about the detected RaW conflict.
4327a1579acSMatthias Springer static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
4331840d18aSMatthias Springer                              Value definition) {
4347a1579acSMatthias Springer   static uint64_t counter = 0;
4357a1579acSMatthias Springer   Operation *readingOp = uRead->getOwner();
4367a1579acSMatthias Springer   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
4377a1579acSMatthias Springer 
4387a1579acSMatthias Springer   OpBuilder b(conflictingWritingOp->getContext());
4397a1579acSMatthias Springer   std::string id = "C_" + std::to_string(counter++);
4407a1579acSMatthias Springer 
4417a1579acSMatthias Springer   std::string conflictingWriteAttr =
4427a1579acSMatthias Springer       id +
4437a1579acSMatthias Springer       "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
4447a1579acSMatthias Springer       "]";
4457a1579acSMatthias Springer   conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
4467a1579acSMatthias Springer 
4477a1579acSMatthias Springer   std::string readAttr =
4487a1579acSMatthias Springer       id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
4497a1579acSMatthias Springer   readingOp->setAttr(readAttr, b.getUnitAttr());
4507a1579acSMatthias Springer 
4515550c821STres Popp   if (auto opResult = dyn_cast<OpResult>(definition)) {
4521840d18aSMatthias Springer     std::string defAttr =
4531840d18aSMatthias Springer         id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]";
4541840d18aSMatthias Springer     opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
4557a1579acSMatthias Springer   } else {
4565550c821STres Popp     auto bbArg = cast<BlockArgument>(definition);
4571840d18aSMatthias Springer     std::string defAttr =
4581840d18aSMatthias Springer         id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
4591840d18aSMatthias Springer     bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
4607a1579acSMatthias Springer   }
4617a1579acSMatthias Springer }
4627a1579acSMatthias Springer 
463f36e1934SMatthias Springer /// Return 'true' if a tensor that is equivalent to `other` can be found in the
464f36e1934SMatthias Springer /// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of
465f36e1934SMatthias Springer /// place along that use-def chain, the two tensors may not materialize as
466f36e1934SMatthias Springer /// equivalent buffers (but separate allocations).
467f36e1934SMatthias Springer ///
468f36e1934SMatthias Springer /// Note: This function also requires that the two tensors have equivalent
469f36e1934SMatthias Springer /// indexing. I.e., the tensor types do not change along the use-def chain,
470f36e1934SMatthias Springer /// apart from static <-> dynamic dim casts.
471f36e1934SMatthias Springer static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
472*d9111f19SAmir Bishara                                                    OpOperand *start,
473*d9111f19SAmir Bishara                                                    Value other) {
474f36e1934SMatthias Springer   TraversalConfig config;
475f36e1934SMatthias Springer   config.followEquivalentOnly = true;
476f36e1934SMatthias Springer   config.alwaysIncludeLeaves = false;
477f36e1934SMatthias Springer   config.followSameTypeOrCastsOnly = true;
478f36e1934SMatthias Springer   return !state
479f36e1934SMatthias Springer               .findValueInReverseUseDefChain(
480f36e1934SMatthias Springer                   start, [&](Value v) { return v == other; }, config)
481f36e1934SMatthias Springer               .empty();
482f36e1934SMatthias Springer }
483f36e1934SMatthias Springer 
484*d9111f19SAmir Bishara /// Return "true" if the given operand's value is originating from a subset
485*d9111f19SAmir Bishara /// that is equivalent to the subset that `subsetOp` inserts into.
486*d9111f19SAmir Bishara static bool matchesInsertDestination(const AnalysisState &state,
487*d9111f19SAmir Bishara                                      OpOperand *opOperand,
4888143307bSMatthias Springer                                      SubsetInsertionOpInterface subsetOp) {
4898143307bSMatthias Springer   auto matchingSubset = [&](Value val) {
4908143307bSMatthias Springer     if (auto opResult = dyn_cast<OpResult>(val))
4918143307bSMatthias Springer       if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
4928143307bSMatthias Springer             return state.areEquivalentBufferizedValues(v1, v2);
4938143307bSMatthias Springer           }))
4948143307bSMatthias Springer         return true;
4958143307bSMatthias Springer     return false;
4968143307bSMatthias Springer   };
4978143307bSMatthias Springer   // There may be multiple leaves at which the reverse SSA use-def chain lookup
4988143307bSMatthias Springer   // terminates. All of them must be equivalent subsets.
4998143307bSMatthias Springer   SetVector<Value> backwardSlice =
500*d9111f19SAmir Bishara       state.findValueInReverseUseDefChain(opOperand, matchingSubset);
5018143307bSMatthias Springer   return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
5028143307bSMatthias Springer }
5038143307bSMatthias Springer 
5048143307bSMatthias Springer /// Return "true" if the given "read" and potentially conflicting "write" are
5058143307bSMatthias Springer /// not conflicting due to their subset relationship. The comments in this
5068143307bSMatthias Springer /// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
5078143307bSMatthias Springer /// pairs, but apply to any subset ops that implement the
5088143307bSMatthias Springer /// `SubsetInsertionOpInterface`.
5098143307bSMatthias Springer static bool areNonConflictingSubsets(OpOperand *uRead,
5108143307bSMatthias Springer                                      OpOperand *uConflictingWrite,
5118143307bSMatthias Springer                                      const AnalysisState &state) {
5128143307bSMatthias Springer   Operation *readingOp = uRead->getOwner();
5138143307bSMatthias Springer   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
5148143307bSMatthias Springer 
5158143307bSMatthias Springer   // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
5168143307bSMatthias Springer   // uRead is an InsertSliceOp...
5178143307bSMatthias Springer   if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) {
5188143307bSMatthias Springer     // As an example, consider the following IR.
5198143307bSMatthias Springer     //
5208143307bSMatthias Springer     // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
5218143307bSMatthias Springer     // %1 = linalg.fill %cst, %0 {inplace= [true] }
5228143307bSMatthias Springer     // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
5238143307bSMatthias Springer     //     {inplace= [true] }
5248143307bSMatthias Springer 
5258143307bSMatthias Springer     if (uRead == &subsetOp.getDestinationOperand() &&
526*d9111f19SAmir Bishara         matchesInsertDestination(state, uConflictingWrite, subsetOp))
5278143307bSMatthias Springer       // Case 1: The main insight is that InsertSliceOp reads only part of
5288143307bSMatthias Springer       // the destination tensor. The overwritten area is not read. If
5298143307bSMatthias Springer       // uConflictingWrite writes into exactly the memory location that is
5308143307bSMatthias Springer       // being read by uRead, this is not a conflict.
5318143307bSMatthias Springer       //
5328143307bSMatthias Springer       // In the above example:
5338143307bSMatthias Springer       // uRead             = OpOperand 1 (%t) of tensor.insert_slice
5348143307bSMatthias Springer       // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
5358143307bSMatthias Springer       //
5368143307bSMatthias Springer       // The read of %t does not conflict with the write of the FillOp
5378143307bSMatthias Springer       // (same aliases!) because the area that the FillOp operates on is
5388143307bSMatthias Springer       // exactly the one that is *not* read via %t.
5398143307bSMatthias Springer       return true;
5408143307bSMatthias Springer 
5418143307bSMatthias Springer     if (uRead == &subsetOp.getSourceOperand() &&
5428143307bSMatthias Springer         uConflictingWrite == &subsetOp.getDestinationOperand() &&
543*d9111f19SAmir Bishara         matchesInsertDestination(state, uRead, subsetOp))
5448143307bSMatthias Springer       // Case 2: The read of the source tensor and the write to the dest
5458143307bSMatthias Springer       // tensor via an InsertSliceOp is not a conflict if the read is
5468143307bSMatthias Springer       // reading exactly that part of an equivalent tensor that the
5478143307bSMatthias Springer       // InsertSliceOp is writing.
5488143307bSMatthias Springer       //
5498143307bSMatthias Springer       // In the above example:
5508143307bSMatthias Springer       // uRead             = OpOperand 0 (%1) of tensor.insert_slice
5518143307bSMatthias Springer       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
5528143307bSMatthias Springer       return true;
5538143307bSMatthias Springer   }
5548143307bSMatthias Springer 
5558143307bSMatthias Springer   // If uConflictingWrite is an InsertSliceOp...
5568143307bSMatthias Springer   if (auto subsetOp =
5578143307bSMatthias Springer           dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp))
5588143307bSMatthias Springer     // As an example, consider the following IR.
5598143307bSMatthias Springer     //
5608143307bSMatthias Springer     // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
5618143307bSMatthias Springer     // %1 = linalg.fill %cst, %0 {inplace= [true] }
5628143307bSMatthias Springer     // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
5638143307bSMatthias Springer     //     {inplace= [true] }
5648143307bSMatthias Springer     // %3 = vector.transfer_read %1, %cst
5658143307bSMatthias Springer     //
5668143307bSMatthias Springer     // In the above example:
5678143307bSMatthias Springer     // uRead             = OpOperand 0 (%1) of vector.transfer_read
5688143307bSMatthias Springer     // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
5698143307bSMatthias Springer     // definition        = %1
5708143307bSMatthias Springer     //
5718143307bSMatthias Springer     // This is not a conflict because the InsertSliceOp overwrites the
5728143307bSMatthias Springer     // memory segment of %1 with the exact same data. (Effectively, there
5738143307bSMatthias Springer     // is no memory write here.)
5748143307bSMatthias Springer     if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
5758143307bSMatthias Springer         state.areEquivalentBufferizedValues(
5768143307bSMatthias Springer             uRead->get(), subsetOp.getSourceOperand().get()) &&
577*d9111f19SAmir Bishara         matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp))
5788143307bSMatthias Springer       return true;
5798143307bSMatthias Springer 
5808143307bSMatthias Springer   return false;
5818143307bSMatthias Springer }
5828143307bSMatthias Springer 
5837a1579acSMatthias Springer /// Given sets of uses and writes, return true if there is a RaW conflict under
5847a1579acSMatthias Springer /// the assumption that all given reads/writes alias the same buffer and that
5857a1579acSMatthias Springer /// all given writes bufferize inplace.
5867a1579acSMatthias Springer ///
5877a1579acSMatthias Springer /// A conflict is: According to SSA use-def chains, a read R is supposed to read
5881840d18aSMatthias Springer /// the result of a definition W1. But because of bufferization decisions, R
5891840d18aSMatthias Springer /// actually reads another definition W2.
590cf2d374eSMatthias Springer static bool
591cf2d374eSMatthias Springer hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
592cf2d374eSMatthias Springer                               const DenseSet<OpOperand *> &usesWrite,
593cf2d374eSMatthias Springer                               const DominanceInfo &domInfo,
594cf2d374eSMatthias Springer                               OneShotAnalysisState &state) {
5957a1579acSMatthias Springer   const BufferizationOptions &options = state.getOptions();
5967a1579acSMatthias Springer 
5971e1a3112SMatthias Springer   // Before going through the main RaW analysis, find cases where a buffer must
5981e1a3112SMatthias Springer   // be privatized due to parallelism. If the result of a write is never read,
5991e1a3112SMatthias Springer   // privatization is not necessary (and large parts of the IR are likely dead).
600d5863721SMax191   if (options.checkParallelRegions && !usesRead.empty()) {
6011e1a3112SMatthias Springer     for (OpOperand *uConflictingWrite : usesWrite) {
6021e1a3112SMatthias Springer       // Find the allocation point or last write (definition) of the buffer.
6031e1a3112SMatthias Springer       // Note: In contrast to `findDefinitions`, this also returns results of
6041e1a3112SMatthias Springer       // ops that do not bufferize to memory write when no other definition
6051e1a3112SMatthias Springer       // could be found. E.g., "bufferization.alloc_tensor" would be included,
6061e1a3112SMatthias Springer       // even though that op just bufferizes to an allocation but does define
6071e1a3112SMatthias Springer       // the contents of the buffer.
6081e1a3112SMatthias Springer       SetVector<Value> definitionsOrLeaves =
609*d9111f19SAmir Bishara           state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) {
610*d9111f19SAmir Bishara             return state.bufferizesToMemoryWrite(v);
611*d9111f19SAmir Bishara           });
6121e1a3112SMatthias Springer       assert(!definitionsOrLeaves.empty() &&
6131e1a3112SMatthias Springer              "expected at least one definition or leaf");
6141e1a3112SMatthias Springer 
6151e1a3112SMatthias Springer       // The writing op must bufferize out-of-place if the definition is in a
6161e1a3112SMatthias Springer       // different parallel region than this write.
6171e1a3112SMatthias Springer       for (Value def : definitionsOrLeaves) {
6181e1a3112SMatthias Springer         if (getParallelRegion(def.getParentRegion(), options) !=
6191e1a3112SMatthias Springer             getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(),
6201e1a3112SMatthias Springer                               options)) {
6211e1a3112SMatthias Springer           LLVM_DEBUG(
6221e1a3112SMatthias Springer               llvm::dbgs()
6231e1a3112SMatthias Springer               << "\n- bufferizes out-of-place due to parallel region:\n");
6241e1a3112SMatthias Springer           LLVM_DEBUG(llvm::dbgs()
6251e1a3112SMatthias Springer                      << "  unConflictingWrite = operand "
6261e1a3112SMatthias Springer                      << uConflictingWrite->getOperandNumber() << " of "
6271e1a3112SMatthias Springer                      << *uConflictingWrite->getOwner() << "\n");
6281e1a3112SMatthias Springer           return true;
6291e1a3112SMatthias Springer         }
6301e1a3112SMatthias Springer       }
6311e1a3112SMatthias Springer     }
6321e1a3112SMatthias Springer   }
6331e1a3112SMatthias Springer 
6347a1579acSMatthias Springer   for (OpOperand *uRead : usesRead) {
6357a1579acSMatthias Springer     Operation *readingOp = uRead->getOwner();
6361fdf06d6SMatthias Springer     LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
6371fdf06d6SMatthias Springer     LLVM_DEBUG(llvm::dbgs() << "  uRead = operand " << uRead->getOperandNumber()
6381fdf06d6SMatthias Springer                             << " of " << *readingOp << "\n");
6397a1579acSMatthias Springer 
6401fdf06d6SMatthias Springer     // Find the definition of uRead by following the SSA use-def chain.
6417a1579acSMatthias Springer     // E.g.:
6427a1579acSMatthias Springer     //
6437a1579acSMatthias Springer     // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
6447a1579acSMatthias Springer     // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
6457a1579acSMatthias Springer     // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
6467a1579acSMatthias Springer     //
6471840d18aSMatthias Springer     // In the above example, if uRead is the OpOperand of reading_op, the
6481840d18aSMatthias Springer     // definition is %0. Note that operations that create an alias but do not
6491840d18aSMatthias Springer     // bufferize to a memory write (such as ExtractSliceOp) are skipped.
650*d9111f19SAmir Bishara     const SetVector<Value> &definitions = state.findDefinitionsCached(uRead);
6511fdf06d6SMatthias Springer     if (definitions.empty()) {
6521fdf06d6SMatthias Springer       // Fast path: No conflict if there are no definitions.
6531fdf06d6SMatthias Springer       LLVM_DEBUG(llvm::dbgs()
6541fdf06d6SMatthias Springer                  << "  no conflict: read value has no definitions\n");
6551fdf06d6SMatthias Springer       continue;
6561fdf06d6SMatthias Springer     }
6577a1579acSMatthias Springer 
6587a1579acSMatthias Springer     // Look for conflicting memory writes. Potential conflicts are writes to an
6597a1579acSMatthias Springer     // alias that have been decided to bufferize inplace.
6607a1579acSMatthias Springer     for (OpOperand *uConflictingWrite : usesWrite) {
661e0b40af7SMatthias Springer       LLVM_DEBUG(llvm::dbgs() << "  unConflictingWrite = operand "
662e0b40af7SMatthias Springer                               << uConflictingWrite->getOperandNumber() << " of "
663e0b40af7SMatthias Springer                               << *uConflictingWrite->getOwner() << "\n");
664e0b40af7SMatthias Springer 
665c89c31a2SMatthias Springer       // Check if op dominance can be used to rule out read-after-write
666c89c31a2SMatthias Springer       // conflicts.
667c89c31a2SMatthias Springer       bool useDominance =
668c89c31a2SMatthias Springer           canUseOpDominance(uRead, uConflictingWrite, definitions, state);
669c89c31a2SMatthias Springer       LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
670c89c31a2SMatthias Springer 
6717a1579acSMatthias Springer       // Throughout this loop, check for multiple requirements that have to be
6727a1579acSMatthias Springer       // met for uConflictingWrite to be an actual conflict.
6737a1579acSMatthias Springer       Operation *conflictingWritingOp = uConflictingWrite->getOwner();
6747a1579acSMatthias Springer 
6751f8ffbd1SMatthias Springer       // Inside of repetitive regions, ops may be executed multiple times and op
6761f8ffbd1SMatthias Springer       // dominance cannot be used to rule out conflicts.
6771f8ffbd1SMatthias Springer       if (useDominance) {
6781f8ffbd1SMatthias Springer         // No conflict if the readingOp dominates conflictingWritingOp, i.e.,
6791f8ffbd1SMatthias Springer         // the write is not visible when reading.
6809235e597SMatthias Springer         //
6811f8ffbd1SMatthias Springer         // Note: If ops are executed multiple times (e.g., because they are
6821f8ffbd1SMatthias Springer         //       inside a loop), there may be no meaningful `happensBefore`
6831f8ffbd1SMatthias Springer         //       relationship.
684e0b40af7SMatthias Springer         if (happensBefore(readingOp, conflictingWritingOp, domInfo)) {
685e0b40af7SMatthias Springer           LLVM_DEBUG(llvm::dbgs()
686e0b40af7SMatthias Springer                      << "  no conflict: read happens before write\n");
6877a1579acSMatthias Springer           continue;
688e0b40af7SMatthias Springer         }
6897a1579acSMatthias Springer 
6901f8ffbd1SMatthias Springer         // No conflict if the reading use equals the use of the conflicting
6911f8ffbd1SMatthias Springer         // write. A use cannot conflict with itself.
6929235e597SMatthias Springer         //
6931f8ffbd1SMatthias Springer         // Note: Just being the same op is not enough. It has to be the same
6941f8ffbd1SMatthias Springer         //       use.
6951f8ffbd1SMatthias Springer         // Note: If the op is executed multiple times (e.g., because it is
6961f8ffbd1SMatthias Springer         //       inside a loop), it may be conflicting with itself.
697e0b40af7SMatthias Springer         if (uConflictingWrite == uRead) {
698e0b40af7SMatthias Springer           LLVM_DEBUG(llvm::dbgs()
699e0b40af7SMatthias Springer                      << "  no conflict: read and write are same use\n");
7007a1579acSMatthias Springer           continue;
701e0b40af7SMatthias Springer         }
7027a1579acSMatthias Springer 
7031f8ffbd1SMatthias Springer         // Ops are not conflicting if they are in mutually exclusive regions.
7041f8ffbd1SMatthias Springer         //
7051f8ffbd1SMatthias Springer         // Note: If ops are executed multiple times (e.g., because they are
7061f8ffbd1SMatthias Springer         //       inside a loop), mutually exclusive regions may be executed
7071f8ffbd1SMatthias Springer         //       multiple times.
708e0b40af7SMatthias Springer         if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) {
709e0b40af7SMatthias Springer           LLVM_DEBUG(llvm::dbgs() << "  no conflict: read and write are in "
710e0b40af7SMatthias Springer                                      "mutually exclusive regions\n");
7111f8ffbd1SMatthias Springer           continue;
7121f8ffbd1SMatthias Springer         }
7131f8ffbd1SMatthias Springer 
71454683405SMatthias Springer         // Two equivalent operands of the same op are not conflicting if the op
715cf9b77a6SMatthias Springer         // bufferizes to element-wise access. I.e., all loads at a position
716cf9b77a6SMatthias Springer         // happen before all stores to the same position.
717f36e1934SMatthias Springer         if (conflictingWritingOp == readingOp) {
71854683405SMatthias Springer           if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
719f36e1934SMatthias Springer             if (bufferizableOp.bufferizesToElementwiseAccess(
720f36e1934SMatthias Springer                     state, {uRead, uConflictingWrite})) {
721f36e1934SMatthias Springer               if (hasEquivalentValueInReverseUseDefChain(
722*d9111f19SAmir Bishara                       state, uRead, uConflictingWrite->get()) ||
723f36e1934SMatthias Springer                   hasEquivalentValueInReverseUseDefChain(
724*d9111f19SAmir Bishara                       state, uConflictingWrite, uRead->get())) {
72554683405SMatthias Springer                 LLVM_DEBUG(
72654683405SMatthias Springer                     llvm::dbgs()
72754683405SMatthias Springer                     << "  no conflict: op bufferizes to element-wise access\n");
72854683405SMatthias Springer                 continue;
72954683405SMatthias Springer               }
73054683405SMatthias Springer             }
73154683405SMatthias Springer           }
732f36e1934SMatthias Springer         }
733cf9b77a6SMatthias Springer       }
73454683405SMatthias Springer 
7358143307bSMatthias Springer       // No conflict if the operands are non-conflicting subsets.
7368143307bSMatthias Springer       if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
7378143307bSMatthias Springer         LLVM_DEBUG(llvm::dbgs() << "  no conflict: non-conflicting subsets\n");
7388143307bSMatthias Springer         continue;
7398143307bSMatthias Springer       }
7408143307bSMatthias Springer 
7417a1579acSMatthias Springer       // No conflict if the op interface says so.
742e0b40af7SMatthias Springer       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
743e0b40af7SMatthias Springer         if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
744e0b40af7SMatthias Springer           LLVM_DEBUG(llvm::dbgs()
745e0b40af7SMatthias Springer                      << "  no conflict: op interace of reading op says 'no'\n");
7467a1579acSMatthias Springer           continue;
747e0b40af7SMatthias Springer         }
748e0b40af7SMatthias Springer       }
7497a1579acSMatthias Springer 
750e0b40af7SMatthias Springer       if (conflictingWritingOp != readingOp) {
7517a1579acSMatthias Springer         if (auto bufferizableOp =
752e0b40af7SMatthias Springer                 options.dynCastBufferizableOp(conflictingWritingOp)) {
753e0b40af7SMatthias Springer           if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
754e0b40af7SMatthias Springer                                               state)) {
755e0b40af7SMatthias Springer             LLVM_DEBUG(
756e0b40af7SMatthias Springer                 llvm::dbgs()
757e0b40af7SMatthias Springer                 << "  no conflict: op interace of writing op says 'no'\n");
7587a1579acSMatthias Springer             continue;
759e0b40af7SMatthias Springer           }
760e0b40af7SMatthias Springer         }
761e0b40af7SMatthias Springer       }
7627a1579acSMatthias Springer 
7631840d18aSMatthias Springer       // Check all possible definitions.
7641840d18aSMatthias Springer       for (Value definition : definitions) {
7651840d18aSMatthias Springer         LLVM_DEBUG(llvm::dbgs() << "  * definition = " << definition << "\n");
766e0b40af7SMatthias Springer 
7671840d18aSMatthias Springer         // No conflict if the conflicting write happens before the definition.
7681fdf06d6SMatthias Springer         if (Operation *defOp = definition.getDefiningOp()) {
7691fdf06d6SMatthias Springer           if (happensBefore(conflictingWritingOp, defOp, domInfo)) {
7701fdf06d6SMatthias Springer             // conflictingWritingOp happens before defOp. No conflict.
771e0b40af7SMatthias Springer             LLVM_DEBUG(llvm::dbgs()
7721840d18aSMatthias Springer                        << "    no conflict: write happens before definition\n");
7737a1579acSMatthias Springer             continue;
774e0b40af7SMatthias Springer           }
7751fdf06d6SMatthias Springer           // No conflict if conflictingWritingOp is contained in defOp.
7761fdf06d6SMatthias Springer           if (defOp->isProperAncestor(conflictingWritingOp)) {
777e0b40af7SMatthias Springer             LLVM_DEBUG(
778e0b40af7SMatthias Springer                 llvm::dbgs()
7791840d18aSMatthias Springer                 << "    no conflict: write is contained in definition\n");
7807a1579acSMatthias Springer             continue;
781e0b40af7SMatthias Springer           }
7827a1579acSMatthias Springer         } else {
7835550c821STres Popp           auto bbArg = cast<BlockArgument>(definition);
7847a1579acSMatthias Springer           Block *block = bbArg.getOwner();
785e0b40af7SMatthias Springer           if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
7861840d18aSMatthias Springer             LLVM_DEBUG(llvm::dbgs() << "    no conflict: definition is bbArg "
787e0b40af7SMatthias Springer                                        "and write happens outside of block\n");
7887a1579acSMatthias Springer             // conflictingWritingOp happens outside of the block. No
7897a1579acSMatthias Springer             // conflict.
7907a1579acSMatthias Springer             continue;
7917a1579acSMatthias Springer           }
792e0b40af7SMatthias Springer         }
7937a1579acSMatthias Springer 
7941840d18aSMatthias Springer         // No conflict if the conflicting write and the definition are the same
7957a1579acSMatthias Springer         // use.
796a02ad6c1SMatthias Springer         AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite);
7979fa6b350SMatthias Springer         if (aliases.getNumAliases() == 1 &&
798a02ad6c1SMatthias Springer             aliases.getAliases()[0].value == definition) {
799e0b40af7SMatthias Springer           LLVM_DEBUG(llvm::dbgs()
8001840d18aSMatthias Springer                      << "    no conflict: definition and write are same\n");
8017a1579acSMatthias Springer           continue;
802e0b40af7SMatthias Springer         }
8037a1579acSMatthias Springer 
8047a1579acSMatthias Springer         // All requirements are met. Conflict found!
8057a1579acSMatthias Springer 
8067a1579acSMatthias Springer         if (options.printConflicts)
8071840d18aSMatthias Springer           annotateConflict(uRead, uConflictingWrite, definition);
808e0b40af7SMatthias Springer         LLVM_DEBUG(llvm::dbgs() << "  => RaW CONFLICT FOUND\n");
8097a1579acSMatthias Springer         return true;
8107a1579acSMatthias Springer       }
8117a1579acSMatthias Springer     }
8127a1579acSMatthias Springer   }
8137a1579acSMatthias Springer 
8147a1579acSMatthias Springer   return false;
8157a1579acSMatthias Springer }
8167a1579acSMatthias Springer 
817032be233SMatthias Springer // Helper function to iterate on aliases of `root` and capture the writes.
818032be233SMatthias Springer static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
819cf2d374eSMatthias Springer                                      const OneShotAnalysisState &state) {
820cf2d374eSMatthias Springer   state.applyOnAliases(root, [&](Value alias) {
821032be233SMatthias Springer     for (auto &use : alias.getUses())
822032be233SMatthias Springer       // Inplace write to a value that aliases root.
823cf2d374eSMatthias Springer       if (isInplaceMemoryWrite(use, state))
824032be233SMatthias Springer         res.insert(&use);
825032be233SMatthias Springer   });
826032be233SMatthias Springer }
827032be233SMatthias Springer 
828032be233SMatthias Springer // Helper function to iterate on aliases of `root` and capture the reads.
829032be233SMatthias Springer static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
830cf2d374eSMatthias Springer                              const OneShotAnalysisState &state) {
831cf2d374eSMatthias Springer   state.applyOnAliases(root, [&](Value alias) {
8321b99f3a2SMatthias Springer     for (auto &use : alias.getUses()) {
8331b99f3a2SMatthias Springer       // Read of a value that aliases root.
8341b99f3a2SMatthias Springer       if (state.bufferizesToMemoryRead(use)) {
835032be233SMatthias Springer         res.insert(&use);
8361b99f3a2SMatthias Springer         continue;
8371b99f3a2SMatthias Springer       }
8381b99f3a2SMatthias Springer 
8391b99f3a2SMatthias Springer       // Read of a dependent value in the SSA use-def chain. E.g.:
8401b99f3a2SMatthias Springer       //
8411b99f3a2SMatthias Springer       // %0 = ...
8421b99f3a2SMatthias Springer       // %1 = tensor.extract_slice %0 {not_analyzed_yet}
8431b99f3a2SMatthias Springer       // "read"(%1)
8441b99f3a2SMatthias Springer       //
8451b99f3a2SMatthias Springer       // In the above example, getAliasingReads(%0) includes the first OpOperand
8461b99f3a2SMatthias Springer       // of the tensor.extract_slice op. The extract_slice itself does not read
8471b99f3a2SMatthias Springer       // but its aliasing result is eventually fed into an op that does.
8481b99f3a2SMatthias Springer       //
8491b99f3a2SMatthias Springer       // Note: This is considered a "read" only if the use does not bufferize to
8501b99f3a2SMatthias Springer       // a memory write. (We already ruled out memory reads. In case of a memory
8511b99f3a2SMatthias Springer       // write, the buffer would be entirely overwritten; in the above example
8521b99f3a2SMatthias Springer       // there would then be no flow of data from the extract_slice operand to
8531b99f3a2SMatthias Springer       // its result's uses.)
8541b99f3a2SMatthias Springer       if (!state.bufferizesToMemoryWrite(use)) {
855a02ad6c1SMatthias Springer         AliasingValueList aliases = state.getAliasingValues(use);
856a02ad6c1SMatthias Springer         if (llvm::any_of(aliases, [&](AliasingValue a) {
857a02ad6c1SMatthias Springer               return state.isValueRead(a.value);
8589fa6b350SMatthias Springer             }))
8591b99f3a2SMatthias Springer           res.insert(&use);
8601b99f3a2SMatthias Springer       }
8611b99f3a2SMatthias Springer     }
862032be233SMatthias Springer   });
863032be233SMatthias Springer }
864032be233SMatthias Springer 
8657a1579acSMatthias Springer /// Return true if bufferizing `operand` inplace would create a conflict. A read
8667a1579acSMatthias Springer /// R and a write W of the same alias set is a conflict if inplace bufferization
8677a1579acSMatthias Springer /// of W changes the value read by R to a value different from the one that
8687a1579acSMatthias Springer /// would be expected by tracing back R's origin through SSA use-def chains.
8697a1579acSMatthias Springer /// A conflict can only be introduced by a new alias and/or an inplace
8707a1579acSMatthias Springer /// bufferization decision.
8717a1579acSMatthias Springer ///
8727a1579acSMatthias Springer /// Example:
8737a1579acSMatthias Springer /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
8747a1579acSMatthias Springer /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
8757a1579acSMatthias Springer /// %e = tensor.extract_slice %1
8767a1579acSMatthias Springer /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
8777a1579acSMatthias Springer /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
8787a1579acSMatthias Springer ///
8797a1579acSMatthias Springer /// In the above example, the two TransferWriteOps have already been decided to
8807a1579acSMatthias Springer /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
8817a1579acSMatthias Springer /// conflict because:
8827a1579acSMatthias Springer /// * According to SSA use-def chains, we expect to read the result of %1.
8837a1579acSMatthias Springer /// * However, adding an alias {%0, %t} would mean that the second
8841840d18aSMatthias Springer ///   TransferWriteOp overwrites the result of the first one. Therefore, the
8851840d18aSMatthias Springer ///   TransferReadOp would no longer be reading the result of %1.
8867a1579acSMatthias Springer ///
8877a1579acSMatthias Springer /// If `checkConsistencyOnly` is true, this function checks if there is a
8887a1579acSMatthias Springer /// read-after-write conflict without bufferizing `operand` inplace. This would
8897a1579acSMatthias Springer /// indicate a problem with the current inplace bufferization decisions.
8907a1579acSMatthias Springer ///
8917a1579acSMatthias Springer /// Note: If `checkConsistencyOnly`, this function may be called with a null
8927a1579acSMatthias Springer /// OpResult. In that case, only the consistency of bufferization decisions
8937a1579acSMatthias Springer /// involving aliases of the given OpOperand are checked.
8947a1579acSMatthias Springer static bool wouldCreateReadAfterWriteInterference(
895cf2d374eSMatthias Springer     OpOperand &operand, const DominanceInfo &domInfo,
896cf2d374eSMatthias Springer     OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
8977a1579acSMatthias Springer   // Collect reads and writes of all aliases of OpOperand and OpResult.
8987a1579acSMatthias Springer   DenseSet<OpOperand *> usesRead, usesWrite;
899cf2d374eSMatthias Springer   getAliasingReads(usesRead, operand.get(), state);
900cf2d374eSMatthias Springer   getAliasingInplaceWrites(usesWrite, operand.get(), state);
901a02ad6c1SMatthias Springer   for (AliasingValue alias : state.getAliasingValues(operand)) {
902a02ad6c1SMatthias Springer     getAliasingReads(usesRead, alias.value, state);
903a02ad6c1SMatthias Springer     getAliasingInplaceWrites(usesWrite, alias.value, state);
9047a1579acSMatthias Springer   }
9057a1579acSMatthias Springer   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
9067a1579acSMatthias Springer     usesWrite.insert(&operand);
9077a1579acSMatthias Springer 
908cf2d374eSMatthias Springer   return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state);
9097a1579acSMatthias Springer }
9107a1579acSMatthias Springer 
911be630f07SMatthias Springer /// Annotate IR with details about the detected non-writability conflict.
912be630f07SMatthias Springer static void annotateNonWritableTensor(Value value) {
913be630f07SMatthias Springer   static int64_t counter = 0;
914be630f07SMatthias Springer   OpBuilder b(value.getContext());
915be630f07SMatthias Springer   std::string id = "W_" + std::to_string(counter++);
9165550c821STres Popp   if (auto opResult = dyn_cast<OpResult>(value)) {
917be630f07SMatthias Springer     std::string attr = id + "[NOT-WRITABLE: result " +
918be630f07SMatthias Springer                        std::to_string(opResult.getResultNumber()) + "]";
919be630f07SMatthias Springer     opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr());
920be630f07SMatthias Springer   } else {
9215550c821STres Popp     auto bbArg = cast<BlockArgument>(value);
922be630f07SMatthias Springer     std::string attr = id + "[NOT-WRITABLE: bbArg " +
923be630f07SMatthias Springer                        std::to_string(bbArg.getArgNumber()) + "]";
924be630f07SMatthias Springer     bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr());
925be630f07SMatthias Springer   }
926be630f07SMatthias Springer }
927be630f07SMatthias Springer 
928032be233SMatthias Springer /// Return true if bufferizing `operand` inplace would create a write to a
929032be233SMatthias Springer /// non-writable buffer.
930cf2d374eSMatthias Springer static bool
931cf2d374eSMatthias Springer wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
932cf2d374eSMatthias Springer                                     OneShotAnalysisState &state,
933cf2d374eSMatthias Springer                                     bool checkConsistencyOnly = false) {
934dc7ad194SMatthias Springer   bool foundWrite =
935dc7ad194SMatthias Springer       !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);
936dc7ad194SMatthias Springer 
937dc7ad194SMatthias Springer   if (!foundWrite) {
938032be233SMatthias Springer     // Collect writes of all aliases of OpOperand and OpResult.
939032be233SMatthias Springer     DenseSet<OpOperand *> usesWrite;
940cf2d374eSMatthias Springer     getAliasingInplaceWrites(usesWrite, operand.get(), state);
941a02ad6c1SMatthias Springer     for (AliasingValue alias : state.getAliasingValues(operand))
942a02ad6c1SMatthias Springer       getAliasingInplaceWrites(usesWrite, alias.value, state);
943dc7ad194SMatthias Springer     foundWrite = !usesWrite.empty();
944032be233SMatthias Springer   }
9457a1579acSMatthias Springer 
946dc7ad194SMatthias Springer   if (!foundWrite)
947dc7ad194SMatthias Springer     return false;
948dc7ad194SMatthias Springer 
949dc7ad194SMatthias Springer   // Look for a read-only tensor among all aliases.
950dc7ad194SMatthias Springer   bool foundReadOnly = false;
951dc7ad194SMatthias Springer   auto checkReadOnly = [&](Value v) {
952dc7ad194SMatthias Springer     if (!state.isWritable(v)) {
953dc7ad194SMatthias Springer       foundReadOnly = true;
954dc7ad194SMatthias Springer       if (state.getOptions().printConflicts)
955dc7ad194SMatthias Springer         annotateNonWritableTensor(v);
956dc7ad194SMatthias Springer     }
957dc7ad194SMatthias Springer   };
958dc7ad194SMatthias Springer   state.applyOnAliases(operand.get(), checkReadOnly);
959a02ad6c1SMatthias Springer   for (AliasingValue alias : state.getAliasingValues(operand))
960a02ad6c1SMatthias Springer     state.applyOnAliases(alias.value, checkReadOnly);
961dc7ad194SMatthias Springer   if (foundReadOnly) {
962e0b40af7SMatthias Springer     LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
963032be233SMatthias Springer     return true;
964e0b40af7SMatthias Springer   }
9657a1579acSMatthias Springer 
966032be233SMatthias Springer   return false;
9677a1579acSMatthias Springer }
9687a1579acSMatthias Springer 
9697a1579acSMatthias Springer //===----------------------------------------------------------------------===//
9707a1579acSMatthias Springer // Bufferization analyses.
9717a1579acSMatthias Springer //===----------------------------------------------------------------------===//
9727a1579acSMatthias Springer 
973*d9111f19SAmir Bishara // Find the values that define the contents of the given operand's value.
9742b5a020dSMatthias Springer const llvm::SetVector<Value> &
975*d9111f19SAmir Bishara OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
976*d9111f19SAmir Bishara   Value value = opOperand->get();
9771f479c1eSMatthias Springer   if (!cachedDefinitions.count(value))
978*d9111f19SAmir Bishara     cachedDefinitions[value] = findDefinitions(opOperand);
9792b5a020dSMatthias Springer   return cachedDefinitions[value];
9802b5a020dSMatthias Springer }
9812b5a020dSMatthias Springer 
9829312b4f9SMartin Erhart void OneShotAnalysisState::resetCache() {
9839312b4f9SMartin Erhart   AnalysisState::resetCache();
9849312b4f9SMartin Erhart   cachedDefinitions.clear();
9859312b4f9SMartin Erhart }
9862b5a020dSMatthias Springer 
9877a1579acSMatthias Springer /// Determine if `operand` can be bufferized in-place.
988cf2d374eSMatthias Springer static LogicalResult
989cf2d374eSMatthias Springer bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
990cf2d374eSMatthias Springer                                 const DominanceInfo &domInfo) {
991e0b40af7SMatthias Springer   LLVM_DEBUG(
992e0b40af7SMatthias Springer       llvm::dbgs() << "//===-------------------------------------------===//\n"
993e0b40af7SMatthias Springer                    << "Analyzing operand #" << operand.getOperandNumber()
994e0b40af7SMatthias Springer                    << " of " << *operand.getOwner() << "\n");
995e0b40af7SMatthias Springer 
9967a1579acSMatthias Springer   bool foundInterference =
997cf2d374eSMatthias Springer       wouldCreateWriteToNonWritableBuffer(operand, state) ||
998cf2d374eSMatthias Springer       wouldCreateReadAfterWriteInterference(operand, domInfo, state);
9997a1579acSMatthias Springer 
10007a1579acSMatthias Springer   if (foundInterference)
1001cf2d374eSMatthias Springer     state.bufferizeOutOfPlace(operand);
10027a1579acSMatthias Springer   else
1003cf2d374eSMatthias Springer     state.bufferizeInPlace(operand);
10047a1579acSMatthias Springer 
1005e0b40af7SMatthias Springer   LLVM_DEBUG(llvm::dbgs()
1006e0b40af7SMatthias Springer              << "//===-------------------------------------------===//\n");
10077a1579acSMatthias Springer   return success();
10087a1579acSMatthias Springer }
10097a1579acSMatthias Springer 
10106d14b110SMatthias Springer LogicalResult
10116d14b110SMatthias Springer OneShotAnalysisState::analyzeSingleOp(Operation *op,
10126d14b110SMatthias Springer                                       const DominanceInfo &domInfo) {
10137a1579acSMatthias Springer   for (OpOperand &opOperand : op->getOpOperands())
10145550c821STres Popp     if (isa<TensorType>(opOperand.get().getType()))
10156d14b110SMatthias Springer       if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo)))
10167a1579acSMatthias Springer         return failure();
10171b99f3a2SMatthias Springer   return success();
10187a1579acSMatthias Springer }
10197a1579acSMatthias Springer 
10207a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
10217a1579acSMatthias Springer static void equivalenceAnalysis(SmallVector<Operation *> &ops,
1022cf2d374eSMatthias Springer                                 OneShotAnalysisState &state) {
10239fa6b350SMatthias Springer   for (Operation *op : ops) {
10249fa6b350SMatthias Springer     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
10259fa6b350SMatthias Springer       for (OpResult opResult : op->getOpResults()) {
10265550c821STres Popp         if (!isa<TensorType>(opResult.getType()))
10279fa6b350SMatthias Springer           continue;
10289fa6b350SMatthias Springer         AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
10299fa6b350SMatthias Springer         if (aliases.getNumAliases() == 0)
10309fa6b350SMatthias Springer           // Nothing to do if there are no aliasing OpOperands.
10319fa6b350SMatthias Springer           continue;
10329fa6b350SMatthias Springer 
10339fa6b350SMatthias Springer         Value firstOperand = aliases.begin()->opOperand->get();
10349fa6b350SMatthias Springer         bool allEquivalent = true;
10359fa6b350SMatthias Springer         for (AliasingOpOperand alias : aliases) {
10369fa6b350SMatthias Springer           bool isEquiv = alias.relation == BufferRelation::Equivalent;
10379fa6b350SMatthias Springer           bool isInPlace = state.isInPlace(*alias.opOperand);
10389fa6b350SMatthias Springer           Value operand = alias.opOperand->get();
10399fa6b350SMatthias Springer           if (isEquiv && isInPlace && alias.isDefinite) {
10409fa6b350SMatthias Springer             // Found a definite, equivalent alias. Merge equivalence sets.
10419fa6b350SMatthias Springer             // There can only be one definite alias, so we can stop here.
10429fa6b350SMatthias Springer             state.unionEquivalenceClasses(opResult, operand);
10439fa6b350SMatthias Springer             allEquivalent = false;
10449fa6b350SMatthias Springer             break;
10459fa6b350SMatthias Springer           }
10469fa6b350SMatthias Springer           if (!isEquiv || !isInPlace)
10479fa6b350SMatthias Springer             allEquivalent = false;
10489fa6b350SMatthias Springer           if (!state.areEquivalentBufferizedValues(operand, firstOperand))
10499fa6b350SMatthias Springer             allEquivalent = false;
10509fa6b350SMatthias Springer         }
10519fa6b350SMatthias Springer 
10529fa6b350SMatthias Springer         // If all "maybe" aliases are equivalent and the OpResult is not a new
10539fa6b350SMatthias Springer         // allocation, it is a definite, equivalent alias. E.g.:
10549fa6b350SMatthias Springer         //
10559fa6b350SMatthias Springer         // aliasingOpOperands(%r) = {(%t0, EQUIV, MAYBE), (%t1, EQUIV, MAYBE)}
1056a02ad6c1SMatthias Springer         // aliasingValues(%t0) = {(%r, EQUIV, MAYBE)}
1057a02ad6c1SMatthias Springer         // aliasingValues(%t1) = {(%r, EQUIV, MAYBE)}
10589fa6b350SMatthias Springer         // %r = arith.select %c, %t0, %t1 : tensor<?xf32>
10599fa6b350SMatthias Springer         //
10609fa6b350SMatthias Springer         // If %t0 and %t1 are equivalent, it is safe to union the equivalence
10619fa6b350SMatthias Springer         // classes of %r, %t0 and %t1.
10629fa6b350SMatthias Springer         if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult))
10639fa6b350SMatthias Springer           state.unionEquivalenceClasses(opResult, firstOperand);
10649fa6b350SMatthias Springer       }
10659fa6b350SMatthias Springer     }
10669fa6b350SMatthias Springer   }
10677a1579acSMatthias Springer }
10687a1579acSMatthias Springer 
10697a1579acSMatthias Springer /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
10707a1579acSMatthias Springer /// in `op`.
1071cf2d374eSMatthias Springer static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) {
10727a1579acSMatthias Springer   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
10737a1579acSMatthias Springer   SmallVector<Operation *> ops;
10747a1579acSMatthias Springer   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
10757a1579acSMatthias Springer     // No tensors => no buffers.
10767a1579acSMatthias Springer     if (none_of(op->getResultTypes(), isaTensor))
10777a1579acSMatthias Springer       return;
10787a1579acSMatthias Springer     ops.push_back(op);
10797a1579acSMatthias Springer   });
10807a1579acSMatthias Springer 
1081cf2d374eSMatthias Springer   equivalenceAnalysis(ops, state);
10827a1579acSMatthias Springer }
10837a1579acSMatthias Springer 
108435d3b343SMatthias Springer /// "Bottom-up from terminators" heuristic.
108535d3b343SMatthias Springer static SmallVector<Operation *>
108635d3b343SMatthias Springer bottomUpFromTerminatorsHeuristic(Operation *op,
108735d3b343SMatthias Springer                                  const OneShotAnalysisState &state) {
108835d3b343SMatthias Springer   SetVector<Operation *> traversedOps;
108935d3b343SMatthias Springer 
109035d3b343SMatthias Springer   // Find region terminators.
109135d3b343SMatthias Springer   op->walk<WalkOrder::PostOrder>([&](RegionBranchTerminatorOpInterface term) {
109235d3b343SMatthias Springer     if (!traversedOps.insert(term))
109335d3b343SMatthias Springer       return;
109435d3b343SMatthias Springer     // Follow the reverse SSA use-def chain from each yielded value as long as
109535d3b343SMatthias Springer     // we stay within the same region.
109635d3b343SMatthias Springer     SmallVector<OpResult> worklist;
109735d3b343SMatthias Springer     for (Value v : term->getOperands()) {
109835d3b343SMatthias Springer       if (!isa<TensorType>(v.getType()))
109935d3b343SMatthias Springer         continue;
110035d3b343SMatthias Springer       auto opResult = dyn_cast<OpResult>(v);
110135d3b343SMatthias Springer       if (!opResult)
110235d3b343SMatthias Springer         continue;
110335d3b343SMatthias Springer       worklist.push_back(opResult);
110435d3b343SMatthias Springer     }
110535d3b343SMatthias Springer     while (!worklist.empty()) {
110635d3b343SMatthias Springer       OpResult opResult = worklist.pop_back_val();
110735d3b343SMatthias Springer       Operation *defOp = opResult.getDefiningOp();
110835d3b343SMatthias Springer       if (!traversedOps.insert(defOp))
110935d3b343SMatthias Springer         continue;
111035d3b343SMatthias Springer       if (!term->getParentRegion()->findAncestorOpInRegion(*defOp))
111135d3b343SMatthias Springer         continue;
111235d3b343SMatthias Springer       AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
111335d3b343SMatthias Springer       for (auto alias : aliases) {
111435d3b343SMatthias Springer         Value v = alias.opOperand->get();
111535d3b343SMatthias Springer         if (!isa<TensorType>(v.getType()))
111635d3b343SMatthias Springer           continue;
111735d3b343SMatthias Springer         auto opResult = dyn_cast<OpResult>(v);
111835d3b343SMatthias Springer         if (!opResult)
111935d3b343SMatthias Springer           continue;
112035d3b343SMatthias Springer         worklist.push_back(opResult);
112135d3b343SMatthias Springer       }
112235d3b343SMatthias Springer     }
112335d3b343SMatthias Springer   });
112435d3b343SMatthias Springer 
112535d3b343SMatthias Springer   // Analyze traversed ops, then all remaining ops.
112635d3b343SMatthias Springer   SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end());
112735d3b343SMatthias Springer   op->walk<WalkOrder::PostOrder, ReverseIterator>([&](Operation *op) {
112835d3b343SMatthias Springer     if (!traversedOps.contains(op) && hasTensorSemantics(op))
112935d3b343SMatthias Springer       result.push_back(op);
113035d3b343SMatthias Springer   });
113135d3b343SMatthias Springer   return result;
113235d3b343SMatthias Springer }
113335d3b343SMatthias Springer 
11346d14b110SMatthias Springer LogicalResult OneShotAnalysisState::analyzeOp(Operation *op,
11356d14b110SMatthias Springer                                               const DominanceInfo &domInfo) {
113635d3b343SMatthias Springer   OneShotBufferizationOptions::AnalysisHeuristic heuristic =
113735d3b343SMatthias Springer       getOptions().analysisHeuristic;
113835d3b343SMatthias Springer 
113935d3b343SMatthias Springer   SmallVector<Operation *> orderedOps;
114035d3b343SMatthias Springer   if (heuristic ==
114135d3b343SMatthias Springer       OneShotBufferizationOptions::AnalysisHeuristic::BottomUpFromTerminators) {
114235d3b343SMatthias Springer     orderedOps = bottomUpFromTerminatorsHeuristic(op, *this);
114335d3b343SMatthias Springer   } else {
11446d14b110SMatthias Springer     op->walk([&](Operation *op) {
11456d14b110SMatthias Springer       // No tensors => no buffers.
11466d14b110SMatthias Springer       if (!hasTensorSemantics(op))
11476d14b110SMatthias Springer         return;
114835d3b343SMatthias Springer       orderedOps.push_back(op);
11496d14b110SMatthias Springer     });
115035d3b343SMatthias Springer     switch (heuristic) {
115135d3b343SMatthias Springer     case OneShotBufferizationOptions::AnalysisHeuristic::BottomUp: {
11526d14b110SMatthias Springer       // Default: Walk ops in reverse for better interference analysis.
115335d3b343SMatthias Springer       std::reverse(orderedOps.begin(), orderedOps.end());
115435d3b343SMatthias Springer       break;
115535d3b343SMatthias Springer     }
115635d3b343SMatthias Springer     case OneShotBufferizationOptions::AnalysisHeuristic::TopDown: {
115735d3b343SMatthias Springer       // Ops are already sorted top-down in `orderedOps`.
115835d3b343SMatthias Springer       break;
115935d3b343SMatthias Springer     }
116035d3b343SMatthias Springer     case OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer: {
116135d3b343SMatthias Springer       assert(getOptions().analysisFuzzerSeed &&
116235d3b343SMatthias Springer              "expected that fuzzer seed it set");
116335d3b343SMatthias Springer       // This is a fuzzer. For testing purposes only. Randomize the order in
116435d3b343SMatthias Springer       // which operations are analyzed. The bufferization quality is likely
116535d3b343SMatthias Springer       // worse, but we want to make sure that no assertions are triggered
116635d3b343SMatthias Springer       // anywhere.
116735d3b343SMatthias Springer       std::mt19937 g(getOptions().analysisFuzzerSeed);
116835d3b343SMatthias Springer       llvm::shuffle(orderedOps.begin(), orderedOps.end(), g);
116935d3b343SMatthias Springer       break;
117035d3b343SMatthias Springer     }
117135d3b343SMatthias Springer     default: {
11726d14b110SMatthias Springer       llvm_unreachable("unsupported heuristic");
11736d14b110SMatthias Springer     }
117435d3b343SMatthias Springer     }
117535d3b343SMatthias Springer   }
117635d3b343SMatthias Springer 
117735d3b343SMatthias Springer   // Analyze ops in the computed order.
117835d3b343SMatthias Springer   for (Operation *op : orderedOps)
117935d3b343SMatthias Springer     if (failed(analyzeSingleOp(op, domInfo)))
118035d3b343SMatthias Springer       return failure();
11816d14b110SMatthias Springer 
11826d14b110SMatthias Springer   equivalenceAnalysis(op, *this);
11836d14b110SMatthias Springer   return success();
11846d14b110SMatthias Springer }
11856d14b110SMatthias Springer 
1186061aa2e3SMatthias Springer /// Perform various checks on the input IR to see if it contains IR constructs
1187061aa2e3SMatthias Springer /// that are unsupported by One-Shot Bufferize.
1188061aa2e3SMatthias Springer static LogicalResult
1189061aa2e3SMatthias Springer checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo,
1190cf2d374eSMatthias Springer                                  OneShotAnalysisState &state) {
11917a1579acSMatthias Springer   const BufferizationOptions &options = state.getOptions();
11923f914d84SMatthias Springer 
1193061aa2e3SMatthias Springer   // Note: This walk cannot be combined with the one below because interface
1194061aa2e3SMatthias Springer   // methods of invalid/unsupported ops may be called during the second walk.
1195061aa2e3SMatthias Springer   // (On ops different from `op`.)
11963f914d84SMatthias Springer   WalkResult walkResult = op->walk([&](BufferizableOpInterface op) {
11973f914d84SMatthias Springer     // Skip ops that are not in the filter.
11983f914d84SMatthias Springer     if (!options.isOpAllowed(op.getOperation()))
11993f914d84SMatthias Springer       return WalkResult::advance();
12003f914d84SMatthias Springer 
1201061aa2e3SMatthias Springer     // Check for unsupported unstructured control flow.
1202061aa2e3SMatthias Springer     if (!op.supportsUnstructuredControlFlow()) {
1203061aa2e3SMatthias Springer       for (Region &r : op->getRegions()) {
1204061aa2e3SMatthias Springer         if (r.getBlocks().size() > 1) {
1205061aa2e3SMatthias Springer           op->emitOpError("op or BufferizableOpInterface implementation does "
1206061aa2e3SMatthias Springer                           "not support unstructured control flow, but at least "
1207061aa2e3SMatthias Springer                           "one region has multiple blocks");
1208061aa2e3SMatthias Springer           return WalkResult::interrupt();
1209061aa2e3SMatthias Springer         }
1210061aa2e3SMatthias Springer       }
1211061aa2e3SMatthias Springer     }
1212061aa2e3SMatthias Springer 
1213061aa2e3SMatthias Springer     return WalkResult::advance();
1214061aa2e3SMatthias Springer   });
1215061aa2e3SMatthias Springer   if (walkResult.wasInterrupted())
1216061aa2e3SMatthias Springer     return failure();
1217061aa2e3SMatthias Springer 
1218061aa2e3SMatthias Springer   walkResult = op->walk([&](BufferizableOpInterface op) {
1219061aa2e3SMatthias Springer     // Skip ops that are not in the filter.
1220061aa2e3SMatthias Springer     if (!options.isOpAllowed(op.getOperation()))
1221061aa2e3SMatthias Springer       return WalkResult::advance();
1222061aa2e3SMatthias Springer 
12238f7e7400SMatthias Springer     // Input IR may not contain any ToTensorOps without the "restrict"
12248f7e7400SMatthias Springer     // attribute. Such tensors may alias any other tensor, which is currently
12258f7e7400SMatthias Springer     // not handled in the analysis.
12268f7e7400SMatthias Springer     if (auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) {
12276badbd6fSMatthias Springer       if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) {
12288ee38f3bSMatthias Springer         op->emitOpError("to_tensor ops without `restrict` are not supported by "
12298f7e7400SMatthias Springer                         "One-Shot Analysis");
12308f7e7400SMatthias Springer         return WalkResult::interrupt();
12318f7e7400SMatthias Springer       }
12328f7e7400SMatthias Springer     }
12338f7e7400SMatthias Springer 
12343f914d84SMatthias Springer     for (OpOperand &opOperand : op->getOpOperands()) {
12355550c821STres Popp       if (isa<TensorType>(opOperand.get().getType())) {
12367a1579acSMatthias Springer         if (wouldCreateReadAfterWriteInterference(
1237cf2d374eSMatthias Springer                 opOperand, domInfo, state,
12387a1579acSMatthias Springer                 /*checkConsistencyOnly=*/true)) {
12397a1579acSMatthias Springer           // This error can happen if certain "mustBufferizeInPlace" interface
12407a1579acSMatthias Springer           // methods are implemented incorrectly, such that the IR already has
12418ee38f3bSMatthias Springer           // a RaW conflict before making any bufferization decisions. It can
12428ee38f3bSMatthias Springer           // also happen if the bufferization.materialize_in_destination is used
12438ee38f3bSMatthias Springer           // in such a way that a RaW conflict is not avoidable.
12448ee38f3bSMatthias Springer           op->emitOpError("not bufferizable under the given constraints: "
12458ee38f3bSMatthias Springer                           "cannot avoid RaW conflict");
12468ee38f3bSMatthias Springer           return WalkResult::interrupt();
12478ee38f3bSMatthias Springer         }
12488ee38f3bSMatthias Springer 
12498ee38f3bSMatthias Springer         if (state.isInPlace(opOperand) &&
12508ee38f3bSMatthias Springer             wouldCreateWriteToNonWritableBuffer(
12518ee38f3bSMatthias Springer                 opOperand, state, /*checkConsistencyOnly=*/true)) {
12528ee38f3bSMatthias Springer           op->emitOpError("not bufferizable under the given constraints: would "
12538ee38f3bSMatthias Springer                           "write to read-only buffer");
12547a1579acSMatthias Springer           return WalkResult::interrupt();
12557a1579acSMatthias Springer         }
12567a1579acSMatthias Springer       }
12573f914d84SMatthias Springer     }
12583f914d84SMatthias Springer 
12597a1579acSMatthias Springer     return WalkResult::advance();
12607a1579acSMatthias Springer   });
12617a1579acSMatthias Springer 
12623f914d84SMatthias Springer   return success(!walkResult.wasInterrupted());
12637a1579acSMatthias Springer }
12647a1579acSMatthias Springer 
12657a1579acSMatthias Springer /// Annotate the IR with the result of the analysis. For testing/debugging only.
12667a1579acSMatthias Springer static void
12677a1579acSMatthias Springer annotateOpsWithBufferizationMarkers(Operation *op,
1268cf2d374eSMatthias Springer                                     const OneShotAnalysisState &state) {
12697cdfc843SMatthias Springer   // Add __inplace_operands_attr__.
1270f3483c23SMatthias Springer   op->walk([&](Operation *op) {
1271f3483c23SMatthias Springer     for (OpOperand &opOperand : op->getOpOperands())
12725550c821STres Popp       if (isa<TensorType>(opOperand.get().getType()))
1273cf2d374eSMatthias Springer         setInPlaceOpOperand(opOperand, state.isInPlace(opOperand));
12747a1579acSMatthias Springer   });
12757a1579acSMatthias Springer }
12767a1579acSMatthias Springer 
1277bb9d1b55SMatthias Springer static void annotateOpsWithAliasSets(Operation *op,
1278bb9d1b55SMatthias Springer                                      const OneShotAnalysisState &state) {
1279bb9d1b55SMatthias Springer   AsmState asmState(op);
1280bb9d1b55SMatthias Springer   Builder b(op->getContext());
1281a02ad6c1SMatthias Springer   // Helper function to build an array attribute of aliasing SSA value strings.
1282a02ad6c1SMatthias Springer   auto buildAliasesArray = [&](Value v) {
1283bb9d1b55SMatthias Springer     SmallVector<Attribute> aliases;
1284a02ad6c1SMatthias Springer     state.applyOnAliases(v, [&](Value alias) {
1285bb9d1b55SMatthias Springer       std::string buffer;
1286bb9d1b55SMatthias Springer       llvm::raw_string_ostream stream(buffer);
1287bb9d1b55SMatthias Springer       alias.printAsOperand(stream, asmState);
1288884221edSJOE1994       aliases.push_back(b.getStringAttr(buffer));
1289bb9d1b55SMatthias Springer     });
1290a02ad6c1SMatthias Springer     return b.getArrayAttr(aliases);
1291a02ad6c1SMatthias Springer   };
1292a02ad6c1SMatthias Springer 
1293a02ad6c1SMatthias Springer   op->walk([&](Operation *op) {
1294a02ad6c1SMatthias Springer     // Build alias set array for every OpResult.
1295a02ad6c1SMatthias Springer     SmallVector<Attribute> opResultAliasSets;
1296a02ad6c1SMatthias Springer     for (OpResult opResult : op->getOpResults()) {
1297a02ad6c1SMatthias Springer       if (llvm::isa<TensorType>(opResult.getType())) {
1298a02ad6c1SMatthias Springer         opResultAliasSets.push_back(buildAliasesArray(opResult));
1299bb9d1b55SMatthias Springer       }
1300bb9d1b55SMatthias Springer     }
1301a02ad6c1SMatthias Springer     if (!opResultAliasSets.empty())
1302a02ad6c1SMatthias Springer       op->setAttr(kOpResultAliasSetAttrName, b.getArrayAttr(opResultAliasSets));
1303a02ad6c1SMatthias Springer 
1304a02ad6c1SMatthias Springer     // Build alias set array for every BlockArgument.
1305a02ad6c1SMatthias Springer     SmallVector<Attribute> regionAliasSets;
1306a02ad6c1SMatthias Springer     bool hasTensorBbArg = false;
1307a02ad6c1SMatthias Springer     for (Region &r : op->getRegions()) {
1308a02ad6c1SMatthias Springer       SmallVector<Attribute> blockAliasSets;
1309a02ad6c1SMatthias Springer       for (Block &block : r.getBlocks()) {
1310a02ad6c1SMatthias Springer         SmallVector<Attribute> bbArgAliasSets;
1311a02ad6c1SMatthias Springer         for (BlockArgument bbArg : block.getArguments()) {
1312a02ad6c1SMatthias Springer           if (llvm::isa<TensorType>(bbArg.getType())) {
1313a02ad6c1SMatthias Springer             bbArgAliasSets.push_back(buildAliasesArray(bbArg));
1314a02ad6c1SMatthias Springer             hasTensorBbArg = true;
1315a02ad6c1SMatthias Springer           }
1316a02ad6c1SMatthias Springer         }
1317a02ad6c1SMatthias Springer         blockAliasSets.push_back(b.getArrayAttr(bbArgAliasSets));
1318a02ad6c1SMatthias Springer       }
1319a02ad6c1SMatthias Springer       regionAliasSets.push_back(b.getArrayAttr(blockAliasSets));
1320a02ad6c1SMatthias Springer     }
1321a02ad6c1SMatthias Springer     if (hasTensorBbArg)
1322a02ad6c1SMatthias Springer       op->setAttr(kBbArgAliasSetAttrName, b.getArrayAttr(regionAliasSets));
1323bb9d1b55SMatthias Springer   });
1324bb9d1b55SMatthias Springer }
1325bb9d1b55SMatthias Springer 
13267a1579acSMatthias Springer LogicalResult bufferization::analyzeOp(Operation *op,
1327ae05bd99SMatthias Springer                                        OneShotAnalysisState &state,
1328ae05bd99SMatthias Springer                                        BufferizationStatistics *statistics) {
13297a1579acSMatthias Springer   DominanceInfo domInfo(op);
13307cdfc843SMatthias Springer   const OneShotBufferizationOptions &options = state.getOptions();
13317a1579acSMatthias Springer 
1332061aa2e3SMatthias Springer   if (failed(checkPreBufferizationAssumptions(op, domInfo, state)))
13337a1579acSMatthias Springer     return failure();
13347a1579acSMatthias Springer 
13357a1579acSMatthias Springer   // If the analysis fails, just return.
13366d14b110SMatthias Springer   if (failed(state.analyzeOp(op, domInfo)))
13377a1579acSMatthias Springer     return failure();
1338ae05bd99SMatthias Springer 
1339ae05bd99SMatthias Springer   if (statistics) {
1340cf2d374eSMatthias Springer     statistics->numTensorInPlace = state.getStatNumTensorInPlace();
1341cf2d374eSMatthias Springer     statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace();
1342ae05bd99SMatthias Springer   }
1343ae05bd99SMatthias Springer 
1344d1d79920SMatthias Springer   bool failedAnalysis = false;
13457a1579acSMatthias Springer 
1346988748c0SMatthias Springer   // Gather some extra analysis data.
1347988748c0SMatthias Springer   state.gatherUndefinedTensorUses(op);
13489e24f0f4SMatthias Springer 
13494ec00fb3SMatthias Springer   // Analysis verification: After setting up alias/equivalence sets, each op
13504ec00fb3SMatthias Springer   // can check for expected invariants/limitations and fail the analysis if
13514ec00fb3SMatthias Springer   // necessary.
13524ec00fb3SMatthias Springer   op->walk([&](Operation *op) {
13534ec00fb3SMatthias Springer     if (BufferizableOpInterface bufferizableOp =
13544ec00fb3SMatthias Springer             options.dynCastBufferizableOp(op))
1355d1d79920SMatthias Springer       failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
13564ec00fb3SMatthias Springer   });
13574ec00fb3SMatthias Springer 
13587a1579acSMatthias Springer   // Annotate operations if we only want to report the analysis.
13597a1579acSMatthias Springer   if (options.testAnalysisOnly)
1360cf2d374eSMatthias Springer     annotateOpsWithBufferizationMarkers(op, state);
1361bb9d1b55SMatthias Springer   if (options.dumpAliasSets)
1362bb9d1b55SMatthias Springer     annotateOpsWithAliasSets(op, state);
13637a1579acSMatthias Springer 
1364d1d79920SMatthias Springer   return success(!failedAnalysis);
13657a1579acSMatthias Springer }
13667a1579acSMatthias Springer 
13679597b16aSMatthias Springer LogicalResult
13689597b16aSMatthias Springer bufferization::runOneShotBufferize(Operation *op,
1369ae05bd99SMatthias Springer                                    const OneShotBufferizationOptions &options,
1370ae05bd99SMatthias Springer                                    BufferizationStatistics *statistics) {
1371179e1749SMatthias Springer   // copy-before-write deactivates the analysis. It cannot be used together with
1372179e1749SMatthias Springer   // test-analysis-only.
1373f7dd9a32SMatthias Springer   assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
1374f7dd9a32SMatthias Springer          "invalid combination of bufferization flags");
1375179e1749SMatthias Springer 
1376179e1749SMatthias Springer   if (options.copyBeforeWrite) {
1377179e1749SMatthias Springer     // Copy buffer before each write. No analysis is needed.
1378179e1749SMatthias Springer   } else {
1379179e1749SMatthias Springer     // Run One-Shot Analysis and insert buffer copies (on the tensor level)
1380179e1749SMatthias Springer     // only where needed. This is the default and much more efficient than
1381179e1749SMatthias Springer     // copy-before-write.
1382ae05bd99SMatthias Springer     if (failed(insertTensorCopies(op, options, statistics)))
13837a1579acSMatthias Springer       return failure();
1384179e1749SMatthias Springer 
1385179e1749SMatthias Springer     // If test-analysis-only is set, the IR was annotated with RaW conflict
1386179e1749SMatthias Springer     // markers (attributes) during One-Shot Analysis.
1387d2dacde5SMatthias Springer     if (options.testAnalysisOnly)
13887a1579acSMatthias Springer       return success();
1389179e1749SMatthias Springer   }
1390179e1749SMatthias Springer 
1391179e1749SMatthias Springer   // Bufferize the op and its nested ops. If options.copyBeforeWrite is set,
1392179e1749SMatthias Springer   // a new buffer copy is allocated every time a buffer is written to.
13939d34c052SMatthias Springer   return bufferizeOp(op, options, statistics);
13947a1579acSMatthias Springer }
1395