xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision 92f1562f3dd158d837c66a1dd20ae745477d9c36)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Analysis/Dominance.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/FoldUtils.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #define DEBUG_TYPE "linalg-fusion"
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 
38 using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>;
39 
40 using llvm::dbgs;
41 
42 /// Implements a simple high-level fusion pass of linalg library operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 ///   1. inspecting the linalg ops that write into the views read by `O`. This
47 ///      uses the SSA value of the views and a simple subview/slice analysis to
48 ///      determine producer-consumer dependences;
49 ///   2. greedily fuse the linalg ops that produce subview
50 ///   3. inspect the fused ops and determine whether they have other remaining
51 ///      LinalgOp uses. If not, then erase the original producing linalg op.
52 ///
53 /// More advanced use cases, analyses as well as profitability heuristics are
54 /// left for future work.
55 
56 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
57 // a subset of the original loop ranges of `op`.
58 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
59 // to the `loopRanges` in order to obtain view ranges.
60 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
61                                     ArrayRef<SubViewOp::Range> loopRanges) {
62   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
63   auto maps = op.indexing_maps();
64   SmallVector<Value, 8> clonedViews;
65   clonedViews.reserve(op.getNumInputsAndOutputs());
66   // Iterate over the inputs and outputs in order.
67   // Extract the subranges from the linearized ranges.
68   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
69   for (auto en : llvm::enumerate(ios)) {
70     unsigned idx = en.index();
71     auto map = maps[idx].cast<AffineMapAttr>().getValue();
72     LLVM_DEBUG(dbgs() << "map: " << map << "\n");
73     Value view = en.value();
74     SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
75     for (auto en2 : llvm::enumerate(map.getResults())) {
76       unsigned d = en2.index();
77       // loopToOperandRangesMaps are permutations-only.
78       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
79       viewRanges[d] = loopRanges[loopPos];
80       LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
81                         << "\t"
82                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
83     }
84     // Construct a new subview for the tile.
85     unsigned rank = viewRanges.size();
86     SmallVector<Value, 4> offsets, sizes, strides;
87     offsets.reserve(rank);
88     sizes.reserve(rank);
89     strides.reserve(rank);
90     for (auto r : viewRanges) {
91       offsets.push_back(r.offset);
92       sizes.push_back(r.size);
93       strides.push_back(r.stride);
94     }
95     clonedViews.push_back(
96         b.create<SubViewOp>(loc, view, offsets, sizes, strides));
97   }
98   auto operands = getAssumedNonViewOperands(op);
99   clonedViews.append(operands.begin(), operands.end());
100   return op.clone(b, loc, clonedViews);
101 }
102 
103 struct ViewDimension {
104   Value view;
105   unsigned dimension;
106 };
107 
108 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies
109 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
110 // guarantees at least one such dimension is found. If multiple candidates exist
111 // they must agree by construction (i.e. have the same size) and we just return
112 // the first one.
113 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
114   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
115   auto maps = op.indexing_maps();
116   // Iterate over the inputs and outputs in order.
117   // Extract the subranges from the linearized ranges.
118   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
119   for (auto en : llvm::enumerate(ios)) {
120     unsigned idx = en.index();
121     auto map = maps[idx].cast<AffineMapAttr>().getValue();
122     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
123     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
124     Value view = en.value();
125     SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
126     for (auto en2 : llvm::enumerate(map.getResults())) {
127       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
128         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
129                           << "\n");
130         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
131         return ViewDimension{view, static_cast<unsigned>(en2.index())};
132       }
133     }
134   }
135   llvm_unreachable("Expect to be able to extract a view defining loop range");
136 }
137 
138 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
139                      unsigned consumerIdx, unsigned producerIdx,
140                      OperationFolder *folder) {
141   assert(producer.hasBufferSemantics() &&
142          "expected linalg op with buffer semantics");
143   assert(consumer.hasBufferSemantics() &&
144          "expected linalg op with buffer semantics");
145 
146   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
147     // TODO(ntv): add a level of indirection to linalg.generic.
148     if (convOp.padding())
149       llvm_unreachable("Unexpected conv with padding");
150   }
151   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
152     // TODO(ntv): add a level of indirection to linalg.generic.
153     if (convOp.padding())
154       llvm_unreachable("Unexpected conv with padding");
155   }
156 
157   auto subView = dyn_cast_or_null<SubViewOp>(
158       consumer.getBuffer(consumerIdx).getDefiningOp());
159   auto slice = dyn_cast_or_null<SliceOp>(
160       consumer.getBuffer(consumerIdx).getDefiningOp());
161   assert(subView || slice);
162   (void)subView;
163   (void)slice;
164 
165   // loopToOperandRangesMaps are permutations-only by construction:
166   //   we can always identify a data dimension with a (at least one) loop
167   //   dimension.
168   AffineMap producerMap =
169       producer.indexing_maps()[producer.getNumInputs() + producerIdx]
170           .cast<AffineMapAttr>()
171           .getValue();
172   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
173                     << ", producer map: " << producerMap << "\n");
174 
175   unsigned nPar = producer.getNumParallelLoops();
176   unsigned nRed = producer.getNumReductionLoops();
177   unsigned nWin = producer.getNumWindowLoops();
178   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
179 
180   // Iterate over dimensions identified by the producer map for `producerIdx`.
181   // This defines a subset of the loop ranges that we need to complete later.
182   for (auto en : llvm::enumerate(producerMap.getResults())) {
183     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
184     loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
185   }
186 
187   OpBuilder b(consumer.getOperation());
188   auto loc = consumer.getLoc();
189   // Iterate over all dimensions. For the dimensions not identified by the
190   // producer map for `producerIdx`, we need to explicitly compute the view that
191   // defines the loop ranges using the `producer`.
192   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
193     if (loopRanges[i].offset)
194       LLVM_DEBUG(llvm::dbgs()
195                  << "existing LoopRange: " << loopRanges[i] << "\n");
196     else {
197       auto viewDim = getViewDefiningLoopRange(producer, i);
198       loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0),
199                                        std_dim(viewDim.view, viewDim.dimension),
200                                        folded_std_constant_index(folder, 1)};
201       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
202     }
203   }
204 
205   return cloneWithLoopRanges(b, loc, producer, loopRanges);
206 }
207 
208 // Encode structural fusion safety preconditions.
209 // Some of these will be lifted in the future with better analysis.
210 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
211                                           LinalgOp consumer) {
212   assert(producer.hasBufferSemantics() &&
213          "expected linalg op with buffer semantics");
214   assert(consumer.hasBufferSemantics() &&
215          "expected linalg op with buffer semantics");
216   if (producer.getNumOutputs() != 1) {
217     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
218     return false;
219   }
220   // Only fuse when the producer block dominates.
221   DominanceInfo dom(producer.getOperation());
222   if (!dom.dominates(producer.getOperation()->getBlock(),
223                      consumer.getOperation()->getBlock())) {
224     LLVM_DEBUG(
225         dbgs()
226         << "\nNot structurally fusable (producer block does not dominate)");
227     return false;
228   }
229   return true;
230 }
231 
232 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
233                                              LinalgOp consumer,
234                                              Value consumedView,
235                                              LinalgOp producer) {
236   assert(producer.hasBufferSemantics() &&
237          "expected linalg op with buffer semantics");
238   assert(consumer.hasBufferSemantics() &&
239          "expected linalg op with buffer semantics");
240   // Make some simple structural checks that alleviate the need for more
241   // complex analyses.
242   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
243     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
244                       << *producer.getOperation());
245     return false;
246   }
247   // Check for any interleaved write to consumedView.
248   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
249     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
250                       << *producer.getOperation());
251     return false;
252   }
253   return true;
254 }
255 
256 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
257                                  LinalgOp consumer, Value consumedView,
258                                  LinalgOp producer) {
259   assert(producer.hasBufferSemantics() &&
260          "expected linalg op with buffer semantics");
261   assert(consumer.hasBufferSemantics() &&
262          "expected linalg op with buffer semantics");
263   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
264     return false;
265   // Check for any fusion-preventing dependence to any view read/written that
266   // would violate dependences.
267   if (!graph.findCoveringDependences(producer, consumer).empty()) {
268     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
269                       << *producer.getOperation());
270     return false;
271   }
272   return true;
273 }
274 
275 static Optional<FusionInfo>
276 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
277                   const LinalgDependenceGraph &graph, OperationFolder *folder,
278                   LinalgDependenceGraph::DependenceType depType) {
279   assert(consumer.hasBufferSemantics() &&
280          "expected linalg op with buffer semantics");
281   LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
282                     << *consumer.getOperation());
283   for (auto dependence : graph.getDependencesInto(consumer, depType)) {
284     LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
285                       << *dependence.dependentOpView.op << "\n");
286     auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
287     if (isa<linalg::IndexedGenericOp>(dependence.dependentOpView.op)) {
288       LLVM_DEBUG(dbgs() << "Not fusing indexed_generic producer");
289       continue;
290     }
291 
292     // Check that the dependence is indeed on the input `consumerIdx` view.
293     auto consumedView = dependence.indexingView;
294     if (consumer.getBuffer(consumerIdx) != consumedView)
295       continue;
296 
297     // Consumer consumes this view, `isStructurallyFusableProducer` also checks
298     // whether it is a strict subview of the producer view.
299     auto producedView = dependence.dependentOpView.view;
300     auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
301     // `consumerIdx` and `producerIdx` exist by construction.
302     LLVM_DEBUG(dbgs() << "\n"
303                       << LinalgDependenceGraph::getDependenceTypeStr(depType)
304                       << "producer: " << *producer.getOperation() << " view: "
305                       << producedView << " output index: " << producerIdx);
306 
307     // Must be a subview or a slice to guarantee there are loops we can fuse
308     // into.
309     auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp());
310     auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp());
311     if (!subView && !slice) {
312       LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
313       continue;
314     }
315 
316     // Simple fusability checks.
317     if (!isFusableInto(graph, consumer, consumedView, producer))
318       continue;
319 
320     // Fuse `producer` just before `consumer`.
321     OpBuilder::InsertionGuard g(b);
322     b.setInsertionPoint(consumer.getOperation());
323     ScopedContext scope(b, consumer.getLoc());
324     LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
325     auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
326                               producerIdx, folder);
327 
328     return FusionInfo{producer, fusedProducer};
329   }
330   return llvm::None;
331 }
332 
333 // Only consider RAW and WAW atm.
334 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
335     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
336     const LinalgDependenceGraph &graph, OperationFolder *folder) {
337   SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = {
338       LinalgDependenceGraph::DependenceType::RAW,
339       LinalgDependenceGraph::DependenceType::WAW,
340   };
341   for (auto dep : deps) {
342     if (auto res =
343             fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep))
344       return res;
345   }
346   return llvm::None;
347 }
348 
349 /// Checks if two Generic ops are fusible, when one is a producer and another is
350 /// a consumer (with the result of the producer being the `consumerIdx` operand
351 /// of the consumer).
352 static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer,
353                                 unsigned consumerIdx) {
354   // Verify that the producer and consumer are ops on tensors.
355   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
356     return false;
357 
358   auto producerOp = dyn_cast<linalg::GenericOp>(producer.getOperation());
359   auto consumerOp = dyn_cast<linalg::GenericOp>(consumer.getOperation());
360   // Verify that
361   // - the producer and consumers are generic ops,
362   // - only handle cases where the producer has a single return value,
363   // - the producer return value should be the same as argument at `consumerIdx`
364   //   of the consumer,
365   // - the producer has all "parallel" iterator type.
366   // - only handle ops that use regions for specifying the scalar operations.
367   if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 ||
368       producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) ||
369       producerOp.getNumParallelLoops() != producerOp.getNumLoops() ||
370       producerOp.fun() || consumerOp.fun())
371     return false;
372 
373   // Get the consumer index map. The number of results of the consumer index map
374   // must match the number of loops of the producer.
375   AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx);
376   if (consumerIndexMap.getNumResults() != producerOp.getNumLoops())
377     return false;
378 
379   // Finally the index_map for the result must be invertible. For now just
380   // verify it is a permutation.
381   AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0);
382   return producerResultIndexMap.isPermutation();
383 }
384 
385 /// Computes the indexing maps for arguments of a producer generic op when the
386 /// result of the producer is fused with the consumer.
387 /// - consumerIndexMap is the indexing_map for the argument in the consumer op
388 ///   that is the result of the producer op.
389 /// - invProducerResultIndexMap is the inverse of the indexing_map for the
390 ///   result in the producer op.
391 /// - producerArgIndexMap is the indexing_map of the argument of the producer
392 ///   op.
393 /// The result is the indexing_map to use for the producer argument when the
394 /// producer and consumer ops are fused.
395 static AffineMap computeProducerArgMap(AffineMap consumerIndexMap,
396                                        AffineMap invProducerResultIndexMap,
397                                        AffineMap producerArgIndexMap) {
398   // t1 is map from producer result tensor index -> producer arg tensor index.
399   auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap);
400   // The return is map from consumer loop -> producer arg tensor index,
401   // i.e. indexing_map for the producer argument in the fused operation.
402   return t1.compose(consumerIndexMap);
403 }
404 
405 Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer,
406                                                LinalgOp consumer,
407                                                unsigned consumerIdx,
408                                                OperationFolder *folder) {
409   if (!areTensorOpsFusible(producer, consumer, consumerIdx))
410     return {};
411 
412   MLIRContext *context = b.getContext();
413   auto producerOp = cast<linalg::GenericOp>(producer.getOperation());
414   auto consumerOp = cast<linalg::GenericOp>(consumer.getOperation());
415   AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx);
416   AffineMap invProducerResultIndexMap =
417       inversePermutation(producerOp.getOutputIndexingMap(0));
418   if (!invProducerResultIndexMap)
419     return {};
420 
421   // Compute the fused op operandslist by replacing the operand corresponding to
422   // the result of the producer, with the operands of the producer.
423   unsigned fusedArgsIn =
424       producerOp.getNumInputs() + consumerOp.getNumInputs() - 1;
425   auto fusedArgsOut = consumerOp.getNumOutputs();
426   SmallVector<Value, 2> fusedOperandsList(consumerOp.getOperands());
427   fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx));
428   fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut);
429   fusedOperandsList.insert(
430       std::next(fusedOperandsList.begin(), consumerIdx),
431       producerOp.operand_begin(),
432       std::next(producerOp.operand_begin(), producerOp.getNumInputs()));
433 
434   // Compute the fused indexing_maps of the operands/results of the fused op.
435   SmallVector<Attribute, 2> fusedIndexingMapAttrs;
436   fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut);
437   fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(),
438                                consumerOp.indexing_maps().end());
439   fusedIndexingMapAttrs.erase(
440       std::next(fusedIndexingMapAttrs.begin(), consumerIdx));
441   auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx);
442   for (auto producerArgIndexAttr :
443        llvm::enumerate(producerOp.indexing_maps())) {
444     if (producerArgIndexAttr.index() == producerOp.getNumInputs())
445       break;
446     auto composedIndexMap = computeProducerArgMap(
447         consumerIndexMap, invProducerResultIndexMap,
448         producerArgIndexAttr.value().cast<AffineMapAttr>().getValue());
449     insertPos = std::next(fusedIndexingMapAttrs.insert(
450         insertPos, AffineMapAttr::get(composedIndexMap)));
451   }
452 
453   // Generate the fused op.
454   auto fusedLinalgOp = b.create<GenericOp>(
455       UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList,
456       b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut),
457       b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(),
458       /*doc=*/nullptr,
459       /*fun=*/nullptr,
460       /*library_call=*/nullptr);
461 
462   // Build the region of the fused op.
463   auto &fusedOpRegion = fusedLinalgOp.region();
464   Block &producerOpBlock = producerOp.region().front();
465   Block &consumerOpBlock = consumerOp.region().front();
466   Block *fusedBlock = new Block();
467   fusedOpRegion.push_back(fusedBlock);
468   BlockAndValueMapping mapper;
469   // Map the arguments for the unmodified args from the consumer.
470   for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) {
471     if (consumerOpArg.index() == consumerIdx) {
472       // Map the arguments for the args from the producer.
473       for (auto producerOpArg : producerOpBlock.getArguments())
474         mapper.map(producerOpArg,
475                    fusedBlock->addArgument(producerOpArg.getType()));
476       continue;
477     }
478     mapper.map(consumerOpArg.value(),
479                fusedBlock->addArgument(consumerOpArg.value().getType()));
480   }
481 
482   // Add operations from producer (except the yield operation) to the fused op.
483   for (auto &op : producerOpBlock.getOperations()) {
484     if (auto yieldOp = dyn_cast<YieldOp>(op)) {
485       // Lookup the value the yield operation is mapped to.
486       Value yieldVal = yieldOp.getOperand(0);
487       auto clonedVal = mapper.lookup(yieldVal);
488       mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal);
489       continue;
490     }
491     fusedBlock->push_back(op.clone(mapper));
492   }
493   for (auto &op : consumerOpBlock.getOperations())
494     fusedBlock->push_back(op.clone(mapper));
495 
496   return cast<LinalgOp>(fusedLinalgOp.getOperation());
497 }
498 
499 static void fuseLinalgOpsGreedily(FuncOp f) {
500   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
501 
502   OpBuilder b(f);
503   OperationFolder folder(f.getContext());
504   DenseSet<Operation *> eraseSet;
505 
506   // Save original Linalg ops, we only want to make a pass over those.
507   SmallVector<Operation *, 8> linalgOps;
508   f.walk([&](LinalgOp op) {
509     if (op.hasBufferSemantics())
510       linalgOps.push_back(op);
511   });
512 
513   // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself.
514   // The current naive and expensive reconstruction of the graph should be
515   // removed.
516   for (auto *op : llvm::reverse(linalgOps)) {
517     for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers();
518          id < e; ++id) {
519       linalg::Aliases aliases;
520       linalg::LinalgDependenceGraph graph(aliases, linalgOps);
521       if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
522         auto *originalOp = info->originalProducer.getOperation();
523         eraseSet.insert(originalOp);
524         auto *originalOpInLinalgOpsVector =
525             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
526         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
527       }
528     }
529   }
530   // The `fuseProducerOf` function performs structural checks and in particular
531   // that no covering read or write exist between the consumer and the producer.
532   // As a consequence, the only fusions that may occur preserve subsequent
533   // dependences and are guaranteed by construction to produce the whole view.
534   // We may thus erase the producer once it is fused.
535   for (auto *e : eraseSet)
536     e->erase();
537   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
538 }
539 
540 namespace {
541 
542 /// Patterns to fuse a generic op, with the producer of its operands.
543 struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> {
544   using OpRewritePattern<GenericOp>::OpRewritePattern;
545 
546   LogicalResult matchAndRewrite(GenericOp op,
547                                 PatternRewriter &rewriter) const override {
548     if (!op.hasTensorSemantics())
549       return failure();
550 
551     // Find the first operand that is defined by another generic op on tensors.
552     for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) {
553       auto definingOp =
554           dyn_cast_or_null<GenericOp>(operand.value().getDefiningOp());
555       if (!definingOp || !definingOp.hasTensorSemantics())
556         continue;
557       auto fusedOp =
558           fuseTensorOps(rewriter, cast<LinalgOp>(definingOp.getOperation()),
559                         cast<LinalgOp>(op.getOperation()), operand.index());
560       if (!fusedOp)
561         continue;
562       rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults());
563       if (llvm::all_of(definingOp.getResults(),
564                        [](Value val) -> bool { return val.use_empty(); }))
565         rewriter.eraseOp(definingOp);
566       return success();
567     }
568     return failure();
569   }
570 };
571 
572 /// Pass that fuses generic ops on tensors. Used only for testing.
573 struct FusionOfTensorOpsPass
574     : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
575   void runOnOperation() override {
576     OwningRewritePatternList patterns;
577     Operation *op = getOperation();
578     patterns.insert<FuseGenericTensorOps>(op->getContext());
579     applyPatternsAndFoldGreedily(op->getRegions(), patterns);
580   };
581 };
582 
583 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
584   void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
585 };
586 } // namespace
587 
588 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
589   return std::make_unique<LinalgFusionPass>();
590 }
591 
592 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
593   return std::make_unique<FusionOfTensorOpsPass>();
594 }
595