xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision c0a6318d96344b475eec1229b664dd04b569a375)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Dominance.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "mlir/Transforms/RegionUtils.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/ScopeExit.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/Debug.h"
33 
34 #include <set>
35 
36 #define DEBUG_TYPE "linalg-fusion"
37 
38 using namespace mlir;
39 using namespace mlir::linalg;
40 
41 using llvm::dbgs;
42 
43 /// Implements a simple high-level fusion pass on linalg structured operations.
44 ///
45 /// In each block, linalg ops are processed in reverse textual order.
46 /// Given a linalg op `O`, fusion occurs by:
47 ///   1. inspecting the linalg ops that write into the views read by `O`. There
48 ///      are 2 cases:
49 ///      a) buffer case: use the SSA value of the views and a simple alias
50 ///         analysis on subview ops to determine producer-consumer dependences;
51 ///      b) tensor case: use SSA use-def chains on extract_slice ops;
52 ///   2. greedily fuse the linalg ops that produce the subview/extract_slice.
53 ///   3. inspect the fused ops and determine whether they have other remaining
54 ///      LinalgOp uses. If not, then erase the original producing linalg op.
55 ///
56 /// More advanced use cases, analyses as well as profitability heuristics are
57 /// left for future work.
58 
59 struct ShapeDimension {
60   Value shape;
61   unsigned dimension;
62 };
63 
64 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
65 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
66 // guarantees at least one such dimension is found. If multiple candidates exist
67 // they must agree by construction (i.e. have the same size) and we just return
68 // the first one.
69 static ShapeDimension
70 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
71                           bool fromSubViewOpOnly = false) {
72   // Iterate over the inputs and outputs in order.
73   // Extract the subranges from the linearized ranges.
74   for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
75     // The method `getRangeFromOperandShape` requires using SubViewOp or
76     // ExtractSliceOps. If the value isn't defined from there continue.
77     // todo: The method should be adapted to get the values from
78     // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
79     // currently returns a `linalg.range`. The fix here is to move this op to
80     // `std` dialect and add the method to `ViewInterface`.
81     if (fromSubViewOpOnly &&
82         !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>(
83             opOperand->get().getDefiningOp()))
84       continue;
85 
86     AffineMap map = op.getTiedIndexingMap(opOperand);
87     LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
88                             << opOperand->getOperandNumber() << "\n");
89     LLVM_DEBUG(llvm::dbgs()
90                << "getShapeDefiningLoopRange map: " << map << "\n");
91     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
92     for (auto en : llvm::enumerate(map.getResults())) {
93       auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
94       if (!dimExpr)
95         continue;
96       if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) {
97         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
98                                 << loopDepth << "\n");
99         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
100                                 << opOperand->get() << "\n");
101         return ShapeDimension{opOperand->get(),
102                               static_cast<unsigned>(en.index())};
103       }
104     }
105   }
106   llvm_unreachable("Expect to be able to extract a shape defining loop range");
107 }
108 
109 // Return tiled operands for the fused producer op. When fusing into
110 // `linalg.tiled_loop` one has to update `input` and `output` arguments of the
111 // loop correspondingly.
112 // Each input tensor of the producer op has to be added to `inputs` of the
113 // `tiled_loop` if it is not present there already. Each output tensor has to
114 // be added either to `inputs` or to `outputs` of `linalg.tiled_loop` depending
115 // on whether the correponding result is an input or an output to the loop.
116 //
117 // NOTE: This way of updating the arguments of the `tiled_loop` assumes that the
118 // intermediate result is not used by any other operation but the consumer. A
119 // more generic way is to append all missing output tensors of the producer to
120 // the tiled loop outputs and hence modify the number of the results, since we
121 // would need to add the intermediate results to `linalg.yield`. After that a
122 // canonicalization pass would move the unused output args of the `tiled_loop`
123 // to the `input` section.
124 static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) {
125   auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
126   if (!tiledLoop)
127     return producer.getInputAndOutputOperands();
128 
129   SmallVector<Value> tiledOperands;
130   assert(producer.hasTensorSemantics() &&
131          "only fusion on tensors is currently supported for TiledLinalgOp");
132 
133   for (OpOperand *producerInput : producer.getInputOperands()) {
134     OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get());
135     if (addedInput == nullptr)
136       addedInput = &tiledLoop.appendInputOperand(b, producerInput->get());
137     BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
138     tiledOperands.push_back(addedBlockArg);
139   }
140   for (OpOperand *producerOutput : producer.getOutputOperands()) {
141     OpResult result = producer.getTiedOpResult(producerOutput);
142     OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
143     OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
144     assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
145            "The result should be present in `input` or `output` args of "
146            "`tiled_loop");
147 
148     bool isInput = resultInputOperand;
149     int opNumber = isInput ? resultInputOperand->getOperandNumber()
150                            : resultOutputOperand->getOperandNumber();
151 
152     OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput->get());
153     if (addedOutput == nullptr)
154       addedOutput =
155           isInput ? &tiledLoop.appendInputOperand(b, producerOutput->get())
156                   : &tiledLoop.appendOutputOperand(b, producerOutput->get());
157 
158     OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
159     auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
160     auto resultOperandBlockArg = tiledLoop.getTiedBlockArgument(resultOperand);
161     resultOperandBlockArg.replaceAllUsesWith(addedBlockArg);
162     tiledLoop.eraseOperand(b, resultOperand);
163     tiledOperands.push_back(addedBlockArg);
164   }
165   return tiledOperands;
166 }
167 
168 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
169 /// provides the loop range information for the fused loops. The rest are
170 /// obtained from the producer itself, since they are not tiled + fused.
171 static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
172                      const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
173   SmallVector<Value, 8> ivs, tileSizes, sizeBounds;
174   SmallVector<Range, 8> loopRanges;
175   Location loc = producer.getLoc();
176   auto zero = b.create<ConstantIndexOp>(loc, 0);
177   auto one = b.create<ConstantIndexOp>(loc, 1);
178 
179   for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
180     auto it = fusedLoopsAndRanges.find(i);
181     if (it != fusedLoopsAndRanges.end()) {
182       ivs.push_back(it->second.offset);
183       tileSizes.push_back(it->second.size);
184       sizeBounds.push_back(nullptr);
185       loopRanges.push_back(it->second);
186       LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
187                               << loopRanges.back() << "\n");
188     } else {
189       auto shapeDim = getShapeDefiningLoopRange(producer, i);
190       Value dim = createOrFoldDimOp(b, loc, shapeDim.shape, shapeDim.dimension);
191       tileSizes.push_back(zero);
192       sizeBounds.push_back(dim);
193       loopRanges.push_back(Range{zero, dim, one});
194       LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
195                               << loopRanges.back() << "\n");
196     }
197   }
198 
199   SmallVector<Value, 8> clonedShapes;
200   clonedShapes.reserve(producer.getNumInputsAndOutputs());
201 
202   // Compute subranges for all tensor input/output operands.
203   clonedShapes.append(makeTiledShapes(b, loc, producer,
204                                       getTiledOperands(b, producer), ivs,
205                                       tileSizes, sizeBounds));
206 
207   // Iterate over the results in order.
208   // Extract the subtensor type from the linearized range.
209   // Since we do not enforce any canonicalizations on the fly, this is always
210   // fully dynamic at construction time.
211   SmallVector<Type, 4> resultTypes;
212   resultTypes.reserve(producer->getNumResults());
213   for (RankedTensorType t : producer.getOutputTensorTypes()) {
214     unsigned rank = t.getRank();
215     SmallVector<int64_t, 4> staticOffsetsVector(
216         rank, ShapedType::kDynamicStrideOrOffset);
217     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
218     SmallVector<int64_t, 4> staticStridesVector(
219         rank, ShapedType::kDynamicStrideOrOffset);
220     resultTypes.push_back(tensor::ExtractSliceOp::inferResultType(
221         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
222         staticStridesVector));
223   }
224 
225   Operation *clonedOp = producer.clone(b, loc, resultTypes, clonedShapes);
226   // When the producer has index semantics, we have to transform the indices of
227   // the producer according to the tiling of the consumer, i.e. offset them by
228   // the values computed in `loopRanges`.
229   if (producer.hasIndexSemantics()) {
230     assert(clonedOp->getNumRegions() == 1 &&
231            clonedOp->getRegion(0).getBlocks().size() == 1 &&
232            "expected producer to have one block.");
233     // Shift all indices by the tile offset.
234     Block &block = clonedOp->getRegion(0).front();
235     for (IndexOp indexOp : block.getOps<IndexOp>()) {
236       OpBuilder::InsertionGuard g(b);
237       b.setInsertionPointAfter(indexOp);
238       AffineExpr index, offset;
239       bindDims(b.getContext(), index, offset);
240       AffineApplyOp applyOp = b.create<AffineApplyOp>(
241           indexOp.getLoc(), index + offset,
242           ValueRange{indexOp.getResult(), loopRanges[indexOp.dim()].offset});
243       indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
244     }
245   }
246 
247   return clonedOp;
248 }
249 
250 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
251 /// expected to be defined by a subview op or an extract_slice op.
252 static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
253                                       Value shapedOperand, unsigned dim) {
254   Operation *shapeProducingOp = shapedOperand.getDefiningOp();
255   if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp))
256     return subViewOp.getOrCreateRanges(b, loc)[dim];
257   if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(shapeProducingOp))
258     return sliceOp.getOrCreateRanges(b, loc)[dim];
259   llvm_unreachable("SubviewOp or ExtractSliceOp expected");
260 }
261 
262 /// Fuses the producer into the loop immediately enclosing the consumer.
263 /// This is achieved by "recomputing" the producer at the time it
264 /// is needed just before the consumer.
265 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
266                      OpOperand &consumerOpOperand) {
267   LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
268   DenseMap<unsigned, Range> fusedLoopsAndRanges;
269   Value shapedOperand = consumerOpOperand.get();
270   for (auto en : llvm::enumerate(producerMap.getResults())) {
271     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
272     fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
273         b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
274   }
275   return fuse(b, producerOp, fusedLoopsAndRanges);
276 }
277 
278 // Encode structural fusion safety preconditions.
279 // Some of these will be lifted in the future with better analysis.
280 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
281                                           LinalgOp consumer) {
282   assert(producer.hasBufferSemantics() &&
283          "expected linalg op with buffer semantics");
284   assert(consumer.hasBufferSemantics() &&
285          "expected linalg op with buffer semantics");
286   if (producer.getNumOutputs() != 1) {
287     LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
288     return false;
289   }
290   // Only fuse when the producer block dominates.
291   DominanceInfo dom(producer.getOperation());
292   if (!dom.dominates(producer->getBlock(), consumer->getBlock())) {
293     LLVM_DEBUG(
294         llvm::dbgs()
295         << "\nNot structurally fusable (producer block does not dominate)");
296     return false;
297   }
298   return true;
299 }
300 
301 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
302                                              LinalgOp consumer,
303                                              Value consumedView,
304                                              LinalgOp producer) {
305   assert(producer.hasBufferSemantics() &&
306          "expected linalg op with buffer semantics");
307   assert(consumer.hasBufferSemantics() &&
308          "expected linalg op with buffer semantics");
309   // Make some simple structural checks that alleviate the need for more
310   // complex analyses.
311   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
312     LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
313                             << *producer.getOperation());
314     return false;
315   }
316   // Check for any interleaved write to consumedView.
317   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
318     LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
319                             << *producer.getOperation());
320     return false;
321   }
322   return true;
323 }
324 
325 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
326                                  LinalgOp consumer, Value consumedView,
327                                  LinalgOp producer) {
328   assert(producer.hasBufferSemantics() &&
329          "expected linalg op with buffer semantics");
330   assert(consumer.hasBufferSemantics() &&
331          "expected linalg op with buffer semantics");
332   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
333     return false;
334   // Check for any fusion-preventing dependence to any shape read/written that
335   // would violate dependences.
336   if (!graph.findCoveringDependences(producer, consumer).empty()) {
337     LLVM_DEBUG(llvm::dbgs()
338                << "\n***Not fusable due to an interleaved dependence:\t"
339                << *producer.getOperation());
340     return false;
341   }
342   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
343     // TODO: add a level of indirection to linalg.generic.
344     if (convOp.padding())
345       return false;
346   }
347   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
348     // TODO: add a level of indirection to linalg.generic.
349     if (convOp.padding())
350       return false;
351   }
352   return true;
353 }
354 
355 /// For `consumer` with buffer semantics, find the Linalg operation on buffers
356 /// that is the last writer of `consumerOpOperand`. For now the fusable
357 /// dependence is returned as an instance of the `dependenceGraph`.
358 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
359 findFusableProducer(OpOperand &consumerOpOperand,
360                     const LinalgDependenceGraph &dependenceGraph) {
361   LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: "
362                           << consumerOpOperand.get() << " @"
363                           << consumerOpOperand.getOperandNumber() << " in "
364                           << *consumerOpOperand.getOwner() << "\n");
365   LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
366   if (!consumerOp)
367     return {};
368 
369   // Only consider RAW and WAW atm.
370   for (auto depType : {
371            LinalgDependenceGraph::DependenceType::RAW,
372            LinalgDependenceGraph::DependenceType::WAW,
373        }) {
374     LLVM_DEBUG(llvm::dbgs()
375                << "Dependencies into: " << *consumerOp.getOperation() << "\n");
376     for (auto dependence : llvm::make_filter_range(
377              dependenceGraph.getDependencesInto(consumerOp, depType),
378              [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
379                LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: "
380                                        << elem.getIndexingValue() << " and "
381                                        << elem.getDependentValue() << "\n");
382                Value v = elem.getIndexingValue();
383                Optional<unsigned> operandNum =
384                    elem.getIndexingOpViewOperandNum();
385                return isa<LinalgOp>(elem.getDependentOp()) &&
386                       v == consumerOpOperand.get() && operandNum &&
387                       operandNum.getValue() ==
388                           consumerOpOperand.getOperandNumber();
389              })) {
390       // Consumer consumes this view, `isStructurallyFusableProducer` also
391       // checks whether it is a strict subview of the producer view.
392       auto producer = cast<LinalgOp>(dependence.getDependentOp());
393       LLVM_DEBUG(llvm::dbgs()
394                  << "\n"
395                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
396                  << "producer: " << *dependence.getDependentOp()
397                  << " view: " << dependence.getDependentValue() << "\n");
398 
399       // If the producer and consumer have tensor semantics, the only dependence
400       // between them is through a RAW dependence and they are fusable by
401       // construction. For buffer semantics need additional checks.
402       if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() &&
403           isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
404                         producer))
405         return dependence;
406       if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) {
407         assert(dependence.dependenceType ==
408                LinalgDependenceGraph::DependenceType::RAW);
409         return dependence;
410       }
411     }
412   }
413   return {};
414 }
415 
416 Optional<FusionInfo>
417 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
418                                    const LinalgDependenceGraph &graph) {
419   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
420       findFusableProducer(consumerOpOperand, graph);
421   if (!fusableDependence)
422     return llvm::None;
423 
424   LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
425   if (!producerOp)
426     return llvm::None;
427 
428   // If producer is already in the same block as consumer, we are done.
429   if (consumerOpOperand.get().getParentBlock() ==
430       fusableDependence->getDependentValue().getParentBlock())
431     return llvm::None;
432 
433   Optional<AffineMap> producerMap =
434       fusableDependence->getDependentOpViewIndexingMap();
435   if (!producerMap)
436     return llvm::None;
437 
438   // Must be a subview or an extract_slice to guarantee there are loops we can
439   // fuse into.
440   auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>();
441   if (!subView) {
442     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)");
443     return llvm::None;
444   }
445 
446   // Fuse `producer` just before `consumer`.
447   OpBuilder::InsertionGuard g(b);
448   b.setInsertionPoint(consumerOpOperand.getOwner());
449   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: "
450                           << *consumerOpOperand.getOwner() << "\n");
451 
452   auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand);
453   return FusionInfo{producerOp, fusedProducer};
454 }
455 
456 /// Walk back use-def chain through scf::For yields.
457 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
458 
459 // TODO(ravishankarm, ntv): This can be moved into the dependence graphs
460 // dependence tracking since the dependence tracking is similar to what is done
461 // w.r.t to buffers.
462 static void getProducerOfTensor(Value tensor, OpResult &opResult) {
463   if (!tensor.getType().isa<RankedTensorType>())
464     return;
465 
466   while (true) {
467     LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
468     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
469       opResult = tensor.cast<OpResult>();
470       return;
471     }
472     if (auto sliceOp = tensor.getDefiningOp<tensor::ExtractSliceOp>()) {
473       tensor = sliceOp.source();
474       continue;
475     }
476     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
477       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
478         tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
479         continue;
480       }
481     }
482     return;
483   }
484 }
485 
486 Optional<FusionInfo>
487 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
488   Value inputTensor = consumerOpOperand.get();
489   OpResult producerOpResult;
490   getProducerOfTensor(inputTensor, producerOpResult);
491   if (!producerOpResult) {
492     LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
493     return {};
494   }
495   return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
496 }
497 
498 Optional<FusionInfo>
499 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
500                                    OpOperand &consumerOpOperand) {
501   auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
502   if (!producerOp)
503     return llvm::None;
504 
505   LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
506   if (!consumerOp)
507     return llvm::None;
508 
509   Value inputTensor = consumerOpOperand.get();
510 
511   // Must be an extract_slice op to guarantee there are loops we can fuse into.
512   auto sliceOp = inputTensor.getDefiningOp<tensor::ExtractSliceOp>();
513   if (!sliceOp) {
514     LLVM_DEBUG(llvm::dbgs()
515                << "\nNot fusable, not an extract_slice op: " << inputTensor);
516     return {};
517   }
518 
519   // If producer is already in the same block as consumer, we are done.
520   if (consumerOpOperand.get().getParentBlock() ==
521       producerOpResult.getParentBlock())
522     return {};
523 
524   // Insert fused `producer` just before `consumer`.
525   OpBuilder::InsertionGuard g(b);
526   b.setInsertionPoint(consumerOp);
527   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
528   OpOperand *opOperand =
529       producerOp.getOutputOperand(producerOpResult.getResultNumber());
530   LinalgOp fusedProducer =
531       fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand),
532            consumerOpOperand);
533 
534   // Replace use.
535   // Canonicalizations are not guaranteed to have happened before constructing
536   // `fusedProducer`. In the tensor case this can result in temporary type
537   // mismatches. Insert a `tensor.cast` op to propagate the transformation
538   // invariant that types are compatible.
539   Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
540   Type consumerType = consumerOpOperand.get().getType();
541   if (consumerType != def.getType())
542     def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
543   consumerOpOperand.set(def);
544   return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
545 }
546 
547 /// Prune all dimensions that are of reduction iterator type from `map`.
548 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
549                                            AffineMap map) {
550   llvm::SmallDenseSet<unsigned> projectedDims;
551   for (auto attr : llvm::enumerate(iteratorTypes)) {
552     if (!isParallelIterator(attr.value()))
553       projectedDims.insert(attr.index());
554   }
555   return getProjectedMap(map, projectedDims);
556 }
557 
558 /// Returns the mapping from iterations in the consumer that write to the same
559 /// location as the iterations in the producer. To do so use
560 /// - indexing map of the fused view in the consumer : consumerIndexMap
561 /// - indexing map of the fused view in the producer : producerIndexMap
562 ///     consumerLoopToProducerLoop =
563 ///       inverse(producerIndexMap).compose(consumerIndexMap)
564 static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
565     LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
566   auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp());
567   if (!producer)
568     return None;
569 
570   Optional<AffineMap> producerIndexingMap =
571       dependence.getDependentOpViewIndexingMap();
572   Optional<AffineMap> consumerIndexingMap =
573       dependence.getIndexingOpViewIndexingMap();
574   if (!producerIndexingMap || !consumerIndexingMap)
575     return None;
576 
577   AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
578       producer.iterator_types().getValue(), *producerIndexingMap);
579   if (!prunedProducerIndexingMap.isPermutation())
580     return None;
581 
582   if (consumerIndexingMap->getNumResults() !=
583       prunedProducerIndexingMap.getNumResults())
584     return None;
585 
586   LLVM_DEBUG({
587     llvm::dbgs() << "\t producerMap : ";
588     producerIndexingMap->print(llvm::dbgs());
589     llvm::dbgs() << "  pruned : ";
590     prunedProducerIndexingMap.print(llvm::dbgs());
591     llvm::dbgs() << "\n";
592     llvm::dbgs() << "\t consumerMap : ";
593     consumerIndexingMap->print(llvm::dbgs());
594     llvm::dbgs() << "\n";
595   });
596 
597   AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
598   if (!invProducerIndexMap)
599     return None;
600 
601   return invProducerIndexMap.compose(*consumerIndexingMap);
602 }
603 
604 /// Given a projected permutation `map`, returns true if the map changes the
605 /// order in which the fused loop dimension appear.
606 static bool doesTransposeAccess(AffineMap map,
607                                 const std::set<unsigned> &fusableLoops) {
608   Optional<unsigned> lastFusableLoop;
609   for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
610          return expr.cast<AffineDimExpr>().getPosition();
611        })) {
612     if (!fusableLoops.count(pos))
613       continue;
614     if (!lastFusableLoop) {
615       lastFusableLoop = pos;
616       continue;
617     }
618     if (pos <= lastFusableLoop.getValue())
619       return true;
620     lastFusableLoop = pos;
621   }
622   return false;
623 }
624 
625 /// Returns the positions of the loop in `op` that can be tiled based on the
626 /// operations that are to be fused with it. For example, in a
627 ///
628 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
629 ///
630 /// if the producer of %a needs to be fused with this op, only the `i` loop of
631 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
632 /// fused, then no loops can be tiled while fusing. The conditions used are:
633 /// 1. Only parallel loops can be used for tile + fuse. Find the number of
634 ///    common outer parallel loops between the op and its producers being fused.
635 /// 2. Of the parallel loops only some can be fused. Only those loops can be
636 ///    fused such where the fusable loops iteration space only touches one tile
637 ///    of the fused operation. This is because the producer (which is writing
638 ///    the fused subview) has update semantics.
639 ///
640 /// Since an inverse computation is needed, we need to consider the projection
641 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
642 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to
643 /// parallel loops and appear in the result of the map
644 ///
645 /// Example 1:
646 ///   linalg.fill(%cst, %c)
647 ///   linalg.matmul ins(%a, %b) outs(%c)
648 ///     Number of parallel loops : 2
649 ///     producerIndexMap = affine_map<(i, j) ->(i , j)>
650 ///     consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
651 ///     consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
652 ///     Fused dimensions : i, j
653 ///
654 /// Example 2:
655 ///   linalg.matmul ins(%a, %b) outs(%c)
656 ///   linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
657 ///                   iterator_types = ["parallel", "parallel"]}
658 ///     ins(%c) ...
659 ///
660 ///     Number of parallel loops = 2:
661 ///     producerIndexMap (projected to parallel loops) =
662 ///       affine_map<(i, j) -> (i, j)>
663 ///     consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
664 ///     Fused dimensions : i, j
665 ///
666 /// Example 3:
667 ///   linalg.copy(%s, %b)
668 ///   linalg.matmul ins(%a, %b) outs(%c)
669 ///
670 ///   Number of parallel loops = 2
671 ///   produceIndexMap : affine_map<(i, j) -> (i, j)>
672 ///   consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
673 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
674 ///   Fused dimensions : j
675 static std::set<unsigned>
676 collectFusableLoops(ArrayRef<LinalgOp> ops,
677                     const FusableOpDependencesTy &fusableDependences) {
678   assert(!ops.empty());
679   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
680     return linalgOp.iterator_types()
681         .getValue()
682         .take_while([](Attribute attr) -> bool {
683           return attr.cast<StringAttr>().getValue() ==
684                  getParallelIteratorTypeName();
685         })
686         .size();
687   };
688 
689   size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
690   for (auto op : ops.drop_back()) {
691     numOuterParallelLoops =
692         std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
693   }
694 
695   std::set<unsigned> fusableLoops;
696   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
697   fusableLoops.insert(range.begin(), range.end());
698 
699   for (auto op : reverse(ops)) {
700     for (auto dependence : fusableDependences.lookup(op)) {
701       LLVM_DEBUG({
702         llvm::dbgs() << "\t fusable :";
703         for (unsigned i : fusableLoops)
704           llvm::dbgs() << " " << i;
705         llvm::dbgs() << "\n";
706       });
707 
708       Optional<AffineMap> consumerLoopToProducerLoop =
709           getConsumerLoopToProducerLoopMap(dependence);
710       if (!consumerLoopToProducerLoop) {
711         op.emitRemark("failed to get map from consumer loop to producer loop");
712         return {};
713       }
714       // todo: This condition is only an implementation limitation. When fusing
715       // the operation, if the accesses in the producer/consumer are transposes
716       // of each other, the loop bounds for the tiled producer can be
717       // manipulated accordingly. This requires some additional bookkeeping in
718       // the implementation of tile+fuse that is deferred to later.
719       if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
720         op.emitRemark("unhandled fusion when fusion requires permutation");
721         return {};
722       }
723 
724       std::set<unsigned> candidates;
725       for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
726         unsigned position = expr.cast<AffineDimExpr>().getPosition();
727         if (fusableLoops.count(position))
728           candidates.insert(position);
729       }
730       LLVM_DEBUG({
731         llvm::dbgs() << "\t candidates :";
732         for (unsigned i : candidates)
733           llvm::dbgs() << " " << i;
734         llvm::dbgs() << "\n";
735       });
736       if (candidates.empty())
737         return {};
738       std::swap(candidates, fusableLoops);
739     }
740   }
741 
742   return fusableLoops;
743 }
744 
745 /// Find all dependences that are fusable.
746 FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
747     ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
748   FusableOpDependencesTy fusableDependences;
749   DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
750   for (LinalgOp op : reverse(ops)) {
751     for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
752       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
753           fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
754       if (!fusableDependence)
755         continue;
756       LinalgOp producerOp =
757           dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
758       if (!producerOp)
759         continue;
760       // Do not fuse dependences that are to operations not in the same basic
761       // block. This avoid moving fused operations across loops that might
762       // themselves carry dependency making the fusion illegal.
763       if (producerOp->getBlock() != op->getBlock())
764         continue;
765 
766       // Make sure that the indexing map of the view used for fusion in the
767       // producer is a projected permutation.
768       Optional<AffineMap> producerMap =
769           fusableDependence->getDependentOpViewIndexingMap();
770       Optional<AffineMap> consumerMap =
771           fusableDependence->getIndexingOpViewIndexingMap();
772       assert(
773           consumerMap &&
774           "unable to find indexing map of operand/result of indexing OpView");
775       fusedProducerIndexingMap[producerOp.getOperation()].push_back(
776           *consumerMap);
777       if (!producerMap || !producerMap->isProjectedPermutation() ||
778           !consumerMap->isProjectedPermutation())
779         continue;
780 
781       fusableDependences[producerOp.getOperation()].push_back(
782           *fusableDependence);
783     }
784   }
785   // TODO: Currently fusion would not be legal if the fusable dependence is to
786   // the same producer but different indexing map in the consumer. Fix this, but
787   // in the meanwhile disallow such a fusion.
788   for (auto useIndexingMapsList : fusedProducerIndexingMap) {
789     AffineMap map1 = useIndexingMapsList.second.front();
790     for (AffineMap map2 :
791          ArrayRef<AffineMap>(useIndexingMapsList.second).drop_front()) {
792       if (map1 != map2) {
793         fusableDependences.erase(useIndexingMapsList.first);
794         break;
795       }
796     }
797   }
798   return fusableDependences;
799 }
800 
801 /// Tile the fused loops in the root operation, by setting the tile sizes for
802 /// all other loops to zero (those will be tiled later).
803 static Optional<TiledLinalgOp>
804 tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector,
805                   const LinalgTilingOptions &options,
806                   const std::set<unsigned> &fusedLoops) {
807   SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
808   auto zero = b.create<ConstantIndexOp>(op.getLoc(), 0);
809   for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
810     if (!fusedLoops.count(i))
811       tileSizes[i] = zero;
812   LinalgTilingOptions tileFusedLoopsOptions = options;
813   tileFusedLoopsOptions.setTileSizes(tileSizes);
814   return tileLinalgOp(b, op, tileFusedLoopsOptions);
815 }
816 
817 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
818 /// to be a tiled operation such that it is valid to fuse all operations in
819 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
820 /// `tiledOp`.
821 static SmallVector<LinalgOp, 1>
822 fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
823                ArrayRef<LinalgOp> fusionCandidates,
824                const FusableOpDependencesTy &fusableDependences,
825                const std::set<unsigned> &fusedLoops) {
826   LinalgOp tiledOp = tiledLinalgOp.op;
827   OpBuilder::InsertionGuard guard(b);
828   b.setInsertionPoint(tiledOp);
829 
830   DenseMap<unsigned, Range> fusedLoopsAndRanges;
831   for (unsigned loop : fusedLoops) {
832     ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
833     fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
834         b, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
835   }
836 
837   SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
838   DenseMap<Operation *, LinalgOp> origOpToFusedOp;
839   origOpToFusedOp[rootOp.getOperation()] = tiledOp;
840   for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
841     LinalgOp origOp = candidate.value();
842     LinalgOp fusedOp = fuse(b, origOp, fusedLoopsAndRanges);
843     origOpToFusedOp[origOp.getOperation()] = fusedOp;
844     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
845 
846     // Prepare the builder for the next insertion point.
847     auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); });
848     if (!origOp.hasTensorSemantics())
849       continue;
850 
851     // If the producer consumer operations are linalg operations on tensors, the
852     // dependence is due to value produced (as a return tensor) by the producer
853     // and used in the consumer. The returned value of the fused op needs to be
854     // made the operand of the tiled/fused consumer operation. By construction
855     // the value returned by the producer is the value used by the consumer.
856     for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) {
857       if (dependence.dependenceType !=
858           LinalgDependenceGraph::DependenceType::RAW)
859         continue;
860 
861       unsigned resultIndex =
862           dependence.getDependentOpViewResultNum().getValue();
863       LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
864       if (!consumer)
865         continue;
866 
867       Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
868       consumer.getOperation()->setOperand(
869           dependence.getIndexingOpViewOperandNum().getValue(),
870           replacementValue);
871     }
872 
873     // At this point, all Linalg uses of the tensors produced by `origOp` have
874     // been replaced. However, there may still be "output tensor"-like uses
875     // coming from WAW dependencies.
876     // All these uses are iter_args of the outermost loop (TODO: add a check).
877     // Such iter_args uses serve 2 purposes:
878     //  1. give a shape to the output
879     //  2. encode destructive updates that may be inplaceable by bufferization.
880     // To keep the second type of information while letting the unfused op die
881     // unused, we need to forward the producer output operand.
882     if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
883       for (auto &operand : forOp.getIterOpOperands()) {
884         if (auto opResult = operand.get().dyn_cast<OpResult>()) {
885           if (opResult.getOwner() == origOp) {
886             Value output =
887                 origOp.getOutputOperand(opResult.getResultNumber())->get();
888             assert(output.getType().isa<RankedTensorType>());
889             operand.set(output);
890           }
891         }
892       }
893     }
894   }
895   return fusedOps;
896 }
897 
898 static Optional<TiledAndFusedLinalgOps>
899 tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
900                          const LinalgDependenceGraph &dependenceGraph,
901                          const LinalgTilingOptions &tilingOptions) {
902   if (ops.size() < 2)
903     return llvm::None;
904   LinalgOp rootOp = ops.back();
905   if (!llvm::all_of(
906           ops,
907           [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) &&
908       !llvm::all_of(ops, [](LinalgOp linalgOp) {
909         return linalgOp.hasTensorSemantics();
910       })) {
911     rootOp.emitError(
912         "unable to fuse operations that have tensor semantics with operations "
913         "that have buffer semantics and viceversa.");
914     return llvm::None;
915   }
916   // TODO: Support interchange with tile + fuse. This might actually help do
917   // better fusion.
918   if (!tilingOptions.interchangeVector.empty()) {
919     rootOp.emitRemark("unable to handle tile and fuse with interchange");
920     return llvm::None;
921   }
922 
923   OpBuilder::InsertionGuard guard(b);
924   b.setInsertionPoint(rootOp);
925 
926   // Find all the producers.
927   LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n");
928   FusableOpDependencesTy fusableDependences =
929       findAllFusableDependences(ops, dependenceGraph);
930   if (fusableDependences.empty()) {
931     LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n");
932     return llvm::None;
933   }
934 
935   TiledAndFusedLinalgOps ret;
936   // Find the loops that can be tiled and fused.
937   LLVM_DEBUG(llvm::dbgs() << "collectFusableLoops\n");
938   ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
939 
940   // If there are no fusable dependences or there are no tile+fusable loops,
941   // just return.
942   if (ret.fusedLoopDims.empty()) {
943     LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n");
944     return llvm::None;
945   }
946 
947   // Tile the fused loops in the last operation in the list.
948   SmallVector<Value, 4> tileSizeVector =
949       tilingOptions.tileSizeComputationFunction(b, rootOp);
950   Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
951       b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
952   if (!tiledRootOp) {
953     rootOp.emitRemark("failed to tile the fused loops");
954     return llvm::None;
955   }
956   ret.op = tiledRootOp->op;
957   ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
958 
959   // Fuse the other operations into the fused inter-tile loops produced above.
960   ret.fusedProducers = fuseOperations(b, rootOp, *tiledRootOp, ops.drop_back(),
961                                       fusableDependences, ret.fusedLoopDims);
962 
963   return ret;
964 }
965 
966 Optional<TiledAndFusedLinalgOps>
967 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops,
968                                    const LinalgDependenceGraph &dependenceGraph,
969                                    const LinalgTilingOptions &tilingOptions) {
970   switch (tilingOptions.loopType) {
971   case LinalgTilingLoopType::Loops:
972   case LinalgTilingLoopType::ParallelLoops:
973   case LinalgTilingLoopType::TiledLoops:
974     return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions);
975   default:;
976   }
977   return llvm::None;
978 }
979