xref: /llvm-project/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp (revision 0259f92711599c45d229fb12f6f51915fffac6bd)
1 //===- BufferDeallocationOpInterface.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/IR/AsmState.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "mlir/IR/Value.h"
17 #include "llvm/ADT/SetOperations.h"
18 
19 //===----------------------------------------------------------------------===//
20 // BufferDeallocationOpInterface
21 //===----------------------------------------------------------------------===//
22 
23 namespace mlir {
24 namespace bufferization {
25 
26 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"
27 
28 } // namespace bufferization
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace bufferization;
33 
34 //===----------------------------------------------------------------------===//
35 // Helpers
36 //===----------------------------------------------------------------------===//
37 
38 static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
39   return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
40 }
41 
42 static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
43 
44 //===----------------------------------------------------------------------===//
45 // Ownership
46 //===----------------------------------------------------------------------===//
47 
48 Ownership::Ownership(Value indicator)
49     : indicator(indicator), state(State::Unique) {}
50 
51 Ownership Ownership::getUnknown() {
52   Ownership unknown;
53   unknown.indicator = Value();
54   unknown.state = State::Unknown;
55   return unknown;
56 }
57 Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); }
58 Ownership Ownership::getUninitialized() { return Ownership(); }
59 
60 bool Ownership::isUninitialized() const {
61   return state == State::Uninitialized;
62 }
63 bool Ownership::isUnique() const { return state == State::Unique; }
64 bool Ownership::isUnknown() const { return state == State::Unknown; }
65 
66 Value Ownership::getIndicator() const {
67   assert(isUnique() && "must have unique ownership to get the indicator");
68   return indicator;
69 }
70 
71 Ownership Ownership::getCombined(Ownership other) const {
72   if (other.isUninitialized())
73     return *this;
74   if (isUninitialized())
75     return other;
76 
77   if (!isUnique() || !other.isUnique())
78     return getUnknown();
79 
80   // Since we create a new constant i1 value for (almost) each use-site, we
81   // should compare the actual value rather than just the SSA Value to avoid
82   // unnecessary invalidations.
83   if (isEqualConstantIntOrValue(indicator, other.indicator))
84     return *this;
85 
86   // Return the join of the lattice if the indicator of both ownerships cannot
87   // be merged.
88   return getUnknown();
89 }
90 
91 void Ownership::combine(Ownership other) { *this = getCombined(other); }
92 
93 //===----------------------------------------------------------------------===//
94 // DeallocationState
95 //===----------------------------------------------------------------------===//
96 
97 DeallocationState::DeallocationState(Operation *op) : liveness(op) {}
98 
99 void DeallocationState::updateOwnership(Value memref, Ownership ownership,
100                                         Block *block) {
101   // In most cases we care about the block where the value is defined.
102   if (block == nullptr)
103     block = memref.getParentBlock();
104 
105   // Update ownership of current memref itself.
106   ownershipMap[{memref, block}].combine(ownership);
107 }
108 
109 void DeallocationState::resetOwnerships(ValueRange memrefs, Block *block) {
110   for (Value val : memrefs)
111     ownershipMap[{val, block}] = Ownership::getUninitialized();
112 }
113 
114 Ownership DeallocationState::getOwnership(Value memref, Block *block) const {
115   return ownershipMap.lookup({memref, block});
116 }
117 
118 void DeallocationState::addMemrefToDeallocate(Value memref, Block *block) {
119   memrefsToDeallocatePerBlock[block].push_back(memref);
120 }
121 
122 void DeallocationState::dropMemrefToDeallocate(Value memref, Block *block) {
123   llvm::erase(memrefsToDeallocatePerBlock[block], memref);
124 }
125 
126 void DeallocationState::getLiveMemrefsIn(Block *block,
127                                          SmallVectorImpl<Value> &memrefs) {
128   SmallVector<Value> liveMemrefs(
129       llvm::make_filter_range(liveness.getLiveIn(block), isMemref));
130   llvm::sort(liveMemrefs, ValueComparator());
131   memrefs.append(liveMemrefs);
132 }
133 
134 std::pair<Value, Value>
135 DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder,
136                                                 Value memref, Block *block) {
137   auto iter = ownershipMap.find({memref, block});
138   assert(iter != ownershipMap.end() &&
139          "Value must already have been registered in the ownership map");
140 
141   Ownership ownership = iter->second;
142   if (ownership.isUnique())
143     return {memref, ownership.getIndicator()};
144 
145   // Instead of inserting a clone operation we could also insert a dealloc
146   // operation earlier in the block and use the updated ownerships returned by
147   // the op for the retained values. Alternatively, we could insert code to
148   // check aliasing at runtime and use this information to combine two unique
149   // ownerships more intelligently to not end up with an 'Unknown' ownership in
150   // the first place.
151   auto cloneOp =
152       builder.create<bufferization::CloneOp>(memref.getLoc(), memref);
153   Value condition = buildBoolValue(builder, memref.getLoc(), true);
154   Value newMemref = cloneOp.getResult();
155   updateOwnership(newMemref, condition);
156   memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);
157   return {newMemref, condition};
158 }
159 
160 void DeallocationState::getMemrefsToRetain(
161     Block *fromBlock, Block *toBlock, ValueRange destOperands,
162     SmallVectorImpl<Value> &toRetain) const {
163   for (Value operand : destOperands) {
164     if (!isMemref(operand))
165       continue;
166     toRetain.push_back(operand);
167   }
168 
169   SmallPtrSet<Value, 16> liveOut;
170   for (auto val : liveness.getLiveOut(fromBlock))
171     if (isMemref(val))
172       liveOut.insert(val);
173 
174   if (toBlock)
175     llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock));
176 
177   // liveOut has non-deterministic order because it was constructed by iterating
178   // over a hash-set.
179   SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end());
180   std::sort(retainedByLiveness.begin(), retainedByLiveness.end(),
181             ValueComparator());
182   toRetain.append(retainedByLiveness);
183 }
184 
185 LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
186     OpBuilder &builder, Location loc, Block *block,
187     SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const {
188 
189   for (auto [i, memref] :
190        llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) {
191     Ownership ownership = ownershipMap.lookup({memref, block});
192     if (!ownership.isUnique())
193       return emitError(memref.getLoc(),
194                        "MemRef value does not have valid ownership");
195 
196     // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
197     // that we can call extract_strided_metadata on it.
198     if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
199       memref = builder.create<memref::ReinterpretCastOp>(
200           loc, memref,
201           /*offset=*/builder.getIndexAttr(0),
202           /*sizes=*/ArrayRef<OpFoldResult>{},
203           /*strides=*/ArrayRef<OpFoldResult>{});
204 
205     // Use the `memref.extract_strided_metadata` operation to get the base
206     // memref. This is needed because the same MemRef that was produced by the
207     // alloc operation has to be passed to the dealloc operation. Passing
208     // subviews, etc. to a dealloc operation is not allowed.
209     memrefs.push_back(
210         builder.create<memref::ExtractStridedMetadataOp>(loc, memref)
211             .getResult(0));
212     conditions.push_back(ownership.getIndicator());
213   }
214 
215   return success();
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // ValueComparator
220 //===----------------------------------------------------------------------===//
221 
222 bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
223   if (lhs == rhs)
224     return false;
225 
226   // Block arguments are less than results.
227   bool lhsIsBBArg = isa<BlockArgument>(lhs);
228   if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
229     return lhsIsBBArg;
230   }
231 
232   Region *lhsRegion;
233   Region *rhsRegion;
234   if (lhsIsBBArg) {
235     auto lhsBBArg = llvm::cast<BlockArgument>(lhs);
236     auto rhsBBArg = llvm::cast<BlockArgument>(rhs);
237     if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
238       return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
239     }
240     lhsRegion = lhsBBArg.getParentRegion();
241     rhsRegion = rhsBBArg.getParentRegion();
242     assert(lhsRegion != rhsRegion &&
243            "lhsRegion == rhsRegion implies lhs == rhs");
244   } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
245     return llvm::cast<OpResult>(lhs).getResultNumber() <
246            llvm::cast<OpResult>(rhs).getResultNumber();
247   } else {
248     lhsRegion = lhs.getDefiningOp()->getParentRegion();
249     rhsRegion = rhs.getDefiningOp()->getParentRegion();
250     if (lhsRegion == rhsRegion) {
251       return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
252     }
253   }
254 
255   // lhsRegion != rhsRegion, so if we look at their ancestor chain, they
256   // - have different heights
257   // - or there's a spot where their region numbers differ
258   // - or their parent regions are the same and their parent ops are
259   //   different.
260   while (lhsRegion && rhsRegion) {
261     if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
262       return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
263     }
264     if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
265       return lhsRegion->getParentOp()->isBeforeInBlock(
266           rhsRegion->getParentOp());
267     }
268     lhsRegion = lhsRegion->getParentRegion();
269     rhsRegion = rhsRegion->getParentRegion();
270   }
271   if (rhsRegion)
272     return true;
273   assert(lhsRegion && "this should only happen if lhs == rhs");
274   return false;
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // Implementation utilities
279 //===----------------------------------------------------------------------===//
280 
281 FailureOr<Operation *> deallocation_impl::insertDeallocOpForReturnLike(
282     DeallocationState &state, Operation *op, ValueRange operands,
283     SmallVectorImpl<Value> &updatedOperandOwnerships) {
284   assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator");
285   assert(!op->hasSuccessors() && "must not have any successors");
286   // Collect the values to deallocate and retain and use them to create the
287   // dealloc operation.
288   OpBuilder builder(op);
289   Block *block = op->getBlock();
290   SmallVector<Value> memrefs, conditions, toRetain;
291   if (failed(state.getMemrefsAndConditionsToDeallocate(
292           builder, op->getLoc(), block, memrefs, conditions)))
293     return failure();
294 
295   state.getMemrefsToRetain(block, /*toBlock=*/nullptr, operands, toRetain);
296   if (memrefs.empty() && toRetain.empty())
297     return op;
298 
299   auto deallocOp = builder.create<bufferization::DeallocOp>(
300       op->getLoc(), memrefs, conditions, toRetain);
301 
302   // We want to replace the current ownership of the retained values with the
303   // result values of the dealloc operation as they are always unique.
304   state.resetOwnerships(deallocOp.getRetained(), block);
305   for (auto [retained, ownership] :
306        llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
307     state.updateOwnership(retained, ownership, block);
308 
309   unsigned numMemrefOperands = llvm::count_if(operands, isMemref);
310   auto newOperandOwnerships =
311       deallocOp.getUpdatedConditions().take_front(numMemrefOperands);
312   updatedOperandOwnerships.append(newOperandOwnerships.begin(),
313                                   newOperandOwnerships.end());
314 
315   return op;
316 }
317