xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision 60d97fb4cfa6164919f5ef5222b2afdd0ee04b1c)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Dominance.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "mlir/Transforms/RegionUtils.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/ScopeExit.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/Debug.h"
33 
34 #include <set>
35 
36 #define DEBUG_TYPE "linalg-fusion"
37 
38 using namespace mlir;
39 using namespace mlir::linalg;
40 
41 using llvm::dbgs;
42 
43 /// Implements a simple high-level fusion pass on linalg structured operations.
44 ///
45 /// In each block, linalg ops are processed in reverse textual order.
46 /// Given a linalg op `O`, fusion occurs by:
47 ///   1. inspecting the linalg ops that write into the views read by `O`. There
48 ///      are 2 cases:
49 ///      a) buffer case: use the SSA value of the views and a simple alias
50 ///         analysis on subview ops to determine producer-consumer dependences;
51 ///      b) tensor case: use SSA use-def chains on subtensor ops;
52 ///   2. greedily fuse the linalg ops that produce the subview/subtensor.
53 ///   3. inspect the fused ops and determine whether they have other remaining
54 ///      LinalgOp uses. If not, then erase the original producing linalg op.
55 ///
56 /// More advanced use cases, analyses as well as profitability heuristics are
57 /// left for future work.
58 
59 struct ShapeDimension {
60   Value shape;
61   unsigned dimension;
62 };
63 
64 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
65 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
66 // guarantees at least one such dimension is found. If multiple candidates exist
67 // they must agree by construction (i.e. have the same size) and we just return
68 // the first one.
69 static ShapeDimension
70 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
71                           bool fromSubViewOpOnly = false) {
72   // Iterate over the inputs and outputs in order.
73   // Extract the subranges from the linearized ranges.
74   for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
75     // The method `getRangeFromOperandShape` requires using SubViewOp or
76     // SubTensorOps. If the value isnt defined from there continue.
77     // todo: The method should be adapted to get the values from
78     // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
79     // currently returns a `linalg.range`. The fix here is to move this op to
80     // `std` dialect and add the method to `ViewInterface`.
81     if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>(
82                                  opOperand->get().getDefiningOp()))
83       continue;
84 
85     AffineMap map = op.getTiedIndexingMap(opOperand);
86     LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
87                             << opOperand->getOperandNumber() << "\n");
88     LLVM_DEBUG(llvm::dbgs()
89                << "getShapeDefiningLoopRange map: " << map << "\n");
90     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
91     for (auto en : llvm::enumerate(map.getResults())) {
92       auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
93       if (!dimExpr)
94         continue;
95       if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) {
96         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
97                                 << loopDepth << "\n");
98         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
99                                 << opOperand->get() << "\n");
100         return ShapeDimension{opOperand->get(),
101                               static_cast<unsigned>(en.index())};
102       }
103     }
104   }
105   llvm_unreachable("Expect to be able to extract a shape defining loop range");
106 }
107 
108 // Return tiled operands for the fused producer op. When fusing into
109 // `linalg.tiled_loop` one has to update `input` and `output` arguments of the
110 // loop correspondingly.
111 // Each input tensor of the producer op has to be added to `inputs` of the
112 // `tiled_loop` if it is not present there already. Each output tensor has to
113 // be added either to `inputs` or to `outputs` of `linalg.tiled_loop` depending
114 // on whether the correponding result is an input or an output to the loop.
115 //
116 // NOTE: This way of updating the arguments of the `tiled_loop` assumes that the
117 // intermediate result is not used by any other operation but the consumer. A
118 // more generic way is to append all missing output tensors of the producer to
119 // the tiled loop outputs and hence modify the number of the results, since we
120 // would need to add the intermediate results to `linalg.yield`. After that a
121 // canonicalization pass would move the unused output args of the `tiled_loop`
122 // to the `input` section.
123 static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) {
124   auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
125   if (!tiledLoop)
126     return producer.getInputAndOutputOperands();
127 
128   SmallVector<Value> tiledOperands;
129   assert(producer.hasTensorSemantics() &&
130          "only fusion on tensors is currently supported for TiledLinalgOp");
131 
132   for (OpOperand *producerInput : producer.getInputOperands()) {
133     OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get());
134     if (addedInput == nullptr)
135       addedInput = &tiledLoop.appendInputOperand(b, producerInput->get());
136     BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
137     tiledOperands.push_back(addedBlockArg);
138   }
139   for (OpOperand *producerOutput : producer.getOutputOperands()) {
140     OpResult result = producer.getTiedOpResult(producerOutput);
141     OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
142     OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
143     assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
144            "The result should be present in `input` or `output` args of "
145            "`tiled_loop");
146 
147     bool isInput = resultInputOperand;
148     int opNumber = isInput ? resultInputOperand->getOperandNumber()
149                            : resultOutputOperand->getOperandNumber();
150 
151     OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput->get());
152     if (addedOutput == nullptr)
153       addedOutput =
154           isInput ? &tiledLoop.appendInputOperand(b, producerOutput->get())
155                   : &tiledLoop.appendOutputOperand(b, producerOutput->get());
156 
157     OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
158     auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
159     auto resultOperandBlockArg = tiledLoop.getTiedBlockArgument(resultOperand);
160     resultOperandBlockArg.replaceAllUsesWith(addedBlockArg);
161     tiledLoop.eraseOperand(b, resultOperand);
162     tiledOperands.push_back(addedBlockArg);
163   }
164   return tiledOperands;
165 }
166 
167 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
168 /// provides the loop range information for the fused loops. The rest are
169 /// obtained from the producer itself, since they are not tiled + fused.
170 static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
171                      const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
172   SmallVector<Value, 8> ivs, tileSizes, sizeBounds;
173   SmallVector<Range, 8> loopRanges;
174   Location loc = producer.getLoc();
175   auto zero = b.create<ConstantIndexOp>(loc, 0);
176   auto one = b.create<ConstantIndexOp>(loc, 1);
177 
178   for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
179     auto it = fusedLoopsAndRanges.find(i);
180     if (it != fusedLoopsAndRanges.end()) {
181       ivs.push_back(it->second.offset);
182       tileSizes.push_back(it->second.size);
183       sizeBounds.push_back(nullptr);
184       loopRanges.push_back(it->second);
185       LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
186                               << loopRanges.back() << "\n");
187     } else {
188       auto shapeDim = getShapeDefiningLoopRange(producer, i);
189       Value dim = b.createOrFold<memref::DimOp>(loc, shapeDim.shape,
190                                                 shapeDim.dimension);
191       tileSizes.push_back(zero);
192       sizeBounds.push_back(dim);
193       loopRanges.push_back(Range{zero, dim, one});
194       LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
195                               << loopRanges.back() << "\n");
196     }
197   }
198 
199   SmallVector<Value, 8> clonedShapes;
200   clonedShapes.reserve(producer.getNumInputsAndOutputs());
201 
202   // Compute subranges for all tensor input/output operands.
203   clonedShapes.append(makeTiledShapes(b, loc, producer,
204                                       getTiledOperands(b, producer), ivs,
205                                       tileSizes, sizeBounds));
206 
207   // Append the other operands.
208   auto operands = producer.getAssumedNonShapedOperands();
209   clonedShapes.append(operands.begin(), operands.end());
210 
211   // Iterate over the results in order.
212   // Extract the subtensor type from the linearized range.
213   // Since we do not enforce any canonicalizations on the fly, this is always
214   // fully dynamic at construction time.
215   SmallVector<Type, 4> resultTypes;
216   resultTypes.reserve(producer->getNumResults());
217   for (RankedTensorType t : producer.getOutputTensorTypes()) {
218     unsigned rank = t.getRank();
219     SmallVector<int64_t, 4> staticOffsetsVector(
220         rank, ShapedType::kDynamicStrideOrOffset);
221     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
222     SmallVector<int64_t, 4> staticStridesVector(
223         rank, ShapedType::kDynamicStrideOrOffset);
224     resultTypes.push_back(SubTensorOp::inferResultType(
225         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
226         staticStridesVector));
227   }
228 
229   Operation *clonedOp = producer.clone(b, loc, resultTypes, clonedShapes);
230   // When the producer has index semantics, we have to transform the indices of
231   // the producer according to the tiling of the consumer, i.e. offset them by
232   // the values computed in `loopRanges`.
233   if (producer.hasIndexSemantics()) {
234     assert(clonedOp->getNumRegions() == 1 &&
235            clonedOp->getRegion(0).getBlocks().size() == 1 &&
236            "expected producer to have one block.");
237     // Shift all indices by the tile offset.
238     Block &block = clonedOp->getRegion(0).front();
239     for (IndexOp indexOp : block.getOps<IndexOp>()) {
240       OpBuilder::InsertionGuard g(b);
241       b.setInsertionPointAfter(indexOp);
242       AffineExpr index, offset;
243       bindDims(b.getContext(), index, offset);
244       AffineApplyOp applyOp = b.create<AffineApplyOp>(
245           indexOp.getLoc(), index + offset,
246           ValueRange{indexOp.getResult(), loopRanges[indexOp.dim()].offset});
247       indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
248     }
249   }
250 
251   return clonedOp;
252 }
253 
254 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
255 /// expected to be defined by a subview op or a subtensor op.
256 static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
257                                       Value shapedOperand, unsigned dim) {
258   Operation *shapeProducingOp = shapedOperand.getDefiningOp();
259   if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp))
260     return subViewOp.getOrCreateRanges(b, loc)[dim];
261   if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
262     return subTensorOp.getOrCreateRanges(b, loc)[dim];
263   llvm_unreachable("SubviewOp or SubTensorOp expected");
264 }
265 
266 /// Fuses the producer into the loop immediately enclosing the consumer.
267 /// This is achieved by "recomputing" the producer at the time it
268 /// is needed just before the consumer.
269 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
270                      OpOperand &consumerOpOperand) {
271   LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
272   DenseMap<unsigned, Range> fusedLoopsAndRanges;
273   Value shapedOperand = consumerOpOperand.get();
274   for (auto en : llvm::enumerate(producerMap.getResults())) {
275     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
276     fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
277         b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
278   }
279   return fuse(b, producerOp, fusedLoopsAndRanges);
280 }
281 
282 // Encode structural fusion safety preconditions.
283 // Some of these will be lifted in the future with better analysis.
284 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
285                                           LinalgOp consumer) {
286   assert(producer.hasBufferSemantics() &&
287          "expected linalg op with buffer semantics");
288   assert(consumer.hasBufferSemantics() &&
289          "expected linalg op with buffer semantics");
290   if (producer.getNumOutputs() != 1) {
291     LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
292     return false;
293   }
294   // Only fuse when the producer block dominates.
295   DominanceInfo dom(producer.getOperation());
296   if (!dom.dominates(producer->getBlock(), consumer->getBlock())) {
297     LLVM_DEBUG(
298         llvm::dbgs()
299         << "\nNot structurally fusable (producer block does not dominate)");
300     return false;
301   }
302   return true;
303 }
304 
305 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
306                                              LinalgOp consumer,
307                                              Value consumedView,
308                                              LinalgOp producer) {
309   assert(producer.hasBufferSemantics() &&
310          "expected linalg op with buffer semantics");
311   assert(consumer.hasBufferSemantics() &&
312          "expected linalg op with buffer semantics");
313   // Make some simple structural checks that alleviate the need for more
314   // complex analyses.
315   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
316     LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
317                             << *producer.getOperation());
318     return false;
319   }
320   // Check for any interleaved write to consumedView.
321   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
322     LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
323                             << *producer.getOperation());
324     return false;
325   }
326   return true;
327 }
328 
329 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
330                                  LinalgOp consumer, Value consumedView,
331                                  LinalgOp producer) {
332   assert(producer.hasBufferSemantics() &&
333          "expected linalg op with buffer semantics");
334   assert(consumer.hasBufferSemantics() &&
335          "expected linalg op with buffer semantics");
336   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
337     return false;
338   // Check for any fusion-preventing dependence to any shape read/written that
339   // would violate dependences.
340   if (!graph.findCoveringDependences(producer, consumer).empty()) {
341     LLVM_DEBUG(llvm::dbgs()
342                << "\n***Not fusable due to an interleaved dependence:\t"
343                << *producer.getOperation());
344     return false;
345   }
346   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
347     // TODO: add a level of indirection to linalg.generic.
348     if (convOp.padding())
349       return false;
350   }
351   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
352     // TODO: add a level of indirection to linalg.generic.
353     if (convOp.padding())
354       return false;
355   }
356   return true;
357 }
358 
359 /// For `consumer` with buffer semantics, find the Linalg operation on buffers
360 /// that is the last writer of `consumerOpOperand`. For now the fusable
361 /// dependence is returned as an instance of the `dependenceGraph`.
362 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
363 findFusableProducer(OpOperand &consumerOpOperand,
364                     const LinalgDependenceGraph &dependenceGraph) {
365   LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: "
366                           << consumerOpOperand.get() << " @"
367                           << consumerOpOperand.getOperandNumber() << " in "
368                           << *consumerOpOperand.getOwner() << "\n");
369   LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
370   if (!consumerOp)
371     return {};
372 
373   // Only consider RAW and WAW atm.
374   for (auto depType : {
375            LinalgDependenceGraph::DependenceType::RAW,
376            LinalgDependenceGraph::DependenceType::WAW,
377        }) {
378     LLVM_DEBUG(llvm::dbgs()
379                << "Dependencies into: " << *consumerOp.getOperation() << "\n");
380     for (auto dependence : llvm::make_filter_range(
381              dependenceGraph.getDependencesInto(consumerOp, depType),
382              [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
383                LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: "
384                                        << elem.getIndexingValue() << " and "
385                                        << elem.getDependentValue() << "\n");
386                Value v = elem.getIndexingValue();
387                Optional<unsigned> operandNum =
388                    elem.getIndexingOpViewOperandNum();
389                return isa<LinalgOp>(elem.getDependentOp()) &&
390                       v == consumerOpOperand.get() && operandNum &&
391                       operandNum.getValue() ==
392                           consumerOpOperand.getOperandNumber();
393              })) {
394       // Consumer consumes this view, `isStructurallyFusableProducer` also
395       // checks whether it is a strict subview of the producer view.
396       auto producer = cast<LinalgOp>(dependence.getDependentOp());
397       LLVM_DEBUG(llvm::dbgs()
398                  << "\n"
399                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
400                  << "producer: " << *dependence.getDependentOp()
401                  << " view: " << dependence.getDependentValue() << "\n");
402 
403       // If the producer and consumer have tensor semantics, the only dependence
404       // between them is through a RAW dependence and they are fusable by
405       // construction. For buffer semantics need additional checks.
406       if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() &&
407           isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
408                         producer))
409         return dependence;
410       if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) {
411         assert(dependence.dependenceType ==
412                LinalgDependenceGraph::DependenceType::RAW);
413         return dependence;
414       }
415     }
416   }
417   return {};
418 }
419 
420 Optional<FusionInfo>
421 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
422                                    const LinalgDependenceGraph &graph) {
423   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
424       findFusableProducer(consumerOpOperand, graph);
425   if (!fusableDependence)
426     return llvm::None;
427 
428   LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
429   if (!producerOp)
430     return llvm::None;
431 
432   // If producer is already in the same block as consumer, we are done.
433   if (consumerOpOperand.get().getParentBlock() ==
434       fusableDependence->getDependentValue().getParentBlock())
435     return llvm::None;
436 
437   Optional<AffineMap> producerMap =
438       fusableDependence->getDependentOpViewIndexingMap();
439   if (!producerMap)
440     return llvm::None;
441 
442   // Must be a subview or a slice to guarantee there are loops we can fuse
443   // into.
444   auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>();
445   if (!subView) {
446     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)");
447     return llvm::None;
448   }
449 
450   // Fuse `producer` just before `consumer`.
451   OpBuilder::InsertionGuard g(b);
452   b.setInsertionPoint(consumerOpOperand.getOwner());
453   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: "
454                           << *consumerOpOperand.getOwner() << "\n");
455 
456   auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand);
457   return FusionInfo{producerOp, fusedProducer};
458 }
459 
460 /// Walk back use-def chain through scf::For yields.
461 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
462 
463 // TODO(ravishankarm, ntv): This can be moved into the dependence graphs
464 // dependence tracking since the dependence tracking is similar to what is done
465 // w.r.t to buffers.
466 static void getProducerOfTensor(Value tensor, OpResult &opResult) {
467   if (!tensor.getType().isa<RankedTensorType>())
468     return;
469 
470   while (true) {
471     LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
472     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
473       opResult = tensor.cast<OpResult>();
474       return;
475     }
476     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
477       tensor = subTensorOp.source();
478       continue;
479     }
480     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
481       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
482         tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
483         continue;
484       }
485     }
486     return;
487   }
488 }
489 
490 Optional<FusionInfo>
491 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
492   Value inputTensor = consumerOpOperand.get();
493   OpResult producerOpResult;
494   getProducerOfTensor(inputTensor, producerOpResult);
495   if (!producerOpResult) {
496     LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
497     return {};
498   }
499   return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
500 }
501 
502 Optional<FusionInfo>
503 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
504                                    OpOperand &consumerOpOperand) {
505   auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
506   if (!producerOp)
507     return llvm::None;
508 
509   LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
510   if (!consumerOp)
511     return llvm::None;
512 
513   Value inputTensor = consumerOpOperand.get();
514 
515   // Must be a subtensor to guarantee there are loops we can fuse into.
516   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
517   if (!subTensor) {
518     LLVM_DEBUG(llvm::dbgs()
519                << "\nNot fusable, not a subtensor: " << inputTensor);
520     return {};
521   }
522 
523   // If producer is already in the same block as consumer, we are done.
524   if (consumerOpOperand.get().getParentBlock() ==
525       producerOpResult.getParentBlock())
526     return {};
527 
528   // Insert fused `producer` just before `consumer`.
529   OpBuilder::InsertionGuard g(b);
530   b.setInsertionPoint(consumerOp);
531   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
532   OpOperand *opOperand =
533       producerOp.getOutputOperand(producerOpResult.getResultNumber());
534   LinalgOp fusedProducer =
535       fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand),
536            consumerOpOperand);
537 
538   // Replace use.
539   // Canonicalizations are not guaranteed to have happened before constructing
540   // `fusedProducer`. In the tensor case this can result in temporary type
541   // mismatches. Insert a `tensor.cast` op to propagate the transformation
542   // invariant that types are compatible.
543   Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
544   Type consumerType = consumerOpOperand.get().getType();
545   if (consumerType != def.getType())
546     def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
547   consumerOpOperand.set(def);
548   return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
549 }
550 
551 /// Prune all dimensions that are of reduction iterator type from `map`.
552 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
553                                            AffineMap map) {
554   llvm::SmallDenseSet<unsigned> projectedDims;
555   for (auto attr : llvm::enumerate(iteratorTypes)) {
556     if (!isParallelIterator(attr.value()))
557       projectedDims.insert(attr.index());
558   }
559   return getProjectedMap(map, projectedDims);
560 }
561 
562 /// Returns the mapping from iterations in the consumer that write to the same
563 /// location as the iterations in the producer. To do so use
564 /// - indexing map of the fused view in the consumer : consumerIndexMap
565 /// - indexing map of the fused view in the producer : producerIndexMap
566 ///     consumerLoopToProducerLoop =
567 ///       inverse(producerIndexMap).compose(consumerIndexMap)
568 static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
569     LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
570   auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp());
571   if (!producer)
572     return None;
573 
574   Optional<AffineMap> producerIndexingMap =
575       dependence.getDependentOpViewIndexingMap();
576   Optional<AffineMap> consumerIndexingMap =
577       dependence.getIndexingOpViewIndexingMap();
578   if (!producerIndexingMap || !consumerIndexingMap)
579     return None;
580 
581   AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
582       producer.iterator_types().getValue(), *producerIndexingMap);
583   if (!prunedProducerIndexingMap.isPermutation())
584     return None;
585 
586   if (consumerIndexingMap->getNumResults() !=
587       prunedProducerIndexingMap.getNumResults())
588     return None;
589 
590   LLVM_DEBUG({
591     llvm::dbgs() << "\t producerMap : ";
592     producerIndexingMap->print(llvm::dbgs());
593     llvm::dbgs() << "  pruned : ";
594     prunedProducerIndexingMap.print(llvm::dbgs());
595     llvm::dbgs() << "\n";
596     llvm::dbgs() << "\t consumerMap : ";
597     consumerIndexingMap->print(llvm::dbgs());
598     llvm::dbgs() << "\n";
599   });
600 
601   AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
602   if (!invProducerIndexMap)
603     return None;
604 
605   return invProducerIndexMap.compose(*consumerIndexingMap);
606 }
607 
608 /// Given a projected permutation `map`, returns true if the map changes the
609 /// order in which the fused loop dimension appear.
610 static bool doesTransposeAccess(AffineMap map,
611                                 const std::set<unsigned> &fusableLoops) {
612   Optional<unsigned> lastFusableLoop;
613   for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
614          return expr.cast<AffineDimExpr>().getPosition();
615        })) {
616     if (!fusableLoops.count(pos))
617       continue;
618     if (!lastFusableLoop) {
619       lastFusableLoop = pos;
620       continue;
621     }
622     if (pos <= lastFusableLoop.getValue())
623       return true;
624     lastFusableLoop = pos;
625   }
626   return false;
627 }
628 
629 /// Returns the positions of the loop in `op` that can be tiled based on the
630 /// operations that are to be fused with it. For example, in a
631 ///
632 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
633 ///
634 /// if the producer of %a needs to be fused with this op, only the `i` loop of
635 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
636 /// fused, then no loops can be tiled while fusing. The conditions used are:
637 /// 1. Only parallel loops can be used for tile + fuse. Find the number of
638 ///    common outer parallel loops between the op and its producers being fused.
639 /// 2. Of the parallel loops only some can be fused. Only those loops can be
640 ///    fused such where the fusable loops iteration space only touches one tile
641 ///    of the fused operation. This is because the producer (which is writing
642 ///    the fused subview) has update semantics.
643 ///
644 /// Since an inverse computation is needed, we need to consider the projection
645 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
646 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to
647 /// parallel loops and appear in the result of the map
648 ///
649 /// Example 1:
650 ///   linalg.fill(%c, %cst)
651 ///   linalg.matmul ins(%a, %b) outs(%c)
652 ///     Number of parallel loops : 2
653 ///     producerIndexMap = affine_map<(i, j) ->(i , j)>
654 ///     consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
655 ///     consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
656 ///     Fused dimensions : i, j
657 ///
658 /// Example 2:
659 ///   linalg.matmul ins(%a, %b) outs(%c)
660 ///   linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
661 ///                   iterator_types = ["parallel", "parallel"]}
662 ///     ins(%c) ...
663 ///
664 ///     Number of parallel loops = 2:
665 ///     producerIndexMap (projected to parallel loops) =
666 ///       affine_map<(i, j) -> (i, j)>
667 ///     consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
668 ///     Fused dimensions : i, j
669 ///
670 /// Example 3:
671 ///   linalg.copy(%s, %b)
672 ///   linalg.matmul ins(%a, %b) outs(%c)
673 ///
674 ///   Number of parallel loops = 2
675 ///   produceIndexMap : affine_map<(i, j) -> (i, j)>
676 ///   consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
677 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
678 ///   Fused dimensions : j
679 static std::set<unsigned>
680 collectFusableLoops(ArrayRef<LinalgOp> ops,
681                     const FusableOpDependencesTy &fusableDependences) {
682   assert(!ops.empty());
683   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
684     return linalgOp.iterator_types()
685         .getValue()
686         .take_while([](Attribute attr) -> bool {
687           return attr.cast<StringAttr>().getValue() ==
688                  getParallelIteratorTypeName();
689         })
690         .size();
691   };
692 
693   size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
694   for (auto op : ops.drop_back()) {
695     numOuterParallelLoops =
696         std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
697   }
698 
699   std::set<unsigned> fusableLoops;
700   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
701   fusableLoops.insert(range.begin(), range.end());
702 
703   for (auto op : reverse(ops)) {
704     for (auto dependence : fusableDependences.lookup(op)) {
705       LLVM_DEBUG({
706         llvm::dbgs() << "\t fusable :";
707         for (unsigned i : fusableLoops)
708           llvm::dbgs() << " " << i;
709         llvm::dbgs() << "\n";
710       });
711 
712       Optional<AffineMap> consumerLoopToProducerLoop =
713           getConsumerLoopToProducerLoopMap(dependence);
714       if (!consumerLoopToProducerLoop) {
715         op.emitRemark("failed to get map from consumer loop to producer loop");
716         return {};
717       }
718       // todo: This condition is only an implementation limitation. When fusing
719       // the operation, if the accesses in the producer/consumer are transposes
720       // of each other, the loop bounds for the tiled producer can be
721       // manipulated accordingly. This requires some additional bookkeeping in
722       // the implementation of tile+fuse that is deferred to later.
723       if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
724         op.emitRemark("unhandled fusion when fusion requires permutation");
725         return {};
726       }
727 
728       std::set<unsigned> candidates;
729       for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
730         unsigned position = expr.cast<AffineDimExpr>().getPosition();
731         if (fusableLoops.count(position))
732           candidates.insert(position);
733       }
734       LLVM_DEBUG({
735         llvm::dbgs() << "\t candidates :";
736         for (unsigned i : candidates)
737           llvm::dbgs() << " " << i;
738         llvm::dbgs() << "\n";
739       });
740       if (candidates.empty())
741         return {};
742       std::swap(candidates, fusableLoops);
743     }
744   }
745 
746   return fusableLoops;
747 }
748 
749 /// Find all dependences that are fusable.
750 FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
751     ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
752   FusableOpDependencesTy fusableDependences;
753   DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
754   for (LinalgOp op : reverse(ops)) {
755     for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
756       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
757           fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
758       if (!fusableDependence)
759         continue;
760       LinalgOp producerOp =
761           dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
762       if (!producerOp)
763         continue;
764       // Do not fuse dependences that are to operations not in the same basic
765       // block. This avoid moving fused operations across loops that might
766       // themselves carry dependency making the fusion illegal.
767       if (producerOp->getBlock() != op->getBlock())
768         continue;
769 
770       // Make sure that the indexing map of the view used for fusion in the
771       // producer is a projected permutation.
772       Optional<AffineMap> producerMap =
773           fusableDependence->getDependentOpViewIndexingMap();
774       Optional<AffineMap> consumerMap =
775           fusableDependence->getIndexingOpViewIndexingMap();
776       assert(
777           consumerMap &&
778           "unable to find indexing map of operand/result of indexing OpView");
779       fusedProducerIndexingMap[producerOp.getOperation()].push_back(
780           *consumerMap);
781       if (!producerMap || !producerMap->isProjectedPermutation() ||
782           !consumerMap->isProjectedPermutation())
783         continue;
784 
785       fusableDependences[producerOp.getOperation()].push_back(
786           *fusableDependence);
787     }
788   }
789   // TODO: Currently fusion would not be legal if the fusable dependence is to
790   // the same producer but different indexing map in the consumer. Fix this, but
791   // in the meanwhile disallow such a fusion.
792   for (auto useIndexingMapsList : fusedProducerIndexingMap) {
793     AffineMap map1 = useIndexingMapsList.second.front();
794     for (AffineMap map2 :
795          ArrayRef<AffineMap>(useIndexingMapsList.second).drop_front()) {
796       if (map1 != map2) {
797         fusableDependences.erase(useIndexingMapsList.first);
798         break;
799       }
800     }
801   }
802   return fusableDependences;
803 }
804 
805 /// Tile the fused loops in the root operation, by setting the tile sizes for
806 /// all other loops to zero (those will be tiled later).
807 static Optional<TiledLinalgOp>
808 tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector,
809                   const LinalgTilingOptions &options,
810                   const std::set<unsigned> &fusedLoops) {
811   SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
812   auto zero = b.create<ConstantIndexOp>(op.getLoc(), 0);
813   for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
814     if (!fusedLoops.count(i))
815       tileSizes[i] = zero;
816   LinalgTilingOptions tileFusedLoopsOptions = options;
817   tileFusedLoopsOptions.setTileSizes(tileSizes);
818   return tileLinalgOp(b, op, tileFusedLoopsOptions);
819 }
820 
821 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
822 /// to be a tiled operation such that it is valid to fuse all operations in
823 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
824 /// `tiledOp`.
825 static SmallVector<LinalgOp, 1>
826 fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
827                ArrayRef<LinalgOp> fusionCandidates,
828                const FusableOpDependencesTy &fusableDependences,
829                const std::set<unsigned> &fusedLoops) {
830   LinalgOp tiledOp = tiledLinalgOp.op;
831   OpBuilder::InsertionGuard guard(b);
832   b.setInsertionPoint(tiledOp);
833 
834   DenseMap<unsigned, Range> fusedLoopsAndRanges;
835   for (unsigned loop : fusedLoops) {
836     ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
837     fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
838         b, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
839   }
840 
841   SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
842   DenseMap<Operation *, LinalgOp> origOpToFusedOp;
843   origOpToFusedOp[rootOp.getOperation()] = tiledOp;
844   for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
845     LinalgOp origOp = candidate.value();
846     LinalgOp fusedOp = fuse(b, origOp, fusedLoopsAndRanges);
847     origOpToFusedOp[origOp.getOperation()] = fusedOp;
848     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
849 
850     // Prepare the builder for the next insertion point.
851     auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); });
852     if (!origOp.hasTensorSemantics())
853       continue;
854 
855     // If the producer consumer operations are linalg operations on tensors, the
856     // dependence is due to value produced (as a return tensor) by the producer
857     // and used in the consumer. The returned value of the fused op needs to be
858     // made the operand of the tiled/fused consumer operation. By construction
859     // the value returned by the producer is the value used by the consumer.
860     for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) {
861       if (dependence.dependenceType !=
862           LinalgDependenceGraph::DependenceType::RAW)
863         continue;
864 
865       unsigned resultIndex =
866           dependence.getDependentOpViewResultNum().getValue();
867       LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
868       if (!consumer)
869         continue;
870 
871       Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
872       consumer.getOperation()->setOperand(
873           dependence.getIndexingOpViewOperandNum().getValue(),
874           replacementValue);
875     }
876 
877     // At this point, all Linalg uses of the tensors produced by `origOp` have
878     // been replaced. However, there may still be "output tensor"-like uses
879     // coming from WAW dependencies.
880     // All these uses are iter_args of the outermost loop (TODO: add a check).
881     // Such iter_args uses serve 2 purposes:
882     //  1. give a shape to the output
883     //  2. encode destructive updates that may be inplaceable by bufferization.
884     // To keep the second type of information while letting the unfused op die
885     // unused, we need to forward the producer output operand.
886     if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
887       for (auto &operand : forOp.getIterOpOperands()) {
888         if (auto opResult = operand.get().dyn_cast<OpResult>()) {
889           if (opResult.getOwner() == origOp) {
890             Value output =
891                 origOp.getOutputOperand(opResult.getResultNumber())->get();
892             assert(output.getType().isa<RankedTensorType>());
893             operand.set(output);
894           }
895         }
896       }
897     }
898   }
899   return fusedOps;
900 }
901 
902 static Optional<TiledAndFusedLinalgOps>
903 tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
904                          const LinalgDependenceGraph &dependenceGraph,
905                          const LinalgTilingOptions &tilingOptions) {
906   if (ops.size() < 2)
907     return llvm::None;
908   LinalgOp rootOp = ops.back();
909   if (!llvm::all_of(
910           ops,
911           [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) &&
912       !llvm::all_of(ops, [](LinalgOp linalgOp) {
913         return linalgOp.hasTensorSemantics();
914       })) {
915     rootOp.emitError(
916         "unable to fuse operations that have tensor semantics with operations "
917         "that have buffer semantics and viceversa.");
918     return llvm::None;
919   }
920   // TODO: Support interchange with tile + fuse. This might actually help do
921   // better fusion.
922   if (!tilingOptions.interchangeVector.empty()) {
923     rootOp.emitRemark("unable to handle tile and fuse with interchange");
924     return llvm::None;
925   }
926 
927   OpBuilder::InsertionGuard guard(b);
928   b.setInsertionPoint(rootOp);
929 
930   // Find all the producers.
931   LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n");
932   FusableOpDependencesTy fusableDependences =
933       findAllFusableDependences(ops, dependenceGraph);
934   if (fusableDependences.empty()) {
935     LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n");
936     return llvm::None;
937   }
938 
939   TiledAndFusedLinalgOps ret;
940   // Find the loops that can be tiled and fused.
941   LLVM_DEBUG(llvm::dbgs() << "collectFusableLoops\n");
942   ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
943 
944   // If there are no fusable dependences or there are no tile+fusable loops,
945   // just return.
946   if (ret.fusedLoopDims.empty()) {
947     LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n");
948     return llvm::None;
949   }
950 
951   // Tile the fused loops in the last operation in the list.
952   SmallVector<Value, 4> tileSizeVector =
953       tilingOptions.tileSizeComputationFunction(b, rootOp);
954   Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
955       b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
956   if (!tiledRootOp) {
957     rootOp.emitRemark("failed to tile the fused loops");
958     return llvm::None;
959   }
960   ret.op = tiledRootOp->op;
961   ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
962 
963   // Fuse the other operations into the fused inter-tile loops produced above.
964   ret.fusedProducers = fuseOperations(b, rootOp, *tiledRootOp, ops.drop_back(),
965                                       fusableDependences, ret.fusedLoopDims);
966 
967   return ret;
968 }
969 
970 Optional<TiledAndFusedLinalgOps>
971 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops,
972                                    const LinalgDependenceGraph &dependenceGraph,
973                                    const LinalgTilingOptions &tilingOptions) {
974   switch (tilingOptions.loopType) {
975   case LinalgTilingLoopType::Loops:
976   case LinalgTilingLoopType::ParallelLoops:
977   case LinalgTilingLoopType::TiledLoops:
978     return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions);
979   default:;
980   }
981   return llvm::None;
982 }
983