xref: /llvm-project/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp (revision 49df12c01e99af6e091fedc123f775580064740a)
1 //===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Bufferization dialect to MemRef
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREF
28 #include "mlir/Conversion/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 
33 namespace {
34 /// The CloneOpConversion transforms all bufferization clone operations into
35 /// memref alloc and memref copy operations. In the dynamic-shape case, it also
36 /// emits additional dim and constant operations to determine the shape. This
37 /// conversion does not resolve memory leaks if it is used alone.
38 struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
39   using OpConversionPattern<bufferization::CloneOp>::OpConversionPattern;
40 
41   LogicalResult
42   matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
43                   ConversionPatternRewriter &rewriter) const override {
44     Location loc = op->getLoc();
45 
46     Type type = op.getType();
47     Value alloc;
48 
49     if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
50       // Constants
51       Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
52       Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
53 
54       // Dynamically evaluate the size and shape of the unranked memref
55       Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
56       MemRefType allocType =
57           MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
58       Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
59 
60       // Create a loop to query dimension sizes, store them as a shape, and
61       // compute the total size of the memref
62       auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
63                           ValueRange args) {
64         auto acc = args.front();
65         auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i);
66 
67         rewriter.create<memref::StoreOp>(loc, dim, shape, i);
68         acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
69 
70         rewriter.create<scf::YieldOp>(loc, acc);
71       };
72       auto size = rewriter
73                       .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
74                                           loopBody)
75                       .getResult(0);
76 
77       MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
78                                               unrankedType.getElementType());
79 
80       // Allocate new memref with 1D dynamic shape, then reshape into the
81       // shape of the original unranked memref
82       alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
83       alloc =
84           rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
85     } else {
86       MemRefType memrefType = cast<MemRefType>(type);
87       MemRefLayoutAttrInterface layout;
88       auto allocType =
89           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
90                           layout, memrefType.getMemorySpace());
91       // Since this implementation always allocates, certain result types of
92       // the clone op cannot be lowered.
93       if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
94         return failure();
95 
96       // Transform a clone operation into alloc + copy operation and pay
97       // attention to the shape dimensions.
98       SmallVector<Value, 4> dynamicOperands;
99       for (int i = 0; i < memrefType.getRank(); ++i) {
100         if (!memrefType.isDynamicDim(i))
101           continue;
102         Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
103         dynamicOperands.push_back(dim);
104       }
105 
106       // Allocate a memref with identity layout.
107       alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands);
108       // Cast the allocation to the specified type if needed.
109       if (memrefType != allocType)
110         alloc =
111             rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
112     }
113 
114     rewriter.replaceOp(op, alloc);
115     rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
116     return success();
117   }
118 };
119 
120 } // namespace
121 
122 namespace {
123 struct BufferizationToMemRefPass
124     : public impl::ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> {
125   BufferizationToMemRefPass() = default;
126 
127   void runOnOperation() override {
128     if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
129       emitError(getOperation()->getLoc(),
130                 "root operation must be a builtin.module or a function");
131       signalPassFailure();
132       return;
133     }
134 
135     bufferization::DeallocHelperMap deallocHelperFuncMap;
136     if (auto module = dyn_cast<ModuleOp>(getOperation())) {
137       OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
138 
139       // Build dealloc helper function if there are deallocs.
140       getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
141         Operation *symtableOp =
142             deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
143         if (deallocOp.getMemrefs().size() > 1 &&
144             !deallocHelperFuncMap.contains(symtableOp)) {
145           SymbolTable symbolTable(symtableOp);
146           func::FuncOp helperFuncOp =
147               bufferization::buildDeallocationLibraryFunction(
148                   builder, getOperation()->getLoc(), symbolTable);
149           deallocHelperFuncMap[symtableOp] = helperFuncOp;
150         }
151       });
152     }
153 
154     RewritePatternSet patterns(&getContext());
155     patterns.add<CloneOpConversion>(patterns.getContext());
156     bufferization::populateBufferizationDeallocLoweringPattern(
157         patterns, deallocHelperFuncMap);
158 
159     ConversionTarget target(getContext());
160     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
161                            scf::SCFDialect, func::FuncDialect>();
162     target.addIllegalDialect<bufferization::BufferizationDialect>();
163 
164     if (failed(applyPartialConversion(getOperation(), target,
165                                       std::move(patterns))))
166       signalPassFailure();
167   }
168 };
169 } // namespace
170 
171 std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() {
172   return std::make_unique<BufferizationToMemRefPass>();
173 }
174