xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision 37e0fdd072a95b51bcd0eb6b08d2762aa304e766)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Dominance.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Support/LLVM.h"
28 #include "mlir/Transforms/FoldUtils.h"
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "linalg-fusion"
34 
35 using namespace mlir;
36 using namespace mlir::edsc;
37 using namespace mlir::edsc::intrinsics;
38 using namespace mlir::linalg;
39 
40 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
41 
42 using llvm::dbgs;
43 
44 /// Implements a simple high-level fusion pass on linalg structured operations.
45 ///
46 /// In each block, linalg ops are processed in reverse textual order.
47 /// Given a linalg op `O`, fusion occurs by:
48 ///   1. inspecting the linalg ops that write into the views read by `O`. There
49 ///      are 2 cases:
50 ///      a) buffer case: use the SSA value of the views and a simple alias
51 ///         analysis on subview ops to determine producer-consumer dependences;
52 ///      b) tensor case: use SSA use-def chains on subtensor ops;
53 ///   2. greedily fuse the linalg ops that produce the subview/subtensor.
54 ///   3. inspect the fused ops and determine whether they have other remaining
55 ///      LinalgOp uses. If not, then erase the original producing linalg op.
56 ///
57 /// More advanced use cases, analyses as well as profitability heuristics are
58 /// left for future work.
59 
60 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
61 // by `permutationMap`.
62 static void inferShapeComponents(AffineMap permutationMap,
63                                  ArrayRef<Range> loopRanges,
64                                  SmallVectorImpl<Value> &offsets,
65                                  SmallVectorImpl<Value> &sizes,
66                                  SmallVectorImpl<Value> &strides) {
67   assert(permutationMap.isProjectedPermutation() &&
68          "expected some subset of a permutation map");
69   SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
70   unsigned idx = 0;
71   for (AffineExpr e : permutationMap.getResults()) {
72     // loopToOperandRangesMaps are permutations-only, just swap indices.
73     unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
74     shapeRanges[idx++] = loopRanges[loopPos];
75   }
76   // Construct a new subshape for the tile.
77   unsigned rank = shapeRanges.size();
78   offsets.reserve(rank);
79   sizes.reserve(rank);
80   strides.reserve(rank);
81   for (auto r : shapeRanges) {
82     offsets.push_back(r.offset);
83     sizes.push_back(r.size);
84     strides.push_back(r.stride);
85   }
86 }
87 
88 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
89 // a subset of the original loop ranges of `op`.
90 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
91 // to the `loopRanges` in order to obtain view ranges.
92 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
93                                     ArrayRef<Range> loopRanges) {
94   SmallVector<Value, 8> clonedShapes;
95   clonedShapes.reserve(op.getNumShapedOperands());
96 
97   // Iterate over the shape operands in order.
98   // Extract the subranges from the linearized ranges.
99   for (auto en : llvm::enumerate(op.getShapedOperands())) {
100     unsigned shapedOperandIdx = en.index();
101     AffineMap map = op.getIndexingMap(shapedOperandIdx);
102     LLVM_DEBUG(dbgs() << "shapedOperandIdx: " << shapedOperandIdx
103                       << " with indexingMap: " << map << "\n");
104     SmallVector<Value, 4> offsets, sizes, strides;
105     inferShapeComponents(map, loopRanges, offsets, sizes, strides);
106     Value shape = en.value();
107     Value sub = shape.getType().isa<MemRefType>()
108                     ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
109                           .getResult()
110                     : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
111                           .getResult();
112     clonedShapes.push_back(sub);
113   }
114   // Append the other operands.
115   auto operands = op.getAssumedNonShapedOperands();
116   clonedShapes.append(operands.begin(), operands.end());
117 
118   // Iterate over the results in order.
119   // Extract the subtensor type from the linearized range.
120   // Since we do not enforce any canonicalizations on the fly, this is always
121   // fully dynamic at construction time.
122   SmallVector<Type, 4> resultTypes;
123   resultTypes.reserve(op.getOperation()->getNumResults());
124   for (RankedTensorType t : op.getOutputTensorTypes()) {
125     unsigned rank = t.getRank();
126     SmallVector<int64_t, 4> staticOffsetsVector(
127         rank, ShapedType::kDynamicStrideOrOffset);
128     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
129     SmallVector<int64_t, 4> staticStridesVector(
130         rank, ShapedType::kDynamicStrideOrOffset);
131     resultTypes.push_back(SubTensorOp::inferResultType(
132         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
133         staticStridesVector));
134   }
135 
136   Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
137   // When the producer is an IndexedGenericOp, we have to transform its block
138   // IV arguments according to the tiling of the consumer, i.e. offset them by
139   // the values computed in `loopRanges`.
140   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
141     auto &block = indexedGenericOp.region().front();
142     OpBuilder::InsertionGuard g(b);
143     b.setInsertionPointToStart(&block);
144     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
145       Value oldIndex = block.getArgument(i);
146       // TODO: replace by an affine_apply.
147       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
148                                          loopRanges[i].offset);
149       oldIndex.replaceAllUsesExcept(newIndex,
150                                     SmallPtrSet<Operation *, 1>{newIndex});
151     }
152   }
153 
154   return clonedOp;
155 }
156 
157 struct ShapeDimension {
158   Value shape;
159   unsigned dimension;
160 };
161 
162 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
163 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
164 // guarantees at least one such dimension is found. If multiple candidates exist
165 // they must agree by construction (i.e. have the same size) and we just return
166 // the first one.
167 static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
168                                                 unsigned loopDepth) {
169   auto maps = op.indexing_maps();
170   // Iterate over the inputs and outputs in order.
171   // Extract the subranges from the linearized ranges.
172   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
173   for (auto en : llvm::enumerate(ios)) {
174     unsigned idx = en.index();
175     auto map = maps[idx].cast<AffineMapAttr>().getValue();
176     LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
177     LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange map: " << map << "\n");
178     Value shape = en.value();
179     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
180     for (auto en2 : llvm::enumerate(map.getResults())) {
181       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
182         LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange loopDepth: "
183                           << loopDepth << "\n");
184         LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange shape: " << shape
185                           << "\n");
186         return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
187       }
188     }
189   }
190   llvm_unreachable("Expect to be able to extract a shape defining loop range");
191 }
192 
193 /// Fuses the producer of `producerIdx` into the loop immediately enclosing
194 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it
195 /// is needed just before the `consumer.
196 ///
197 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
198 /// 2 cases:
199 ///   1. Buffer case: `producerIdx` is the index of the buffer in
200 ///      `producer.getOutputBuffers()`.
201 ///   2. Tensor case: `producerIdx` is the index of the tensor in
202 ///      `producer.getResults()`.
203 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
204                      LinalgOp consumer, unsigned consumerIdx,
205                      OperationFolder *folder = nullptr) {
206   Operation *shapeProducingOp =
207       consumer.getShapedOperand(consumerIdx).getDefiningOp();
208   assert((isa<SubViewOp>(shapeProducingOp) ||
209           isa<SubTensorOp>(shapeProducingOp)) &&
210          "SubviewOp or SubTensorOp expected");
211 
212   // loopToOperandRangesMaps are permutations-only by construction:
213   //   we can always identify a data dimension with a (at least one) loop
214   //   dimension.
215   // TODO: extend this with range inference.
216   AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
217   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
218                     << ", producer map: " << producerMap << "\n");
219 
220   unsigned nPar = producer.getNumParallelLoops();
221   unsigned nRed = producer.getNumReductionLoops();
222   unsigned nWin = producer.getNumWindowLoops();
223   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
224 
225   // Iterate over dimensions identified by the producer map for `producerIdx`.
226   // This defines a subset of the loop ranges that we need to complete later.
227   auto loc = consumer.getLoc();
228   for (auto en : llvm::enumerate(producerMap.getResults())) {
229     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
230     loopRanges[posInProducerLoop] =
231         isa<SubViewOp>(shapeProducingOp)
232             ? cast<SubViewOp>(shapeProducingOp)
233                   .getOrCreateRanges(b, loc)[en.index()]
234             : cast<SubTensorOp>(shapeProducingOp)
235                   .getOrCreateRanges(b, loc)[en.index()];
236   }
237 
238   // Iterate over all dimensions. For the dimensions not identified by the
239   // producer map for `producerIdx`, we need to explicitly compute the shape
240   // that defines the loop ranges using the `producer`.
241   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
242     if (loopRanges[i].offset)
243       LLVM_DEBUG(llvm::dbgs()
244                  << "existing LoopRange: " << loopRanges[i] << "\n");
245     else {
246       auto shapeDim = getShapeDefiningLoopRange(producer, i);
247       loopRanges[i] = Range{folded_std_constant_index(folder, 0),
248                             std_dim(shapeDim.shape, shapeDim.dimension),
249                             folded_std_constant_index(folder, 1)};
250       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
251     }
252   }
253 
254   return cloneWithLoopRanges(b, loc, producer, loopRanges);
255 }
256 
257 // Encode structural fusion safety preconditions.
258 // Some of these will be lifted in the future with better analysis.
259 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
260                                           LinalgOp consumer) {
261   assert(producer.hasBufferSemantics() &&
262          "expected linalg op with buffer semantics");
263   assert(consumer.hasBufferSemantics() &&
264          "expected linalg op with buffer semantics");
265   if (producer.getNumOutputs() != 1) {
266     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
267     return false;
268   }
269   // Only fuse when the producer block dominates.
270   DominanceInfo dom(producer.getOperation());
271   if (!dom.dominates(producer.getOperation()->getBlock(),
272                      consumer.getOperation()->getBlock())) {
273     LLVM_DEBUG(
274         dbgs()
275         << "\nNot structurally fusable (producer block does not dominate)");
276     return false;
277   }
278   return true;
279 }
280 
281 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
282                                              LinalgOp consumer,
283                                              Value consumedView,
284                                              LinalgOp producer) {
285   assert(producer.hasBufferSemantics() &&
286          "expected linalg op with buffer semantics");
287   assert(consumer.hasBufferSemantics() &&
288          "expected linalg op with buffer semantics");
289   // Make some simple structural checks that alleviate the need for more
290   // complex analyses.
291   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
292     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
293                       << *producer.getOperation());
294     return false;
295   }
296   // Check for any interleaved write to consumedView.
297   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
298     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
299                       << *producer.getOperation());
300     return false;
301   }
302   return true;
303 }
304 
305 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
306                                  LinalgOp consumer, Value consumedView,
307                                  LinalgOp producer) {
308   assert(producer.hasBufferSemantics() &&
309          "expected linalg op with buffer semantics");
310   assert(consumer.hasBufferSemantics() &&
311          "expected linalg op with buffer semantics");
312   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
313     return false;
314   // Check for any fusion-preventing dependence to any shape read/written that
315   // would violate dependences.
316   if (!graph.findCoveringDependences(producer, consumer).empty()) {
317     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
318                       << *producer.getOperation());
319     return false;
320   }
321   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
322     // TODO: add a level of indirection to linalg.generic.
323     if (convOp.padding())
324       return false;
325   }
326   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
327     // TODO: add a level of indirection to linalg.generic.
328     if (convOp.padding())
329       return false;
330   }
331   return true;
332 }
333 
334 static bool isSameSubView(Value a, Value b) {
335   if (a == b)
336     return true;
337   auto sva = a.getDefiningOp<SubViewOp>();
338   auto svb = b.getDefiningOp<SubViewOp>();
339   if (!sva || !svb)
340     return false;
341   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
342     return false;
343   if (sva.getType() != svb.getType())
344     return false;
345   if (sva.getNumOperands() != svb.getNumOperands())
346     return false;
347   if (sva.static_offsets() != svb.static_offsets())
348     return false;
349   if (sva.static_sizes() != svb.static_sizes())
350     return false;
351   if (sva.static_strides() != svb.static_strides())
352     return false;
353   /// Skip the "source" operand.
354   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
355     if (sva.getOperand(idx) != svb.getOperand(idx))
356       return false;
357   return true;
358 }
359 
360 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
361 findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
362                     const LinalgDependenceGraph &dependenceGraph) {
363   // Only consider RAW and WAW atm.
364   for (auto depType : {
365            LinalgDependenceGraph::DependenceType::RAW,
366            LinalgDependenceGraph::DependenceType::WAW,
367        }) {
368     for (auto dependence :
369          dependenceGraph.getDependencesInto(consumer, depType)) {
370       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
371 
372       // Check that the dependence is indeed on the input `consumerIdx` view.
373       auto consumedView = dependence.indexingView;
374       if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
375         continue;
376 
377       // Consumer consumes this view, `isStructurallyFusableProducer` also
378       // checks whether it is a strict subview of the producer view.
379       auto producedView = dependence.dependentOpView.view;
380       auto producerIdx =
381           producer.getIndexOfOutputBuffer(producedView).getValue();
382       // `consumerIdx` and `producerIdx` exist by construction.
383       LLVM_DEBUG(dbgs() << "\n"
384                         << LinalgDependenceGraph::getDependenceTypeStr(depType)
385                         << "producer: " << *producer.getOperation() << " view: "
386                         << producedView << " output index: " << producerIdx);
387       (void)producerIdx;
388 
389       // Simple fusability checks.
390       if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
391         continue;
392 
393       return dependence;
394     }
395   }
396   return {};
397 }
398 
399 Optional<FusionInfo> mlir::linalg::fuseProducerOfBuffer(
400     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
401     const LinalgDependenceGraph &graph, OperationFolder *folder) {
402   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
403       findFusableProducer(consumer, consumerIdx, graph);
404   if (!fusableDependence)
405     return {};
406 
407   LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
408   Value producerView = fusableDependence->dependentOpView.view;
409   Value consumerView = fusableDependence->indexingView;
410 
411   // Must be a subview or a slice to guarantee there are loops we can fuse
412   // into.
413   auto subView = consumerView.getDefiningOp<SubViewOp>();
414   auto slice = consumerView.getDefiningOp<SliceOp>();
415   if (!subView && !slice) {
416     LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
417     return {};
418   }
419 
420   // Fuse `producer` just before `consumer`.
421   OpBuilder::InsertionGuard g(b);
422   b.setInsertionPoint(consumer.getOperation());
423   ScopedContext scope(b, consumer.getLoc());
424   LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
425   Optional<unsigned> producerIdxOpt =
426       producerOp.getIndexOfOutputBuffer(producerView);
427   assert(producerIdxOpt.hasValue() && "incorrect operand index");
428   unsigned producerIdx = producerIdxOpt.getValue();
429 
430   auto fusedProducer =
431       fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
432   return FusionInfo{producerOp, fusedProducer};
433 }
434 
435 /// Walk back use-def chain through scf::For yields.
436 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
437 static void getProducerOfTensor(Value tensor, LinalgOp &producer,
438                                 unsigned &outputIndex) {
439   if (!tensor.getType().isa<RankedTensorType>())
440     return;
441 
442   while (true) {
443     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
444       producer = linalgOp;
445       outputIndex = tensor.cast<OpResult>().getResultNumber();
446       return;
447     }
448     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
449       tensor = subTensorOp.source();
450       continue;
451     }
452     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
453       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
454         tensor = forOp.getResult(blockArg.getArgNumber());
455         continue;
456       }
457     }
458     return;
459   }
460 }
461 
462 Optional<FusionInfo>
463 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
464                                    unsigned consumerIdx,
465                                    OperationFolder *folder) {
466   Value inputTensor = consumer.getInput(consumerIdx);
467   LinalgOp producerOp;
468   unsigned producerIdx;
469   getProducerOfTensor(inputTensor, producerOp, producerIdx);
470 
471   // Must be a subtensor to guarantee there are loops we can fuse into.
472   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
473   if (!subTensor || !producerOp) {
474     LLVM_DEBUG(dbgs() << "\nNot fusable (not a subtensor)");
475     return {};
476   }
477 
478   // Insert fused `producer` just before `consumer`.
479   OpBuilder::InsertionGuard g(b);
480   b.setInsertionPoint(consumer.getOperation());
481   ScopedContext scope(b, consumer.getLoc());
482   LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
483   LinalgOp fusedProducer =
484       fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
485 
486   // Replace use.
487   // Canonicalizations are not guaranteed to have happened before constructing
488   // `fusedProducer`. In the tensor case this can result in temporary type
489   // mismatches. Insert a `tensor_cast` op to propagate the transformation
490   // invariant that types are compatible.
491   Value def = fusedProducer.getOperation()->getResult(producerIdx);
492   OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx);
493   Type consumerType = use.get().getType();
494   if (consumerType != def.getType())
495     def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
496   use.set(def);
497   return FusionInfo{producerOp, fusedProducer};
498 }
499 
500 /// Returns the positions of the loop in `op` that can be tiled based on the
501 /// operations that are to be fused with it. For example, in a
502 ///
503 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
504 ///
505 /// if the producer of %a needs to be fused with this op, only the `i` loop of
506 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
507 /// fused, then no loops can be tiled while fusing.
508 static DenseSet<unsigned> collectTileAndFuseLoops(
509     LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem>
510                      fusableDependences) {
511   // 1. Only parallel loops can be used for tile + fuse. Find the number of
512   // common outer parallel loops between the op and its producers being fused.
513   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
514     return linalgOp.iterator_types()
515         .getValue()
516         .take_while([](Attribute attr) -> bool {
517           return attr.cast<StringAttr>().getValue() ==
518                  getParallelIteratorTypeName();
519         })
520         .size();
521   };
522 
523   size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
524   for (auto dependence : fusableDependences) {
525     numOuterParallelLoops =
526         std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>(
527                                             dependence.dependentOpView.op)));
528   }
529 
530   // Need to compute what tiled loops can be "fused". Given the precondition
531   // that all indexing map for the producer view is a projected permutation, we
532   // can assert that the producer iterates over the dimensions of the "fused
533   // view" only once. To be used a fused loop the producer should use this loop
534   // to access the fused view. For example, consider
535   //
536   // ```
537   //   linalg.add ins(%a, %b) outs(%c)
538   //   linalg.matmul ins(%d, %c) outs(%e)
539   // ```
540   //
541   // if `linalg.add` has the semantics of `c = a + b`, then the following
542   // tile+fuse code is correct.
543   //
544   // ```
545   // for j ... += TSj
546   //   %sa = subview %a[0, %j][...]
547   //   %sb = subview %b[0, %j][...]
548   //   %sc = subview %c[0, %j][...]
549   //   %sd = subview %d[0, 0][...]
550   //   %se = subview %e[0, %j][...]
551   //   linalg.add ins(%sa, %sb) outs(%sc)
552   //   linalg.matmul ins(%sd, %sc) outs(%se)
553   // ```
554   //
555   // On the other hand tiling along i would be incorrect
556   //
557   // ```
558   // for %i .. += TSi
559   //   %sa = subview %a[%i, 0][...]
560   //   %sb = subview %b[%i, 0][...]
561   //   %sc = subview %c[%i, 0][...]
562   //   %sc2 = subview %c[0, 0][...]
563   //   %sd = subview %d[%i, 0][...]
564   //   %se = subview %e[%i, 0][...]
565   //   linalg.add ins(%sa, %sb) outs(%sc)
566   //   linalg.matmul ins(%sd, %sc2) outs(%se)
567   // ```
568   //
569   // The write to the subview `%sc` in `linalg.add` is performed after the read
570   // from it using `%sc2` violating the RAW dependence of the original code. To
571   // find such loops indexing map of the fused view in the consumer op is
572   // used. For the above example, this indexing map is
573   //
574   //   affine_map<(d0, d1, d2) -> (d2, d1)>
575   //
576   // Since d0 is not in the result expressions of this map, it is not treated as
577   // tile + fuse loop, (but d1 is).
578   //
579   // TODO: The above is probably restrictive and there might be a generalization
580   // of these that might allow for more fusion opportunities. Explore based on
581   // needs.
582   SmallVector<DenseSet<unsigned>, 1> commonTilableLoops;
583   for (auto dependence : fusableDependences) {
584     unsigned consumerIdx =
585         op.getIndexOfShapedOperand(dependence.indexingView).getValue();
586     AffineMap consumerAccess = op.getIndexingMap(consumerIdx);
587     // Previously asserted that the consumerAccess map is a projected
588     // permutation, so all results are known to be AffineDimExprs. To remove
589     // this restriction walk the expression to find which dimensions of the
590     // consumer loop appear in the `consumerAccess`.
591     DenseSet<unsigned> positions;
592     for (auto expr : consumerAccess.getResults())
593       positions.insert(expr.cast<AffineDimExpr>().getPosition());
594     commonTilableLoops.emplace_back(std::move(positions));
595   }
596 
597   // 2. Of the outer parallel loops, only those loops can be tiled + fused as
598   // computed above for all the fused dependences can be used to tile and fuse.
599   DenseSet<unsigned> tilableParallelLoops;
600   for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) {
601     if (llvm::all_of(commonTilableLoops,
602                      [&](const DenseSet<unsigned> &tilableLoops) {
603                        return tilableLoops.count(index);
604                      }))
605       tilableParallelLoops.insert(index);
606   }
607   return tilableParallelLoops;
608 }
609 
610 /// Find all dependences that are to be fusable.
611 static Optional<
612     SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
613 findAllFusableDependences(LinalgOp op,
614                           const LinalgDependenceGraph &dependenceGraph,
615                           const LinalgFusionOptions &fusionOptions) {
616   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>
617       fusableDependences;
618   for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) {
619     if (fusionOptions.indicesToFuse &&
620         !fusionOptions.indicesToFuse->count(operand.index()))
621       continue;
622     Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
623         fusableDependence =
624             findFusableProducer(op, operand.index(), dependenceGraph);
625     if (!fusableDependence)
626       continue;
627     // Make sure that the indexing map of the view used for fusion in the
628     // producer is a projected permutation.
629     LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
630     Value producerView = fusableDependence->dependentOpView.view;
631     unsigned producerIdx =
632         producerOp.getIndexOfOutputBuffer(producerView).getValue();
633     AffineMap producerMap = producerOp.getOutputIndexingMap(producerIdx);
634     if (!producerMap.isProjectedPermutation()) {
635       op.emitError("unhandled non permutation indexing map for fused view in "
636                    "producer for operand at index ")
637           << operand.index();
638       return llvm::None;
639     }
640     Value consumerView = fusableDependence->indexingView;
641     unsigned consumerIdx = op.getIndexOfShapedOperand(consumerView).getValue();
642     if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) {
643       op.emitError(
644           "unhandled case where indexing map for fused view in the consumer is "
645           "not a projected permuration while fusing at index ")
646           << operand.index();
647       return llvm::None;
648     }
649     fusableDependences.push_back(*fusableDependence);
650     if (!fusionOptions.indicesToFuse)
651       break;
652   }
653   return fusableDependences;
654 }
655 
656 static bool isZero(Value v) {
657   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
658     return cst.getValue() == 0;
659   return false;
660 }
661 
662 template <typename LoopType>
663 static Optional<TiledAndFusedLinalgOps>
664 tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
665                          const LinalgDependenceGraph &dependenceGraph,
666                          const LinalgTilingOptions &tilingOptions,
667                          const LinalgFusionOptions &fusionOptions) {
668   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
669   // Some of the tiling options might not be supportable with tile and fuse.
670   // TODO: Support interchange with tile + fuse.
671   if (!tilingOptions.interchangeVector.empty()) {
672     op.emitError("unable to handle tile and fuse with interchange");
673     return llvm::None;
674   }
675 
676   OpBuilder::InsertionGuard g(rewriter);
677   rewriter.setInsertionPoint(op);
678   ScopedContext scope(rewriter, op.getLoc());
679 
680   // Find all the producers.
681   Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
682       fusableDependencesOpt =
683           findAllFusableDependences(op, dependenceGraph, fusionOptions);
684   if (!fusableDependencesOpt)
685     return llvm::None;
686   ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences(
687       *fusableDependencesOpt);
688 
689   // Enforce the convention that "tiling by zero" skips tiling a particular
690   // dimension. This convention is significantly simpler to handle instead of
691   // adjusting affine maps to account for missing dimensions.
692   auto nLoops = op.getNumLoops();
693   SmallVector<Value, 4> tileSizeVector =
694       tilingOptions.tileSizeComputationFunction(rewriter, op);
695   if (tileSizeVector.size() < nLoops) {
696     auto zero = std_constant_index(0);
697     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
698   }
699 
700   TiledAndFusedLinalgOps ret;
701 
702   // Find the loops that can be tiled and fused.
703   DenseSet<unsigned> tileFuseLoops =
704       collectTileAndFuseLoops(op, fusableDependences);
705 
706   // If there are no fusable dependences or there are no tile+fusable loops,
707   // just return.
708   if (fusableDependences.empty() || tileFuseLoops.empty()) {
709     return llvm::None;
710   }
711 
712   // Get the tile sizes for the first and second tiling steps. For the first
713   // step the tile size are set to zero for the loops that arent
714   // fused. Similarly for the second step, the tile sizes are set to zero for
715   // the loops that are fused. For example, if for the following input
716   //
717   // ```
718   //   linalg.add ins(%a, %b) outs(%c)
719   //   linalg.matmul ins(%d, %c) outs(%e)
720   // ```
721   //
722   // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
723   // respectively, and since only `j` can be tiled and fused. The tile sizes
724   // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
725   // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
726   // the tiled matmul generated by the first tiling step.
727   SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
728   for (auto tileSize : enumerate(tileSizeVector)) {
729     auto zero = std_constant_index(0);
730     if (tileFuseLoops.count(tileSize.index())) {
731       tileAndFuseSizes.push_back(tileSize.value());
732       tileSizes.push_back(zero);
733     } else {
734       tileSizes.push_back(tileSize.value());
735       tileAndFuseSizes.push_back(zero);
736     }
737   }
738 
739   // Tile for the loops that can be fused.
740   LinalgTilingOptions firstTilingOptions = tilingOptions;
741   firstTilingOptions.setTileSizes(tileAndFuseSizes);
742   Optional<TiledLinalgOp> firstTiledOp =
743       tileLinalgOp(rewriter, op, firstTilingOptions);
744   if (!firstTiledOp)
745     return llvm::None;
746   ret.op = firstTiledOp->op;
747   ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
748 
749   rewriter.setInsertionPoint(ret.op);
750   // Fuse the operands.
751   for (auto producer : enumerate(fusableDependences)) {
752     LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op);
753     unsigned producerIdx =
754         producerOp.getIndexOfOutputBuffer(producer.value().dependentOpView.view)
755             .getValue();
756     unsigned consumerIdx =
757         op.getIndexOfShapedOperand(producer.value().indexingView).getValue();
758     LinalgOp fusedOp =
759         fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx);
760     ret.fusedProducers.push_back(fusedOp);
761     ret.originalProducers.push_back(producerOp);
762   }
763 
764   if (!llvm::all_of(tileSizes, isZero)) {
765     // Tile the remaining loops of the root operation.
766     LinalgTilingOptions secondTilingOptions = tilingOptions;
767     // The distribution is done only for the tile+fused loops.
768     secondTilingOptions.distribution = llvm::None;
769     secondTilingOptions.setTileSizes(tileSizes);
770     Optional<TiledLinalgOp> secondTiledOp =
771         tileLinalgOp(rewriter, ret.op, secondTilingOptions);
772     if (!secondTiledOp)
773       return llvm::None;
774     ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
775                             secondTiledOp->loops.end());
776     rewriter.eraseOp(ret.op);
777     ret.op = secondTiledOp->op;
778   }
779 
780   return ret;
781 }
782 
783 Optional<TiledAndFusedLinalgOps>
784 mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
785                                    const LinalgDependenceGraph &dependenceGraph,
786                                    const LinalgTilingOptions &tilingOptions,
787                                    const LinalgFusionOptions &fusionOptions) {
788   switch (tilingOptions.loopType) {
789   case LinalgTilingLoopType::Loops:
790     return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
791                                                 tilingOptions, fusionOptions);
792   case LinalgTilingLoopType::ParallelLoops:
793     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
794         rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
795   default:;
796   }
797   return llvm::None;
798 }
799 
800 static void fuseLinalgOpsGreedily(FuncOp f) {
801   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
802 
803   OpBuilder b(f);
804   OperationFolder folder(f.getContext());
805   DenseSet<Operation *> eraseSet;
806 
807   // Save original Linalg ops, we only want to make a pass over those.
808   SmallVector<Operation *, 8> linalgOps;
809   f.walk([&](LinalgOp op) {
810     // TODO: support multi-results.
811     if (op.getOperation()->getNumResults() <= 1)
812       linalgOps.push_back(op);
813   });
814 
815   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
816   for (auto *op : llvm::reverse(linalgOps)) {
817     LinalgOp linalgOp = cast<LinalgOp>(op);
818     for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
819       if (en.value().getType().isa<MemRefType>()) {
820         // TODO: LinalgDependenceGraph should be able to update itself.
821         // The current naive and expensive reconstruction of the graph should be
822         // removed.
823         linalg::Aliases aliases;
824         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
825         if (auto info =
826                 fuseProducerOfBuffer(b, op, en.index(), graph, &folder)) {
827           auto *originalOp = info->originalProducer.getOperation();
828           eraseSet.insert(originalOp);
829           auto *originalOpInLinalgOpsVector =
830               std::find(linalgOps.begin(), linalgOps.end(), originalOp);
831           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
832         }
833       } else {
834         assert(en.value().getType().isa<RankedTensorType>());
835         // Tile and Fuse tensor input (TODO: init_tensors too).
836         if (en.index() >= linalgOp.getNumInputs())
837           continue;
838         if (auto info = fuseProducerOfTensor(b, op, en.index(), &folder)) {
839           auto *originalOp = info->originalProducer.getOperation();
840           auto *originalOpInLinalgOpsVector =
841               std::find(linalgOps.begin(), linalgOps.end(), originalOp);
842           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
843           // Don't mark for erasure in the tensor case, let DCE handle this.
844         }
845       }
846     }
847   }
848   // The `fuseProducerOfBuffer` function performs structural checks and in
849   // particular that no covering read or write exist between the consumer and
850   // the producer. As a consequence, the only fusions that may occur preserve
851   // subsequent dependences and are guaranteed by construction to produce the
852   // whole view. We may thus erase the producer once it is fused.
853   for (auto *e : eraseSet)
854     e->erase();
855 
856   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
857 }
858 
859 namespace {
860 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
861   void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
862 };
863 } // namespace
864 
865 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
866   return std::make_unique<LinalgFusionPass>();
867 }
868