xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision 3145427dd73f0ee16dac4044890e2e2d2cae5040)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/Dominance.h"
14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Support/STLExtras.h"
28 #include "mlir/Transforms/FoldUtils.h"
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "linalg-fusion"
34 
35 using namespace mlir;
36 using namespace mlir::edsc;
37 using namespace mlir::edsc::intrinsics;
38 using namespace mlir::linalg;
39 
40 using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>;
41 
42 using llvm::dbgs;
43 
44 /// Implements a simple high-level fusion pass of linalg library operations.
45 ///
46 /// In each block, linalg ops are processed in reverse textual order.
47 /// Given a linalg op `O`, fusion occurs by:
48 ///   1. inspecting the linalg ops that write into the views read by `O`. This
49 ///      uses the SSA value of the views and a simple subview/slice analysis to
50 ///      determine producer-consumer dependences;
51 ///   2. greedily fuse the linalg ops that produce subview
52 ///   3. inspect the fused ops and determine whether they have other remaining
53 ///      LinalgOp uses. If not, then erase the original producing linalg op.
54 ///
55 /// More advanced use cases, analyses as well as profitability heuristics are
56 /// left for future work.
57 
58 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
59 static llvm::cl::list<unsigned> clTileSizes(
60     "linalg-fusion-tile-sizes",
61     llvm::cl::desc(
62         "Tile sizes by which to tile linalg operations during linalg fusion"),
63     llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
64     llvm::cl::cat(clOptionsCategory));
65 
66 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
67 // a subset of the original loop ranges of `op`.
68 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
69 // to the `loopRanges` in order to obtain view ranges.
70 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
71                                     ArrayRef<SubViewOp::Range> loopRanges) {
72   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
73   auto maps = op.indexing_maps();
74   SmallVector<Value, 8> clonedViews;
75   clonedViews.reserve(op.getNumInputsAndOutputs());
76   // Iterate over the inputs and outputs in order.
77   // Extract the subranges from the linearized ranges.
78   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
79   for (auto en : llvm::enumerate(ios)) {
80     unsigned idx = en.index();
81     auto map = maps[idx].cast<AffineMapAttr>().getValue();
82     LLVM_DEBUG(dbgs() << "map: " << map << "\n");
83     Value view = en.value();
84     SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
85     for (auto en2 : llvm::enumerate(map.getResults())) {
86       unsigned d = en2.index();
87       // loopToOperandRangesMaps are permutations-only.
88       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
89       viewRanges[d] = loopRanges[loopPos];
90       LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
91                         << "\t"
92                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
93     }
94     // Construct a new subview for the tile.
95     unsigned rank = viewRanges.size();
96     SmallVector<Value, 4> offsets, sizes, strides;
97     offsets.reserve(rank);
98     sizes.reserve(rank);
99     strides.reserve(rank);
100     for (auto r : viewRanges) {
101       offsets.push_back(r.offset);
102       sizes.push_back(r.size);
103       strides.push_back(r.stride);
104     }
105     clonedViews.push_back(
106         b.create<SubViewOp>(loc, view, offsets, sizes, strides));
107   }
108   auto operands = getAssumedNonViewOperands(op);
109   clonedViews.append(operands.begin(), operands.end());
110   return op.clone(b, loc, clonedViews);
111 }
112 
113 struct ViewDimension {
114   Value view;
115   unsigned dimension;
116 };
117 
118 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies
119 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
120 // guarantees at least one such dimension is found. If multiple candidates exist
121 // they must agree by construction (i.e. have the same size) and we just return
122 // the first one.
123 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
124   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
125   auto maps = op.indexing_maps();
126   // Iterate over the inputs and outputs in order.
127   // Extract the subranges from the linearized ranges.
128   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
129   for (auto en : llvm::enumerate(ios)) {
130     unsigned idx = en.index();
131     auto map = maps[idx].cast<AffineMapAttr>().getValue();
132     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
133     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
134     Value view = en.value();
135     SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
136     for (auto en2 : llvm::enumerate(map.getResults())) {
137       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
138         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
139                           << "\n");
140         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
141         return ViewDimension{view, static_cast<unsigned>(en2.index())};
142       }
143     }
144   }
145   llvm_unreachable("Expect to be able to extract a view defining loop range");
146 }
147 
148 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
149                      unsigned consumerIdx, unsigned producerIdx,
150                      OperationFolder *folder) {
151   assert(producer.hasBufferSemantics() &&
152          "expected linalg op with buffer semantics");
153   assert(consumer.hasBufferSemantics() &&
154          "expected linalg op with buffer semantics");
155 
156   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
157     // TODO(ntv): add a level of indirection to linalg.generic.
158     if (convOp.padding())
159       llvm_unreachable("Unexpected conv with padding");
160   }
161   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
162     // TODO(ntv): add a level of indirection to linalg.generic.
163     if (convOp.padding())
164       llvm_unreachable("Unexpected conv with padding");
165   }
166 
167   auto subView = dyn_cast_or_null<SubViewOp>(
168       consumer.getInput(consumerIdx).getDefiningOp());
169   auto slice =
170       dyn_cast_or_null<SliceOp>(consumer.getInput(consumerIdx).getDefiningOp());
171   assert(subView || slice);
172   (void)subView;
173   (void)slice;
174 
175   // loopToOperandRangesMaps are permutations-only by construction:
176   //   we can always identify a data dimension with a (at least one) loop
177   //   dimension.
178   AffineMap producerMap =
179       producer.indexing_maps()[producer.getNumInputs() + producerIdx]
180           .cast<AffineMapAttr>()
181           .getValue();
182   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
183                     << ", producer map: " << producerMap << "\n");
184 
185   unsigned nPar = producer.getNumParallelLoops();
186   unsigned nRed = producer.getNumReductionLoops();
187   unsigned nWin = producer.getNumWindowLoops();
188   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
189 
190   // Iterate over dimensions identified by the producer map for `producerIdx`.
191   // This defines a subset of the loop ranges that we need to complete later.
192   for (auto en : llvm::enumerate(producerMap.getResults())) {
193     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
194     loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
195   }
196 
197   OpBuilder b(consumer.getOperation());
198   auto loc = consumer.getLoc();
199   // Iterate over all dimensions. For the dimensions not identified by the
200   // producer map for `producerIdx`, we need to explicitly compute the view that
201   // defines the loop ranges using the `producer`.
202   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
203     if (loopRanges[i].offset)
204       LLVM_DEBUG(llvm::dbgs()
205                  << "existing LoopRange: " << loopRanges[i] << "\n");
206     else {
207       auto viewDim = getViewDefiningLoopRange(producer, i);
208       loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0),
209                                        std_dim(viewDim.view, viewDim.dimension),
210                                        folded_std_constant_index(folder, 1)};
211       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
212     }
213   }
214 
215   return cloneWithLoopRanges(b, loc, producer, loopRanges);
216 }
217 
218 // Encode structural fusion safety preconditions.
219 // Some of these will be lifted in the future with better analysis.
220 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
221                                           LinalgOp consumer) {
222   assert(producer.hasBufferSemantics() &&
223          "expected linalg op with buffer semantics");
224   assert(consumer.hasBufferSemantics() &&
225          "expected linalg op with buffer semantics");
226   if (producer.getNumOutputs() != 1) {
227     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
228     return false;
229   }
230   // Only fuse when the producer block dominates.
231   DominanceInfo dom(producer.getOperation());
232   if (!dom.dominates(producer.getOperation()->getBlock(),
233                      consumer.getOperation()->getBlock())) {
234     LLVM_DEBUG(
235         dbgs()
236         << "\nNot structurally fusable (producer block does not dominate)");
237     return false;
238   }
239   return true;
240 }
241 
242 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
243                                              LinalgOp consumer,
244                                              Value consumedView,
245                                              LinalgOp producer) {
246   assert(producer.hasBufferSemantics() &&
247          "expected linalg op with buffer semantics");
248   assert(consumer.hasBufferSemantics() &&
249          "expected linalg op with buffer semantics");
250   // Make some simple structural checks that alleviate the need for more
251   // complex analyses.
252   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
253     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
254                       << *producer.getOperation());
255     return false;
256   }
257   // Check for any interleaved write to consumedView.
258   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
259     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
260                       << *producer.getOperation());
261     return false;
262   }
263   return true;
264 }
265 
266 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
267                                  LinalgOp consumer, Value consumedView,
268                                  LinalgOp producer) {
269   assert(producer.hasBufferSemantics() &&
270          "expected linalg op with buffer semantics");
271   assert(consumer.hasBufferSemantics() &&
272          "expected linalg op with buffer semantics");
273   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
274     return false;
275   // Check for any fusion-preventing dependence to any view read/written that
276   // would violate dependences.
277   if (!graph.findCoveringDependences(producer, consumer).empty()) {
278     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
279                       << *producer.getOperation());
280     return false;
281   }
282   return true;
283 }
284 
285 // Only consider RAW atm.
286 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
287     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
288     const LinalgDependenceGraph &graph, OperationFolder *folder) {
289   assert(consumer.hasBufferSemantics() &&
290          "expected linalg op with buffer semantics");
291   LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
292                     << *consumer.getOperation());
293   for (auto dependence : graph.getDependencesInto(
294            consumer, LinalgDependenceGraph::DependenceType::RAW)) {
295     LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
296                       << *dependence.dependentOpView.op << "\n");
297     auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
298     if (isa<linalg::IndexedGenericOp>(dependence.dependentOpView.op)) {
299       LLVM_DEBUG(dbgs() << "Not fusing indexed_generic producer");
300       continue;
301     }
302 
303     // Check that the dependence is indeed on the input `consumerIdx` view.
304     auto consumedView = dependence.indexingView;
305     if (consumer.getInput(consumerIdx) != consumedView)
306       continue;
307 
308     // Consumer consumes this view, `isStructurallyFusableProducer` also checks
309     // whether it is a strict subview of the producer view.
310     auto producedView = dependence.dependentOpView.view;
311     auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
312     // `consumerIdx` and `producerIdx` exist by construction.
313     LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
314                       << " view: " << producedView
315                       << " output index: " << producerIdx);
316 
317     // Must be a subview or a slice to guarantee there are loops we can fuse
318     // into.
319     auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp());
320     auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp());
321     if (!subView && !slice) {
322       LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
323       continue;
324     }
325 
326     // Simple fusability checks.
327     if (!isFusableInto(graph, consumer, consumedView, producer))
328       continue;
329 
330     // Fuse `producer` just before `consumer`.
331     OpBuilder::InsertionGuard g(b);
332     b.setInsertionPoint(consumer.getOperation());
333     ScopedContext scope(b, consumer.getLoc());
334     LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
335     auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
336                               producerIdx, folder);
337 
338     return FusionInfo{producer, fusedProducer};
339   }
340   return llvm::None;
341 }
342 
343 /// Checks if two Generic ops are fusible, when one is a producer and another is
344 /// a consumer (with the result of the producer being the `consumerIdx` operand
345 /// of the consumer).
346 static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer,
347                                 unsigned consumerIdx) {
348   // Verify that the producer and consumer are ops on tensors.
349   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
350     return false;
351 
352   auto producerOp = dyn_cast<linalg::GenericOp>(producer.getOperation());
353   auto consumerOp = dyn_cast<linalg::GenericOp>(consumer.getOperation());
354   // Verify that
355   // - the producer and consumers are generic ops,
356   // - only handle cases where the producer has a single return value,
357   // - the producer return value should be the same as argument at `consumerIdx`
358   //   of the consumer,
359   // - the producer has all "parallel" iterator type.
360   // - only handle ops that use regions for specifying the scalar operations.
361   if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 ||
362       producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) ||
363       producerOp.getNumParallelLoops() != producerOp.getNumLoops() ||
364       producerOp.fun() || consumerOp.fun())
365     return false;
366 
367   // Get the consumer index map. The number of results of the consumer index map
368   // must match the number of loops of the producer.
369   AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx);
370   if (consumerIndexMap.getNumResults() != producerOp.getNumLoops())
371     return false;
372 
373   // Finally the index_map for the result must be invertible. For now just
374   // verify it is a permutation.
375   AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0);
376   return producerResultIndexMap.isPermutation();
377 }
378 
379 /// Computes the indexing maps for arguments of a producer generic op when the
380 /// result of the producer is fused with the consumer.
381 /// - consumerIndexMap is the indexing_map for the argument in the consumer op
382 ///   that is the result of the producer op.
383 /// - invProducerResultIndexMap is the inverse of the indexing_map for the
384 ///   result in the producer op.
385 /// - producerArgIndexMap is the indexing_map of the argument of the producer
386 ///   op.
387 /// The result is the indexing_map to use for the producer argument when the
388 /// producer and consumer ops are fused.
389 static AffineMap computeProducerArgMap(AffineMap consumerIndexMap,
390                                        AffineMap invProducerResultIndexMap,
391                                        AffineMap producerArgIndexMap) {
392   // t1 is map from producer result tensor index -> producer arg tensor index.
393   auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap);
394   // The return is map from consumer loop -> producer arg tensor index,
395   // i.e. indexing_map for the producer argument in the fused operation.
396   return t1.compose(consumerIndexMap);
397 }
398 
399 Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer,
400                                                LinalgOp consumer,
401                                                unsigned consumerIdx,
402                                                OperationFolder *folder) {
403   if (!areTensorOpsFusible(producer, consumer, consumerIdx))
404     return {};
405 
406   MLIRContext *context = b.getContext();
407   auto producerOp = cast<linalg::GenericOp>(producer.getOperation());
408   auto consumerOp = cast<linalg::GenericOp>(consumer.getOperation());
409   AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx);
410   AffineMap invProducerResultIndexMap =
411       inversePermutation(producerOp.getOutputIndexingMap(0));
412 
413   // Compute the fused op operandslist by replacing the operand corresponding to
414   // the result of the producer, with the operands of the producer.
415   unsigned fusedArgsIn =
416       producerOp.getNumInputs() + consumerOp.getNumInputs() - 1;
417   auto fusedArgsOut = consumerOp.getNumOutputs();
418   SmallVector<Value, 2> fusedOperandsList(consumerOp.getOperands());
419   fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx));
420   fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut);
421   fusedOperandsList.insert(
422       std::next(fusedOperandsList.begin(), consumerIdx),
423       producerOp.operand_begin(),
424       std::next(producerOp.operand_begin(), producerOp.getNumInputs()));
425 
426   // Compute the fused indexing_maps of the operands/results of the fused op.
427   SmallVector<Attribute, 2> fusedIndexingMapAttrs;
428   fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut);
429   fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(),
430                                consumerOp.indexing_maps().end());
431   fusedIndexingMapAttrs.erase(
432       std::next(fusedIndexingMapAttrs.begin(), consumerIdx));
433   auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx);
434   for (auto producerArgIndexAttr :
435        llvm::enumerate(producerOp.indexing_maps())) {
436     if (producerArgIndexAttr.index() == producerOp.getNumInputs())
437       break;
438     auto composedIndexMap = computeProducerArgMap(
439         consumerIndexMap, invProducerResultIndexMap,
440         producerArgIndexAttr.value().cast<AffineMapAttr>().getValue());
441     insertPos = std::next(fusedIndexingMapAttrs.insert(
442         insertPos, AffineMapAttr::get(composedIndexMap)));
443   }
444 
445   // Generate the fused op.
446   auto fusedLinalgOp = b.create<GenericOp>(
447       UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList,
448       b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut),
449       b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(),
450       /*doc=*/nullptr,
451       /*fun=*/nullptr,
452       /*library_call=*/nullptr);
453 
454   // Build the region of the fused op.
455   auto &fusedOpRegion = fusedLinalgOp.region();
456   Block &producerOpBlock = producerOp.region().front();
457   Block &consumerOpBlock = consumerOp.region().front();
458   Block *fusedBlock = new Block();
459   fusedOpRegion.push_back(fusedBlock);
460   BlockAndValueMapping mapper;
461   // Map the arguments for the unmodified args from the consumer.
462   for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) {
463     if (consumerOpArg.index() == consumerIdx) {
464       // Map the arguments for the args from the producer.
465       for (auto producerOpArg : producerOpBlock.getArguments())
466         mapper.map(producerOpArg,
467                    fusedBlock->addArgument(producerOpArg.getType()));
468       continue;
469     }
470     mapper.map(consumerOpArg.value(),
471                fusedBlock->addArgument(consumerOpArg.value().getType()));
472   }
473 
474   // Add operations from producer (except the yield operation) to the fused op.
475   for (auto &op : producerOpBlock.getOperations()) {
476     if (auto yieldOp = dyn_cast<YieldOp>(op)) {
477       // Lookup the value the yield operation is mapped to.
478       Value yieldVal = yieldOp.getOperand(0);
479       auto clonedVal = mapper.lookup(yieldVal);
480       mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal);
481       continue;
482     }
483     fusedBlock->push_back(op.clone(mapper));
484   }
485   for (auto &op : consumerOpBlock.getOperations())
486     fusedBlock->push_back(op.clone(mapper));
487 
488   return cast<LinalgOp>(fusedLinalgOp.getOperation());
489 }
490 
491 static void fuseLinalgOpsGreedily(FuncOp f) {
492   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
493 
494   OpBuilder b(f);
495   OperationFolder folder(f.getContext());
496   DenseSet<Operation *> eraseSet;
497 
498   // Save original Linalg ops, we only want to make a pass over those.
499   SmallVector<Operation *, 8> linalgOps;
500   f.walk([&](LinalgOp op) {
501     if (op.hasBufferSemantics())
502       linalgOps.push_back(op);
503   });
504 
505   // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself.
506   // The current naive and expensive reconstruction of the graph should be
507   // removed.
508   for (auto *op : llvm::reverse(linalgOps)) {
509     for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) {
510       linalg::Aliases aliases;
511       linalg::LinalgDependenceGraph graph(aliases, linalgOps);
512       if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
513         auto *originalOp = info->originalProducer.getOperation();
514         eraseSet.insert(originalOp);
515         auto *originalOpInLinalgOpsVector =
516             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
517         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
518       }
519     }
520   }
521   // The `fuseProducerOf` function performs structural checks and in particular
522   // that no covering read or write exist between the consumer and the producer.
523   // As a consequence, the only fusions that may occur preserve subsequent
524   // dependences and are guaranteed by construction to produce the whole view.
525   // We may thus erase the producer once it is fused.
526   for (auto *e : eraseSet)
527     e->erase();
528   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
529 }
530 
531 namespace {
532 
533 /// Patterns to fuse a generic op, with the producer of its operands.
534 struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> {
535   using OpRewritePattern<GenericOp>::OpRewritePattern;
536 
537   LogicalResult matchAndRewrite(GenericOp op,
538                                 PatternRewriter &rewriter) const override {
539     if (!op.hasTensorSemantics())
540       return failure();
541 
542     // Find the first operand that is defined by another generic op on tensors.
543     for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) {
544       auto definingOp =
545           dyn_cast_or_null<GenericOp>(operand.value().getDefiningOp());
546       if (!definingOp || !definingOp.hasTensorSemantics())
547         continue;
548       auto fusedOp =
549           fuseTensorOps(rewriter, cast<LinalgOp>(definingOp.getOperation()),
550                         cast<LinalgOp>(op.getOperation()), operand.index());
551       if (!fusedOp)
552         continue;
553       rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults());
554       return success();
555     }
556     return failure();
557   }
558 };
559 
560 /// Pass that fuses generic ops on tensors. Used only for testing.
561 struct FusionOfTensorOpsPass : public OperationPass<FusionOfTensorOpsPass> {
562   void runOnOperation() override {
563     OwningRewritePatternList patterns;
564     Operation *op = getOperation();
565     patterns.insert<FuseGenericTensorOps>(op->getContext());
566     applyPatternsGreedily(op->getRegions(), patterns);
567   };
568 };
569 
570 struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
571   void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
572 };
573 } // namespace
574 
575 std::unique_ptr<OpPassBase<FuncOp>> mlir::createLinalgFusionPass() {
576   return std::make_unique<LinalgFusionPass>();
577 }
578 
579 static PassRegistration<LinalgFusionPass>
580     pass("linalg-fusion", "Fuse operations in the linalg dialect");
581 
582 static PassRegistration<FusionOfTensorOpsPass>
583     tensorOpsPass("linalg-fusion-for-tensor-ops",
584                   "Fuse operations on RankedTensorType in linalg dialect");
585