1 //===- LowerDeallocations.cpp - Bufferization Deallocs to MemRef pass -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert `bufferization.dealloc` operations 10 // to the MemRef dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 24 namespace mlir { 25 namespace bufferization { 26 #define GEN_PASS_DEF_LOWERDEALLOCATIONS 27 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 28 } // namespace bufferization 29 } // namespace mlir 30 31 using namespace mlir; 32 33 namespace { 34 /// The DeallocOpConversion transforms all bufferization dealloc operations into 35 /// memref dealloc operations potentially guarded by scf if operations. 36 /// Additionally, memref extract_aligned_pointer_as_index and arith operations 37 /// are inserted to compute the guard conditions. We distinguish multiple cases 38 /// to provide an overall more efficient lowering. In the general case, a helper 39 /// func is created to avoid quadratic code size explosion (relative to the 40 /// number of operands of the dealloc operation). For examples of each case, 41 /// refer to the documentation of the member functions of this class. 42 class DeallocOpConversion 43 : public OpConversionPattern<bufferization::DeallocOp> { 44 45 /// Lower a simple case without any retained values and a single memref to 46 /// avoiding the helper function. Ideally, static analysis can provide enough 47 /// aliasing information to split the dealloc operations up into this simple 48 /// case as much as possible before running this pass. 49 /// 50 /// Example: 51 /// ``` 52 /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1) 53 /// ``` 54 /// is lowered to 55 /// ``` 56 /// scf.if %arg1 { 57 /// memref.dealloc %arg0 : memref<2xf32> 58 /// } 59 /// ``` 60 LogicalResult 61 rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor, 62 ConversionPatternRewriter &rewriter) const { 63 assert(adaptor.getMemrefs().size() == 1 && "expected only one memref"); 64 assert(adaptor.getRetained().empty() && "expected no retained memrefs"); 65 66 rewriter.replaceOpWithNewOp<scf::IfOp>( 67 op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) { 68 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]); 69 builder.create<scf::YieldOp>(loc); 70 }); 71 return success(); 72 } 73 74 /// A special case lowering for the deallocation operation with exactly one 75 /// memref, but arbitrary number of retained values. This avoids the helper 76 /// function that the general case needs and thus also avoids storing indices 77 /// to specifically allocated memrefs. The size of the code produced by this 78 /// lowering is linear to the number of retained values. 79 /// 80 /// Example: 81 /// ```mlir 82 /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond) 83 // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>) 84 /// return %0#0, %0#1 : i1, i1 85 /// ``` 86 /// ```mlir 87 /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m 88 /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0 89 /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer 90 /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1 91 /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer 92 /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1 93 /// %should_dealloc = arith.andi %not_retained, %cond : i1 94 /// scf.if %should_dealloc { 95 /// memref.dealloc %m : memref<2xf32> 96 /// } 97 /// %true = arith.constant true 98 /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1 99 /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1 100 /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1 101 /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1 102 /// return %r0_ownership, %r1_ownership : i1, i1 103 /// ``` 104 LogicalResult rewriteOneMemrefMultipleRetainCase( 105 bufferization::DeallocOp op, OpAdaptor adaptor, 106 ConversionPatternRewriter &rewriter) const { 107 assert(adaptor.getMemrefs().size() == 1 && "expected only one memref"); 108 109 // Compute the base pointer indices, compare all retained indices to the 110 // memref index to check if they alias. 111 SmallVector<Value> doesNotAliasList; 112 Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>( 113 op->getLoc(), adaptor.getMemrefs()[0]); 114 for (Value retained : adaptor.getRetained()) { 115 Value retainedAsIdx = 116 rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(), 117 retained); 118 Value doesNotAlias = rewriter.create<arith::CmpIOp>( 119 op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx); 120 doesNotAliasList.push_back(doesNotAlias); 121 } 122 123 // AND-reduce the list of booleans from above. 124 Value prev = doesNotAliasList.front(); 125 for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front()) 126 prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias); 127 128 // Also consider the condition given by the dealloc operation and perform a 129 // conditional deallocation guarded by that value. 130 Value shouldDealloc = rewriter.create<arith::AndIOp>( 131 op->getLoc(), prev, adaptor.getConditions()[0]); 132 133 rewriter.create<scf::IfOp>( 134 op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { 135 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]); 136 builder.create<scf::YieldOp>(loc); 137 }); 138 139 // Compute the replacement values for the dealloc operation results. This 140 // inserts an already canonicalized form of 141 // `select(does_alias_with_memref(r), memref_cond, false)` for each retained 142 // value r. 143 SmallVector<Value> replacements; 144 Value trueVal = rewriter.create<arith::ConstantOp>( 145 op->getLoc(), rewriter.getBoolAttr(true)); 146 for (Value doesNotAlias : doesNotAliasList) { 147 Value aliases = 148 rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal); 149 Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases, 150 adaptor.getConditions()[0]); 151 replacements.push_back(result); 152 } 153 154 rewriter.replaceOp(op, replacements); 155 156 return success(); 157 } 158 159 /// Lowering that supports all features the dealloc operation has to offer. It 160 /// computes the base pointer of each memref (as an index), stores it in a 161 /// new memref helper structure and passes it to the helper function generated 162 /// in 'buildDeallocationHelperFunction'. The results are stored in two lists 163 /// (represented as memrefs) of booleans passed as arguments. The first list 164 /// stores whether the corresponding condition should be deallocated, the 165 /// second list stores the ownership of the retained values which can be used 166 /// to replace the result values of the `bufferization.dealloc` operation. 167 /// 168 /// Example: 169 /// ``` 170 /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>) 171 /// if (%cond0, %cond1) 172 /// retain (%r0, %r1 : memref<1xf32>, memref<2xf32>) 173 /// ``` 174 /// lowers to (simplified): 175 /// ``` 176 /// %c0 = arith.constant 0 : index 177 /// %c1 = arith.constant 1 : index 178 /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex> 179 /// %cond_list = memref.alloc() : memref<2xi1> 180 /// %retain_base_pointer_list = memref.alloc() : memref<2xindex> 181 /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0 182 /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0] 183 /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1 184 /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1] 185 /// memref.store %cond0, %cond_list[%c0] 186 /// memref.store %cond1, %cond_list[%c1] 187 /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0 188 /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0] 189 /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1 190 /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1] 191 /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list : 192 /// memref<2xindex> to memref<?xindex> 193 /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1> 194 /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list : 195 /// memref<2xindex> to memref<?xindex> 196 /// %dealloc_cond_out = memref.alloc() : memref<2xi1> 197 /// %ownership_out = memref.alloc() : memref<2xi1> 198 /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out : 199 /// memref<2xi1> to memref<?xi1> 200 /// %dyn_ownership_out = memref.cast %ownership_out : 201 /// memref<2xi1> to memref<?xi1> 202 /// call @dealloc_helper(%dyn_dealloc_base_pointer_list, 203 /// %dyn_retain_base_pointer_list, 204 /// %dyn_cond_list, 205 /// %dyn_dealloc_cond_out, 206 /// %dyn_ownership_out) : (...) 207 /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1> 208 /// scf.if %m0_dealloc_cond { 209 /// memref.dealloc %m0 : memref<2xf32> 210 /// } 211 /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1> 212 /// scf.if %m1_dealloc_cond { 213 /// memref.dealloc %m1 : memref<5xf32> 214 /// } 215 /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1> 216 /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1> 217 /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex> 218 /// memref.dealloc %retain_base_pointer_list : memref<2xindex> 219 /// memref.dealloc %cond_list : memref<2xi1> 220 /// memref.dealloc %dealloc_cond_out : memref<2xi1> 221 /// memref.dealloc %ownership_out : memref<2xi1> 222 /// // replace %0#0 with %r0_ownership 223 /// // replace %0#1 with %r1_ownership 224 /// ``` 225 LogicalResult rewriteGeneralCase(bufferization::DeallocOp op, 226 OpAdaptor adaptor, 227 ConversionPatternRewriter &rewriter) const { 228 // Allocate two memrefs holding the base pointer indices of the list of 229 // memrefs to be deallocated and the ones to be retained. These can then be 230 // passed to the helper function and the for-loops can iterate over them. 231 // Without storing them to memrefs, we could not use for-loops but only a 232 // completely unrolled version of it, potentially leading to code-size 233 // blow-up. 234 Value toDeallocMemref = rewriter.create<memref::AllocOp>( 235 op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, 236 rewriter.getIndexType())); 237 Value conditionMemref = rewriter.create<memref::AllocOp>( 238 op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()}, 239 rewriter.getI1Type())); 240 Value toRetainMemref = rewriter.create<memref::AllocOp>( 241 op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, 242 rewriter.getIndexType())); 243 244 auto getConstValue = [&](uint64_t value) -> Value { 245 return rewriter.create<arith::ConstantOp>(op.getLoc(), 246 rewriter.getIndexAttr(value)); 247 }; 248 249 // Extract the base pointers of the memrefs as indices to check for aliasing 250 // at runtime. 251 for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) { 252 Value memrefAsIdx = 253 rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(), 254 toDealloc); 255 rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, 256 toDeallocMemref, getConstValue(i)); 257 } 258 259 for (auto [i, cond] : llvm::enumerate(adaptor.getConditions())) 260 rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref, 261 getConstValue(i)); 262 263 for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) { 264 Value memrefAsIdx = 265 rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(), 266 toRetain); 267 rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref, 268 getConstValue(i)); 269 } 270 271 // Cast the allocated memrefs to dynamic shape because we want only one 272 // helper function no matter how many operands the bufferization.dealloc 273 // has. 274 Value castedDeallocMemref = rewriter.create<memref::CastOp>( 275 op->getLoc(), 276 MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), 277 toDeallocMemref); 278 Value castedCondsMemref = rewriter.create<memref::CastOp>( 279 op->getLoc(), 280 MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), 281 conditionMemref); 282 Value castedRetainMemref = rewriter.create<memref::CastOp>( 283 op->getLoc(), 284 MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), 285 toRetainMemref); 286 287 Value deallocCondsMemref = rewriter.create<memref::AllocOp>( 288 op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, 289 rewriter.getI1Type())); 290 Value retainCondsMemref = rewriter.create<memref::AllocOp>( 291 op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, 292 rewriter.getI1Type())); 293 294 Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>( 295 op->getLoc(), 296 MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), 297 deallocCondsMemref); 298 Value castedRetainCondsMemref = rewriter.create<memref::CastOp>( 299 op->getLoc(), 300 MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), 301 retainCondsMemref); 302 303 Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>(); 304 rewriter.create<func::CallOp>( 305 op.getLoc(), deallocHelperFuncMap.lookup(symtableOp), 306 SmallVector<Value>{castedDeallocMemref, castedRetainMemref, 307 castedCondsMemref, castedDeallocCondsMemref, 308 castedRetainCondsMemref}); 309 310 for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) { 311 Value idxValue = getConstValue(i); 312 Value shouldDealloc = rewriter.create<memref::LoadOp>( 313 op.getLoc(), deallocCondsMemref, idxValue); 314 rewriter.create<scf::IfOp>( 315 op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { 316 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]); 317 builder.create<scf::YieldOp>(loc); 318 }); 319 } 320 321 SmallVector<Value> replacements; 322 for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) { 323 Value idxValue = getConstValue(i); 324 Value ownership = rewriter.create<memref::LoadOp>( 325 op.getLoc(), retainCondsMemref, idxValue); 326 replacements.push_back(ownership); 327 } 328 329 // Deallocate above allocated memrefs again to avoid memory leaks. 330 // Deallocation will not be run on code after this stage. 331 rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref); 332 rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref); 333 rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref); 334 rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref); 335 rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref); 336 337 rewriter.replaceOp(op, replacements); 338 return success(); 339 } 340 341 public: 342 DeallocOpConversion( 343 MLIRContext *context, 344 const bufferization::DeallocHelperMap &deallocHelperFuncMap) 345 : OpConversionPattern<bufferization::DeallocOp>(context), 346 deallocHelperFuncMap(deallocHelperFuncMap) {} 347 348 LogicalResult 349 matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor, 350 ConversionPatternRewriter &rewriter) const override { 351 // Lower the trivial case. 352 if (adaptor.getMemrefs().empty()) { 353 Value falseVal = rewriter.create<arith::ConstantOp>( 354 op.getLoc(), rewriter.getBoolAttr(false)); 355 rewriter.replaceOp( 356 op, SmallVector<Value>(adaptor.getRetained().size(), falseVal)); 357 return success(); 358 } 359 360 if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty()) 361 return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter); 362 363 if (adaptor.getMemrefs().size() == 1) 364 return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter); 365 366 Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>(); 367 if (!deallocHelperFuncMap.contains(symtableOp)) 368 return op->emitError( 369 "library function required for generic lowering, but cannot be " 370 "automatically inserted when operating on functions"); 371 372 return rewriteGeneralCase(op, adaptor, rewriter); 373 } 374 375 private: 376 const bufferization::DeallocHelperMap &deallocHelperFuncMap; 377 }; 378 } // namespace 379 380 namespace { 381 struct LowerDeallocationsPass 382 : public bufferization::impl::LowerDeallocationsBase< 383 LowerDeallocationsPass> { 384 void runOnOperation() override { 385 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) { 386 emitError(getOperation()->getLoc(), 387 "root operation must be a builtin.module or a function"); 388 signalPassFailure(); 389 return; 390 } 391 392 bufferization::DeallocHelperMap deallocHelperFuncMap; 393 if (auto module = dyn_cast<ModuleOp>(getOperation())) { 394 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 395 396 // Build dealloc helper function if there are deallocs. 397 getOperation()->walk([&](bufferization::DeallocOp deallocOp) { 398 Operation *symtableOp = 399 deallocOp->getParentWithTrait<OpTrait::SymbolTable>(); 400 if (deallocOp.getMemrefs().size() > 1 && 401 !deallocHelperFuncMap.contains(symtableOp)) { 402 SymbolTable symbolTable(symtableOp); 403 func::FuncOp helperFuncOp = 404 bufferization::buildDeallocationLibraryFunction( 405 builder, getOperation()->getLoc(), symbolTable); 406 deallocHelperFuncMap[symtableOp] = helperFuncOp; 407 } 408 }); 409 } 410 411 RewritePatternSet patterns(&getContext()); 412 bufferization::populateBufferizationDeallocLoweringPattern( 413 patterns, deallocHelperFuncMap); 414 415 ConversionTarget target(getContext()); 416 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect, 417 scf::SCFDialect, func::FuncDialect>(); 418 target.addIllegalOp<bufferization::DeallocOp>(); 419 420 if (failed(applyPartialConversion(getOperation(), target, 421 std::move(patterns)))) 422 signalPassFailure(); 423 } 424 }; 425 } // namespace 426 427 func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( 428 OpBuilder &builder, Location loc, SymbolTable &symbolTable) { 429 Type indexMemrefType = 430 MemRefType::get({ShapedType::kDynamic}, builder.getIndexType()); 431 Type boolMemrefType = 432 MemRefType::get({ShapedType::kDynamic}, builder.getI1Type()); 433 SmallVector<Type> argTypes{indexMemrefType, indexMemrefType, boolMemrefType, 434 boolMemrefType, boolMemrefType}; 435 builder.clearInsertionPoint(); 436 437 // Generate the func operation itself. 438 auto helperFuncOp = func::FuncOp::create( 439 loc, "dealloc_helper", builder.getFunctionType(argTypes, {})); 440 helperFuncOp.setVisibility(SymbolTable::Visibility::Private); 441 symbolTable.insert(helperFuncOp); 442 auto &block = helperFuncOp.getFunctionBody().emplaceBlock(); 443 block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc)); 444 445 builder.setInsertionPointToStart(&block); 446 Value toDeallocMemref = helperFuncOp.getArguments()[0]; 447 Value toRetainMemref = helperFuncOp.getArguments()[1]; 448 Value conditionMemref = helperFuncOp.getArguments()[2]; 449 Value deallocCondsMemref = helperFuncOp.getArguments()[3]; 450 Value retainCondsMemref = helperFuncOp.getArguments()[4]; 451 452 // Insert some prerequisites. 453 Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0)); 454 Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1)); 455 Value trueValue = 456 builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true)); 457 Value falseValue = 458 builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false)); 459 Value toDeallocSize = builder.create<memref::DimOp>(loc, toDeallocMemref, c0); 460 Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0); 461 462 builder.create<scf::ForOp>( 463 loc, c0, toRetainSize, c1, std::nullopt, 464 [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { 465 builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i); 466 builder.create<scf::YieldOp>(loc); 467 }); 468 469 builder.create<scf::ForOp>( 470 loc, c0, toDeallocSize, c1, std::nullopt, 471 [&](OpBuilder &builder, Location loc, Value outerIter, 472 ValueRange iterArgs) { 473 Value toDealloc = 474 builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter); 475 Value cond = 476 builder.create<memref::LoadOp>(loc, conditionMemref, outerIter); 477 478 // Build the first for loop that computes aliasing with retained 479 // memrefs. 480 Value noRetainAlias = 481 builder 482 .create<scf::ForOp>( 483 loc, c0, toRetainSize, c1, trueValue, 484 [&](OpBuilder &builder, Location loc, Value i, 485 ValueRange iterArgs) { 486 Value retainValue = builder.create<memref::LoadOp>( 487 loc, toRetainMemref, i); 488 Value doesAlias = builder.create<arith::CmpIOp>( 489 loc, arith::CmpIPredicate::eq, retainValue, 490 toDealloc); 491 builder.create<scf::IfOp>( 492 loc, doesAlias, 493 [&](OpBuilder &builder, Location loc) { 494 Value retainCondValue = 495 builder.create<memref::LoadOp>( 496 loc, retainCondsMemref, i); 497 Value aggregatedRetainCond = 498 builder.create<arith::OrIOp>( 499 loc, retainCondValue, cond); 500 builder.create<memref::StoreOp>( 501 loc, aggregatedRetainCond, retainCondsMemref, 502 i); 503 builder.create<scf::YieldOp>(loc); 504 }); 505 Value doesntAlias = builder.create<arith::CmpIOp>( 506 loc, arith::CmpIPredicate::ne, retainValue, 507 toDealloc); 508 Value yieldValue = builder.create<arith::AndIOp>( 509 loc, iterArgs[0], doesntAlias); 510 builder.create<scf::YieldOp>(loc, yieldValue); 511 }) 512 .getResult(0); 513 514 // Build the second for loop that adds aliasing with previously 515 // deallocated memrefs. 516 Value noAlias = 517 builder 518 .create<scf::ForOp>( 519 loc, c0, outerIter, c1, noRetainAlias, 520 [&](OpBuilder &builder, Location loc, Value i, 521 ValueRange iterArgs) { 522 Value prevDeallocValue = builder.create<memref::LoadOp>( 523 loc, toDeallocMemref, i); 524 Value doesntAlias = builder.create<arith::CmpIOp>( 525 loc, arith::CmpIPredicate::ne, prevDeallocValue, 526 toDealloc); 527 Value yieldValue = builder.create<arith::AndIOp>( 528 loc, iterArgs[0], doesntAlias); 529 builder.create<scf::YieldOp>(loc, yieldValue); 530 }) 531 .getResult(0); 532 533 Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond); 534 builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref, 535 outerIter); 536 builder.create<scf::YieldOp>(loc); 537 }); 538 539 builder.create<func::ReturnOp>(loc); 540 return helperFuncOp; 541 } 542 543 void mlir::bufferization::populateBufferizationDeallocLoweringPattern( 544 RewritePatternSet &patterns, 545 const bufferization::DeallocHelperMap &deallocHelperFuncMap) { 546 patterns.add<DeallocOpConversion>(patterns.getContext(), 547 deallocHelperFuncMap); 548 } 549 550 std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() { 551 return std::make_unique<LowerDeallocationsPass>(); 552 } 553