xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp (revision a5757c5b65f1894de16f549212b1c37793312703)
1 //===- ExpandRealloc.cpp - Expand memref.realloc ops into it's components -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
10 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
11 
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 namespace mlir {
18 namespace memref {
19 #define GEN_PASS_DEF_EXPANDREALLOC
20 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
21 } // namespace memref
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 /// The `realloc` operation performs a conditional allocation and copy to
29 /// increase the size of a buffer if necessary. This pattern converts the
30 /// `realloc` operation into this sequence of simpler operations.
31 
32 /// Example of an expansion:
33 /// ```mlir
34 /// %realloc = memref.realloc %alloc (%size) : memref<?xf32> to memref<?xf32>
35 /// ```
36 /// is expanded to
37 /// ```mlir
38 /// %c0 = arith.constant 0 : index
39 /// %dim = memref.dim %alloc, %c0 : memref<?xf32>
40 /// %is_old_smaller = arith.cmpi ult, %dim, %arg1
41 /// %realloc = scf.if %is_old_smaller -> (memref<?xf32>) {
42 ///   %new_alloc = memref.alloc(%size) : memref<?xf32>
43 ///   %subview = memref.subview %new_alloc[0] [%dim] [1]
44 ///   memref.copy %alloc, %subview
45 ///   memref.dealloc %alloc
46 ///   scf.yield %alloc_0 : memref<?xf32>
47 /// } else {
48 ///   %reinterpret_cast = memref.reinterpret_cast %alloc to
49 ///     offset: [0], sizes: [%size], strides: [1]
50 ///   scf.yield %reinterpret_cast : memref<?xf32>
51 /// }
52 /// ```
53 struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
ExpandReallocOpPattern__anonded970ac0111::ExpandReallocOpPattern54   ExpandReallocOpPattern(MLIRContext *ctx, bool emitDeallocs)
55       : OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {}
56 
matchAndRewrite__anonded970ac0111::ExpandReallocOpPattern57   LogicalResult matchAndRewrite(memref::ReallocOp op,
58                                 PatternRewriter &rewriter) const final {
59     Location loc = op.getLoc();
60     assert(op.getType().getRank() == 1 &&
61            "result MemRef must have exactly one rank");
62     assert(op.getSource().getType().getRank() == 1 &&
63            "source MemRef must have exactly one rank");
64     assert(op.getType().getLayout().isIdentity() &&
65            "result MemRef must have identity layout (or none)");
66     assert(op.getSource().getType().getLayout().isIdentity() &&
67            "source MemRef must have identity layout (or none)");
68 
69     // Get the size of the original buffer.
70     int64_t inputSize =
71         cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
72     OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
73     if (ShapedType::isDynamic(inputSize)) {
74       Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
75                                                       rewriter.getIndexAttr(0));
76       currSize = rewriter.create<memref::DimOp>(loc, op.getSource(), dimZero)
77                      .getResult();
78     }
79 
80     // Get the requested size that the new buffer should have.
81     int64_t outputSize =
82         cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
83     OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
84                                   ? OpFoldResult{op.getDynamicResultSize()}
85                                   : rewriter.getIndexAttr(outputSize);
86 
87     // Only allocate a new buffer and copy over the values in the old buffer if
88     // the old buffer is smaller than the requested size.
89     Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize);
90     Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize);
91     Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
92                                                 lhs, rhs);
93     auto ifOp = rewriter.create<scf::IfOp>(
94         loc, cond,
95         [&](OpBuilder &builder, Location loc) {
96           // Allocate the new buffer. If it is a dynamic memref we need to pass
97           // an additional operand for the size at runtime, otherwise the static
98           // size is encoded in the result type.
99           SmallVector<Value> dynamicSizeOperands;
100           if (op.getDynamicResultSize())
101             dynamicSizeOperands.push_back(op.getDynamicResultSize());
102 
103           Value newAlloc = builder.create<memref::AllocOp>(
104               loc, op.getResult().getType(), dynamicSizeOperands,
105               op.getAlignmentAttr());
106 
107           // Take a subview of the new (bigger) buffer such that we can copy the
108           // old values over (the copy operation requires both operands to have
109           // the same shape).
110           Value subview = builder.create<memref::SubViewOp>(
111               loc, newAlloc, ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
112               ArrayRef<OpFoldResult>{currSize},
113               ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
114           builder.create<memref::CopyOp>(loc, op.getSource(), subview);
115 
116           // Insert the deallocation of the old buffer only if requested
117           // (enabled by default).
118           if (emitDeallocs)
119             builder.create<memref::DeallocOp>(loc, op.getSource());
120 
121           builder.create<scf::YieldOp>(loc, newAlloc);
122         },
123         [&](OpBuilder &builder, Location loc) {
124           // We need to reinterpret-cast here because either the input or output
125           // type might be static, which means we need to cast from static to
126           // dynamic or vice-versa. If both are static and the original buffer
127           // is already bigger than the requested size, the cast represents a
128           // subview operation.
129           Value casted = builder.create<memref::ReinterpretCastOp>(
130               loc, cast<MemRefType>(op.getResult().getType()), op.getSource(),
131               rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize},
132               ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
133           builder.create<scf::YieldOp>(loc, casted);
134         });
135 
136     rewriter.replaceOp(op, ifOp.getResult(0));
137     return success();
138   }
139 
140 private:
141   const bool emitDeallocs;
142 };
143 
144 struct ExpandReallocPass
145     : public memref::impl::ExpandReallocBase<ExpandReallocPass> {
ExpandReallocPass__anonded970ac0111::ExpandReallocPass146   ExpandReallocPass(bool emitDeallocs)
147       : memref::impl::ExpandReallocBase<ExpandReallocPass>() {
148     this->emitDeallocs.setValue(emitDeallocs);
149   }
runOnOperation__anonded970ac0111::ExpandReallocPass150   void runOnOperation() override {
151     MLIRContext &ctx = getContext();
152 
153     RewritePatternSet patterns(&ctx);
154     memref::populateExpandReallocPatterns(patterns, emitDeallocs.getValue());
155     ConversionTarget target(ctx);
156 
157     target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
158                            memref::MemRefDialect>();
159     target.addIllegalOp<memref::ReallocOp>();
160     if (failed(applyPartialConversion(getOperation(), target,
161                                       std::move(patterns))))
162       signalPassFailure();
163   }
164 };
165 
166 } // namespace
167 
populateExpandReallocPatterns(RewritePatternSet & patterns,bool emitDeallocs)168 void mlir::memref::populateExpandReallocPatterns(RewritePatternSet &patterns,
169                                                  bool emitDeallocs) {
170   patterns.add<ExpandReallocOpPattern>(patterns.getContext(), emitDeallocs);
171 }
172 
createExpandReallocPass(bool emitDeallocs)173 std::unique_ptr<Pass> mlir::memref::createExpandReallocPass(bool emitDeallocs) {
174   return std::make_unique<ExpandReallocPass>(emitDeallocs);
175 }
176