xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision c694588fc52a8845174fee06ad0bcfa338e87816)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Dominance.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Support/LLVM.h"
28 #include "mlir/Transforms/FoldUtils.h"
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "linalg-fusion"
34 
35 using namespace mlir;
36 using namespace mlir::edsc;
37 using namespace mlir::edsc::intrinsics;
38 using namespace mlir::linalg;
39 
40 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
41 
42 using llvm::dbgs;
43 
44 /// Implements a simple high-level fusion pass of linalg library operations.
45 ///
46 /// In each block, linalg ops are processed in reverse textual order.
47 /// Given a linalg op `O`, fusion occurs by:
48 ///   1. inspecting the linalg ops that write into the views read by `O`. This
49 ///      uses the SSA value of the views and a simple subview/slice analysis to
50 ///      determine producer-consumer dependences;
51 ///   2. greedily fuse the linalg ops that produce subview
52 ///   3. inspect the fused ops and determine whether they have other remaining
53 ///      LinalgOp uses. If not, then erase the original producing linalg op.
54 ///
55 /// More advanced use cases, analyses as well as profitability heuristics are
56 /// left for future work.
57 
58 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
59 // a subset of the original loop ranges of `op`.
60 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
61 // to the `loopRanges` in order to obtain view ranges.
62 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
63                                     ArrayRef<SubViewOp::Range> loopRanges) {
64   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
65   auto maps = op.indexing_maps();
66   SmallVector<Value, 8> clonedViews;
67   clonedViews.reserve(op.getNumInputsAndOutputs());
68   // Iterate over the inputs and outputs in order.
69   // Extract the subranges from the linearized ranges.
70   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
71   for (auto en : llvm::enumerate(ios)) {
72     unsigned idx = en.index();
73     auto map = maps[idx].cast<AffineMapAttr>().getValue();
74     LLVM_DEBUG(dbgs() << "map: " << map << "\n");
75     Value view = en.value();
76     SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
77     for (auto en2 : llvm::enumerate(map.getResults())) {
78       unsigned d = en2.index();
79       // loopToOperandRangesMaps are permutations-only.
80       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
81       viewRanges[d] = loopRanges[loopPos];
82       LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
83                         << "\t"
84                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
85     }
86     // Construct a new subview for the tile.
87     unsigned rank = viewRanges.size();
88     SmallVector<Value, 4> offsets, sizes, strides;
89     offsets.reserve(rank);
90     sizes.reserve(rank);
91     strides.reserve(rank);
92     for (auto r : viewRanges) {
93       offsets.push_back(r.offset);
94       sizes.push_back(r.size);
95       strides.push_back(r.stride);
96     }
97     clonedViews.push_back(
98         b.create<SubViewOp>(loc, view, offsets, sizes, strides));
99   }
100   auto operands = getAssumedNonViewOperands(op);
101   clonedViews.append(operands.begin(), operands.end());
102 
103   Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews);
104   // When the producer is an IndexedGenercOp, we have to transform its block
105   // IV arguments according to the tiling of the consumer, i.e. offset them by
106   // the values computed in `loopRanges`.
107   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
108     auto &block = indexedGenericOp.region().front();
109 
110     OpBuilder::InsertionGuard g(b);
111     b.setInsertionPointToStart(&block);
112     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
113       Value oldIndex = block.getArgument(i);
114       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
115                                          loopRanges[i].offset);
116       oldIndex.replaceAllUsesExcept(newIndex,
117                                     SmallPtrSet<Operation *, 1>{newIndex});
118     }
119   }
120   return clonedOp;
121 }
122 
123 struct ViewDimension {
124   Value view;
125   unsigned dimension;
126 };
127 
128 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies
129 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
130 // guarantees at least one such dimension is found. If multiple candidates exist
131 // they must agree by construction (i.e. have the same size) and we just return
132 // the first one.
133 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
134   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
135   auto maps = op.indexing_maps();
136   // Iterate over the inputs and outputs in order.
137   // Extract the subranges from the linearized ranges.
138   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
139   for (auto en : llvm::enumerate(ios)) {
140     unsigned idx = en.index();
141     auto map = maps[idx].cast<AffineMapAttr>().getValue();
142     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
143     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
144     Value view = en.value();
145     SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
146     for (auto en2 : llvm::enumerate(map.getResults())) {
147       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
148         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
149                           << "\n");
150         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
151         return ViewDimension{view, static_cast<unsigned>(en2.index())};
152       }
153     }
154   }
155   llvm_unreachable("Expect to be able to extract a view defining loop range");
156 }
157 
158 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
159                      LinalgOp consumer, unsigned consumerIdx,
160                      OperationFolder *folder = nullptr) {
161   assert(producer.hasBufferSemantics() &&
162          "expected linalg op with buffer semantics");
163   assert(consumer.hasBufferSemantics() &&
164          "expected linalg op with buffer semantics");
165 
166   auto subView = dyn_cast_or_null<SubViewOp>(
167       consumer.getBuffer(consumerIdx).getDefiningOp());
168   auto slice = dyn_cast_or_null<SliceOp>(
169       consumer.getBuffer(consumerIdx).getDefiningOp());
170   assert(subView || slice);
171   (void)subView;
172   (void)slice;
173 
174   // loopToOperandRangesMaps are permutations-only by construction:
175   //   we can always identify a data dimension with a (at least one) loop
176   //   dimension.
177   AffineMap producerMap =
178       producer.indexing_maps()[producerIdx].cast<AffineMapAttr>().getValue();
179   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
180                     << ", producer map: " << producerMap << "\n");
181 
182   unsigned nPar = producer.getNumParallelLoops();
183   unsigned nRed = producer.getNumReductionLoops();
184   unsigned nWin = producer.getNumWindowLoops();
185   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
186 
187   // Iterate over dimensions identified by the producer map for `producerIdx`.
188   // This defines a subset of the loop ranges that we need to complete later.
189   auto loc = consumer.getLoc();
190   for (auto en : llvm::enumerate(producerMap.getResults())) {
191     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
192     loopRanges[posInProducerLoop] =
193         subView.getOrCreateRanges(b, loc)[en.index()];
194   }
195 
196   // Iterate over all dimensions. For the dimensions not identified by the
197   // producer map for `producerIdx`, we need to explicitly compute the view that
198   // defines the loop ranges using the `producer`.
199   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
200     if (loopRanges[i].offset)
201       LLVM_DEBUG(llvm::dbgs()
202                  << "existing LoopRange: " << loopRanges[i] << "\n");
203     else {
204       auto viewDim = getViewDefiningLoopRange(producer, i);
205       loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0),
206                                        std_dim(viewDim.view, viewDim.dimension),
207                                        folded_std_constant_index(folder, 1)};
208       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
209     }
210   }
211 
212   return cloneWithLoopRanges(b, loc, producer, loopRanges);
213 }
214 
215 // Encode structural fusion safety preconditions.
216 // Some of these will be lifted in the future with better analysis.
217 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
218                                           LinalgOp consumer) {
219   assert(producer.hasBufferSemantics() &&
220          "expected linalg op with buffer semantics");
221   assert(consumer.hasBufferSemantics() &&
222          "expected linalg op with buffer semantics");
223   if (producer.getNumOutputs() != 1) {
224     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
225     return false;
226   }
227   // Only fuse when the producer block dominates.
228   DominanceInfo dom(producer.getOperation());
229   if (!dom.dominates(producer.getOperation()->getBlock(),
230                      consumer.getOperation()->getBlock())) {
231     LLVM_DEBUG(
232         dbgs()
233         << "\nNot structurally fusable (producer block does not dominate)");
234     return false;
235   }
236   return true;
237 }
238 
239 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
240                                              LinalgOp consumer,
241                                              Value consumedView,
242                                              LinalgOp producer) {
243   assert(producer.hasBufferSemantics() &&
244          "expected linalg op with buffer semantics");
245   assert(consumer.hasBufferSemantics() &&
246          "expected linalg op with buffer semantics");
247   // Make some simple structural checks that alleviate the need for more
248   // complex analyses.
249   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
250     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
251                       << *producer.getOperation());
252     return false;
253   }
254   // Check for any interleaved write to consumedView.
255   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
256     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
257                       << *producer.getOperation());
258     return false;
259   }
260   return true;
261 }
262 
263 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
264                                  LinalgOp consumer, Value consumedView,
265                                  LinalgOp producer) {
266   assert(producer.hasBufferSemantics() &&
267          "expected linalg op with buffer semantics");
268   assert(consumer.hasBufferSemantics() &&
269          "expected linalg op with buffer semantics");
270   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
271     return false;
272   // Check for any fusion-preventing dependence to any view read/written that
273   // would violate dependences.
274   if (!graph.findCoveringDependences(producer, consumer).empty()) {
275     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
276                       << *producer.getOperation());
277     return false;
278   }
279   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
280     // TODO: add a level of indirection to linalg.generic.
281     if (convOp.padding())
282       return false;
283   }
284   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
285     // TODO: add a level of indirection to linalg.generic.
286     if (convOp.padding())
287       return false;
288   }
289   return true;
290 }
291 
292 static bool isSameSubView(Value a, Value b) {
293   if (a == b)
294     return true;
295   auto sva = a.getDefiningOp<SubViewOp>();
296   auto svb = b.getDefiningOp<SubViewOp>();
297   if (!sva || !svb)
298     return false;
299   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
300     return false;
301   if (sva.getType() != svb.getType())
302     return false;
303   if (sva.getRank() != svb.getRank())
304     return false;
305   if (sva.getNumOperands() != svb.getNumOperands())
306     return false;
307   if (sva.static_offsets() != svb.static_offsets())
308     return false;
309   if (sva.static_sizes() != svb.static_sizes())
310     return false;
311   if (sva.static_strides() != svb.static_strides())
312     return false;
313   /// Skip the "viewSource" operand.
314   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
315     if (sva.getOperand(idx) != svb.getOperand(idx))
316       return false;
317   return true;
318 }
319 
320 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
321 findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
322                     const LinalgDependenceGraph &dependenceGraph) {
323   // Only consider RAW and WAW atm.
324   for (auto depType : {
325            LinalgDependenceGraph::DependenceType::RAW,
326            LinalgDependenceGraph::DependenceType::WAW,
327        }) {
328     for (auto dependence :
329          dependenceGraph.getDependencesInto(consumer, depType)) {
330       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
331 
332       // Check that the dependence is indeed on the input `consumerIdx` view.
333       auto consumedView = dependence.indexingView;
334       if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
335         continue;
336 
337       // Consumer consumes this view, `isStructurallyFusableProducer` also
338       // checks whether it is a strict subview of the producer view.
339       auto producedView = dependence.dependentOpView.view;
340       auto producerIdx =
341           producer.getIndexOfOutputBuffer(producedView).getValue();
342       // `consumerIdx` and `producerIdx` exist by construction.
343       LLVM_DEBUG(dbgs() << "\n"
344                         << LinalgDependenceGraph::getDependenceTypeStr(depType)
345                         << "producer: " << *producer.getOperation() << " view: "
346                         << producedView << " output index: " << producerIdx);
347       (void)producerIdx;
348 
349       // Simple fusability checks.
350       if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
351         continue;
352 
353       return dependence;
354     }
355   }
356   return {};
357 }
358 
359 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
360     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
361     const LinalgDependenceGraph &graph, OperationFolder *folder) {
362   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
363       findFusableProducer(consumer, consumerIdx, graph);
364   if (!fusableDependence)
365     return {};
366 
367   LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
368   Value producerView = fusableDependence->dependentOpView.view;
369   Value consumerView = fusableDependence->indexingView;
370 
371   // Must be a subview or a slice to guarantee there are loops we can fuse
372   // into.
373   auto subView = consumerView.getDefiningOp<SubViewOp>();
374   auto slice = consumerView.getDefiningOp<SliceOp>();
375   if (!subView && !slice) {
376     LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
377     return {};
378   }
379 
380   // Fuse `producer` just before `consumer`.
381   OpBuilder::InsertionGuard g(b);
382   b.setInsertionPoint(consumer.getOperation());
383   ScopedContext scope(b, consumer.getLoc());
384   LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
385   Optional<unsigned> producerIdxOpt =
386       producerOp.getIndexOfInputAndOutputBuffer(producerView);
387   assert(producerIdxOpt.hasValue() && "incorrect operand index");
388   unsigned producerIdx = producerIdxOpt.getValue();
389 
390   auto fusedProducer =
391       fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
392   return FusionInfo{producerOp, fusedProducer};
393 }
394 
395 /// Returns the positions of the loop in `op` that can be tiled based on the
396 /// operations that are to be fused with it. For example, in a
397 ///
398 ///   linalg. matmul ins(%a, %b : ...) outs(%c : ...)
399 ///
400 /// if the producer of %a needs to be fused with this op, only the `i` loop of
401 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
402 /// fused, then no loops can be tiled while fusing.
403 static DenseSet<unsigned> collectTileAndFuseLoops(
404     LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem>
405                      fusableDependences) {
406   // 1. Only parallel loops can be used for tile + fuse. Find the number of
407   // common outer parallel loops between the op and its producers being fused.
408   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
409     return linalgOp.iterator_types()
410         .getValue()
411         .take_while([](Attribute attr) -> bool {
412           return attr.cast<StringAttr>().getValue() ==
413                  getParallelIteratorTypeName();
414         })
415         .size();
416   };
417 
418   size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
419   for (auto dependence : fusableDependences) {
420     numOuterParallelLoops =
421         std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>(
422                                             dependence.dependentOpView.op)));
423   }
424 
425   // Need to compute what tiled loops can be "fused". Given the precondition
426   // that all indexing map for the producer view is a projected permutation, we
427   // can assert that the producer iterates over the dimensions of the "fused
428   // view" only once. To be used a fused loop the producer should use this loop
429   // to access the fused view. For example, consider
430   //
431   // ```
432   //   linalg.add ins(%a, %b) outs(%c)
433   //   linalg.matmul ins(%d, %c) outs(%e)
434   // ```
435   //
436   // if `linalg.add` has the semantics of `c = a + b`, then the following
437   // tile+fuse code is correct.
438   //
439   // ```
440   // for j ... += TSj
441   //   %sa = subview %a[0, %j][...]
442   //   %sb = subview %b[0, %j][...]
443   //   %sc = subview %c[0, %j][...]
444   //   %sd = subview %d[0, 0][...]
445   //   %se = subview %e[0, %j][...]
446   //   linalg.add ins(%sa, %sb) outs(%sc)
447   //   linalg.matmul ins(%sd, %sc) outs(%se)
448   // ```
449   //
450   // On the other hand tiling along i would be incorrect
451   //
452   // ```
453   // for %i .. += TSi
454   //   %sa = subview %a[%i, 0][...]
455   //   %sb = subview %b[%i, 0][...]
456   //   %sc = subview %c[%i, 0][...]
457   //   %sc2 = subview %c[0, 0][...]
458   //   %sd = subview %d[%i, 0][...]
459   //   %se = subview %e[%i, 0][...]
460   //   linalg.add ins(%sa, %sb) outs(%sc)
461   //   linalg.matmul ins(%sd, %sc2) outs(%se)
462   // ```
463   //
464   // The write to the subview `%sc` in `linalg.add` is performed after the read
465   // from it using `%sc2` violating the RAW dependence of the original code. To
466   // find such loops indexing map of the fused view in the consumer op is
467   // used. For the above example, this indexing map is
468   //
469   //   affine_map<(d0, d1, d2) -> (d2, d1)>
470   //
471   // Since d0 is not in the result expressions of this map, it is not treated as
472   // tile + fuse loop, (but d1 is).
473   //
474   // TODO: The above is probably restrictive and there might be a generalization
475   // of these that might allow for more fusion opportunities. Explore based on
476   // needs.
477   SmallVector<DenseSet<unsigned>, 1> commonTilableLoops;
478   for (auto dependence : fusableDependences) {
479     unsigned consumerIdx =
480         op.getIndexOfInputAndOutputBuffer(dependence.indexingView).getValue();
481     AffineMap consumerAccess = op.getIndexingMap(consumerIdx);
482     // Previously asserted that the consumerAccess map is a projected
483     // permutation, so all results are known to be AffineDimExprs. To remove
484     // this restriction walk the expression to find which dimensions of the
485     // consumer loop appear in the `consumerAccess`.
486     DenseSet<unsigned> positions;
487     for (auto expr : consumerAccess.getResults())
488       positions.insert(expr.cast<AffineDimExpr>().getPosition());
489     commonTilableLoops.emplace_back(std::move(positions));
490   }
491 
492   // 2. Of the outer parallel loops, only those loops can be tiled + fused as
493   // computed above for all the fused dependences can be used to tile and fuse.
494   DenseSet<unsigned> tilableParallelLoops;
495   for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) {
496     if (llvm::all_of(commonTilableLoops,
497                      [&](const DenseSet<unsigned> &tilableLoops) {
498                        return tilableLoops.count(index);
499                      }))
500       tilableParallelLoops.insert(index);
501   }
502   return tilableParallelLoops;
503 }
504 
505 /// Find all dependences that are to be fusable.
506 static Optional<
507     SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
508 findAllFusableDependences(LinalgOp op,
509                           const LinalgDependenceGraph &dependenceGraph,
510                           const LinalgFusionOptions &fusionOptions) {
511   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>
512       fusableDependences;
513   for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) {
514     if (fusionOptions.indicesToFuse &&
515         !fusionOptions.indicesToFuse->count(operand.index()))
516       continue;
517     Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
518         fusableDependence =
519             findFusableProducer(op, operand.index(), dependenceGraph);
520     if (!fusableDependence)
521       continue;
522     // Make sure that the indexing map of the view used for fusion in the
523     // producer is a projected permutation.
524     LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
525     Value producerView = fusableDependence->dependentOpView.view;
526     unsigned producerIdx =
527         producerOp.getIndexOfInputAndOutputBuffer(producerView).getValue();
528     AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
529     if (!producerMap.isProjectedPermutation()) {
530       op.emitError("unhandled non permutation indexing map for fused view in "
531                    "producer for operand at index ")
532           << operand.index();
533       return llvm::None;
534     }
535     Value consumerView = fusableDependence->indexingView;
536     unsigned consumerIdx =
537         op.getIndexOfInputAndOutputBuffer(consumerView).getValue();
538     if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) {
539       op.emitError(
540           "unhandled case where indexing map for fused view in the consumer is "
541           "not a projected permuration while fusing at index ")
542           << operand.index();
543       return llvm::None;
544     }
545     fusableDependences.push_back(*fusableDependence);
546     if (!fusionOptions.indicesToFuse)
547       break;
548   }
549   return fusableDependences;
550 }
551 
552 static bool isZero(Value v) {
553   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
554     return cst.getValue() == 0;
555   return false;
556 }
557 
558 template <typename LoopType>
559 static Optional<TiledAndFusedLinalgOps>
560 tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
561                          const LinalgDependenceGraph &dependenceGraph,
562                          const LinalgTilingOptions &tilingOptions,
563                          const LinalgFusionOptions &fusionOptions) {
564   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
565   // Some of the tiling options might not be supportable with tile and fuse.
566   // TODO: Support interchange with tile + fuse.
567   if (!tilingOptions.interchangeVector.empty()) {
568     op.emitError("unable to handle tile and fuse with interchange");
569     return llvm::None;
570   }
571 
572   OpBuilder::InsertionGuard g(rewriter);
573   rewriter.setInsertionPoint(op);
574   ScopedContext scope(rewriter, op.getLoc());
575 
576   // Find all the producers.
577   Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
578       fusableDependencesOpt =
579           findAllFusableDependences(op, dependenceGraph, fusionOptions);
580   if (!fusableDependencesOpt)
581     return llvm::None;
582   ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences(
583       *fusableDependencesOpt);
584 
585   // Enforce the convention that "tiling by zero" skips tiling a particular
586   // dimension. This convention is significantly simpler to handle instead of
587   // adjusting affine maps to account for missing dimensions.
588   auto nLoops = op.getNumLoops();
589   SmallVector<Value, 4> tileSizeVector =
590       tilingOptions.tileSizeComputationFunction(rewriter, op);
591   if (tileSizeVector.size() < nLoops) {
592     auto zero = std_constant_index(0);
593     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
594   }
595 
596   TiledAndFusedLinalgOps ret;
597 
598   // Find the loops that can be tiled and fused.
599   DenseSet<unsigned> tileFuseLoops =
600       collectTileAndFuseLoops(op, fusableDependences);
601 
602   // If there are no fusable dependences or there are no tile+fusable loops,
603   // just return.
604   if (fusableDependences.empty() || tileFuseLoops.empty()) {
605     return llvm::None;
606   }
607 
608   // Get the tile sizes for the first and second tiling steps. For the first
609   // step the tile size are set to zero for the loops that arent
610   // fused. Similarly for the second step, the tile sizes are set to zero for
611   // the loops that are fused. For example, if for the following input
612   //
613   // ```
614   //   linalg.add ins(%a, %b) outs(%c)
615   //   linalg.matmul ins(%d, %c) outs(%e)
616   // ```
617   //
618   // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
619   // respectively, and since only `j` can be tiled and fused. The tile sizes
620   // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
621   // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
622   // the tiled matmul generated by the first tiling step.
623   SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
624   for (auto tileSize : enumerate(tileSizeVector)) {
625     auto zero = std_constant_index(0);
626     if (tileFuseLoops.count(tileSize.index())) {
627       tileAndFuseSizes.push_back(tileSize.value());
628       tileSizes.push_back(zero);
629     } else {
630       tileSizes.push_back(tileSize.value());
631       tileAndFuseSizes.push_back(zero);
632     }
633   }
634 
635   // Tile for the loops that can be fused.
636   LinalgTilingOptions firstTilingOptions = tilingOptions;
637   firstTilingOptions.setTileSizes(tileAndFuseSizes);
638   Optional<TiledLinalgOp> firstTiledOp =
639       tileLinalgOp(rewriter, op, firstTilingOptions);
640   if (!firstTiledOp)
641     return llvm::None;
642   ret.op = firstTiledOp->op;
643   ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
644 
645   rewriter.setInsertionPoint(ret.op);
646   // Fuse the operands.
647   for (auto producer : enumerate(fusableDependences)) {
648     LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op);
649     unsigned producerIdx = producerOp
650                                .getIndexOfInputAndOutputBuffer(
651                                    producer.value().dependentOpView.view)
652                                .getValue();
653     unsigned consumerIdx =
654         op.getIndexOfInputAndOutputBuffer(producer.value().indexingView)
655             .getValue();
656     LinalgOp fusedOp =
657         fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx);
658     ret.fusedProducers.push_back(fusedOp);
659     ret.originalProducers.push_back(producerOp);
660   }
661 
662   if (!llvm::all_of(tileSizes, isZero)) {
663     // Tile the remaining loops of the root operation.
664     LinalgTilingOptions secondTilingOptions = tilingOptions;
665     // The distribution is done only for the tile+fused loops.
666     secondTilingOptions.distribution = llvm::None;
667     secondTilingOptions.setTileSizes(tileSizes);
668     Optional<TiledLinalgOp> secondTiledOp =
669         tileLinalgOp(rewriter, ret.op, secondTilingOptions);
670     if (!secondTiledOp)
671       return llvm::None;
672     ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
673                             secondTiledOp->loops.end());
674     rewriter.eraseOp(ret.op);
675     ret.op = secondTiledOp->op;
676   }
677 
678   return ret;
679 }
680 
681 Optional<TiledAndFusedLinalgOps>
682 mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
683                                    const LinalgDependenceGraph &dependenceGraph,
684                                    const LinalgTilingOptions &tilingOptions,
685                                    const LinalgFusionOptions &fusionOptions) {
686   switch (tilingOptions.loopType) {
687   case LinalgTilingLoopType::Loops:
688     return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
689                                                 tilingOptions, fusionOptions);
690   case LinalgTilingLoopType::ParallelLoops:
691     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
692         rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
693   default:;
694   }
695   return llvm::None;
696 }
697 
698 static void fuseLinalgOpsGreedily(FuncOp f) {
699   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
700 
701   OpBuilder b(f);
702   OperationFolder folder(f.getContext());
703   DenseSet<Operation *> eraseSet;
704 
705   // Save original Linalg ops, we only want to make a pass over those.
706   SmallVector<Operation *, 8> linalgOps;
707   f.walk([&](LinalgOp op) {
708     if (op.hasBufferSemantics())
709       linalgOps.push_back(op);
710   });
711 
712   // TODO: LinalgDependenceGraph should be able to update itself.
713   // The current naive and expensive reconstruction of the graph should be
714   // removed.
715   for (auto *op : llvm::reverse(linalgOps)) {
716     for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers();
717          id < e; ++id) {
718       linalg::Aliases aliases;
719       linalg::LinalgDependenceGraph graph(aliases, linalgOps);
720       if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
721         auto *originalOp = info->originalProducer.getOperation();
722         eraseSet.insert(originalOp);
723         auto *originalOpInLinalgOpsVector =
724             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
725         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
726       }
727     }
728   }
729   // The `fuseProducerOf` function performs structural checks and in particular
730   // that no covering read or write exist between the consumer and the producer.
731   // As a consequence, the only fusions that may occur preserve subsequent
732   // dependences and are guaranteed by construction to produce the whole view.
733   // We may thus erase the producer once it is fused.
734   for (auto *e : eraseSet)
735     e->erase();
736   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
737 }
738 
739 //====---------------------------------------------------------------------===//
740 // Fusion on Tensor operation.
741 //====---------------------------------------------------------------------===//
742 
743 namespace {
744 
745 /// Implementation of fusion of generic ops and indexed_generic ops.
746 struct FuseGenericOpsOnTensors {
747   static bool isFusible(LinalgOp producer, LinalgOp consumer,
748                         unsigned consumerIdx) {
749     // Producer and consumer must have tensor semantics.
750     if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
751       return false;
752 
753     // Verify that
754     // - the producer has all "parallel" iterator type.
755     if (producer.getNumParallelLoops() != producer.getNumLoops())
756       return false;
757 
758     // Get the consumer index map. The number of results of the consumer index
759     // map must match the number of loops of the producer.
760     AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
761     if (consumerIndexMap.getNumResults() != producer.getNumLoops())
762       return false;
763 
764     // Finally the index_map for the result must be invertible. For now just
765     // verify it is a permutation.
766     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
767     return producerResultIndexMap.isPermutation();
768   }
769 
770   static LinalgOp fuse(LinalgOp producer, LinalgOp consumer,
771                        unsigned consumerIdx, PatternRewriter &rewriter,
772                        OperationFolder *folder = nullptr) {
773     if (!isFusible(producer, consumer, consumerIdx))
774       return nullptr;
775 
776     unsigned numFusedOperands = producer.getOperation()->getNumOperands() +
777                                 consumer.getOperation()->getNumOperands() - 1;
778 
779     // Compute the fused operands list,
780     SmallVector<Value, 2> fusedOperands;
781     fusedOperands.reserve(numFusedOperands);
782     auto consumerOperands = consumer.getOperation()->getOperands();
783     auto producerOperands = producer.getOperation()->getOperands();
784     fusedOperands.assign(consumerOperands.begin(),
785                          std::next(consumerOperands.begin(), consumerIdx));
786     fusedOperands.append(producerOperands.begin(), producerOperands.end());
787     fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
788                          consumerOperands.end());
789 
790     // Compute indexing_maps for the fused operation. The indexing_maps for the
791     // operands of the consumers that arent fused are the same. The
792     // indexing_maps for the producers need to be computed based on the
793     // indexing_map of the operand at consumerIdx in the consumer.
794     SmallVector<Attribute, 4> fusedIndexMaps;
795     auto consumerIndexMaps = consumer.indexing_maps();
796     fusedIndexMaps.reserve(fusedOperands.size() +
797                            consumer.getOperation()->getNumResults());
798     fusedIndexMaps.assign(consumerIndexMaps.begin(),
799                           std::next(consumerIndexMaps.begin(), consumerIdx));
800     // Compute indexing maps for the producer args in the fused operation.
801     computeProducerOperandIndex(
802         producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
803 
804     // Append the indexing maps for the remaining consumer operands.
805     fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
806                           consumerIndexMaps.end());
807 
808     // Generate the fused op.
809     // Tensor-level fusion is only on ops without initTensors and outputBuffers.
810     LinalgOp fusedOp;
811     if (isa<GenericOp>(producer.getOperation()) &&
812         isa<GenericOp>(consumer.getOperation())) {
813       fusedOp =
814           rewriter
815               .create<GenericOp>(consumer.getLoc(),
816                                  consumer.getOperation()->getResultTypes(),
817                                  /*inputs=*/fusedOperands,
818                                  /*outputBuffers=*/ValueRange{},
819                                  /*initTensors=*/ValueRange{},
820                                  rewriter.getArrayAttr(fusedIndexMaps),
821                                  consumer.iterator_types(),
822                                  /*doc=*/nullptr,
823                                  /*library_call=*/nullptr,
824                                  /*symbol_source=*/nullptr)
825               .getOperation();
826     } else {
827       fusedOp =
828           rewriter
829               .create<IndexedGenericOp>(
830                   consumer.getLoc(), consumer.getOperation()->getResultTypes(),
831                   /*inputs=*/fusedOperands,
832                   /*outputBuffers=*/ValueRange{},
833                   /*initTensors=*/ValueRange{},
834                   rewriter.getArrayAttr(fusedIndexMaps),
835                   consumer.iterator_types(),
836                   /*doc=*/nullptr,
837                   /*library_call=*/nullptr,
838                   /*symbol_source=*/nullptr)
839               .getOperation();
840     }
841 
842     // Construct an AffineMap from consumer loops to producer loops.
843     // consumer loop -> tensor index
844     AffineMap consumerResultIndexMap =
845         consumer.getInputIndexingMap(consumerIdx);
846     // producer loop -> tensor index
847     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
848     // tensor index -> producer loop
849     AffineMap invProducerResultIndexMap =
850         inversePermutation(producerResultIndexMap);
851     assert(invProducerResultIndexMap &&
852            "expected producer result indexig map to be invertible");
853     // consumer loop -> producer loop
854     AffineMap consumerToProducerLoopsMap =
855         invProducerResultIndexMap.compose(consumerResultIndexMap);
856 
857     generateFusedRegion(rewriter, fusedOp, producer, consumer,
858                         consumerToProducerLoopsMap, consumerIdx,
859                         consumer.getNumLoops());
860     return fusedOp;
861   }
862 
863 private:
864   /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
865   /// the `producer` to use in the fused operation given the indexing map of the
866   /// result of the producer in the consumer.
867   static void computeProducerOperandIndex(
868       LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
869       SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
870     // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
871     // from consumer loop -> consumer arg tensor index/producer result tensor
872     // index. The fused loop is same as the consumer loop. For each producer arg
873     // the indexing map to be computed is a map from consumer loop -> producer
874     // arg tensor index.
875 
876     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
877     // producerResultIndexMap is a map from producer loop -> tensor index.
878     // Compute the inverse to get map from tensor index -> producer loop.
879     // The inverse is a map from producer result tensor index -> producer loop.
880     AffineMap invProducerResultIndexMap =
881         inversePermutation(producerResultIndexMap);
882     assert(invProducerResultIndexMap &&
883            "expected producer result indexig map to be invertible");
884     for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
885       // argMap is a map from producer loop -> producer arg tensor index.
886       AffineMap argMap = producer.getInputIndexingMap(argNum);
887 
888       // Compose argMap with invProducerResultIndexMap to get a map from
889       // producer result tensor index -> producer arg tensor index.
890       AffineMap t1 = argMap.compose(invProducerResultIndexMap);
891 
892       // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
893       // consumer loop/ fused loop -> producer arg tensor index.
894       AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
895       fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
896     }
897   }
898 
899   /// Generate the region of the fused operation. The region of the fused op
900   /// must be empty.
901   static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp,
902                                   LinalgOp producer, LinalgOp consumer,
903                                   AffineMap consumerToProducerLoopsMap,
904                                   unsigned consumerIdx, unsigned nloops) {
905     // Build the region of the fused op.
906     Block &producerBlock = producer.getOperation()->getRegion(0).front();
907     Block &consumerBlock = consumer.getOperation()->getRegion(0).front();
908     Block *fusedBlock = new Block();
909     fusedOp->getRegion(0).push_back(fusedBlock);
910     BlockAndValueMapping mapper;
911     OpBuilder::InsertionGuard guard(rewriter);
912     rewriter.setInsertionPointToStart(fusedBlock);
913 
914     // The block arguments are
915     // [index_0, index_1, ... ,
916     //   consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
917     //   producer_operand_0, ... , producer_operand_(n-1)],
918     //   consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
919     // , where n is the number of producer's operand and m is the number
920     // consumer's operand.
921     // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
922     // generic op. In this case, there are no indices in block arguments.
923     unsigned numProducerIndices =
924         isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
925     unsigned numConsumerIndices =
926         isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
927     // Firstly, add all the indices to the block arguments.
928     for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
929          i < e; ++i)
930       fusedBlock->addArgument(rewriter.getIndexType());
931     // Map the arguments for the unmodified args from the consumer.
932     for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
933       if (consumerArg.index() == consumerIdx + numConsumerIndices) {
934         // Map the arguments for the args from the producer.
935         for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) {
936           // If producer is an indexed_generic op, map the indices from consumer
937           // loop to producer loop (because the fusedOp is built based on
938           // consumer's perspective).
939           if (producerArg.index() < numProducerIndices) {
940             auto newIndex = rewriter.create<mlir::AffineApplyOp>(
941                 producer.getLoc(),
942                 consumerToProducerLoopsMap.getSubMap(producerArg.index()),
943                 fusedBlock->getArguments().take_front(nloops));
944             mapper.map(producerArg.value(), newIndex);
945           } else {
946             mapper.map(producerArg.value(),
947                        fusedBlock->addArgument(producerArg.value().getType()));
948           }
949         }
950         continue;
951       }
952 
953       // If consumer is an indexed_generic op, map the indices to the block
954       // arguments directly. Otherwise, add the same type of arugment and map to
955       // it.
956       if (consumerArg.index() < numConsumerIndices) {
957         mapper.map(consumerArg.value(),
958                    fusedBlock->getArgument(consumerArg.index()));
959       } else {
960         mapper.map(consumerArg.value(),
961                    fusedBlock->addArgument(consumerArg.value().getType()));
962       }
963     }
964 
965     // Add operations from producer (except the yield operation) to the fused
966     // op.
967     for (auto &op : producerBlock.getOperations()) {
968       if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
969         // Lookup the value the yield operation is mapped to.
970         Value yieldVal = yieldOp.getOperand(0);
971         if (Value clonedVal = mapper.lookupOrNull(yieldVal))
972           mapper.map(
973               consumerBlock.getArgument(consumerIdx + numConsumerIndices),
974               clonedVal);
975         continue;
976       }
977       rewriter.clone(op, mapper);
978     }
979     for (auto &op : consumerBlock.getOperations())
980       rewriter.clone(op, mapper);
981   }
982 };
983 } // namespace
984 
985 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
986 /// provided, given the shape of the source tensor that corresponds to the
987 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
988 /// are "row-major" ordered logically.
989 ///
990 /// For example:
991 ///
992 /// %0 = op ... : tensor<?x?x4x5xf32>
993 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
994 ///
995 /// and reshape:
996 /// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
997 ///                                affine_map<(i, j, k, l) -> (j, k, l)>] :
998 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
999 ///
1000 /// would be rewritten into:
1001 /// %0 = op ... : tensor<?x?x4x5xf32>
1002 /// with output index_map
1003 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
1004 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
1005                                         ArrayRef<int64_t> sourceShape,
1006                                         ArrayRef<AffineMap> reassociationMaps) {
1007   SmallVector<AffineExpr, 4> resultExprs;
1008   resultExprs.reserve(reassociationMaps.size());
1009   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
1010   MLIRContext *context = sourceMap.getContext();
1011 
1012   // Compute the result exprs based on the reassociation maps.
1013   for (AffineMap map : reassociationMaps) {
1014     ArrayRef<AffineExpr> collapsedDims = map.getResults();
1015     // Assume that they are in-order and contiguous (already checked in
1016     // verifier).
1017     assert(!collapsedDims.empty());
1018     unsigned startDim =
1019         collapsedDims.front().cast<AffineDimExpr>().getPosition();
1020     AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
1021         sourceShape.slice(startDim, collapsedDims.size()),
1022         sourceExprs.slice(startDim, collapsedDims.size()), context);
1023     resultExprs.push_back(linearizedExpr);
1024   }
1025   return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
1026                         resultExprs, context);
1027 }
1028 
1029 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
1030 /// true) or its producer (if `asProducer` is false) given the indexing map at
1031 /// its use.
1032 static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
1033                                      AffineMap useIndexMap, bool asProducer) {
1034   RankedTensorType returnType = reshapeOp.getResultType();
1035   RankedTensorType operandType = reshapeOp.getSrcType();
1036   // Reshape is fusible with its consumer (i.e. reshape as a producer) when its
1037   // operand is of lesser rank than the result. Fusing when operand has higher
1038   // rank will require use of mods and divs in the indexing maps of the fused op
1039   // which would make it non-invertible. Similarly reshape is fused with its
1040   // producer (i.e. reshape as consumer) only if the return type has lesser
1041   // rank.
1042   if ((asProducer && returnType.getRank() < operandType.getRank()) ||
1043       (!asProducer && operandType.getRank() < returnType.getRank()))
1044     return false;
1045   return useIndexMap.isIdentity();
1046 }
1047 
1048 /// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
1049 /// is a linalg.generic operation, the create a `linalg.generic` operation with
1050 /// the given `args`. Expects `op` to be `linalg.generic` or
1051 /// `linalg.indexed_generic`.
1052 template <typename... Args>
1053 static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
1054                                          Args... args) {
1055   if (isa<GenericOp>(op.getOperation()))
1056     return cast<LinalgOp>(rewriter.create<GenericOp>(args...).getOperation());
1057   if (isa<IndexedGenericOp>(op.getOperation()))
1058     return cast<LinalgOp>(
1059         rewriter.create<IndexedGenericOp>(args...).getOperation());
1060   llvm_unreachable(
1061       "expected only linalg.generic or linalg.indexed_generic ops");
1062   return nullptr;
1063 }
1064 
1065 namespace {
1066 
1067 /// Implementation of fusion on tensor ops when producer is a TensorReshapeOp.
1068 struct FuseTensorReshapeOpAsProducer {
1069   static bool isFusible(TensorReshapeOp producer, LinalgOp consumer,
1070                         unsigned consumerIdx) {
1071     return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
1072            consumer.hasTensorSemantics() &&
1073            isTensorReshapeOpFusible(producer,
1074                                     consumer.getInputIndexingMap(consumerIdx),
1075                                     /*asProducer=*/true);
1076   }
1077 
1078   static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer,
1079                        unsigned consumerIdx, PatternRewriter &rewriter,
1080                        OperationFolder *folder = nullptr) {
1081     if (producer.src().getDefiningOp<ConstantOp>())
1082       return nullptr;
1083 
1084     if (!isFusible(producer, consumer, consumerIdx))
1085       return nullptr;
1086 
1087     // Compute the fused operands list,
1088     Operation *consumerOp = consumer.getOperation();
1089     SmallVector<Value, 2> fusedOperands(consumerOp->getOperands());
1090     fusedOperands[consumerIdx] = producer.src();
1091 
1092     // Compute indexing_maps for the fused operation. The indexing_maps for the
1093     // operands of the consumers that arent fused are the same.
1094     SmallVector<AffineMap, 4> fusedIndexMaps =
1095         llvm::to_vector<4>(llvm::map_range(
1096             consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
1097               return attr.cast<AffineMapAttr>().getValue();
1098             }));
1099 
1100     // Compute the indexing map to use for the operand of the producer.
1101     AffineMap modifiedMap = linearizeCollapsedDims(
1102         fusedIndexMaps[consumerIdx], producer.getResultType().getShape(),
1103         producer.getReassociationMaps());
1104     for (AffineExpr expr : modifiedMap.getResults()) {
1105       if (!expr.isPureAffine())
1106         return nullptr;
1107     }
1108     fusedIndexMaps[consumerIdx] = modifiedMap;
1109 
1110     // Further check that the resulting index maps can be fused and
1111     // inverted. Without this the resultant op is not legal.
1112     if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
1113       return nullptr;
1114 
1115     SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
1116         llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
1117           return AffineMapAttr::get(map);
1118         }));
1119     LinalgOp fusedOp = createLinalgOpOfSameType(
1120         consumer, rewriter, rewriter.getUnknownLoc(),
1121         consumerOp->getResultTypes(),
1122         /*inputs=*/fusedOperands,
1123         /*outputBuffers=*/ValueRange{},
1124         /*initTensors=*/ValueRange{}, // no init tensors for now.
1125         rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
1126         /*doc=*/nullptr,
1127         /*library_call=*/nullptr,
1128         /*symbol_source=*/nullptr);
1129     auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
1130     rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion,
1131                                fusedRegion.begin());
1132     return fusedOp;
1133   }
1134 };
1135 
1136 /// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
1137 struct FuseTensorReshapeOpAsConsumer {
1138   static bool isCollapsingAndFusible(LinalgOp producer,
1139                                      TensorReshapeOp consumer,
1140                                      unsigned consumerIdx) {
1141     return isa<GenericOp, IndexedGenericOp>(producer.getOperation()) &&
1142            producer.hasTensorSemantics() &&
1143            isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
1144                                     /*asProducer=*/false);
1145   }
1146 
1147   static LinalgOp fuseCollapsingCase(LinalgOp producer,
1148                                      TensorReshapeOp consumer,
1149                                      unsigned consumerIdx,
1150                                      PatternRewriter &rewriter) {
1151     // The indexing_maps for the operands of the fused operation are same as
1152     // those for the operands of the producer.
1153     SmallVector<AffineMap, 4> fusedIndexMaps =
1154         llvm::to_vector<4>(llvm::map_range(
1155             producer.indexing_maps(), [](Attribute attr) -> AffineMap {
1156               return attr.cast<AffineMapAttr>().getValue();
1157             }));
1158     // Compute the indexing map to use for the operand of the producer.
1159     AffineMap modifiedMap = linearizeCollapsedDims(
1160         producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(),
1161         consumer.getReassociationMaps());
1162     for (AffineExpr expr : modifiedMap.getResults()) {
1163       if (!expr.isPureAffine())
1164         return nullptr;
1165     }
1166     fusedIndexMaps.back() = modifiedMap;
1167 
1168     // Further check that the resulting index maps can be fused and
1169     // inverted. Without this the resultant op is not legal.
1170     if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
1171       return nullptr;
1172 
1173     SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
1174         llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
1175           return AffineMapAttr::get(map);
1176         }));
1177 
1178     Operation *producerOp = producer.getOperation();
1179     LinalgOp fusedOp = createLinalgOpOfSameType(
1180         producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(),
1181         /*inputs=*/producerOp->getOperands(),
1182         /*outputBuffers=*/ValueRange{},
1183         /*initTensors=*/ValueRange{}, // no init tensors for now.
1184         rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(),
1185         /*doc=*/nullptr,
1186         /*library_call=*/nullptr,
1187         /*symbol_source=*/nullptr);
1188     auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
1189     rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion,
1190                                fusedRegion.begin());
1191     return fusedOp;
1192   }
1193 
1194   static bool isExpandingAndFusible(LinalgOp producer, TensorReshapeOp consumer,
1195                                     unsigned consumerIdx) {
1196     // Is fusible only if:
1197     //   1) The producer is a generic op.
1198     //   2) The producer has tensor semantics.
1199     //   3) The tensor reshape op is a expanding case.
1200     //   4) All the shapes are the same for the generic op.
1201     //   5) All the indexing maps in producer are identity.
1202     //   6) All the loops in producer are parallel loops.
1203     //   7) The producer has a single user.
1204     auto types = producer.getInputOutputShapedTypes();
1205     assert(!types.empty());
1206     return isa<GenericOp>(producer.getOperation()) &&
1207            producer.hasTensorSemantics() &&
1208            consumer.getSrcType().getRank() <
1209                consumer.getResultType().getRank() &&
1210            std::equal(types.begin() + 1, types.end(), types.begin()) &&
1211            llvm::all_of(producer.getIndexingMaps(),
1212                         [](AffineMap map) { return map.isIdentity(); }) &&
1213            llvm::all_of(producer.iterator_types(),
1214                         [](Attribute attr) {
1215                           return attr.cast<StringAttr>().getValue() ==
1216                                  getParallelIteratorTypeName();
1217                         }) &&
1218            producer.getOperation()->hasOneUse();
1219   }
1220 
1221   static LinalgOp fuseExpandingCase(LinalgOp producer, TensorReshapeOp consumer,
1222                                     unsigned consumerIdx,
1223                                     PatternRewriter &rewriter) {
1224     Location loc = producer.getLoc();
1225     auto dstShape = consumer.getResultType().cast<ShapedType>().getShape();
1226     SmallVector<Value, 4> args;
1227     for (auto arg : producer.getOperation()->getOperands()) {
1228       auto type = RankedTensorType::get(
1229           dstShape, arg.getType().cast<ShapedType>().getElementType());
1230       args.push_back(rewriter.createOrFold<linalg::TensorReshapeOp>(
1231           loc, type, arg, consumer.reassociation()));
1232     }
1233 
1234     SmallVector<Type, 4> resultTypes;
1235     for (auto t : producer.getOutputTensorTypes()) {
1236       Type type = RankedTensorType::get(dstShape,
1237                                         t.cast<ShapedType>().getElementType());
1238       resultTypes.push_back(type);
1239     }
1240 
1241     int rank = dstShape.size();
1242     auto genericOp = rewriter.create<linalg::GenericOp>(
1243         loc, resultTypes, /*inputs=*/args,
1244         /*outputBuffers=*/ValueRange{},
1245         /*initTensors=*/ValueRange{},
1246         SmallVector<AffineMap, 3>(args.size() + resultTypes.size(),
1247                                   rewriter.getMultiDimIdentityMap(rank)),
1248         SmallVector<StringRef, 3>(rank, getParallelIteratorTypeName()));
1249     Region &region = genericOp.getRegion();
1250     rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region,
1251                                region.begin());
1252     return cast<LinalgOp>(genericOp.getOperation());
1253   }
1254 
1255   static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer,
1256                        unsigned consumerIdx, PatternRewriter &rewriter,
1257                        OperationFolder *folder = nullptr) {
1258     if (isCollapsingAndFusible(producer, consumer, consumerIdx))
1259       return fuseCollapsingCase(producer, consumer, consumerIdx, rewriter);
1260     if (isExpandingAndFusible(producer, consumer, consumerIdx))
1261       return fuseExpandingCase(producer, consumer, consumerIdx, rewriter);
1262     return nullptr;
1263   }
1264 };
1265 
1266 /// Implementation of fusion on tensor ops when producer is a splat constant.
1267 struct FuseConstantOpAsProducer {
1268   static bool isFusible(ConstantOp producer, LinalgOp consumer,
1269                         unsigned consumerIdx) {
1270     return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
1271            consumer.hasTensorSemantics() &&
1272            producer.getResult().getType().isa<RankedTensorType>() &&
1273            producer.value().cast<DenseElementsAttr>().isSplat();
1274   }
1275 
1276   static LinalgOp fuse(ConstantOp producer, LinalgOp consumer,
1277                        unsigned consumerIdx, PatternRewriter &rewriter,
1278                        OperationFolder *folder = nullptr) {
1279     if (!isFusible(producer, consumer, consumerIdx))
1280       return nullptr;
1281 
1282     // The indexing_maps for the operands of the fused operation are same as
1283     // those for the operands of the consumer without the indexing map at
1284     // consumerIdx
1285     SmallVector<AffineMap, 4> fusedIndexMaps =
1286         llvm::to_vector<4>(llvm::map_range(
1287             consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
1288               return attr.cast<AffineMapAttr>().getValue();
1289             }));
1290     fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx));
1291 
1292     // The operands list is same as the consumer with the argument for constant
1293     // index dropped.
1294     Operation *consumerOp = consumer.getOperation();
1295     SmallVector<Value, 4> fusedOperands(consumerOp->getOperands());
1296     fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx));
1297 
1298     // Create a constant scalar value from the splat constant.
1299     Value scalarConstant = rewriter.create<ConstantOp>(
1300         producer.getLoc(),
1301         producer.value().cast<DenseElementsAttr>().getSplatValue());
1302 
1303     LinalgOp fusedOp = createLinalgOpOfSameType(
1304         consumer, rewriter, rewriter.getUnknownLoc(),
1305         consumerOp->getResultTypes(),
1306         /*inputs=*/fusedOperands,
1307         /*outputBuffers=*/ValueRange{},
1308         /*initTensors=*/ValueRange{}, // no init tensors for now.
1309         rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1310         consumer.iterator_types(),
1311         /*doc=*/nullptr,
1312         /*library_call=*/nullptr,
1313         /*symbol_source=*/nullptr);
1314 
1315     // Map the block argument corresponding to the replaced argument with the
1316     // scalar constant.
1317     Region &consumerRegion = consumerOp->getRegion(0);
1318     Block &entryBlock = *consumerRegion.begin();
1319     unsigned argIndex = entryBlock.getNumArguments() -
1320                         consumerOp->getNumOperands() + consumerIdx;
1321     BlockAndValueMapping mapping;
1322     mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
1323     Region &fusedRegion = fusedOp.getOperation()->getRegion(0);
1324     rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(),
1325                                mapping);
1326     return fusedOp;
1327   }
1328 };
1329 } // namespace
1330 
1331 Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
1332                                        Operation *consumer,
1333                                        unsigned consumerIdx,
1334                                        OperationFolder *folder) {
1335   if (consumerIdx >= consumer->getNumOperands())
1336     return nullptr;
1337   Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
1338   if (!producer || producer->getNumResults() != 1)
1339     return nullptr;
1340 
1341   // Fuse when consumer is GenericOp or IndexedGenericOp.
1342   if (isa<GenericOp, IndexedGenericOp>(consumer)) {
1343     if (isa<GenericOp, IndexedGenericOp>(producer))
1344       return FuseGenericOpsOnTensors::fuse(cast<LinalgOp>(producer),
1345                                            cast<LinalgOp>(consumer),
1346                                            consumerIdx, rewriter, folder);
1347     if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer))
1348       return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer,
1349                                                  cast<LinalgOp>(consumer),
1350                                                  consumerIdx, rewriter, folder);
1351     if (auto constantOpProducer = dyn_cast<ConstantOp>(producer))
1352       return FuseConstantOpAsProducer::fuse(constantOpProducer,
1353                                             cast<LinalgOp>(consumer),
1354                                             consumerIdx, rewriter, folder);
1355     return nullptr;
1356   }
1357 
1358   if (isa<GenericOp, IndexedGenericOp>(producer)) {
1359     // Fuse when consumer is a TensorReshapeOp.
1360     if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
1361       return FuseTensorReshapeOpAsConsumer::fuse(
1362           cast<LinalgOp>(producer), reshapeOp, consumerIdx, rewriter, folder);
1363     }
1364   }
1365 
1366   return nullptr;
1367 }
1368 
1369 namespace {
1370 /// Patterns to fuse a generic op, with the producer of its operands.
1371 template <typename LinalgOpTy>
1372 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
1373   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
1374 
1375   LogicalResult matchAndRewrite(LinalgOpTy op,
1376                                 PatternRewriter &rewriter) const override {
1377     // Find the first operand that is defined by another generic op on tensors.
1378     for (auto operandNum :
1379          llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) {
1380       Operation *producer =
1381           op.getOperation()->getOperand(operandNum).getDefiningOp();
1382       if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) {
1383         rewriter.replaceOp(op, fusedOp->getResults());
1384         if (producer && llvm::all_of(producer->getResults(),
1385                                      [](Value val) { return val.use_empty(); }))
1386           rewriter.eraseOp(producer);
1387         return success();
1388       }
1389     }
1390     return failure();
1391   }
1392 };
1393 
1394 /// Pass that fuses generic ops on tensors. Used only for testing.
1395 struct FusionOfTensorOpsPass
1396     : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
1397   void runOnOperation() override {
1398     OwningRewritePatternList patterns;
1399     Operation *op = getOperation();
1400     populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
1401     applyPatternsAndFoldGreedily(op->getRegions(), patterns);
1402   };
1403 };
1404 
1405 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
1406   void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
1407 };
1408 } // namespace
1409 
1410 void mlir::populateLinalgTensorOpsFusionPatterns(
1411     MLIRContext *context, OwningRewritePatternList &patterns) {
1412   patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
1413                   FuseTensorOps<TensorReshapeOp>>(context);
1414 }
1415 
1416 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
1417   return std::make_unique<LinalgFusionPass>();
1418 }
1419 
1420 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
1421   return std::make_unique<FusionOfTensorOpsPass>();
1422 }
1423