12e2c0738SRiver Riddle //===- NormalizeMemRefs.cpp -----------------------------------------------===// 22e2c0738SRiver Riddle // 32e2c0738SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42e2c0738SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 52e2c0738SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62e2c0738SRiver Riddle // 72e2c0738SRiver Riddle //===----------------------------------------------------------------------===// 82e2c0738SRiver Riddle // 92e2c0738SRiver Riddle // This file implements an interprocedural pass to normalize memrefs to have 102e2c0738SRiver Riddle // identity layout maps. 112e2c0738SRiver Riddle // 122e2c0738SRiver Riddle //===----------------------------------------------------------------------===// 132e2c0738SRiver Riddle 142e2c0738SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h" 15a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Utils.h" 161f971e23SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 172e2c0738SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 182e2c0738SRiver Riddle #include "mlir/Dialect/MemRef/Transforms/Passes.h" 192e2c0738SRiver Riddle #include "llvm/ADT/SmallSet.h" 202e2c0738SRiver Riddle #include "llvm/Support/Debug.h" 212e2c0738SRiver Riddle 2267d0d7acSMichele Scuttari namespace mlir { 2367d0d7acSMichele Scuttari namespace memref { 2467d0d7acSMichele Scuttari #define GEN_PASS_DEF_NORMALIZEMEMREFS 2567d0d7acSMichele Scuttari #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 2667d0d7acSMichele Scuttari } // namespace memref 2767d0d7acSMichele Scuttari } // namespace mlir 2867d0d7acSMichele Scuttari 292e2c0738SRiver Riddle #define DEBUG_TYPE "normalize-memrefs" 302e2c0738SRiver Riddle 312e2c0738SRiver Riddle using namespace mlir; 324c48f016SMatthias Springer using namespace mlir::affine; 332e2c0738SRiver Riddle 342e2c0738SRiver Riddle namespace { 352e2c0738SRiver Riddle 362e2c0738SRiver Riddle /// All memrefs passed across functions with non-trivial layout maps are 372e2c0738SRiver Riddle /// converted to ones with trivial identity layout ones. 382e2c0738SRiver Riddle /// If all the memref types/uses in a function are normalizable, we treat 392e2c0738SRiver Riddle /// such functions as normalizable. Also, if a normalizable function is known 402e2c0738SRiver Riddle /// to call a non-normalizable function, we treat that function as 412e2c0738SRiver Riddle /// non-normalizable as well. We assume external functions to be normalizable. 4267d0d7acSMichele Scuttari struct NormalizeMemRefs 4367d0d7acSMichele Scuttari : public memref::impl::NormalizeMemRefsBase<NormalizeMemRefs> { 442e2c0738SRiver Riddle void runOnOperation() override; 4558ceae95SRiver Riddle void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp); 4658ceae95SRiver Riddle bool areMemRefsNormalizable(func::FuncOp funcOp); 4758ceae95SRiver Riddle void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp); 4858ceae95SRiver Riddle void setCalleesAndCallersNonNormalizable( 4958ceae95SRiver Riddle func::FuncOp funcOp, ModuleOp moduleOp, 5058ceae95SRiver Riddle DenseSet<func::FuncOp> &normalizableFuncs); 5158ceae95SRiver Riddle Operation *createOpResultsNormalized(func::FuncOp funcOp, Operation *oldOp); 522e2c0738SRiver Riddle }; 532e2c0738SRiver Riddle 542e2c0738SRiver Riddle } // namespace 552e2c0738SRiver Riddle 562e2c0738SRiver Riddle std::unique_ptr<OperationPass<ModuleOp>> 572e2c0738SRiver Riddle mlir::memref::createNormalizeMemRefsPass() { 582e2c0738SRiver Riddle return std::make_unique<NormalizeMemRefs>(); 592e2c0738SRiver Riddle } 602e2c0738SRiver Riddle 612e2c0738SRiver Riddle void NormalizeMemRefs::runOnOperation() { 622e2c0738SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n"); 632e2c0738SRiver Riddle ModuleOp moduleOp = getOperation(); 642e2c0738SRiver Riddle // We maintain all normalizable FuncOps in a DenseSet. It is initialized 652e2c0738SRiver Riddle // with all the functions within a module and then functions which are not 662e2c0738SRiver Riddle // normalizable are removed from this set. 672e2c0738SRiver Riddle // TODO: Change this to work on FuncLikeOp once there is an operation 682e2c0738SRiver Riddle // interface for it. 6958ceae95SRiver Riddle DenseSet<func::FuncOp> normalizableFuncs; 702e2c0738SRiver Riddle // Initialize `normalizableFuncs` with all the functions within a module. 7158ceae95SRiver Riddle moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); }); 722e2c0738SRiver Riddle 732e2c0738SRiver Riddle // Traverse through all the functions applying a filter which determines 742e2c0738SRiver Riddle // whether that function is normalizable or not. All callers/callees of 752e2c0738SRiver Riddle // a non-normalizable function will also become non-normalizable even if 762e2c0738SRiver Riddle // they aren't passing any or specific non-normalizable memrefs. So, 772e2c0738SRiver Riddle // functions which calls or get called by a non-normalizable becomes non- 782e2c0738SRiver Riddle // normalizable functions themselves. 7958ceae95SRiver Riddle moduleOp.walk([&](func::FuncOp funcOp) { 802e2c0738SRiver Riddle if (normalizableFuncs.contains(funcOp)) { 812e2c0738SRiver Riddle if (!areMemRefsNormalizable(funcOp)) { 822e2c0738SRiver Riddle LLVM_DEBUG(llvm::dbgs() 832e2c0738SRiver Riddle << "@" << funcOp.getName() 842e2c0738SRiver Riddle << " contains ops that cannot normalize MemRefs\n"); 852e2c0738SRiver Riddle // Since this function is not normalizable, we set all the caller 862e2c0738SRiver Riddle // functions and the callees of this function as not normalizable. 872e2c0738SRiver Riddle // TODO: Drop this conservative assumption in the future. 882e2c0738SRiver Riddle setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 892e2c0738SRiver Riddle normalizableFuncs); 902e2c0738SRiver Riddle } 912e2c0738SRiver Riddle } 922e2c0738SRiver Riddle }); 932e2c0738SRiver Riddle 942e2c0738SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size() 952e2c0738SRiver Riddle << " functions\n"); 962e2c0738SRiver Riddle // Those functions which can be normalized are subjected to normalization. 9758ceae95SRiver Riddle for (func::FuncOp &funcOp : normalizableFuncs) 982e2c0738SRiver Riddle normalizeFuncOpMemRefs(funcOp, moduleOp); 992e2c0738SRiver Riddle } 1002e2c0738SRiver Riddle 1012e2c0738SRiver Riddle /// Check whether all the uses of oldMemRef are either dereferencing uses or the 1022e2c0738SRiver Riddle /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints 1032e2c0738SRiver Riddle /// are satisfied will the value become a candidate for replacement. 1042e2c0738SRiver Riddle /// TODO: Extend this for DimOps. 1052e2c0738SRiver Riddle static bool isMemRefNormalizable(Value::user_range opUsers) { 1062e2c0738SRiver Riddle return llvm::all_of(opUsers, [](Operation *op) { 1072e2c0738SRiver Riddle return op->hasTrait<OpTrait::MemRefsNormalizable>(); 1082e2c0738SRiver Riddle }); 1092e2c0738SRiver Riddle } 1102e2c0738SRiver Riddle 1112e2c0738SRiver Riddle /// Set all the calling functions and the callees of the function as not 1122e2c0738SRiver Riddle /// normalizable. 1132e2c0738SRiver Riddle void NormalizeMemRefs::setCalleesAndCallersNonNormalizable( 11458ceae95SRiver Riddle func::FuncOp funcOp, ModuleOp moduleOp, 11558ceae95SRiver Riddle DenseSet<func::FuncOp> &normalizableFuncs) { 1162e2c0738SRiver Riddle if (!normalizableFuncs.contains(funcOp)) 1172e2c0738SRiver Riddle return; 1182e2c0738SRiver Riddle 1192e2c0738SRiver Riddle LLVM_DEBUG( 1202e2c0738SRiver Riddle llvm::dbgs() << "@" << funcOp.getName() 1212e2c0738SRiver Riddle << " calls or is called by non-normalizable function\n"); 1222e2c0738SRiver Riddle normalizableFuncs.erase(funcOp); 1232e2c0738SRiver Riddle // Caller of the function. 124e8bcc37fSRamkumar Ramachandra std::optional<SymbolTable::UseRange> symbolUses = 125e8bcc37fSRamkumar Ramachandra funcOp.getSymbolUses(moduleOp); 1262e2c0738SRiver Riddle for (SymbolTable::SymbolUse symbolUse : *symbolUses) { 1272e2c0738SRiver Riddle // TODO: Extend this for ops that are FunctionOpInterface. This would 1282e2c0738SRiver Riddle // require creating an OpInterface for FunctionOpInterface ops. 12958ceae95SRiver Riddle func::FuncOp parentFuncOp = 13058ceae95SRiver Riddle symbolUse.getUser()->getParentOfType<func::FuncOp>(); 13158ceae95SRiver Riddle for (func::FuncOp &funcOp : normalizableFuncs) { 1322e2c0738SRiver Riddle if (parentFuncOp == funcOp) { 1332e2c0738SRiver Riddle setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 1342e2c0738SRiver Riddle normalizableFuncs); 1352e2c0738SRiver Riddle break; 1362e2c0738SRiver Riddle } 1372e2c0738SRiver Riddle } 1382e2c0738SRiver Riddle } 1392e2c0738SRiver Riddle 1402e2c0738SRiver Riddle // Functions called by this function. 14123aa5a74SRiver Riddle funcOp.walk([&](func::CallOp callOp) { 1422e2c0738SRiver Riddle StringAttr callee = callOp.getCalleeAttr().getAttr(); 14358ceae95SRiver Riddle for (func::FuncOp &funcOp : normalizableFuncs) { 14458ceae95SRiver Riddle // We compare func::FuncOp and callee's name. 1452e2c0738SRiver Riddle if (callee == funcOp.getNameAttr()) { 1462e2c0738SRiver Riddle setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 1472e2c0738SRiver Riddle normalizableFuncs); 1482e2c0738SRiver Riddle break; 1492e2c0738SRiver Riddle } 1502e2c0738SRiver Riddle } 1512e2c0738SRiver Riddle }); 1522e2c0738SRiver Riddle } 1532e2c0738SRiver Riddle 154*2ec27848SMatthias Gehre /// Check whether all the uses of AllocOps, AllocaOps, CallOps and function 155*2ec27848SMatthias Gehre /// arguments of a function are either of dereferencing type or are uses in: 156*2ec27848SMatthias Gehre /// DeallocOp, CallOp or ReturnOp. Only if these constraints are satisfied will 157*2ec27848SMatthias Gehre /// the function become a candidate for normalization. When the uses of a memref 158*2ec27848SMatthias Gehre /// are non-normalizable and the memref map layout is trivial (identity), we can 159183c4a39STung D. Le /// still label the entire function as normalizable. We assume external 160183c4a39STung D. Le /// functions to be normalizable. 16158ceae95SRiver Riddle bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { 1622e2c0738SRiver Riddle // We assume external functions to be normalizable. 1632e2c0738SRiver Riddle if (funcOp.isExternal()) 1642e2c0738SRiver Riddle return true; 1652e2c0738SRiver Riddle 1662e2c0738SRiver Riddle if (funcOp 1672e2c0738SRiver Riddle .walk([&](memref::AllocOp allocOp) -> WalkResult { 1682e2c0738SRiver Riddle Value oldMemRef = allocOp.getResult(); 169d4284bafSTung D. Le if (!allocOp.getType().getLayout().isIdentity() && 170183c4a39STung D. Le !isMemRefNormalizable(oldMemRef.getUsers())) 1712e2c0738SRiver Riddle return WalkResult::interrupt(); 1722e2c0738SRiver Riddle return WalkResult::advance(); 1732e2c0738SRiver Riddle }) 1742e2c0738SRiver Riddle .wasInterrupted()) 1752e2c0738SRiver Riddle return false; 1762e2c0738SRiver Riddle 1772e2c0738SRiver Riddle if (funcOp 178*2ec27848SMatthias Gehre .walk([&](memref::AllocaOp allocaOp) -> WalkResult { 179*2ec27848SMatthias Gehre Value oldMemRef = allocaOp.getResult(); 180*2ec27848SMatthias Gehre if (!allocaOp.getType().getLayout().isIdentity() && 181*2ec27848SMatthias Gehre !isMemRefNormalizable(oldMemRef.getUsers())) 182*2ec27848SMatthias Gehre return WalkResult::interrupt(); 183*2ec27848SMatthias Gehre return WalkResult::advance(); 184*2ec27848SMatthias Gehre }) 185*2ec27848SMatthias Gehre .wasInterrupted()) 186*2ec27848SMatthias Gehre return false; 187*2ec27848SMatthias Gehre 188*2ec27848SMatthias Gehre if (funcOp 18923aa5a74SRiver Riddle .walk([&](func::CallOp callOp) -> WalkResult { 1902e2c0738SRiver Riddle for (unsigned resIndex : 1912e2c0738SRiver Riddle llvm::seq<unsigned>(0, callOp.getNumResults())) { 1922e2c0738SRiver Riddle Value oldMemRef = callOp.getResult(resIndex); 193d4284bafSTung D. Le if (auto oldMemRefType = 1945550c821STres Popp dyn_cast<MemRefType>(oldMemRef.getType())) 195d4284bafSTung D. Le if (!oldMemRefType.getLayout().isIdentity() && 196183c4a39STung D. Le !isMemRefNormalizable(oldMemRef.getUsers())) 1972e2c0738SRiver Riddle return WalkResult::interrupt(); 1982e2c0738SRiver Riddle } 1992e2c0738SRiver Riddle return WalkResult::advance(); 2002e2c0738SRiver Riddle }) 2012e2c0738SRiver Riddle .wasInterrupted()) 2022e2c0738SRiver Riddle return false; 2032e2c0738SRiver Riddle 2042e2c0738SRiver Riddle for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { 2052e2c0738SRiver Riddle BlockArgument oldMemRef = funcOp.getArgument(argIndex); 2065550c821STres Popp if (auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType())) 207d4284bafSTung D. Le if (!oldMemRefType.getLayout().isIdentity() && 208183c4a39STung D. Le !isMemRefNormalizable(oldMemRef.getUsers())) 2092e2c0738SRiver Riddle return false; 2102e2c0738SRiver Riddle } 2112e2c0738SRiver Riddle 2122e2c0738SRiver Riddle return true; 2132e2c0738SRiver Riddle } 2142e2c0738SRiver Riddle 2152e2c0738SRiver Riddle /// Fetch the updated argument list and result of the function and update the 2162e2c0738SRiver Riddle /// function signature. This updates the function's return type at the caller 2172e2c0738SRiver Riddle /// site and in case the return type is a normalized memref then it updates 2182e2c0738SRiver Riddle /// the calling function's signature. 2192e2c0738SRiver Riddle /// TODO: An update to the calling function signature is required only if the 2202e2c0738SRiver Riddle /// returned value is in turn used in ReturnOp of the calling function. 22158ceae95SRiver Riddle void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, 2222e2c0738SRiver Riddle ModuleOp moduleOp) { 2234a3460a7SRiver Riddle FunctionType functionType = funcOp.getFunctionType(); 2242e2c0738SRiver Riddle SmallVector<Type, 4> resultTypes; 2252e2c0738SRiver Riddle FunctionType newFuncType; 2262e2c0738SRiver Riddle resultTypes = llvm::to_vector<4>(functionType.getResults()); 2272e2c0738SRiver Riddle 2282e2c0738SRiver Riddle // External function's signature was already updated in 2292e2c0738SRiver Riddle // 'normalizeFuncOpMemRefs()'. 2302e2c0738SRiver Riddle if (!funcOp.isExternal()) { 2312e2c0738SRiver Riddle SmallVector<Type, 8> argTypes; 2322e2c0738SRiver Riddle for (const auto &argEn : llvm::enumerate(funcOp.getArguments())) 2332e2c0738SRiver Riddle argTypes.push_back(argEn.value().getType()); 2342e2c0738SRiver Riddle 2352e2c0738SRiver Riddle // Traverse ReturnOps to check if an update to the return type in the 2362e2c0738SRiver Riddle // function signature is required. 23723aa5a74SRiver Riddle funcOp.walk([&](func::ReturnOp returnOp) { 2382e2c0738SRiver Riddle for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) { 2392e2c0738SRiver Riddle Type opType = operandEn.value().getType(); 2405550c821STres Popp MemRefType memrefType = dyn_cast<MemRefType>(opType); 2412e2c0738SRiver Riddle // If type is not memref or if the memref type is same as that in 2422e2c0738SRiver Riddle // function's return signature then no update is required. 2432e2c0738SRiver Riddle if (!memrefType || memrefType == resultTypes[operandEn.index()]) 2442e2c0738SRiver Riddle continue; 2452e2c0738SRiver Riddle // Update function's return type signature. 2462e2c0738SRiver Riddle // Return type gets normalized either as a result of function argument 2472e2c0738SRiver Riddle // normalization, AllocOp normalization or an update made at CallOp. 2482e2c0738SRiver Riddle // There can be many call flows inside a function and an update to a 2492e2c0738SRiver Riddle // specific ReturnOp has not yet been made. So we check that the result 2502e2c0738SRiver Riddle // memref type is normalized. 2512e2c0738SRiver Riddle // TODO: When selective normalization is implemented, handle multiple 2522e2c0738SRiver Riddle // results case where some are normalized, some aren't. 2532e2c0738SRiver Riddle if (memrefType.getLayout().isIdentity()) 2542e2c0738SRiver Riddle resultTypes[operandEn.index()] = memrefType; 2552e2c0738SRiver Riddle } 2562e2c0738SRiver Riddle }); 2572e2c0738SRiver Riddle 2582e2c0738SRiver Riddle // We create a new function type and modify the function signature with this 2592e2c0738SRiver Riddle // new type. 2602e2c0738SRiver Riddle newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes, 2612e2c0738SRiver Riddle /*results=*/resultTypes); 2622e2c0738SRiver Riddle } 2632e2c0738SRiver Riddle 2642e2c0738SRiver Riddle // Since we update the function signature, it might affect the result types at 2652e2c0738SRiver Riddle // the caller site. Since this result might even be used by the caller 2662e2c0738SRiver Riddle // function in ReturnOps, the caller function's signature will also change. 2672e2c0738SRiver Riddle // Hence we record the caller function in 'funcOpsToUpdate' to update their 2682e2c0738SRiver Riddle // signature as well. 26958ceae95SRiver Riddle llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate; 2702e2c0738SRiver Riddle // We iterate over all symbolic uses of the function and update the return 2712e2c0738SRiver Riddle // type at the caller site. 272e8bcc37fSRamkumar Ramachandra std::optional<SymbolTable::UseRange> symbolUses = 273e8bcc37fSRamkumar Ramachandra funcOp.getSymbolUses(moduleOp); 2742e2c0738SRiver Riddle for (SymbolTable::SymbolUse symbolUse : *symbolUses) { 2752e2c0738SRiver Riddle Operation *userOp = symbolUse.getUser(); 2762e2c0738SRiver Riddle OpBuilder builder(userOp); 2772e2c0738SRiver Riddle // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes 2782e2c0738SRiver Riddle // that the non-CallOp has no memrefs to be replaced. 2792e2c0738SRiver Riddle // TODO: Handle cases where a non-CallOp symbol use of a function deals with 2802e2c0738SRiver Riddle // memrefs. 28123aa5a74SRiver Riddle auto callOp = dyn_cast<func::CallOp>(userOp); 2822e2c0738SRiver Riddle if (!callOp) 2832e2c0738SRiver Riddle continue; 2842e2c0738SRiver Riddle Operation *newCallOp = 28523aa5a74SRiver Riddle builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(), 2862e2c0738SRiver Riddle resultTypes, userOp->getOperands()); 2872e2c0738SRiver Riddle bool replacingMemRefUsesFailed = false; 2882e2c0738SRiver Riddle bool returnTypeChanged = false; 2892e2c0738SRiver Riddle for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) { 2902e2c0738SRiver Riddle OpResult oldResult = userOp->getResult(resIndex); 2912e2c0738SRiver Riddle OpResult newResult = newCallOp->getResult(resIndex); 2922e2c0738SRiver Riddle // This condition ensures that if the result is not of type memref or if 2932e2c0738SRiver Riddle // the resulting memref was already having a trivial map layout then we 2942e2c0738SRiver Riddle // need not perform any use replacement here. 2952e2c0738SRiver Riddle if (oldResult.getType() == newResult.getType()) 2962e2c0738SRiver Riddle continue; 2972e2c0738SRiver Riddle AffineMap layoutMap = 2985550c821STres Popp cast<MemRefType>(oldResult.getType()).getLayout().getAffineMap(); 2992e2c0738SRiver Riddle if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, 3002e2c0738SRiver Riddle /*extraIndices=*/{}, 3012e2c0738SRiver Riddle /*indexRemap=*/layoutMap, 3022e2c0738SRiver Riddle /*extraOperands=*/{}, 3032e2c0738SRiver Riddle /*symbolOperands=*/{}, 3042e2c0738SRiver Riddle /*domOpFilter=*/nullptr, 3052e2c0738SRiver Riddle /*postDomOpFilter=*/nullptr, 3062e2c0738SRiver Riddle /*allowNonDereferencingOps=*/true, 3072e2c0738SRiver Riddle /*replaceInDeallocOp=*/true))) { 3082e2c0738SRiver Riddle // If it failed (due to escapes for example), bail out. 3092e2c0738SRiver Riddle // It should never hit this part of the code because it is called by 3102e2c0738SRiver Riddle // only those functions which are normalizable. 3112e2c0738SRiver Riddle newCallOp->erase(); 3122e2c0738SRiver Riddle replacingMemRefUsesFailed = true; 3132e2c0738SRiver Riddle break; 3142e2c0738SRiver Riddle } 3152e2c0738SRiver Riddle returnTypeChanged = true; 3162e2c0738SRiver Riddle } 3172e2c0738SRiver Riddle if (replacingMemRefUsesFailed) 3182e2c0738SRiver Riddle continue; 3192e2c0738SRiver Riddle // Replace all uses for other non-memref result types. 3202e2c0738SRiver Riddle userOp->replaceAllUsesWith(newCallOp); 3212e2c0738SRiver Riddle userOp->erase(); 3222e2c0738SRiver Riddle if (returnTypeChanged) { 3232e2c0738SRiver Riddle // Since the return type changed it might lead to a change in function's 3242e2c0738SRiver Riddle // signature. 3252e2c0738SRiver Riddle // TODO: If funcOp doesn't return any memref type then no need to update 3262e2c0738SRiver Riddle // signature. 3272e2c0738SRiver Riddle // TODO: Further optimization - Check if the memref is indeed part of 3282e2c0738SRiver Riddle // ReturnOp at the parentFuncOp and only then updation of signature is 3292e2c0738SRiver Riddle // required. 3302e2c0738SRiver Riddle // TODO: Extend this for ops that are FunctionOpInterface. This would 3312e2c0738SRiver Riddle // require creating an OpInterface for FunctionOpInterface ops. 33258ceae95SRiver Riddle func::FuncOp parentFuncOp = newCallOp->getParentOfType<func::FuncOp>(); 3332e2c0738SRiver Riddle funcOpsToUpdate.insert(parentFuncOp); 3342e2c0738SRiver Riddle } 3352e2c0738SRiver Riddle } 3362e2c0738SRiver Riddle // Because external function's signature is already updated in 3372e2c0738SRiver Riddle // 'normalizeFuncOpMemRefs()', we don't need to update it here again. 3382e2c0738SRiver Riddle if (!funcOp.isExternal()) 3392e2c0738SRiver Riddle funcOp.setType(newFuncType); 3402e2c0738SRiver Riddle 3412e2c0738SRiver Riddle // Updating the signature type of those functions which call the current 3422e2c0738SRiver Riddle // function. Only if the return type of the current function has a normalized 3432e2c0738SRiver Riddle // memref will the caller function become a candidate for signature update. 34458ceae95SRiver Riddle for (func::FuncOp parentFuncOp : funcOpsToUpdate) 3452e2c0738SRiver Riddle updateFunctionSignature(parentFuncOp, moduleOp); 3462e2c0738SRiver Riddle } 3472e2c0738SRiver Riddle 3482e2c0738SRiver Riddle /// Normalizes the memrefs within a function which includes those arising as a 349*2ec27848SMatthias Gehre /// result of AllocOps, AllocaOps, CallOps and function's argument. The ModuleOp 350*2ec27848SMatthias Gehre /// argument is used to help update function's signature after normalization. 35158ceae95SRiver Riddle void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, 3522e2c0738SRiver Riddle ModuleOp moduleOp) { 3532e2c0738SRiver Riddle // Turn memrefs' non-identity layouts maps into ones with identity. Collect 354*2ec27848SMatthias Gehre // alloc/alloca ops first and then process since normalizeMemRef 355*2ec27848SMatthias Gehre // replaces/erases ops during memref rewriting. 3562e2c0738SRiver Riddle SmallVector<memref::AllocOp, 4> allocOps; 3572e2c0738SRiver Riddle funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); }); 3582e2c0738SRiver Riddle for (memref::AllocOp allocOp : allocOps) 3592e2c0738SRiver Riddle (void)normalizeMemRef(&allocOp); 3602e2c0738SRiver Riddle 361*2ec27848SMatthias Gehre SmallVector<memref::AllocaOp> allocaOps; 362*2ec27848SMatthias Gehre funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); }); 363*2ec27848SMatthias Gehre for (memref::AllocaOp allocaOp : allocaOps) 364*2ec27848SMatthias Gehre (void)normalizeMemRef(&allocaOp); 365*2ec27848SMatthias Gehre 3662e2c0738SRiver Riddle // We use this OpBuilder to create new memref layout later. 3672e2c0738SRiver Riddle OpBuilder b(funcOp); 3682e2c0738SRiver Riddle 3694a3460a7SRiver Riddle FunctionType functionType = funcOp.getFunctionType(); 3702e2c0738SRiver Riddle SmallVector<Location> functionArgLocs(llvm::map_range( 3712e2c0738SRiver Riddle funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); })); 3722e2c0738SRiver Riddle SmallVector<Type, 8> inputTypes; 3732e2c0738SRiver Riddle // Walk over each argument of a function to perform memref normalization (if 3742e2c0738SRiver Riddle for (unsigned argIndex : 3752e2c0738SRiver Riddle llvm::seq<unsigned>(0, functionType.getNumInputs())) { 3762e2c0738SRiver Riddle Type argType = functionType.getInput(argIndex); 3775550c821STres Popp MemRefType memrefType = dyn_cast<MemRefType>(argType); 3782e2c0738SRiver Riddle // Check whether argument is of MemRef type. Any other argument type can 3792e2c0738SRiver Riddle // simply be part of the final function signature. 3802e2c0738SRiver Riddle if (!memrefType) { 3812e2c0738SRiver Riddle inputTypes.push_back(argType); 3822e2c0738SRiver Riddle continue; 3832e2c0738SRiver Riddle } 3842e2c0738SRiver Riddle // Fetch a new memref type after normalizing the old memref to have an 3852e2c0738SRiver Riddle // identity map layout. 3861fee821dSKai Sasaki MemRefType newMemRefType = normalizeMemRefType(memrefType); 3872e2c0738SRiver Riddle if (newMemRefType == memrefType || funcOp.isExternal()) { 3882e2c0738SRiver Riddle // Either memrefType already had an identity map or the map couldn't be 3892e2c0738SRiver Riddle // transformed to an identity map. 3902e2c0738SRiver Riddle inputTypes.push_back(newMemRefType); 3912e2c0738SRiver Riddle continue; 3922e2c0738SRiver Riddle } 3932e2c0738SRiver Riddle 3942e2c0738SRiver Riddle // Insert a new temporary argument with the new memref type. 3952e2c0738SRiver Riddle BlockArgument newMemRef = funcOp.front().insertArgument( 3962e2c0738SRiver Riddle argIndex, newMemRefType, functionArgLocs[argIndex]); 3972e2c0738SRiver Riddle BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1); 3982e2c0738SRiver Riddle AffineMap layoutMap = memrefType.getLayout().getAffineMap(); 3992e2c0738SRiver Riddle // Replace all uses of the old memref. 4002e2c0738SRiver Riddle if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, 4012e2c0738SRiver Riddle /*extraIndices=*/{}, 4022e2c0738SRiver Riddle /*indexRemap=*/layoutMap, 4032e2c0738SRiver Riddle /*extraOperands=*/{}, 4042e2c0738SRiver Riddle /*symbolOperands=*/{}, 4052e2c0738SRiver Riddle /*domOpFilter=*/nullptr, 4062e2c0738SRiver Riddle /*postDomOpFilter=*/nullptr, 4072e2c0738SRiver Riddle /*allowNonDereferencingOps=*/true, 4082e2c0738SRiver Riddle /*replaceInDeallocOp=*/true))) { 4092e2c0738SRiver Riddle // If it failed (due to escapes for example), bail out. Removing the 4102e2c0738SRiver Riddle // temporary argument inserted previously. 4112e2c0738SRiver Riddle funcOp.front().eraseArgument(argIndex); 4122e2c0738SRiver Riddle continue; 4132e2c0738SRiver Riddle } 4142e2c0738SRiver Riddle 4152e2c0738SRiver Riddle // All uses for the argument with old memref type were replaced 4162e2c0738SRiver Riddle // successfully. So we remove the old argument now. 4172e2c0738SRiver Riddle funcOp.front().eraseArgument(argIndex + 1); 4182e2c0738SRiver Riddle } 4192e2c0738SRiver Riddle 4202e2c0738SRiver Riddle // Walk over normalizable operations to normalize memrefs of the operation 4212e2c0738SRiver Riddle // results. When `op` has memrefs with affine map in the operation results, 4222e2c0738SRiver Riddle // new operation containin normalized memrefs is created. Then, the memrefs 4232e2c0738SRiver Riddle // are replaced. `CallOp` is skipped here because it is handled in 4242e2c0738SRiver Riddle // `updateFunctionSignature()`. 4252e2c0738SRiver Riddle funcOp.walk([&](Operation *op) { 4262e2c0738SRiver Riddle if (op->hasTrait<OpTrait::MemRefsNormalizable>() && 42723aa5a74SRiver Riddle op->getNumResults() > 0 && !isa<func::CallOp>(op) && 42823aa5a74SRiver Riddle !funcOp.isExternal()) { 4292e2c0738SRiver Riddle // Create newOp containing normalized memref in the operation result. 4302e2c0738SRiver Riddle Operation *newOp = createOpResultsNormalized(funcOp, op); 4312e2c0738SRiver Riddle // When all of the operation results have no memrefs or memrefs without 4322e2c0738SRiver Riddle // affine map, `newOp` is the same with `op` and following process is 4332e2c0738SRiver Riddle // skipped. 4342e2c0738SRiver Riddle if (op != newOp) { 4352e2c0738SRiver Riddle bool replacingMemRefUsesFailed = false; 4362e2c0738SRiver Riddle for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) { 4372e2c0738SRiver Riddle // Replace all uses of the old memrefs. 4382e2c0738SRiver Riddle Value oldMemRef = op->getResult(resIndex); 4392e2c0738SRiver Riddle Value newMemRef = newOp->getResult(resIndex); 4405550c821STres Popp MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType()); 4412e2c0738SRiver Riddle // Check whether the operation result is MemRef type. 4422e2c0738SRiver Riddle if (!oldMemRefType) 4432e2c0738SRiver Riddle continue; 4445550c821STres Popp MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType()); 4452e2c0738SRiver Riddle if (oldMemRefType == newMemRefType) 4462e2c0738SRiver Riddle continue; 4472e2c0738SRiver Riddle // TODO: Assume single layout map. Multiple maps not supported. 4482e2c0738SRiver Riddle AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap(); 4492e2c0738SRiver Riddle if (failed(replaceAllMemRefUsesWith(oldMemRef, 4502e2c0738SRiver Riddle /*newMemRef=*/newMemRef, 4512e2c0738SRiver Riddle /*extraIndices=*/{}, 4522e2c0738SRiver Riddle /*indexRemap=*/layoutMap, 4532e2c0738SRiver Riddle /*extraOperands=*/{}, 4542e2c0738SRiver Riddle /*symbolOperands=*/{}, 4552e2c0738SRiver Riddle /*domOpFilter=*/nullptr, 4562e2c0738SRiver Riddle /*postDomOpFilter=*/nullptr, 4572e2c0738SRiver Riddle /*allowNonDereferencingOps=*/true, 4582e2c0738SRiver Riddle /*replaceInDeallocOp=*/true))) { 4592e2c0738SRiver Riddle newOp->erase(); 4602e2c0738SRiver Riddle replacingMemRefUsesFailed = true; 4612e2c0738SRiver Riddle continue; 4622e2c0738SRiver Riddle } 4632e2c0738SRiver Riddle } 4642e2c0738SRiver Riddle if (!replacingMemRefUsesFailed) { 4652e2c0738SRiver Riddle // Replace other ops with new op and delete the old op when the 4662e2c0738SRiver Riddle // replacement succeeded. 4672e2c0738SRiver Riddle op->replaceAllUsesWith(newOp); 4682e2c0738SRiver Riddle op->erase(); 4692e2c0738SRiver Riddle } 4702e2c0738SRiver Riddle } 4712e2c0738SRiver Riddle } 4722e2c0738SRiver Riddle }); 4732e2c0738SRiver Riddle 4742e2c0738SRiver Riddle // In a normal function, memrefs in the return type signature gets normalized 4752e2c0738SRiver Riddle // as a result of normalization of functions arguments, AllocOps or CallOps' 4762e2c0738SRiver Riddle // result types. Since an external function doesn't have a body, memrefs in 4772e2c0738SRiver Riddle // the return type signature can only get normalized by iterating over the 4782e2c0738SRiver Riddle // individual return types. 4792e2c0738SRiver Riddle if (funcOp.isExternal()) { 4802e2c0738SRiver Riddle SmallVector<Type, 4> resultTypes; 4812e2c0738SRiver Riddle for (unsigned resIndex : 4822e2c0738SRiver Riddle llvm::seq<unsigned>(0, functionType.getNumResults())) { 4832e2c0738SRiver Riddle Type resType = functionType.getResult(resIndex); 4845550c821STres Popp MemRefType memrefType = dyn_cast<MemRefType>(resType); 4852e2c0738SRiver Riddle // Check whether result is of MemRef type. Any other argument type can 4862e2c0738SRiver Riddle // simply be part of the final function signature. 4872e2c0738SRiver Riddle if (!memrefType) { 4882e2c0738SRiver Riddle resultTypes.push_back(resType); 4892e2c0738SRiver Riddle continue; 4902e2c0738SRiver Riddle } 4912e2c0738SRiver Riddle // Computing a new memref type after normalizing the old memref to have an 4922e2c0738SRiver Riddle // identity map layout. 4931fee821dSKai Sasaki MemRefType newMemRefType = normalizeMemRefType(memrefType); 4942e2c0738SRiver Riddle resultTypes.push_back(newMemRefType); 4952e2c0738SRiver Riddle } 4962e2c0738SRiver Riddle 4972e2c0738SRiver Riddle FunctionType newFuncType = 4982e2c0738SRiver Riddle FunctionType::get(&getContext(), /*inputs=*/inputTypes, 4992e2c0738SRiver Riddle /*results=*/resultTypes); 5002e2c0738SRiver Riddle // Setting the new function signature for this external function. 5012e2c0738SRiver Riddle funcOp.setType(newFuncType); 5022e2c0738SRiver Riddle } 5032e2c0738SRiver Riddle updateFunctionSignature(funcOp, moduleOp); 5042e2c0738SRiver Riddle } 5052e2c0738SRiver Riddle 5062e2c0738SRiver Riddle /// Create an operation containing normalized memrefs in the operation results. 5072e2c0738SRiver Riddle /// When the results of `oldOp` have memrefs with affine map, the memrefs are 5082e2c0738SRiver Riddle /// normalized, and new operation containing them in the operation results is 5092e2c0738SRiver Riddle /// returned. If all of the results of `oldOp` have no memrefs or memrefs 5102e2c0738SRiver Riddle /// without affine map, `oldOp` is returned without modification. 51158ceae95SRiver Riddle Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp, 5122e2c0738SRiver Riddle Operation *oldOp) { 5132e2c0738SRiver Riddle // Prepare OperationState to create newOp containing normalized memref in 5142e2c0738SRiver Riddle // the operation results. 5152e2c0738SRiver Riddle OperationState result(oldOp->getLoc(), oldOp->getName()); 5162e2c0738SRiver Riddle result.addOperands(oldOp->getOperands()); 5172e2c0738SRiver Riddle result.addAttributes(oldOp->getAttrs()); 5182e2c0738SRiver Riddle // Add normalized MemRefType to the OperationState. 5192e2c0738SRiver Riddle SmallVector<Type, 4> resultTypes; 5202e2c0738SRiver Riddle OpBuilder b(funcOp); 5212e2c0738SRiver Riddle bool resultTypeNormalized = false; 5222e2c0738SRiver Riddle for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) { 5232e2c0738SRiver Riddle auto resultType = oldOp->getResult(resIndex).getType(); 5245550c821STres Popp MemRefType memrefType = dyn_cast<MemRefType>(resultType); 5252e2c0738SRiver Riddle // Check whether the operation result is MemRef type. 5262e2c0738SRiver Riddle if (!memrefType) { 5272e2c0738SRiver Riddle resultTypes.push_back(resultType); 5282e2c0738SRiver Riddle continue; 5292e2c0738SRiver Riddle } 5301fee821dSKai Sasaki 5312e2c0738SRiver Riddle // Fetch a new memref type after normalizing the old memref. 5321fee821dSKai Sasaki MemRefType newMemRefType = normalizeMemRefType(memrefType); 5332e2c0738SRiver Riddle if (newMemRefType == memrefType) { 5342e2c0738SRiver Riddle // Either memrefType already had an identity map or the map couldn't 5352e2c0738SRiver Riddle // be transformed to an identity map. 5362e2c0738SRiver Riddle resultTypes.push_back(memrefType); 5372e2c0738SRiver Riddle continue; 5382e2c0738SRiver Riddle } 5392e2c0738SRiver Riddle resultTypes.push_back(newMemRefType); 5402e2c0738SRiver Riddle resultTypeNormalized = true; 5412e2c0738SRiver Riddle } 5422e2c0738SRiver Riddle result.addTypes(resultTypes); 5432e2c0738SRiver Riddle // When all of the results of `oldOp` have no memrefs or memrefs without 5442e2c0738SRiver Riddle // affine map, `oldOp` is returned without modification. 5452e2c0738SRiver Riddle if (resultTypeNormalized) { 5462e2c0738SRiver Riddle OpBuilder bb(oldOp); 5472e2c0738SRiver Riddle for (auto &oldRegion : oldOp->getRegions()) { 5482e2c0738SRiver Riddle Region *newRegion = result.addRegion(); 5492e2c0738SRiver Riddle newRegion->takeBody(oldRegion); 5502e2c0738SRiver Riddle } 55114ecafd0SChia-hung Duan return bb.create(result); 5522e2c0738SRiver Riddle } 5532e2c0738SRiver Riddle return oldOp; 5542e2c0738SRiver Riddle } 555