176100870SMatthias Springer //===- IndependenceTransforms.cpp - Make ops independent of values --------===// 276100870SMatthias Springer // 376100870SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 476100870SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 576100870SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 676100870SMatthias Springer // 776100870SMatthias Springer //===----------------------------------------------------------------------===// 876100870SMatthias Springer 976100870SMatthias Springer #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 1076100870SMatthias Springer 1176100870SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 1276100870SMatthias Springer #include "mlir/Dialect/Affine/Transforms/Transforms.h" 1376100870SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 1476100870SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h" 1576100870SMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h" 1676100870SMatthias Springer 1776100870SMatthias Springer using namespace mlir; 1876100870SMatthias Springer using namespace mlir::memref; 1976100870SMatthias Springer 2076100870SMatthias Springer /// Make the given OpFoldResult independent of all independencies. 2176100870SMatthias Springer static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc, 2276100870SMatthias Springer OpFoldResult ofr, 2376100870SMatthias Springer ValueRange independencies) { 24*30916b69SKazu Hirata if (isa<Attribute>(ofr)) 2576100870SMatthias Springer return ofr; 2676100870SMatthias Springer AffineMap boundMap; 2776100870SMatthias Springer ValueDimList mapOperands; 2876100870SMatthias Springer if (failed(ValueBoundsConstraintSet::computeIndependentBound( 2940dd3aa9SMatthias Springer boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies, 3040dd3aa9SMatthias Springer /*closedUB=*/true))) 3176100870SMatthias Springer return failure(); 3276100870SMatthias Springer return affine::materializeComputedBound(b, loc, boundMap, mapOperands); 3376100870SMatthias Springer } 3476100870SMatthias Springer 3576100870SMatthias Springer FailureOr<Value> memref::buildIndependentOp(OpBuilder &b, 3676100870SMatthias Springer memref::AllocaOp allocaOp, 3776100870SMatthias Springer ValueRange independencies) { 3876100870SMatthias Springer OpBuilder::InsertionGuard g(b); 3976100870SMatthias Springer b.setInsertionPoint(allocaOp); 4076100870SMatthias Springer Location loc = allocaOp.getLoc(); 4176100870SMatthias Springer 4276100870SMatthias Springer SmallVector<OpFoldResult> newSizes; 4376100870SMatthias Springer for (OpFoldResult ofr : allocaOp.getMixedSizes()) { 4476100870SMatthias Springer auto ub = makeIndependent(b, loc, ofr, independencies); 4576100870SMatthias Springer if (failed(ub)) 4676100870SMatthias Springer return failure(); 4776100870SMatthias Springer newSizes.push_back(*ub); 4876100870SMatthias Springer } 4976100870SMatthias Springer 5076100870SMatthias Springer // Return existing memref::AllocaOp if nothing has changed. 5176100870SMatthias Springer if (llvm::equal(allocaOp.getMixedSizes(), newSizes)) 5276100870SMatthias Springer return allocaOp.getResult(); 5376100870SMatthias Springer 5476100870SMatthias Springer // Create a new memref::AllocaOp. 5576100870SMatthias Springer Value newAllocaOp = 5676100870SMatthias Springer b.create<AllocaOp>(loc, newSizes, allocaOp.getType().getElementType()); 5776100870SMatthias Springer 5876100870SMatthias Springer // Create a memref::SubViewOp. 5976100870SMatthias Springer SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0)); 6076100870SMatthias Springer SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1)); 6176100870SMatthias Springer return b 6276100870SMatthias Springer .create<SubViewOp>(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(), 6376100870SMatthias Springer strides) 6476100870SMatthias Springer .getResult(); 6576100870SMatthias Springer } 6676100870SMatthias Springer 6776100870SMatthias Springer /// Push down an UnrealizedConversionCastOp past a SubViewOp. 6876100870SMatthias Springer static UnrealizedConversionCastOp 6976100870SMatthias Springer propagateSubViewOp(RewriterBase &rewriter, 7076100870SMatthias Springer UnrealizedConversionCastOp conversionOp, SubViewOp op) { 7176100870SMatthias Springer OpBuilder::InsertionGuard g(rewriter); 7276100870SMatthias Springer rewriter.setInsertionPoint(op); 735550c821STres Popp auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType( 7476100870SMatthias Springer op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), 755550c821STres Popp op.getMixedSizes(), op.getMixedStrides())); 7676100870SMatthias Springer Value newSubview = rewriter.create<SubViewOp>( 7776100870SMatthias Springer op.getLoc(), newResultType, conversionOp.getOperand(0), 7876100870SMatthias Springer op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); 7976100870SMatthias Springer auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>( 8076100870SMatthias Springer op.getLoc(), op.getType(), newSubview); 8176100870SMatthias Springer rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0)); 8276100870SMatthias Springer return newConversionOp; 8376100870SMatthias Springer } 8476100870SMatthias Springer 8576100870SMatthias Springer /// Given an original op and a new, modified op with the same number of results, 8676100870SMatthias Springer /// whose memref return types may differ, replace all uses of the original op 8776100870SMatthias Springer /// with the new op and propagate the new memref types through the IR. 8876100870SMatthias Springer /// 8976100870SMatthias Springer /// Example: 9076100870SMatthias Springer /// %from = memref.alloca(%sz) : memref<?xf32> 9176100870SMatthias Springer /// %to = memref.subview ... : ... to memref<?xf32, strided<[1], offset: ?>> 9276100870SMatthias Springer /// memref.store %cst, %from[%c0] : memref<?xf32> 9376100870SMatthias Springer /// 9476100870SMatthias Springer /// In the above example, all uses of %from are replaced with %to. This can be 9576100870SMatthias Springer /// done directly for ops such as memref.store. For ops that have memref results 9676100870SMatthias Springer /// (e.g., memref.subview), the result type may depend on the operand type, so 9776100870SMatthias Springer /// we cannot just replace all uses. There is special handling for common memref 9876100870SMatthias Springer /// ops. For all other ops, unrealized_conversion_cast is inserted. 9976100870SMatthias Springer static void replaceAndPropagateMemRefType(RewriterBase &rewriter, 10076100870SMatthias Springer Operation *from, Operation *to) { 10176100870SMatthias Springer assert(from->getNumResults() == to->getNumResults() && 10276100870SMatthias Springer "expected same number of results"); 10376100870SMatthias Springer OpBuilder::InsertionGuard g(rewriter); 10476100870SMatthias Springer rewriter.setInsertionPointAfter(to); 10576100870SMatthias Springer 10676100870SMatthias Springer // Wrap new results in unrealized_conversion_cast and replace all uses of the 10776100870SMatthias Springer // original op. 10876100870SMatthias Springer SmallVector<UnrealizedConversionCastOp> unrealizedConversions; 10976100870SMatthias Springer for (const auto &it : 11076100870SMatthias Springer llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) { 11176100870SMatthias Springer unrealizedConversions.push_back(rewriter.create<UnrealizedConversionCastOp>( 11276100870SMatthias Springer to->getLoc(), std::get<0>(it.value()).getType(), 11376100870SMatthias Springer std::get<1>(it.value()))); 11476100870SMatthias Springer rewriter.replaceAllUsesWith(from->getResult(it.index()), 11576100870SMatthias Springer unrealizedConversions.back()->getResult(0)); 11676100870SMatthias Springer } 11776100870SMatthias Springer 11876100870SMatthias Springer // Push unrealized_conversion_cast ops further down in the IR. I.e., try to 11976100870SMatthias Springer // wrap results instead of operands in a cast. 12076100870SMatthias Springer for (int i = 0; i < static_cast<int>(unrealizedConversions.size()); ++i) { 12176100870SMatthias Springer UnrealizedConversionCastOp conversion = unrealizedConversions[i]; 12276100870SMatthias Springer assert(conversion->getNumOperands() == 1 && 12376100870SMatthias Springer conversion->getNumResults() == 1 && 12476100870SMatthias Springer "expected single operand and single result"); 12576100870SMatthias Springer SmallVector<Operation *> users = llvm::to_vector(conversion->getUsers()); 12676100870SMatthias Springer for (Operation *user : users) { 12776100870SMatthias Springer // Handle common memref dialect ops that produce new memrefs and must 12876100870SMatthias Springer // be recreated with the new result type. 12976100870SMatthias Springer if (auto subviewOp = dyn_cast<SubViewOp>(user)) { 13076100870SMatthias Springer unrealizedConversions.push_back( 13176100870SMatthias Springer propagateSubViewOp(rewriter, conversion, subviewOp)); 13276100870SMatthias Springer continue; 13376100870SMatthias Springer } 13476100870SMatthias Springer 13576100870SMatthias Springer // TODO: Other memref ops such as memref.collapse_shape/expand_shape 13676100870SMatthias Springer // should also be handled here. 13776100870SMatthias Springer 13876100870SMatthias Springer // Skip any ops that produce MemRef result or have MemRef region block 13976100870SMatthias Springer // arguments. These may need special handling (e.g., scf.for). 14076100870SMatthias Springer if (llvm::any_of(user->getResultTypes(), 14176100870SMatthias Springer [](Type t) { return isa<MemRefType>(t); })) 14276100870SMatthias Springer continue; 14376100870SMatthias Springer if (llvm::any_of(user->getRegions(), [](Region &r) { 14476100870SMatthias Springer return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) { 14576100870SMatthias Springer return isa<MemRefType>(bbArg.getType()); 14676100870SMatthias Springer }); 14776100870SMatthias Springer })) 14876100870SMatthias Springer continue; 14976100870SMatthias Springer 15076100870SMatthias Springer // For all other ops, we assume that we can directly replace the operand. 15176100870SMatthias Springer // This may have to be revised in the future; e.g., there may be ops that 15276100870SMatthias Springer // do not support non-identity layout maps. 15376100870SMatthias Springer for (OpOperand &operand : user->getOpOperands()) { 1540a0aff2dSMikhail Goncharov if ([[maybe_unused]] auto castOp = 15576100870SMatthias Springer operand.get().getDefiningOp<UnrealizedConversionCastOp>()) { 1565fcf907bSMatthias Springer rewriter.modifyOpInPlace( 15776100870SMatthias Springer user, [&]() { operand.set(conversion->getOperand(0)); }); 15876100870SMatthias Springer } 15976100870SMatthias Springer } 16076100870SMatthias Springer } 16176100870SMatthias Springer } 16276100870SMatthias Springer 16376100870SMatthias Springer // Erase all unrealized_conversion_cast ops without uses. 16476100870SMatthias Springer for (auto op : unrealizedConversions) 16576100870SMatthias Springer if (op->getUses().empty()) 16676100870SMatthias Springer rewriter.eraseOp(op); 16776100870SMatthias Springer } 16876100870SMatthias Springer 16976100870SMatthias Springer FailureOr<Value> memref::replaceWithIndependentOp(RewriterBase &rewriter, 17076100870SMatthias Springer memref::AllocaOp allocaOp, 17176100870SMatthias Springer ValueRange independencies) { 17276100870SMatthias Springer auto replacement = 17376100870SMatthias Springer memref::buildIndependentOp(rewriter, allocaOp, independencies); 17476100870SMatthias Springer if (failed(replacement)) 17576100870SMatthias Springer return failure(); 17676100870SMatthias Springer replaceAndPropagateMemRefType(rewriter, allocaOp, 17776100870SMatthias Springer replacement->getDefiningOp()); 17876100870SMatthias Springer return replacement; 17976100870SMatthias Springer } 180e55e36deSOleksandr "Alex" Zinenko 181e55e36deSOleksandr "Alex" Zinenko memref::AllocaOp memref::allocToAlloca( 182e55e36deSOleksandr "Alex" Zinenko RewriterBase &rewriter, memref::AllocOp alloc, 183e55e36deSOleksandr "Alex" Zinenko function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) { 184e55e36deSOleksandr "Alex" Zinenko memref::DeallocOp dealloc = nullptr; 185e55e36deSOleksandr "Alex" Zinenko for (Operation &candidate : 186e55e36deSOleksandr "Alex" Zinenko llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) { 187e55e36deSOleksandr "Alex" Zinenko dealloc = dyn_cast<memref::DeallocOp>(candidate); 188e55e36deSOleksandr "Alex" Zinenko if (dealloc && dealloc.getMemref() == alloc.getMemref() && 189e55e36deSOleksandr "Alex" Zinenko (!filter || filter(alloc, dealloc))) { 190e55e36deSOleksandr "Alex" Zinenko break; 191e55e36deSOleksandr "Alex" Zinenko } 192e55e36deSOleksandr "Alex" Zinenko } 193e55e36deSOleksandr "Alex" Zinenko 194e55e36deSOleksandr "Alex" Zinenko if (!dealloc) 195e55e36deSOleksandr "Alex" Zinenko return nullptr; 196e55e36deSOleksandr "Alex" Zinenko 197e55e36deSOleksandr "Alex" Zinenko OpBuilder::InsertionGuard guard(rewriter); 198e55e36deSOleksandr "Alex" Zinenko rewriter.setInsertionPoint(alloc); 199e55e36deSOleksandr "Alex" Zinenko auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>( 200e55e36deSOleksandr "Alex" Zinenko alloc, alloc.getMemref().getType(), alloc.getOperands()); 201e55e36deSOleksandr "Alex" Zinenko rewriter.eraseOp(dealloc); 202e55e36deSOleksandr "Alex" Zinenko return alloca; 203e55e36deSOleksandr "Alex" Zinenko } 204