1b6281940SNicolas Vasilache //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2b6281940SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b6281940SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
8b6281940SNicolas Vasilache //
9b6281940SNicolas Vasilache // This file implements the linalg dialect Fusion pass.
10b6281940SNicolas Vasilache //
11b6281940SNicolas Vasilache //===----------------------------------------------------------------------===//
12b6281940SNicolas Vasilache
13cc11cedaSHanhan Wang #include "mlir/Dialect/Affine/IR/AffineOps.h"
14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
15b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
16b6281940SNicolas Vasilache #include "mlir/Dialect/Linalg/Passes.h"
17c694588fSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18b6281940SNicolas Vasilache #include "mlir/Dialect/Linalg/Utils/Utils.h"
19e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h"
20129d6e55SSean Silva #include "mlir/Dialect/Tensor/IR/Tensor.h"
210c8ad3aaSNicolas Vasilache #include "mlir/IR/AffineExpr.h"
220c8ad3aaSNicolas Vasilache #include "mlir/IR/AffineMap.h"
2357818885SStephen Neuendorffer #include "mlir/IR/Dominance.h"
24b6281940SNicolas Vasilache #include "mlir/Support/LLVM.h"
25b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2601defcc8SMaheshRavishankar #include "mlir/Transforms/RegionUtils.h"
275ca20851SMaheshRavishankar #include "llvm/ADT/MapVector.h"
288cf650c5SNicolas Vasilache #include "llvm/ADT/ScopeExit.h"
29b6281940SNicolas Vasilache #include "llvm/Support/CommandLine.h"
30b6281940SNicolas Vasilache #include "llvm/Support/Debug.h"
31b6281940SNicolas Vasilache
32a1fe1f5fSKazu Hirata #include <optional>
33197a73f0SMehdi Amini #include <set>
345ca20851SMaheshRavishankar
35b6281940SNicolas Vasilache #define DEBUG_TYPE "linalg-fusion"
36b6281940SNicolas Vasilache
37b6281940SNicolas Vasilache using namespace mlir;
38b6281940SNicolas Vasilache using namespace mlir::linalg;
39b6281940SNicolas Vasilache
4037e0fdd0SNicolas Vasilache /// Implements a simple high-level fusion pass on linalg structured operations.
41b6281940SNicolas Vasilache ///
42b6281940SNicolas Vasilache /// In each block, linalg ops are processed in reverse textual order.
43445232dfSNicolas Vasilache /// Given a linalg op `O`, fusion occurs by:
4437e0fdd0SNicolas Vasilache /// 1. inspecting the linalg ops that write into the views read by `O`. There
4537e0fdd0SNicolas Vasilache /// are 2 cases:
4637e0fdd0SNicolas Vasilache /// a) buffer case: use the SSA value of the views and a simple alias
4737e0fdd0SNicolas Vasilache /// analysis on subview ops to determine producer-consumer dependences;
48060208b4SMatthias Springer /// b) tensor case: use SSA use-def chains on extract_slice ops;
49060208b4SMatthias Springer /// 2. greedily fuse the linalg ops that produce the subview/extract_slice.
50445232dfSNicolas Vasilache /// 3. inspect the fused ops and determine whether they have other remaining
51b6281940SNicolas Vasilache /// LinalgOp uses. If not, then erase the original producing linalg op.
52b6281940SNicolas Vasilache ///
53b6281940SNicolas Vasilache /// More advanced use cases, analyses as well as profitability heuristics are
54b6281940SNicolas Vasilache /// left for future work.
55b6281940SNicolas Vasilache
5637e0fdd0SNicolas Vasilache struct ShapeDimension {
5737e0fdd0SNicolas Vasilache Value shape;
58b6281940SNicolas Vasilache unsigned dimension;
59b6281940SNicolas Vasilache };
60b6281940SNicolas Vasilache
6137e0fdd0SNicolas Vasilache // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
62445232dfSNicolas Vasilache // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
63445232dfSNicolas Vasilache // guarantees at least one such dimension is found. If multiple candidates exist
64445232dfSNicolas Vasilache // they must agree by construction (i.e. have the same size) and we just return
65445232dfSNicolas Vasilache // the first one.
66e65a5e5bSMaheshRavishankar static ShapeDimension
getShapeDefiningLoopRange(LinalgOp op,unsigned loopDepth,bool fromSubViewOpOnly=false)67e65a5e5bSMaheshRavishankar getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
68e65a5e5bSMaheshRavishankar bool fromSubViewOpOnly = false) {
69b6281940SNicolas Vasilache // Iterate over the inputs and outputs in order.
70b6281940SNicolas Vasilache // Extract the subranges from the linearized ranges.
71a7cccb9cSAlexander Belyaev for (OpOperand &opOperand : op->getOpOperands()) {
72e65a5e5bSMaheshRavishankar // The method `getRangeFromOperandShape` requires using SubViewOp or
73060208b4SMatthias Springer // ExtractSliceOps. If the value isn't defined from there continue.
74e65a5e5bSMaheshRavishankar // todo: The method should be adapted to get the values from
75e65a5e5bSMaheshRavishankar // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
76e65a5e5bSMaheshRavishankar // currently returns a `linalg.range`. The fix here is to move this op to
77e65a5e5bSMaheshRavishankar // `std` dialect and add the method to `ViewInterface`.
78060208b4SMatthias Springer if (fromSubViewOpOnly &&
79060208b4SMatthias Springer !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>(
80a7cccb9cSAlexander Belyaev opOperand.get().getDefiningOp()))
81e65a5e5bSMaheshRavishankar continue;
82e65a5e5bSMaheshRavishankar
83a7cccb9cSAlexander Belyaev AffineMap map = op.getMatchingIndexingMap(&opOperand);
847594f502STobias Gysi LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
85a7cccb9cSAlexander Belyaev << opOperand.getOperandNumber() << "\n");
865ca20851SMaheshRavishankar LLVM_DEBUG(llvm::dbgs()
875ca20851SMaheshRavishankar << "getShapeDefiningLoopRange map: " << map << "\n");
8837e0fdd0SNicolas Vasilache SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
89e4853be2SMehdi Amini for (const auto &en : llvm::enumerate(map.getResults())) {
901609f1c2Slong.chen auto dimExpr = dyn_cast<AffineDimExpr>(en.value());
91e65a5e5bSMaheshRavishankar if (!dimExpr)
92e65a5e5bSMaheshRavishankar continue;
931609f1c2Slong.chen if (loopDepth == cast<AffineDimExpr>(en.value()).getPosition()) {
945ca20851SMaheshRavishankar LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
9537e0fdd0SNicolas Vasilache << loopDepth << "\n");
967594f502STobias Gysi LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
97a7cccb9cSAlexander Belyaev << opOperand.get() << "\n");
98a7cccb9cSAlexander Belyaev return ShapeDimension{opOperand.get(),
997594f502STobias Gysi static_cast<unsigned>(en.index())};
100b6281940SNicolas Vasilache }
101b6281940SNicolas Vasilache }
102b6281940SNicolas Vasilache }
10337e0fdd0SNicolas Vasilache llvm_unreachable("Expect to be able to extract a shape defining loop range");
104b6281940SNicolas Vasilache }
105b6281940SNicolas Vasilache
getTiledOperands(LinalgOp producer)1061a829d2dSAlexander Belyaev static SmallVector<Value> getTiledOperands(LinalgOp producer) {
107a7cccb9cSAlexander Belyaev return producer->getOperands();
1089ecc8178SAlexander Belyaev }
1099ecc8178SAlexander Belyaev
110e58597eeSLei Zhang /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
111e65a5e5bSMaheshRavishankar /// provides the loop range information for the fused loops. The rest are
112e65a5e5bSMaheshRavishankar /// obtained from the producer itself, since they are not tiled + fused.
fuse(OpBuilder & b,LinalgOp producer,const DenseMap<unsigned,Range> & fusedLoopsAndRanges)113e3cf7c88SNicolas Vasilache static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
114e65a5e5bSMaheshRavishankar const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
115e99fae89SAlex Zinenko SmallVector<OpFoldResult> ivs, tileSizes, sizeBounds;
116e99fae89SAlex Zinenko SmallVector<Range> loopRanges;
117e58597eeSLei Zhang Location loc = producer.getLoc();
118b6281940SNicolas Vasilache
119e58597eeSLei Zhang for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
120c95a7246SMatthias Springer auto shapeDim = getShapeDefiningLoopRange(producer, i);
121e99fae89SAlex Zinenko OpFoldResult dim =
122e99fae89SAlex Zinenko createFoldedDimOp(b, loc, shapeDim.shape, shapeDim.dimension);
123c95a7246SMatthias Springer sizeBounds.push_back(dim);
124e58597eeSLei Zhang auto it = fusedLoopsAndRanges.find(i);
125e58597eeSLei Zhang if (it != fusedLoopsAndRanges.end()) {
126e99fae89SAlex Zinenko ivs.push_back(it->second.offset);
127e99fae89SAlex Zinenko tileSizes.push_back(it->second.size);
128e58597eeSLei Zhang loopRanges.push_back(it->second);
129e58597eeSLei Zhang LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
130e58597eeSLei Zhang << loopRanges.back() << "\n");
131e58597eeSLei Zhang } else {
132e99fae89SAlex Zinenko tileSizes.push_back(b.getIndexAttr(0));
13370e99f38SAlex Zinenko loopRanges.push_back(Range{b.getIndexAttr(0), dim, b.getIndexAttr(1)});
134e58597eeSLei Zhang LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
135e58597eeSLei Zhang << loopRanges.back() << "\n");
136b6281940SNicolas Vasilache }
137b6281940SNicolas Vasilache }
138b6281940SNicolas Vasilache
139e58597eeSLei Zhang SmallVector<Value, 8> clonedShapes;
140a7cccb9cSAlexander Belyaev clonedShapes.reserve(producer->getNumOperands());
141e58597eeSLei Zhang
142e58597eeSLei Zhang // Compute subranges for all tensor input/output operands.
14365bdeddbSOkwan Kwon clonedShapes.append(makeTiledShapes(
14465bdeddbSOkwan Kwon b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds,
14565bdeddbSOkwan Kwon /**omitPartialTileCheck=*/false));
146e58597eeSLei Zhang
1473a087c15SMatthias Springer // Take result types from the tiled init operands.
1483a087c15SMatthias Springer MutableOperandRange producerDpsInits = producer.getDpsInitsMutable();
149e58597eeSLei Zhang SmallVector<Type, 4> resultTypes;
150e58597eeSLei Zhang resultTypes.reserve(producer->getNumResults());
1513a087c15SMatthias Springer int64_t firstInitOperandIdx =
152*a9304edfSThomas Preud'homme producerDpsInits.getAsOperandRange().getBeginOperandIndex();
1533a087c15SMatthias Springer for (int64_t i = 0, e = producer->getNumResults(); i < e; ++i) {
1543a087c15SMatthias Springer resultTypes.push_back(clonedShapes[firstInitOperandIdx + i].getType());
155e58597eeSLei Zhang }
156e58597eeSLei Zhang
1573a087c15SMatthias Springer // Clone the producer with new operands and result types.
1586089d612SRahul Kayaith LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes);
15990b7817eSTobias Gysi
16090b7817eSTobias Gysi // Shift all IndexOp results by the tile offset.
161e99fae89SAlex Zinenko SmallVector<OpFoldResult> allIvs = llvm::to_vector(
162e99fae89SAlex Zinenko llvm::map_range(loopRanges, [&](Range range) { return range.offset; }));
16381b62f7fSAlex Zinenko offsetIndices(b, clonedOp, allIvs);
164e58597eeSLei Zhang
165e58597eeSLei Zhang return clonedOp;
166e65a5e5bSMaheshRavishankar }
167e65a5e5bSMaheshRavishankar
168e65a5e5bSMaheshRavishankar /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
169060208b4SMatthias Springer /// expected to be defined by a subview op or an extract_slice op.
getRangeFromOperandShape(OpBuilder & b,Location loc,Value shapedOperand,unsigned dim)170e65a5e5bSMaheshRavishankar static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
171e65a5e5bSMaheshRavishankar Value shapedOperand, unsigned dim) {
172e65a5e5bSMaheshRavishankar Operation *shapeProducingOp = shapedOperand.getDefiningOp();
173e2310704SJulian Gross if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp))
174e65a5e5bSMaheshRavishankar return subViewOp.getOrCreateRanges(b, loc)[dim];
175060208b4SMatthias Springer if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(shapeProducingOp))
176060208b4SMatthias Springer return sliceOp.getOrCreateRanges(b, loc)[dim];
177060208b4SMatthias Springer llvm_unreachable("SubviewOp or ExtractSliceOp expected");
178e65a5e5bSMaheshRavishankar }
179e65a5e5bSMaheshRavishankar
1807594f502STobias Gysi /// Fuses the producer into the loop immediately enclosing the consumer.
1817594f502STobias Gysi /// This is achieved by "recomputing" the producer at the time it
1827594f502STobias Gysi /// is needed just before the consumer.
fuse(OpBuilder & b,LinalgOp producerOp,AffineMap producerMap,OpOperand & consumerOpOperand)183bce318f5SMaheshRavishankar static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
184bce318f5SMaheshRavishankar OpOperand &consumerOpOperand) {
185bce318f5SMaheshRavishankar LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
186e65a5e5bSMaheshRavishankar DenseMap<unsigned, Range> fusedLoopsAndRanges;
18780f07854SNicolas Vasilache Value shapedOperand = consumerOpOperand.get();
188e4853be2SMehdi Amini for (const auto &en : llvm::enumerate(producerMap.getResults())) {
1891609f1c2Slong.chen unsigned posInProducerLoop = cast<AffineDimExpr>(en.value()).getPosition();
19080f07854SNicolas Vasilache fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
19180f07854SNicolas Vasilache b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
192e65a5e5bSMaheshRavishankar }
19380f07854SNicolas Vasilache return fuse(b, producerOp, fusedLoopsAndRanges);
194b6281940SNicolas Vasilache }
195b6281940SNicolas Vasilache
19637e0fdd0SNicolas Vasilache /// Walk back use-def chain through scf::For yields.
19737e0fdd0SNicolas Vasilache /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
19801defcc8SMaheshRavishankar
19901defcc8SMaheshRavishankar // TODO(ravishankarm, ntv): This can be moved into the dependence graphs
20001defcc8SMaheshRavishankar // dependence tracking since the dependence tracking is similar to what is done
20101defcc8SMaheshRavishankar // w.r.t to buffers.
getProducerOfTensor(Value tensor,OpResult & opResult)20280f07854SNicolas Vasilache static void getProducerOfTensor(Value tensor, OpResult &opResult) {
2035550c821STres Popp if (!isa<RankedTensorType>(tensor.getType()))
20437e0fdd0SNicolas Vasilache return;
20537e0fdd0SNicolas Vasilache
20637e0fdd0SNicolas Vasilache while (true) {
20780f07854SNicolas Vasilache LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
20837e0fdd0SNicolas Vasilache if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
2095550c821STres Popp opResult = cast<OpResult>(tensor);
21037e0fdd0SNicolas Vasilache return;
21137e0fdd0SNicolas Vasilache }
212060208b4SMatthias Springer if (auto sliceOp = tensor.getDefiningOp<tensor::ExtractSliceOp>()) {
21304235d07SJacques Pienaar tensor = sliceOp.getSource();
21437e0fdd0SNicolas Vasilache continue;
21537e0fdd0SNicolas Vasilache }
2165550c821STres Popp if (auto blockArg = dyn_cast<BlockArgument>(tensor)) {
21737e0fdd0SNicolas Vasilache if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
2185cf714bbSMatthias Springer tensor = forOp.getInitArgs()[blockArg.getArgNumber()];
21937e0fdd0SNicolas Vasilache continue;
22037e0fdd0SNicolas Vasilache }
22137e0fdd0SNicolas Vasilache }
22237e0fdd0SNicolas Vasilache return;
22337e0fdd0SNicolas Vasilache }
22437e0fdd0SNicolas Vasilache }
22537e0fdd0SNicolas Vasilache
226489fec27SNicolas Vasilache FailureOr<FusionInfo>
fuseProducerOfTensor(OpBuilder & b,OpOperand & consumerOpOperand)22780f07854SNicolas Vasilache mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
22880f07854SNicolas Vasilache Value inputTensor = consumerOpOperand.get();
22980f07854SNicolas Vasilache OpResult producerOpResult;
23080f07854SNicolas Vasilache getProducerOfTensor(inputTensor, producerOpResult);
23180f07854SNicolas Vasilache if (!producerOpResult) {
23280f07854SNicolas Vasilache LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
233489fec27SNicolas Vasilache return failure();
23480f07854SNicolas Vasilache }
23580f07854SNicolas Vasilache return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
23680f07854SNicolas Vasilache }
23780f07854SNicolas Vasilache
238489fec27SNicolas Vasilache FailureOr<FusionInfo>
fuseProducerOfTensor(OpBuilder & b,OpResult producerOpResult,OpOperand & consumerOpOperand)23980f07854SNicolas Vasilache mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
24080f07854SNicolas Vasilache OpOperand &consumerOpOperand) {
24180f07854SNicolas Vasilache auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
242bce318f5SMaheshRavishankar if (!producerOp)
243489fec27SNicolas Vasilache return failure();
244bce318f5SMaheshRavishankar
245bce318f5SMaheshRavishankar LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
246bce318f5SMaheshRavishankar if (!consumerOp)
247489fec27SNicolas Vasilache return failure();
248bce318f5SMaheshRavishankar
24980f07854SNicolas Vasilache Value inputTensor = consumerOpOperand.get();
25037e0fdd0SNicolas Vasilache
251060208b4SMatthias Springer // Must be an extract_slice op to guarantee there are loops we can fuse into.
252060208b4SMatthias Springer auto sliceOp = inputTensor.getDefiningOp<tensor::ExtractSliceOp>();
253060208b4SMatthias Springer if (!sliceOp) {
25480f07854SNicolas Vasilache LLVM_DEBUG(llvm::dbgs()
255060208b4SMatthias Springer << "\nNot fusable, not an extract_slice op: " << inputTensor);
256489fec27SNicolas Vasilache return failure();
25737e0fdd0SNicolas Vasilache }
25837e0fdd0SNicolas Vasilache
2599b17bf2eSNicolas Vasilache // If producer is already in the same block as consumer, we are done.
26080f07854SNicolas Vasilache if (consumerOpOperand.get().getParentBlock() ==
26180f07854SNicolas Vasilache producerOpResult.getParentBlock())
262489fec27SNicolas Vasilache return failure();
2639b17bf2eSNicolas Vasilache
26437e0fdd0SNicolas Vasilache // Insert fused `producer` just before `consumer`.
26537e0fdd0SNicolas Vasilache OpBuilder::InsertionGuard g(b);
26680f07854SNicolas Vasilache b.setInsertionPoint(consumerOp);
26780f07854SNicolas Vasilache LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
2687594f502STobias Gysi OpOperand *opOperand =
269b4db15a9SAlexander Belyaev producerOp.getDpsInitOperand(producerOpResult.getResultNumber());
270bce318f5SMaheshRavishankar LinalgOp fusedProducer =
2711227b8abSOleg Shyshkov fuse(b, producerOp, producerOp.getMatchingIndexingMap(opOperand),
272bce318f5SMaheshRavishankar consumerOpOperand);
27337e0fdd0SNicolas Vasilache
27437e0fdd0SNicolas Vasilache // Replace use.
27537e0fdd0SNicolas Vasilache // Canonicalizations are not guaranteed to have happened before constructing
27637e0fdd0SNicolas Vasilache // `fusedProducer`. In the tensor case this can result in temporary type
277129d6e55SSean Silva // mismatches. Insert a `tensor.cast` op to propagate the transformation
27837e0fdd0SNicolas Vasilache // invariant that types are compatible.
27980f07854SNicolas Vasilache Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
28080f07854SNicolas Vasilache Type consumerType = consumerOpOperand.get().getType();
28137e0fdd0SNicolas Vasilache if (consumerType != def.getType())
282129d6e55SSean Silva def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
28380f07854SNicolas Vasilache consumerOpOperand.set(def);
28480f07854SNicolas Vasilache return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
28537e0fdd0SNicolas Vasilache }
286