1ae1ea0beSJulian Gross //===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===// 2ae1ea0beSJulian Gross // 3ae1ea0beSJulian Gross // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ae1ea0beSJulian Gross // See https://llvm.org/LICENSE.txt for license information. 5ae1ea0beSJulian Gross // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ae1ea0beSJulian Gross // 7ae1ea0beSJulian Gross //===----------------------------------------------------------------------===// 8ae1ea0beSJulian Gross // 9ae1ea0beSJulian Gross // This file implements patterns to convert Bufferization dialect to MemRef 10ae1ea0beSJulian Gross // dialect. 11ae1ea0beSJulian Gross // 12ae1ea0beSJulian Gross //===----------------------------------------------------------------------===// 13ae1ea0beSJulian Gross 14ae1ea0beSJulian Gross #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" 1567d0d7acSMichele Scuttari 16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 17ae1ea0beSJulian Gross #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 18950f0944SMartin Erhart #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 1907c079a9SMartin Erhart #include "mlir/Dialect/Func/IR/FuncOps.h" 20ae1ea0beSJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h" 2107c079a9SMartin Erhart #include "mlir/Dialect/SCF/IR/SCF.h" 22ae1ea0beSJulian Gross #include "mlir/IR/BuiltinTypes.h" 2367d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h" 24ae1ea0beSJulian Gross #include "mlir/Transforms/DialectConversion.h" 25ae1ea0beSJulian Gross 2667d0d7acSMichele Scuttari namespace mlir { 2767d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREF 2867d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 2967d0d7acSMichele Scuttari } // namespace mlir 3067d0d7acSMichele Scuttari 31ae1ea0beSJulian Gross using namespace mlir; 32ae1ea0beSJulian Gross 33ae1ea0beSJulian Gross namespace { 34ae1ea0beSJulian Gross /// The CloneOpConversion transforms all bufferization clone operations into 35ae1ea0beSJulian Gross /// memref alloc and memref copy operations. In the dynamic-shape case, it also 36ae1ea0beSJulian Gross /// emits additional dim and constant operations to determine the shape. This 37ae1ea0beSJulian Gross /// conversion does not resolve memory leaks if it is used alone. 38ae1ea0beSJulian Gross struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { 39ae1ea0beSJulian Gross using OpConversionPattern<bufferization::CloneOp>::OpConversionPattern; 40ae1ea0beSJulian Gross 41ae1ea0beSJulian Gross LogicalResult 42ae1ea0beSJulian Gross matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor, 43ae1ea0beSJulian Gross ConversionPatternRewriter &rewriter) const override { 4428d6aa90Sryankima Location loc = op->getLoc(); 4528d6aa90Sryankima 46ae1ea0beSJulian Gross Type type = op.getType(); 4728d6aa90Sryankima Value alloc; 4828d6aa90Sryankima 4928d6aa90Sryankima if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) { 5028d6aa90Sryankima // Constants 5128d6aa90Sryankima Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 5228d6aa90Sryankima Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 5328d6aa90Sryankima 5428d6aa90Sryankima // Dynamically evaluate the size and shape of the unranked memref 5528d6aa90Sryankima Value rank = rewriter.create<memref::RankOp>(loc, op.getInput()); 5628d6aa90Sryankima MemRefType allocType = 5728d6aa90Sryankima MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); 5828d6aa90Sryankima Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank); 5928d6aa90Sryankima 6028d6aa90Sryankima // Create a loop to query dimension sizes, store them as a shape, and 6128d6aa90Sryankima // compute the total size of the memref 6228d6aa90Sryankima auto loopBody = [&](OpBuilder &builder, Location loc, Value i, 6328d6aa90Sryankima ValueRange args) { 6428d6aa90Sryankima auto acc = args.front(); 6528d6aa90Sryankima auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i); 6628d6aa90Sryankima 6728d6aa90Sryankima rewriter.create<memref::StoreOp>(loc, dim, shape, i); 6828d6aa90Sryankima acc = rewriter.create<arith::MulIOp>(loc, acc, dim); 6928d6aa90Sryankima 7028d6aa90Sryankima rewriter.create<scf::YieldOp>(loc, acc); 7128d6aa90Sryankima }; 7228d6aa90Sryankima auto size = rewriter 7328d6aa90Sryankima .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one), 7428d6aa90Sryankima loopBody) 7528d6aa90Sryankima .getResult(0); 7628d6aa90Sryankima 7728d6aa90Sryankima MemRefType memrefType = MemRefType::get({ShapedType::kDynamic}, 7828d6aa90Sryankima unrankedType.getElementType()); 7928d6aa90Sryankima 8028d6aa90Sryankima // Allocate new memref with 1D dynamic shape, then reshape into the 8128d6aa90Sryankima // shape of the original unranked memref 8228d6aa90Sryankima alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size); 8328d6aa90Sryankima alloc = 8428d6aa90Sryankima rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape); 8528d6aa90Sryankima } else { 865550c821STres Popp MemRefType memrefType = cast<MemRefType>(type); 878dca38d5SMatthias Springer MemRefLayoutAttrInterface layout; 888dca38d5SMatthias Springer auto allocType = 898dca38d5SMatthias Springer MemRefType::get(memrefType.getShape(), memrefType.getElementType(), 908dca38d5SMatthias Springer layout, memrefType.getMemorySpace()); 9128d6aa90Sryankima // Since this implementation always allocates, certain result types of 9228d6aa90Sryankima // the clone op cannot be lowered. 938dca38d5SMatthias Springer if (!memref::CastOp::areCastCompatible({allocType}, {memrefType})) 948dca38d5SMatthias Springer return failure(); 95ae1ea0beSJulian Gross 96ae1ea0beSJulian Gross // Transform a clone operation into alloc + copy operation and pay 97ae1ea0beSJulian Gross // attention to the shape dimensions. 98ae1ea0beSJulian Gross SmallVector<Value, 4> dynamicOperands; 99ae1ea0beSJulian Gross for (int i = 0; i < memrefType.getRank(); ++i) { 100ae1ea0beSJulian Gross if (!memrefType.isDynamicDim(i)) 101ae1ea0beSJulian Gross continue; 102b23c8225SMatthias Springer Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i); 103ae1ea0beSJulian Gross dynamicOperands.push_back(dim); 104ae1ea0beSJulian Gross } 1058dca38d5SMatthias Springer 1068dca38d5SMatthias Springer // Allocate a memref with identity layout. 10728d6aa90Sryankima alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands); 1088dca38d5SMatthias Springer // Cast the allocation to the specified type if needed. 1098dca38d5SMatthias Springer if (memrefType != allocType) 11028d6aa90Sryankima alloc = 11128d6aa90Sryankima rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc); 11228d6aa90Sryankima } 11328d6aa90Sryankima 1148dca38d5SMatthias Springer rewriter.replaceOp(op, alloc); 11599260e95SMatthias Springer rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc); 116ae1ea0beSJulian Gross return success(); 117ae1ea0beSJulian Gross } 118ae1ea0beSJulian Gross }; 119ae1ea0beSJulian Gross 12007c079a9SMartin Erhart } // namespace 12107c079a9SMartin Erhart 122ae1ea0beSJulian Gross namespace { 123039b969bSMichele Scuttari struct BufferizationToMemRefPass 12467d0d7acSMichele Scuttari : public impl::ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> { 125039b969bSMichele Scuttari BufferizationToMemRefPass() = default; 126ae1ea0beSJulian Gross 127ae1ea0beSJulian Gross void runOnOperation() override { 1283610c82cSMartin Erhart if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) { 1293610c82cSMartin Erhart emitError(getOperation()->getLoc(), 1303610c82cSMartin Erhart "root operation must be a builtin.module or a function"); 1313610c82cSMartin Erhart signalPassFailure(); 1323610c82cSMartin Erhart return; 1333610c82cSMartin Erhart } 1343610c82cSMartin Erhart 135662c6fc7Sdonald chen bufferization::DeallocHelperMap deallocHelperFuncMap; 1363610c82cSMartin Erhart if (auto module = dyn_cast<ModuleOp>(getOperation())) { 137*49df12c0SMatthias Springer OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 13807c079a9SMartin Erhart 13907c079a9SMartin Erhart // Build dealloc helper function if there are deallocs. 14007c079a9SMartin Erhart getOperation()->walk([&](bufferization::DeallocOp deallocOp) { 141662c6fc7Sdonald chen Operation *symtableOp = 142662c6fc7Sdonald chen deallocOp->getParentWithTrait<OpTrait::SymbolTable>(); 143662c6fc7Sdonald chen if (deallocOp.getMemrefs().size() > 1 && 144662c6fc7Sdonald chen !deallocHelperFuncMap.contains(symtableOp)) { 145662c6fc7Sdonald chen SymbolTable symbolTable(symtableOp); 146662c6fc7Sdonald chen func::FuncOp helperFuncOp = 147662c6fc7Sdonald chen bufferization::buildDeallocationLibraryFunction( 14807c079a9SMartin Erhart builder, getOperation()->getLoc(), symbolTable); 149662c6fc7Sdonald chen deallocHelperFuncMap[symtableOp] = helperFuncOp; 15007c079a9SMartin Erhart } 15107c079a9SMartin Erhart }); 1523610c82cSMartin Erhart } 15307c079a9SMartin Erhart 154ae1ea0beSJulian Gross RewritePatternSet patterns(&getContext()); 15507c079a9SMartin Erhart patterns.add<CloneOpConversion>(patterns.getContext()); 156662c6fc7Sdonald chen bufferization::populateBufferizationDeallocLoweringPattern( 157662c6fc7Sdonald chen patterns, deallocHelperFuncMap); 158ae1ea0beSJulian Gross 159ae1ea0beSJulian Gross ConversionTarget target(getContext()); 16007c079a9SMartin Erhart target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect, 16107c079a9SMartin Erhart scf::SCFDialect, func::FuncDialect>(); 162ae1ea0beSJulian Gross target.addIllegalDialect<bufferization::BufferizationDialect>(); 163ae1ea0beSJulian Gross 164ae1ea0beSJulian Gross if (failed(applyPartialConversion(getOperation(), target, 165ae1ea0beSJulian Gross std::move(patterns)))) 166ae1ea0beSJulian Gross signalPassFailure(); 167ae1ea0beSJulian Gross } 168ae1ea0beSJulian Gross }; 169ae1ea0beSJulian Gross } // namespace 170039b969bSMichele Scuttari 1713610c82cSMartin Erhart std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() { 172039b969bSMichele Scuttari return std::make_unique<BufferizationToMemRefPass>(); 173039b969bSMichele Scuttari } 174