1 //===- BufferDeallocationSimplification.cpp -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for optimizing `bufferization.dealloc` operations 10 // that requires more analysis than what can be supported by regular 11 // canonicalization patterns. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 16 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" 17 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 23 namespace mlir { 24 namespace bufferization { 25 #define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION 26 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 27 } // namespace bufferization 28 } // namespace mlir 29 30 using namespace mlir; 31 using namespace mlir::bufferization; 32 33 //===----------------------------------------------------------------------===// 34 // Helpers 35 //===----------------------------------------------------------------------===// 36 37 /// Given a memref value, return the "base" value by skipping over all 38 /// ViewLikeOpInterface ops (if any) in the reverse use-def chain. 39 static Value getViewBase(Value value) { 40 while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) 41 value = viewLikeOp.getViewSource(); 42 return value; 43 } 44 45 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, 46 ValueRange memrefs, 47 ValueRange conditions, 48 PatternRewriter &rewriter) { 49 if (deallocOp.getMemrefs() == memrefs && 50 deallocOp.getConditions() == conditions) 51 return failure(); 52 53 rewriter.modifyOpInPlace(deallocOp, [&]() { 54 deallocOp.getMemrefsMutable().assign(memrefs); 55 deallocOp.getConditionsMutable().assign(conditions); 56 }); 57 return success(); 58 } 59 60 /// Return "true" if the given values are guaranteed to be different (and 61 /// non-aliasing) allocations based on the fact that one value is the result 62 /// of an allocation and the other value is a block argument of a parent block. 63 /// Note: This is a best-effort analysis that will eventually be replaced by a 64 /// proper "is same allocation" analysis. This function may return "false" even 65 /// though the two values are distinct allocations. 66 static bool distinctAllocAndBlockArgument(Value v1, Value v2) { 67 Value v1Base = getViewBase(v1); 68 Value v2Base = getViewBase(v2); 69 auto areDistinct = [](Value v1, Value v2) { 70 if (Operation *op = v1.getDefiningOp()) 71 if (hasEffect<MemoryEffects::Allocate>(op, v1)) 72 if (auto bbArg = dyn_cast<BlockArgument>(v2)) 73 if (bbArg.getOwner()->findAncestorOpInBlock(*op)) 74 return true; 75 return false; 76 }; 77 return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base); 78 } 79 80 /// Checks if `memref` may potentially alias a MemRef in `otherList`. It is 81 /// often a requirement of optimization patterns that there cannot be any 82 /// aliasing memref in order to perform the desired simplification. 83 static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis, 84 ValueRange otherList, Value memref) { 85 for (auto other : otherList) { 86 if (distinctAllocAndBlockArgument(other, memref)) 87 continue; 88 std::optional<bool> analysisResult = 89 analysis.isSameAllocation(other, memref); 90 if (!analysisResult.has_value() || analysisResult == true) 91 return true; 92 } 93 return false; 94 } 95 96 //===----------------------------------------------------------------------===// 97 // Patterns 98 //===----------------------------------------------------------------------===// 99 100 namespace { 101 102 /// Remove values from the `memref` operand list that are also present in the 103 /// `retained` list (or a guaranteed alias of it) because they will never 104 /// actually be deallocated. However, we also need to be certain about which 105 /// other memrefs in the `retained` list can alias, i.e., there must not by any 106 /// may-aliasing memref. This is necessary because the `dealloc` operation is 107 /// defined to return one `i1` value per memref in the `retained` list which 108 /// represents the disjunction of the condition values corresponding to all 109 /// aliasing values in the `memref` list. In particular, this means that if 110 /// there is some value R in the `retained` list which aliases with a value M in 111 /// the `memref` list (but can only be staticaly determined to may-alias) and M 112 /// is also present in the `retained` list, then it would be illegal to remove M 113 /// because the result corresponding to R would be computed incorrectly 114 /// afterwards. Because we require an alias analysis, this pattern cannot be 115 /// applied as a regular canonicalization pattern. 116 /// 117 /// Example: 118 /// ```mlir 119 /// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0) 120 /// retain (%m0, %r0, %r1 : ...) 121 /// ``` 122 /// is canonicalized to 123 /// ```mlir 124 /// // bufferization.dealloc without memrefs and conditions returns %false for 125 /// // every retained value 126 /// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...) 127 /// %1 = arith.ori %0#0, %cond0 : i1 128 /// // replace %0#0 with %1 129 /// ``` 130 /// given that `%r0` and `%r1` may not alias with `%m0`. 131 struct RemoveDeallocMemrefsContainedInRetained 132 : public OpRewritePattern<DeallocOp> { 133 RemoveDeallocMemrefsContainedInRetained(MLIRContext *context, 134 BufferOriginAnalysis &analysis) 135 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {} 136 137 /// The passed 'memref' must not have a may-alias relation to any retained 138 /// memref, and at least one must-alias relation. If there is no must-aliasing 139 /// memref in the retain list, we cannot simply remove the memref as there 140 /// could be situations in which it actually has to be deallocated. If it's 141 /// no-alias, then just proceed, if it's must-alias we need to update the 142 /// updated condition returned by the dealloc operation for that alias. 143 LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond, 144 PatternRewriter &rewriter) const { 145 rewriter.setInsertionPointAfter(deallocOp); 146 147 // Check that there is no may-aliasing memref and that at least one memref 148 // in the retain list aliases (because otherwise it might have to be 149 // deallocated in some situations and can thus not be dropped). 150 bool atLeastOneMustAlias = false; 151 for (Value retained : deallocOp.getRetained()) { 152 std::optional<bool> analysisResult = 153 analysis.isSameAllocation(retained, memref); 154 if (!analysisResult.has_value()) 155 return failure(); 156 if (analysisResult == true) 157 atLeastOneMustAlias = true; 158 } 159 if (!atLeastOneMustAlias) 160 return failure(); 161 162 // Insert arith.ori operations to update the corresponding dealloc result 163 // values to incorporate the condition of the must-aliasing memref such that 164 // we can remove that operand later on. 165 for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) { 166 Value updatedCondition = deallocOp.getUpdatedConditions()[i]; 167 std::optional<bool> analysisResult = 168 analysis.isSameAllocation(retained, memref); 169 if (analysisResult == true) { 170 auto disjunction = rewriter.create<arith::OrIOp>( 171 deallocOp.getLoc(), updatedCondition, cond); 172 rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(), 173 disjunction); 174 } 175 } 176 177 return success(); 178 } 179 180 LogicalResult matchAndRewrite(DeallocOp deallocOp, 181 PatternRewriter &rewriter) const override { 182 // There must not be any duplicates in the retain list anymore because we 183 // would miss updating one of the result values otherwise. 184 DenseSet<Value> retained(deallocOp.getRetained().begin(), 185 deallocOp.getRetained().end()); 186 if (retained.size() != deallocOp.getRetained().size()) 187 return failure(); 188 189 SmallVector<Value> newMemrefs, newConditions; 190 for (auto [memref, cond] : 191 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { 192 193 if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter))) 194 continue; 195 196 if (auto extractOp = 197 memref.getDefiningOp<memref::ExtractStridedMetadataOp>()) 198 if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond, 199 rewriter))) 200 continue; 201 202 newMemrefs.push_back(memref); 203 newConditions.push_back(cond); 204 } 205 206 // Return failure if we don't change anything such that we don't run into an 207 // infinite loop of pattern applications. 208 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, 209 rewriter); 210 } 211 212 private: 213 BufferOriginAnalysis &analysis; 214 }; 215 216 /// Remove memrefs from the `retained` list which are guaranteed to not alias 217 /// any memref in the `memrefs` list. The corresponding result value can be 218 /// replaced with `false` in that case according to the operation description. 219 /// 220 /// Example: 221 /// ```mlir 222 /// %0:2 = bufferization.dealloc (%m : memref<2xi32>) if (%cond) 223 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) 224 /// return %0#0, %0#1 225 /// ``` 226 /// can be canonicalized to the following given that `%r0` and `%r1` do not 227 /// alias `%m`: 228 /// ```mlir 229 /// bufferization.dealloc (%m : memref<2xi32>) if (%cond) 230 /// return %false, %false 231 /// ``` 232 struct RemoveRetainedMemrefsGuaranteedToNotAlias 233 : public OpRewritePattern<DeallocOp> { 234 RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context, 235 BufferOriginAnalysis &analysis) 236 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {} 237 238 LogicalResult matchAndRewrite(DeallocOp deallocOp, 239 PatternRewriter &rewriter) const override { 240 SmallVector<Value> newRetainedMemrefs, replacements; 241 242 for (auto retainedMemref : deallocOp.getRetained()) { 243 if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(), 244 retainedMemref)) { 245 newRetainedMemrefs.push_back(retainedMemref); 246 replacements.push_back({}); 247 continue; 248 } 249 250 replacements.push_back(rewriter.create<arith::ConstantOp>( 251 deallocOp.getLoc(), rewriter.getBoolAttr(false))); 252 } 253 254 if (newRetainedMemrefs.size() == deallocOp.getRetained().size()) 255 return failure(); 256 257 auto newDeallocOp = rewriter.create<DeallocOp>( 258 deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(), 259 newRetainedMemrefs); 260 int i = 0; 261 for (auto &repl : replacements) { 262 if (!repl) 263 repl = newDeallocOp.getUpdatedConditions()[i++]; 264 } 265 266 rewriter.replaceOp(deallocOp, replacements); 267 return success(); 268 } 269 270 private: 271 BufferOriginAnalysis &analysis; 272 }; 273 274 /// Split off memrefs to separate dealloc operations to reduce the number of 275 /// runtime checks required and enable further canonicalization of the new and 276 /// simpler dealloc operations. A memref can be split off if it is guaranteed to 277 /// not alias with any other memref in the `memref` operand list. The results 278 /// of the old and the new dealloc operation have to be combined by computing 279 /// the element-wise disjunction of them. 280 /// 281 /// Example: 282 /// ```mlir 283 /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>) 284 /// if (%cond0, %cond1) 285 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) 286 /// return %0#0, %0#1 287 /// ``` 288 /// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is 289 /// canonicalized to the following, thus reducing the number of runtime alias 290 /// checks by 1 and potentially enabling further canonicalization of the new 291 /// split-up dealloc operations. 292 /// ```mlir 293 /// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0) 294 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) 295 /// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1) 296 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) 297 /// %2 = arith.ori %0#0, %1#0 298 /// %3 = arith.ori %0#1, %1#1 299 /// return %2, %3 300 /// ``` 301 struct SplitDeallocWhenNotAliasingAnyOther 302 : public OpRewritePattern<DeallocOp> { 303 SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context, 304 BufferOriginAnalysis &analysis) 305 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {} 306 307 LogicalResult matchAndRewrite(DeallocOp deallocOp, 308 PatternRewriter &rewriter) const override { 309 Location loc = deallocOp.getLoc(); 310 if (deallocOp.getMemrefs().size() <= 1) 311 return failure(); 312 313 SmallVector<Value> remainingMemrefs, remainingConditions; 314 SmallVector<SmallVector<Value>> updatedConditions; 315 for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) { 316 Value memref = deallocOp.getMemrefs()[i]; 317 Value cond = deallocOp.getConditions()[i]; 318 SmallVector<Value> otherMemrefs(deallocOp.getMemrefs()); 319 otherMemrefs.erase(otherMemrefs.begin() + i); 320 // Check if `memref` can split off into a separate bufferization.dealloc. 321 if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) { 322 // `memref` alias with other memrefs, do not split off. 323 remainingMemrefs.push_back(memref); 324 remainingConditions.push_back(cond); 325 continue; 326 } 327 328 // Create new bufferization.dealloc op for `memref`. 329 auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond, 330 deallocOp.getRetained()); 331 updatedConditions.push_back( 332 llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()))); 333 } 334 335 // Fail if no memref was split off. 336 if (remainingMemrefs.size() == deallocOp.getMemrefs().size()) 337 return failure(); 338 339 // Create bufferization.dealloc op for all remaining memrefs. 340 auto newDeallocOp = rewriter.create<DeallocOp>( 341 loc, remainingMemrefs, remainingConditions, deallocOp.getRetained()); 342 343 // Bit-or all conditions. 344 SmallVector<Value> replacements = 345 llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())); 346 for (auto additionalConditions : updatedConditions) { 347 assert(replacements.size() == additionalConditions.size() && 348 "expected same number of updated conditions"); 349 for (int64_t i = 0, e = replacements.size(); i < e; ++i) { 350 replacements[i] = rewriter.create<arith::OrIOp>( 351 loc, replacements[i], additionalConditions[i]); 352 } 353 } 354 rewriter.replaceOp(deallocOp, replacements); 355 return success(); 356 } 357 358 private: 359 BufferOriginAnalysis &analysis; 360 }; 361 362 /// Check for every retained memref if a must-aliasing memref exists in the 363 /// 'memref' operand list with constant 'true' condition. If so, we can replace 364 /// the operation result corresponding to that retained memref with 'true'. If 365 /// this condition holds for all retained memrefs we can also remove the 366 /// aliasing memrefs and their conditions since they will never be deallocated 367 /// due to the must-alias and we don't need them to compute the result value 368 /// anymore since it got replaced with 'true'. 369 /// 370 /// Example: 371 /// ```mlir 372 /// %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : ...) 373 /// if (%true, %true, %true) 374 /// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) 375 /// ``` 376 /// becomes 377 /// ```mlir 378 /// %0:2 = bufferization.dealloc (%arg2 : memref<2xi32>) if (%true) 379 /// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) 380 /// // replace %0#0 with %true 381 /// // replace %0#1 with %true 382 /// ``` 383 /// Note that the dealloc operation will still have the result values, but they 384 /// don't have uses anymore. 385 struct RetainedMemrefAliasingAlwaysDeallocatedMemref 386 : public OpRewritePattern<DeallocOp> { 387 RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context, 388 BufferOriginAnalysis &analysis) 389 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {} 390 391 LogicalResult matchAndRewrite(DeallocOp deallocOp, 392 PatternRewriter &rewriter) const override { 393 BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size()); 394 SmallVector<Value> newMemrefs, newConditions; 395 for (auto [memref, cond] : 396 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { 397 bool canDropMemref = false; 398 for (auto [i, retained, res] : llvm::enumerate( 399 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) { 400 if (!matchPattern(cond, m_One())) 401 continue; 402 403 std::optional<bool> analysisResult = 404 analysis.isSameAllocation(retained, memref); 405 if (analysisResult == true) { 406 rewriter.replaceAllUsesWith(res, cond); 407 aliasesWithConstTrueMemref[i] = true; 408 canDropMemref = true; 409 continue; 410 } 411 412 // TODO: once our alias analysis is powerful enough we can remove the 413 // rest of this loop body 414 auto extractOp = 415 memref.getDefiningOp<memref::ExtractStridedMetadataOp>(); 416 if (!extractOp) 417 continue; 418 419 std::optional<bool> extractAnalysisResult = 420 analysis.isSameAllocation(retained, extractOp.getOperand()); 421 if (extractAnalysisResult == true) { 422 rewriter.replaceAllUsesWith(res, cond); 423 aliasesWithConstTrueMemref[i] = true; 424 canDropMemref = true; 425 } 426 } 427 428 if (!canDropMemref) { 429 newMemrefs.push_back(memref); 430 newConditions.push_back(cond); 431 } 432 } 433 if (!aliasesWithConstTrueMemref.all()) 434 return failure(); 435 436 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, 437 rewriter); 438 } 439 440 private: 441 BufferOriginAnalysis &analysis; 442 }; 443 444 } // namespace 445 446 //===----------------------------------------------------------------------===// 447 // BufferDeallocationSimplificationPass 448 //===----------------------------------------------------------------------===// 449 450 namespace { 451 452 /// The actual buffer deallocation pass that inserts and moves dealloc nodes 453 /// into the right positions. Furthermore, it inserts additional clones if 454 /// necessary. It uses the algorithm described at the top of the file. 455 struct BufferDeallocationSimplificationPass 456 : public bufferization::impl::BufferDeallocationSimplificationBase< 457 BufferDeallocationSimplificationPass> { 458 void runOnOperation() override { 459 BufferOriginAnalysis analysis(getOperation()); 460 RewritePatternSet patterns(&getContext()); 461 patterns.add<RemoveDeallocMemrefsContainedInRetained, 462 RemoveRetainedMemrefsGuaranteedToNotAlias, 463 SplitDeallocWhenNotAliasingAnyOther, 464 RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(), 465 analysis); 466 // We don't want that the block structure changes invalidating the 467 // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of 468 // region simplification 469 GreedyRewriteConfig config; 470 config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal; 471 populateDeallocOpCanonicalizationPatterns(patterns, &getContext()); 472 473 if (failed( 474 applyPatternsGreedily(getOperation(), std::move(patterns), config))) 475 signalPassFailure(); 476 } 477 }; 478 479 } // namespace 480 481 std::unique_ptr<Pass> 482 mlir::bufferization::createBufferDeallocationSimplificationPass() { 483 return std::make_unique<BufferDeallocationSimplificationPass>(); 484 } 485