xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision d27ab5c2409b0223ffb6b7ebcb75cd1bde4ac231)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Analysis/Dominance.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/FoldUtils.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #define DEBUG_TYPE "linalg-fusion"
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 
38 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
39 
40 using llvm::dbgs;
41 
42 /// Implements a simple high-level fusion pass of linalg library operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 ///   1. inspecting the linalg ops that write into the views read by `O`. This
47 ///      uses the SSA value of the views and a simple subview/slice analysis to
48 ///      determine producer-consumer dependences;
49 ///   2. greedily fuse the linalg ops that produce subview
50 ///   3. inspect the fused ops and determine whether they have other remaining
51 ///      LinalgOp uses. If not, then erase the original producing linalg op.
52 ///
53 /// More advanced use cases, analyses as well as profitability heuristics are
54 /// left for future work.
55 
56 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
57 // a subset of the original loop ranges of `op`.
58 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
59 // to the `loopRanges` in order to obtain view ranges.
60 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
61                                     ArrayRef<SubViewOp::Range> loopRanges) {
62   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
63   auto maps = op.indexing_maps();
64   SmallVector<Value, 8> clonedViews;
65   clonedViews.reserve(op.getNumInputsAndOutputs());
66   // Iterate over the inputs and outputs in order.
67   // Extract the subranges from the linearized ranges.
68   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
69   for (auto en : llvm::enumerate(ios)) {
70     unsigned idx = en.index();
71     auto map = maps[idx].cast<AffineMapAttr>().getValue();
72     LLVM_DEBUG(dbgs() << "map: " << map << "\n");
73     Value view = en.value();
74     SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
75     for (auto en2 : llvm::enumerate(map.getResults())) {
76       unsigned d = en2.index();
77       // loopToOperandRangesMaps are permutations-only.
78       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
79       viewRanges[d] = loopRanges[loopPos];
80       LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
81                         << "\t"
82                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
83     }
84     // Construct a new subview for the tile.
85     unsigned rank = viewRanges.size();
86     SmallVector<Value, 4> offsets, sizes, strides;
87     offsets.reserve(rank);
88     sizes.reserve(rank);
89     strides.reserve(rank);
90     for (auto r : viewRanges) {
91       offsets.push_back(r.offset);
92       sizes.push_back(r.size);
93       strides.push_back(r.stride);
94     }
95     clonedViews.push_back(
96         b.create<SubViewOp>(loc, view, offsets, sizes, strides));
97   }
98   auto operands = getAssumedNonViewOperands(op);
99   clonedViews.append(operands.begin(), operands.end());
100 
101   Operation *clonedOp = op.clone(b, loc, clonedViews);
102   // When the producer is an IndexedGenercOp, we have to transform its block
103   // IV arguments according to the tiling of the consumer, i.e. offset them by
104   // the values computed in `loopRanges`.
105   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
106     auto &block = indexedGenericOp.region().front();
107 
108     OpBuilder::InsertionGuard g(b);
109     b.setInsertionPointToStart(&block);
110     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
111       Value oldIndex = block.getArgument(i);
112       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
113                                          loopRanges[i].offset);
114       oldIndex.replaceAllUsesExcept(newIndex,
115                                     SmallPtrSet<Operation *, 1>{newIndex});
116     }
117   }
118   return clonedOp;
119 }
120 
121 struct ViewDimension {
122   Value view;
123   unsigned dimension;
124 };
125 
126 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies
127 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
128 // guarantees at least one such dimension is found. If multiple candidates exist
129 // they must agree by construction (i.e. have the same size) and we just return
130 // the first one.
131 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
132   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
133   auto maps = op.indexing_maps();
134   // Iterate over the inputs and outputs in order.
135   // Extract the subranges from the linearized ranges.
136   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
137   for (auto en : llvm::enumerate(ios)) {
138     unsigned idx = en.index();
139     auto map = maps[idx].cast<AffineMapAttr>().getValue();
140     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
141     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
142     Value view = en.value();
143     SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
144     for (auto en2 : llvm::enumerate(map.getResults())) {
145       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
146         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
147                           << "\n");
148         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
149         return ViewDimension{view, static_cast<unsigned>(en2.index())};
150       }
151     }
152   }
153   llvm_unreachable("Expect to be able to extract a view defining loop range");
154 }
155 
156 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
157                      unsigned consumerIdx, unsigned producerIdx,
158                      OperationFolder *folder) {
159   assert(producer.hasBufferSemantics() &&
160          "expected linalg op with buffer semantics");
161   assert(consumer.hasBufferSemantics() &&
162          "expected linalg op with buffer semantics");
163 
164   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
165     // TODO(ntv): add a level of indirection to linalg.generic.
166     if (convOp.padding())
167       llvm_unreachable("Unexpected conv with padding");
168   }
169   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
170     // TODO(ntv): add a level of indirection to linalg.generic.
171     if (convOp.padding())
172       llvm_unreachable("Unexpected conv with padding");
173   }
174 
175   auto subView = dyn_cast_or_null<SubViewOp>(
176       consumer.getBuffer(consumerIdx).getDefiningOp());
177   auto slice = dyn_cast_or_null<SliceOp>(
178       consumer.getBuffer(consumerIdx).getDefiningOp());
179   assert(subView || slice);
180   (void)subView;
181   (void)slice;
182 
183   // loopToOperandRangesMaps are permutations-only by construction:
184   //   we can always identify a data dimension with a (at least one) loop
185   //   dimension.
186   AffineMap producerMap =
187       producer.indexing_maps()[producer.getNumInputs() + producerIdx]
188           .cast<AffineMapAttr>()
189           .getValue();
190   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
191                     << ", producer map: " << producerMap << "\n");
192 
193   unsigned nPar = producer.getNumParallelLoops();
194   unsigned nRed = producer.getNumReductionLoops();
195   unsigned nWin = producer.getNumWindowLoops();
196   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
197 
198   // Iterate over dimensions identified by the producer map for `producerIdx`.
199   // This defines a subset of the loop ranges that we need to complete later.
200   for (auto en : llvm::enumerate(producerMap.getResults())) {
201     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
202     loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
203   }
204 
205   OpBuilder b(consumer.getOperation());
206   auto loc = consumer.getLoc();
207   // Iterate over all dimensions. For the dimensions not identified by the
208   // producer map for `producerIdx`, we need to explicitly compute the view that
209   // defines the loop ranges using the `producer`.
210   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
211     if (loopRanges[i].offset)
212       LLVM_DEBUG(llvm::dbgs()
213                  << "existing LoopRange: " << loopRanges[i] << "\n");
214     else {
215       auto viewDim = getViewDefiningLoopRange(producer, i);
216       loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0),
217                                        std_dim(viewDim.view, viewDim.dimension),
218                                        folded_std_constant_index(folder, 1)};
219       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
220     }
221   }
222 
223   return cloneWithLoopRanges(b, loc, producer, loopRanges);
224 }
225 
226 // Encode structural fusion safety preconditions.
227 // Some of these will be lifted in the future with better analysis.
228 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
229                                           LinalgOp consumer) {
230   assert(producer.hasBufferSemantics() &&
231          "expected linalg op with buffer semantics");
232   assert(consumer.hasBufferSemantics() &&
233          "expected linalg op with buffer semantics");
234   if (producer.getNumOutputs() != 1) {
235     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
236     return false;
237   }
238   // Only fuse when the producer block dominates.
239   DominanceInfo dom(producer.getOperation());
240   if (!dom.dominates(producer.getOperation()->getBlock(),
241                      consumer.getOperation()->getBlock())) {
242     LLVM_DEBUG(
243         dbgs()
244         << "\nNot structurally fusable (producer block does not dominate)");
245     return false;
246   }
247   return true;
248 }
249 
250 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
251                                              LinalgOp consumer,
252                                              Value consumedView,
253                                              LinalgOp producer) {
254   assert(producer.hasBufferSemantics() &&
255          "expected linalg op with buffer semantics");
256   assert(consumer.hasBufferSemantics() &&
257          "expected linalg op with buffer semantics");
258   // Make some simple structural checks that alleviate the need for more
259   // complex analyses.
260   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
261     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
262                       << *producer.getOperation());
263     return false;
264   }
265   // Check for any interleaved write to consumedView.
266   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
267     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
268                       << *producer.getOperation());
269     return false;
270   }
271   return true;
272 }
273 
274 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
275                                  LinalgOp consumer, Value consumedView,
276                                  LinalgOp producer) {
277   assert(producer.hasBufferSemantics() &&
278          "expected linalg op with buffer semantics");
279   assert(consumer.hasBufferSemantics() &&
280          "expected linalg op with buffer semantics");
281   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
282     return false;
283   // Check for any fusion-preventing dependence to any view read/written that
284   // would violate dependences.
285   if (!graph.findCoveringDependences(producer, consumer).empty()) {
286     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
287                       << *producer.getOperation());
288     return false;
289   }
290   return true;
291 }
292 
293 static Optional<FusionInfo>
294 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
295                   const LinalgDependenceGraph &graph, OperationFolder *folder,
296                   LinalgDependenceGraph::DependenceType depType) {
297   assert(consumer.hasBufferSemantics() &&
298          "expected linalg op with buffer semantics");
299   LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
300                     << *consumer.getOperation());
301   for (auto dependence : graph.getDependencesInto(consumer, depType)) {
302     LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
303                       << *dependence.dependentOpView.op << "\n");
304     auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
305 
306     // Check that the dependence is indeed on the input `consumerIdx` view.
307     auto consumedView = dependence.indexingView;
308     if (consumer.getBuffer(consumerIdx) != consumedView)
309       continue;
310 
311     // Consumer consumes this view, `isStructurallyFusableProducer` also checks
312     // whether it is a strict subview of the producer view.
313     auto producedView = dependence.dependentOpView.view;
314     auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
315     // `consumerIdx` and `producerIdx` exist by construction.
316     LLVM_DEBUG(dbgs() << "\n"
317                       << LinalgDependenceGraph::getDependenceTypeStr(depType)
318                       << "producer: " << *producer.getOperation() << " view: "
319                       << producedView << " output index: " << producerIdx);
320 
321     // Must be a subview or a slice to guarantee there are loops we can fuse
322     // into.
323     auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp());
324     auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp());
325     if (!subView && !slice) {
326       LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
327       continue;
328     }
329 
330     // Simple fusability checks.
331     if (!isFusableInto(graph, consumer, consumedView, producer))
332       continue;
333 
334     // Fuse `producer` just before `consumer`.
335     OpBuilder::InsertionGuard g(b);
336     b.setInsertionPoint(consumer.getOperation());
337     ScopedContext scope(b, consumer.getLoc());
338     LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
339     auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
340                               producerIdx, folder);
341 
342     return FusionInfo{producer, fusedProducer};
343   }
344   return llvm::None;
345 }
346 
347 // Only consider RAW and WAW atm.
348 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
349     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
350     const LinalgDependenceGraph &graph, OperationFolder *folder) {
351   SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = {
352       LinalgDependenceGraph::DependenceType::RAW,
353       LinalgDependenceGraph::DependenceType::WAW,
354   };
355   for (auto dep : deps) {
356     if (auto res =
357             fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep))
358       return res;
359   }
360   return llvm::None;
361 }
362 
363 static void fuseLinalgOpsGreedily(FuncOp f) {
364   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
365 
366   OpBuilder b(f);
367   OperationFolder folder(f.getContext());
368   DenseSet<Operation *> eraseSet;
369 
370   // Save original Linalg ops, we only want to make a pass over those.
371   SmallVector<Operation *, 8> linalgOps;
372   f.walk([&](LinalgOp op) {
373     if (op.hasBufferSemantics())
374       linalgOps.push_back(op);
375   });
376 
377   // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself.
378   // The current naive and expensive reconstruction of the graph should be
379   // removed.
380   for (auto *op : llvm::reverse(linalgOps)) {
381     for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers();
382          id < e; ++id) {
383       linalg::Aliases aliases;
384       linalg::LinalgDependenceGraph graph(aliases, linalgOps);
385       if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
386         auto *originalOp = info->originalProducer.getOperation();
387         eraseSet.insert(originalOp);
388         auto *originalOpInLinalgOpsVector =
389             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
390         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
391       }
392     }
393   }
394   // The `fuseProducerOf` function performs structural checks and in particular
395   // that no covering read or write exist between the consumer and the producer.
396   // As a consequence, the only fusions that may occur preserve subsequent
397   // dependences and are guaranteed by construction to produce the whole view.
398   // We may thus erase the producer once it is fused.
399   for (auto *e : eraseSet)
400     e->erase();
401   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
402 }
403 
404 //====---------------------------------------------------------------------===//
405 // Fusion on Tensor operation.
406 //====---------------------------------------------------------------------===//
407 
408 namespace {
409 
410 /// Implementation of fusion of generic ops.
411 struct FuseGenericOpsOnTensors {
412   static bool isFusible(GenericOp producer, GenericOp consumer,
413                         unsigned consumerIdx) {
414     // Verify that
415     // - the producer has all "parallel" iterator type.
416     if (producer.getNumParallelLoops() != producer.getNumLoops())
417       return false;
418 
419     // Get the consumer index map. The number of results of the consumer index
420     // map must match the number of loops of the producer.
421     AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
422     if (consumerIndexMap.getNumResults() != producer.getNumLoops())
423       return false;
424 
425     // Finally the index_map for the result must be invertible. For now just
426     // verify it is a permutation.
427     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
428     return producerResultIndexMap.isPermutation();
429   }
430 
431   static Operation *fuse(GenericOp producer, GenericOp consumer,
432                          unsigned consumerIdx, PatternRewriter &rewriter,
433                          OperationFolder *folder = nullptr) {
434     if (!isFusible(producer, consumer, consumerIdx))
435       return nullptr;
436 
437     unsigned numFusedOperands = producer.getOperation()->getNumOperands() +
438                                 consumer.getOperation()->getNumOperands() - 1;
439 
440     // Compute the fused operands list,
441     SmallVector<Value, 2> fusedOperands;
442     fusedOperands.reserve(numFusedOperands);
443     auto consumerOperands = consumer.getOperation()->getOperands();
444     auto producerOperands = producer.getOperation()->getOperands();
445     fusedOperands.assign(consumerOperands.begin(),
446                          std::next(consumerOperands.begin(), consumerIdx));
447     fusedOperands.append(producerOperands.begin(), producerOperands.end());
448     fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
449                          consumerOperands.end());
450 
451     // Compute indexing_maps for the fused operation. The indexing_maps for the
452     // operands of the consumers that arent fused are the same. The
453     // indexing_maps for the producers need to be computed based on the
454     // indexing_map of the operand at consumerIdx in the consumer.
455     SmallVector<Attribute, 4> fusedIndexMaps;
456     auto consumerIndexMaps = consumer.indexing_maps();
457     fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumResults());
458     fusedIndexMaps.assign(consumerIndexMaps.begin(),
459                           std::next(consumerIndexMaps.begin(), consumerIdx));
460     // Compute indexing maps for the producer args in the fused operation.
461     computeProducerOperandIndex(
462         producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
463 
464     // Append the indexing maps for the remaining consumer operands.
465     fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
466                           consumerIndexMaps.end());
467 
468     // Generate the fused op.
469     auto fusedOp = rewriter.create<GenericOp>(
470         rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
471         rewriter.getI64IntegerAttr(fusedOperands.size()),
472         rewriter.getI64IntegerAttr(consumer.getNumResults()),
473         rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(),
474         /*doc=*/nullptr,
475         /*library_call=*/nullptr);
476     generateFusedRegion(rewriter, fusedOp.region(), producer.region(),
477                         consumer.region(), consumerIdx);
478     return fusedOp;
479   }
480 
481 private:
482   /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
483   /// the `producer` to use in the fused operation given the indexing map of the
484   /// result of the producer in the consumer.
485   static void computeProducerOperandIndex(
486       GenericOp producer, AffineMap fusedConsumerArgIndexMap,
487       SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
488     // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
489     // from consumer loop -> consumer arg tensor index/producer result tensor
490     // index. The fused loop is same as the consumer loop. For each producer arg
491     // the indexing map to be computed is a map from consumer loop -> producer
492     // arg tensor index.
493 
494     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
495     // producerResultIndexMap is a map from producer loop -> tensor index.
496     // Compute the inverse to get map from tensor index -> producer loop.
497     // The inverse is a map from producer result tensor index -> producer loop.
498     AffineMap invProducerResultIndexMap =
499         inversePermutation(producerResultIndexMap);
500     assert(invProducerResultIndexMap &&
501            "expected producer result indexig map to be invertible");
502     for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
503       // argMap is a map from producer loop -> producer arg tensor index.
504       AffineMap argMap = producer.getInputIndexingMap(argNum);
505 
506       // Compose argMap with invProducerResultIndexMap to get a map from
507       // producer result tensor index -> producer arg tensor index.
508       AffineMap t1 = argMap.compose(invProducerResultIndexMap);
509 
510       // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
511       // consumer loop/ fused loop -> producer arg tensor index.
512       AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
513       fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
514     }
515   }
516 
517   /// Generate the region of the fused operation. The region of the fused op
518   /// must be empty.
519   static void generateFusedRegion(PatternRewriter &rewriter,
520                                   Region &fusedRegion, Region &producerRegion,
521                                   Region &consumerRegion,
522                                   unsigned consumerIdx) {
523     // Build the region of the fused op.
524     Block &producerBlock = producerRegion.front();
525     Block &consumerBlock = consumerRegion.front();
526     Block *fusedBlock = new Block();
527     fusedRegion.push_back(fusedBlock);
528     BlockAndValueMapping mapper;
529     OpBuilder::InsertionGuard guard(rewriter);
530     rewriter.setInsertionPointToStart(fusedBlock);
531     // Map the arguments for the unmodified args from the consumer.
532     for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
533       if (consumerArg.index() == consumerIdx) {
534         // Map the arguments for the args from the producer.
535         for (auto producerArg : producerBlock.getArguments())
536           mapper.map(producerArg,
537                      fusedBlock->addArgument(producerArg.getType()));
538         continue;
539       }
540       mapper.map(consumerArg.value(),
541                  fusedBlock->addArgument(consumerArg.value().getType()));
542     }
543 
544     // Add operations from producer (except the yield operation) to the fused
545     // op.
546     for (auto &op : producerBlock.getOperations()) {
547       if (auto yieldOp = dyn_cast<YieldOp>(op)) {
548         // Lookup the value the yield operation is mapped to.
549         Value yieldVal = yieldOp.getOperand(0);
550         auto clonedVal = mapper.lookup(yieldVal);
551         mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal);
552         continue;
553       }
554       rewriter.clone(op, mapper);
555     }
556     for (auto &op : consumerBlock.getOperations())
557       rewriter.clone(op, mapper);
558   }
559 };
560 } // namespace
561 
562 Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
563                                        Operation *consumer,
564                                        unsigned consumerIdx,
565                                        OperationFolder *folder) {
566   if (consumerIdx >= consumer->getNumOperands())
567     return nullptr;
568   Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
569   if (!producer || producer->getNumResults() != 1)
570     return nullptr;
571 
572   if (GenericOp genericOp = dyn_cast<GenericOp>(consumer)) {
573     if (!genericOp.hasTensorSemantics())
574       return nullptr;
575     if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
576       if (genericOpProducer.hasTensorSemantics())
577         return FuseGenericOpsOnTensors::fuse(genericOpProducer, genericOp,
578                                              consumerIdx, rewriter, folder);
579     }
580   }
581   return nullptr;
582 }
583 
584 namespace {
585 /// Patterns to fuse a generic op, with the producer of its operands.
586 template <typename LinalgOpTy>
587 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
588   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
589 
590   LogicalResult matchAndRewrite(LinalgOpTy op,
591                                 PatternRewriter &rewriter) const override {
592     // Find the first operand that is defined by another generic op on tensors.
593     for (auto operandNum :
594          llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) {
595       Operation *producer =
596           op.getOperation()->getOperand(operandNum).getDefiningOp();
597       if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) {
598         rewriter.replaceOp(op, fusedOp->getResults());
599         if (producer && llvm::all_of(producer->getResults(),
600                                      [](Value val) { return val.use_empty(); }))
601           rewriter.eraseOp(producer);
602         return success();
603       }
604     }
605     return failure();
606   }
607 };
608 
609 /// Pass that fuses generic ops on tensors. Used only for testing.
610 struct FusionOfTensorOpsPass
611     : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
612   void runOnOperation() override {
613     OwningRewritePatternList patterns;
614     Operation *op = getOperation();
615     patterns.insert<FuseTensorOps<GenericOp>>(op->getContext());
616     applyPatternsAndFoldGreedily(op->getRegions(), patterns);
617   };
618 };
619 
620 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
621   void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
622 };
623 } // namespace
624 
625 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
626   return std::make_unique<LinalgFusionPass>();
627 }
628 
629 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
630   return std::make_unique<FusionOfTensorOpsPass>();
631 }
632