1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 10 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/Pass/Pass.h" 15 16 namespace mlir { 17 namespace bufferization { 18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS 19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 20 } // namespace bufferization 21 } // namespace mlir 22 23 using namespace mlir; 24 using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; 25 26 /// Return `true` if the given MemRef type has a fully dynamic layout. 27 static bool hasFullyDynamicLayoutMap(MemRefType type) { 28 int64_t offset; 29 SmallVector<int64_t, 4> strides; 30 if (failed(getStridesAndOffset(type, strides, offset))) 31 return false; 32 if (!llvm::all_of(strides, ShapedType::isDynamic)) 33 return false; 34 if (!ShapedType::isDynamic(offset)) 35 return false; 36 return true; 37 } 38 39 /// Return `true` if the given MemRef type has a static identity layout (i.e., 40 /// no layout). 41 static bool hasStaticIdentityLayout(MemRefType type) { 42 return type.getLayout().isIdentity(); 43 } 44 45 // Updates the func op and entry block. 46 // 47 // Any args appended to the entry block are added to `appendedEntryArgs`. 48 // If `addResultAttribute` is true, adds the unit attribute `bufferize.result` 49 // to each newly created function argument. 50 static LogicalResult 51 updateFuncOp(func::FuncOp func, 52 SmallVectorImpl<BlockArgument> &appendedEntryArgs, 53 bool addResultAttribute) { 54 auto functionType = func.getFunctionType(); 55 56 // Collect information about the results will become appended arguments. 57 SmallVector<Type, 6> erasedResultTypes; 58 BitVector erasedResultIndices(functionType.getNumResults()); 59 for (const auto &resultType : llvm::enumerate(functionType.getResults())) { 60 if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) { 61 if (!hasStaticIdentityLayout(memrefType) && 62 !hasFullyDynamicLayoutMap(memrefType)) { 63 // Only buffers with static identity layout can be allocated. These can 64 // be casted to memrefs with fully dynamic layout map. Other layout maps 65 // are not supported. 66 return func->emitError() 67 << "cannot create out param for result with unsupported layout"; 68 } 69 erasedResultIndices.set(resultType.index()); 70 erasedResultTypes.push_back(memrefType); 71 } 72 } 73 74 // Add the new arguments to the function type. 75 auto newArgTypes = llvm::to_vector<6>( 76 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); 77 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, 78 functionType.getResults()); 79 func.setType(newFunctionType); 80 81 // Transfer the result attributes to arg attributes. 82 auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); 83 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { 84 func.setArgAttrs(functionType.getNumInputs() + i, 85 func.getResultAttrs(*erasedIndicesIt)); 86 if (addResultAttribute) 87 func.setArgAttr(functionType.getNumInputs() + i, 88 StringAttr::get(func.getContext(), "bufferize.result"), 89 UnitAttr::get(func.getContext())); 90 } 91 92 // Erase the results. 93 func.eraseResults(erasedResultIndices); 94 95 // Add the new arguments to the entry block if the function is not external. 96 if (func.isExternal()) 97 return success(); 98 Location loc = func.getLoc(); 99 for (Type type : erasedResultTypes) 100 appendedEntryArgs.push_back(func.front().addArgument(type, loc)); 101 102 return success(); 103 } 104 105 // Updates all ReturnOps in the scope of the given func::FuncOp by either 106 // keeping them as return values or copying the associated buffer contents into 107 // the given out-params. 108 static LogicalResult updateReturnOps(func::FuncOp func, 109 ArrayRef<BlockArgument> appendedEntryArgs, 110 MemCpyFn memCpyFn, 111 bool hoistStaticAllocs) { 112 auto res = func.walk([&](func::ReturnOp op) { 113 SmallVector<Value, 6> copyIntoOutParams; 114 SmallVector<Value, 6> keepAsReturnOperands; 115 for (Value operand : op.getOperands()) { 116 if (isa<MemRefType>(operand.getType())) 117 copyIntoOutParams.push_back(operand); 118 else 119 keepAsReturnOperands.push_back(operand); 120 } 121 OpBuilder builder(op); 122 for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) { 123 if (hoistStaticAllocs && 124 isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) && 125 mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) { 126 orig.replaceAllUsesWith(arg); 127 orig.getDefiningOp()->erase(); 128 } else { 129 if (failed(memCpyFn(builder, op.getLoc(), orig, arg))) 130 return WalkResult::interrupt(); 131 } 132 } 133 builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands); 134 op.erase(); 135 return WalkResult::advance(); 136 }); 137 return failure(res.wasInterrupted()); 138 } 139 140 // Updates all CallOps in the scope of the given ModuleOp by allocating 141 // temporary buffers for newly introduced out params. 142 static LogicalResult 143 updateCalls(ModuleOp module, 144 const bufferization::BufferResultsToOutParamsOpts &options) { 145 bool didFail = false; 146 SymbolTable symtab(module); 147 module.walk([&](func::CallOp op) { 148 auto callee = symtab.lookup<func::FuncOp>(op.getCallee()); 149 if (!callee) { 150 op.emitError() << "cannot find callee '" << op.getCallee() << "' in " 151 << "symbol table"; 152 didFail = true; 153 return; 154 } 155 if (!options.filterFn(&callee)) 156 return; 157 SmallVector<Value, 6> replaceWithNewCallResults; 158 SmallVector<Value, 6> replaceWithOutParams; 159 for (OpResult result : op.getResults()) { 160 if (isa<MemRefType>(result.getType())) 161 replaceWithOutParams.push_back(result); 162 else 163 replaceWithNewCallResults.push_back(result); 164 } 165 SmallVector<Value, 6> outParams; 166 OpBuilder builder(op); 167 for (Value memref : replaceWithOutParams) { 168 if (!cast<MemRefType>(memref.getType()).hasStaticShape()) { 169 op.emitError() 170 << "cannot create out param for dynamically shaped result"; 171 didFail = true; 172 return; 173 } 174 auto memrefType = cast<MemRefType>(memref.getType()); 175 auto allocType = 176 MemRefType::get(memrefType.getShape(), memrefType.getElementType(), 177 AffineMap(), memrefType.getMemorySpace()); 178 Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType); 179 if (!hasStaticIdentityLayout(memrefType)) { 180 // Layout maps are already checked in `updateFuncOp`. 181 assert(hasFullyDynamicLayoutMap(memrefType) && 182 "layout map not supported"); 183 outParam = 184 builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam); 185 } 186 memref.replaceAllUsesWith(outParam); 187 outParams.push_back(outParam); 188 } 189 190 auto newOperands = llvm::to_vector<6>(op.getOperands()); 191 newOperands.append(outParams.begin(), outParams.end()); 192 auto newResultTypes = llvm::to_vector<6>(llvm::map_range( 193 replaceWithNewCallResults, [](Value v) { return v.getType(); })); 194 auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(), 195 newResultTypes, newOperands); 196 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) 197 std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); 198 op.erase(); 199 }); 200 201 return failure(didFail); 202 } 203 204 LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( 205 ModuleOp module, 206 const bufferization::BufferResultsToOutParamsOpts &options) { 207 for (auto func : module.getOps<func::FuncOp>()) { 208 if (!options.filterFn(&func)) 209 continue; 210 SmallVector<BlockArgument, 6> appendedEntryArgs; 211 if (failed( 212 updateFuncOp(func, appendedEntryArgs, options.addResultAttribute))) 213 return failure(); 214 if (func.isExternal()) 215 continue; 216 auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from, 217 Value to) { 218 builder.create<memref::CopyOp>(loc, from, to); 219 return success(); 220 }; 221 if (failed(updateReturnOps(func, appendedEntryArgs, 222 options.memCpyFn.value_or(defaultMemCpyFn), 223 options.hoistStaticAllocs))) { 224 return failure(); 225 } 226 } 227 if (failed(updateCalls(module, options))) 228 return failure(); 229 return success(); 230 } 231 232 namespace { 233 struct BufferResultsToOutParamsPass 234 : bufferization::impl::BufferResultsToOutParamsBase< 235 BufferResultsToOutParamsPass> { 236 explicit BufferResultsToOutParamsPass( 237 const bufferization::BufferResultsToOutParamsOpts &options) 238 : options(options) {} 239 240 void runOnOperation() override { 241 // Convert from pass options in tablegen to BufferResultsToOutParamsOpts. 242 if (addResultAttribute) 243 options.addResultAttribute = true; 244 if (hoistStaticAllocs) 245 options.hoistStaticAllocs = true; 246 247 if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), 248 options))) 249 return signalPassFailure(); 250 } 251 252 private: 253 bufferization::BufferResultsToOutParamsOpts options; 254 }; 255 } // namespace 256 257 std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass( 258 const bufferization::BufferResultsToOutParamsOpts &options) { 259 return std::make_unique<BufferResultsToOutParamsPass>(options); 260 } 261