1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/IR/Dialect.h" 18 #include "mlir/IR/Operation.h" 19 #include <optional> 20 21 namespace mlir { 22 /// Return all func.return ops in the given function. 23 SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) { 24 SmallVector<func::ReturnOp> result; 25 for (Block &b : funcOp.getBody()) 26 if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator())) 27 result.push_back(returnOp); 28 return result; 29 } 30 31 namespace bufferization { 32 namespace func_ext { 33 34 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { 35 analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; 36 auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); 37 auto createdAliasingResults = 38 aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping()); 39 auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); 40 auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); 41 (void)createdEquiv; 42 (void)createdAliasingResults; 43 (void)createdRead; 44 (void)createdWritten; 45 #ifndef NDEBUG 46 assert(createdEquiv.second && "equivalence info exists already"); 47 assert(createdAliasingResults.second && "aliasing info exists already"); 48 assert(createdRead.second && "bbarg access info exists already"); 49 assert(createdWritten.second && "bbarg access info exists already"); 50 #endif // NDEBUG 51 } 52 53 /// Return the index-th bufferized function argument type. This assumes that the 54 /// specified argument is a tensor. If the tensor is ranked, a layout map may be 55 /// specified by the user (as per `options.functionArgTypeConverterFn`). 56 static BaseMemRefType 57 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, 58 const BufferizationOptions &options) { 59 auto tensorType = 60 dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index)); 61 assert(tensorType && "expected TensorType"); 62 63 BaseMemRefType memrefType = options.functionArgTypeConverterFn( 64 tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); 65 66 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( 67 index, BufferizationDialect::kBufferLayoutAttrName); 68 if (!layoutAttr) 69 return memrefType; 70 71 auto rankedMemrefType = dyn_cast<MemRefType>(memrefType); 72 assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); 73 return MemRefType::get( 74 rankedMemrefType.getShape(), rankedMemrefType.getElementType(), 75 layoutAttr.getValue(), rankedMemrefType.getMemorySpace()); 76 } 77 78 /// Return the FuncOp called by `callOp`. 79 static FuncOp getCalledFunction(CallOpInterface callOp) { 80 SymbolRefAttr sym = 81 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); 82 if (!sym) 83 return nullptr; 84 return dyn_cast_or_null<FuncOp>( 85 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 86 } 87 88 /// Get FuncAnalysisState. 89 static const FuncAnalysisState & 90 getFuncAnalysisState(const AnalysisState &state) { 91 assert(isa<OneShotAnalysisState>(state) && "expected OneShotAnalysisState"); 92 auto *result = static_cast<const OneShotAnalysisState &>(state) 93 .getExtension<FuncAnalysisState>(); 94 assert(result && "FuncAnalysisState does not exist"); 95 return *result; 96 } 97 98 /// Return the state (phase) of analysis of the FuncOp. 99 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 100 FuncOp funcOp) { 101 if (!isa<OneShotAnalysisState>(state)) 102 return FuncOpAnalysisState::NotAnalyzed; 103 auto *funcState = static_cast<const OneShotAnalysisState &>(state) 104 .getExtension<FuncAnalysisState>(); 105 if (!funcState) 106 return FuncOpAnalysisState::NotAnalyzed; 107 const auto &analyzedFuncOps = funcState->analyzedFuncOps; 108 auto it = analyzedFuncOps.find(funcOp); 109 if (it == analyzedFuncOps.end()) 110 return FuncOpAnalysisState::NotAnalyzed; 111 return it->second; 112 } 113 114 /// Return the index of the bbArg in the given FuncOp that is equivalent to the 115 /// specified return value (if any). 116 static std::optional<int64_t> 117 getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, 118 int64_t returnValIdx) { 119 auto funcOpIt = state.equivalentFuncArgs.find(funcOp); 120 if (funcOpIt == state.equivalentFuncArgs.end()) 121 // No equivalence info stores for funcOp. 122 return std::nullopt; 123 124 auto retValIt = funcOpIt->getSecond().find(returnValIdx); 125 if (retValIt == funcOpIt->getSecond().end()) 126 // Return value has no equivalent bbArg. 127 return std::nullopt; 128 129 return retValIt->getSecond(); 130 } 131 132 struct CallOpInterface 133 : public BufferizableOpInterface::ExternalModel<CallOpInterface, 134 func::CallOp> { 135 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 136 const AnalysisState &state) const { 137 func::CallOp callOp = cast<func::CallOp>(op); 138 FuncOp funcOp = getCalledFunction(callOp); 139 assert(funcOp && "expected CallOp to a FuncOp"); 140 141 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 142 // FuncOp not analyzed yet. Assume that OpOperand is read. 143 return true; 144 145 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 146 return funcState.readBbArgs.lookup(funcOp).contains( 147 opOperand.getOperandNumber()); 148 } 149 150 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 151 const AnalysisState &state) const { 152 func::CallOp callOp = cast<func::CallOp>(op); 153 FuncOp funcOp = getCalledFunction(callOp); 154 assert(funcOp && "expected CallOp to a FuncOp"); 155 156 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 157 // FuncOp not analyzed yet. Assume that OpOperand is written. 158 return true; 159 160 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 161 return funcState.writtenBbArgs.lookup(funcOp).contains( 162 opOperand.getOperandNumber()); 163 } 164 165 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 166 const AnalysisState &state) const { 167 func::CallOp callOp = cast<func::CallOp>(op); 168 FuncOp funcOp = getCalledFunction(callOp); 169 assert(funcOp && "expected CallOp to a FuncOp"); 170 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 171 // FuncOp not analyzed yet. Any OpResult may be aliasing. 172 return detail::unknownGetAliasingValues(opOperand); 173 174 // Get aliasing results from state. 175 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 176 auto aliasingReturnVals = 177 funcState.aliasingReturnVals.lookup(funcOp).lookup( 178 opOperand.getOperandNumber()); 179 180 // Check if the aliasing OpResult is equivalent to the OpOperand. 181 std::optional<int64_t> equivalent = {}; 182 if (aliasingReturnVals.size() == 1) { 183 equivalent = getEquivalentFuncArgIdx(funcOp, funcState, 184 aliasingReturnVals.front()); 185 assert((!equivalent.has_value() || 186 *equivalent == opOperand.getOperandNumber()) && 187 "inconsistent analysis state"); 188 } 189 AliasingValueList result; 190 for (int64_t resultIdx : aliasingReturnVals) 191 result.addAlias({callOp->getOpResult(resultIdx), 192 equivalent.has_value() ? BufferRelation::Equivalent 193 : BufferRelation::Unknown, 194 /*isDefinite=*/equivalent.has_value()}); 195 return result; 196 } 197 198 FailureOr<BaseMemRefType> 199 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 200 SmallVector<Value> &invocationStack) const { 201 auto callOp = cast<func::CallOp>(op); 202 FuncOp funcOp = getCalledFunction(callOp); 203 assert(funcOp && "expected CallOp to a FuncOp"); 204 205 // If the callee was already bufferized, we can directly take the type from 206 // its signature. 207 FunctionType funcType = funcOp.getFunctionType(); 208 Type resultType = 209 funcType.getResult(cast<OpResult>(value).getResultNumber()); 210 if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType)) 211 return bufferizedType; 212 213 // Otherwise, call the type converter to compute the bufferized type. 214 auto tensorType = cast<TensorType>(resultType); 215 return options.functionArgTypeConverterFn( 216 tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); 217 } 218 219 /// All function arguments are writable. It is the responsibility of the 220 /// CallOp to insert buffer copies where necessary. 221 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 222 const BufferizationOptions &options) const { 223 func::CallOp callOp = cast<func::CallOp>(op); 224 225 // 1. Compute the result types of the new CallOp. 226 SmallVector<Type> resultTypes; 227 for (Value result : callOp.getResults()) { 228 Type returnType = result.getType(); 229 if (!isa<TensorType>(returnType)) { 230 // Non-tensor values are returned. 231 resultTypes.push_back(returnType); 232 continue; 233 } 234 235 // Returning a memref. 236 FailureOr<BaseMemRefType> resultType = 237 bufferization::getBufferType(result, options); 238 if (failed(resultType)) 239 return failure(); 240 resultTypes.push_back(*resultType); 241 } 242 243 // 2. Rewrite tensor operands as memrefs based on type of the already 244 // bufferized callee. 245 SmallVector<Value> newOperands; 246 FuncOp funcOp = getCalledFunction(callOp); 247 assert(funcOp && "expected CallOp to a FuncOp"); 248 FunctionType funcType = funcOp.getFunctionType(); 249 250 for (OpOperand &opOperand : callOp->getOpOperands()) { 251 // Non-tensor operands are just copied. 252 if (!isa<TensorType>(opOperand.get().getType())) { 253 newOperands.push_back(opOperand.get()); 254 continue; 255 } 256 257 // Retrieve buffers for tensor operands. 258 FailureOr<Value> maybeBuffer = 259 getBuffer(rewriter, opOperand.get(), options); 260 if (failed(maybeBuffer)) 261 return failure(); 262 Value buffer = *maybeBuffer; 263 264 // Caller / callee type mismatch is handled with castOrReallocMemRefValue. 265 auto memRefType = funcType.getInput(opOperand.getOperandNumber()); 266 if (!isa<BaseMemRefType>(memRefType)) { 267 // The called function was not bufferized yet. This can happen when 268 // there cycles in the function call graph. Compute the bufferized 269 // result type. 270 FailureOr<BaseMemRefType> maybeMemRefType = 271 bufferization::getBufferType( 272 funcOp.getArgument(opOperand.getOperandNumber()), options); 273 if (failed(maybeMemRefType)) 274 return failure(); 275 memRefType = *maybeMemRefType; 276 } 277 278 // Since we don't yet have a clear layout story, to_memref may 279 // conservatively turn tensors into more dynamic memref than necessary. 280 // If the memref type of the callee fails, introduce an extra memref.cast 281 // that will either canonicalize away or fail compilation until we can do 282 // something better. Insert a reallocation + copy if it cannot be 283 // statically guaranteed that a direct cast would be valid. 284 if (buffer.getType() != memRefType) { 285 auto memrefDstType = dyn_cast<MemRefType>(memRefType); 286 assert(memrefDstType && 287 "buffer layout not supported on unranked tensors"); 288 FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue( 289 rewriter, buffer, memrefDstType, options); 290 if (failed(replacement)) 291 return failure(); 292 buffer = *replacement; 293 } 294 newOperands.push_back(buffer); 295 } 296 297 // 3. Create the new CallOp. 298 Operation *newCallOp = rewriter.create<func::CallOp>( 299 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); 300 newCallOp->setAttrs(callOp->getAttrs()); 301 302 // 4. Replace the old op with the new op. 303 replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults()); 304 305 return success(); 306 } 307 }; 308 309 struct ReturnOpInterface 310 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface, 311 func::ReturnOp> { 312 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 313 const AnalysisState &state) const { 314 return true; 315 } 316 317 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 318 const AnalysisState &state) const { 319 return false; 320 } 321 322 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, 323 const AnalysisState &state) const { 324 return {}; 325 } 326 327 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 328 const BufferizationOptions &options) const { 329 #ifndef NDEBUG 330 auto returnOp = cast<func::ReturnOp>(op); 331 assert(isa<FuncOp>(returnOp->getParentOp()) && 332 "only support FuncOp parent for ReturnOp"); 333 #endif // NDEBUG 334 335 // ReturnOps are bufferized as part of FuncOps. 336 return success(); 337 } 338 }; 339 340 struct FuncOpInterface 341 : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< 342 FuncOpInterface, FuncOp> { 343 344 static bool supportsUnstructuredControlFlow() { return true; } 345 346 bool hasTensorSemantics(Operation *op) const { 347 auto isaTensor = llvm::IsaPred<TensorType>; 348 349 // A function has tensor semantics if it has tensor arguments/results. 350 auto funcOp = cast<FuncOp>(op); 351 bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); 352 bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); 353 if (hasTensorArg || hasTensorResult) 354 return true; 355 356 // It also has tensor semantics if it has tensor block arguments. 357 // TODO: Decouple bufferization of unstructured control flow from 358 // BufferizableOpInterface implementations. We should only care about 359 // region entry block arguments here (which are already covered by the 360 // argument types of the function). 361 for (Block &block : funcOp.getBody()) 362 if (any_of(block.getArgumentTypes(), isaTensor)) 363 return true; 364 365 return false; 366 } 367 368 AliasingOpOperandList 369 getAliasingOpOperands(Operation *op, Value value, 370 const AnalysisState &state) const { 371 return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state); 372 } 373 374 FailureOr<BaseMemRefType> 375 getBufferType(Operation *op, Value value, const BufferizationOptions &options, 376 SmallVector<Value> &invocationStack) const { 377 auto funcOp = cast<FuncOp>(op); 378 auto bbArg = cast<BlockArgument>(value); 379 380 // Function arguments are special. 381 if (bbArg.getOwner() == &funcOp.getBody().front()) 382 return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), 383 options); 384 385 return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: 386 getBufferType(op, value, options, invocationStack); 387 } 388 389 /// Rewrite function bbArgs and return values into buffer form. This function 390 /// bufferizes the function signature and the ReturnOp. When the entire 391 /// function body has been bufferized, function return types can be switched 392 /// to more concise memref types as part of `foldMemRefCasts`. 393 /// 394 /// All function bbArgs are writable unless they are explicitly marked as 395 /// read-only. Callers must insert copies when needed. 396 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 397 const BufferizationOptions &options) const { 398 auto funcOp = cast<FuncOp>(op); 399 FunctionType funcType = funcOp.getFunctionType(); 400 401 // Compute the argument types. 402 SmallVector<Type> argTypes; 403 for (const auto &it : llvm::enumerate(funcType.getInputs())) { 404 Type argType = it.value(); 405 if (isa<TensorType>(argType)) { 406 argTypes.push_back( 407 getBufferizedFunctionArgType(funcOp, it.index(), options)); 408 continue; 409 } 410 argTypes.push_back(argType); 411 } 412 413 // Compute the result types. 414 SmallVector<Type> retTypes; 415 for (Type resultType : funcType.getResults()) { 416 if (auto tensorType = dyn_cast<TensorType>(resultType)) { 417 BaseMemRefType resultType = options.functionArgTypeConverterFn( 418 tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, 419 options); 420 retTypes.push_back(resultType); 421 continue; 422 } 423 retTypes.push_back(resultType); 424 } 425 426 // Compute the new function type. 427 auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes); 428 429 // If the function has no body, set the new function type and we are done. 430 if (funcOp.isExternal()) { 431 funcOp.setType(newFuncType); 432 return success(); 433 } 434 435 // 1. Bufferize every block. 436 for (Block &block : funcOp.getBody()) 437 if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, 438 options))) 439 return failure(); 440 441 // 2. Bufferize the operands of the all return op. 442 for (func::ReturnOp returnOp : getReturnOps(funcOp)) { 443 assert(returnOp->getNumOperands() == retTypes.size() && 444 "incorrect number of return values"); 445 SmallVector<Value> returnValues; 446 for (auto [returnVal, bufferizedType] : 447 llvm::zip_equal(returnOp->getOperands(), retTypes)) { 448 auto tensorType = dyn_cast<TensorType>(returnVal.getType()); 449 rewriter.setInsertionPoint(returnOp); 450 451 // If not a tensor type just forward it. 452 if (!tensorType) { 453 returnValues.push_back(returnVal); 454 continue; 455 } 456 457 // Note: If `inferFunctionResultLayout = true`, casts are later folded 458 // away. 459 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 460 returnOp.getLoc(), bufferizedType, returnVal); 461 returnValues.push_back(toMemrefOp); 462 } 463 464 returnOp.getOperandsMutable().assign(returnValues); 465 } 466 467 // 3. Set the new function type. 468 funcOp.setType(newFuncType); 469 return success(); 470 } 471 472 /// Return `true` if the given function argument is writable. 473 bool isWritable(Operation *op, Value value, 474 const AnalysisState &state) const { 475 auto funcOp = cast<FuncOp>(op); 476 BlockArgument bbArg = dyn_cast<BlockArgument>(value); 477 assert(bbArg && "expected BlockArgument"); 478 479 // Non-entry block arguments are always writable. (They may alias with 480 // values that are not writable, which will turn them into read-only.) 481 if (bbArg.getOwner() != &funcOp.getBody().front()) 482 return true; 483 484 // "bufferization.writable" overrides other writability decisions. This is 485 // currently used for testing only. 486 if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>( 487 bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) 488 return writable.getValue(); 489 490 // All function arguments are writable by default. 491 return true; 492 } 493 }; 494 495 } // namespace func_ext 496 } // namespace bufferization 497 } // namespace mlir 498 499 void mlir::bufferization::func_ext:: 500 registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { 501 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 502 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx); 503 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx); 504 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx); 505 }); 506 } 507