//===- IndependenceTransforms.cpp - Make ops independent of values --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" using namespace mlir; using namespace mlir::memref; /// Make the given OpFoldResult independent of all independencies. static FailureOr makeIndependent(OpBuilder &b, Location loc, OpFoldResult ofr, ValueRange independencies) { if (isa(ofr)) return ofr; AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeIndependentBound( boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies, /*closedUB=*/true))) return failure(); return affine::materializeComputedBound(b, loc, boundMap, mapOperands); } FailureOr memref::buildIndependentOp(OpBuilder &b, memref::AllocaOp allocaOp, ValueRange independencies) { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(allocaOp); Location loc = allocaOp.getLoc(); SmallVector newSizes; for (OpFoldResult ofr : allocaOp.getMixedSizes()) { auto ub = makeIndependent(b, loc, ofr, independencies); if (failed(ub)) return failure(); newSizes.push_back(*ub); } // Return existing memref::AllocaOp if nothing has changed. if (llvm::equal(allocaOp.getMixedSizes(), newSizes)) return allocaOp.getResult(); // Create a new memref::AllocaOp. Value newAllocaOp = b.create(loc, newSizes, allocaOp.getType().getElementType()); // Create a memref::SubViewOp. SmallVector offsets(newSizes.size(), b.getIndexAttr(0)); SmallVector strides(newSizes.size(), b.getIndexAttr(1)); return b .create(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(), strides) .getResult(); } /// Push down an UnrealizedConversionCastOp past a SubViewOp. static UnrealizedConversionCastOp propagateSubViewOp(RewriterBase &rewriter, UnrealizedConversionCastOp conversionOp, SubViewOp op) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); auto newResultType = cast(SubViewOp::inferRankReducedResultType( op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())); Value newSubview = rewriter.create( op.getLoc(), newResultType, conversionOp.getOperand(0), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); auto newConversionOp = rewriter.create( op.getLoc(), op.getType(), newSubview); rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0)); return newConversionOp; } /// Given an original op and a new, modified op with the same number of results, /// whose memref return types may differ, replace all uses of the original op /// with the new op and propagate the new memref types through the IR. /// /// Example: /// %from = memref.alloca(%sz) : memref /// %to = memref.subview ... : ... to memref> /// memref.store %cst, %from[%c0] : memref /// /// In the above example, all uses of %from are replaced with %to. This can be /// done directly for ops such as memref.store. For ops that have memref results /// (e.g., memref.subview), the result type may depend on the operand type, so /// we cannot just replace all uses. There is special handling for common memref /// ops. For all other ops, unrealized_conversion_cast is inserted. static void replaceAndPropagateMemRefType(RewriterBase &rewriter, Operation *from, Operation *to) { assert(from->getNumResults() == to->getNumResults() && "expected same number of results"); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(to); // Wrap new results in unrealized_conversion_cast and replace all uses of the // original op. SmallVector unrealizedConversions; for (const auto &it : llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) { unrealizedConversions.push_back(rewriter.create( to->getLoc(), std::get<0>(it.value()).getType(), std::get<1>(it.value()))); rewriter.replaceAllUsesWith(from->getResult(it.index()), unrealizedConversions.back()->getResult(0)); } // Push unrealized_conversion_cast ops further down in the IR. I.e., try to // wrap results instead of operands in a cast. for (int i = 0; i < static_cast(unrealizedConversions.size()); ++i) { UnrealizedConversionCastOp conversion = unrealizedConversions[i]; assert(conversion->getNumOperands() == 1 && conversion->getNumResults() == 1 && "expected single operand and single result"); SmallVector users = llvm::to_vector(conversion->getUsers()); for (Operation *user : users) { // Handle common memref dialect ops that produce new memrefs and must // be recreated with the new result type. if (auto subviewOp = dyn_cast(user)) { unrealizedConversions.push_back( propagateSubViewOp(rewriter, conversion, subviewOp)); continue; } // TODO: Other memref ops such as memref.collapse_shape/expand_shape // should also be handled here. // Skip any ops that produce MemRef result or have MemRef region block // arguments. These may need special handling (e.g., scf.for). if (llvm::any_of(user->getResultTypes(), [](Type t) { return isa(t); })) continue; if (llvm::any_of(user->getRegions(), [](Region &r) { return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) { return isa(bbArg.getType()); }); })) continue; // For all other ops, we assume that we can directly replace the operand. // This may have to be revised in the future; e.g., there may be ops that // do not support non-identity layout maps. for (OpOperand &operand : user->getOpOperands()) { if ([[maybe_unused]] auto castOp = operand.get().getDefiningOp()) { rewriter.modifyOpInPlace( user, [&]() { operand.set(conversion->getOperand(0)); }); } } } } // Erase all unrealized_conversion_cast ops without uses. for (auto op : unrealizedConversions) if (op->getUses().empty()) rewriter.eraseOp(op); } FailureOr memref::replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies) { auto replacement = memref::buildIndependentOp(rewriter, allocaOp, independencies); if (failed(replacement)) return failure(); replaceAndPropagateMemRefType(rewriter, allocaOp, replacement->getDefiningOp()); return replacement; } memref::AllocaOp memref::allocToAlloca( RewriterBase &rewriter, memref::AllocOp alloc, function_ref filter) { memref::DeallocOp dealloc = nullptr; for (Operation &candidate : llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) { dealloc = dyn_cast(candidate); if (dealloc && dealloc.getMemref() == alloc.getMemref() && (!filter || filter(alloc, dealloc))) { break; } } if (!dealloc) return nullptr; OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(alloc); auto alloca = rewriter.replaceOpWithNewOp( alloc, alloc.getMemref().getType(), alloc.getOperands()); rewriter.eraseOp(dealloc); return alloca; }