xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision e3de249a4c94d6962b36c2b4747c134d152bed37)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Dominance.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Support/LLVM.h"
28 #include "mlir/Transforms/FoldUtils.h"
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "linalg-fusion"
34 
35 using namespace mlir;
36 using namespace mlir::edsc;
37 using namespace mlir::edsc::intrinsics;
38 using namespace mlir::linalg;
39 
40 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
41 
42 using llvm::dbgs;
43 
44 /// Implements a simple high-level fusion pass of linalg library operations.
45 ///
46 /// In each block, linalg ops are processed in reverse textual order.
47 /// Given a linalg op `O`, fusion occurs by:
48 ///   1. inspecting the linalg ops that write into the views read by `O`. This
49 ///      uses the SSA value of the views and a simple subview/slice analysis to
50 ///      determine producer-consumer dependences;
51 ///   2. greedily fuse the linalg ops that produce subview
52 ///   3. inspect the fused ops and determine whether they have other remaining
53 ///      LinalgOp uses. If not, then erase the original producing linalg op.
54 ///
55 /// More advanced use cases, analyses as well as profitability heuristics are
56 /// left for future work.
57 
58 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
59 // a subset of the original loop ranges of `op`.
60 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
61 // to the `loopRanges` in order to obtain view ranges.
62 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
63                                     ArrayRef<Range> loopRanges) {
64   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
65   auto maps = op.indexing_maps();
66   SmallVector<Value, 8> clonedViews;
67   clonedViews.reserve(op.getNumInputsAndOutputs());
68   // Iterate over the inputs and outputs in order.
69   // Extract the subranges from the linearized ranges.
70   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
71   for (auto en : llvm::enumerate(ios)) {
72     unsigned idx = en.index();
73     auto map = maps[idx].cast<AffineMapAttr>().getValue();
74     LLVM_DEBUG(dbgs() << "map: " << map << "\n");
75     Value view = en.value();
76     SmallVector<Range, 4> viewRanges(map.getNumResults());
77     for (auto en2 : llvm::enumerate(map.getResults())) {
78       unsigned d = en2.index();
79       // loopToOperandRangesMaps are permutations-only.
80       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
81       viewRanges[d] = loopRanges[loopPos];
82       LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
83                         << "\t"
84                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
85     }
86     // Construct a new subview for the tile.
87     unsigned rank = viewRanges.size();
88     SmallVector<Value, 4> offsets, sizes, strides;
89     offsets.reserve(rank);
90     sizes.reserve(rank);
91     strides.reserve(rank);
92     for (auto r : viewRanges) {
93       offsets.push_back(r.offset);
94       sizes.push_back(r.size);
95       strides.push_back(r.stride);
96     }
97     clonedViews.push_back(
98         b.create<SubViewOp>(loc, view, offsets, sizes, strides));
99   }
100   auto operands = getAssumedNonViewOperands(op);
101   clonedViews.append(operands.begin(), operands.end());
102 
103   Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews);
104   // When the producer is an IndexedGenercOp, we have to transform its block
105   // IV arguments according to the tiling of the consumer, i.e. offset them by
106   // the values computed in `loopRanges`.
107   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
108     auto &block = indexedGenericOp.region().front();
109 
110     OpBuilder::InsertionGuard g(b);
111     b.setInsertionPointToStart(&block);
112     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
113       Value oldIndex = block.getArgument(i);
114       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
115                                          loopRanges[i].offset);
116       oldIndex.replaceAllUsesExcept(newIndex,
117                                     SmallPtrSet<Operation *, 1>{newIndex});
118     }
119   }
120   return clonedOp;
121 }
122 
123 struct ViewDimension {
124   Value view;
125   unsigned dimension;
126 };
127 
128 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies
129 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
130 // guarantees at least one such dimension is found. If multiple candidates exist
131 // they must agree by construction (i.e. have the same size) and we just return
132 // the first one.
133 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
134   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
135   auto maps = op.indexing_maps();
136   // Iterate over the inputs and outputs in order.
137   // Extract the subranges from the linearized ranges.
138   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
139   for (auto en : llvm::enumerate(ios)) {
140     unsigned idx = en.index();
141     auto map = maps[idx].cast<AffineMapAttr>().getValue();
142     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
143     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
144     Value view = en.value();
145     SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
146     for (auto en2 : llvm::enumerate(map.getResults())) {
147       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
148         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
149                           << "\n");
150         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
151         return ViewDimension{view, static_cast<unsigned>(en2.index())};
152       }
153     }
154   }
155   llvm_unreachable("Expect to be able to extract a view defining loop range");
156 }
157 
158 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
159                      LinalgOp consumer, unsigned consumerIdx,
160                      OperationFolder *folder = nullptr) {
161   assert(producer.hasBufferSemantics() &&
162          "expected linalg op with buffer semantics");
163   assert(consumer.hasBufferSemantics() &&
164          "expected linalg op with buffer semantics");
165 
166   auto subView = dyn_cast_or_null<SubViewOp>(
167       consumer.getBuffer(consumerIdx).getDefiningOp());
168   auto slice = dyn_cast_or_null<SliceOp>(
169       consumer.getBuffer(consumerIdx).getDefiningOp());
170   assert(subView || slice);
171   (void)subView;
172   (void)slice;
173 
174   // loopToOperandRangesMaps are permutations-only by construction:
175   //   we can always identify a data dimension with a (at least one) loop
176   //   dimension.
177   AffineMap producerMap =
178       producer.indexing_maps()[producerIdx].cast<AffineMapAttr>().getValue();
179   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
180                     << ", producer map: " << producerMap << "\n");
181 
182   unsigned nPar = producer.getNumParallelLoops();
183   unsigned nRed = producer.getNumReductionLoops();
184   unsigned nWin = producer.getNumWindowLoops();
185   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
186 
187   // Iterate over dimensions identified by the producer map for `producerIdx`.
188   // This defines a subset of the loop ranges that we need to complete later.
189   auto loc = consumer.getLoc();
190   for (auto en : llvm::enumerate(producerMap.getResults())) {
191     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
192     loopRanges[posInProducerLoop] =
193         subView.getOrCreateRanges(b, loc)[en.index()];
194   }
195 
196   // Iterate over all dimensions. For the dimensions not identified by the
197   // producer map for `producerIdx`, we need to explicitly compute the view that
198   // defines the loop ranges using the `producer`.
199   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
200     if (loopRanges[i].offset)
201       LLVM_DEBUG(llvm::dbgs()
202                  << "existing LoopRange: " << loopRanges[i] << "\n");
203     else {
204       auto viewDim = getViewDefiningLoopRange(producer, i);
205       loopRanges[i] = Range{folded_std_constant_index(folder, 0),
206                             std_dim(viewDim.view, viewDim.dimension),
207                             folded_std_constant_index(folder, 1)};
208       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
209     }
210   }
211 
212   return cloneWithLoopRanges(b, loc, producer, loopRanges);
213 }
214 
215 // Encode structural fusion safety preconditions.
216 // Some of these will be lifted in the future with better analysis.
217 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
218                                           LinalgOp consumer) {
219   assert(producer.hasBufferSemantics() &&
220          "expected linalg op with buffer semantics");
221   assert(consumer.hasBufferSemantics() &&
222          "expected linalg op with buffer semantics");
223   if (producer.getNumOutputs() != 1) {
224     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
225     return false;
226   }
227   // Only fuse when the producer block dominates.
228   DominanceInfo dom(producer.getOperation());
229   if (!dom.dominates(producer.getOperation()->getBlock(),
230                      consumer.getOperation()->getBlock())) {
231     LLVM_DEBUG(
232         dbgs()
233         << "\nNot structurally fusable (producer block does not dominate)");
234     return false;
235   }
236   return true;
237 }
238 
239 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
240                                              LinalgOp consumer,
241                                              Value consumedView,
242                                              LinalgOp producer) {
243   assert(producer.hasBufferSemantics() &&
244          "expected linalg op with buffer semantics");
245   assert(consumer.hasBufferSemantics() &&
246          "expected linalg op with buffer semantics");
247   // Make some simple structural checks that alleviate the need for more
248   // complex analyses.
249   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
250     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
251                       << *producer.getOperation());
252     return false;
253   }
254   // Check for any interleaved write to consumedView.
255   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
256     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
257                       << *producer.getOperation());
258     return false;
259   }
260   return true;
261 }
262 
263 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
264                                  LinalgOp consumer, Value consumedView,
265                                  LinalgOp producer) {
266   assert(producer.hasBufferSemantics() &&
267          "expected linalg op with buffer semantics");
268   assert(consumer.hasBufferSemantics() &&
269          "expected linalg op with buffer semantics");
270   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
271     return false;
272   // Check for any fusion-preventing dependence to any view read/written that
273   // would violate dependences.
274   if (!graph.findCoveringDependences(producer, consumer).empty()) {
275     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
276                       << *producer.getOperation());
277     return false;
278   }
279   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
280     // TODO: add a level of indirection to linalg.generic.
281     if (convOp.padding())
282       return false;
283   }
284   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
285     // TODO: add a level of indirection to linalg.generic.
286     if (convOp.padding())
287       return false;
288   }
289   return true;
290 }
291 
292 static bool isSameSubView(Value a, Value b) {
293   if (a == b)
294     return true;
295   auto sva = a.getDefiningOp<SubViewOp>();
296   auto svb = b.getDefiningOp<SubViewOp>();
297   if (!sva || !svb)
298     return false;
299   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
300     return false;
301   if (sva.getType() != svb.getType())
302     return false;
303   if (sva.getNumOperands() != svb.getNumOperands())
304     return false;
305   if (sva.static_offsets() != svb.static_offsets())
306     return false;
307   if (sva.static_sizes() != svb.static_sizes())
308     return false;
309   if (sva.static_strides() != svb.static_strides())
310     return false;
311   /// Skip the "viewSource" operand.
312   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
313     if (sva.getOperand(idx) != svb.getOperand(idx))
314       return false;
315   return true;
316 }
317 
318 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
319 findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
320                     const LinalgDependenceGraph &dependenceGraph) {
321   // Only consider RAW and WAW atm.
322   for (auto depType : {
323            LinalgDependenceGraph::DependenceType::RAW,
324            LinalgDependenceGraph::DependenceType::WAW,
325        }) {
326     for (auto dependence :
327          dependenceGraph.getDependencesInto(consumer, depType)) {
328       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
329 
330       // Check that the dependence is indeed on the input `consumerIdx` view.
331       auto consumedView = dependence.indexingView;
332       if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
333         continue;
334 
335       // Consumer consumes this view, `isStructurallyFusableProducer` also
336       // checks whether it is a strict subview of the producer view.
337       auto producedView = dependence.dependentOpView.view;
338       auto producerIdx =
339           producer.getIndexOfOutputBuffer(producedView).getValue();
340       // `consumerIdx` and `producerIdx` exist by construction.
341       LLVM_DEBUG(dbgs() << "\n"
342                         << LinalgDependenceGraph::getDependenceTypeStr(depType)
343                         << "producer: " << *producer.getOperation() << " view: "
344                         << producedView << " output index: " << producerIdx);
345       (void)producerIdx;
346 
347       // Simple fusability checks.
348       if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
349         continue;
350 
351       return dependence;
352     }
353   }
354   return {};
355 }
356 
357 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
358     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
359     const LinalgDependenceGraph &graph, OperationFolder *folder) {
360   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
361       findFusableProducer(consumer, consumerIdx, graph);
362   if (!fusableDependence)
363     return {};
364 
365   LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
366   Value producerView = fusableDependence->dependentOpView.view;
367   Value consumerView = fusableDependence->indexingView;
368 
369   // Must be a subview or a slice to guarantee there are loops we can fuse
370   // into.
371   auto subView = consumerView.getDefiningOp<SubViewOp>();
372   auto slice = consumerView.getDefiningOp<SliceOp>();
373   if (!subView && !slice) {
374     LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
375     return {};
376   }
377 
378   // Fuse `producer` just before `consumer`.
379   OpBuilder::InsertionGuard g(b);
380   b.setInsertionPoint(consumer.getOperation());
381   ScopedContext scope(b, consumer.getLoc());
382   LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
383   Optional<unsigned> producerIdxOpt =
384       producerOp.getIndexOfInputAndOutputBuffer(producerView);
385   assert(producerIdxOpt.hasValue() && "incorrect operand index");
386   unsigned producerIdx = producerIdxOpt.getValue();
387 
388   auto fusedProducer =
389       fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
390   return FusionInfo{producerOp, fusedProducer};
391 }
392 
393 /// Returns the positions of the loop in `op` that can be tiled based on the
394 /// operations that are to be fused with it. For example, in a
395 ///
396 ///   linalg. matmul ins(%a, %b : ...) outs(%c : ...)
397 ///
398 /// if the producer of %a needs to be fused with this op, only the `i` loop of
399 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
400 /// fused, then no loops can be tiled while fusing.
401 static DenseSet<unsigned> collectTileAndFuseLoops(
402     LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem>
403                      fusableDependences) {
404   // 1. Only parallel loops can be used for tile + fuse. Find the number of
405   // common outer parallel loops between the op and its producers being fused.
406   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
407     return linalgOp.iterator_types()
408         .getValue()
409         .take_while([](Attribute attr) -> bool {
410           return attr.cast<StringAttr>().getValue() ==
411                  getParallelIteratorTypeName();
412         })
413         .size();
414   };
415 
416   size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
417   for (auto dependence : fusableDependences) {
418     numOuterParallelLoops =
419         std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>(
420                                             dependence.dependentOpView.op)));
421   }
422 
423   // Need to compute what tiled loops can be "fused". Given the precondition
424   // that all indexing map for the producer view is a projected permutation, we
425   // can assert that the producer iterates over the dimensions of the "fused
426   // view" only once. To be used a fused loop the producer should use this loop
427   // to access the fused view. For example, consider
428   //
429   // ```
430   //   linalg.add ins(%a, %b) outs(%c)
431   //   linalg.matmul ins(%d, %c) outs(%e)
432   // ```
433   //
434   // if `linalg.add` has the semantics of `c = a + b`, then the following
435   // tile+fuse code is correct.
436   //
437   // ```
438   // for j ... += TSj
439   //   %sa = subview %a[0, %j][...]
440   //   %sb = subview %b[0, %j][...]
441   //   %sc = subview %c[0, %j][...]
442   //   %sd = subview %d[0, 0][...]
443   //   %se = subview %e[0, %j][...]
444   //   linalg.add ins(%sa, %sb) outs(%sc)
445   //   linalg.matmul ins(%sd, %sc) outs(%se)
446   // ```
447   //
448   // On the other hand tiling along i would be incorrect
449   //
450   // ```
451   // for %i .. += TSi
452   //   %sa = subview %a[%i, 0][...]
453   //   %sb = subview %b[%i, 0][...]
454   //   %sc = subview %c[%i, 0][...]
455   //   %sc2 = subview %c[0, 0][...]
456   //   %sd = subview %d[%i, 0][...]
457   //   %se = subview %e[%i, 0][...]
458   //   linalg.add ins(%sa, %sb) outs(%sc)
459   //   linalg.matmul ins(%sd, %sc2) outs(%se)
460   // ```
461   //
462   // The write to the subview `%sc` in `linalg.add` is performed after the read
463   // from it using `%sc2` violating the RAW dependence of the original code. To
464   // find such loops indexing map of the fused view in the consumer op is
465   // used. For the above example, this indexing map is
466   //
467   //   affine_map<(d0, d1, d2) -> (d2, d1)>
468   //
469   // Since d0 is not in the result expressions of this map, it is not treated as
470   // tile + fuse loop, (but d1 is).
471   //
472   // TODO: The above is probably restrictive and there might be a generalization
473   // of these that might allow for more fusion opportunities. Explore based on
474   // needs.
475   SmallVector<DenseSet<unsigned>, 1> commonTilableLoops;
476   for (auto dependence : fusableDependences) {
477     unsigned consumerIdx =
478         op.getIndexOfInputAndOutputBuffer(dependence.indexingView).getValue();
479     AffineMap consumerAccess = op.getIndexingMap(consumerIdx);
480     // Previously asserted that the consumerAccess map is a projected
481     // permutation, so all results are known to be AffineDimExprs. To remove
482     // this restriction walk the expression to find which dimensions of the
483     // consumer loop appear in the `consumerAccess`.
484     DenseSet<unsigned> positions;
485     for (auto expr : consumerAccess.getResults())
486       positions.insert(expr.cast<AffineDimExpr>().getPosition());
487     commonTilableLoops.emplace_back(std::move(positions));
488   }
489 
490   // 2. Of the outer parallel loops, only those loops can be tiled + fused as
491   // computed above for all the fused dependences can be used to tile and fuse.
492   DenseSet<unsigned> tilableParallelLoops;
493   for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) {
494     if (llvm::all_of(commonTilableLoops,
495                      [&](const DenseSet<unsigned> &tilableLoops) {
496                        return tilableLoops.count(index);
497                      }))
498       tilableParallelLoops.insert(index);
499   }
500   return tilableParallelLoops;
501 }
502 
503 /// Find all dependences that are to be fusable.
504 static Optional<
505     SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
506 findAllFusableDependences(LinalgOp op,
507                           const LinalgDependenceGraph &dependenceGraph,
508                           const LinalgFusionOptions &fusionOptions) {
509   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>
510       fusableDependences;
511   for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) {
512     if (fusionOptions.indicesToFuse &&
513         !fusionOptions.indicesToFuse->count(operand.index()))
514       continue;
515     Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
516         fusableDependence =
517             findFusableProducer(op, operand.index(), dependenceGraph);
518     if (!fusableDependence)
519       continue;
520     // Make sure that the indexing map of the view used for fusion in the
521     // producer is a projected permutation.
522     LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
523     Value producerView = fusableDependence->dependentOpView.view;
524     unsigned producerIdx =
525         producerOp.getIndexOfInputAndOutputBuffer(producerView).getValue();
526     AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
527     if (!producerMap.isProjectedPermutation()) {
528       op.emitError("unhandled non permutation indexing map for fused view in "
529                    "producer for operand at index ")
530           << operand.index();
531       return llvm::None;
532     }
533     Value consumerView = fusableDependence->indexingView;
534     unsigned consumerIdx =
535         op.getIndexOfInputAndOutputBuffer(consumerView).getValue();
536     if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) {
537       op.emitError(
538           "unhandled case where indexing map for fused view in the consumer is "
539           "not a projected permuration while fusing at index ")
540           << operand.index();
541       return llvm::None;
542     }
543     fusableDependences.push_back(*fusableDependence);
544     if (!fusionOptions.indicesToFuse)
545       break;
546   }
547   return fusableDependences;
548 }
549 
550 static bool isZero(Value v) {
551   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
552     return cst.getValue() == 0;
553   return false;
554 }
555 
556 template <typename LoopType>
557 static Optional<TiledAndFusedLinalgOps>
558 tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
559                          const LinalgDependenceGraph &dependenceGraph,
560                          const LinalgTilingOptions &tilingOptions,
561                          const LinalgFusionOptions &fusionOptions) {
562   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
563   // Some of the tiling options might not be supportable with tile and fuse.
564   // TODO: Support interchange with tile + fuse.
565   if (!tilingOptions.interchangeVector.empty()) {
566     op.emitError("unable to handle tile and fuse with interchange");
567     return llvm::None;
568   }
569 
570   OpBuilder::InsertionGuard g(rewriter);
571   rewriter.setInsertionPoint(op);
572   ScopedContext scope(rewriter, op.getLoc());
573 
574   // Find all the producers.
575   Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
576       fusableDependencesOpt =
577           findAllFusableDependences(op, dependenceGraph, fusionOptions);
578   if (!fusableDependencesOpt)
579     return llvm::None;
580   ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences(
581       *fusableDependencesOpt);
582 
583   // Enforce the convention that "tiling by zero" skips tiling a particular
584   // dimension. This convention is significantly simpler to handle instead of
585   // adjusting affine maps to account for missing dimensions.
586   auto nLoops = op.getNumLoops();
587   SmallVector<Value, 4> tileSizeVector =
588       tilingOptions.tileSizeComputationFunction(rewriter, op);
589   if (tileSizeVector.size() < nLoops) {
590     auto zero = std_constant_index(0);
591     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
592   }
593 
594   TiledAndFusedLinalgOps ret;
595 
596   // Find the loops that can be tiled and fused.
597   DenseSet<unsigned> tileFuseLoops =
598       collectTileAndFuseLoops(op, fusableDependences);
599 
600   // If there are no fusable dependences or there are no tile+fusable loops,
601   // just return.
602   if (fusableDependences.empty() || tileFuseLoops.empty()) {
603     return llvm::None;
604   }
605 
606   // Get the tile sizes for the first and second tiling steps. For the first
607   // step the tile size are set to zero for the loops that arent
608   // fused. Similarly for the second step, the tile sizes are set to zero for
609   // the loops that are fused. For example, if for the following input
610   //
611   // ```
612   //   linalg.add ins(%a, %b) outs(%c)
613   //   linalg.matmul ins(%d, %c) outs(%e)
614   // ```
615   //
616   // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
617   // respectively, and since only `j` can be tiled and fused. The tile sizes
618   // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
619   // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
620   // the tiled matmul generated by the first tiling step.
621   SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
622   for (auto tileSize : enumerate(tileSizeVector)) {
623     auto zero = std_constant_index(0);
624     if (tileFuseLoops.count(tileSize.index())) {
625       tileAndFuseSizes.push_back(tileSize.value());
626       tileSizes.push_back(zero);
627     } else {
628       tileSizes.push_back(tileSize.value());
629       tileAndFuseSizes.push_back(zero);
630     }
631   }
632 
633   // Tile for the loops that can be fused.
634   LinalgTilingOptions firstTilingOptions = tilingOptions;
635   firstTilingOptions.setTileSizes(tileAndFuseSizes);
636   Optional<TiledLinalgOp> firstTiledOp =
637       tileLinalgOp(rewriter, op, firstTilingOptions);
638   if (!firstTiledOp)
639     return llvm::None;
640   ret.op = firstTiledOp->op;
641   ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
642 
643   rewriter.setInsertionPoint(ret.op);
644   // Fuse the operands.
645   for (auto producer : enumerate(fusableDependences)) {
646     LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op);
647     unsigned producerIdx = producerOp
648                                .getIndexOfInputAndOutputBuffer(
649                                    producer.value().dependentOpView.view)
650                                .getValue();
651     unsigned consumerIdx =
652         op.getIndexOfInputAndOutputBuffer(producer.value().indexingView)
653             .getValue();
654     LinalgOp fusedOp =
655         fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx);
656     ret.fusedProducers.push_back(fusedOp);
657     ret.originalProducers.push_back(producerOp);
658   }
659 
660   if (!llvm::all_of(tileSizes, isZero)) {
661     // Tile the remaining loops of the root operation.
662     LinalgTilingOptions secondTilingOptions = tilingOptions;
663     // The distribution is done only for the tile+fused loops.
664     secondTilingOptions.distribution = llvm::None;
665     secondTilingOptions.setTileSizes(tileSizes);
666     Optional<TiledLinalgOp> secondTiledOp =
667         tileLinalgOp(rewriter, ret.op, secondTilingOptions);
668     if (!secondTiledOp)
669       return llvm::None;
670     ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
671                             secondTiledOp->loops.end());
672     rewriter.eraseOp(ret.op);
673     ret.op = secondTiledOp->op;
674   }
675 
676   return ret;
677 }
678 
679 Optional<TiledAndFusedLinalgOps>
680 mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
681                                    const LinalgDependenceGraph &dependenceGraph,
682                                    const LinalgTilingOptions &tilingOptions,
683                                    const LinalgFusionOptions &fusionOptions) {
684   switch (tilingOptions.loopType) {
685   case LinalgTilingLoopType::Loops:
686     return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
687                                                 tilingOptions, fusionOptions);
688   case LinalgTilingLoopType::ParallelLoops:
689     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
690         rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
691   default:;
692   }
693   return llvm::None;
694 }
695 
696 static void fuseLinalgOpsGreedily(FuncOp f) {
697   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
698 
699   OpBuilder b(f);
700   OperationFolder folder(f.getContext());
701   DenseSet<Operation *> eraseSet;
702 
703   // Save original Linalg ops, we only want to make a pass over those.
704   SmallVector<Operation *, 8> linalgOps;
705   f.walk([&](LinalgOp op) {
706     if (op.hasBufferSemantics())
707       linalgOps.push_back(op);
708   });
709 
710   // TODO: LinalgDependenceGraph should be able to update itself.
711   // The current naive and expensive reconstruction of the graph should be
712   // removed.
713   for (auto *op : llvm::reverse(linalgOps)) {
714     for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers();
715          id < e; ++id) {
716       linalg::Aliases aliases;
717       linalg::LinalgDependenceGraph graph(aliases, linalgOps);
718       if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
719         auto *originalOp = info->originalProducer.getOperation();
720         eraseSet.insert(originalOp);
721         auto *originalOpInLinalgOpsVector =
722             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
723         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
724       }
725     }
726   }
727   // The `fuseProducerOf` function performs structural checks and in particular
728   // that no covering read or write exist between the consumer and the producer.
729   // As a consequence, the only fusions that may occur preserve subsequent
730   // dependences and are guaranteed by construction to produce the whole view.
731   // We may thus erase the producer once it is fused.
732   for (auto *e : eraseSet)
733     e->erase();
734   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
735 }
736 
737 namespace {
738 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
739   void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
740 };
741 } // namespace
742 
743 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
744   return std::make_unique<LinalgFusionPass>();
745 }
746