1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/Arith.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/Matchers.h" 17 #include <optional> 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 22 //===----------------------------------------------------------------------===// 23 // Helper functions 24 //===----------------------------------------------------------------------===// 25 26 FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue( 27 OpBuilder &b, Value value, MemRefType destType, 28 const BufferizationOptions &options) { 29 auto srcType = llvm::cast<MemRefType>(value.getType()); 30 31 // Element type, rank and memory space must match. 32 if (srcType.getElementType() != destType.getElementType()) 33 return failure(); 34 if (srcType.getMemorySpace() != destType.getMemorySpace()) 35 return failure(); 36 if (srcType.getRank() != destType.getRank()) 37 return failure(); 38 39 // In case the affine maps are different, we may need to use a copy if we go 40 // from dynamic to static offset or stride (the canonicalization cannot know 41 // at this point that it is really cast compatible). 42 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 43 int64_t sourceOffset, targetOffset; 44 SmallVector<int64_t, 4> sourceStrides, targetStrides; 45 if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) || 46 failed(target.getStridesAndOffset(targetStrides, targetOffset))) 47 return false; 48 auto dynamicToStatic = [](int64_t a, int64_t b) { 49 return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b); 50 }; 51 if (dynamicToStatic(sourceOffset, targetOffset)) 52 return false; 53 for (auto it : zip(sourceStrides, targetStrides)) 54 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 55 return false; 56 return true; 57 }; 58 59 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 60 // ensure that we only generate casts that always succeed at runtime, we check 61 // a fix extra conditions in `isGuaranteedCastCompatible`. 62 if (memref::CastOp::areCastCompatible(srcType, destType) && 63 isGuaranteedCastCompatible(srcType, destType)) { 64 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 65 return casted; 66 } 67 68 auto loc = value.getLoc(); 69 SmallVector<Value, 4> dynamicOperands; 70 for (int i = 0; i < destType.getRank(); ++i) { 71 if (destType.getShape()[i] != ShapedType::kDynamic) 72 continue; 73 Value size = b.create<memref::DimOp>(loc, value, i); 74 dynamicOperands.push_back(size); 75 } 76 77 FailureOr<Value> copy = 78 options.createAlloc(b, loc, destType, dynamicOperands); 79 if (failed(copy)) 80 return failure(); 81 if (failed(options.createMemCpy(b, loc, value, *copy))) 82 return failure(); 83 return copy; 84 } 85 86 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 87 /// to_memref op are different, a memref.cast is needed. 88 LogicalResult mlir::bufferization::foldToMemrefToTensorPair( 89 RewriterBase &rewriter, ToMemrefOp toMemref, 90 const BufferizationOptions &options) { 91 auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>(); 92 if (!memrefToTensor) 93 return failure(); 94 95 Type srcType = memrefToTensor.getMemref().getType(); 96 Type destType = toMemref.getType(); 97 98 // Directly rewrite if the type did not change. 99 if (srcType == destType) { 100 rewriter.replaceOp(toMemref, memrefToTensor.getMemref()); 101 return success(); 102 } 103 104 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType); 105 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType); 106 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType); 107 108 // Ranked memref -> Ranked memref cast. 109 if (rankedSrcType && rankedDestType) { 110 FailureOr<Value> replacement = castOrReallocMemRefValue( 111 rewriter, memrefToTensor.getMemref(), rankedDestType, options); 112 if (failed(replacement)) 113 return failure(); 114 115 rewriter.replaceOp(toMemref, *replacement); 116 return success(); 117 } 118 119 // Unranked memref -> Ranked memref cast: May require a copy. 120 // TODO: Not implemented at the moment. 121 if (unrankedSrcType && rankedDestType) 122 return failure(); 123 124 // Unranked memref -> unranked memref cast 125 // Ranked memref -> unranked memref cast: No copy needed. 126 assert(memref::CastOp::areCastCompatible(srcType, destType) && 127 "expected that types are cast compatible"); 128 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 129 memrefToTensor.getMemref()); 130 return success(); 131 } 132 133 void mlir::bufferization::populateDynamicDimSizes( 134 OpBuilder &b, Location loc, Value shapedValue, 135 SmallVector<Value> &dynamicDims) { 136 auto shapedType = llvm::cast<ShapedType>(shapedValue.getType()); 137 for (int64_t i = 0; i < shapedType.getRank(); ++i) { 138 if (shapedType.isDynamicDim(i)) { 139 if (llvm::isa<MemRefType>(shapedType)) { 140 dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i)); 141 } else { 142 assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor"); 143 dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i)); 144 } 145 } 146 } 147 } 148 149 //===----------------------------------------------------------------------===// 150 // AllocTensorOp 151 //===----------------------------------------------------------------------===// 152 153 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, 154 const BufferizationOptions &options) { 155 OpBuilder::InsertionGuard g(rewriter); 156 Location loc = getLoc(); 157 158 // Nothing to do for dead AllocTensorOps. 159 if (getOperation()->getUses().empty()) { 160 rewriter.eraseOp(getOperation()); 161 return success(); 162 } 163 164 // Get "copy" buffer. 165 Value copyBuffer; 166 if (getCopy()) { 167 FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options); 168 if (failed(maybeCopyBuffer)) 169 return failure(); 170 copyBuffer = *maybeCopyBuffer; 171 } 172 173 // Create memory allocation. 174 auto allocType = bufferization::getBufferType(getResult(), options); 175 if (failed(allocType)) 176 return failure(); 177 SmallVector<Value> dynamicDims = getDynamicSizes(); 178 if (getCopy()) { 179 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`"); 180 populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); 181 } 182 FailureOr<Value> alloc = options.createAlloc( 183 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims); 184 if (failed(alloc)) 185 return failure(); 186 187 // Create memory copy (if any). 188 if (getCopy()) { 189 if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc))) 190 return failure(); 191 } 192 193 // Replace op. 194 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); 195 196 return success(); 197 } 198 199 bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult, 200 const AnalysisState &state) { 201 // AllocTensorOps do not write unless they have a `copy` value. 202 return static_cast<bool>(getCopy()); 203 } 204 205 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, 206 const AnalysisState &state) { 207 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 208 "expected copy operand"); 209 return true; 210 } 211 212 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, 213 const AnalysisState &state) { 214 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 215 "expected copy operand"); 216 return false; 217 } 218 219 AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand, 220 const AnalysisState &state) { 221 // This is a new allocation. It does not alias with any other buffer. 222 return {}; 223 } 224 225 FailureOr<BaseMemRefType> 226 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, 227 SmallVector<Value> &invocationStack) { 228 assert(value == getResult() && "invalid value"); 229 230 // Compute memory space of this allocation. 231 Attribute memorySpace; 232 if (getMemorySpace().has_value()) { 233 memorySpace = *getMemorySpace(); 234 } else if (getCopy()) { 235 auto copyBufferType = 236 bufferization::getBufferType(getCopy(), options, invocationStack); 237 if (failed(copyBufferType)) 238 return failure(); 239 memorySpace = copyBufferType->getMemorySpace(); 240 } else if (auto ms = options.defaultMemorySpaceFn(getType())) { 241 memorySpace = *ms; 242 } else { 243 return getOperation()->emitError("could not infer memory space"); 244 } 245 246 return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace); 247 } 248 249 LogicalResult AllocTensorOp::verify() { 250 if (getCopy() && !getDynamicSizes().empty()) 251 return emitError("dynamic sizes not needed when copying a tensor"); 252 if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size()) 253 return emitError("expected ") 254 << getType().getNumDynamicDims() << " dynamic sizes"; 255 if (getCopy() && getCopy().getType() != getType()) 256 return emitError("expected that `copy` and return type match"); 257 return success(); 258 } 259 260 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 261 RankedTensorType type, ValueRange dynamicSizes) { 262 build(builder, result, type, dynamicSizes, /*copy=*/Value(), 263 /*size_hint=*/Value(), 264 /*memory_space=*/IntegerAttr()); 265 } 266 267 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 268 RankedTensorType type, ValueRange dynamicSizes, 269 Value copy) { 270 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(), 271 /*memory_space=*/IntegerAttr()); 272 } 273 274 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 275 TensorType type, ValueRange dynamicSizes, Value copy, 276 IntegerAttr memorySpace) { 277 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(), 278 memorySpace); 279 } 280 281 namespace { 282 /// Change the type of the result of a `bufferization.alloc_tensor` by making 283 /// the result type statically sized along dimension that in the original 284 /// operation where defined as dynamic, but the size was defined using a 285 /// `constant` op. For example: 286 /// 287 /// %c5 = arith.constant 5: index 288 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32> 289 /// 290 /// to 291 /// 292 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32> 293 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> { 294 using OpRewritePattern<AllocTensorOp>::OpRewritePattern; 295 296 LogicalResult matchAndRewrite(AllocTensorOp op, 297 PatternRewriter &rewriter) const override { 298 if (op.getCopy()) 299 return failure(); 300 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape()); 301 SmallVector<Value> newDynamicSizes; 302 unsigned int dynValCounter = 0; 303 for (int64_t i = 0; i < op.getType().getRank(); ++i) { 304 if (!op.isDynamicDim(i)) 305 continue; 306 Value value = op.getDynamicSizes()[dynValCounter++]; 307 APInt intVal; 308 if (matchPattern(value, m_ConstantInt(&intVal))) { 309 int64_t dim = intVal.getSExtValue(); 310 if (dim >= 0) 311 newShape[i] = intVal.getSExtValue(); 312 else 313 newDynamicSizes.push_back(value); 314 } else { 315 newDynamicSizes.push_back(value); 316 } 317 } 318 RankedTensorType newType = RankedTensorType::get( 319 newShape, op.getType().getElementType(), op.getType().getEncoding()); 320 if (newType == op.getType()) 321 return failure(); 322 auto newOp = rewriter.create<AllocTensorOp>( 323 op.getLoc(), newType, newDynamicSizes, /*copy=*/Value()); 324 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 325 return success(); 326 } 327 }; 328 329 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> { 330 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 331 332 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 333 PatternRewriter &rewriter) const override { 334 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 335 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>(); 336 if (!allocTensorOp || !maybeConstantIndex) 337 return failure(); 338 if (*maybeConstantIndex < 0 || 339 *maybeConstantIndex >= allocTensorOp.getType().getRank()) 340 return failure(); 341 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex)) 342 return failure(); 343 rewriter.replaceOp( 344 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex)); 345 return success(); 346 } 347 }; 348 } // namespace 349 350 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 351 MLIRContext *ctx) { 352 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx); 353 } 354 355 LogicalResult AllocTensorOp::reifyResultShapes( 356 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 357 auto shapes = llvm::to_vector<4>( 358 llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()), 359 [&](int64_t dim) -> OpFoldResult { 360 if (isDynamicDim(dim)) 361 return getDynamicSize(builder, dim); 362 return builder.getIndexAttr(getStaticSize(dim)); 363 })); 364 reifiedReturnShapes.emplace_back(std::move(shapes)); 365 return success(); 366 } 367 368 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) { 369 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands; 370 if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) || 371 parser.parseRParen()) 372 return failure(); 373 ParseResult copyKeyword = parser.parseOptionalKeyword("copy"); 374 OpAsmParser::UnresolvedOperand copyOperand; 375 if (copyKeyword.succeeded()) 376 if (parser.parseLParen() || parser.parseOperand(copyOperand) || 377 parser.parseRParen()) 378 return failure(); 379 ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint"); 380 OpAsmParser::UnresolvedOperand sizeHintOperand; 381 if (sizeHintKeyword.succeeded()) 382 if (parser.parseEqual() || parser.parseOperand(sizeHintOperand)) 383 return failure(); 384 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) 385 return failure(); 386 387 TensorType type; 388 if (parser.parseCustomTypeWithFallback(type)) 389 return failure(); 390 result.addTypes(type); 391 392 Type indexType = parser.getBuilder().getIndexType(); 393 if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands)) 394 return failure(); 395 if (copyKeyword.succeeded()) 396 if (parser.resolveOperand(copyOperand, type, result.operands)) 397 return failure(); 398 if (sizeHintKeyword.succeeded()) 399 if (parser.resolveOperand(sizeHintOperand, indexType, result.operands)) 400 return failure(); 401 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(), 402 parser.getBuilder().getDenseI32ArrayAttr( 403 {static_cast<int32_t>(dynamicSizesOperands.size()), 404 static_cast<int32_t>(copyKeyword.succeeded()), 405 static_cast<int32_t>(sizeHintKeyword.succeeded())})); 406 return success(); 407 } 408 409 void AllocTensorOp::print(OpAsmPrinter &p) { 410 p << "(" << getDynamicSizes() << ")"; 411 if (getCopy()) 412 p << " copy(" << getCopy() << ")"; 413 if (getSizeHint()) 414 p << " size_hint=" << getSizeHint(); 415 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ 416 AllocTensorOp::getOperandSegmentSizeAttr()}); 417 p << " : "; 418 auto type = getResult().getType(); 419 if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type)) 420 p.printStrippedAttrOrType(validType); 421 else 422 p << type; 423 } 424 425 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) { 426 assert(isDynamicDim(idx) && "expected dynamic dim"); 427 if (getCopy()) 428 return b.create<tensor::DimOp>(getLoc(), getCopy(), idx); 429 return getOperand(getIndexOfDynamicSize(idx)); 430 } 431 432 //===----------------------------------------------------------------------===// 433 // CloneOp 434 //===----------------------------------------------------------------------===// 435 436 OpFoldResult CloneOp::fold(FoldAdaptor adaptor) { 437 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 438 } 439 440 namespace { 441 442 /// Merge the clone and its source (by converting the clone to a cast) when 443 /// possible. 444 struct SimplifyClones : public OpRewritePattern<CloneOp> { 445 using OpRewritePattern<CloneOp>::OpRewritePattern; 446 447 LogicalResult matchAndRewrite(CloneOp cloneOp, 448 PatternRewriter &rewriter) const override { 449 if (cloneOp.use_empty()) { 450 rewriter.eraseOp(cloneOp); 451 return success(); 452 } 453 454 Value source = cloneOp.getInput(); 455 if (source.getType() != cloneOp.getType() && 456 !memref::CastOp::areCastCompatible({source.getType()}, 457 {cloneOp.getType()})) 458 return failure(); 459 460 // Aims to find the dealloc op for the canonical source 461 // which otherwise could prevent removal of unnecessary allocs. 462 Value canonicalSource = source; 463 while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>( 464 canonicalSource.getDefiningOp())) 465 canonicalSource = iface.getViewSource(); 466 467 std::optional<Operation *> maybeCloneDeallocOp = 468 memref::findDealloc(cloneOp.getOutput()); 469 // Skip if either of them has > 1 deallocate operations. 470 if (!maybeCloneDeallocOp.has_value()) 471 return failure(); 472 std::optional<Operation *> maybeSourceDeallocOp = 473 memref::findDealloc(canonicalSource); 474 if (!maybeSourceDeallocOp.has_value()) 475 return failure(); 476 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 477 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 478 479 // If both are deallocated in the same block, their in-block lifetimes 480 // might not fully overlap, so we cannot decide which one to drop. 481 if (cloneDeallocOp && sourceDeallocOp && 482 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 483 return failure(); 484 485 Block *currentBlock = cloneOp->getBlock(); 486 Operation *redundantDealloc = nullptr; 487 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 488 redundantDealloc = cloneDeallocOp; 489 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 490 redundantDealloc = sourceDeallocOp; 491 } 492 493 if (!redundantDealloc) 494 return failure(); 495 496 // Safety check that there are no other deallocations inbetween 497 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 498 // of source before the uses of the clone. With alias information, we could 499 // restrict this to only fail of the dealloc's operand is an alias 500 // of the source. 501 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 502 pos = pos->getNextNode()) { 503 // Bail if we run out of operations while looking for a deallocation op. 504 if (!pos) 505 return failure(); 506 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 507 if (!effectInterface) 508 continue; 509 if (effectInterface.hasEffect<MemoryEffects::Free>()) 510 return failure(); 511 } 512 513 if (source.getType() != cloneOp.getType()) 514 source = rewriter.create<memref::CastOp>(cloneOp.getLoc(), 515 cloneOp.getType(), source); 516 rewriter.replaceOp(cloneOp, source); 517 rewriter.eraseOp(redundantDealloc); 518 return success(); 519 } 520 }; 521 522 } // namespace 523 524 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 525 MLIRContext *context) { 526 results.add<SimplifyClones>(context); 527 } 528 529 //===----------------------------------------------------------------------===// 530 // DeallocTensorOp 531 //===----------------------------------------------------------------------===// 532 533 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, 534 const BufferizationOptions &options) { 535 FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options); 536 if (failed(buffer)) 537 return failure(); 538 rewriter.create<memref::DeallocOp>(getLoc(), *buffer); 539 rewriter.eraseOp(getOperation()); 540 return success(); 541 } 542 543 //===----------------------------------------------------------------------===// 544 // MaterializeInDestinationOp 545 //===----------------------------------------------------------------------===// 546 547 bool MaterializeInDestinationOp::bufferizesToMemoryRead( 548 OpOperand &opOperand, const AnalysisState &state) { 549 return opOperand == getSourceMutable(); 550 } 551 552 bool MaterializeInDestinationOp::bufferizesToMemoryWrite( 553 OpOperand &opOperand, const AnalysisState &state) { 554 if (opOperand == getDestMutable()) { 555 assert(isa<TensorType>(getDest().getType()) && "expected tensor type"); 556 return true; 557 } 558 return false; 559 } 560 561 bool MaterializeInDestinationOp::mustBufferizeInPlace( 562 OpOperand &opOperand, const AnalysisState &state) { 563 // The source is only read and not written, so it always bufferizes in-place 564 // by default. The destination is written and is forced to bufferize in-place 565 // (if it is a tensor). 566 return true; 567 } 568 569 AliasingValueList 570 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand, 571 const AnalysisState &state) { 572 if (opOperand == getDestMutable()) { 573 assert(isa<TensorType>(getDest().getType()) && "expected tensor type"); 574 return {{getOperation()->getResult(0), BufferRelation::Equivalent}}; 575 } 576 return {}; 577 } 578 579 LogicalResult 580 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter, 581 const BufferizationOptions &options) { 582 bool tensorDest = isa<TensorType>(getDest().getType()); 583 Value buffer; 584 if (tensorDest) { 585 FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options); 586 if (failed(maybeBuffer)) 587 return failure(); 588 buffer = *maybeBuffer; 589 } else { 590 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type"); 591 buffer = getDest(); 592 } 593 auto srcBuffer = getBuffer(rewriter, getSource(), options); 594 if (failed(srcBuffer)) 595 return failure(); 596 if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer))) 597 return failure(); 598 replaceOpWithBufferizedValues(rewriter, getOperation(), 599 tensorDest ? ValueRange(buffer) : ValueRange()); 600 return success(); 601 } 602 603 bool MaterializeInDestinationOp::bufferizesToElementwiseAccess( 604 const AnalysisState &state, ArrayRef<OpOperand *> opOperands) { 605 // As elements are copied from the "source" buffer to the "dest" buffer, 606 // already copied elements are not read a second time. 607 return true; 608 } 609 610 LogicalResult MaterializeInDestinationOp::reifyResultShapes( 611 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 612 if (getOperation()->getNumResults() == 1) { 613 assert(isa<TensorType>(getDest().getType()) && "expected tensor type"); 614 reifiedReturnShapes.resize(1, 615 SmallVector<OpFoldResult>(getType().getRank())); 616 reifiedReturnShapes[0] = 617 tensor::getMixedSizes(builder, getLoc(), getDest()); 618 } 619 return success(); 620 } 621 622 Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder, 623 Location loc) { 624 if (isa<TensorType>(getDest().getType())) { 625 // The subset is the entire destination tensor. 626 return getDest(); 627 } 628 629 // The "restrict" attribute is transferred from this op to the newly created 630 // to_tensor op. If this op does not the "restrict" attribute, the subset 631 // extraction cannot be built because there is no guarantee that there is no 632 // pre-existing "restrict" to_tensor op with the same/an aliasing destination. 633 if (!getRestrict()) 634 return {}; 635 636 // Build a bufferization.to_tensor op. 637 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type"); 638 assert(getRestrict() && 639 "expected that ops with memrefs dest have 'restrict'"); 640 setRestrict(false); 641 return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true, 642 getWritable()); 643 } 644 645 bool MaterializeInDestinationOp::isEquivalentSubset( 646 Value candidate, function_ref<bool(Value, Value)> equivalenceFn) { 647 return equivalenceFn(getDest(), candidate); 648 } 649 650 SmallVector<Value> 651 MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() { 652 return {getDest()}; 653 } 654 655 OpOperand &MaterializeInDestinationOp::getSourceOperand() { 656 return getOperation()->getOpOperand(0) /*source*/; 657 } 658 659 bool MaterializeInDestinationOp::operatesOnEquivalentSubset( 660 SubsetOpInterface subsetOp, 661 function_ref<bool(Value, Value)> equivalenceFn) { 662 return false; 663 } 664 665 bool MaterializeInDestinationOp::operatesOnDisjointSubset( 666 SubsetOpInterface subsetOp, 667 function_ref<bool(Value, Value)> equivalenceFn) { 668 return false; 669 } 670 671 LogicalResult MaterializeInDestinationOp::verify() { 672 if (!isa<TensorType, BaseMemRefType>(getDest().getType())) 673 return emitOpError("'dest' must be a tensor or a memref"); 674 if (auto destType = dyn_cast<TensorType>(getDest().getType())) { 675 if (getOperation()->getNumResults() != 1) 676 return emitOpError("tensor 'dest' implies exactly one tensor result"); 677 if (destType != getResult().getType()) 678 return emitOpError("result and 'dest' types must match"); 679 } 680 if (isa<BaseMemRefType>(getDest().getType()) && 681 getOperation()->getNumResults() != 0) 682 return emitOpError("memref 'dest' implies zero results"); 683 if (getRestrict() && !isa<BaseMemRefType>(getDest().getType())) 684 return emitOpError("'restrict' is valid only for memref destinations"); 685 if (getWritable() != isa<BaseMemRefType>(getDest().getType())) 686 return emitOpError("'writable' must be specified if and only if the " 687 "destination is of memref type"); 688 TensorType srcType = getSource().getType(); 689 ShapedType destType = cast<ShapedType>(getDest().getType()); 690 if (srcType.hasRank() != destType.hasRank()) 691 return emitOpError("source/destination shapes are incompatible"); 692 if (srcType.hasRank()) { 693 if (srcType.getRank() != destType.getRank()) 694 return emitOpError("rank mismatch between source and destination shape"); 695 for (auto [src, dest] : 696 llvm::zip(srcType.getShape(), destType.getShape())) { 697 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) { 698 // Cannot verify dynamic dimension size. Assume that that they match at 699 // runtime. 700 continue; 701 } 702 if (src != dest) 703 return emitOpError("source/destination shapes are incompatible"); 704 } 705 } 706 return success(); 707 } 708 709 void MaterializeInDestinationOp::build(OpBuilder &builder, 710 OperationState &state, Value source, 711 Value dest) { 712 auto destTensorType = dyn_cast<TensorType>(dest.getType()); 713 build(builder, state, /*result=*/destTensorType ? destTensorType : Type(), 714 source, dest); 715 } 716 717 bool MaterializeInDestinationOp::isWritable(Value value, 718 const AnalysisState &state) { 719 return isa<TensorType>(getDest().getType()) ? true : getWritable(); 720 } 721 722 MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() { 723 return getDestMutable(); 724 } 725 726 void MaterializeInDestinationOp::getEffects( 727 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 728 &effects) { 729 if (isa<BaseMemRefType>(getDest().getType())) 730 effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(), 731 SideEffects::DefaultResource::get()); 732 } 733 734 //===----------------------------------------------------------------------===// 735 // ToTensorOp 736 //===----------------------------------------------------------------------===// 737 738 bool ToTensorOp::isWritable(Value value, const AnalysisState &state) { 739 return getWritable(); 740 } 741 742 OpFoldResult ToTensorOp::fold(FoldAdaptor) { 743 if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>()) 744 // Approximate alias analysis by conservatively folding only when no there 745 // is no interleaved operation. 746 if (toMemref->getBlock() == this->getOperation()->getBlock() && 747 toMemref->getNextNode() == this->getOperation()) 748 return toMemref.getTensor(); 749 return {}; 750 } 751 752 namespace { 753 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 754 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 755 756 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 757 PatternRewriter &rewriter) const override { 758 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>(); 759 if (!memrefToTensorOp) 760 return failure(); 761 762 rewriter.replaceOpWithNewOp<memref::DimOp>( 763 dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex()); 764 return success(); 765 } 766 }; 767 } // namespace 768 769 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 770 MLIRContext *context) { 771 results.add<DimOfToTensorFolder>(context); 772 } 773 774 //===----------------------------------------------------------------------===// 775 // ToMemrefOp 776 //===----------------------------------------------------------------------===// 777 778 OpFoldResult ToMemrefOp::fold(FoldAdaptor) { 779 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>()) 780 if (memrefToTensor.getMemref().getType() == getType()) 781 return memrefToTensor.getMemref(); 782 return {}; 783 } 784 785 namespace { 786 787 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 788 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 789 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 790 791 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 792 PatternRewriter &rewriter) const final { 793 auto tensorCastOperand = 794 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 795 if (!tensorCastOperand) 796 return failure(); 797 auto srcTensorType = llvm::dyn_cast<RankedTensorType>( 798 tensorCastOperand.getOperand().getType()); 799 if (!srcTensorType) 800 return failure(); 801 auto memrefType = MemRefType::get(srcTensorType.getShape(), 802 srcTensorType.getElementType()); 803 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 804 tensorCastOperand.getOperand()); 805 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 806 memref); 807 return success(); 808 } 809 }; 810 811 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a 812 /// cast if necessary. 813 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> { 814 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 815 816 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 817 PatternRewriter &rewriter) const final { 818 BufferizationOptions options; 819 options.bufferAlignment = 0; 820 return foldToMemrefToTensorPair(rewriter, toMemref, options); 821 } 822 }; 823 824 /// Fold a load on a to_memref operation into an tensor.extract on the 825 /// corresponding tensor. 826 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 827 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 828 829 LogicalResult matchAndRewrite(memref::LoadOp load, 830 PatternRewriter &rewriter) const override { 831 auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>(); 832 if (!toMemref) 833 return failure(); 834 835 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(), 836 load.getIndices()); 837 return success(); 838 } 839 }; 840 841 /// Fold dim of a to_memref into the dim of the tensor. 842 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 843 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 844 845 LogicalResult matchAndRewrite(memref::DimOp dimOp, 846 PatternRewriter &rewriter) const override { 847 auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>(); 848 if (!castOp) 849 return failure(); 850 Value newSource = castOp.getOperand(); 851 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, 852 dimOp.getIndex()); 853 return success(); 854 } 855 }; 856 857 } // namespace 858 859 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 860 MLIRContext *context) { 861 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, 862 ToMemrefToTensorFolding>(context); 863 } 864 865 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 866 const BufferizationOptions &options) { 867 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 868 (void)foldToMemrefToTensorPair(rewriter, *this, options); 869 // Note: The return value of `bufferize` indicates whether there was an error 870 // or not. (And not whether the pattern matched or not.) 871 return success(); 872 } 873 874 std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, 875 Value alloc) { 876 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 877 .getOperation(); 878 } 879 880 std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 881 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 882 } 883 884 //===----------------------------------------------------------------------===// 885 // DeallocOp 886 //===----------------------------------------------------------------------===// 887 888 LogicalResult DeallocOp::inferReturnTypes( 889 MLIRContext *context, std::optional<::mlir::Location> location, 890 ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, 891 RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) { 892 DeallocOpAdaptor adaptor(operands, attributes, properties, regions); 893 inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(), 894 IntegerType::get(context, 1)); 895 return success(); 896 } 897 898 LogicalResult DeallocOp::verify() { 899 if (getMemrefs().size() != getConditions().size()) 900 return emitOpError( 901 "must have the same number of conditions as memrefs to deallocate"); 902 if (getRetained().size() != getUpdatedConditions().size()) 903 return emitOpError("must have the same number of updated conditions " 904 "(results) as retained operands"); 905 return success(); 906 } 907 908 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, 909 ValueRange memrefs, 910 ValueRange conditions, 911 PatternRewriter &rewriter) { 912 if (deallocOp.getMemrefs() == memrefs && 913 deallocOp.getConditions() == conditions) 914 return failure(); 915 916 rewriter.modifyOpInPlace(deallocOp, [&]() { 917 deallocOp.getMemrefsMutable().assign(memrefs); 918 deallocOp.getConditionsMutable().assign(conditions); 919 }); 920 return success(); 921 } 922 923 namespace { 924 925 /// Remove duplicate values in the list of memrefs to be deallocated. We need to 926 /// make sure the corresponding condition value is updated accordingly since 927 /// their two conditions might not cover the same set of cases. In that case, we 928 /// have to combine them (by computing the disjunction of them). 929 /// Example: 930 /// ```mlir 931 /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2) 932 /// ``` 933 /// is canonicalized to 934 /// ```mlir 935 /// %0 = arith.ori %arg1, %arg2 : i1 936 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0) 937 /// ``` 938 struct DeallocRemoveDuplicateDeallocMemrefs 939 : public OpRewritePattern<DeallocOp> { 940 using OpRewritePattern<DeallocOp>::OpRewritePattern; 941 942 LogicalResult matchAndRewrite(DeallocOp deallocOp, 943 PatternRewriter &rewriter) const override { 944 // Unique memrefs to be deallocated. 945 DenseMap<Value, unsigned> memrefToCondition; 946 SmallVector<Value> newMemrefs, newConditions; 947 for (auto [i, memref, cond] : 948 llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) { 949 if (memrefToCondition.count(memref)) { 950 // If the dealloc conditions don't match, we need to make sure that the 951 // dealloc happens on the union of cases. 952 Value &newCond = newConditions[memrefToCondition[memref]]; 953 if (newCond != cond) 954 newCond = 955 rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond); 956 } else { 957 memrefToCondition.insert({memref, newConditions.size()}); 958 newMemrefs.push_back(memref); 959 newConditions.push_back(cond); 960 } 961 } 962 963 // Return failure if we don't change anything such that we don't run into an 964 // infinite loop of pattern applications. 965 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, 966 rewriter); 967 } 968 }; 969 970 /// Remove duplicate values in the list of retained memrefs. We need to make 971 /// sure the corresponding result condition value is replaced properly. 972 /// Example: 973 /// ```mlir 974 /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...) 975 /// ``` 976 /// is canonicalized to 977 /// ```mlir 978 /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>) 979 /// ``` 980 struct DeallocRemoveDuplicateRetainedMemrefs 981 : public OpRewritePattern<DeallocOp> { 982 using OpRewritePattern<DeallocOp>::OpRewritePattern; 983 984 LogicalResult matchAndRewrite(DeallocOp deallocOp, 985 PatternRewriter &rewriter) const override { 986 // Unique retained values 987 DenseMap<Value, unsigned> seen; 988 SmallVector<Value> newRetained; 989 SmallVector<unsigned> resultReplacementIdx; 990 unsigned i = 0; 991 for (auto retained : deallocOp.getRetained()) { 992 if (seen.count(retained)) { 993 resultReplacementIdx.push_back(seen[retained]); 994 continue; 995 } 996 997 seen[retained] = i; 998 newRetained.push_back(retained); 999 resultReplacementIdx.push_back(i++); 1000 } 1001 1002 // Return failure if we don't change anything such that we don't run into an 1003 // infinite loop of pattern applications. 1004 if (newRetained.size() == deallocOp.getRetained().size()) 1005 return failure(); 1006 1007 // We need to create a new op because the number of results is always the 1008 // same as the number of condition operands. 1009 auto newDeallocOp = 1010 rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(), 1011 deallocOp.getConditions(), newRetained); 1012 SmallVector<Value> replacements( 1013 llvm::map_range(resultReplacementIdx, [&](unsigned idx) { 1014 return newDeallocOp.getUpdatedConditions()[idx]; 1015 })); 1016 rewriter.replaceOp(deallocOp, replacements); 1017 return success(); 1018 } 1019 }; 1020 1021 /// Erase deallocation operations where the variadic list of memrefs to 1022 /// deallocate is empty. Example: 1023 /// ```mlir 1024 /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>) 1025 /// ``` 1026 struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> { 1027 using OpRewritePattern<DeallocOp>::OpRewritePattern; 1028 1029 LogicalResult matchAndRewrite(DeallocOp deallocOp, 1030 PatternRewriter &rewriter) const override { 1031 if (deallocOp.getMemrefs().empty()) { 1032 Value constFalse = rewriter.create<arith::ConstantOp>( 1033 deallocOp.getLoc(), rewriter.getBoolAttr(false)); 1034 rewriter.replaceOp( 1035 deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(), 1036 constFalse)); 1037 return success(); 1038 } 1039 return failure(); 1040 } 1041 }; 1042 1043 /// Removes memrefs from the deallocation list if their associated condition is 1044 /// always 'false'. 1045 /// 1046 /// Example: 1047 /// ``` 1048 /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) 1049 /// if (%arg2, %false) 1050 /// ``` 1051 /// becomes 1052 /// ``` 1053 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2) 1054 /// ``` 1055 struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> { 1056 using OpRewritePattern<DeallocOp>::OpRewritePattern; 1057 1058 LogicalResult matchAndRewrite(DeallocOp deallocOp, 1059 PatternRewriter &rewriter) const override { 1060 SmallVector<Value> newMemrefs, newConditions; 1061 for (auto [memref, cond] : 1062 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { 1063 if (!matchPattern(cond, m_Zero())) { 1064 newMemrefs.push_back(memref); 1065 newConditions.push_back(cond); 1066 } 1067 } 1068 1069 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, 1070 rewriter); 1071 } 1072 }; 1073 1074 /// The `memref.extract_strided_metadata` is often inserted to get the base 1075 /// memref if the operand is not already guaranteed to be the result of a memref 1076 /// allocation operation. This canonicalization pattern removes this extraction 1077 /// operation if the operand is now produced by an allocation operation (e.g., 1078 /// due to other canonicalizations simplifying the IR). 1079 /// 1080 /// Example: 1081 /// ```mlir 1082 /// %alloc = memref.alloc() : memref<2xi32> 1083 /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata 1084 /// %alloc : memref<2xi32> -> memref<i32>, index, index, index 1085 /// bufferization.dealloc (%base_memref : memref<i32>) if (%cond) 1086 /// ``` 1087 /// is canonicalized to 1088 /// ```mlir 1089 /// %alloc = memref.alloc() : memref<2xi32> 1090 /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond) 1091 /// ``` 1092 struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> { 1093 using OpRewritePattern<DeallocOp>::OpRewritePattern; 1094 1095 LogicalResult matchAndRewrite(DeallocOp deallocOp, 1096 PatternRewriter &rewriter) const override { 1097 SmallVector<Value> newMemrefs( 1098 llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) { 1099 auto extractStridedOp = 1100 memref.getDefiningOp<memref::ExtractStridedMetadataOp>(); 1101 if (!extractStridedOp) 1102 return memref; 1103 Value allocMemref = extractStridedOp.getOperand(); 1104 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>(); 1105 if (!allocOp) 1106 return memref; 1107 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref)) 1108 return allocMemref; 1109 return memref; 1110 })); 1111 1112 return updateDeallocIfChanged(deallocOp, newMemrefs, 1113 deallocOp.getConditions(), rewriter); 1114 } 1115 }; 1116 1117 /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no 1118 /// other user of the allocated value and the allocating operation can be safely 1119 /// removed. If the same value is present multiple times, this pattern relies on 1120 /// other canonicalization patterns to remove the duplicate first. 1121 /// 1122 /// Example: 1123 /// ```mlir 1124 /// %alloc = memref.alloc() : memref<2xi32> 1125 /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true) 1126 /// ``` 1127 /// is canonicalized to 1128 /// ```mlir 1129 /// bufferization.dealloc (%arg0 : ...) if (%true) 1130 /// ``` 1131 struct RemoveAllocDeallocPairWhenNoOtherUsers 1132 : public OpRewritePattern<DeallocOp> { 1133 using OpRewritePattern<DeallocOp>::OpRewritePattern; 1134 1135 LogicalResult matchAndRewrite(DeallocOp deallocOp, 1136 PatternRewriter &rewriter) const override { 1137 SmallVector<Value> newMemrefs, newConditions; 1138 SmallVector<Operation *> toDelete; 1139 for (auto [memref, cond] : 1140 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { 1141 if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) { 1142 // Check that it is indeed an allocate effect, that the op has no other 1143 // side effects (which would not allow us to remove the op), and that 1144 // there are no other users. 1145 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) && 1146 hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) && 1147 memref.hasOneUse()) { 1148 toDelete.push_back(allocOp); 1149 continue; 1150 } 1151 } 1152 1153 newMemrefs.push_back(memref); 1154 newConditions.push_back(cond); 1155 } 1156 1157 if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, 1158 rewriter))) 1159 return failure(); 1160 1161 for (Operation *op : toDelete) 1162 rewriter.eraseOp(op); 1163 1164 return success(); 1165 } 1166 }; 1167 1168 } // anonymous namespace 1169 1170 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, 1171 MLIRContext *context) { 1172 populateDeallocOpCanonicalizationPatterns(results, context); 1173 } 1174 1175 void bufferization::populateDeallocOpCanonicalizationPatterns( 1176 RewritePatternSet &patterns, MLIRContext *context) { 1177 patterns.add<DeallocRemoveDuplicateDeallocMemrefs, 1178 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc, 1179 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc, 1180 RemoveAllocDeallocPairWhenNoOtherUsers>(context); 1181 } 1182 1183 //===----------------------------------------------------------------------===// 1184 // TableGen'd op method definitions 1185 //===----------------------------------------------------------------------===// 1186 1187 #define GET_OP_CLASSES 1188 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 1189