xref: /llvm-project/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp (revision 49df12c01e99af6e091fedc123f775580064740a)
1ae1ea0beSJulian Gross //===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===//
2ae1ea0beSJulian Gross //
3ae1ea0beSJulian Gross // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ae1ea0beSJulian Gross // See https://llvm.org/LICENSE.txt for license information.
5ae1ea0beSJulian Gross // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ae1ea0beSJulian Gross //
7ae1ea0beSJulian Gross //===----------------------------------------------------------------------===//
8ae1ea0beSJulian Gross //
9ae1ea0beSJulian Gross // This file implements patterns to convert Bufferization dialect to MemRef
10ae1ea0beSJulian Gross // dialect.
11ae1ea0beSJulian Gross //
12ae1ea0beSJulian Gross //===----------------------------------------------------------------------===//
13ae1ea0beSJulian Gross 
14ae1ea0beSJulian Gross #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
1567d0d7acSMichele Scuttari 
16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
17ae1ea0beSJulian Gross #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
18950f0944SMartin Erhart #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
1907c079a9SMartin Erhart #include "mlir/Dialect/Func/IR/FuncOps.h"
20ae1ea0beSJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h"
2107c079a9SMartin Erhart #include "mlir/Dialect/SCF/IR/SCF.h"
22ae1ea0beSJulian Gross #include "mlir/IR/BuiltinTypes.h"
2367d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
24ae1ea0beSJulian Gross #include "mlir/Transforms/DialectConversion.h"
25ae1ea0beSJulian Gross 
2667d0d7acSMichele Scuttari namespace mlir {
2767d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREF
2867d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
2967d0d7acSMichele Scuttari } // namespace mlir
3067d0d7acSMichele Scuttari 
31ae1ea0beSJulian Gross using namespace mlir;
32ae1ea0beSJulian Gross 
33ae1ea0beSJulian Gross namespace {
34ae1ea0beSJulian Gross /// The CloneOpConversion transforms all bufferization clone operations into
35ae1ea0beSJulian Gross /// memref alloc and memref copy operations. In the dynamic-shape case, it also
36ae1ea0beSJulian Gross /// emits additional dim and constant operations to determine the shape. This
37ae1ea0beSJulian Gross /// conversion does not resolve memory leaks if it is used alone.
38ae1ea0beSJulian Gross struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
39ae1ea0beSJulian Gross   using OpConversionPattern<bufferization::CloneOp>::OpConversionPattern;
40ae1ea0beSJulian Gross 
41ae1ea0beSJulian Gross   LogicalResult
42ae1ea0beSJulian Gross   matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
43ae1ea0beSJulian Gross                   ConversionPatternRewriter &rewriter) const override {
4428d6aa90Sryankima     Location loc = op->getLoc();
4528d6aa90Sryankima 
46ae1ea0beSJulian Gross     Type type = op.getType();
4728d6aa90Sryankima     Value alloc;
4828d6aa90Sryankima 
4928d6aa90Sryankima     if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
5028d6aa90Sryankima       // Constants
5128d6aa90Sryankima       Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
5228d6aa90Sryankima       Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
5328d6aa90Sryankima 
5428d6aa90Sryankima       // Dynamically evaluate the size and shape of the unranked memref
5528d6aa90Sryankima       Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
5628d6aa90Sryankima       MemRefType allocType =
5728d6aa90Sryankima           MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
5828d6aa90Sryankima       Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
5928d6aa90Sryankima 
6028d6aa90Sryankima       // Create a loop to query dimension sizes, store them as a shape, and
6128d6aa90Sryankima       // compute the total size of the memref
6228d6aa90Sryankima       auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
6328d6aa90Sryankima                           ValueRange args) {
6428d6aa90Sryankima         auto acc = args.front();
6528d6aa90Sryankima         auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i);
6628d6aa90Sryankima 
6728d6aa90Sryankima         rewriter.create<memref::StoreOp>(loc, dim, shape, i);
6828d6aa90Sryankima         acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
6928d6aa90Sryankima 
7028d6aa90Sryankima         rewriter.create<scf::YieldOp>(loc, acc);
7128d6aa90Sryankima       };
7228d6aa90Sryankima       auto size = rewriter
7328d6aa90Sryankima                       .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
7428d6aa90Sryankima                                           loopBody)
7528d6aa90Sryankima                       .getResult(0);
7628d6aa90Sryankima 
7728d6aa90Sryankima       MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
7828d6aa90Sryankima                                               unrankedType.getElementType());
7928d6aa90Sryankima 
8028d6aa90Sryankima       // Allocate new memref with 1D dynamic shape, then reshape into the
8128d6aa90Sryankima       // shape of the original unranked memref
8228d6aa90Sryankima       alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
8328d6aa90Sryankima       alloc =
8428d6aa90Sryankima           rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
8528d6aa90Sryankima     } else {
865550c821STres Popp       MemRefType memrefType = cast<MemRefType>(type);
878dca38d5SMatthias Springer       MemRefLayoutAttrInterface layout;
888dca38d5SMatthias Springer       auto allocType =
898dca38d5SMatthias Springer           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
908dca38d5SMatthias Springer                           layout, memrefType.getMemorySpace());
9128d6aa90Sryankima       // Since this implementation always allocates, certain result types of
9228d6aa90Sryankima       // the clone op cannot be lowered.
938dca38d5SMatthias Springer       if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
948dca38d5SMatthias Springer         return failure();
95ae1ea0beSJulian Gross 
96ae1ea0beSJulian Gross       // Transform a clone operation into alloc + copy operation and pay
97ae1ea0beSJulian Gross       // attention to the shape dimensions.
98ae1ea0beSJulian Gross       SmallVector<Value, 4> dynamicOperands;
99ae1ea0beSJulian Gross       for (int i = 0; i < memrefType.getRank(); ++i) {
100ae1ea0beSJulian Gross         if (!memrefType.isDynamicDim(i))
101ae1ea0beSJulian Gross           continue;
102b23c8225SMatthias Springer         Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
103ae1ea0beSJulian Gross         dynamicOperands.push_back(dim);
104ae1ea0beSJulian Gross       }
1058dca38d5SMatthias Springer 
1068dca38d5SMatthias Springer       // Allocate a memref with identity layout.
10728d6aa90Sryankima       alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands);
1088dca38d5SMatthias Springer       // Cast the allocation to the specified type if needed.
1098dca38d5SMatthias Springer       if (memrefType != allocType)
11028d6aa90Sryankima         alloc =
11128d6aa90Sryankima             rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
11228d6aa90Sryankima     }
11328d6aa90Sryankima 
1148dca38d5SMatthias Springer     rewriter.replaceOp(op, alloc);
11599260e95SMatthias Springer     rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
116ae1ea0beSJulian Gross     return success();
117ae1ea0beSJulian Gross   }
118ae1ea0beSJulian Gross };
119ae1ea0beSJulian Gross 
12007c079a9SMartin Erhart } // namespace
12107c079a9SMartin Erhart 
122ae1ea0beSJulian Gross namespace {
123039b969bSMichele Scuttari struct BufferizationToMemRefPass
12467d0d7acSMichele Scuttari     : public impl::ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> {
125039b969bSMichele Scuttari   BufferizationToMemRefPass() = default;
126ae1ea0beSJulian Gross 
127ae1ea0beSJulian Gross   void runOnOperation() override {
1283610c82cSMartin Erhart     if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
1293610c82cSMartin Erhart       emitError(getOperation()->getLoc(),
1303610c82cSMartin Erhart                 "root operation must be a builtin.module or a function");
1313610c82cSMartin Erhart       signalPassFailure();
1323610c82cSMartin Erhart       return;
1333610c82cSMartin Erhart     }
1343610c82cSMartin Erhart 
135662c6fc7Sdonald chen     bufferization::DeallocHelperMap deallocHelperFuncMap;
1363610c82cSMartin Erhart     if (auto module = dyn_cast<ModuleOp>(getOperation())) {
137*49df12c0SMatthias Springer       OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
13807c079a9SMartin Erhart 
13907c079a9SMartin Erhart       // Build dealloc helper function if there are deallocs.
14007c079a9SMartin Erhart       getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
141662c6fc7Sdonald chen         Operation *symtableOp =
142662c6fc7Sdonald chen             deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
143662c6fc7Sdonald chen         if (deallocOp.getMemrefs().size() > 1 &&
144662c6fc7Sdonald chen             !deallocHelperFuncMap.contains(symtableOp)) {
145662c6fc7Sdonald chen           SymbolTable symbolTable(symtableOp);
146662c6fc7Sdonald chen           func::FuncOp helperFuncOp =
147662c6fc7Sdonald chen               bufferization::buildDeallocationLibraryFunction(
14807c079a9SMartin Erhart                   builder, getOperation()->getLoc(), symbolTable);
149662c6fc7Sdonald chen           deallocHelperFuncMap[symtableOp] = helperFuncOp;
15007c079a9SMartin Erhart         }
15107c079a9SMartin Erhart       });
1523610c82cSMartin Erhart     }
15307c079a9SMartin Erhart 
154ae1ea0beSJulian Gross     RewritePatternSet patterns(&getContext());
15507c079a9SMartin Erhart     patterns.add<CloneOpConversion>(patterns.getContext());
156662c6fc7Sdonald chen     bufferization::populateBufferizationDeallocLoweringPattern(
157662c6fc7Sdonald chen         patterns, deallocHelperFuncMap);
158ae1ea0beSJulian Gross 
159ae1ea0beSJulian Gross     ConversionTarget target(getContext());
16007c079a9SMartin Erhart     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
16107c079a9SMartin Erhart                            scf::SCFDialect, func::FuncDialect>();
162ae1ea0beSJulian Gross     target.addIllegalDialect<bufferization::BufferizationDialect>();
163ae1ea0beSJulian Gross 
164ae1ea0beSJulian Gross     if (failed(applyPartialConversion(getOperation(), target,
165ae1ea0beSJulian Gross                                       std::move(patterns))))
166ae1ea0beSJulian Gross       signalPassFailure();
167ae1ea0beSJulian Gross   }
168ae1ea0beSJulian Gross };
169ae1ea0beSJulian Gross } // namespace
170039b969bSMichele Scuttari 
1713610c82cSMartin Erhart std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() {
172039b969bSMichele Scuttari   return std::make_unique<BufferizationToMemRefPass>();
173039b969bSMichele Scuttari }
174