1 //===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Bufferization dialect to MemRef 10 // dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/IR/SCF.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 26 namespace mlir { 27 #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREF 28 #include "mlir/Conversion/Passes.h.inc" 29 } // namespace mlir 30 31 using namespace mlir; 32 33 namespace { 34 /// The CloneOpConversion transforms all bufferization clone operations into 35 /// memref alloc and memref copy operations. In the dynamic-shape case, it also 36 /// emits additional dim and constant operations to determine the shape. This 37 /// conversion does not resolve memory leaks if it is used alone. 38 struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { 39 using OpConversionPattern<bufferization::CloneOp>::OpConversionPattern; 40 41 LogicalResult 42 matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor, 43 ConversionPatternRewriter &rewriter) const override { 44 Location loc = op->getLoc(); 45 46 Type type = op.getType(); 47 Value alloc; 48 49 if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) { 50 // Constants 51 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 52 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 53 54 // Dynamically evaluate the size and shape of the unranked memref 55 Value rank = rewriter.create<memref::RankOp>(loc, op.getInput()); 56 MemRefType allocType = 57 MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); 58 Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank); 59 60 // Create a loop to query dimension sizes, store them as a shape, and 61 // compute the total size of the memref 62 auto loopBody = [&](OpBuilder &builder, Location loc, Value i, 63 ValueRange args) { 64 auto acc = args.front(); 65 auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i); 66 67 rewriter.create<memref::StoreOp>(loc, dim, shape, i); 68 acc = rewriter.create<arith::MulIOp>(loc, acc, dim); 69 70 rewriter.create<scf::YieldOp>(loc, acc); 71 }; 72 auto size = rewriter 73 .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one), 74 loopBody) 75 .getResult(0); 76 77 MemRefType memrefType = MemRefType::get({ShapedType::kDynamic}, 78 unrankedType.getElementType()); 79 80 // Allocate new memref with 1D dynamic shape, then reshape into the 81 // shape of the original unranked memref 82 alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size); 83 alloc = 84 rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape); 85 } else { 86 MemRefType memrefType = cast<MemRefType>(type); 87 MemRefLayoutAttrInterface layout; 88 auto allocType = 89 MemRefType::get(memrefType.getShape(), memrefType.getElementType(), 90 layout, memrefType.getMemorySpace()); 91 // Since this implementation always allocates, certain result types of 92 // the clone op cannot be lowered. 93 if (!memref::CastOp::areCastCompatible({allocType}, {memrefType})) 94 return failure(); 95 96 // Transform a clone operation into alloc + copy operation and pay 97 // attention to the shape dimensions. 98 SmallVector<Value, 4> dynamicOperands; 99 for (int i = 0; i < memrefType.getRank(); ++i) { 100 if (!memrefType.isDynamicDim(i)) 101 continue; 102 Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i); 103 dynamicOperands.push_back(dim); 104 } 105 106 // Allocate a memref with identity layout. 107 alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands); 108 // Cast the allocation to the specified type if needed. 109 if (memrefType != allocType) 110 alloc = 111 rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc); 112 } 113 114 rewriter.replaceOp(op, alloc); 115 rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc); 116 return success(); 117 } 118 }; 119 120 } // namespace 121 122 namespace { 123 struct BufferizationToMemRefPass 124 : public impl::ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> { 125 BufferizationToMemRefPass() = default; 126 127 void runOnOperation() override { 128 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) { 129 emitError(getOperation()->getLoc(), 130 "root operation must be a builtin.module or a function"); 131 signalPassFailure(); 132 return; 133 } 134 135 bufferization::DeallocHelperMap deallocHelperFuncMap; 136 if (auto module = dyn_cast<ModuleOp>(getOperation())) { 137 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 138 139 // Build dealloc helper function if there are deallocs. 140 getOperation()->walk([&](bufferization::DeallocOp deallocOp) { 141 Operation *symtableOp = 142 deallocOp->getParentWithTrait<OpTrait::SymbolTable>(); 143 if (deallocOp.getMemrefs().size() > 1 && 144 !deallocHelperFuncMap.contains(symtableOp)) { 145 SymbolTable symbolTable(symtableOp); 146 func::FuncOp helperFuncOp = 147 bufferization::buildDeallocationLibraryFunction( 148 builder, getOperation()->getLoc(), symbolTable); 149 deallocHelperFuncMap[symtableOp] = helperFuncOp; 150 } 151 }); 152 } 153 154 RewritePatternSet patterns(&getContext()); 155 patterns.add<CloneOpConversion>(patterns.getContext()); 156 bufferization::populateBufferizationDeallocLoweringPattern( 157 patterns, deallocHelperFuncMap); 158 159 ConversionTarget target(getContext()); 160 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect, 161 scf::SCFDialect, func::FuncDialect>(); 162 target.addIllegalDialect<bufferization::BufferizationDialect>(); 163 164 if (failed(applyPartialConversion(getOperation(), target, 165 std::move(patterns)))) 166 signalPassFailure(); 167 } 168 }; 169 } // namespace 170 171 std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() { 172 return std::make_unique<BufferizationToMemRefPass>(); 173 } 174