//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements functions concerned with hoisting invariant operations // in the context of Linalg transformations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using llvm::dbgs; #define DEBUG_TYPE "linalg-hoisting" #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") using namespace mlir; using namespace mlir::linalg; /// Replace `loop` with a new loop that has a different init operand at /// position `index`. The body of this loop is moved over to the new loop. /// /// `newInitOperands` specifies the replacement "init" operands. /// `newYieldValue` is the replacement yield value of the loop at position /// `index`. static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop, Value newInitOperand, unsigned index, Value newYieldValue) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(loop.getOperation()); auto inits = llvm::to_vector(loop.getInits()); // Replace the init value with the new operand. assert(index < inits.size()); inits[index] = newInitOperand; scf::ForOp newLoop = rewriter.create( loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); // Generate the new yield with the replaced operand. auto yieldOp = cast(loop.getBody()->getTerminator()); yieldOp.setOperand(index, newYieldValue); // Move the loop body to the new op. rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(), newLoop.getBody()->getArguments()); // Replace the old loop. rewriter.replaceOp(loop.getOperation(), newLoop->getResults()); return newLoop; } // Hoist out a pair of corresponding vector.extract+vector.broadcast // operations. This function transforms a loop like this: // %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) { // %e = vector.extract %iarg : t1 to t2 // %u = "some_use"(%e) : (t2) -> t2 // %b = vector.broadcast %u : t2 to t1 // scf.yield %b : t1 // } // into the following: // %e = vector.extract %v: t1 to t2 // %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) { // %u' = "some_use"(%iarg) : (t2) -> t2 // scf.yield %u' : t2 // } // %res = vector.broadcast %res' : t2 to t1 void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root) { bool changed = true; while (changed) { changed = false; // First move loop invariant ops outside of their loop. This needs to be // done before as we cannot move ops without interrupting the function walk. root->walk( [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); root->walk([&](vector::ExtractOp extractOp) { LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *extractOp.getOperation() << "\n"); auto loop = dyn_cast(extractOp->getParentOp()); if (!loop) return WalkResult::advance(); // Check that the vector to extract from is a BlockArgument. auto blockArg = dyn_cast(extractOp.getVector()); if (!blockArg) return WalkResult::advance(); // Check that the blockArg is an iter_arg of the loop. OpOperand *initArg = loop.getTiedLoopInit(blockArg); if (!initArg) return WalkResult::advance(); // If the iter_arg does not have only one use, it won't be possible to // hoist the extractOp out. if (!blockArg.hasOneUse()) return WalkResult::advance(); unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars(); // Check that the loop yields a broadcast that has just one use. Operation *yieldedVal = loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp(); auto broadcast = dyn_cast(yieldedVal); if (!broadcast || !broadcast.getResult().hasOneUse()) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n"); Type broadcastInputType = broadcast.getSourceType(); if (broadcastInputType != extractOp.getType()) return WalkResult::advance(); // The position of the extract must be defined outside of the loop if // it is dynamic. for (auto operand : extractOp.getDynamicPosition()) if (!loop.isDefinedOutsideOfLoop(operand)) return WalkResult::advance(); rewriter.modifyOpInPlace(broadcast, [&] { extractOp.getVectorMutable().assign(initArg->get()); }); loop.moveOutOfLoop(extractOp); rewriter.moveOpAfter(broadcast, loop); scf::ForOp newLoop = replaceWithDifferentYield( rewriter, loop, extractOp.getResult(), index, broadcast.getSource()); LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n"); rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast); rewriter.modifyOpInPlace( broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); }); changed = true; return WalkResult::interrupt(); }); } } static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop) { Value source = transferRead.getSource(); // Skip view-like Ops and retrive the actual soruce Operation while (auto srcOp = dyn_cast_or_null(source.getDefiningOp())) source = srcOp.getViewSource(); llvm::SmallVector users(source.getUsers().begin(), source.getUsers().end()); llvm::SmallDenseSet processed; while (!users.empty()) { Operation *user = users.pop_back_val(); // If the user has already been processed skip. if (!processed.insert(user).second) continue; if (auto viewLike = dyn_cast(user)) { users.append(viewLike->getUsers().begin(), viewLike->getUsers().end()); continue; } if (isMemoryEffectFree(user) || isa(user)) continue; if (!loop->isAncestor(user)) continue; return false; } return true; } void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip) { bool changed = true; while (changed) { changed = false; // First move loop invariant ops outside of their loop. This needs to be // done before as we cannot move ops without interrupting the function walk. root->walk( [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); // Find all loops that are certain to have non zero trip count. Any loops // that are not part of this set cannot be hoisted from, since hoisting from // a potentially zero trip count loop may cause a vector transfer to be // executed when it shouldn't be. llvm::DenseSet definiteNonZeroTripCountLoops; if (verifyNonZeroTrip) { root->walk([&](LoopLikeOpInterface loopLike) { std::optional> lbs = loopLike.getLoopLowerBounds(); std::optional> ubs = loopLike.getLoopUpperBounds(); // If loop bounds cannot be found, assume possibly zero trip count. if (!lbs || !ubs) return; // Otherwise, use ValueBounds to find the maximum lower bound and // minimum upper bound. If the bounds are found, and maxLb is less // than the minUb, then the loop will not have zero trip count. for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) { FailureOr maxLb = ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType::UB, lb, /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(maxLb)) return; FailureOr minUb = ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType::LB, ub); if (failed(minUb)) return; if (minUb.value() <= maxLb.value()) return; definiteNonZeroTripCountLoops.insert(loopLike); } }); } root->walk([&](vector::TransferReadOp transferRead) { if (!isa(transferRead.getShapedType())) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *transferRead.getOperation() << "\n"); auto loop = dyn_cast(transferRead->getParentOp()); LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() << "\n"); if (!isa_and_nonnull(loop)) return WalkResult::advance(); if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) { LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop << "\n"); return WalkResult::advance(); } LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() << "\n"); SetVector forwardSlice; getForwardSlice(transferRead.getOperation(), &forwardSlice); // Look for the last TransferWriteOp in the forwardSlice of // `transferRead` that operates on the same memref. vector::TransferWriteOp transferWrite; for (auto *sliceOp : llvm::reverse(forwardSlice)) { auto candidateWrite = dyn_cast(sliceOp); if (!candidateWrite || candidateWrite.getSource() != transferRead.getSource()) continue; transferWrite = candidateWrite; } // All operands of the TransferRead must be defined outside of the loop. for (auto operand : transferRead.getOperands()) if (!loop.isDefinedOutsideOfLoop(operand)) return WalkResult::advance(); // Only hoist transfer_read / transfer_write pairs and singleton // transfer_reads for now. if (!transferWrite) { // Make sure there are no other accesses to the memref before // hoisting transfer_read. if (noAliasingUseInLoop(transferRead, loop)) loop.moveOutOfLoop(transferRead); return WalkResult::advance(); } LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() << "\n"); // Approximate aliasing by checking that: // 1. indices, vector type and permutation map are the same (i.e., the // transfer_read/transfer_write ops are matching), // 2. source operands for transfer.{read|write} do not originate from // Ops implementing ViewLikeOpInterface. // 3. no other operations in the loop access the same memref except // for transfer_read/transfer_write accessing statically disjoint // slices. if (transferRead.getIndices() != transferWrite.getIndices() || transferRead.getVectorType() != transferWrite.getVectorType() || transferRead.getPermutationMap() != transferWrite.getPermutationMap()) return WalkResult::advance(); auto *source = transferRead.getSource().getDefiningOp(); if (source && isa_and_nonnull(source)) return WalkResult::advance(); source = transferWrite.getSource().getDefiningOp(); if (source && isa_and_nonnull(source)) return WalkResult::advance(); // TODO: may want to memoize this information for performance but it // likely gets invalidated often. DominanceInfo dom(loop); if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) return WalkResult::advance(); for (auto &use : transferRead.getSource().getUses()) { if (!loop->isAncestor(use.getOwner())) continue; if (use.getOwner() == transferRead.getOperation() || use.getOwner() == transferWrite.getOperation()) continue; if (auto transferWriteUse = dyn_cast(use.getOwner())) { if (!vector::isDisjointTransferSet( cast(*transferWrite), cast(*transferWriteUse), /*testDynamicValueUsingBounds=*/true)) return WalkResult::advance(); } else if (auto transferReadUse = dyn_cast(use.getOwner())) { if (!vector::isDisjointTransferSet( cast(*transferWrite), cast(*transferReadUse), /*testDynamicValueUsingBounds=*/true)) return WalkResult::advance(); } else { // Unknown use, we cannot prove that it doesn't alias with the // transferRead/transferWrite operations. return WalkResult::advance(); } } // Hoist read before. loop.moveOutOfLoop(transferRead); // Hoist write after. transferWrite->moveAfter(loop); // Rewrite `loop` with new yields by cloning and erase the original loop. IRRewriter rewriter(transferRead.getContext()); NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { return SmallVector{transferWrite.getVector()}; }; auto maybeNewLoop = loop.replaceWithAdditionalYields( rewriter, transferRead.getVector(), /*replaceInitOperandUsesInLoop=*/true, yieldFn); if (failed(maybeNewLoop)) return WalkResult::interrupt(); transferWrite.getVectorMutable().assign( maybeNewLoop->getOperation()->getResults().back()); changed = true; // Need to interrupt and restart because erasing the loop messes up // the walk. return WalkResult::interrupt(); }); } }