xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp (revision 2ec27848c00cda734697619047e640eadb254555)
1 //===- NormalizeMemRefs.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an interprocedural pass to normalize memrefs to have
10 // identity layout maps.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/Utils.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/Support/Debug.h"
21 
22 namespace mlir {
23 namespace memref {
24 #define GEN_PASS_DEF_NORMALIZEMEMREFS
25 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
26 } // namespace memref
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "normalize-memrefs"
30 
31 using namespace mlir;
32 using namespace mlir::affine;
33 
34 namespace {
35 
36 /// All memrefs passed across functions with non-trivial layout maps are
37 /// converted to ones with trivial identity layout ones.
38 /// If all the memref types/uses in a function are normalizable, we treat
39 /// such functions as normalizable. Also, if a normalizable function is known
40 /// to call a non-normalizable function, we treat that function as
41 /// non-normalizable as well. We assume external functions to be normalizable.
42 struct NormalizeMemRefs
43     : public memref::impl::NormalizeMemRefsBase<NormalizeMemRefs> {
44   void runOnOperation() override;
45   void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
46   bool areMemRefsNormalizable(func::FuncOp funcOp);
47   void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp);
48   void setCalleesAndCallersNonNormalizable(
49       func::FuncOp funcOp, ModuleOp moduleOp,
50       DenseSet<func::FuncOp> &normalizableFuncs);
51   Operation *createOpResultsNormalized(func::FuncOp funcOp, Operation *oldOp);
52 };
53 
54 } // namespace
55 
56 std::unique_ptr<OperationPass<ModuleOp>>
57 mlir::memref::createNormalizeMemRefsPass() {
58   return std::make_unique<NormalizeMemRefs>();
59 }
60 
61 void NormalizeMemRefs::runOnOperation() {
62   LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
63   ModuleOp moduleOp = getOperation();
64   // We maintain all normalizable FuncOps in a DenseSet. It is initialized
65   // with all the functions within a module and then functions which are not
66   // normalizable are removed from this set.
67   // TODO: Change this to work on FuncLikeOp once there is an operation
68   // interface for it.
69   DenseSet<func::FuncOp> normalizableFuncs;
70   // Initialize `normalizableFuncs` with all the functions within a module.
71   moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
72 
73   // Traverse through all the functions applying a filter which determines
74   // whether that function is normalizable or not. All callers/callees of
75   // a non-normalizable function will also become non-normalizable even if
76   // they aren't passing any or specific non-normalizable memrefs. So,
77   // functions which calls or get called by a non-normalizable becomes non-
78   // normalizable functions themselves.
79   moduleOp.walk([&](func::FuncOp funcOp) {
80     if (normalizableFuncs.contains(funcOp)) {
81       if (!areMemRefsNormalizable(funcOp)) {
82         LLVM_DEBUG(llvm::dbgs()
83                    << "@" << funcOp.getName()
84                    << " contains ops that cannot normalize MemRefs\n");
85         // Since this function is not normalizable, we set all the caller
86         // functions and the callees of this function as not normalizable.
87         // TODO: Drop this conservative assumption in the future.
88         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
89                                             normalizableFuncs);
90       }
91     }
92   });
93 
94   LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
95                           << " functions\n");
96   // Those functions which can be normalized are subjected to normalization.
97   for (func::FuncOp &funcOp : normalizableFuncs)
98     normalizeFuncOpMemRefs(funcOp, moduleOp);
99 }
100 
101 /// Check whether all the uses of oldMemRef are either dereferencing uses or the
102 /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
103 /// are satisfied will the value become a candidate for replacement.
104 /// TODO: Extend this for DimOps.
105 static bool isMemRefNormalizable(Value::user_range opUsers) {
106   return llvm::all_of(opUsers, [](Operation *op) {
107     return op->hasTrait<OpTrait::MemRefsNormalizable>();
108   });
109 }
110 
111 /// Set all the calling functions and the callees of the function as not
112 /// normalizable.
113 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
114     func::FuncOp funcOp, ModuleOp moduleOp,
115     DenseSet<func::FuncOp> &normalizableFuncs) {
116   if (!normalizableFuncs.contains(funcOp))
117     return;
118 
119   LLVM_DEBUG(
120       llvm::dbgs() << "@" << funcOp.getName()
121                    << " calls or is called by non-normalizable function\n");
122   normalizableFuncs.erase(funcOp);
123   // Caller of the function.
124   std::optional<SymbolTable::UseRange> symbolUses =
125       funcOp.getSymbolUses(moduleOp);
126   for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
127     // TODO: Extend this for ops that are FunctionOpInterface. This would
128     // require creating an OpInterface for FunctionOpInterface ops.
129     func::FuncOp parentFuncOp =
130         symbolUse.getUser()->getParentOfType<func::FuncOp>();
131     for (func::FuncOp &funcOp : normalizableFuncs) {
132       if (parentFuncOp == funcOp) {
133         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
134                                             normalizableFuncs);
135         break;
136       }
137     }
138   }
139 
140   // Functions called by this function.
141   funcOp.walk([&](func::CallOp callOp) {
142     StringAttr callee = callOp.getCalleeAttr().getAttr();
143     for (func::FuncOp &funcOp : normalizableFuncs) {
144       // We compare func::FuncOp and callee's name.
145       if (callee == funcOp.getNameAttr()) {
146         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
147                                             normalizableFuncs);
148         break;
149       }
150     }
151   });
152 }
153 
154 /// Check whether all the uses of AllocOps, AllocaOps, CallOps and function
155 /// arguments of a function are either of dereferencing type or are uses in:
156 /// DeallocOp, CallOp or ReturnOp. Only if these constraints are satisfied will
157 /// the function become a candidate for normalization. When the uses of a memref
158 /// are non-normalizable and the memref map layout is trivial (identity), we can
159 /// still label the entire function as normalizable. We assume external
160 /// functions to be normalizable.
161 bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
162   // We assume external functions to be normalizable.
163   if (funcOp.isExternal())
164     return true;
165 
166   if (funcOp
167           .walk([&](memref::AllocOp allocOp) -> WalkResult {
168             Value oldMemRef = allocOp.getResult();
169             if (!allocOp.getType().getLayout().isIdentity() &&
170                 !isMemRefNormalizable(oldMemRef.getUsers()))
171               return WalkResult::interrupt();
172             return WalkResult::advance();
173           })
174           .wasInterrupted())
175     return false;
176 
177   if (funcOp
178           .walk([&](memref::AllocaOp allocaOp) -> WalkResult {
179             Value oldMemRef = allocaOp.getResult();
180             if (!allocaOp.getType().getLayout().isIdentity() &&
181                 !isMemRefNormalizable(oldMemRef.getUsers()))
182               return WalkResult::interrupt();
183             return WalkResult::advance();
184           })
185           .wasInterrupted())
186     return false;
187 
188   if (funcOp
189           .walk([&](func::CallOp callOp) -> WalkResult {
190             for (unsigned resIndex :
191                  llvm::seq<unsigned>(0, callOp.getNumResults())) {
192               Value oldMemRef = callOp.getResult(resIndex);
193               if (auto oldMemRefType =
194                       dyn_cast<MemRefType>(oldMemRef.getType()))
195                 if (!oldMemRefType.getLayout().isIdentity() &&
196                     !isMemRefNormalizable(oldMemRef.getUsers()))
197                   return WalkResult::interrupt();
198             }
199             return WalkResult::advance();
200           })
201           .wasInterrupted())
202     return false;
203 
204   for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
205     BlockArgument oldMemRef = funcOp.getArgument(argIndex);
206     if (auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType()))
207       if (!oldMemRefType.getLayout().isIdentity() &&
208           !isMemRefNormalizable(oldMemRef.getUsers()))
209         return false;
210   }
211 
212   return true;
213 }
214 
215 /// Fetch the updated argument list and result of the function and update the
216 /// function signature. This updates the function's return type at the caller
217 /// site and in case the return type is a normalized memref then it updates
218 /// the calling function's signature.
219 /// TODO: An update to the calling function signature is required only if the
220 /// returned value is in turn used in ReturnOp of the calling function.
221 void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
222                                                ModuleOp moduleOp) {
223   FunctionType functionType = funcOp.getFunctionType();
224   SmallVector<Type, 4> resultTypes;
225   FunctionType newFuncType;
226   resultTypes = llvm::to_vector<4>(functionType.getResults());
227 
228   // External function's signature was already updated in
229   // 'normalizeFuncOpMemRefs()'.
230   if (!funcOp.isExternal()) {
231     SmallVector<Type, 8> argTypes;
232     for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
233       argTypes.push_back(argEn.value().getType());
234 
235     // Traverse ReturnOps to check if an update to the return type in the
236     // function signature is required.
237     funcOp.walk([&](func::ReturnOp returnOp) {
238       for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
239         Type opType = operandEn.value().getType();
240         MemRefType memrefType = dyn_cast<MemRefType>(opType);
241         // If type is not memref or if the memref type is same as that in
242         // function's return signature then no update is required.
243         if (!memrefType || memrefType == resultTypes[operandEn.index()])
244           continue;
245         // Update function's return type signature.
246         // Return type gets normalized either as a result of function argument
247         // normalization, AllocOp normalization or an update made at CallOp.
248         // There can be many call flows inside a function and an update to a
249         // specific ReturnOp has not yet been made. So we check that the result
250         // memref type is normalized.
251         // TODO: When selective normalization is implemented, handle multiple
252         // results case where some are normalized, some aren't.
253         if (memrefType.getLayout().isIdentity())
254           resultTypes[operandEn.index()] = memrefType;
255       }
256     });
257 
258     // We create a new function type and modify the function signature with this
259     // new type.
260     newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
261                                     /*results=*/resultTypes);
262   }
263 
264   // Since we update the function signature, it might affect the result types at
265   // the caller site. Since this result might even be used by the caller
266   // function in ReturnOps, the caller function's signature will also change.
267   // Hence we record the caller function in 'funcOpsToUpdate' to update their
268   // signature as well.
269   llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate;
270   // We iterate over all symbolic uses of the function and update the return
271   // type at the caller site.
272   std::optional<SymbolTable::UseRange> symbolUses =
273       funcOp.getSymbolUses(moduleOp);
274   for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
275     Operation *userOp = symbolUse.getUser();
276     OpBuilder builder(userOp);
277     // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
278     // that the non-CallOp has no memrefs to be replaced.
279     // TODO: Handle cases where a non-CallOp symbol use of a function deals with
280     // memrefs.
281     auto callOp = dyn_cast<func::CallOp>(userOp);
282     if (!callOp)
283       continue;
284     Operation *newCallOp =
285         builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
286                                      resultTypes, userOp->getOperands());
287     bool replacingMemRefUsesFailed = false;
288     bool returnTypeChanged = false;
289     for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
290       OpResult oldResult = userOp->getResult(resIndex);
291       OpResult newResult = newCallOp->getResult(resIndex);
292       // This condition ensures that if the result is not of type memref or if
293       // the resulting memref was already having a trivial map layout then we
294       // need not perform any use replacement here.
295       if (oldResult.getType() == newResult.getType())
296         continue;
297       AffineMap layoutMap =
298           cast<MemRefType>(oldResult.getType()).getLayout().getAffineMap();
299       if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
300                                           /*extraIndices=*/{},
301                                           /*indexRemap=*/layoutMap,
302                                           /*extraOperands=*/{},
303                                           /*symbolOperands=*/{},
304                                           /*domOpFilter=*/nullptr,
305                                           /*postDomOpFilter=*/nullptr,
306                                           /*allowNonDereferencingOps=*/true,
307                                           /*replaceInDeallocOp=*/true))) {
308         // If it failed (due to escapes for example), bail out.
309         // It should never hit this part of the code because it is called by
310         // only those functions which are normalizable.
311         newCallOp->erase();
312         replacingMemRefUsesFailed = true;
313         break;
314       }
315       returnTypeChanged = true;
316     }
317     if (replacingMemRefUsesFailed)
318       continue;
319     // Replace all uses for other non-memref result types.
320     userOp->replaceAllUsesWith(newCallOp);
321     userOp->erase();
322     if (returnTypeChanged) {
323       // Since the return type changed it might lead to a change in function's
324       // signature.
325       // TODO: If funcOp doesn't return any memref type then no need to update
326       // signature.
327       // TODO: Further optimization - Check if the memref is indeed part of
328       // ReturnOp at the parentFuncOp and only then updation of signature is
329       // required.
330       // TODO: Extend this for ops that are FunctionOpInterface. This would
331       // require creating an OpInterface for FunctionOpInterface ops.
332       func::FuncOp parentFuncOp = newCallOp->getParentOfType<func::FuncOp>();
333       funcOpsToUpdate.insert(parentFuncOp);
334     }
335   }
336   // Because external function's signature is already updated in
337   // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
338   if (!funcOp.isExternal())
339     funcOp.setType(newFuncType);
340 
341   // Updating the signature type of those functions which call the current
342   // function. Only if the return type of the current function has a normalized
343   // memref will the caller function become a candidate for signature update.
344   for (func::FuncOp parentFuncOp : funcOpsToUpdate)
345     updateFunctionSignature(parentFuncOp, moduleOp);
346 }
347 
348 /// Normalizes the memrefs within a function which includes those arising as a
349 /// result of AllocOps, AllocaOps, CallOps and function's argument. The ModuleOp
350 /// argument is used to help update function's signature after normalization.
351 void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
352                                               ModuleOp moduleOp) {
353   // Turn memrefs' non-identity layouts maps into ones with identity. Collect
354   // alloc/alloca ops first and then process since normalizeMemRef
355   // replaces/erases ops during memref rewriting.
356   SmallVector<memref::AllocOp, 4> allocOps;
357   funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
358   for (memref::AllocOp allocOp : allocOps)
359     (void)normalizeMemRef(&allocOp);
360 
361   SmallVector<memref::AllocaOp> allocaOps;
362   funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); });
363   for (memref::AllocaOp allocaOp : allocaOps)
364     (void)normalizeMemRef(&allocaOp);
365 
366   // We use this OpBuilder to create new memref layout later.
367   OpBuilder b(funcOp);
368 
369   FunctionType functionType = funcOp.getFunctionType();
370   SmallVector<Location> functionArgLocs(llvm::map_range(
371       funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); }));
372   SmallVector<Type, 8> inputTypes;
373   // Walk over each argument of a function to perform memref normalization (if
374   for (unsigned argIndex :
375        llvm::seq<unsigned>(0, functionType.getNumInputs())) {
376     Type argType = functionType.getInput(argIndex);
377     MemRefType memrefType = dyn_cast<MemRefType>(argType);
378     // Check whether argument is of MemRef type. Any other argument type can
379     // simply be part of the final function signature.
380     if (!memrefType) {
381       inputTypes.push_back(argType);
382       continue;
383     }
384     // Fetch a new memref type after normalizing the old memref to have an
385     // identity map layout.
386     MemRefType newMemRefType = normalizeMemRefType(memrefType);
387     if (newMemRefType == memrefType || funcOp.isExternal()) {
388       // Either memrefType already had an identity map or the map couldn't be
389       // transformed to an identity map.
390       inputTypes.push_back(newMemRefType);
391       continue;
392     }
393 
394     // Insert a new temporary argument with the new memref type.
395     BlockArgument newMemRef = funcOp.front().insertArgument(
396         argIndex, newMemRefType, functionArgLocs[argIndex]);
397     BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
398     AffineMap layoutMap = memrefType.getLayout().getAffineMap();
399     // Replace all uses of the old memref.
400     if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
401                                         /*extraIndices=*/{},
402                                         /*indexRemap=*/layoutMap,
403                                         /*extraOperands=*/{},
404                                         /*symbolOperands=*/{},
405                                         /*domOpFilter=*/nullptr,
406                                         /*postDomOpFilter=*/nullptr,
407                                         /*allowNonDereferencingOps=*/true,
408                                         /*replaceInDeallocOp=*/true))) {
409       // If it failed (due to escapes for example), bail out. Removing the
410       // temporary argument inserted previously.
411       funcOp.front().eraseArgument(argIndex);
412       continue;
413     }
414 
415     // All uses for the argument with old memref type were replaced
416     // successfully. So we remove the old argument now.
417     funcOp.front().eraseArgument(argIndex + 1);
418   }
419 
420   // Walk over normalizable operations to normalize memrefs of the operation
421   // results. When `op` has memrefs with affine map in the operation results,
422   // new operation containin normalized memrefs is created. Then, the memrefs
423   // are replaced. `CallOp` is skipped here because it is handled in
424   // `updateFunctionSignature()`.
425   funcOp.walk([&](Operation *op) {
426     if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
427         op->getNumResults() > 0 && !isa<func::CallOp>(op) &&
428         !funcOp.isExternal()) {
429       // Create newOp containing normalized memref in the operation result.
430       Operation *newOp = createOpResultsNormalized(funcOp, op);
431       // When all of the operation results have no memrefs or memrefs without
432       // affine map, `newOp` is the same with `op` and following process is
433       // skipped.
434       if (op != newOp) {
435         bool replacingMemRefUsesFailed = false;
436         for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
437           // Replace all uses of the old memrefs.
438           Value oldMemRef = op->getResult(resIndex);
439           Value newMemRef = newOp->getResult(resIndex);
440           MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType());
441           // Check whether the operation result is MemRef type.
442           if (!oldMemRefType)
443             continue;
444           MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType());
445           if (oldMemRefType == newMemRefType)
446             continue;
447           // TODO: Assume single layout map. Multiple maps not supported.
448           AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
449           if (failed(replaceAllMemRefUsesWith(oldMemRef,
450                                               /*newMemRef=*/newMemRef,
451                                               /*extraIndices=*/{},
452                                               /*indexRemap=*/layoutMap,
453                                               /*extraOperands=*/{},
454                                               /*symbolOperands=*/{},
455                                               /*domOpFilter=*/nullptr,
456                                               /*postDomOpFilter=*/nullptr,
457                                               /*allowNonDereferencingOps=*/true,
458                                               /*replaceInDeallocOp=*/true))) {
459             newOp->erase();
460             replacingMemRefUsesFailed = true;
461             continue;
462           }
463         }
464         if (!replacingMemRefUsesFailed) {
465           // Replace other ops with new op and delete the old op when the
466           // replacement succeeded.
467           op->replaceAllUsesWith(newOp);
468           op->erase();
469         }
470       }
471     }
472   });
473 
474   // In a normal function, memrefs in the return type signature gets normalized
475   // as a result of normalization of functions arguments, AllocOps or CallOps'
476   // result types. Since an external function doesn't have a body, memrefs in
477   // the return type signature can only get normalized by iterating over the
478   // individual return types.
479   if (funcOp.isExternal()) {
480     SmallVector<Type, 4> resultTypes;
481     for (unsigned resIndex :
482          llvm::seq<unsigned>(0, functionType.getNumResults())) {
483       Type resType = functionType.getResult(resIndex);
484       MemRefType memrefType = dyn_cast<MemRefType>(resType);
485       // Check whether result is of MemRef type. Any other argument type can
486       // simply be part of the final function signature.
487       if (!memrefType) {
488         resultTypes.push_back(resType);
489         continue;
490       }
491       // Computing a new memref type after normalizing the old memref to have an
492       // identity map layout.
493       MemRefType newMemRefType = normalizeMemRefType(memrefType);
494       resultTypes.push_back(newMemRefType);
495     }
496 
497     FunctionType newFuncType =
498         FunctionType::get(&getContext(), /*inputs=*/inputTypes,
499                           /*results=*/resultTypes);
500     // Setting the new function signature for this external function.
501     funcOp.setType(newFuncType);
502   }
503   updateFunctionSignature(funcOp, moduleOp);
504 }
505 
506 /// Create an operation containing normalized memrefs in the operation results.
507 /// When the results of `oldOp` have memrefs with affine map, the memrefs are
508 /// normalized, and new operation containing them in the operation results is
509 /// returned. If all of the results of `oldOp` have no memrefs or memrefs
510 /// without affine map, `oldOp` is returned without modification.
511 Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
512                                                        Operation *oldOp) {
513   // Prepare OperationState to create newOp containing normalized memref in
514   // the operation results.
515   OperationState result(oldOp->getLoc(), oldOp->getName());
516   result.addOperands(oldOp->getOperands());
517   result.addAttributes(oldOp->getAttrs());
518   // Add normalized MemRefType to the OperationState.
519   SmallVector<Type, 4> resultTypes;
520   OpBuilder b(funcOp);
521   bool resultTypeNormalized = false;
522   for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
523     auto resultType = oldOp->getResult(resIndex).getType();
524     MemRefType memrefType = dyn_cast<MemRefType>(resultType);
525     // Check whether the operation result is MemRef type.
526     if (!memrefType) {
527       resultTypes.push_back(resultType);
528       continue;
529     }
530 
531     // Fetch a new memref type after normalizing the old memref.
532     MemRefType newMemRefType = normalizeMemRefType(memrefType);
533     if (newMemRefType == memrefType) {
534       // Either memrefType already had an identity map or the map couldn't
535       // be transformed to an identity map.
536       resultTypes.push_back(memrefType);
537       continue;
538     }
539     resultTypes.push_back(newMemRefType);
540     resultTypeNormalized = true;
541   }
542   result.addTypes(resultTypes);
543   // When all of the results of `oldOp` have no memrefs or memrefs without
544   // affine map, `oldOp` is returned without modification.
545   if (resultTypeNormalized) {
546     OpBuilder bb(oldOp);
547     for (auto &oldRegion : oldOp->getRegions()) {
548       Region *newRegion = result.addRegion();
549       newRegion->takeBody(oldRegion);
550     }
551     return bb.create(result);
552   }
553   return oldOp;
554 }
555