xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision 98eead81868c1ba017cc5d8dbea11285d2eadc4c)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Dominance.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/FoldUtils.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #define DEBUG_TYPE "linalg-fusion"
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 
38 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
39 
40 using llvm::dbgs;
41 
42 /// Implements a simple high-level fusion pass of linalg library operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 ///   1. inspecting the linalg ops that write into the views read by `O`. This
47 ///      uses the SSA value of the views and a simple subview/slice analysis to
48 ///      determine producer-consumer dependences;
49 ///   2. greedily fuse the linalg ops that produce subview
50 ///   3. inspect the fused ops and determine whether they have other remaining
51 ///      LinalgOp uses. If not, then erase the original producing linalg op.
52 ///
53 /// More advanced use cases, analyses as well as profitability heuristics are
54 /// left for future work.
55 
56 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
57 // a subset of the original loop ranges of `op`.
58 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
59 // to the `loopRanges` in order to obtain view ranges.
60 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
61                                     ArrayRef<SubViewOp::Range> loopRanges) {
62   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
63   auto maps = op.indexing_maps();
64   SmallVector<Value, 8> clonedViews;
65   clonedViews.reserve(op.getNumInputsAndOutputs());
66   // Iterate over the inputs and outputs in order.
67   // Extract the subranges from the linearized ranges.
68   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
69   for (auto en : llvm::enumerate(ios)) {
70     unsigned idx = en.index();
71     auto map = maps[idx].cast<AffineMapAttr>().getValue();
72     LLVM_DEBUG(dbgs() << "map: " << map << "\n");
73     Value view = en.value();
74     SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
75     for (auto en2 : llvm::enumerate(map.getResults())) {
76       unsigned d = en2.index();
77       // loopToOperandRangesMaps are permutations-only.
78       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
79       viewRanges[d] = loopRanges[loopPos];
80       LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
81                         << "\t"
82                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
83     }
84     // Construct a new subview for the tile.
85     unsigned rank = viewRanges.size();
86     SmallVector<Value, 4> offsets, sizes, strides;
87     offsets.reserve(rank);
88     sizes.reserve(rank);
89     strides.reserve(rank);
90     for (auto r : viewRanges) {
91       offsets.push_back(r.offset);
92       sizes.push_back(r.size);
93       strides.push_back(r.stride);
94     }
95     clonedViews.push_back(
96         b.create<SubViewOp>(loc, view, offsets, sizes, strides));
97   }
98   auto operands = getAssumedNonViewOperands(op);
99   clonedViews.append(operands.begin(), operands.end());
100 
101   Operation *clonedOp = op.clone(b, loc, clonedViews);
102   // When the producer is an IndexedGenercOp, we have to transform its block
103   // IV arguments according to the tiling of the consumer, i.e. offset them by
104   // the values computed in `loopRanges`.
105   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
106     auto &block = indexedGenericOp.region().front();
107 
108     OpBuilder::InsertionGuard g(b);
109     b.setInsertionPointToStart(&block);
110     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
111       Value oldIndex = block.getArgument(i);
112       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
113                                          loopRanges[i].offset);
114       oldIndex.replaceAllUsesExcept(newIndex,
115                                     SmallPtrSet<Operation *, 1>{newIndex});
116     }
117   }
118   return clonedOp;
119 }
120 
121 struct ViewDimension {
122   Value view;
123   unsigned dimension;
124 };
125 
126 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies
127 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
128 // guarantees at least one such dimension is found. If multiple candidates exist
129 // they must agree by construction (i.e. have the same size) and we just return
130 // the first one.
131 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
132   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
133   auto maps = op.indexing_maps();
134   // Iterate over the inputs and outputs in order.
135   // Extract the subranges from the linearized ranges.
136   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
137   for (auto en : llvm::enumerate(ios)) {
138     unsigned idx = en.index();
139     auto map = maps[idx].cast<AffineMapAttr>().getValue();
140     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
141     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
142     Value view = en.value();
143     SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
144     for (auto en2 : llvm::enumerate(map.getResults())) {
145       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
146         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
147                           << "\n");
148         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
149         return ViewDimension{view, static_cast<unsigned>(en2.index())};
150       }
151     }
152   }
153   llvm_unreachable("Expect to be able to extract a view defining loop range");
154 }
155 
156 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
157                      unsigned consumerIdx, unsigned producerIdx,
158                      OperationFolder *folder) {
159   assert(producer.hasBufferSemantics() &&
160          "expected linalg op with buffer semantics");
161   assert(consumer.hasBufferSemantics() &&
162          "expected linalg op with buffer semantics");
163 
164   auto subView = dyn_cast_or_null<SubViewOp>(
165       consumer.getBuffer(consumerIdx).getDefiningOp());
166   auto slice = dyn_cast_or_null<SliceOp>(
167       consumer.getBuffer(consumerIdx).getDefiningOp());
168   assert(subView || slice);
169   (void)subView;
170   (void)slice;
171 
172   // loopToOperandRangesMaps are permutations-only by construction:
173   //   we can always identify a data dimension with a (at least one) loop
174   //   dimension.
175   AffineMap producerMap =
176       producer.indexing_maps()[producer.getNumInputs() + producerIdx]
177           .cast<AffineMapAttr>()
178           .getValue();
179   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
180                     << ", producer map: " << producerMap << "\n");
181 
182   unsigned nPar = producer.getNumParallelLoops();
183   unsigned nRed = producer.getNumReductionLoops();
184   unsigned nWin = producer.getNumWindowLoops();
185   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
186 
187   // Iterate over dimensions identified by the producer map for `producerIdx`.
188   // This defines a subset of the loop ranges that we need to complete later.
189   for (auto en : llvm::enumerate(producerMap.getResults())) {
190     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
191     loopRanges[posInProducerLoop] = subView.getRanges()[en.index()];
192   }
193 
194   OpBuilder b(consumer.getOperation());
195   auto loc = consumer.getLoc();
196   // Iterate over all dimensions. For the dimensions not identified by the
197   // producer map for `producerIdx`, we need to explicitly compute the view that
198   // defines the loop ranges using the `producer`.
199   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
200     if (loopRanges[i].offset)
201       LLVM_DEBUG(llvm::dbgs()
202                  << "existing LoopRange: " << loopRanges[i] << "\n");
203     else {
204       auto viewDim = getViewDefiningLoopRange(producer, i);
205       loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0),
206                                        std_dim(viewDim.view, viewDim.dimension),
207                                        folded_std_constant_index(folder, 1)};
208       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
209     }
210   }
211 
212   return cloneWithLoopRanges(b, loc, producer, loopRanges);
213 }
214 
215 // Encode structural fusion safety preconditions.
216 // Some of these will be lifted in the future with better analysis.
217 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
218                                           LinalgOp consumer) {
219   assert(producer.hasBufferSemantics() &&
220          "expected linalg op with buffer semantics");
221   assert(consumer.hasBufferSemantics() &&
222          "expected linalg op with buffer semantics");
223   if (producer.getNumOutputs() != 1) {
224     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
225     return false;
226   }
227   // Only fuse when the producer block dominates.
228   DominanceInfo dom(producer.getOperation());
229   if (!dom.dominates(producer.getOperation()->getBlock(),
230                      consumer.getOperation()->getBlock())) {
231     LLVM_DEBUG(
232         dbgs()
233         << "\nNot structurally fusable (producer block does not dominate)");
234     return false;
235   }
236   return true;
237 }
238 
239 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
240                                              LinalgOp consumer,
241                                              Value consumedView,
242                                              LinalgOp producer) {
243   assert(producer.hasBufferSemantics() &&
244          "expected linalg op with buffer semantics");
245   assert(consumer.hasBufferSemantics() &&
246          "expected linalg op with buffer semantics");
247   // Make some simple structural checks that alleviate the need for more
248   // complex analyses.
249   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
250     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
251                       << *producer.getOperation());
252     return false;
253   }
254   // Check for any interleaved write to consumedView.
255   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
256     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
257                       << *producer.getOperation());
258     return false;
259   }
260   return true;
261 }
262 
263 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
264                                  LinalgOp consumer, Value consumedView,
265                                  LinalgOp producer) {
266   assert(producer.hasBufferSemantics() &&
267          "expected linalg op with buffer semantics");
268   assert(consumer.hasBufferSemantics() &&
269          "expected linalg op with buffer semantics");
270   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
271     return false;
272   // Check for any fusion-preventing dependence to any view read/written that
273   // would violate dependences.
274   if (!graph.findCoveringDependences(producer, consumer).empty()) {
275     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
276                       << *producer.getOperation());
277     return false;
278   }
279   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
280     // TODO(ntv): add a level of indirection to linalg.generic.
281     if (convOp.padding())
282       return false;
283   }
284   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
285     // TODO(ntv): add a level of indirection to linalg.generic.
286     if (convOp.padding())
287       return false;
288   }
289   return true;
290 }
291 
292 static Optional<FusionInfo>
293 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
294                   const LinalgDependenceGraph &graph, OperationFolder *folder,
295                   LinalgDependenceGraph::DependenceType depType) {
296   assert(consumer.hasBufferSemantics() &&
297          "expected linalg op with buffer semantics");
298   LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
299                     << *consumer.getOperation());
300   for (auto dependence : graph.getDependencesInto(consumer, depType)) {
301     LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
302                       << *dependence.dependentOpView.op << "\n");
303     auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
304 
305     // Check that the dependence is indeed on the input `consumerIdx` view.
306     auto consumedView = dependence.indexingView;
307     if (consumer.getBuffer(consumerIdx) != consumedView)
308       continue;
309 
310     // Consumer consumes this view, `isStructurallyFusableProducer` also checks
311     // whether it is a strict subview of the producer view.
312     auto producedView = dependence.dependentOpView.view;
313     auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
314     // `consumerIdx` and `producerIdx` exist by construction.
315     LLVM_DEBUG(dbgs() << "\n"
316                       << LinalgDependenceGraph::getDependenceTypeStr(depType)
317                       << "producer: " << *producer.getOperation() << " view: "
318                       << producedView << " output index: " << producerIdx);
319 
320     // Must be a subview or a slice to guarantee there are loops we can fuse
321     // into.
322     auto subView = consumedView.getDefiningOp<SubViewOp>();
323     auto slice = consumedView.getDefiningOp<SliceOp>();
324     if (!subView && !slice) {
325       LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
326       continue;
327     }
328 
329     // Simple fusability checks.
330     if (!isFusableInto(graph, consumer, consumedView, producer))
331       continue;
332 
333     // Fuse `producer` just before `consumer`.
334     OpBuilder::InsertionGuard g(b);
335     b.setInsertionPoint(consumer.getOperation());
336     ScopedContext scope(b, consumer.getLoc());
337     LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
338     auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
339                               producerIdx, folder);
340 
341     return FusionInfo{producer, fusedProducer};
342   }
343   return llvm::None;
344 }
345 
346 // Only consider RAW and WAW atm.
347 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
348     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
349     const LinalgDependenceGraph &graph, OperationFolder *folder) {
350   SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = {
351       LinalgDependenceGraph::DependenceType::RAW,
352       LinalgDependenceGraph::DependenceType::WAW,
353   };
354   for (auto dep : deps) {
355     if (auto res =
356             fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep))
357       return res;
358   }
359   return llvm::None;
360 }
361 
362 static void fuseLinalgOpsGreedily(FuncOp f) {
363   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
364 
365   OpBuilder b(f);
366   OperationFolder folder(f.getContext());
367   DenseSet<Operation *> eraseSet;
368 
369   // Save original Linalg ops, we only want to make a pass over those.
370   SmallVector<Operation *, 8> linalgOps;
371   f.walk([&](LinalgOp op) {
372     if (op.hasBufferSemantics())
373       linalgOps.push_back(op);
374   });
375 
376   // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself.
377   // The current naive and expensive reconstruction of the graph should be
378   // removed.
379   for (auto *op : llvm::reverse(linalgOps)) {
380     for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers();
381          id < e; ++id) {
382       linalg::Aliases aliases;
383       linalg::LinalgDependenceGraph graph(aliases, linalgOps);
384       if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
385         auto *originalOp = info->originalProducer.getOperation();
386         eraseSet.insert(originalOp);
387         auto *originalOpInLinalgOpsVector =
388             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
389         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
390       }
391     }
392   }
393   // The `fuseProducerOf` function performs structural checks and in particular
394   // that no covering read or write exist between the consumer and the producer.
395   // As a consequence, the only fusions that may occur preserve subsequent
396   // dependences and are guaranteed by construction to produce the whole view.
397   // We may thus erase the producer once it is fused.
398   for (auto *e : eraseSet)
399     e->erase();
400   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
401 }
402 
403 //====---------------------------------------------------------------------===//
404 // Fusion on Tensor operation.
405 //====---------------------------------------------------------------------===//
406 
407 namespace {
408 
409 /// Implementation of fusion of generic ops.
410 struct FuseGenericOpsOnTensors {
411   static bool isFusible(GenericOp producer, GenericOp consumer,
412                         unsigned consumerIdx) {
413     // Verify that
414     // - the producer has all "parallel" iterator type.
415     if (producer.getNumParallelLoops() != producer.getNumLoops())
416       return false;
417 
418     // Get the consumer index map. The number of results of the consumer index
419     // map must match the number of loops of the producer.
420     AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
421     if (consumerIndexMap.getNumResults() != producer.getNumLoops())
422       return false;
423 
424     // Finally the index_map for the result must be invertible. For now just
425     // verify it is a permutation.
426     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
427     return producerResultIndexMap.isPermutation();
428   }
429 
430   static Operation *fuse(GenericOp producer, GenericOp consumer,
431                          unsigned consumerIdx, PatternRewriter &rewriter,
432                          OperationFolder *folder = nullptr) {
433     if (!isFusible(producer, consumer, consumerIdx))
434       return nullptr;
435 
436     unsigned numFusedOperands = producer.getOperation()->getNumOperands() +
437                                 consumer.getOperation()->getNumOperands() - 1;
438 
439     // Compute the fused operands list,
440     SmallVector<Value, 2> fusedOperands;
441     fusedOperands.reserve(numFusedOperands);
442     auto consumerOperands = consumer.getOperation()->getOperands();
443     auto producerOperands = producer.getOperation()->getOperands();
444     fusedOperands.assign(consumerOperands.begin(),
445                          std::next(consumerOperands.begin(), consumerIdx));
446     fusedOperands.append(producerOperands.begin(), producerOperands.end());
447     fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
448                          consumerOperands.end());
449 
450     // Compute indexing_maps for the fused operation. The indexing_maps for the
451     // operands of the consumers that arent fused are the same. The
452     // indexing_maps for the producers need to be computed based on the
453     // indexing_map of the operand at consumerIdx in the consumer.
454     SmallVector<Attribute, 4> fusedIndexMaps;
455     auto consumerIndexMaps = consumer.indexing_maps();
456     fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumResults());
457     fusedIndexMaps.assign(consumerIndexMaps.begin(),
458                           std::next(consumerIndexMaps.begin(), consumerIdx));
459     // Compute indexing maps for the producer args in the fused operation.
460     computeProducerOperandIndex(
461         producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
462 
463     // Append the indexing maps for the remaining consumer operands.
464     fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
465                           consumerIndexMaps.end());
466 
467     // Generate the fused op.
468     auto fusedOp = rewriter.create<GenericOp>(
469         rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
470         rewriter.getI64IntegerAttr(fusedOperands.size()),
471         rewriter.getI64IntegerAttr(consumer.getNumResults()),
472         rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(),
473         /*doc=*/nullptr,
474         /*library_call=*/nullptr);
475     generateFusedRegion(rewriter, fusedOp.region(), producer.region(),
476                         consumer.region(), consumerIdx);
477     return fusedOp;
478   }
479 
480 private:
481   /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
482   /// the `producer` to use in the fused operation given the indexing map of the
483   /// result of the producer in the consumer.
484   static void computeProducerOperandIndex(
485       GenericOp producer, AffineMap fusedConsumerArgIndexMap,
486       SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
487     // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
488     // from consumer loop -> consumer arg tensor index/producer result tensor
489     // index. The fused loop is same as the consumer loop. For each producer arg
490     // the indexing map to be computed is a map from consumer loop -> producer
491     // arg tensor index.
492 
493     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
494     // producerResultIndexMap is a map from producer loop -> tensor index.
495     // Compute the inverse to get map from tensor index -> producer loop.
496     // The inverse is a map from producer result tensor index -> producer loop.
497     AffineMap invProducerResultIndexMap =
498         inversePermutation(producerResultIndexMap);
499     assert(invProducerResultIndexMap &&
500            "expected producer result indexig map to be invertible");
501     for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
502       // argMap is a map from producer loop -> producer arg tensor index.
503       AffineMap argMap = producer.getInputIndexingMap(argNum);
504 
505       // Compose argMap with invProducerResultIndexMap to get a map from
506       // producer result tensor index -> producer arg tensor index.
507       AffineMap t1 = argMap.compose(invProducerResultIndexMap);
508 
509       // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
510       // consumer loop/ fused loop -> producer arg tensor index.
511       AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
512       fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
513     }
514   }
515 
516   /// Generate the region of the fused operation. The region of the fused op
517   /// must be empty.
518   static void generateFusedRegion(PatternRewriter &rewriter,
519                                   Region &fusedRegion, Region &producerRegion,
520                                   Region &consumerRegion,
521                                   unsigned consumerIdx) {
522     // Build the region of the fused op.
523     Block &producerBlock = producerRegion.front();
524     Block &consumerBlock = consumerRegion.front();
525     Block *fusedBlock = new Block();
526     fusedRegion.push_back(fusedBlock);
527     BlockAndValueMapping mapper;
528     OpBuilder::InsertionGuard guard(rewriter);
529     rewriter.setInsertionPointToStart(fusedBlock);
530     // Map the arguments for the unmodified args from the consumer.
531     for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
532       if (consumerArg.index() == consumerIdx) {
533         // Map the arguments for the args from the producer.
534         for (auto producerArg : producerBlock.getArguments())
535           mapper.map(producerArg,
536                      fusedBlock->addArgument(producerArg.getType()));
537         continue;
538       }
539       mapper.map(consumerArg.value(),
540                  fusedBlock->addArgument(consumerArg.value().getType()));
541     }
542 
543     // Add operations from producer (except the yield operation) to the fused
544     // op.
545     for (auto &op : producerBlock.getOperations()) {
546       if (auto yieldOp = dyn_cast<YieldOp>(op)) {
547         // Lookup the value the yield operation is mapped to.
548         Value yieldVal = yieldOp.getOperand(0);
549         auto clonedVal = mapper.lookup(yieldVal);
550         mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal);
551         continue;
552       }
553       rewriter.clone(op, mapper);
554     }
555     for (auto &op : consumerBlock.getOperations())
556       rewriter.clone(op, mapper);
557   }
558 };
559 } // namespace
560 
561 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
562 /// provided, given the shape of the source tensor that corresponds to the
563 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
564 /// are "row-major" ordered logically.
565 ///
566 /// For example:
567 ///
568 /// %0 = op ... : tensor<?x?x4x5xf32>
569 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
570 ///
571 /// and reshape:
572 /// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
573 ///                                affine_map<(i, j, k, l) -> (j, k, l)>] :
574 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
575 ///
576 /// would be rewritten into:
577 /// %0 = op ... : tensor<?x?x4x5xf32>
578 /// with output index_map
579 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
580 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
581                                         ArrayRef<int64_t> sourceShape,
582                                         ArrayRef<AffineMap> reassociationMaps) {
583   SmallVector<AffineExpr, 4> resultExprs;
584   resultExprs.reserve(reassociationMaps.size());
585   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
586   MLIRContext *context = sourceMap.getContext();
587 
588   // Compute the result exprs based on the reassociation maps.
589   for (AffineMap map : reassociationMaps) {
590     ArrayRef<AffineExpr> collapsedDims = map.getResults();
591     // Assume that they are in-order and contiguous (already checked in
592     // verifier).
593     assert(!collapsedDims.empty());
594     unsigned startDim =
595         collapsedDims.front().cast<AffineDimExpr>().getPosition();
596     AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
597         sourceShape.slice(startDim, collapsedDims.size()),
598         sourceExprs.slice(startDim, collapsedDims.size()), context);
599     resultExprs.push_back(linearizedExpr);
600   }
601   return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
602                         resultExprs, context);
603 }
604 
605 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
606 /// true) or its producer (if `asProducer` is false) given the indexing map at
607 /// its use.
608 static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
609                                      AffineMap useIndexMap, bool asProducer) {
610   RankedTensorType returnType = reshapeOp.getResultType();
611   RankedTensorType operandType = reshapeOp.getSrcType();
612   // Reshape is fusible with its consumer (i.e. reshape as a producer) when its
613   // operand is of lesser rank than the result. Fusing when operand has higher
614   // rank will require use of mods and divs in the indexing maps of the fused op
615   // which would make it non-invertible. Similarly reshape is fused with its
616   // producer (i.e. reshape as consumer) only if the return type has lesser
617   // rank.
618   if ((asProducer && returnType.getRank() < operandType.getRank()) ||
619       (!asProducer && operandType.getRank() < returnType.getRank()))
620     return false;
621   return useIndexMap.isIdentity();
622 }
623 
624 namespace {
625 /// Implementation of fusion on tensor ops when producer is a TensorReshapeOp.
626 template <typename LinalgOpTy> struct FuseTensorReshapeOpAsProducer {
627   static bool isFusible(TensorReshapeOp producer, LinalgOpTy consumer,
628                         unsigned consumerIdx) {
629     return isTensorReshapeOpFusible(
630         producer, consumer.getInputIndexingMap(consumerIdx), true);
631   }
632 
633   static Operation *fuse(TensorReshapeOp producer, LinalgOpTy consumer,
634                          unsigned consumerIdx, PatternRewriter &rewriter,
635                          OperationFolder *folder = nullptr) {
636     if (!isFusible(producer, consumer, consumerIdx))
637       return nullptr;
638 
639     // Compute the fused operands list,
640     SmallVector<Value, 2> fusedOperands(consumer.operand_begin(),
641                                         consumer.operand_end());
642     fusedOperands[consumerIdx] = producer.src();
643 
644     // Compute indexing_maps for the fused operation. The indexing_maps for the
645     // operands of the consumers that arent fused are the same.
646     SmallVector<AffineMap, 4> fusedIndexMaps =
647         llvm::to_vector<4>(llvm::map_range(
648             consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
649               return attr.cast<AffineMapAttr>().getValue();
650             }));
651 
652     // Compute the indexing map to use for the operand of the producer.
653     AffineMap modifiedMap = linearizeCollapsedDims(
654         fusedIndexMaps[consumerIdx], producer.getResultType().getShape(),
655         producer.getReassociationMaps());
656     for (AffineExpr expr : modifiedMap.getResults()) {
657       if (!expr.isPureAffine())
658         return nullptr;
659     }
660     fusedIndexMaps[consumerIdx] = modifiedMap;
661 
662     // Further check that the resulting index maps can be fused and
663     // inverted. Without this the resultant op is not legal.
664     if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
665       return nullptr;
666 
667     SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
668         llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
669           return AffineMapAttr::get(map);
670         }));
671     auto fusedOp = rewriter.create<LinalgOpTy>(
672         rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
673         rewriter.getI64IntegerAttr(fusedOperands.size()),
674         rewriter.getI64IntegerAttr(consumer.getNumResults()),
675         rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
676         /*doc=*/nullptr,
677         /*library_call=*/nullptr);
678     auto &fusedRegion = fusedOp.region();
679     rewriter.cloneRegionBefore(consumer.region(), fusedRegion,
680                                fusedRegion.begin());
681     return fusedOp;
682   }
683 };
684 
685 /// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
686 template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer {
687   static bool isFusible(LinalgOpTy producer, TensorReshapeOp consumer,
688                         unsigned consumerIdx) {
689     return isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
690                                     false);
691   }
692 
693   static Operation *fuse(LinalgOpTy producer, TensorReshapeOp consumer,
694                          unsigned consumerIdx, PatternRewriter &rewriter,
695                          OperationFolder *folder = nullptr) {
696     if (!isFusible(producer, consumer, consumerIdx))
697       return nullptr;
698 
699     // The indexing_maps for the operands of the fused operation are same as
700     // those for the operands of the producer.
701     SmallVector<AffineMap, 4> fusedIndexMaps =
702         llvm::to_vector<4>(llvm::map_range(
703             producer.indexing_maps(), [](Attribute attr) -> AffineMap {
704               return attr.cast<AffineMapAttr>().getValue();
705             }));
706     // Compute the indexing map to use for the operand of the producer.
707     AffineMap modifiedMap = linearizeCollapsedDims(
708         producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(),
709         consumer.getReassociationMaps());
710     for (AffineExpr expr : modifiedMap.getResults()) {
711       if (!expr.isPureAffine())
712         return nullptr;
713     }
714     fusedIndexMaps.back() = modifiedMap;
715 
716     // Further check that the resulting index maps can be fused and
717     // inverted. Without this the resultant op is not legal.
718     if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
719       return nullptr;
720 
721     SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
722         llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
723           return AffineMapAttr::get(map);
724         }));
725 
726     auto fusedOp = rewriter.create<LinalgOpTy>(
727         rewriter.getUnknownLoc(), consumer.getResultType(),
728         producer.getOperands(),
729         rewriter.getI64IntegerAttr(producer.getNumOperands()),
730         rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs),
731         producer.iterator_types(),
732         /*doc=*/nullptr,
733         /*library_call=*/nullptr);
734     auto &fusedRegion = fusedOp.region();
735     rewriter.cloneRegionBefore(producer.region(), fusedRegion,
736                                fusedRegion.begin());
737     return fusedOp;
738   }
739 };
740 } // namespace
741 
742 Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
743                                        Operation *consumer,
744                                        unsigned consumerIdx,
745                                        OperationFolder *folder) {
746   if (consumerIdx >= consumer->getNumOperands())
747     return nullptr;
748   Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
749   if (!producer || producer->getNumResults() != 1)
750     return nullptr;
751 
752   // Fuse when consumer is GenericOp.
753   if (GenericOp genericOp = dyn_cast<GenericOp>(consumer)) {
754     if (!genericOp.hasTensorSemantics())
755       return nullptr;
756     if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
757       if (genericOpProducer.hasTensorSemantics())
758         return FuseGenericOpsOnTensors::fuse(genericOpProducer, genericOp,
759                                              consumerIdx, rewriter, folder);
760     } else if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) {
761       return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
762           reshapeOpProducer, genericOp, consumerIdx, rewriter, folder);
763     }
764     return nullptr;
765   }
766 
767   // Fuse when consumer is a TensorReshapeOp.
768   if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
769     if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
770       if (genericOpProducer.hasTensorSemantics())
771         return FuseTensorReshapeOpAsConsumer<GenericOp>::fuse(
772             genericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
773     }
774     return nullptr;
775   }
776   return nullptr;
777 }
778 
779 namespace {
780 /// Patterns to fuse a generic op, with the producer of its operands.
781 template <typename LinalgOpTy>
782 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
783   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
784 
785   LogicalResult matchAndRewrite(LinalgOpTy op,
786                                 PatternRewriter &rewriter) const override {
787     // Find the first operand that is defined by another generic op on tensors.
788     for (auto operandNum :
789          llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) {
790       Operation *producer =
791           op.getOperation()->getOperand(operandNum).getDefiningOp();
792       if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) {
793         rewriter.replaceOp(op, fusedOp->getResults());
794         if (producer && llvm::all_of(producer->getResults(),
795                                      [](Value val) { return val.use_empty(); }))
796           rewriter.eraseOp(producer);
797         return success();
798       }
799     }
800     return failure();
801   }
802 };
803 
804 /// Pass that fuses generic ops on tensors. Used only for testing.
805 struct FusionOfTensorOpsPass
806     : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
807   void runOnOperation() override {
808     OwningRewritePatternList patterns;
809     Operation *op = getOperation();
810     populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
811     applyPatternsAndFoldGreedily(op->getRegions(), patterns);
812   };
813 };
814 
815 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
816   void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
817 };
818 } // namespace
819 
820 void mlir::populateLinalgTensorOpsFusionPatterns(
821     MLIRContext *context, OwningRewritePatternList &patterns) {
822   patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<TensorReshapeOp>>(
823       context);
824 }
825 
826 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
827   return std::make_unique<LinalgFusionPass>();
828 }
829 
830 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
831   return std::make_unique<FusionOfTensorOpsPass>();
832 }
833