xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision 8782c727655942c9aa4c80d698c9ba575510799c)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/FoldUtils.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 
32 #define DEBUG_TYPE "linalg-fusion"
33 
34 using namespace mlir;
35 using namespace mlir::edsc;
36 using namespace mlir::edsc::intrinsics;
37 using namespace mlir::linalg;
38 
39 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
40 
41 using llvm::dbgs;
42 
43 /// Implements a simple high-level fusion pass of linalg library operations.
44 ///
45 /// In each block, linalg ops are processed in reverse textual order.
46 /// Given a linalg op `O`, fusion occurs by:
47 ///   1. inspecting the linalg ops that write into the views read by `O`. This
48 ///      uses the SSA value of the views and a simple subview/slice analysis to
49 ///      determine producer-consumer dependences;
50 ///   2. greedily fuse the linalg ops that produce subview
51 ///   3. inspect the fused ops and determine whether they have other remaining
52 ///      LinalgOp uses. If not, then erase the original producing linalg op.
53 ///
54 /// More advanced use cases, analyses as well as profitability heuristics are
55 /// left for future work.
56 
57 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
58 // a subset of the original loop ranges of `op`.
59 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
60 // to the `loopRanges` in order to obtain view ranges.
61 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
62                                     ArrayRef<SubViewOp::Range> loopRanges) {
63   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
64   auto maps = op.indexing_maps();
65   SmallVector<Value, 8> clonedViews;
66   clonedViews.reserve(op.getNumInputsAndOutputs());
67   // Iterate over the inputs and outputs in order.
68   // Extract the subranges from the linearized ranges.
69   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
70   for (auto en : llvm::enumerate(ios)) {
71     unsigned idx = en.index();
72     auto map = maps[idx].cast<AffineMapAttr>().getValue();
73     LLVM_DEBUG(dbgs() << "map: " << map << "\n");
74     Value view = en.value();
75     SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
76     for (auto en2 : llvm::enumerate(map.getResults())) {
77       unsigned d = en2.index();
78       // loopToOperandRangesMaps are permutations-only.
79       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
80       viewRanges[d] = loopRanges[loopPos];
81       LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
82                         << "\t"
83                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
84     }
85     // Construct a new subview for the tile.
86     unsigned rank = viewRanges.size();
87     SmallVector<Value, 4> offsets, sizes, strides;
88     offsets.reserve(rank);
89     sizes.reserve(rank);
90     strides.reserve(rank);
91     for (auto r : viewRanges) {
92       offsets.push_back(r.offset);
93       sizes.push_back(r.size);
94       strides.push_back(r.stride);
95     }
96     clonedViews.push_back(
97         b.create<SubViewOp>(loc, view, offsets, sizes, strides));
98   }
99   auto operands = getAssumedNonViewOperands(op);
100   clonedViews.append(operands.begin(), operands.end());
101 
102   Operation *clonedOp = op.clone(b, loc, clonedViews);
103   // When the producer is an IndexedGenercOp, we have to transform its block
104   // IV arguments according to the tiling of the consumer, i.e. offset them by
105   // the values computed in `loopRanges`.
106   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
107     auto &block = indexedGenericOp.region().front();
108 
109     OpBuilder::InsertionGuard g(b);
110     b.setInsertionPointToStart(&block);
111     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
112       Value oldIndex = block.getArgument(i);
113       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
114                                          loopRanges[i].offset);
115       oldIndex.replaceAllUsesExcept(newIndex,
116                                     SmallPtrSet<Operation *, 1>{newIndex});
117     }
118   }
119   return clonedOp;
120 }
121 
122 struct ViewDimension {
123   Value view;
124   unsigned dimension;
125 };
126 
127 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies
128 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
129 // guarantees at least one such dimension is found. If multiple candidates exist
130 // they must agree by construction (i.e. have the same size) and we just return
131 // the first one.
132 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
133   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
134   auto maps = op.indexing_maps();
135   // Iterate over the inputs and outputs in order.
136   // Extract the subranges from the linearized ranges.
137   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
138   for (auto en : llvm::enumerate(ios)) {
139     unsigned idx = en.index();
140     auto map = maps[idx].cast<AffineMapAttr>().getValue();
141     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
142     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
143     Value view = en.value();
144     SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
145     for (auto en2 : llvm::enumerate(map.getResults())) {
146       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
147         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
148                           << "\n");
149         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
150         return ViewDimension{view, static_cast<unsigned>(en2.index())};
151       }
152     }
153   }
154   llvm_unreachable("Expect to be able to extract a view defining loop range");
155 }
156 
157 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
158                      unsigned consumerIdx, unsigned producerIdx,
159                      OperationFolder *folder) {
160   assert(producer.hasBufferSemantics() &&
161          "expected linalg op with buffer semantics");
162   assert(consumer.hasBufferSemantics() &&
163          "expected linalg op with buffer semantics");
164 
165   auto subView = dyn_cast_or_null<SubViewOp>(
166       consumer.getBuffer(consumerIdx).getDefiningOp());
167   auto slice = dyn_cast_or_null<SliceOp>(
168       consumer.getBuffer(consumerIdx).getDefiningOp());
169   assert(subView || slice);
170   (void)subView;
171   (void)slice;
172 
173   // loopToOperandRangesMaps are permutations-only by construction:
174   //   we can always identify a data dimension with a (at least one) loop
175   //   dimension.
176   AffineMap producerMap =
177       producer.indexing_maps()[producer.getNumInputs() + producerIdx]
178           .cast<AffineMapAttr>()
179           .getValue();
180   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
181                     << ", producer map: " << producerMap << "\n");
182 
183   unsigned nPar = producer.getNumParallelLoops();
184   unsigned nRed = producer.getNumReductionLoops();
185   unsigned nWin = producer.getNumWindowLoops();
186   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
187 
188   OpBuilder b(consumer.getOperation());
189   auto loc = consumer.getLoc();
190   // Iterate over dimensions identified by the producer map for `producerIdx`.
191   // This defines a subset of the loop ranges that we need to complete later.
192   for (auto en : llvm::enumerate(producerMap.getResults())) {
193     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
194     loopRanges[posInProducerLoop] =
195         subView.getOrCreateRanges(b, loc)[en.index()];
196   }
197 
198   // Iterate over all dimensions. For the dimensions not identified by the
199   // producer map for `producerIdx`, we need to explicitly compute the view that
200   // defines the loop ranges using the `producer`.
201   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
202     if (loopRanges[i].offset)
203       LLVM_DEBUG(llvm::dbgs()
204                  << "existing LoopRange: " << loopRanges[i] << "\n");
205     else {
206       auto viewDim = getViewDefiningLoopRange(producer, i);
207       loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0),
208                                        std_dim(viewDim.view, viewDim.dimension),
209                                        folded_std_constant_index(folder, 1)};
210       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
211     }
212   }
213 
214   return cloneWithLoopRanges(b, loc, producer, loopRanges);
215 }
216 
217 // Encode structural fusion safety preconditions.
218 // Some of these will be lifted in the future with better analysis.
219 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
220                                           LinalgOp consumer) {
221   assert(producer.hasBufferSemantics() &&
222          "expected linalg op with buffer semantics");
223   assert(consumer.hasBufferSemantics() &&
224          "expected linalg op with buffer semantics");
225   if (producer.getNumOutputs() != 1) {
226     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
227     return false;
228   }
229   // Only fuse when the producer block dominates.
230   DominanceInfo dom(producer.getOperation());
231   if (!dom.dominates(producer.getOperation()->getBlock(),
232                      consumer.getOperation()->getBlock())) {
233     LLVM_DEBUG(
234         dbgs()
235         << "\nNot structurally fusable (producer block does not dominate)");
236     return false;
237   }
238   return true;
239 }
240 
241 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
242                                              LinalgOp consumer,
243                                              Value consumedView,
244                                              LinalgOp producer) {
245   assert(producer.hasBufferSemantics() &&
246          "expected linalg op with buffer semantics");
247   assert(consumer.hasBufferSemantics() &&
248          "expected linalg op with buffer semantics");
249   // Make some simple structural checks that alleviate the need for more
250   // complex analyses.
251   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
252     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
253                       << *producer.getOperation());
254     return false;
255   }
256   // Check for any interleaved write to consumedView.
257   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
258     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
259                       << *producer.getOperation());
260     return false;
261   }
262   return true;
263 }
264 
265 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
266                                  LinalgOp consumer, Value consumedView,
267                                  LinalgOp producer) {
268   assert(producer.hasBufferSemantics() &&
269          "expected linalg op with buffer semantics");
270   assert(consumer.hasBufferSemantics() &&
271          "expected linalg op with buffer semantics");
272   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
273     return false;
274   // Check for any fusion-preventing dependence to any view read/written that
275   // would violate dependences.
276   if (!graph.findCoveringDependences(producer, consumer).empty()) {
277     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
278                       << *producer.getOperation());
279     return false;
280   }
281   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
282     // TODO: add a level of indirection to linalg.generic.
283     if (convOp.padding())
284       return false;
285   }
286   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
287     // TODO: add a level of indirection to linalg.generic.
288     if (convOp.padding())
289       return false;
290   }
291   return true;
292 }
293 
294 static bool isSameSubView(Value a, Value b) {
295   if (a == b)
296     return true;
297   auto sva = a.getDefiningOp<SubViewOp>();
298   auto svb = b.getDefiningOp<SubViewOp>();
299   if (!sva || !svb)
300     return false;
301   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
302     return false;
303   if (sva.getType() != svb.getType())
304     return false;
305   if (sva.getRank() != svb.getRank())
306     return false;
307   if (sva.getNumOperands() != svb.getNumOperands())
308     return false;
309   if (sva.static_offsets() != svb.static_offsets())
310     return false;
311   if (sva.static_sizes() != svb.static_sizes())
312     return false;
313   if (sva.static_strides() != svb.static_strides())
314     return false;
315   /// Skip the "viewSource" operand.
316   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
317     if (sva.getOperand(idx) != svb.getOperand(idx))
318       return false;
319   return true;
320 }
321 
322 static Optional<FusionInfo>
323 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
324                   const LinalgDependenceGraph &graph, OperationFolder *folder,
325                   LinalgDependenceGraph::DependenceType depType) {
326   assert(consumer.hasBufferSemantics() &&
327          "expected linalg op with buffer semantics");
328   LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
329                     << *consumer.getOperation());
330   for (auto dependence : graph.getDependencesInto(consumer, depType)) {
331     LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
332                       << *dependence.dependentOpView.op << "\n");
333     auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
334 
335     // Check that the dependence is indeed on the input `consumerIdx` view.
336     auto consumedView = dependence.indexingView;
337     if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
338       continue;
339 
340     // Consumer consumes this view, `isStructurallyFusableProducer` also checks
341     // whether it is a strict subview of the producer view.
342     auto producedView = dependence.dependentOpView.view;
343     auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
344     // `consumerIdx` and `producerIdx` exist by construction.
345     LLVM_DEBUG(dbgs() << "\n"
346                       << LinalgDependenceGraph::getDependenceTypeStr(depType)
347                       << "producer: " << *producer.getOperation() << " view: "
348                       << producedView << " output index: " << producerIdx);
349 
350     // Must be a subview or a slice to guarantee there are loops we can fuse
351     // into.
352     auto subView = consumedView.getDefiningOp<SubViewOp>();
353     auto slice = consumedView.getDefiningOp<SliceOp>();
354     if (!subView && !slice) {
355       LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
356       continue;
357     }
358 
359     // Simple fusability checks.
360     if (!isFusableInto(graph, consumer, consumedView, producer))
361       continue;
362 
363     // Fuse `producer` just before `consumer`.
364     OpBuilder::InsertionGuard g(b);
365     b.setInsertionPoint(consumer.getOperation());
366     ScopedContext scope(b, consumer.getLoc());
367     LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
368     auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
369                               producerIdx, folder);
370 
371     return FusionInfo{producer, fusedProducer};
372   }
373   return llvm::None;
374 }
375 
376 // Only consider RAW and WAW atm.
377 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
378     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
379     const LinalgDependenceGraph &graph, OperationFolder *folder) {
380   for (auto dep : {
381            LinalgDependenceGraph::DependenceType::RAW,
382            LinalgDependenceGraph::DependenceType::WAW,
383        }) {
384     if (auto res =
385             fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep))
386       return res;
387   }
388   return llvm::None;
389 }
390 
391 static void fuseLinalgOpsGreedily(FuncOp f) {
392   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
393 
394   OpBuilder b(f);
395   OperationFolder folder(f.getContext());
396   DenseSet<Operation *> eraseSet;
397 
398   // Save original Linalg ops, we only want to make a pass over those.
399   SmallVector<Operation *, 8> linalgOps;
400   f.walk([&](LinalgOp op) {
401     if (op.hasBufferSemantics())
402       linalgOps.push_back(op);
403   });
404 
405   // TODO: LinalgDependenceGraph should be able to update itself.
406   // The current naive and expensive reconstruction of the graph should be
407   // removed.
408   for (auto *op : llvm::reverse(linalgOps)) {
409     for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers();
410          id < e; ++id) {
411       linalg::Aliases aliases;
412       linalg::LinalgDependenceGraph graph(aliases, linalgOps);
413       if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
414         auto *originalOp = info->originalProducer.getOperation();
415         eraseSet.insert(originalOp);
416         auto *originalOpInLinalgOpsVector =
417             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
418         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
419       }
420     }
421   }
422   // The `fuseProducerOf` function performs structural checks and in particular
423   // that no covering read or write exist between the consumer and the producer.
424   // As a consequence, the only fusions that may occur preserve subsequent
425   // dependences and are guaranteed by construction to produce the whole view.
426   // We may thus erase the producer once it is fused.
427   for (auto *e : eraseSet)
428     e->erase();
429   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
430 }
431 
432 //====---------------------------------------------------------------------===//
433 // Fusion on Tensor operation.
434 //====---------------------------------------------------------------------===//
435 
436 namespace {
437 
438 /// Implementation of fusion of generic ops and indexed_generic ops.
439 struct FuseGenericOpsOnTensors {
440   static bool isFusible(LinalgOp producer, LinalgOp consumer,
441                         unsigned consumerIdx) {
442     // Producer and consumer must have tensor semantics.
443     if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
444       return false;
445 
446     // Verify that
447     // - the producer has all "parallel" iterator type.
448     if (producer.getNumParallelLoops() != producer.getNumLoops())
449       return false;
450 
451     // Get the consumer index map. The number of results of the consumer index
452     // map must match the number of loops of the producer.
453     AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
454     if (consumerIndexMap.getNumResults() != producer.getNumLoops())
455       return false;
456 
457     // Finally the index_map for the result must be invertible. For now just
458     // verify it is a permutation.
459     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
460     return producerResultIndexMap.isPermutation();
461   }
462 
463   static LinalgOp fuse(LinalgOp producer, LinalgOp consumer,
464                        unsigned consumerIdx, PatternRewriter &rewriter,
465                        OperationFolder *folder = nullptr) {
466     if (!isFusible(producer, consumer, consumerIdx))
467       return nullptr;
468 
469     unsigned numFusedOperands = producer.getOperation()->getNumOperands() +
470                                 consumer.getOperation()->getNumOperands() - 1;
471 
472     // Compute the fused operands list,
473     SmallVector<Value, 2> fusedOperands;
474     fusedOperands.reserve(numFusedOperands);
475     auto consumerOperands = consumer.getOperation()->getOperands();
476     auto producerOperands = producer.getOperation()->getOperands();
477     fusedOperands.assign(consumerOperands.begin(),
478                          std::next(consumerOperands.begin(), consumerIdx));
479     fusedOperands.append(producerOperands.begin(), producerOperands.end());
480     fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
481                          consumerOperands.end());
482 
483     // Compute indexing_maps for the fused operation. The indexing_maps for the
484     // operands of the consumers that arent fused are the same. The
485     // indexing_maps for the producers need to be computed based on the
486     // indexing_map of the operand at consumerIdx in the consumer.
487     SmallVector<Attribute, 4> fusedIndexMaps;
488     auto consumerIndexMaps = consumer.indexing_maps();
489     fusedIndexMaps.reserve(fusedOperands.size() +
490                            consumer.getOperation()->getNumResults());
491     fusedIndexMaps.assign(consumerIndexMaps.begin(),
492                           std::next(consumerIndexMaps.begin(), consumerIdx));
493     // Compute indexing maps for the producer args in the fused operation.
494     computeProducerOperandIndex(
495         producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
496 
497     // Append the indexing maps for the remaining consumer operands.
498     fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
499                           consumerIndexMaps.end());
500 
501     // Generate the fused op.
502     LinalgOp fusedOp;
503     if (isa<GenericOp>(producer.getOperation()) &&
504         isa<GenericOp>(consumer.getOperation())) {
505       fusedOp =
506           rewriter
507               .create<GenericOp>(
508                   rewriter.getUnknownLoc(),
509                   consumer.getOperation()->getResultTypes(), fusedOperands,
510                   rewriter.getI64IntegerAttr(fusedOperands.size()),
511                   rewriter.getI64IntegerAttr(
512                       consumer.getOperation()->getNumResults()),
513                   rewriter.getArrayAttr(fusedIndexMaps),
514                   consumer.iterator_types(),
515                   /*doc=*/nullptr,
516                   /*library_call=*/nullptr,
517                   /*symbol_source=*/nullptr)
518               .getOperation();
519     } else {
520       fusedOp =
521           rewriter
522               .create<IndexedGenericOp>(
523                   rewriter.getUnknownLoc(),
524                   consumer.getOperation()->getResultTypes(), fusedOperands,
525                   rewriter.getI64IntegerAttr(fusedOperands.size()),
526                   rewriter.getI64IntegerAttr(
527                       consumer.getOperation()->getNumResults()),
528                   rewriter.getArrayAttr(fusedIndexMaps),
529                   consumer.iterator_types(),
530                   /*doc=*/nullptr,
531                   /*library_call=*/nullptr,
532                   /*symbol_source=*/nullptr)
533               .getOperation();
534     }
535 
536     // Construct an AffineMap from consumer loops to producer loops.
537     // consumer loop -> tensor index
538     AffineMap consumerResultIndexMap =
539         consumer.getInputIndexingMap(consumerIdx);
540     // producer loop -> tensor index
541     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
542     // tensor index -> producer loop
543     AffineMap invProducerResultIndexMap =
544         inversePermutation(producerResultIndexMap);
545     assert(invProducerResultIndexMap &&
546            "expected producer result indexig map to be invertible");
547     // consumer loop -> producer loop
548     AffineMap consumerToProducerLoopsMap =
549         invProducerResultIndexMap.compose(consumerResultIndexMap);
550 
551     generateFusedRegion(rewriter, fusedOp, producer, consumer,
552                         consumerToProducerLoopsMap, consumerIdx,
553                         consumer.getNumLoops());
554     return fusedOp;
555   }
556 
557 private:
558   /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
559   /// the `producer` to use in the fused operation given the indexing map of the
560   /// result of the producer in the consumer.
561   static void computeProducerOperandIndex(
562       LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
563       SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
564     // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
565     // from consumer loop -> consumer arg tensor index/producer result tensor
566     // index. The fused loop is same as the consumer loop. For each producer arg
567     // the indexing map to be computed is a map from consumer loop -> producer
568     // arg tensor index.
569 
570     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
571     // producerResultIndexMap is a map from producer loop -> tensor index.
572     // Compute the inverse to get map from tensor index -> producer loop.
573     // The inverse is a map from producer result tensor index -> producer loop.
574     AffineMap invProducerResultIndexMap =
575         inversePermutation(producerResultIndexMap);
576     assert(invProducerResultIndexMap &&
577            "expected producer result indexig map to be invertible");
578     for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
579       // argMap is a map from producer loop -> producer arg tensor index.
580       AffineMap argMap = producer.getInputIndexingMap(argNum);
581 
582       // Compose argMap with invProducerResultIndexMap to get a map from
583       // producer result tensor index -> producer arg tensor index.
584       AffineMap t1 = argMap.compose(invProducerResultIndexMap);
585 
586       // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
587       // consumer loop/ fused loop -> producer arg tensor index.
588       AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
589       fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
590     }
591   }
592 
593   /// Generate the region of the fused operation. The region of the fused op
594   /// must be empty.
595   static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp,
596                                   LinalgOp producer, LinalgOp consumer,
597                                   AffineMap consumerToProducerLoopsMap,
598                                   unsigned consumerIdx, unsigned nloops) {
599     // Build the region of the fused op.
600     Block &producerBlock = producer.getOperation()->getRegion(0).front();
601     Block &consumerBlock = consumer.getOperation()->getRegion(0).front();
602     Block *fusedBlock = new Block();
603     fusedOp->getRegion(0).push_back(fusedBlock);
604     BlockAndValueMapping mapper;
605     OpBuilder::InsertionGuard guard(rewriter);
606     rewriter.setInsertionPointToStart(fusedBlock);
607 
608     // The block arguments are
609     // [index_0, index_1, ... ,
610     //   consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
611     //   producer_operand_0, ... , producer_operand_(n-1)],
612     //   consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
613     // , where n is the number of producer's operand and m is the number
614     // consumer's operand.
615     // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
616     // generic op. In this case, there are no indices in block arguments.
617     unsigned numProducerIndices =
618         isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
619     unsigned numConsumerIndices =
620         isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
621     // Firstly, add all the indices to the block arguments.
622     for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
623          i < e; ++i)
624       fusedBlock->addArgument(rewriter.getIndexType());
625     // Map the arguments for the unmodified args from the consumer.
626     for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
627       if (consumerArg.index() == consumerIdx + numConsumerIndices) {
628         // Map the arguments for the args from the producer.
629         for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) {
630           // If producer is an indexed_generic op, map the indices from consumer
631           // loop to producer loop (because the fusedOp is built based on
632           // consumer's perspective).
633           if (producerArg.index() < numProducerIndices) {
634             auto newIndex = rewriter.create<mlir::AffineApplyOp>(
635                 producer.getLoc(),
636                 consumerToProducerLoopsMap.getSubMap(producerArg.index()),
637                 fusedBlock->getArguments().take_front(nloops));
638             mapper.map(producerArg.value(), newIndex);
639           } else {
640             mapper.map(producerArg.value(),
641                        fusedBlock->addArgument(producerArg.value().getType()));
642           }
643         }
644         continue;
645       }
646 
647       // If consumer is an indexed_generic op, map the indices to the block
648       // arguments directly. Otherwise, add the same type of arugment and map to
649       // it.
650       if (consumerArg.index() < numConsumerIndices) {
651         mapper.map(consumerArg.value(),
652                    fusedBlock->getArgument(consumerArg.index()));
653       } else {
654         mapper.map(consumerArg.value(),
655                    fusedBlock->addArgument(consumerArg.value().getType()));
656       }
657     }
658 
659     // Add operations from producer (except the yield operation) to the fused
660     // op.
661     for (auto &op : producerBlock.getOperations()) {
662       if (auto yieldOp = dyn_cast<YieldOp>(op)) {
663         // Lookup the value the yield operation is mapped to.
664         Value yieldVal = yieldOp.getOperand(0);
665         if (Value clonedVal = mapper.lookupOrNull(yieldVal))
666           mapper.map(
667               consumerBlock.getArgument(consumerIdx + numConsumerIndices),
668               clonedVal);
669         continue;
670       }
671       rewriter.clone(op, mapper);
672     }
673     for (auto &op : consumerBlock.getOperations())
674       rewriter.clone(op, mapper);
675   }
676 };
677 } // namespace
678 
679 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
680 /// provided, given the shape of the source tensor that corresponds to the
681 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
682 /// are "row-major" ordered logically.
683 ///
684 /// For example:
685 ///
686 /// %0 = op ... : tensor<?x?x4x5xf32>
687 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
688 ///
689 /// and reshape:
690 /// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
691 ///                                affine_map<(i, j, k, l) -> (j, k, l)>] :
692 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
693 ///
694 /// would be rewritten into:
695 /// %0 = op ... : tensor<?x?x4x5xf32>
696 /// with output index_map
697 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
698 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
699                                         ArrayRef<int64_t> sourceShape,
700                                         ArrayRef<AffineMap> reassociationMaps) {
701   SmallVector<AffineExpr, 4> resultExprs;
702   resultExprs.reserve(reassociationMaps.size());
703   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
704   MLIRContext *context = sourceMap.getContext();
705 
706   // Compute the result exprs based on the reassociation maps.
707   for (AffineMap map : reassociationMaps) {
708     ArrayRef<AffineExpr> collapsedDims = map.getResults();
709     // Assume that they are in-order and contiguous (already checked in
710     // verifier).
711     assert(!collapsedDims.empty());
712     unsigned startDim =
713         collapsedDims.front().cast<AffineDimExpr>().getPosition();
714     AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
715         sourceShape.slice(startDim, collapsedDims.size()),
716         sourceExprs.slice(startDim, collapsedDims.size()), context);
717     resultExprs.push_back(linearizedExpr);
718   }
719   return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
720                         resultExprs, context);
721 }
722 
723 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
724 /// true) or its producer (if `asProducer` is false) given the indexing map at
725 /// its use.
726 static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
727                                      AffineMap useIndexMap, bool asProducer) {
728   RankedTensorType returnType = reshapeOp.getResultType();
729   RankedTensorType operandType = reshapeOp.getSrcType();
730   // Reshape is fusible with its consumer (i.e. reshape as a producer) when its
731   // operand is of lesser rank than the result. Fusing when operand has higher
732   // rank will require use of mods and divs in the indexing maps of the fused op
733   // which would make it non-invertible. Similarly reshape is fused with its
734   // producer (i.e. reshape as consumer) only if the return type has lesser
735   // rank.
736   if ((asProducer && returnType.getRank() < operandType.getRank()) ||
737       (!asProducer && operandType.getRank() < returnType.getRank()))
738     return false;
739   return useIndexMap.isIdentity();
740 }
741 
742 /// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
743 /// is a linalg.generic operation, the create a `linalg.generic` operation with
744 /// the given `args`. Expects `op` to be `linalg.generic` or
745 /// `linalg.indexed_generic`.
746 template <typename... Args>
747 static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
748                                          Args... args) {
749   if (isa<GenericOp>(op.getOperation()))
750     return cast<LinalgOp>(rewriter.create<GenericOp>(args...).getOperation());
751   if (isa<IndexedGenericOp>(op.getOperation()))
752     return cast<LinalgOp>(
753         rewriter.create<IndexedGenericOp>(args...).getOperation());
754   llvm_unreachable(
755       "expected only linalg.generic or linalg.indexed_generic ops");
756   return nullptr;
757 }
758 
759 namespace {
760 
761 /// Implementation of fusion on tensor ops when producer is a TensorReshapeOp.
762 struct FuseTensorReshapeOpAsProducer {
763   static bool isFusible(TensorReshapeOp producer, LinalgOp consumer,
764                         unsigned consumerIdx) {
765     return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
766            consumer.hasTensorSemantics() &&
767            isTensorReshapeOpFusible(producer,
768                                     consumer.getInputIndexingMap(consumerIdx),
769                                     /*asProducer=*/true);
770   }
771 
772   static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer,
773                        unsigned consumerIdx, PatternRewriter &rewriter,
774                        OperationFolder *folder = nullptr) {
775     if (producer.src().getDefiningOp<ConstantOp>())
776       return nullptr;
777 
778     if (!isFusible(producer, consumer, consumerIdx))
779       return nullptr;
780 
781     // Compute the fused operands list,
782     Operation *consumerOp = consumer.getOperation();
783     SmallVector<Value, 2> fusedOperands(consumerOp->getOperands());
784     fusedOperands[consumerIdx] = producer.src();
785 
786     // Compute indexing_maps for the fused operation. The indexing_maps for the
787     // operands of the consumers that arent fused are the same.
788     SmallVector<AffineMap, 4> fusedIndexMaps =
789         llvm::to_vector<4>(llvm::map_range(
790             consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
791               return attr.cast<AffineMapAttr>().getValue();
792             }));
793 
794     // Compute the indexing map to use for the operand of the producer.
795     AffineMap modifiedMap = linearizeCollapsedDims(
796         fusedIndexMaps[consumerIdx], producer.getResultType().getShape(),
797         producer.getReassociationMaps());
798     for (AffineExpr expr : modifiedMap.getResults()) {
799       if (!expr.isPureAffine())
800         return nullptr;
801     }
802     fusedIndexMaps[consumerIdx] = modifiedMap;
803 
804     // Further check that the resulting index maps can be fused and
805     // inverted. Without this the resultant op is not legal.
806     if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
807       return nullptr;
808 
809     SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
810         llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
811           return AffineMapAttr::get(map);
812         }));
813     LinalgOp fusedOp = createLinalgOpOfSameType(
814         consumer, rewriter, rewriter.getUnknownLoc(),
815         consumerOp->getResultTypes(), fusedOperands,
816         rewriter.getI64IntegerAttr(fusedOperands.size()),
817         rewriter.getI64IntegerAttr(consumerOp->getNumResults()),
818         rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
819         /*doc=*/nullptr,
820         /*library_call=*/nullptr,
821         /*symbol_source=*/nullptr);
822     auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
823     rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion,
824                                fusedRegion.begin());
825     return fusedOp;
826   }
827 };
828 
829 /// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
830 struct FuseTensorReshapeOpAsConsumer {
831   static bool isCollapsingAndFusible(LinalgOp producer,
832                                      TensorReshapeOp consumer,
833                                      unsigned consumerIdx) {
834     return isa<GenericOp, IndexedGenericOp>(producer.getOperation()) &&
835            producer.hasTensorSemantics() &&
836            isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
837                                     /*asProducer=*/false);
838   }
839 
840   static LinalgOp fuseCollapsingCase(LinalgOp producer,
841                                      TensorReshapeOp consumer,
842                                      unsigned consumerIdx,
843                                      PatternRewriter &rewriter) {
844     // The indexing_maps for the operands of the fused operation are same as
845     // those for the operands of the producer.
846     SmallVector<AffineMap, 4> fusedIndexMaps =
847         llvm::to_vector<4>(llvm::map_range(
848             producer.indexing_maps(), [](Attribute attr) -> AffineMap {
849               return attr.cast<AffineMapAttr>().getValue();
850             }));
851     // Compute the indexing map to use for the operand of the producer.
852     AffineMap modifiedMap = linearizeCollapsedDims(
853         producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(),
854         consumer.getReassociationMaps());
855     for (AffineExpr expr : modifiedMap.getResults()) {
856       if (!expr.isPureAffine())
857         return nullptr;
858     }
859     fusedIndexMaps.back() = modifiedMap;
860 
861     // Further check that the resulting index maps can be fused and
862     // inverted. Without this the resultant op is not legal.
863     if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
864       return nullptr;
865 
866     SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
867         llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
868           return AffineMapAttr::get(map);
869         }));
870 
871     Operation *producerOp = producer.getOperation();
872     LinalgOp fusedOp = createLinalgOpOfSameType(
873         producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(),
874         producerOp->getOperands(),
875         rewriter.getI64IntegerAttr(producerOp->getNumOperands()),
876         rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs),
877         producer.iterator_types(),
878         /*doc=*/nullptr,
879         /*library_call=*/nullptr,
880         /*symbol_source=*/nullptr);
881     auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
882     rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion,
883                                fusedRegion.begin());
884     return fusedOp;
885   }
886 
887   static bool isExpandingAndFusible(LinalgOp producer, TensorReshapeOp consumer,
888                                     unsigned consumerIdx) {
889     // Is fusible only if:
890     //   1) The producer is a generic op.
891     //   2) The producer has tensor semantics.
892     //   3) The tensor reshape op is a expanding case.
893     //   4) All the shapes are the same for the generic op.
894     //   5) All the indexing maps in producer are identity.
895     //   6) All the loops in producer are parallel loops.
896     //   7) The producer has a single user.
897     auto types = producer.getInputOutputShapedTypes();
898     assert(!types.empty());
899     return isa<GenericOp>(producer.getOperation()) &&
900            producer.hasTensorSemantics() &&
901            consumer.getSrcType().getRank() <
902                consumer.getResultType().getRank() &&
903            std::equal(types.begin() + 1, types.end(), types.begin()) &&
904            llvm::all_of(producer.getIndexingMaps(),
905                         [](AffineMap map) { return map.isIdentity(); }) &&
906            llvm::all_of(producer.iterator_types(),
907                         [](Attribute attr) {
908                           return attr.cast<StringAttr>().getValue() ==
909                                  getParallelIteratorTypeName();
910                         }) &&
911            producer.getOperation()->hasOneUse();
912   }
913 
914   static LinalgOp fuseExpandingCase(LinalgOp producer, TensorReshapeOp consumer,
915                                     unsigned consumerIdx,
916                                     PatternRewriter &rewriter) {
917     Location loc = producer.getLoc();
918     auto dstShape = consumer.getResultType().cast<ShapedType>().getShape();
919     SmallVector<Value, 4> args;
920     for (auto arg : producer.getOperation()->getOperands()) {
921       auto type = RankedTensorType::get(
922           dstShape, arg.getType().cast<ShapedType>().getElementType());
923       args.push_back(rewriter.createOrFold<linalg::TensorReshapeOp>(
924           loc, type, arg, consumer.reassociation()));
925     }
926 
927     SmallVector<Type, 4> resultTypes;
928     for (auto t : producer.getOutputTensorTypes()) {
929       Type type = RankedTensorType::get(dstShape,
930                                         t.cast<ShapedType>().getElementType());
931       resultTypes.push_back(type);
932     }
933 
934     int rank = dstShape.size();
935     int numArgsIn = producer.getNumInputs();
936     int numArgsOut = producer.getNumOutputs();
937     auto genericOp = rewriter.create<linalg::GenericOp>(
938         loc, resultTypes, args, numArgsIn, numArgsOut,
939         SmallVector<AffineMap, 3>(args.size() + resultTypes.size(),
940                                   rewriter.getMultiDimIdentityMap(rank)),
941         SmallVector<StringRef, 3>(rank, getParallelIteratorTypeName()));
942     Region &region = genericOp.getRegion();
943     rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region,
944                                region.begin());
945     return cast<LinalgOp>(genericOp.getOperation());
946   }
947 
948   static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer,
949                        unsigned consumerIdx, PatternRewriter &rewriter,
950                        OperationFolder *folder = nullptr) {
951     if (isCollapsingAndFusible(producer, consumer, consumerIdx))
952       return fuseCollapsingCase(producer, consumer, consumerIdx, rewriter);
953     if (isExpandingAndFusible(producer, consumer, consumerIdx))
954       return fuseExpandingCase(producer, consumer, consumerIdx, rewriter);
955     return nullptr;
956   }
957 };
958 
959 /// Implementation of fusion on tensor ops when producer is a splat constant.
960 struct FuseConstantOpAsProducer {
961   static bool isFusible(ConstantOp producer, LinalgOp consumer,
962                         unsigned consumerIdx) {
963     return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
964            consumer.hasTensorSemantics() &&
965            producer.getResult().getType().isa<RankedTensorType>() &&
966            producer.value().cast<DenseElementsAttr>().isSplat();
967   }
968 
969   static LinalgOp fuse(ConstantOp producer, LinalgOp consumer,
970                        unsigned consumerIdx, PatternRewriter &rewriter,
971                        OperationFolder *folder = nullptr) {
972     if (!isFusible(producer, consumer, consumerIdx))
973       return nullptr;
974 
975     // The indexing_maps for the operands of the fused operation are same as
976     // those for the operands of the consumer without the indexing map at
977     // consumerIdx
978     SmallVector<AffineMap, 4> fusedIndexMaps =
979         llvm::to_vector<4>(llvm::map_range(
980             consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
981               return attr.cast<AffineMapAttr>().getValue();
982             }));
983     fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx));
984 
985     // The operands list is same as the consumer with the argument for constant
986     // index dropped.
987     Operation *consumerOp = consumer.getOperation();
988     SmallVector<Value, 4> fusedOperands(consumerOp->getOperands());
989     fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx));
990 
991     // Create a constant scalar value from the splat constant.
992     Value scalarConstant = rewriter.create<ConstantOp>(
993         producer.getLoc(),
994         producer.value().cast<DenseElementsAttr>().getSplatValue());
995 
996     LinalgOp fusedOp = createLinalgOpOfSameType(
997         consumer, rewriter, rewriter.getUnknownLoc(),
998         consumerOp->getResultTypes(), fusedOperands,
999         rewriter.getI64IntegerAttr(consumerOp->getNumOperands() - 1),
1000         rewriter.getI64IntegerAttr(consumerOp->getNumResults()),
1001         rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1002         consumer.iterator_types(),
1003         /*doc=*/nullptr,
1004         /*library_call=*/nullptr,
1005         /*symbol_source=*/nullptr);
1006 
1007     // Map the block argument corresponding to the replaced argument with the
1008     // scalar constant.
1009     Region &consumerRegion = consumerOp->getRegion(0);
1010     Block &entryBlock = *consumerRegion.begin();
1011     unsigned argIndex = entryBlock.getNumArguments() -
1012                         consumerOp->getNumOperands() + consumerIdx;
1013     BlockAndValueMapping mapping;
1014     mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
1015     Region &fusedRegion = fusedOp.getOperation()->getRegion(0);
1016     rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(),
1017                                mapping);
1018     return fusedOp;
1019   }
1020 };
1021 } // namespace
1022 
1023 Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
1024                                        Operation *consumer,
1025                                        unsigned consumerIdx,
1026                                        OperationFolder *folder) {
1027   if (consumerIdx >= consumer->getNumOperands())
1028     return nullptr;
1029   Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
1030   if (!producer || producer->getNumResults() != 1)
1031     return nullptr;
1032 
1033   // Fuse when consumer is GenericOp or IndexedGenericOp.
1034   if (isa<GenericOp, IndexedGenericOp>(consumer)) {
1035     if (isa<GenericOp, IndexedGenericOp>(producer))
1036       return FuseGenericOpsOnTensors::fuse(cast<LinalgOp>(producer),
1037                                            cast<LinalgOp>(consumer),
1038                                            consumerIdx, rewriter, folder);
1039     if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer))
1040       return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer,
1041                                                  cast<LinalgOp>(consumer),
1042                                                  consumerIdx, rewriter, folder);
1043     if (auto constantOpProducer = dyn_cast<ConstantOp>(producer))
1044       return FuseConstantOpAsProducer::fuse(constantOpProducer,
1045                                             cast<LinalgOp>(consumer),
1046                                             consumerIdx, rewriter, folder);
1047     return nullptr;
1048   }
1049 
1050   if (isa<GenericOp, IndexedGenericOp>(producer)) {
1051     // Fuse when consumer is a TensorReshapeOp.
1052     if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
1053       return FuseTensorReshapeOpAsConsumer::fuse(
1054           cast<LinalgOp>(producer), reshapeOp, consumerIdx, rewriter, folder);
1055     }
1056   }
1057 
1058   return nullptr;
1059 }
1060 
1061 namespace {
1062 /// Patterns to fuse a generic op, with the producer of its operands.
1063 template <typename LinalgOpTy>
1064 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
1065   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
1066 
1067   LogicalResult matchAndRewrite(LinalgOpTy op,
1068                                 PatternRewriter &rewriter) const override {
1069     // Find the first operand that is defined by another generic op on tensors.
1070     for (auto operandNum :
1071          llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) {
1072       Operation *producer =
1073           op.getOperation()->getOperand(operandNum).getDefiningOp();
1074       if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) {
1075         rewriter.replaceOp(op, fusedOp->getResults());
1076         if (producer && llvm::all_of(producer->getResults(),
1077                                      [](Value val) { return val.use_empty(); }))
1078           rewriter.eraseOp(producer);
1079         return success();
1080       }
1081     }
1082     return failure();
1083   }
1084 };
1085 
1086 /// Pass that fuses generic ops on tensors. Used only for testing.
1087 struct FusionOfTensorOpsPass
1088     : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
1089   void runOnOperation() override {
1090     OwningRewritePatternList patterns;
1091     Operation *op = getOperation();
1092     populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
1093     applyPatternsAndFoldGreedily(op->getRegions(), patterns);
1094   };
1095 };
1096 
1097 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
1098   void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
1099 };
1100 } // namespace
1101 
1102 void mlir::populateLinalgTensorOpsFusionPatterns(
1103     MLIRContext *context, OwningRewritePatternList &patterns) {
1104   patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
1105                   FuseTensorOps<TensorReshapeOp>>(context);
1106 }
1107 
1108 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
1109   return std::make_unique<LinalgFusionPass>();
1110 }
1111 
1112 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
1113   return std::make_unique<FusionOfTensorOpsPass>();
1114 }
1115