xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp (revision 30916b6942371fc314f3ce1bfa4042cae3e6ff28)
176100870SMatthias Springer //===- IndependenceTransforms.cpp - Make ops independent of values --------===//
276100870SMatthias Springer //
376100870SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
476100870SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
576100870SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
676100870SMatthias Springer //
776100870SMatthias Springer //===----------------------------------------------------------------------===//
876100870SMatthias Springer 
976100870SMatthias Springer #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1076100870SMatthias Springer 
1176100870SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
1276100870SMatthias Springer #include "mlir/Dialect/Affine/Transforms/Transforms.h"
1376100870SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1476100870SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
1576100870SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h"
1676100870SMatthias Springer 
1776100870SMatthias Springer using namespace mlir;
1876100870SMatthias Springer using namespace mlir::memref;
1976100870SMatthias Springer 
2076100870SMatthias Springer /// Make the given OpFoldResult independent of all independencies.
2176100870SMatthias Springer static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
2276100870SMatthias Springer                                                OpFoldResult ofr,
2376100870SMatthias Springer                                                ValueRange independencies) {
24*30916b69SKazu Hirata   if (isa<Attribute>(ofr))
2576100870SMatthias Springer     return ofr;
2676100870SMatthias Springer   AffineMap boundMap;
2776100870SMatthias Springer   ValueDimList mapOperands;
2876100870SMatthias Springer   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
2940dd3aa9SMatthias Springer           boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
3040dd3aa9SMatthias Springer           /*closedUB=*/true)))
3176100870SMatthias Springer     return failure();
3276100870SMatthias Springer   return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
3376100870SMatthias Springer }
3476100870SMatthias Springer 
3576100870SMatthias Springer FailureOr<Value> memref::buildIndependentOp(OpBuilder &b,
3676100870SMatthias Springer                                             memref::AllocaOp allocaOp,
3776100870SMatthias Springer                                             ValueRange independencies) {
3876100870SMatthias Springer   OpBuilder::InsertionGuard g(b);
3976100870SMatthias Springer   b.setInsertionPoint(allocaOp);
4076100870SMatthias Springer   Location loc = allocaOp.getLoc();
4176100870SMatthias Springer 
4276100870SMatthias Springer   SmallVector<OpFoldResult> newSizes;
4376100870SMatthias Springer   for (OpFoldResult ofr : allocaOp.getMixedSizes()) {
4476100870SMatthias Springer     auto ub = makeIndependent(b, loc, ofr, independencies);
4576100870SMatthias Springer     if (failed(ub))
4676100870SMatthias Springer       return failure();
4776100870SMatthias Springer     newSizes.push_back(*ub);
4876100870SMatthias Springer   }
4976100870SMatthias Springer 
5076100870SMatthias Springer   // Return existing memref::AllocaOp if nothing has changed.
5176100870SMatthias Springer   if (llvm::equal(allocaOp.getMixedSizes(), newSizes))
5276100870SMatthias Springer     return allocaOp.getResult();
5376100870SMatthias Springer 
5476100870SMatthias Springer   // Create a new memref::AllocaOp.
5576100870SMatthias Springer   Value newAllocaOp =
5676100870SMatthias Springer       b.create<AllocaOp>(loc, newSizes, allocaOp.getType().getElementType());
5776100870SMatthias Springer 
5876100870SMatthias Springer   // Create a memref::SubViewOp.
5976100870SMatthias Springer   SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0));
6076100870SMatthias Springer   SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1));
6176100870SMatthias Springer   return b
6276100870SMatthias Springer       .create<SubViewOp>(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(),
6376100870SMatthias Springer                          strides)
6476100870SMatthias Springer       .getResult();
6576100870SMatthias Springer }
6676100870SMatthias Springer 
6776100870SMatthias Springer /// Push down an UnrealizedConversionCastOp past a SubViewOp.
6876100870SMatthias Springer static UnrealizedConversionCastOp
6976100870SMatthias Springer propagateSubViewOp(RewriterBase &rewriter,
7076100870SMatthias Springer                    UnrealizedConversionCastOp conversionOp, SubViewOp op) {
7176100870SMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
7276100870SMatthias Springer   rewriter.setInsertionPoint(op);
735550c821STres Popp   auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
7476100870SMatthias Springer       op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
755550c821STres Popp       op.getMixedSizes(), op.getMixedStrides()));
7676100870SMatthias Springer   Value newSubview = rewriter.create<SubViewOp>(
7776100870SMatthias Springer       op.getLoc(), newResultType, conversionOp.getOperand(0),
7876100870SMatthias Springer       op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
7976100870SMatthias Springer   auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>(
8076100870SMatthias Springer       op.getLoc(), op.getType(), newSubview);
8176100870SMatthias Springer   rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0));
8276100870SMatthias Springer   return newConversionOp;
8376100870SMatthias Springer }
8476100870SMatthias Springer 
8576100870SMatthias Springer /// Given an original op and a new, modified op with the same number of results,
8676100870SMatthias Springer /// whose memref return types may differ, replace all uses of the original op
8776100870SMatthias Springer /// with the new op and propagate the new memref types through the IR.
8876100870SMatthias Springer ///
8976100870SMatthias Springer /// Example:
9076100870SMatthias Springer /// %from = memref.alloca(%sz) : memref<?xf32>
9176100870SMatthias Springer /// %to = memref.subview ... : ... to memref<?xf32, strided<[1], offset: ?>>
9276100870SMatthias Springer /// memref.store %cst, %from[%c0] : memref<?xf32>
9376100870SMatthias Springer ///
9476100870SMatthias Springer /// In the above example, all uses of %from are replaced with %to. This can be
9576100870SMatthias Springer /// done directly for ops such as memref.store. For ops that have memref results
9676100870SMatthias Springer /// (e.g., memref.subview), the result type may depend on the operand type, so
9776100870SMatthias Springer /// we cannot just replace all uses. There is special handling for common memref
9876100870SMatthias Springer /// ops. For all other ops, unrealized_conversion_cast is inserted.
9976100870SMatthias Springer static void replaceAndPropagateMemRefType(RewriterBase &rewriter,
10076100870SMatthias Springer                                           Operation *from, Operation *to) {
10176100870SMatthias Springer   assert(from->getNumResults() == to->getNumResults() &&
10276100870SMatthias Springer          "expected same number of results");
10376100870SMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
10476100870SMatthias Springer   rewriter.setInsertionPointAfter(to);
10576100870SMatthias Springer 
10676100870SMatthias Springer   // Wrap new results in unrealized_conversion_cast and replace all uses of the
10776100870SMatthias Springer   // original op.
10876100870SMatthias Springer   SmallVector<UnrealizedConversionCastOp> unrealizedConversions;
10976100870SMatthias Springer   for (const auto &it :
11076100870SMatthias Springer        llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) {
11176100870SMatthias Springer     unrealizedConversions.push_back(rewriter.create<UnrealizedConversionCastOp>(
11276100870SMatthias Springer         to->getLoc(), std::get<0>(it.value()).getType(),
11376100870SMatthias Springer         std::get<1>(it.value())));
11476100870SMatthias Springer     rewriter.replaceAllUsesWith(from->getResult(it.index()),
11576100870SMatthias Springer                                 unrealizedConversions.back()->getResult(0));
11676100870SMatthias Springer   }
11776100870SMatthias Springer 
11876100870SMatthias Springer   // Push unrealized_conversion_cast ops further down in the IR. I.e., try to
11976100870SMatthias Springer   // wrap results instead of operands in a cast.
12076100870SMatthias Springer   for (int i = 0; i < static_cast<int>(unrealizedConversions.size()); ++i) {
12176100870SMatthias Springer     UnrealizedConversionCastOp conversion = unrealizedConversions[i];
12276100870SMatthias Springer     assert(conversion->getNumOperands() == 1 &&
12376100870SMatthias Springer            conversion->getNumResults() == 1 &&
12476100870SMatthias Springer            "expected single operand and single result");
12576100870SMatthias Springer     SmallVector<Operation *> users = llvm::to_vector(conversion->getUsers());
12676100870SMatthias Springer     for (Operation *user : users) {
12776100870SMatthias Springer       // Handle common memref dialect ops that produce new memrefs and must
12876100870SMatthias Springer       // be recreated with the new result type.
12976100870SMatthias Springer       if (auto subviewOp = dyn_cast<SubViewOp>(user)) {
13076100870SMatthias Springer         unrealizedConversions.push_back(
13176100870SMatthias Springer             propagateSubViewOp(rewriter, conversion, subviewOp));
13276100870SMatthias Springer         continue;
13376100870SMatthias Springer       }
13476100870SMatthias Springer 
13576100870SMatthias Springer       // TODO: Other memref ops such as memref.collapse_shape/expand_shape
13676100870SMatthias Springer       // should also be handled here.
13776100870SMatthias Springer 
13876100870SMatthias Springer       // Skip any ops that produce MemRef result or have MemRef region block
13976100870SMatthias Springer       // arguments. These may need special handling (e.g., scf.for).
14076100870SMatthias Springer       if (llvm::any_of(user->getResultTypes(),
14176100870SMatthias Springer                        [](Type t) { return isa<MemRefType>(t); }))
14276100870SMatthias Springer         continue;
14376100870SMatthias Springer       if (llvm::any_of(user->getRegions(), [](Region &r) {
14476100870SMatthias Springer             return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) {
14576100870SMatthias Springer               return isa<MemRefType>(bbArg.getType());
14676100870SMatthias Springer             });
14776100870SMatthias Springer           }))
14876100870SMatthias Springer         continue;
14976100870SMatthias Springer 
15076100870SMatthias Springer       // For all other ops, we assume that we can directly replace the operand.
15176100870SMatthias Springer       // This may have to be revised in the future; e.g., there may be ops that
15276100870SMatthias Springer       // do not support non-identity layout maps.
15376100870SMatthias Springer       for (OpOperand &operand : user->getOpOperands()) {
1540a0aff2dSMikhail Goncharov         if ([[maybe_unused]] auto castOp =
15576100870SMatthias Springer                 operand.get().getDefiningOp<UnrealizedConversionCastOp>()) {
1565fcf907bSMatthias Springer           rewriter.modifyOpInPlace(
15776100870SMatthias Springer               user, [&]() { operand.set(conversion->getOperand(0)); });
15876100870SMatthias Springer         }
15976100870SMatthias Springer       }
16076100870SMatthias Springer     }
16176100870SMatthias Springer   }
16276100870SMatthias Springer 
16376100870SMatthias Springer   // Erase all unrealized_conversion_cast ops without uses.
16476100870SMatthias Springer   for (auto op : unrealizedConversions)
16576100870SMatthias Springer     if (op->getUses().empty())
16676100870SMatthias Springer       rewriter.eraseOp(op);
16776100870SMatthias Springer }
16876100870SMatthias Springer 
16976100870SMatthias Springer FailureOr<Value> memref::replaceWithIndependentOp(RewriterBase &rewriter,
17076100870SMatthias Springer                                                   memref::AllocaOp allocaOp,
17176100870SMatthias Springer                                                   ValueRange independencies) {
17276100870SMatthias Springer   auto replacement =
17376100870SMatthias Springer       memref::buildIndependentOp(rewriter, allocaOp, independencies);
17476100870SMatthias Springer   if (failed(replacement))
17576100870SMatthias Springer     return failure();
17676100870SMatthias Springer   replaceAndPropagateMemRefType(rewriter, allocaOp,
17776100870SMatthias Springer                                 replacement->getDefiningOp());
17876100870SMatthias Springer   return replacement;
17976100870SMatthias Springer }
180e55e36deSOleksandr "Alex" Zinenko 
181e55e36deSOleksandr "Alex" Zinenko memref::AllocaOp memref::allocToAlloca(
182e55e36deSOleksandr "Alex" Zinenko     RewriterBase &rewriter, memref::AllocOp alloc,
183e55e36deSOleksandr "Alex" Zinenko     function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) {
184e55e36deSOleksandr "Alex" Zinenko   memref::DeallocOp dealloc = nullptr;
185e55e36deSOleksandr "Alex" Zinenko   for (Operation &candidate :
186e55e36deSOleksandr "Alex" Zinenko        llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) {
187e55e36deSOleksandr "Alex" Zinenko     dealloc = dyn_cast<memref::DeallocOp>(candidate);
188e55e36deSOleksandr "Alex" Zinenko     if (dealloc && dealloc.getMemref() == alloc.getMemref() &&
189e55e36deSOleksandr "Alex" Zinenko         (!filter || filter(alloc, dealloc))) {
190e55e36deSOleksandr "Alex" Zinenko       break;
191e55e36deSOleksandr "Alex" Zinenko     }
192e55e36deSOleksandr "Alex" Zinenko   }
193e55e36deSOleksandr "Alex" Zinenko 
194e55e36deSOleksandr "Alex" Zinenko   if (!dealloc)
195e55e36deSOleksandr "Alex" Zinenko     return nullptr;
196e55e36deSOleksandr "Alex" Zinenko 
197e55e36deSOleksandr "Alex" Zinenko   OpBuilder::InsertionGuard guard(rewriter);
198e55e36deSOleksandr "Alex" Zinenko   rewriter.setInsertionPoint(alloc);
199e55e36deSOleksandr "Alex" Zinenko   auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
200e55e36deSOleksandr "Alex" Zinenko       alloc, alloc.getMemref().getType(), alloc.getOperands());
201e55e36deSOleksandr "Alex" Zinenko   rewriter.eraseOp(dealloc);
202e55e36deSOleksandr "Alex" Zinenko   return alloca;
203e55e36deSOleksandr "Alex" Zinenko }
204