1 //===- BufferDeallocationOpInterface.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" 10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/IR/AsmState.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "mlir/IR/Value.h" 17 #include "llvm/ADT/SetOperations.h" 18 19 //===----------------------------------------------------------------------===// 20 // BufferDeallocationOpInterface 21 //===----------------------------------------------------------------------===// 22 23 namespace mlir { 24 namespace bufferization { 25 26 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc" 27 28 } // namespace bufferization 29 } // namespace mlir 30 31 using namespace mlir; 32 using namespace bufferization; 33 34 //===----------------------------------------------------------------------===// 35 // Helpers 36 //===----------------------------------------------------------------------===// 37 38 static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { 39 return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value)); 40 } 41 42 static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); } 43 44 //===----------------------------------------------------------------------===// 45 // Ownership 46 //===----------------------------------------------------------------------===// 47 48 Ownership::Ownership(Value indicator) 49 : indicator(indicator), state(State::Unique) {} 50 51 Ownership Ownership::getUnknown() { 52 Ownership unknown; 53 unknown.indicator = Value(); 54 unknown.state = State::Unknown; 55 return unknown; 56 } 57 Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); } 58 Ownership Ownership::getUninitialized() { return Ownership(); } 59 60 bool Ownership::isUninitialized() const { 61 return state == State::Uninitialized; 62 } 63 bool Ownership::isUnique() const { return state == State::Unique; } 64 bool Ownership::isUnknown() const { return state == State::Unknown; } 65 66 Value Ownership::getIndicator() const { 67 assert(isUnique() && "must have unique ownership to get the indicator"); 68 return indicator; 69 } 70 71 Ownership Ownership::getCombined(Ownership other) const { 72 if (other.isUninitialized()) 73 return *this; 74 if (isUninitialized()) 75 return other; 76 77 if (!isUnique() || !other.isUnique()) 78 return getUnknown(); 79 80 // Since we create a new constant i1 value for (almost) each use-site, we 81 // should compare the actual value rather than just the SSA Value to avoid 82 // unnecessary invalidations. 83 if (isEqualConstantIntOrValue(indicator, other.indicator)) 84 return *this; 85 86 // Return the join of the lattice if the indicator of both ownerships cannot 87 // be merged. 88 return getUnknown(); 89 } 90 91 void Ownership::combine(Ownership other) { *this = getCombined(other); } 92 93 //===----------------------------------------------------------------------===// 94 // DeallocationState 95 //===----------------------------------------------------------------------===// 96 97 DeallocationState::DeallocationState(Operation *op) : liveness(op) {} 98 99 void DeallocationState::updateOwnership(Value memref, Ownership ownership, 100 Block *block) { 101 // In most cases we care about the block where the value is defined. 102 if (block == nullptr) 103 block = memref.getParentBlock(); 104 105 // Update ownership of current memref itself. 106 ownershipMap[{memref, block}].combine(ownership); 107 } 108 109 void DeallocationState::resetOwnerships(ValueRange memrefs, Block *block) { 110 for (Value val : memrefs) 111 ownershipMap[{val, block}] = Ownership::getUninitialized(); 112 } 113 114 Ownership DeallocationState::getOwnership(Value memref, Block *block) const { 115 return ownershipMap.lookup({memref, block}); 116 } 117 118 void DeallocationState::addMemrefToDeallocate(Value memref, Block *block) { 119 memrefsToDeallocatePerBlock[block].push_back(memref); 120 } 121 122 void DeallocationState::dropMemrefToDeallocate(Value memref, Block *block) { 123 llvm::erase(memrefsToDeallocatePerBlock[block], memref); 124 } 125 126 void DeallocationState::getLiveMemrefsIn(Block *block, 127 SmallVectorImpl<Value> &memrefs) { 128 SmallVector<Value> liveMemrefs( 129 llvm::make_filter_range(liveness.getLiveIn(block), isMemref)); 130 llvm::sort(liveMemrefs, ValueComparator()); 131 memrefs.append(liveMemrefs); 132 } 133 134 std::pair<Value, Value> 135 DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder, 136 Value memref, Block *block) { 137 auto iter = ownershipMap.find({memref, block}); 138 assert(iter != ownershipMap.end() && 139 "Value must already have been registered in the ownership map"); 140 141 Ownership ownership = iter->second; 142 if (ownership.isUnique()) 143 return {memref, ownership.getIndicator()}; 144 145 // Instead of inserting a clone operation we could also insert a dealloc 146 // operation earlier in the block and use the updated ownerships returned by 147 // the op for the retained values. Alternatively, we could insert code to 148 // check aliasing at runtime and use this information to combine two unique 149 // ownerships more intelligently to not end up with an 'Unknown' ownership in 150 // the first place. 151 auto cloneOp = 152 builder.create<bufferization::CloneOp>(memref.getLoc(), memref); 153 Value condition = buildBoolValue(builder, memref.getLoc(), true); 154 Value newMemref = cloneOp.getResult(); 155 updateOwnership(newMemref, condition); 156 memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref); 157 return {newMemref, condition}; 158 } 159 160 void DeallocationState::getMemrefsToRetain( 161 Block *fromBlock, Block *toBlock, ValueRange destOperands, 162 SmallVectorImpl<Value> &toRetain) const { 163 for (Value operand : destOperands) { 164 if (!isMemref(operand)) 165 continue; 166 toRetain.push_back(operand); 167 } 168 169 SmallPtrSet<Value, 16> liveOut; 170 for (auto val : liveness.getLiveOut(fromBlock)) 171 if (isMemref(val)) 172 liveOut.insert(val); 173 174 if (toBlock) 175 llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock)); 176 177 // liveOut has non-deterministic order because it was constructed by iterating 178 // over a hash-set. 179 SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end()); 180 std::sort(retainedByLiveness.begin(), retainedByLiveness.end(), 181 ValueComparator()); 182 toRetain.append(retainedByLiveness); 183 } 184 185 LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate( 186 OpBuilder &builder, Location loc, Block *block, 187 SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const { 188 189 for (auto [i, memref] : 190 llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) { 191 Ownership ownership = ownershipMap.lookup({memref, block}); 192 if (!ownership.isUnique()) 193 return emitError(memref.getLoc(), 194 "MemRef value does not have valid ownership"); 195 196 // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such 197 // that we can call extract_strided_metadata on it. 198 if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType())) 199 memref = builder.create<memref::ReinterpretCastOp>( 200 loc, memref, 201 /*offset=*/builder.getIndexAttr(0), 202 /*sizes=*/ArrayRef<OpFoldResult>{}, 203 /*strides=*/ArrayRef<OpFoldResult>{}); 204 205 // Use the `memref.extract_strided_metadata` operation to get the base 206 // memref. This is needed because the same MemRef that was produced by the 207 // alloc operation has to be passed to the dealloc operation. Passing 208 // subviews, etc. to a dealloc operation is not allowed. 209 memrefs.push_back( 210 builder.create<memref::ExtractStridedMetadataOp>(loc, memref) 211 .getResult(0)); 212 conditions.push_back(ownership.getIndicator()); 213 } 214 215 return success(); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // ValueComparator 220 //===----------------------------------------------------------------------===// 221 222 bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const { 223 if (lhs == rhs) 224 return false; 225 226 // Block arguments are less than results. 227 bool lhsIsBBArg = isa<BlockArgument>(lhs); 228 if (lhsIsBBArg != isa<BlockArgument>(rhs)) { 229 return lhsIsBBArg; 230 } 231 232 Region *lhsRegion; 233 Region *rhsRegion; 234 if (lhsIsBBArg) { 235 auto lhsBBArg = llvm::cast<BlockArgument>(lhs); 236 auto rhsBBArg = llvm::cast<BlockArgument>(rhs); 237 if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) { 238 return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber(); 239 } 240 lhsRegion = lhsBBArg.getParentRegion(); 241 rhsRegion = rhsBBArg.getParentRegion(); 242 assert(lhsRegion != rhsRegion && 243 "lhsRegion == rhsRegion implies lhs == rhs"); 244 } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) { 245 return llvm::cast<OpResult>(lhs).getResultNumber() < 246 llvm::cast<OpResult>(rhs).getResultNumber(); 247 } else { 248 lhsRegion = lhs.getDefiningOp()->getParentRegion(); 249 rhsRegion = rhs.getDefiningOp()->getParentRegion(); 250 if (lhsRegion == rhsRegion) { 251 return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp()); 252 } 253 } 254 255 // lhsRegion != rhsRegion, so if we look at their ancestor chain, they 256 // - have different heights 257 // - or there's a spot where their region numbers differ 258 // - or their parent regions are the same and their parent ops are 259 // different. 260 while (lhsRegion && rhsRegion) { 261 if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) { 262 return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber(); 263 } 264 if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) { 265 return lhsRegion->getParentOp()->isBeforeInBlock( 266 rhsRegion->getParentOp()); 267 } 268 lhsRegion = lhsRegion->getParentRegion(); 269 rhsRegion = rhsRegion->getParentRegion(); 270 } 271 if (rhsRegion) 272 return true; 273 assert(lhsRegion && "this should only happen if lhs == rhs"); 274 return false; 275 } 276 277 //===----------------------------------------------------------------------===// 278 // Implementation utilities 279 //===----------------------------------------------------------------------===// 280 281 FailureOr<Operation *> deallocation_impl::insertDeallocOpForReturnLike( 282 DeallocationState &state, Operation *op, ValueRange operands, 283 SmallVectorImpl<Value> &updatedOperandOwnerships) { 284 assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator"); 285 assert(!op->hasSuccessors() && "must not have any successors"); 286 // Collect the values to deallocate and retain and use them to create the 287 // dealloc operation. 288 OpBuilder builder(op); 289 Block *block = op->getBlock(); 290 SmallVector<Value> memrefs, conditions, toRetain; 291 if (failed(state.getMemrefsAndConditionsToDeallocate( 292 builder, op->getLoc(), block, memrefs, conditions))) 293 return failure(); 294 295 state.getMemrefsToRetain(block, /*toBlock=*/nullptr, operands, toRetain); 296 if (memrefs.empty() && toRetain.empty()) 297 return op; 298 299 auto deallocOp = builder.create<bufferization::DeallocOp>( 300 op->getLoc(), memrefs, conditions, toRetain); 301 302 // We want to replace the current ownership of the retained values with the 303 // result values of the dealloc operation as they are always unique. 304 state.resetOwnerships(deallocOp.getRetained(), block); 305 for (auto [retained, ownership] : 306 llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) 307 state.updateOwnership(retained, ownership, block); 308 309 unsigned numMemrefOperands = llvm::count_if(operands, isMemref); 310 auto newOperandOwnerships = 311 deallocOp.getUpdatedConditions().take_front(numMemrefOperands); 312 updatedOperandOwnerships.append(newOperandOwnerships.begin(), 313 newOperandOwnerships.end()); 314 315 return op; 316 } 317