xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision be9c3bdc44baddfd1ed0efeb4db249198a21b20d)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Analysis/Dominance.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/FoldUtils.h"
27 #include "mlir/Transforms/LoopUtils.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 
32 #define DEBUG_TYPE "linalg-fusion"
33 
34 using namespace mlir;
35 using namespace mlir::edsc;
36 using namespace mlir::edsc::intrinsics;
37 using namespace mlir::linalg;
38 
39 using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>;
40 
41 using llvm::dbgs;
42 
43 /// Implements a simple high-level fusion pass of linalg library operations.
44 ///
45 /// In each block, linalg ops are processed in reverse textual order.
46 /// Given a linalg op `O`, fusion occurs by:
47 ///   1. inspecting the linalg ops that write into the views read by `O`. This
48 ///      uses the SSA value of the views and a simple subview/slice analysis to
49 ///      determine producer-consumer dependences;
50 ///   2. greedily fuse the linalg ops that produce subview
51 ///   3. inspect the fused ops and determine whether they have other remaining
52 ///      LinalgOp uses. If not, then erase the original producing linalg op.
53 ///
54 /// More advanced use cases, analyses as well as profitability heuristics are
55 /// left for future work.
56 
57 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
58 // a subset of the original loop ranges of `op`.
59 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
60 // to the `loopRanges` in order to obtain view ranges.
61 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
62                                     ArrayRef<SubViewOp::Range> loopRanges) {
63   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
64   auto maps = op.indexing_maps();
65   SmallVector<Value, 8> clonedViews;
66   clonedViews.reserve(op.getNumInputsAndOutputs());
67   // Iterate over the inputs and outputs in order.
68   // Extract the subranges from the linearized ranges.
69   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
70   for (auto en : llvm::enumerate(ios)) {
71     unsigned idx = en.index();
72     auto map = maps[idx].cast<AffineMapAttr>().getValue();
73     LLVM_DEBUG(dbgs() << "map: " << map << "\n");
74     Value view = en.value();
75     SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
76     for (auto en2 : llvm::enumerate(map.getResults())) {
77       unsigned d = en2.index();
78       // loopToOperandRangesMaps are permutations-only.
79       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
80       viewRanges[d] = loopRanges[loopPos];
81       LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
82                         << "\t"
83                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
84     }
85     // Construct a new subview for the tile.
86     unsigned rank = viewRanges.size();
87     SmallVector<Value, 4> offsets, sizes, strides;
88     offsets.reserve(rank);
89     sizes.reserve(rank);
90     strides.reserve(rank);
91     for (auto r : viewRanges) {
92       offsets.push_back(r.offset);
93       sizes.push_back(r.size);
94       strides.push_back(r.stride);
95     }
96     clonedViews.push_back(
97         b.create<SubViewOp>(loc, view, offsets, sizes, strides));
98   }
99   auto operands = getAssumedNonViewOperands(op);
100   clonedViews.append(operands.begin(), operands.end());
101 
102   Operation *clonedOp = op.clone(b, loc, clonedViews);
103   // When the producer is an IndexedGenercOp, we have to transform its block
104   // IV arguments according to the tiling of the consumer, i.e. offset them by
105   // the values computed in `loopRanges`.
106   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
107     auto &block = indexedGenericOp.region().front();
108 
109     OpBuilder::InsertionGuard g(b);
110     b.setInsertionPointToStart(&block);
111     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
112       Value oldIndex = block.getArgument(i);
113       Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
114                                         loopRanges[i].offset);
115       replaceAllUsesExcept(
116           oldIndex, newIndex,
117           SmallPtrSet<Operation *, 1>{newIndex.getDefiningOp()});
118     }
119   }
120   return clonedOp;
121 }
122 
123 struct ViewDimension {
124   Value view;
125   unsigned dimension;
126 };
127 
128 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies
129 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
130 // guarantees at least one such dimension is found. If multiple candidates exist
131 // they must agree by construction (i.e. have the same size) and we just return
132 // the first one.
133 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
134   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
135   auto maps = op.indexing_maps();
136   // Iterate over the inputs and outputs in order.
137   // Extract the subranges from the linearized ranges.
138   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
139   for (auto en : llvm::enumerate(ios)) {
140     unsigned idx = en.index();
141     auto map = maps[idx].cast<AffineMapAttr>().getValue();
142     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
143     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
144     Value view = en.value();
145     SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
146     for (auto en2 : llvm::enumerate(map.getResults())) {
147       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
148         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
149                           << "\n");
150         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
151         return ViewDimension{view, static_cast<unsigned>(en2.index())};
152       }
153     }
154   }
155   llvm_unreachable("Expect to be able to extract a view defining loop range");
156 }
157 
158 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
159                      unsigned consumerIdx, unsigned producerIdx,
160                      OperationFolder *folder) {
161   assert(producer.hasBufferSemantics() &&
162          "expected linalg op with buffer semantics");
163   assert(consumer.hasBufferSemantics() &&
164          "expected linalg op with buffer semantics");
165 
166   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
167     // TODO(ntv): add a level of indirection to linalg.generic.
168     if (convOp.padding())
169       llvm_unreachable("Unexpected conv with padding");
170   }
171   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
172     // TODO(ntv): add a level of indirection to linalg.generic.
173     if (convOp.padding())
174       llvm_unreachable("Unexpected conv with padding");
175   }
176 
177   auto subView = dyn_cast_or_null<SubViewOp>(
178       consumer.getBuffer(consumerIdx).getDefiningOp());
179   auto slice = dyn_cast_or_null<SliceOp>(
180       consumer.getBuffer(consumerIdx).getDefiningOp());
181   assert(subView || slice);
182   (void)subView;
183   (void)slice;
184 
185   // loopToOperandRangesMaps are permutations-only by construction:
186   //   we can always identify a data dimension with a (at least one) loop
187   //   dimension.
188   AffineMap producerMap =
189       producer.indexing_maps()[producer.getNumInputs() + producerIdx]
190           .cast<AffineMapAttr>()
191           .getValue();
192   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
193                     << ", producer map: " << producerMap << "\n");
194 
195   unsigned nPar = producer.getNumParallelLoops();
196   unsigned nRed = producer.getNumReductionLoops();
197   unsigned nWin = producer.getNumWindowLoops();
198   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
199 
200   // Iterate over dimensions identified by the producer map for `producerIdx`.
201   // This defines a subset of the loop ranges that we need to complete later.
202   for (auto en : llvm::enumerate(producerMap.getResults())) {
203     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
204     loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
205   }
206 
207   OpBuilder b(consumer.getOperation());
208   auto loc = consumer.getLoc();
209   // Iterate over all dimensions. For the dimensions not identified by the
210   // producer map for `producerIdx`, we need to explicitly compute the view that
211   // defines the loop ranges using the `producer`.
212   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
213     if (loopRanges[i].offset)
214       LLVM_DEBUG(llvm::dbgs()
215                  << "existing LoopRange: " << loopRanges[i] << "\n");
216     else {
217       auto viewDim = getViewDefiningLoopRange(producer, i);
218       loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0),
219                                        std_dim(viewDim.view, viewDim.dimension),
220                                        folded_std_constant_index(folder, 1)};
221       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
222     }
223   }
224 
225   return cloneWithLoopRanges(b, loc, producer, loopRanges);
226 }
227 
228 // Encode structural fusion safety preconditions.
229 // Some of these will be lifted in the future with better analysis.
230 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
231                                           LinalgOp consumer) {
232   assert(producer.hasBufferSemantics() &&
233          "expected linalg op with buffer semantics");
234   assert(consumer.hasBufferSemantics() &&
235          "expected linalg op with buffer semantics");
236   if (producer.getNumOutputs() != 1) {
237     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
238     return false;
239   }
240   // Only fuse when the producer block dominates.
241   DominanceInfo dom(producer.getOperation());
242   if (!dom.dominates(producer.getOperation()->getBlock(),
243                      consumer.getOperation()->getBlock())) {
244     LLVM_DEBUG(
245         dbgs()
246         << "\nNot structurally fusable (producer block does not dominate)");
247     return false;
248   }
249   return true;
250 }
251 
252 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
253                                              LinalgOp consumer,
254                                              Value consumedView,
255                                              LinalgOp producer) {
256   assert(producer.hasBufferSemantics() &&
257          "expected linalg op with buffer semantics");
258   assert(consumer.hasBufferSemantics() &&
259          "expected linalg op with buffer semantics");
260   // Make some simple structural checks that alleviate the need for more
261   // complex analyses.
262   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
263     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
264                       << *producer.getOperation());
265     return false;
266   }
267   // Check for any interleaved write to consumedView.
268   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
269     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
270                       << *producer.getOperation());
271     return false;
272   }
273   return true;
274 }
275 
276 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
277                                  LinalgOp consumer, Value consumedView,
278                                  LinalgOp producer) {
279   assert(producer.hasBufferSemantics() &&
280          "expected linalg op with buffer semantics");
281   assert(consumer.hasBufferSemantics() &&
282          "expected linalg op with buffer semantics");
283   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
284     return false;
285   // Check for any fusion-preventing dependence to any view read/written that
286   // would violate dependences.
287   if (!graph.findCoveringDependences(producer, consumer).empty()) {
288     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
289                       << *producer.getOperation());
290     return false;
291   }
292   return true;
293 }
294 
295 static Optional<FusionInfo>
296 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
297                   const LinalgDependenceGraph &graph, OperationFolder *folder,
298                   LinalgDependenceGraph::DependenceType depType) {
299   assert(consumer.hasBufferSemantics() &&
300          "expected linalg op with buffer semantics");
301   LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
302                     << *consumer.getOperation());
303   for (auto dependence : graph.getDependencesInto(consumer, depType)) {
304     LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
305                       << *dependence.dependentOpView.op << "\n");
306     auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
307 
308     // Check that the dependence is indeed on the input `consumerIdx` view.
309     auto consumedView = dependence.indexingView;
310     if (consumer.getBuffer(consumerIdx) != consumedView)
311       continue;
312 
313     // Consumer consumes this view, `isStructurallyFusableProducer` also checks
314     // whether it is a strict subview of the producer view.
315     auto producedView = dependence.dependentOpView.view;
316     auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
317     // `consumerIdx` and `producerIdx` exist by construction.
318     LLVM_DEBUG(dbgs() << "\n"
319                       << LinalgDependenceGraph::getDependenceTypeStr(depType)
320                       << "producer: " << *producer.getOperation() << " view: "
321                       << producedView << " output index: " << producerIdx);
322 
323     // Must be a subview or a slice to guarantee there are loops we can fuse
324     // into.
325     auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp());
326     auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp());
327     if (!subView && !slice) {
328       LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
329       continue;
330     }
331 
332     // Simple fusability checks.
333     if (!isFusableInto(graph, consumer, consumedView, producer))
334       continue;
335 
336     // Fuse `producer` just before `consumer`.
337     OpBuilder::InsertionGuard g(b);
338     b.setInsertionPoint(consumer.getOperation());
339     ScopedContext scope(b, consumer.getLoc());
340     LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
341     auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
342                               producerIdx, folder);
343 
344     return FusionInfo{producer, fusedProducer};
345   }
346   return llvm::None;
347 }
348 
349 // Only consider RAW and WAW atm.
350 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
351     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
352     const LinalgDependenceGraph &graph, OperationFolder *folder) {
353   SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = {
354       LinalgDependenceGraph::DependenceType::RAW,
355       LinalgDependenceGraph::DependenceType::WAW,
356   };
357   for (auto dep : deps) {
358     if (auto res =
359             fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep))
360       return res;
361   }
362   return llvm::None;
363 }
364 
365 /// Checks if two Generic ops are fusible, when one is a producer and another is
366 /// a consumer (with the result of the producer being the `consumerIdx` operand
367 /// of the consumer).
368 static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer,
369                                 unsigned consumerIdx) {
370   // Verify that the producer and consumer are ops on tensors.
371   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
372     return false;
373 
374   auto producerOp = dyn_cast<linalg::GenericOp>(producer.getOperation());
375   auto consumerOp = dyn_cast<linalg::GenericOp>(consumer.getOperation());
376   // Verify that
377   // - the producer and consumers are generic ops,
378   // - only handle cases where the producer has a single return value,
379   // - the producer return value should be the same as argument at `consumerIdx`
380   //   of the consumer,
381   // - the producer has all "parallel" iterator type.
382   // - only handle ops that use regions for specifying the scalar operations.
383   if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 ||
384       producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) ||
385       producerOp.getNumParallelLoops() != producerOp.getNumLoops() ||
386       producerOp.fun() || consumerOp.fun())
387     return false;
388 
389   // Get the consumer index map. The number of results of the consumer index map
390   // must match the number of loops of the producer.
391   AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx);
392   if (consumerIndexMap.getNumResults() != producerOp.getNumLoops())
393     return false;
394 
395   // Finally the index_map for the result must be invertible. For now just
396   // verify it is a permutation.
397   AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0);
398   return producerResultIndexMap.isPermutation();
399 }
400 
401 /// Computes the indexing maps for arguments of a producer generic op when the
402 /// result of the producer is fused with the consumer.
403 /// - consumerIndexMap is the indexing_map for the argument in the consumer op
404 ///   that is the result of the producer op.
405 /// - invProducerResultIndexMap is the inverse of the indexing_map for the
406 ///   result in the producer op.
407 /// - producerArgIndexMap is the indexing_map of the argument of the producer
408 ///   op.
409 /// The result is the indexing_map to use for the producer argument when the
410 /// producer and consumer ops are fused.
411 static AffineMap computeProducerArgMap(AffineMap consumerIndexMap,
412                                        AffineMap invProducerResultIndexMap,
413                                        AffineMap producerArgIndexMap) {
414   // t1 is map from producer result tensor index -> producer arg tensor index.
415   auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap);
416   // The return is map from consumer loop -> producer arg tensor index,
417   // i.e. indexing_map for the producer argument in the fused operation.
418   return t1.compose(consumerIndexMap);
419 }
420 
421 Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer,
422                                                LinalgOp consumer,
423                                                unsigned consumerIdx,
424                                                OperationFolder *folder) {
425   if (!areTensorOpsFusible(producer, consumer, consumerIdx))
426     return {};
427 
428   MLIRContext *context = b.getContext();
429   auto producerOp = cast<linalg::GenericOp>(producer.getOperation());
430   auto consumerOp = cast<linalg::GenericOp>(consumer.getOperation());
431   AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx);
432   AffineMap invProducerResultIndexMap =
433       inversePermutation(producerOp.getOutputIndexingMap(0));
434   if (!invProducerResultIndexMap)
435     return {};
436 
437   // Compute the fused op operandslist by replacing the operand corresponding to
438   // the result of the producer, with the operands of the producer.
439   unsigned fusedArgsIn =
440       producerOp.getNumInputs() + consumerOp.getNumInputs() - 1;
441   auto fusedArgsOut = consumerOp.getNumOutputs();
442   SmallVector<Value, 2> fusedOperandsList(consumerOp.getOperands());
443   fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx));
444   fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut);
445   fusedOperandsList.insert(
446       std::next(fusedOperandsList.begin(), consumerIdx),
447       producerOp.operand_begin(),
448       std::next(producerOp.operand_begin(), producerOp.getNumInputs()));
449 
450   // Compute the fused indexing_maps of the operands/results of the fused op.
451   SmallVector<Attribute, 2> fusedIndexingMapAttrs;
452   fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut);
453   fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(),
454                                consumerOp.indexing_maps().end());
455   fusedIndexingMapAttrs.erase(
456       std::next(fusedIndexingMapAttrs.begin(), consumerIdx));
457   auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx);
458   for (auto producerArgIndexAttr :
459        llvm::enumerate(producerOp.indexing_maps())) {
460     if (producerArgIndexAttr.index() == producerOp.getNumInputs())
461       break;
462     auto composedIndexMap = computeProducerArgMap(
463         consumerIndexMap, invProducerResultIndexMap,
464         producerArgIndexAttr.value().cast<AffineMapAttr>().getValue());
465     insertPos = std::next(fusedIndexingMapAttrs.insert(
466         insertPos, AffineMapAttr::get(composedIndexMap)));
467   }
468 
469   // Generate the fused op.
470   auto fusedLinalgOp = b.create<GenericOp>(
471       UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList,
472       b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut),
473       b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(),
474       /*doc=*/nullptr,
475       /*fun=*/nullptr,
476       /*library_call=*/nullptr);
477 
478   // Build the region of the fused op.
479   auto &fusedOpRegion = fusedLinalgOp.region();
480   Block &producerOpBlock = producerOp.region().front();
481   Block &consumerOpBlock = consumerOp.region().front();
482   Block *fusedBlock = new Block();
483   fusedOpRegion.push_back(fusedBlock);
484   BlockAndValueMapping mapper;
485   // Map the arguments for the unmodified args from the consumer.
486   for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) {
487     if (consumerOpArg.index() == consumerIdx) {
488       // Map the arguments for the args from the producer.
489       for (auto producerOpArg : producerOpBlock.getArguments())
490         mapper.map(producerOpArg,
491                    fusedBlock->addArgument(producerOpArg.getType()));
492       continue;
493     }
494     mapper.map(consumerOpArg.value(),
495                fusedBlock->addArgument(consumerOpArg.value().getType()));
496   }
497 
498   // Add operations from producer (except the yield operation) to the fused op.
499   for (auto &op : producerOpBlock.getOperations()) {
500     if (auto yieldOp = dyn_cast<YieldOp>(op)) {
501       // Lookup the value the yield operation is mapped to.
502       Value yieldVal = yieldOp.getOperand(0);
503       auto clonedVal = mapper.lookup(yieldVal);
504       mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal);
505       continue;
506     }
507     fusedBlock->push_back(op.clone(mapper));
508   }
509   for (auto &op : consumerOpBlock.getOperations())
510     fusedBlock->push_back(op.clone(mapper));
511 
512   return cast<LinalgOp>(fusedLinalgOp.getOperation());
513 }
514 
515 static void fuseLinalgOpsGreedily(FuncOp f) {
516   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
517 
518   OpBuilder b(f);
519   OperationFolder folder(f.getContext());
520   DenseSet<Operation *> eraseSet;
521 
522   // Save original Linalg ops, we only want to make a pass over those.
523   SmallVector<Operation *, 8> linalgOps;
524   f.walk([&](LinalgOp op) {
525     if (op.hasBufferSemantics())
526       linalgOps.push_back(op);
527   });
528 
529   // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself.
530   // The current naive and expensive reconstruction of the graph should be
531   // removed.
532   for (auto *op : llvm::reverse(linalgOps)) {
533     for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers();
534          id < e; ++id) {
535       linalg::Aliases aliases;
536       linalg::LinalgDependenceGraph graph(aliases, linalgOps);
537       if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
538         auto *originalOp = info->originalProducer.getOperation();
539         eraseSet.insert(originalOp);
540         auto *originalOpInLinalgOpsVector =
541             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
542         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
543       }
544     }
545   }
546   // The `fuseProducerOf` function performs structural checks and in particular
547   // that no covering read or write exist between the consumer and the producer.
548   // As a consequence, the only fusions that may occur preserve subsequent
549   // dependences and are guaranteed by construction to produce the whole view.
550   // We may thus erase the producer once it is fused.
551   for (auto *e : eraseSet)
552     e->erase();
553   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
554 }
555 
556 namespace {
557 
558 /// Patterns to fuse a generic op, with the producer of its operands.
559 struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> {
560   using OpRewritePattern<GenericOp>::OpRewritePattern;
561 
562   LogicalResult matchAndRewrite(GenericOp op,
563                                 PatternRewriter &rewriter) const override {
564     if (!op.hasTensorSemantics())
565       return failure();
566 
567     // Find the first operand that is defined by another generic op on tensors.
568     for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) {
569       auto definingOp =
570           dyn_cast_or_null<GenericOp>(operand.value().getDefiningOp());
571       if (!definingOp || !definingOp.hasTensorSemantics())
572         continue;
573       auto fusedOp =
574           fuseTensorOps(rewriter, cast<LinalgOp>(definingOp.getOperation()),
575                         cast<LinalgOp>(op.getOperation()), operand.index());
576       if (!fusedOp)
577         continue;
578       rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults());
579       if (llvm::all_of(definingOp.getResults(),
580                        [](Value val) -> bool { return val.use_empty(); }))
581         rewriter.eraseOp(definingOp);
582       return success();
583     }
584     return failure();
585   }
586 };
587 
588 /// Pass that fuses generic ops on tensors. Used only for testing.
589 struct FusionOfTensorOpsPass
590     : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
591   void runOnOperation() override {
592     OwningRewritePatternList patterns;
593     Operation *op = getOperation();
594     patterns.insert<FuseGenericTensorOps>(op->getContext());
595     applyPatternsAndFoldGreedily(op->getRegions(), patterns);
596   };
597 };
598 
599 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
600   void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
601 };
602 } // namespace
603 
604 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
605   return std::make_unique<LinalgFusionPass>();
606 }
607 
608 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
609   return std::make_unique<FusionOfTensorOpsPass>();
610 }
611