xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp (revision 2ec27848c00cda734697619047e640eadb254555)
12e2c0738SRiver Riddle //===- NormalizeMemRefs.cpp -----------------------------------------------===//
22e2c0738SRiver Riddle //
32e2c0738SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42e2c0738SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
52e2c0738SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62e2c0738SRiver Riddle //
72e2c0738SRiver Riddle //===----------------------------------------------------------------------===//
82e2c0738SRiver Riddle //
92e2c0738SRiver Riddle // This file implements an interprocedural pass to normalize memrefs to have
102e2c0738SRiver Riddle // identity layout maps.
112e2c0738SRiver Riddle //
122e2c0738SRiver Riddle //===----------------------------------------------------------------------===//
132e2c0738SRiver Riddle 
142e2c0738SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
15a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Utils.h"
161f971e23SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
172e2c0738SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
182e2c0738SRiver Riddle #include "mlir/Dialect/MemRef/Transforms/Passes.h"
192e2c0738SRiver Riddle #include "llvm/ADT/SmallSet.h"
202e2c0738SRiver Riddle #include "llvm/Support/Debug.h"
212e2c0738SRiver Riddle 
2267d0d7acSMichele Scuttari namespace mlir {
2367d0d7acSMichele Scuttari namespace memref {
2467d0d7acSMichele Scuttari #define GEN_PASS_DEF_NORMALIZEMEMREFS
2567d0d7acSMichele Scuttari #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
2667d0d7acSMichele Scuttari } // namespace memref
2767d0d7acSMichele Scuttari } // namespace mlir
2867d0d7acSMichele Scuttari 
292e2c0738SRiver Riddle #define DEBUG_TYPE "normalize-memrefs"
302e2c0738SRiver Riddle 
312e2c0738SRiver Riddle using namespace mlir;
324c48f016SMatthias Springer using namespace mlir::affine;
332e2c0738SRiver Riddle 
342e2c0738SRiver Riddle namespace {
352e2c0738SRiver Riddle 
362e2c0738SRiver Riddle /// All memrefs passed across functions with non-trivial layout maps are
372e2c0738SRiver Riddle /// converted to ones with trivial identity layout ones.
382e2c0738SRiver Riddle /// If all the memref types/uses in a function are normalizable, we treat
392e2c0738SRiver Riddle /// such functions as normalizable. Also, if a normalizable function is known
402e2c0738SRiver Riddle /// to call a non-normalizable function, we treat that function as
412e2c0738SRiver Riddle /// non-normalizable as well. We assume external functions to be normalizable.
4267d0d7acSMichele Scuttari struct NormalizeMemRefs
4367d0d7acSMichele Scuttari     : public memref::impl::NormalizeMemRefsBase<NormalizeMemRefs> {
442e2c0738SRiver Riddle   void runOnOperation() override;
4558ceae95SRiver Riddle   void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
4658ceae95SRiver Riddle   bool areMemRefsNormalizable(func::FuncOp funcOp);
4758ceae95SRiver Riddle   void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp);
4858ceae95SRiver Riddle   void setCalleesAndCallersNonNormalizable(
4958ceae95SRiver Riddle       func::FuncOp funcOp, ModuleOp moduleOp,
5058ceae95SRiver Riddle       DenseSet<func::FuncOp> &normalizableFuncs);
5158ceae95SRiver Riddle   Operation *createOpResultsNormalized(func::FuncOp funcOp, Operation *oldOp);
522e2c0738SRiver Riddle };
532e2c0738SRiver Riddle 
542e2c0738SRiver Riddle } // namespace
552e2c0738SRiver Riddle 
562e2c0738SRiver Riddle std::unique_ptr<OperationPass<ModuleOp>>
572e2c0738SRiver Riddle mlir::memref::createNormalizeMemRefsPass() {
582e2c0738SRiver Riddle   return std::make_unique<NormalizeMemRefs>();
592e2c0738SRiver Riddle }
602e2c0738SRiver Riddle 
612e2c0738SRiver Riddle void NormalizeMemRefs::runOnOperation() {
622e2c0738SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
632e2c0738SRiver Riddle   ModuleOp moduleOp = getOperation();
642e2c0738SRiver Riddle   // We maintain all normalizable FuncOps in a DenseSet. It is initialized
652e2c0738SRiver Riddle   // with all the functions within a module and then functions which are not
662e2c0738SRiver Riddle   // normalizable are removed from this set.
672e2c0738SRiver Riddle   // TODO: Change this to work on FuncLikeOp once there is an operation
682e2c0738SRiver Riddle   // interface for it.
6958ceae95SRiver Riddle   DenseSet<func::FuncOp> normalizableFuncs;
702e2c0738SRiver Riddle   // Initialize `normalizableFuncs` with all the functions within a module.
7158ceae95SRiver Riddle   moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
722e2c0738SRiver Riddle 
732e2c0738SRiver Riddle   // Traverse through all the functions applying a filter which determines
742e2c0738SRiver Riddle   // whether that function is normalizable or not. All callers/callees of
752e2c0738SRiver Riddle   // a non-normalizable function will also become non-normalizable even if
762e2c0738SRiver Riddle   // they aren't passing any or specific non-normalizable memrefs. So,
772e2c0738SRiver Riddle   // functions which calls or get called by a non-normalizable becomes non-
782e2c0738SRiver Riddle   // normalizable functions themselves.
7958ceae95SRiver Riddle   moduleOp.walk([&](func::FuncOp funcOp) {
802e2c0738SRiver Riddle     if (normalizableFuncs.contains(funcOp)) {
812e2c0738SRiver Riddle       if (!areMemRefsNormalizable(funcOp)) {
822e2c0738SRiver Riddle         LLVM_DEBUG(llvm::dbgs()
832e2c0738SRiver Riddle                    << "@" << funcOp.getName()
842e2c0738SRiver Riddle                    << " contains ops that cannot normalize MemRefs\n");
852e2c0738SRiver Riddle         // Since this function is not normalizable, we set all the caller
862e2c0738SRiver Riddle         // functions and the callees of this function as not normalizable.
872e2c0738SRiver Riddle         // TODO: Drop this conservative assumption in the future.
882e2c0738SRiver Riddle         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
892e2c0738SRiver Riddle                                             normalizableFuncs);
902e2c0738SRiver Riddle       }
912e2c0738SRiver Riddle     }
922e2c0738SRiver Riddle   });
932e2c0738SRiver Riddle 
942e2c0738SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
952e2c0738SRiver Riddle                           << " functions\n");
962e2c0738SRiver Riddle   // Those functions which can be normalized are subjected to normalization.
9758ceae95SRiver Riddle   for (func::FuncOp &funcOp : normalizableFuncs)
982e2c0738SRiver Riddle     normalizeFuncOpMemRefs(funcOp, moduleOp);
992e2c0738SRiver Riddle }
1002e2c0738SRiver Riddle 
1012e2c0738SRiver Riddle /// Check whether all the uses of oldMemRef are either dereferencing uses or the
1022e2c0738SRiver Riddle /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
1032e2c0738SRiver Riddle /// are satisfied will the value become a candidate for replacement.
1042e2c0738SRiver Riddle /// TODO: Extend this for DimOps.
1052e2c0738SRiver Riddle static bool isMemRefNormalizable(Value::user_range opUsers) {
1062e2c0738SRiver Riddle   return llvm::all_of(opUsers, [](Operation *op) {
1072e2c0738SRiver Riddle     return op->hasTrait<OpTrait::MemRefsNormalizable>();
1082e2c0738SRiver Riddle   });
1092e2c0738SRiver Riddle }
1102e2c0738SRiver Riddle 
1112e2c0738SRiver Riddle /// Set all the calling functions and the callees of the function as not
1122e2c0738SRiver Riddle /// normalizable.
1132e2c0738SRiver Riddle void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
11458ceae95SRiver Riddle     func::FuncOp funcOp, ModuleOp moduleOp,
11558ceae95SRiver Riddle     DenseSet<func::FuncOp> &normalizableFuncs) {
1162e2c0738SRiver Riddle   if (!normalizableFuncs.contains(funcOp))
1172e2c0738SRiver Riddle     return;
1182e2c0738SRiver Riddle 
1192e2c0738SRiver Riddle   LLVM_DEBUG(
1202e2c0738SRiver Riddle       llvm::dbgs() << "@" << funcOp.getName()
1212e2c0738SRiver Riddle                    << " calls or is called by non-normalizable function\n");
1222e2c0738SRiver Riddle   normalizableFuncs.erase(funcOp);
1232e2c0738SRiver Riddle   // Caller of the function.
124e8bcc37fSRamkumar Ramachandra   std::optional<SymbolTable::UseRange> symbolUses =
125e8bcc37fSRamkumar Ramachandra       funcOp.getSymbolUses(moduleOp);
1262e2c0738SRiver Riddle   for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
1272e2c0738SRiver Riddle     // TODO: Extend this for ops that are FunctionOpInterface. This would
1282e2c0738SRiver Riddle     // require creating an OpInterface for FunctionOpInterface ops.
12958ceae95SRiver Riddle     func::FuncOp parentFuncOp =
13058ceae95SRiver Riddle         symbolUse.getUser()->getParentOfType<func::FuncOp>();
13158ceae95SRiver Riddle     for (func::FuncOp &funcOp : normalizableFuncs) {
1322e2c0738SRiver Riddle       if (parentFuncOp == funcOp) {
1332e2c0738SRiver Riddle         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
1342e2c0738SRiver Riddle                                             normalizableFuncs);
1352e2c0738SRiver Riddle         break;
1362e2c0738SRiver Riddle       }
1372e2c0738SRiver Riddle     }
1382e2c0738SRiver Riddle   }
1392e2c0738SRiver Riddle 
1402e2c0738SRiver Riddle   // Functions called by this function.
14123aa5a74SRiver Riddle   funcOp.walk([&](func::CallOp callOp) {
1422e2c0738SRiver Riddle     StringAttr callee = callOp.getCalleeAttr().getAttr();
14358ceae95SRiver Riddle     for (func::FuncOp &funcOp : normalizableFuncs) {
14458ceae95SRiver Riddle       // We compare func::FuncOp and callee's name.
1452e2c0738SRiver Riddle       if (callee == funcOp.getNameAttr()) {
1462e2c0738SRiver Riddle         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
1472e2c0738SRiver Riddle                                             normalizableFuncs);
1482e2c0738SRiver Riddle         break;
1492e2c0738SRiver Riddle       }
1502e2c0738SRiver Riddle     }
1512e2c0738SRiver Riddle   });
1522e2c0738SRiver Riddle }
1532e2c0738SRiver Riddle 
154*2ec27848SMatthias Gehre /// Check whether all the uses of AllocOps, AllocaOps, CallOps and function
155*2ec27848SMatthias Gehre /// arguments of a function are either of dereferencing type or are uses in:
156*2ec27848SMatthias Gehre /// DeallocOp, CallOp or ReturnOp. Only if these constraints are satisfied will
157*2ec27848SMatthias Gehre /// the function become a candidate for normalization. When the uses of a memref
158*2ec27848SMatthias Gehre /// are non-normalizable and the memref map layout is trivial (identity), we can
159183c4a39STung D. Le /// still label the entire function as normalizable. We assume external
160183c4a39STung D. Le /// functions to be normalizable.
16158ceae95SRiver Riddle bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
1622e2c0738SRiver Riddle   // We assume external functions to be normalizable.
1632e2c0738SRiver Riddle   if (funcOp.isExternal())
1642e2c0738SRiver Riddle     return true;
1652e2c0738SRiver Riddle 
1662e2c0738SRiver Riddle   if (funcOp
1672e2c0738SRiver Riddle           .walk([&](memref::AllocOp allocOp) -> WalkResult {
1682e2c0738SRiver Riddle             Value oldMemRef = allocOp.getResult();
169d4284bafSTung D. Le             if (!allocOp.getType().getLayout().isIdentity() &&
170183c4a39STung D. Le                 !isMemRefNormalizable(oldMemRef.getUsers()))
1712e2c0738SRiver Riddle               return WalkResult::interrupt();
1722e2c0738SRiver Riddle             return WalkResult::advance();
1732e2c0738SRiver Riddle           })
1742e2c0738SRiver Riddle           .wasInterrupted())
1752e2c0738SRiver Riddle     return false;
1762e2c0738SRiver Riddle 
1772e2c0738SRiver Riddle   if (funcOp
178*2ec27848SMatthias Gehre           .walk([&](memref::AllocaOp allocaOp) -> WalkResult {
179*2ec27848SMatthias Gehre             Value oldMemRef = allocaOp.getResult();
180*2ec27848SMatthias Gehre             if (!allocaOp.getType().getLayout().isIdentity() &&
181*2ec27848SMatthias Gehre                 !isMemRefNormalizable(oldMemRef.getUsers()))
182*2ec27848SMatthias Gehre               return WalkResult::interrupt();
183*2ec27848SMatthias Gehre             return WalkResult::advance();
184*2ec27848SMatthias Gehre           })
185*2ec27848SMatthias Gehre           .wasInterrupted())
186*2ec27848SMatthias Gehre     return false;
187*2ec27848SMatthias Gehre 
188*2ec27848SMatthias Gehre   if (funcOp
18923aa5a74SRiver Riddle           .walk([&](func::CallOp callOp) -> WalkResult {
1902e2c0738SRiver Riddle             for (unsigned resIndex :
1912e2c0738SRiver Riddle                  llvm::seq<unsigned>(0, callOp.getNumResults())) {
1922e2c0738SRiver Riddle               Value oldMemRef = callOp.getResult(resIndex);
193d4284bafSTung D. Le               if (auto oldMemRefType =
1945550c821STres Popp                       dyn_cast<MemRefType>(oldMemRef.getType()))
195d4284bafSTung D. Le                 if (!oldMemRefType.getLayout().isIdentity() &&
196183c4a39STung D. Le                     !isMemRefNormalizable(oldMemRef.getUsers()))
1972e2c0738SRiver Riddle                   return WalkResult::interrupt();
1982e2c0738SRiver Riddle             }
1992e2c0738SRiver Riddle             return WalkResult::advance();
2002e2c0738SRiver Riddle           })
2012e2c0738SRiver Riddle           .wasInterrupted())
2022e2c0738SRiver Riddle     return false;
2032e2c0738SRiver Riddle 
2042e2c0738SRiver Riddle   for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
2052e2c0738SRiver Riddle     BlockArgument oldMemRef = funcOp.getArgument(argIndex);
2065550c821STres Popp     if (auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType()))
207d4284bafSTung D. Le       if (!oldMemRefType.getLayout().isIdentity() &&
208183c4a39STung D. Le           !isMemRefNormalizable(oldMemRef.getUsers()))
2092e2c0738SRiver Riddle         return false;
2102e2c0738SRiver Riddle   }
2112e2c0738SRiver Riddle 
2122e2c0738SRiver Riddle   return true;
2132e2c0738SRiver Riddle }
2142e2c0738SRiver Riddle 
2152e2c0738SRiver Riddle /// Fetch the updated argument list and result of the function and update the
2162e2c0738SRiver Riddle /// function signature. This updates the function's return type at the caller
2172e2c0738SRiver Riddle /// site and in case the return type is a normalized memref then it updates
2182e2c0738SRiver Riddle /// the calling function's signature.
2192e2c0738SRiver Riddle /// TODO: An update to the calling function signature is required only if the
2202e2c0738SRiver Riddle /// returned value is in turn used in ReturnOp of the calling function.
22158ceae95SRiver Riddle void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
2222e2c0738SRiver Riddle                                                ModuleOp moduleOp) {
2234a3460a7SRiver Riddle   FunctionType functionType = funcOp.getFunctionType();
2242e2c0738SRiver Riddle   SmallVector<Type, 4> resultTypes;
2252e2c0738SRiver Riddle   FunctionType newFuncType;
2262e2c0738SRiver Riddle   resultTypes = llvm::to_vector<4>(functionType.getResults());
2272e2c0738SRiver Riddle 
2282e2c0738SRiver Riddle   // External function's signature was already updated in
2292e2c0738SRiver Riddle   // 'normalizeFuncOpMemRefs()'.
2302e2c0738SRiver Riddle   if (!funcOp.isExternal()) {
2312e2c0738SRiver Riddle     SmallVector<Type, 8> argTypes;
2322e2c0738SRiver Riddle     for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
2332e2c0738SRiver Riddle       argTypes.push_back(argEn.value().getType());
2342e2c0738SRiver Riddle 
2352e2c0738SRiver Riddle     // Traverse ReturnOps to check if an update to the return type in the
2362e2c0738SRiver Riddle     // function signature is required.
23723aa5a74SRiver Riddle     funcOp.walk([&](func::ReturnOp returnOp) {
2382e2c0738SRiver Riddle       for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
2392e2c0738SRiver Riddle         Type opType = operandEn.value().getType();
2405550c821STres Popp         MemRefType memrefType = dyn_cast<MemRefType>(opType);
2412e2c0738SRiver Riddle         // If type is not memref or if the memref type is same as that in
2422e2c0738SRiver Riddle         // function's return signature then no update is required.
2432e2c0738SRiver Riddle         if (!memrefType || memrefType == resultTypes[operandEn.index()])
2442e2c0738SRiver Riddle           continue;
2452e2c0738SRiver Riddle         // Update function's return type signature.
2462e2c0738SRiver Riddle         // Return type gets normalized either as a result of function argument
2472e2c0738SRiver Riddle         // normalization, AllocOp normalization or an update made at CallOp.
2482e2c0738SRiver Riddle         // There can be many call flows inside a function and an update to a
2492e2c0738SRiver Riddle         // specific ReturnOp has not yet been made. So we check that the result
2502e2c0738SRiver Riddle         // memref type is normalized.
2512e2c0738SRiver Riddle         // TODO: When selective normalization is implemented, handle multiple
2522e2c0738SRiver Riddle         // results case where some are normalized, some aren't.
2532e2c0738SRiver Riddle         if (memrefType.getLayout().isIdentity())
2542e2c0738SRiver Riddle           resultTypes[operandEn.index()] = memrefType;
2552e2c0738SRiver Riddle       }
2562e2c0738SRiver Riddle     });
2572e2c0738SRiver Riddle 
2582e2c0738SRiver Riddle     // We create a new function type and modify the function signature with this
2592e2c0738SRiver Riddle     // new type.
2602e2c0738SRiver Riddle     newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
2612e2c0738SRiver Riddle                                     /*results=*/resultTypes);
2622e2c0738SRiver Riddle   }
2632e2c0738SRiver Riddle 
2642e2c0738SRiver Riddle   // Since we update the function signature, it might affect the result types at
2652e2c0738SRiver Riddle   // the caller site. Since this result might even be used by the caller
2662e2c0738SRiver Riddle   // function in ReturnOps, the caller function's signature will also change.
2672e2c0738SRiver Riddle   // Hence we record the caller function in 'funcOpsToUpdate' to update their
2682e2c0738SRiver Riddle   // signature as well.
26958ceae95SRiver Riddle   llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate;
2702e2c0738SRiver Riddle   // We iterate over all symbolic uses of the function and update the return
2712e2c0738SRiver Riddle   // type at the caller site.
272e8bcc37fSRamkumar Ramachandra   std::optional<SymbolTable::UseRange> symbolUses =
273e8bcc37fSRamkumar Ramachandra       funcOp.getSymbolUses(moduleOp);
2742e2c0738SRiver Riddle   for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
2752e2c0738SRiver Riddle     Operation *userOp = symbolUse.getUser();
2762e2c0738SRiver Riddle     OpBuilder builder(userOp);
2772e2c0738SRiver Riddle     // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
2782e2c0738SRiver Riddle     // that the non-CallOp has no memrefs to be replaced.
2792e2c0738SRiver Riddle     // TODO: Handle cases where a non-CallOp symbol use of a function deals with
2802e2c0738SRiver Riddle     // memrefs.
28123aa5a74SRiver Riddle     auto callOp = dyn_cast<func::CallOp>(userOp);
2822e2c0738SRiver Riddle     if (!callOp)
2832e2c0738SRiver Riddle       continue;
2842e2c0738SRiver Riddle     Operation *newCallOp =
28523aa5a74SRiver Riddle         builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
2862e2c0738SRiver Riddle                                      resultTypes, userOp->getOperands());
2872e2c0738SRiver Riddle     bool replacingMemRefUsesFailed = false;
2882e2c0738SRiver Riddle     bool returnTypeChanged = false;
2892e2c0738SRiver Riddle     for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
2902e2c0738SRiver Riddle       OpResult oldResult = userOp->getResult(resIndex);
2912e2c0738SRiver Riddle       OpResult newResult = newCallOp->getResult(resIndex);
2922e2c0738SRiver Riddle       // This condition ensures that if the result is not of type memref or if
2932e2c0738SRiver Riddle       // the resulting memref was already having a trivial map layout then we
2942e2c0738SRiver Riddle       // need not perform any use replacement here.
2952e2c0738SRiver Riddle       if (oldResult.getType() == newResult.getType())
2962e2c0738SRiver Riddle         continue;
2972e2c0738SRiver Riddle       AffineMap layoutMap =
2985550c821STres Popp           cast<MemRefType>(oldResult.getType()).getLayout().getAffineMap();
2992e2c0738SRiver Riddle       if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
3002e2c0738SRiver Riddle                                           /*extraIndices=*/{},
3012e2c0738SRiver Riddle                                           /*indexRemap=*/layoutMap,
3022e2c0738SRiver Riddle                                           /*extraOperands=*/{},
3032e2c0738SRiver Riddle                                           /*symbolOperands=*/{},
3042e2c0738SRiver Riddle                                           /*domOpFilter=*/nullptr,
3052e2c0738SRiver Riddle                                           /*postDomOpFilter=*/nullptr,
3062e2c0738SRiver Riddle                                           /*allowNonDereferencingOps=*/true,
3072e2c0738SRiver Riddle                                           /*replaceInDeallocOp=*/true))) {
3082e2c0738SRiver Riddle         // If it failed (due to escapes for example), bail out.
3092e2c0738SRiver Riddle         // It should never hit this part of the code because it is called by
3102e2c0738SRiver Riddle         // only those functions which are normalizable.
3112e2c0738SRiver Riddle         newCallOp->erase();
3122e2c0738SRiver Riddle         replacingMemRefUsesFailed = true;
3132e2c0738SRiver Riddle         break;
3142e2c0738SRiver Riddle       }
3152e2c0738SRiver Riddle       returnTypeChanged = true;
3162e2c0738SRiver Riddle     }
3172e2c0738SRiver Riddle     if (replacingMemRefUsesFailed)
3182e2c0738SRiver Riddle       continue;
3192e2c0738SRiver Riddle     // Replace all uses for other non-memref result types.
3202e2c0738SRiver Riddle     userOp->replaceAllUsesWith(newCallOp);
3212e2c0738SRiver Riddle     userOp->erase();
3222e2c0738SRiver Riddle     if (returnTypeChanged) {
3232e2c0738SRiver Riddle       // Since the return type changed it might lead to a change in function's
3242e2c0738SRiver Riddle       // signature.
3252e2c0738SRiver Riddle       // TODO: If funcOp doesn't return any memref type then no need to update
3262e2c0738SRiver Riddle       // signature.
3272e2c0738SRiver Riddle       // TODO: Further optimization - Check if the memref is indeed part of
3282e2c0738SRiver Riddle       // ReturnOp at the parentFuncOp and only then updation of signature is
3292e2c0738SRiver Riddle       // required.
3302e2c0738SRiver Riddle       // TODO: Extend this for ops that are FunctionOpInterface. This would
3312e2c0738SRiver Riddle       // require creating an OpInterface for FunctionOpInterface ops.
33258ceae95SRiver Riddle       func::FuncOp parentFuncOp = newCallOp->getParentOfType<func::FuncOp>();
3332e2c0738SRiver Riddle       funcOpsToUpdate.insert(parentFuncOp);
3342e2c0738SRiver Riddle     }
3352e2c0738SRiver Riddle   }
3362e2c0738SRiver Riddle   // Because external function's signature is already updated in
3372e2c0738SRiver Riddle   // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
3382e2c0738SRiver Riddle   if (!funcOp.isExternal())
3392e2c0738SRiver Riddle     funcOp.setType(newFuncType);
3402e2c0738SRiver Riddle 
3412e2c0738SRiver Riddle   // Updating the signature type of those functions which call the current
3422e2c0738SRiver Riddle   // function. Only if the return type of the current function has a normalized
3432e2c0738SRiver Riddle   // memref will the caller function become a candidate for signature update.
34458ceae95SRiver Riddle   for (func::FuncOp parentFuncOp : funcOpsToUpdate)
3452e2c0738SRiver Riddle     updateFunctionSignature(parentFuncOp, moduleOp);
3462e2c0738SRiver Riddle }
3472e2c0738SRiver Riddle 
3482e2c0738SRiver Riddle /// Normalizes the memrefs within a function which includes those arising as a
349*2ec27848SMatthias Gehre /// result of AllocOps, AllocaOps, CallOps and function's argument. The ModuleOp
350*2ec27848SMatthias Gehre /// argument is used to help update function's signature after normalization.
35158ceae95SRiver Riddle void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
3522e2c0738SRiver Riddle                                               ModuleOp moduleOp) {
3532e2c0738SRiver Riddle   // Turn memrefs' non-identity layouts maps into ones with identity. Collect
354*2ec27848SMatthias Gehre   // alloc/alloca ops first and then process since normalizeMemRef
355*2ec27848SMatthias Gehre   // replaces/erases ops during memref rewriting.
3562e2c0738SRiver Riddle   SmallVector<memref::AllocOp, 4> allocOps;
3572e2c0738SRiver Riddle   funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
3582e2c0738SRiver Riddle   for (memref::AllocOp allocOp : allocOps)
3592e2c0738SRiver Riddle     (void)normalizeMemRef(&allocOp);
3602e2c0738SRiver Riddle 
361*2ec27848SMatthias Gehre   SmallVector<memref::AllocaOp> allocaOps;
362*2ec27848SMatthias Gehre   funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); });
363*2ec27848SMatthias Gehre   for (memref::AllocaOp allocaOp : allocaOps)
364*2ec27848SMatthias Gehre     (void)normalizeMemRef(&allocaOp);
365*2ec27848SMatthias Gehre 
3662e2c0738SRiver Riddle   // We use this OpBuilder to create new memref layout later.
3672e2c0738SRiver Riddle   OpBuilder b(funcOp);
3682e2c0738SRiver Riddle 
3694a3460a7SRiver Riddle   FunctionType functionType = funcOp.getFunctionType();
3702e2c0738SRiver Riddle   SmallVector<Location> functionArgLocs(llvm::map_range(
3712e2c0738SRiver Riddle       funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); }));
3722e2c0738SRiver Riddle   SmallVector<Type, 8> inputTypes;
3732e2c0738SRiver Riddle   // Walk over each argument of a function to perform memref normalization (if
3742e2c0738SRiver Riddle   for (unsigned argIndex :
3752e2c0738SRiver Riddle        llvm::seq<unsigned>(0, functionType.getNumInputs())) {
3762e2c0738SRiver Riddle     Type argType = functionType.getInput(argIndex);
3775550c821STres Popp     MemRefType memrefType = dyn_cast<MemRefType>(argType);
3782e2c0738SRiver Riddle     // Check whether argument is of MemRef type. Any other argument type can
3792e2c0738SRiver Riddle     // simply be part of the final function signature.
3802e2c0738SRiver Riddle     if (!memrefType) {
3812e2c0738SRiver Riddle       inputTypes.push_back(argType);
3822e2c0738SRiver Riddle       continue;
3832e2c0738SRiver Riddle     }
3842e2c0738SRiver Riddle     // Fetch a new memref type after normalizing the old memref to have an
3852e2c0738SRiver Riddle     // identity map layout.
3861fee821dSKai Sasaki     MemRefType newMemRefType = normalizeMemRefType(memrefType);
3872e2c0738SRiver Riddle     if (newMemRefType == memrefType || funcOp.isExternal()) {
3882e2c0738SRiver Riddle       // Either memrefType already had an identity map or the map couldn't be
3892e2c0738SRiver Riddle       // transformed to an identity map.
3902e2c0738SRiver Riddle       inputTypes.push_back(newMemRefType);
3912e2c0738SRiver Riddle       continue;
3922e2c0738SRiver Riddle     }
3932e2c0738SRiver Riddle 
3942e2c0738SRiver Riddle     // Insert a new temporary argument with the new memref type.
3952e2c0738SRiver Riddle     BlockArgument newMemRef = funcOp.front().insertArgument(
3962e2c0738SRiver Riddle         argIndex, newMemRefType, functionArgLocs[argIndex]);
3972e2c0738SRiver Riddle     BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
3982e2c0738SRiver Riddle     AffineMap layoutMap = memrefType.getLayout().getAffineMap();
3992e2c0738SRiver Riddle     // Replace all uses of the old memref.
4002e2c0738SRiver Riddle     if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
4012e2c0738SRiver Riddle                                         /*extraIndices=*/{},
4022e2c0738SRiver Riddle                                         /*indexRemap=*/layoutMap,
4032e2c0738SRiver Riddle                                         /*extraOperands=*/{},
4042e2c0738SRiver Riddle                                         /*symbolOperands=*/{},
4052e2c0738SRiver Riddle                                         /*domOpFilter=*/nullptr,
4062e2c0738SRiver Riddle                                         /*postDomOpFilter=*/nullptr,
4072e2c0738SRiver Riddle                                         /*allowNonDereferencingOps=*/true,
4082e2c0738SRiver Riddle                                         /*replaceInDeallocOp=*/true))) {
4092e2c0738SRiver Riddle       // If it failed (due to escapes for example), bail out. Removing the
4102e2c0738SRiver Riddle       // temporary argument inserted previously.
4112e2c0738SRiver Riddle       funcOp.front().eraseArgument(argIndex);
4122e2c0738SRiver Riddle       continue;
4132e2c0738SRiver Riddle     }
4142e2c0738SRiver Riddle 
4152e2c0738SRiver Riddle     // All uses for the argument with old memref type were replaced
4162e2c0738SRiver Riddle     // successfully. So we remove the old argument now.
4172e2c0738SRiver Riddle     funcOp.front().eraseArgument(argIndex + 1);
4182e2c0738SRiver Riddle   }
4192e2c0738SRiver Riddle 
4202e2c0738SRiver Riddle   // Walk over normalizable operations to normalize memrefs of the operation
4212e2c0738SRiver Riddle   // results. When `op` has memrefs with affine map in the operation results,
4222e2c0738SRiver Riddle   // new operation containin normalized memrefs is created. Then, the memrefs
4232e2c0738SRiver Riddle   // are replaced. `CallOp` is skipped here because it is handled in
4242e2c0738SRiver Riddle   // `updateFunctionSignature()`.
4252e2c0738SRiver Riddle   funcOp.walk([&](Operation *op) {
4262e2c0738SRiver Riddle     if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
42723aa5a74SRiver Riddle         op->getNumResults() > 0 && !isa<func::CallOp>(op) &&
42823aa5a74SRiver Riddle         !funcOp.isExternal()) {
4292e2c0738SRiver Riddle       // Create newOp containing normalized memref in the operation result.
4302e2c0738SRiver Riddle       Operation *newOp = createOpResultsNormalized(funcOp, op);
4312e2c0738SRiver Riddle       // When all of the operation results have no memrefs or memrefs without
4322e2c0738SRiver Riddle       // affine map, `newOp` is the same with `op` and following process is
4332e2c0738SRiver Riddle       // skipped.
4342e2c0738SRiver Riddle       if (op != newOp) {
4352e2c0738SRiver Riddle         bool replacingMemRefUsesFailed = false;
4362e2c0738SRiver Riddle         for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
4372e2c0738SRiver Riddle           // Replace all uses of the old memrefs.
4382e2c0738SRiver Riddle           Value oldMemRef = op->getResult(resIndex);
4392e2c0738SRiver Riddle           Value newMemRef = newOp->getResult(resIndex);
4405550c821STres Popp           MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType());
4412e2c0738SRiver Riddle           // Check whether the operation result is MemRef type.
4422e2c0738SRiver Riddle           if (!oldMemRefType)
4432e2c0738SRiver Riddle             continue;
4445550c821STres Popp           MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType());
4452e2c0738SRiver Riddle           if (oldMemRefType == newMemRefType)
4462e2c0738SRiver Riddle             continue;
4472e2c0738SRiver Riddle           // TODO: Assume single layout map. Multiple maps not supported.
4482e2c0738SRiver Riddle           AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
4492e2c0738SRiver Riddle           if (failed(replaceAllMemRefUsesWith(oldMemRef,
4502e2c0738SRiver Riddle                                               /*newMemRef=*/newMemRef,
4512e2c0738SRiver Riddle                                               /*extraIndices=*/{},
4522e2c0738SRiver Riddle                                               /*indexRemap=*/layoutMap,
4532e2c0738SRiver Riddle                                               /*extraOperands=*/{},
4542e2c0738SRiver Riddle                                               /*symbolOperands=*/{},
4552e2c0738SRiver Riddle                                               /*domOpFilter=*/nullptr,
4562e2c0738SRiver Riddle                                               /*postDomOpFilter=*/nullptr,
4572e2c0738SRiver Riddle                                               /*allowNonDereferencingOps=*/true,
4582e2c0738SRiver Riddle                                               /*replaceInDeallocOp=*/true))) {
4592e2c0738SRiver Riddle             newOp->erase();
4602e2c0738SRiver Riddle             replacingMemRefUsesFailed = true;
4612e2c0738SRiver Riddle             continue;
4622e2c0738SRiver Riddle           }
4632e2c0738SRiver Riddle         }
4642e2c0738SRiver Riddle         if (!replacingMemRefUsesFailed) {
4652e2c0738SRiver Riddle           // Replace other ops with new op and delete the old op when the
4662e2c0738SRiver Riddle           // replacement succeeded.
4672e2c0738SRiver Riddle           op->replaceAllUsesWith(newOp);
4682e2c0738SRiver Riddle           op->erase();
4692e2c0738SRiver Riddle         }
4702e2c0738SRiver Riddle       }
4712e2c0738SRiver Riddle     }
4722e2c0738SRiver Riddle   });
4732e2c0738SRiver Riddle 
4742e2c0738SRiver Riddle   // In a normal function, memrefs in the return type signature gets normalized
4752e2c0738SRiver Riddle   // as a result of normalization of functions arguments, AllocOps or CallOps'
4762e2c0738SRiver Riddle   // result types. Since an external function doesn't have a body, memrefs in
4772e2c0738SRiver Riddle   // the return type signature can only get normalized by iterating over the
4782e2c0738SRiver Riddle   // individual return types.
4792e2c0738SRiver Riddle   if (funcOp.isExternal()) {
4802e2c0738SRiver Riddle     SmallVector<Type, 4> resultTypes;
4812e2c0738SRiver Riddle     for (unsigned resIndex :
4822e2c0738SRiver Riddle          llvm::seq<unsigned>(0, functionType.getNumResults())) {
4832e2c0738SRiver Riddle       Type resType = functionType.getResult(resIndex);
4845550c821STres Popp       MemRefType memrefType = dyn_cast<MemRefType>(resType);
4852e2c0738SRiver Riddle       // Check whether result is of MemRef type. Any other argument type can
4862e2c0738SRiver Riddle       // simply be part of the final function signature.
4872e2c0738SRiver Riddle       if (!memrefType) {
4882e2c0738SRiver Riddle         resultTypes.push_back(resType);
4892e2c0738SRiver Riddle         continue;
4902e2c0738SRiver Riddle       }
4912e2c0738SRiver Riddle       // Computing a new memref type after normalizing the old memref to have an
4922e2c0738SRiver Riddle       // identity map layout.
4931fee821dSKai Sasaki       MemRefType newMemRefType = normalizeMemRefType(memrefType);
4942e2c0738SRiver Riddle       resultTypes.push_back(newMemRefType);
4952e2c0738SRiver Riddle     }
4962e2c0738SRiver Riddle 
4972e2c0738SRiver Riddle     FunctionType newFuncType =
4982e2c0738SRiver Riddle         FunctionType::get(&getContext(), /*inputs=*/inputTypes,
4992e2c0738SRiver Riddle                           /*results=*/resultTypes);
5002e2c0738SRiver Riddle     // Setting the new function signature for this external function.
5012e2c0738SRiver Riddle     funcOp.setType(newFuncType);
5022e2c0738SRiver Riddle   }
5032e2c0738SRiver Riddle   updateFunctionSignature(funcOp, moduleOp);
5042e2c0738SRiver Riddle }
5052e2c0738SRiver Riddle 
5062e2c0738SRiver Riddle /// Create an operation containing normalized memrefs in the operation results.
5072e2c0738SRiver Riddle /// When the results of `oldOp` have memrefs with affine map, the memrefs are
5082e2c0738SRiver Riddle /// normalized, and new operation containing them in the operation results is
5092e2c0738SRiver Riddle /// returned. If all of the results of `oldOp` have no memrefs or memrefs
5102e2c0738SRiver Riddle /// without affine map, `oldOp` is returned without modification.
51158ceae95SRiver Riddle Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
5122e2c0738SRiver Riddle                                                        Operation *oldOp) {
5132e2c0738SRiver Riddle   // Prepare OperationState to create newOp containing normalized memref in
5142e2c0738SRiver Riddle   // the operation results.
5152e2c0738SRiver Riddle   OperationState result(oldOp->getLoc(), oldOp->getName());
5162e2c0738SRiver Riddle   result.addOperands(oldOp->getOperands());
5172e2c0738SRiver Riddle   result.addAttributes(oldOp->getAttrs());
5182e2c0738SRiver Riddle   // Add normalized MemRefType to the OperationState.
5192e2c0738SRiver Riddle   SmallVector<Type, 4> resultTypes;
5202e2c0738SRiver Riddle   OpBuilder b(funcOp);
5212e2c0738SRiver Riddle   bool resultTypeNormalized = false;
5222e2c0738SRiver Riddle   for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
5232e2c0738SRiver Riddle     auto resultType = oldOp->getResult(resIndex).getType();
5245550c821STres Popp     MemRefType memrefType = dyn_cast<MemRefType>(resultType);
5252e2c0738SRiver Riddle     // Check whether the operation result is MemRef type.
5262e2c0738SRiver Riddle     if (!memrefType) {
5272e2c0738SRiver Riddle       resultTypes.push_back(resultType);
5282e2c0738SRiver Riddle       continue;
5292e2c0738SRiver Riddle     }
5301fee821dSKai Sasaki 
5312e2c0738SRiver Riddle     // Fetch a new memref type after normalizing the old memref.
5321fee821dSKai Sasaki     MemRefType newMemRefType = normalizeMemRefType(memrefType);
5332e2c0738SRiver Riddle     if (newMemRefType == memrefType) {
5342e2c0738SRiver Riddle       // Either memrefType already had an identity map or the map couldn't
5352e2c0738SRiver Riddle       // be transformed to an identity map.
5362e2c0738SRiver Riddle       resultTypes.push_back(memrefType);
5372e2c0738SRiver Riddle       continue;
5382e2c0738SRiver Riddle     }
5392e2c0738SRiver Riddle     resultTypes.push_back(newMemRefType);
5402e2c0738SRiver Riddle     resultTypeNormalized = true;
5412e2c0738SRiver Riddle   }
5422e2c0738SRiver Riddle   result.addTypes(resultTypes);
5432e2c0738SRiver Riddle   // When all of the results of `oldOp` have no memrefs or memrefs without
5442e2c0738SRiver Riddle   // affine map, `oldOp` is returned without modification.
5452e2c0738SRiver Riddle   if (resultTypeNormalized) {
5462e2c0738SRiver Riddle     OpBuilder bb(oldOp);
5472e2c0738SRiver Riddle     for (auto &oldRegion : oldOp->getRegions()) {
5482e2c0738SRiver Riddle       Region *newRegion = result.addRegion();
5492e2c0738SRiver Riddle       newRegion->takeBody(oldRegion);
5502e2c0738SRiver Riddle     }
55114ecafd0SChia-hung Duan     return bb.create(result);
5522e2c0738SRiver Riddle   }
5532e2c0738SRiver Riddle   return oldOp;
5542e2c0738SRiver Riddle }
555