18037deb7SMartin Erhart //===- ExpandRealloc.cpp - Expand memref.realloc ops into it's components -===//
28037deb7SMartin Erhart //
38037deb7SMartin Erhart // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48037deb7SMartin Erhart // See https://llvm.org/LICENSE.txt for license information.
58037deb7SMartin Erhart // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68037deb7SMartin Erhart //
78037deb7SMartin Erhart //===----------------------------------------------------------------------===//
88037deb7SMartin Erhart
98037deb7SMartin Erhart #include "mlir/Dialect/MemRef/Transforms/Passes.h"
108037deb7SMartin Erhart #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
118037deb7SMartin Erhart
128037deb7SMartin Erhart #include "mlir/Dialect/Arith/IR/Arith.h"
138037deb7SMartin Erhart #include "mlir/Dialect/MemRef/IR/MemRef.h"
148037deb7SMartin Erhart #include "mlir/Dialect/SCF/IR/SCF.h"
158037deb7SMartin Erhart #include "mlir/Transforms/DialectConversion.h"
168037deb7SMartin Erhart
178037deb7SMartin Erhart namespace mlir {
188037deb7SMartin Erhart namespace memref {
198037deb7SMartin Erhart #define GEN_PASS_DEF_EXPANDREALLOC
208037deb7SMartin Erhart #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
218037deb7SMartin Erhart } // namespace memref
228037deb7SMartin Erhart } // namespace mlir
238037deb7SMartin Erhart
248037deb7SMartin Erhart using namespace mlir;
258037deb7SMartin Erhart
268037deb7SMartin Erhart namespace {
278037deb7SMartin Erhart
288037deb7SMartin Erhart /// The `realloc` operation performs a conditional allocation and copy to
298037deb7SMartin Erhart /// increase the size of a buffer if necessary. This pattern converts the
308037deb7SMartin Erhart /// `realloc` operation into this sequence of simpler operations.
318037deb7SMartin Erhart
328037deb7SMartin Erhart /// Example of an expansion:
338037deb7SMartin Erhart /// ```mlir
348037deb7SMartin Erhart /// %realloc = memref.realloc %alloc (%size) : memref<?xf32> to memref<?xf32>
358037deb7SMartin Erhart /// ```
368037deb7SMartin Erhart /// is expanded to
378037deb7SMartin Erhart /// ```mlir
388037deb7SMartin Erhart /// %c0 = arith.constant 0 : index
398037deb7SMartin Erhart /// %dim = memref.dim %alloc, %c0 : memref<?xf32>
408037deb7SMartin Erhart /// %is_old_smaller = arith.cmpi ult, %dim, %arg1
418037deb7SMartin Erhart /// %realloc = scf.if %is_old_smaller -> (memref<?xf32>) {
428037deb7SMartin Erhart /// %new_alloc = memref.alloc(%size) : memref<?xf32>
438037deb7SMartin Erhart /// %subview = memref.subview %new_alloc[0] [%dim] [1]
448037deb7SMartin Erhart /// memref.copy %alloc, %subview
458037deb7SMartin Erhart /// memref.dealloc %alloc
468037deb7SMartin Erhart /// scf.yield %alloc_0 : memref<?xf32>
478037deb7SMartin Erhart /// } else {
488037deb7SMartin Erhart /// %reinterpret_cast = memref.reinterpret_cast %alloc to
498037deb7SMartin Erhart /// offset: [0], sizes: [%size], strides: [1]
508037deb7SMartin Erhart /// scf.yield %reinterpret_cast : memref<?xf32>
518037deb7SMartin Erhart /// }
528037deb7SMartin Erhart /// ```
538037deb7SMartin Erhart struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
ExpandReallocOpPattern__anonded970ac0111::ExpandReallocOpPattern548037deb7SMartin Erhart ExpandReallocOpPattern(MLIRContext *ctx, bool emitDeallocs)
558037deb7SMartin Erhart : OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {}
568037deb7SMartin Erhart
matchAndRewrite__anonded970ac0111::ExpandReallocOpPattern578037deb7SMartin Erhart LogicalResult matchAndRewrite(memref::ReallocOp op,
588037deb7SMartin Erhart PatternRewriter &rewriter) const final {
598037deb7SMartin Erhart Location loc = op.getLoc();
608037deb7SMartin Erhart assert(op.getType().getRank() == 1 &&
618037deb7SMartin Erhart "result MemRef must have exactly one rank");
628037deb7SMartin Erhart assert(op.getSource().getType().getRank() == 1 &&
638037deb7SMartin Erhart "source MemRef must have exactly one rank");
648037deb7SMartin Erhart assert(op.getType().getLayout().isIdentity() &&
658037deb7SMartin Erhart "result MemRef must have identity layout (or none)");
668037deb7SMartin Erhart assert(op.getSource().getType().getLayout().isIdentity() &&
678037deb7SMartin Erhart "source MemRef must have identity layout (or none)");
688037deb7SMartin Erhart
698037deb7SMartin Erhart // Get the size of the original buffer.
708037deb7SMartin Erhart int64_t inputSize =
71*a5757c5bSChristian Sigg cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
728037deb7SMartin Erhart OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
738037deb7SMartin Erhart if (ShapedType::isDynamic(inputSize)) {
748037deb7SMartin Erhart Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
758037deb7SMartin Erhart rewriter.getIndexAttr(0));
768037deb7SMartin Erhart currSize = rewriter.create<memref::DimOp>(loc, op.getSource(), dimZero)
778037deb7SMartin Erhart .getResult();
788037deb7SMartin Erhart }
798037deb7SMartin Erhart
808037deb7SMartin Erhart // Get the requested size that the new buffer should have.
818037deb7SMartin Erhart int64_t outputSize =
82*a5757c5bSChristian Sigg cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
838037deb7SMartin Erhart OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
848037deb7SMartin Erhart ? OpFoldResult{op.getDynamicResultSize()}
858037deb7SMartin Erhart : rewriter.getIndexAttr(outputSize);
868037deb7SMartin Erhart
878037deb7SMartin Erhart // Only allocate a new buffer and copy over the values in the old buffer if
888037deb7SMartin Erhart // the old buffer is smaller than the requested size.
898037deb7SMartin Erhart Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize);
908037deb7SMartin Erhart Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize);
918037deb7SMartin Erhart Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
928037deb7SMartin Erhart lhs, rhs);
938037deb7SMartin Erhart auto ifOp = rewriter.create<scf::IfOp>(
948037deb7SMartin Erhart loc, cond,
958037deb7SMartin Erhart [&](OpBuilder &builder, Location loc) {
968037deb7SMartin Erhart // Allocate the new buffer. If it is a dynamic memref we need to pass
978037deb7SMartin Erhart // an additional operand for the size at runtime, otherwise the static
988037deb7SMartin Erhart // size is encoded in the result type.
998037deb7SMartin Erhart SmallVector<Value> dynamicSizeOperands;
1008037deb7SMartin Erhart if (op.getDynamicResultSize())
1018037deb7SMartin Erhart dynamicSizeOperands.push_back(op.getDynamicResultSize());
1028037deb7SMartin Erhart
1038037deb7SMartin Erhart Value newAlloc = builder.create<memref::AllocOp>(
1048037deb7SMartin Erhart loc, op.getResult().getType(), dynamicSizeOperands,
1058037deb7SMartin Erhart op.getAlignmentAttr());
1068037deb7SMartin Erhart
1078037deb7SMartin Erhart // Take a subview of the new (bigger) buffer such that we can copy the
1088037deb7SMartin Erhart // old values over (the copy operation requires both operands to have
1098037deb7SMartin Erhart // the same shape).
1108037deb7SMartin Erhart Value subview = builder.create<memref::SubViewOp>(
1118037deb7SMartin Erhart loc, newAlloc, ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
1128037deb7SMartin Erhart ArrayRef<OpFoldResult>{currSize},
1138037deb7SMartin Erhart ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
1148037deb7SMartin Erhart builder.create<memref::CopyOp>(loc, op.getSource(), subview);
1158037deb7SMartin Erhart
1168037deb7SMartin Erhart // Insert the deallocation of the old buffer only if requested
1178037deb7SMartin Erhart // (enabled by default).
1188037deb7SMartin Erhart if (emitDeallocs)
1198037deb7SMartin Erhart builder.create<memref::DeallocOp>(loc, op.getSource());
1208037deb7SMartin Erhart
1218037deb7SMartin Erhart builder.create<scf::YieldOp>(loc, newAlloc);
1228037deb7SMartin Erhart },
1238037deb7SMartin Erhart [&](OpBuilder &builder, Location loc) {
1248037deb7SMartin Erhart // We need to reinterpret-cast here because either the input or output
1258037deb7SMartin Erhart // type might be static, which means we need to cast from static to
1268037deb7SMartin Erhart // dynamic or vice-versa. If both are static and the original buffer
1278037deb7SMartin Erhart // is already bigger than the requested size, the cast represents a
1288037deb7SMartin Erhart // subview operation.
1298037deb7SMartin Erhart Value casted = builder.create<memref::ReinterpretCastOp>(
130*a5757c5bSChristian Sigg loc, cast<MemRefType>(op.getResult().getType()), op.getSource(),
1318037deb7SMartin Erhart rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize},
1328037deb7SMartin Erhart ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
1338037deb7SMartin Erhart builder.create<scf::YieldOp>(loc, casted);
1348037deb7SMartin Erhart });
1358037deb7SMartin Erhart
1368037deb7SMartin Erhart rewriter.replaceOp(op, ifOp.getResult(0));
1378037deb7SMartin Erhart return success();
1388037deb7SMartin Erhart }
1398037deb7SMartin Erhart
1408037deb7SMartin Erhart private:
1418037deb7SMartin Erhart const bool emitDeallocs;
1428037deb7SMartin Erhart };
1438037deb7SMartin Erhart
1448037deb7SMartin Erhart struct ExpandReallocPass
1458037deb7SMartin Erhart : public memref::impl::ExpandReallocBase<ExpandReallocPass> {
ExpandReallocPass__anonded970ac0111::ExpandReallocPass1468037deb7SMartin Erhart ExpandReallocPass(bool emitDeallocs)
1478037deb7SMartin Erhart : memref::impl::ExpandReallocBase<ExpandReallocPass>() {
1488037deb7SMartin Erhart this->emitDeallocs.setValue(emitDeallocs);
1498037deb7SMartin Erhart }
runOnOperation__anonded970ac0111::ExpandReallocPass1508037deb7SMartin Erhart void runOnOperation() override {
1518037deb7SMartin Erhart MLIRContext &ctx = getContext();
1528037deb7SMartin Erhart
1538037deb7SMartin Erhart RewritePatternSet patterns(&ctx);
1548037deb7SMartin Erhart memref::populateExpandReallocPatterns(patterns, emitDeallocs.getValue());
1558037deb7SMartin Erhart ConversionTarget target(ctx);
1568037deb7SMartin Erhart
1578037deb7SMartin Erhart target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
1588037deb7SMartin Erhart memref::MemRefDialect>();
1598037deb7SMartin Erhart target.addIllegalOp<memref::ReallocOp>();
1608037deb7SMartin Erhart if (failed(applyPartialConversion(getOperation(), target,
1618037deb7SMartin Erhart std::move(patterns))))
1628037deb7SMartin Erhart signalPassFailure();
1638037deb7SMartin Erhart }
1648037deb7SMartin Erhart };
1658037deb7SMartin Erhart
1668037deb7SMartin Erhart } // namespace
1678037deb7SMartin Erhart
populateExpandReallocPatterns(RewritePatternSet & patterns,bool emitDeallocs)1688037deb7SMartin Erhart void mlir::memref::populateExpandReallocPatterns(RewritePatternSet &patterns,
1698037deb7SMartin Erhart bool emitDeallocs) {
1708037deb7SMartin Erhart patterns.add<ExpandReallocOpPattern>(patterns.getContext(), emitDeallocs);
1718037deb7SMartin Erhart }
1728037deb7SMartin Erhart
createExpandReallocPass(bool emitDeallocs)1738037deb7SMartin Erhart std::unique_ptr<Pass> mlir::memref::createExpandReallocPass(bool emitDeallocs) {
1748037deb7SMartin Erhart return std::make_unique<ExpandReallocPass>(emitDeallocs);
1758037deb7SMartin Erhart }
176