xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision 071358e08224b9971f6b7fc49a5e014a5662187c)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/Dominance.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/FoldUtils.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #define DEBUG_TYPE "linalg-fusion"
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 
38 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
39 
40 using llvm::dbgs;
41 
42 /// Implements a simple high-level fusion pass of linalg library operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 ///   1. inspecting the linalg ops that write into the views read by `O`. This
47 ///      uses the SSA value of the views and a simple subview/slice analysis to
48 ///      determine producer-consumer dependences;
49 ///   2. greedily fuse the linalg ops that produce subview
50 ///   3. inspect the fused ops and determine whether they have other remaining
51 ///      LinalgOp uses. If not, then erase the original producing linalg op.
52 ///
53 /// More advanced use cases, analyses as well as profitability heuristics are
54 /// left for future work.
55 
56 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
57 // a subset of the original loop ranges of `op`.
58 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
59 // to the `loopRanges` in order to obtain view ranges.
60 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
61                                     ArrayRef<SubViewOp::Range> loopRanges) {
62   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
63   auto maps = op.indexing_maps();
64   SmallVector<Value, 8> clonedViews;
65   clonedViews.reserve(op.getNumInputsAndOutputs());
66   // Iterate over the inputs and outputs in order.
67   // Extract the subranges from the linearized ranges.
68   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
69   for (auto en : llvm::enumerate(ios)) {
70     unsigned idx = en.index();
71     auto map = maps[idx].cast<AffineMapAttr>().getValue();
72     LLVM_DEBUG(dbgs() << "map: " << map << "\n");
73     Value view = en.value();
74     SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
75     for (auto en2 : llvm::enumerate(map.getResults())) {
76       unsigned d = en2.index();
77       // loopToOperandRangesMaps are permutations-only.
78       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
79       viewRanges[d] = loopRanges[loopPos];
80       LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
81                         << "\t"
82                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
83     }
84     // Construct a new subview for the tile.
85     unsigned rank = viewRanges.size();
86     SmallVector<Value, 4> offsets, sizes, strides;
87     offsets.reserve(rank);
88     sizes.reserve(rank);
89     strides.reserve(rank);
90     for (auto r : viewRanges) {
91       offsets.push_back(r.offset);
92       sizes.push_back(r.size);
93       strides.push_back(r.stride);
94     }
95     clonedViews.push_back(
96         b.create<SubViewOp>(loc, view, offsets, sizes, strides));
97   }
98   auto operands = getAssumedNonViewOperands(op);
99   clonedViews.append(operands.begin(), operands.end());
100 
101   Operation *clonedOp = op.clone(b, loc, clonedViews);
102   // When the producer is an IndexedGenercOp, we have to transform its block
103   // IV arguments according to the tiling of the consumer, i.e. offset them by
104   // the values computed in `loopRanges`.
105   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
106     auto &block = indexedGenericOp.region().front();
107 
108     OpBuilder::InsertionGuard g(b);
109     b.setInsertionPointToStart(&block);
110     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
111       Value oldIndex = block.getArgument(i);
112       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
113                                          loopRanges[i].offset);
114       oldIndex.replaceAllUsesExcept(newIndex,
115                                     SmallPtrSet<Operation *, 1>{newIndex});
116     }
117   }
118   return clonedOp;
119 }
120 
121 struct ViewDimension {
122   Value view;
123   unsigned dimension;
124 };
125 
126 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies
127 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
128 // guarantees at least one such dimension is found. If multiple candidates exist
129 // they must agree by construction (i.e. have the same size) and we just return
130 // the first one.
131 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
132   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
133   auto maps = op.indexing_maps();
134   // Iterate over the inputs and outputs in order.
135   // Extract the subranges from the linearized ranges.
136   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
137   for (auto en : llvm::enumerate(ios)) {
138     unsigned idx = en.index();
139     auto map = maps[idx].cast<AffineMapAttr>().getValue();
140     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
141     LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
142     Value view = en.value();
143     SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
144     for (auto en2 : llvm::enumerate(map.getResults())) {
145       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
146         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
147                           << "\n");
148         LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
149         return ViewDimension{view, static_cast<unsigned>(en2.index())};
150       }
151     }
152   }
153   llvm_unreachable("Expect to be able to extract a view defining loop range");
154 }
155 
156 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
157                      unsigned consumerIdx, unsigned producerIdx,
158                      OperationFolder *folder) {
159   assert(producer.hasBufferSemantics() &&
160          "expected linalg op with buffer semantics");
161   assert(consumer.hasBufferSemantics() &&
162          "expected linalg op with buffer semantics");
163 
164   auto subView = dyn_cast_or_null<SubViewOp>(
165       consumer.getBuffer(consumerIdx).getDefiningOp());
166   auto slice = dyn_cast_or_null<SliceOp>(
167       consumer.getBuffer(consumerIdx).getDefiningOp());
168   assert(subView || slice);
169   (void)subView;
170   (void)slice;
171 
172   // loopToOperandRangesMaps are permutations-only by construction:
173   //   we can always identify a data dimension with a (at least one) loop
174   //   dimension.
175   AffineMap producerMap =
176       producer.indexing_maps()[producer.getNumInputs() + producerIdx]
177           .cast<AffineMapAttr>()
178           .getValue();
179   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
180                     << ", producer map: " << producerMap << "\n");
181 
182   unsigned nPar = producer.getNumParallelLoops();
183   unsigned nRed = producer.getNumReductionLoops();
184   unsigned nWin = producer.getNumWindowLoops();
185   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
186 
187   OpBuilder b(consumer.getOperation());
188   auto loc = consumer.getLoc();
189   // Iterate over dimensions identified by the producer map for `producerIdx`.
190   // This defines a subset of the loop ranges that we need to complete later.
191   for (auto en : llvm::enumerate(producerMap.getResults())) {
192     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
193     loopRanges[posInProducerLoop] =
194         subView.getOrCreateRanges(b, loc)[en.index()];
195   }
196 
197   // Iterate over all dimensions. For the dimensions not identified by the
198   // producer map for `producerIdx`, we need to explicitly compute the view that
199   // defines the loop ranges using the `producer`.
200   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
201     if (loopRanges[i].offset)
202       LLVM_DEBUG(llvm::dbgs()
203                  << "existing LoopRange: " << loopRanges[i] << "\n");
204     else {
205       auto viewDim = getViewDefiningLoopRange(producer, i);
206       loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0),
207                                        std_dim(viewDim.view, viewDim.dimension),
208                                        folded_std_constant_index(folder, 1)};
209       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
210     }
211   }
212 
213   return cloneWithLoopRanges(b, loc, producer, loopRanges);
214 }
215 
216 // Encode structural fusion safety preconditions.
217 // Some of these will be lifted in the future with better analysis.
218 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
219                                           LinalgOp consumer) {
220   assert(producer.hasBufferSemantics() &&
221          "expected linalg op with buffer semantics");
222   assert(consumer.hasBufferSemantics() &&
223          "expected linalg op with buffer semantics");
224   if (producer.getNumOutputs() != 1) {
225     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
226     return false;
227   }
228   // Only fuse when the producer block dominates.
229   DominanceInfo dom(producer.getOperation());
230   if (!dom.dominates(producer.getOperation()->getBlock(),
231                      consumer.getOperation()->getBlock())) {
232     LLVM_DEBUG(
233         dbgs()
234         << "\nNot structurally fusable (producer block does not dominate)");
235     return false;
236   }
237   return true;
238 }
239 
240 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
241                                              LinalgOp consumer,
242                                              Value consumedView,
243                                              LinalgOp producer) {
244   assert(producer.hasBufferSemantics() &&
245          "expected linalg op with buffer semantics");
246   assert(consumer.hasBufferSemantics() &&
247          "expected linalg op with buffer semantics");
248   // Make some simple structural checks that alleviate the need for more
249   // complex analyses.
250   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
251     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
252                       << *producer.getOperation());
253     return false;
254   }
255   // Check for any interleaved write to consumedView.
256   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
257     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
258                       << *producer.getOperation());
259     return false;
260   }
261   return true;
262 }
263 
264 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
265                                  LinalgOp consumer, Value consumedView,
266                                  LinalgOp producer) {
267   assert(producer.hasBufferSemantics() &&
268          "expected linalg op with buffer semantics");
269   assert(consumer.hasBufferSemantics() &&
270          "expected linalg op with buffer semantics");
271   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
272     return false;
273   // Check for any fusion-preventing dependence to any view read/written that
274   // would violate dependences.
275   if (!graph.findCoveringDependences(producer, consumer).empty()) {
276     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
277                       << *producer.getOperation());
278     return false;
279   }
280   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
281     // TODO(ntv): add a level of indirection to linalg.generic.
282     if (convOp.padding())
283       return false;
284   }
285   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
286     // TODO(ntv): add a level of indirection to linalg.generic.
287     if (convOp.padding())
288       return false;
289   }
290   return true;
291 }
292 
293 static bool isSameSubView(Value a, Value b) {
294   if (a == b)
295     return true;
296   auto sva = a.getDefiningOp<SubViewOp>();
297   auto svb = b.getDefiningOp<SubViewOp>();
298   if (!sva || !svb)
299     return false;
300   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
301     return false;
302   if (sva.getType() != svb.getType())
303     return false;
304   if (sva.getRank() != svb.getRank())
305     return false;
306   if (sva.getNumOperands() != svb.getNumOperands())
307     return false;
308   if (sva.static_offsets() != svb.static_offsets())
309     return false;
310   if (sva.static_sizes() != svb.static_sizes())
311     return false;
312   if (sva.static_strides() != svb.static_strides())
313     return false;
314   /// Skip the "viewSource" operand.
315   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
316     if (sva.getOperand(idx) != svb.getOperand(idx))
317       return false;
318   return true;
319 }
320 
321 static Optional<FusionInfo>
322 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
323                   const LinalgDependenceGraph &graph, OperationFolder *folder,
324                   LinalgDependenceGraph::DependenceType depType) {
325   assert(consumer.hasBufferSemantics() &&
326          "expected linalg op with buffer semantics");
327   LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
328                     << *consumer.getOperation());
329   for (auto dependence : graph.getDependencesInto(consumer, depType)) {
330     LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
331                       << *dependence.dependentOpView.op << "\n");
332     auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
333 
334     // Check that the dependence is indeed on the input `consumerIdx` view.
335     auto consumedView = dependence.indexingView;
336     if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
337       continue;
338 
339     // Consumer consumes this view, `isStructurallyFusableProducer` also checks
340     // whether it is a strict subview of the producer view.
341     auto producedView = dependence.dependentOpView.view;
342     auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
343     // `consumerIdx` and `producerIdx` exist by construction.
344     LLVM_DEBUG(dbgs() << "\n"
345                       << LinalgDependenceGraph::getDependenceTypeStr(depType)
346                       << "producer: " << *producer.getOperation() << " view: "
347                       << producedView << " output index: " << producerIdx);
348 
349     // Must be a subview or a slice to guarantee there are loops we can fuse
350     // into.
351     auto subView = consumedView.getDefiningOp<SubViewOp>();
352     auto slice = consumedView.getDefiningOp<SliceOp>();
353     if (!subView && !slice) {
354       LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
355       continue;
356     }
357 
358     // Simple fusability checks.
359     if (!isFusableInto(graph, consumer, consumedView, producer))
360       continue;
361 
362     // Fuse `producer` just before `consumer`.
363     OpBuilder::InsertionGuard g(b);
364     b.setInsertionPoint(consumer.getOperation());
365     ScopedContext scope(b, consumer.getLoc());
366     LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
367     auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
368                               producerIdx, folder);
369 
370     return FusionInfo{producer, fusedProducer};
371   }
372   return llvm::None;
373 }
374 
375 // Only consider RAW and WAW atm.
376 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
377     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
378     const LinalgDependenceGraph &graph, OperationFolder *folder) {
379   SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = {
380       LinalgDependenceGraph::DependenceType::RAW,
381       LinalgDependenceGraph::DependenceType::WAW,
382   };
383   for (auto dep : deps) {
384     if (auto res =
385             fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep))
386       return res;
387   }
388   return llvm::None;
389 }
390 
391 static void fuseLinalgOpsGreedily(FuncOp f) {
392   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
393 
394   OpBuilder b(f);
395   OperationFolder folder(f.getContext());
396   DenseSet<Operation *> eraseSet;
397 
398   // Save original Linalg ops, we only want to make a pass over those.
399   SmallVector<Operation *, 8> linalgOps;
400   f.walk([&](LinalgOp op) {
401     if (op.hasBufferSemantics())
402       linalgOps.push_back(op);
403   });
404 
405   // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself.
406   // The current naive and expensive reconstruction of the graph should be
407   // removed.
408   for (auto *op : llvm::reverse(linalgOps)) {
409     for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers();
410          id < e; ++id) {
411       linalg::Aliases aliases;
412       linalg::LinalgDependenceGraph graph(aliases, linalgOps);
413       if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
414         auto *originalOp = info->originalProducer.getOperation();
415         eraseSet.insert(originalOp);
416         auto *originalOpInLinalgOpsVector =
417             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
418         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
419       }
420     }
421   }
422   // The `fuseProducerOf` function performs structural checks and in particular
423   // that no covering read or write exist between the consumer and the producer.
424   // As a consequence, the only fusions that may occur preserve subsequent
425   // dependences and are guaranteed by construction to produce the whole view.
426   // We may thus erase the producer once it is fused.
427   for (auto *e : eraseSet)
428     e->erase();
429   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
430 }
431 
432 //====---------------------------------------------------------------------===//
433 // Fusion on Tensor operation.
434 //====---------------------------------------------------------------------===//
435 
436 namespace {
437 
438 /// Implementation of fusion of generic ops.
439 struct FuseGenericOpsOnTensors {
440   static bool isFusible(GenericOp producer, GenericOp consumer,
441                         unsigned consumerIdx) {
442     // Verify that
443     // - the producer has all "parallel" iterator type.
444     if (producer.getNumParallelLoops() != producer.getNumLoops())
445       return false;
446 
447     // Get the consumer index map. The number of results of the consumer index
448     // map must match the number of loops of the producer.
449     AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
450     if (consumerIndexMap.getNumResults() != producer.getNumLoops())
451       return false;
452 
453     // Finally the index_map for the result must be invertible. For now just
454     // verify it is a permutation.
455     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
456     return producerResultIndexMap.isPermutation();
457   }
458 
459   static Operation *fuse(GenericOp producer, GenericOp consumer,
460                          unsigned consumerIdx, PatternRewriter &rewriter,
461                          OperationFolder *folder = nullptr) {
462     if (!isFusible(producer, consumer, consumerIdx))
463       return nullptr;
464 
465     unsigned numFusedOperands = producer.getOperation()->getNumOperands() +
466                                 consumer.getOperation()->getNumOperands() - 1;
467 
468     // Compute the fused operands list,
469     SmallVector<Value, 2> fusedOperands;
470     fusedOperands.reserve(numFusedOperands);
471     auto consumerOperands = consumer.getOperation()->getOperands();
472     auto producerOperands = producer.getOperation()->getOperands();
473     fusedOperands.assign(consumerOperands.begin(),
474                          std::next(consumerOperands.begin(), consumerIdx));
475     fusedOperands.append(producerOperands.begin(), producerOperands.end());
476     fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
477                          consumerOperands.end());
478 
479     // Compute indexing_maps for the fused operation. The indexing_maps for the
480     // operands of the consumers that arent fused are the same. The
481     // indexing_maps for the producers need to be computed based on the
482     // indexing_map of the operand at consumerIdx in the consumer.
483     SmallVector<Attribute, 4> fusedIndexMaps;
484     auto consumerIndexMaps = consumer.indexing_maps();
485     fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumResults());
486     fusedIndexMaps.assign(consumerIndexMaps.begin(),
487                           std::next(consumerIndexMaps.begin(), consumerIdx));
488     // Compute indexing maps for the producer args in the fused operation.
489     computeProducerOperandIndex(
490         producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
491 
492     // Append the indexing maps for the remaining consumer operands.
493     fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
494                           consumerIndexMaps.end());
495 
496     // Generate the fused op.
497     auto fusedOp = rewriter.create<GenericOp>(
498         rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
499         rewriter.getI64IntegerAttr(fusedOperands.size()),
500         rewriter.getI64IntegerAttr(consumer.getNumResults()),
501         rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(),
502         /*doc=*/nullptr,
503         /*library_call=*/nullptr);
504     generateFusedRegion(rewriter, fusedOp.region(), producer.region(),
505                         consumer.region(), consumerIdx);
506     return fusedOp;
507   }
508 
509 private:
510   /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
511   /// the `producer` to use in the fused operation given the indexing map of the
512   /// result of the producer in the consumer.
513   static void computeProducerOperandIndex(
514       GenericOp producer, AffineMap fusedConsumerArgIndexMap,
515       SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
516     // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
517     // from consumer loop -> consumer arg tensor index/producer result tensor
518     // index. The fused loop is same as the consumer loop. For each producer arg
519     // the indexing map to be computed is a map from consumer loop -> producer
520     // arg tensor index.
521 
522     AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
523     // producerResultIndexMap is a map from producer loop -> tensor index.
524     // Compute the inverse to get map from tensor index -> producer loop.
525     // The inverse is a map from producer result tensor index -> producer loop.
526     AffineMap invProducerResultIndexMap =
527         inversePermutation(producerResultIndexMap);
528     assert(invProducerResultIndexMap &&
529            "expected producer result indexig map to be invertible");
530     for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
531       // argMap is a map from producer loop -> producer arg tensor index.
532       AffineMap argMap = producer.getInputIndexingMap(argNum);
533 
534       // Compose argMap with invProducerResultIndexMap to get a map from
535       // producer result tensor index -> producer arg tensor index.
536       AffineMap t1 = argMap.compose(invProducerResultIndexMap);
537 
538       // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
539       // consumer loop/ fused loop -> producer arg tensor index.
540       AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
541       fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
542     }
543   }
544 
545   /// Generate the region of the fused operation. The region of the fused op
546   /// must be empty.
547   static void generateFusedRegion(PatternRewriter &rewriter,
548                                   Region &fusedRegion, Region &producerRegion,
549                                   Region &consumerRegion,
550                                   unsigned consumerIdx) {
551     // Build the region of the fused op.
552     Block &producerBlock = producerRegion.front();
553     Block &consumerBlock = consumerRegion.front();
554     Block *fusedBlock = new Block();
555     fusedRegion.push_back(fusedBlock);
556     BlockAndValueMapping mapper;
557     OpBuilder::InsertionGuard guard(rewriter);
558     rewriter.setInsertionPointToStart(fusedBlock);
559     // Map the arguments for the unmodified args from the consumer.
560     for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
561       if (consumerArg.index() == consumerIdx) {
562         // Map the arguments for the args from the producer.
563         for (auto producerArg : producerBlock.getArguments())
564           mapper.map(producerArg,
565                      fusedBlock->addArgument(producerArg.getType()));
566         continue;
567       }
568       mapper.map(consumerArg.value(),
569                  fusedBlock->addArgument(consumerArg.value().getType()));
570     }
571 
572     // Add operations from producer (except the yield operation) to the fused
573     // op.
574     for (auto &op : producerBlock.getOperations()) {
575       if (auto yieldOp = dyn_cast<YieldOp>(op)) {
576         // Lookup the value the yield operation is mapped to.
577         Value yieldVal = yieldOp.getOperand(0);
578         auto clonedVal = mapper.lookup(yieldVal);
579         mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal);
580         continue;
581       }
582       rewriter.clone(op, mapper);
583     }
584     for (auto &op : consumerBlock.getOperations())
585       rewriter.clone(op, mapper);
586   }
587 };
588 } // namespace
589 
590 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
591 /// provided, given the shape of the source tensor that corresponds to the
592 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
593 /// are "row-major" ordered logically.
594 ///
595 /// For example:
596 ///
597 /// %0 = op ... : tensor<?x?x4x5xf32>
598 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
599 ///
600 /// and reshape:
601 /// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
602 ///                                affine_map<(i, j, k, l) -> (j, k, l)>] :
603 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
604 ///
605 /// would be rewritten into:
606 /// %0 = op ... : tensor<?x?x4x5xf32>
607 /// with output index_map
608 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
609 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
610                                         ArrayRef<int64_t> sourceShape,
611                                         ArrayRef<AffineMap> reassociationMaps) {
612   SmallVector<AffineExpr, 4> resultExprs;
613   resultExprs.reserve(reassociationMaps.size());
614   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
615   MLIRContext *context = sourceMap.getContext();
616 
617   // Compute the result exprs based on the reassociation maps.
618   for (AffineMap map : reassociationMaps) {
619     ArrayRef<AffineExpr> collapsedDims = map.getResults();
620     // Assume that they are in-order and contiguous (already checked in
621     // verifier).
622     assert(!collapsedDims.empty());
623     unsigned startDim =
624         collapsedDims.front().cast<AffineDimExpr>().getPosition();
625     AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
626         sourceShape.slice(startDim, collapsedDims.size()),
627         sourceExprs.slice(startDim, collapsedDims.size()), context);
628     resultExprs.push_back(linearizedExpr);
629   }
630   return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
631                         resultExprs, context);
632 }
633 
634 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
635 /// true) or its producer (if `asProducer` is false) given the indexing map at
636 /// its use.
637 static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
638                                      AffineMap useIndexMap, bool asProducer) {
639   RankedTensorType returnType = reshapeOp.getResultType();
640   RankedTensorType operandType = reshapeOp.getSrcType();
641   // Reshape is fusible with its consumer (i.e. reshape as a producer) when its
642   // operand is of lesser rank than the result. Fusing when operand has higher
643   // rank will require use of mods and divs in the indexing maps of the fused op
644   // which would make it non-invertible. Similarly reshape is fused with its
645   // producer (i.e. reshape as consumer) only if the return type has lesser
646   // rank.
647   if ((asProducer && returnType.getRank() < operandType.getRank()) ||
648       (!asProducer && operandType.getRank() < returnType.getRank()))
649     return false;
650   return useIndexMap.isIdentity();
651 }
652 
653 namespace {
654 /// Implementation of fusion on tensor ops when producer is a TensorReshapeOp.
655 template <typename LinalgOpTy> struct FuseTensorReshapeOpAsProducer {
656   static bool isFusible(TensorReshapeOp producer, LinalgOpTy consumer,
657                         unsigned consumerIdx) {
658     return isTensorReshapeOpFusible(
659         producer, consumer.getInputIndexingMap(consumerIdx), true);
660   }
661 
662   static Operation *fuse(TensorReshapeOp producer, LinalgOpTy consumer,
663                          unsigned consumerIdx, PatternRewriter &rewriter,
664                          OperationFolder *folder = nullptr) {
665     if (!isFusible(producer, consumer, consumerIdx))
666       return nullptr;
667 
668     // Compute the fused operands list,
669     SmallVector<Value, 2> fusedOperands(consumer.operand_begin(),
670                                         consumer.operand_end());
671     fusedOperands[consumerIdx] = producer.src();
672 
673     // Compute indexing_maps for the fused operation. The indexing_maps for the
674     // operands of the consumers that arent fused are the same.
675     SmallVector<AffineMap, 4> fusedIndexMaps =
676         llvm::to_vector<4>(llvm::map_range(
677             consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
678               return attr.cast<AffineMapAttr>().getValue();
679             }));
680 
681     // Compute the indexing map to use for the operand of the producer.
682     AffineMap modifiedMap = linearizeCollapsedDims(
683         fusedIndexMaps[consumerIdx], producer.getResultType().getShape(),
684         producer.getReassociationMaps());
685     for (AffineExpr expr : modifiedMap.getResults()) {
686       if (!expr.isPureAffine())
687         return nullptr;
688     }
689     fusedIndexMaps[consumerIdx] = modifiedMap;
690 
691     // Further check that the resulting index maps can be fused and
692     // inverted. Without this the resultant op is not legal.
693     if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
694       return nullptr;
695 
696     SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
697         llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
698           return AffineMapAttr::get(map);
699         }));
700     auto fusedOp = rewriter.create<LinalgOpTy>(
701         rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
702         rewriter.getI64IntegerAttr(fusedOperands.size()),
703         rewriter.getI64IntegerAttr(consumer.getNumResults()),
704         rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
705         /*doc=*/nullptr,
706         /*library_call=*/nullptr);
707     auto &fusedRegion = fusedOp.region();
708     rewriter.cloneRegionBefore(consumer.region(), fusedRegion,
709                                fusedRegion.begin());
710     return fusedOp;
711   }
712 };
713 
714 /// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
715 template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer {
716   static bool isFusible(LinalgOpTy producer, TensorReshapeOp consumer,
717                         unsigned consumerIdx) {
718     return isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
719                                     false);
720   }
721 
722   static Operation *fuse(LinalgOpTy producer, TensorReshapeOp consumer,
723                          unsigned consumerIdx, PatternRewriter &rewriter,
724                          OperationFolder *folder = nullptr) {
725     if (!isFusible(producer, consumer, consumerIdx))
726       return nullptr;
727 
728     // The indexing_maps for the operands of the fused operation are same as
729     // those for the operands of the producer.
730     SmallVector<AffineMap, 4> fusedIndexMaps =
731         llvm::to_vector<4>(llvm::map_range(
732             producer.indexing_maps(), [](Attribute attr) -> AffineMap {
733               return attr.cast<AffineMapAttr>().getValue();
734             }));
735     // Compute the indexing map to use for the operand of the producer.
736     AffineMap modifiedMap = linearizeCollapsedDims(
737         producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(),
738         consumer.getReassociationMaps());
739     for (AffineExpr expr : modifiedMap.getResults()) {
740       if (!expr.isPureAffine())
741         return nullptr;
742     }
743     fusedIndexMaps.back() = modifiedMap;
744 
745     // Further check that the resulting index maps can be fused and
746     // inverted. Without this the resultant op is not legal.
747     if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
748       return nullptr;
749 
750     SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
751         llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
752           return AffineMapAttr::get(map);
753         }));
754 
755     auto fusedOp = rewriter.create<LinalgOpTy>(
756         rewriter.getUnknownLoc(), consumer.getResultType(),
757         producer.getOperands(),
758         rewriter.getI64IntegerAttr(producer.getNumOperands()),
759         rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs),
760         producer.iterator_types(),
761         /*doc=*/nullptr,
762         /*library_call=*/nullptr);
763     auto &fusedRegion = fusedOp.region();
764     rewriter.cloneRegionBefore(producer.region(), fusedRegion,
765                                fusedRegion.begin());
766     return fusedOp;
767   }
768 };
769 
770 /// Implementation of fusion on tensor ops when producer is a splat constant.
771 template <typename LinalgOpTy> struct FuseConstantOpAsProducer {
772   static bool isFusible(ConstantOp producer, LinalgOpTy consumer,
773                         unsigned consumerIdx) {
774     return producer.getResult().getType().isa<RankedTensorType>() &&
775            producer.value().template cast<DenseElementsAttr>().isSplat();
776   }
777 
778   static Operation *fuse(ConstantOp producer, LinalgOpTy consumer,
779                          unsigned consumerIdx, PatternRewriter &rewriter,
780                          OperationFolder *folder = nullptr) {
781     if (!isFusible(producer, consumer, consumerIdx))
782       return nullptr;
783 
784     // The indexing_maps for the operands of the fused operation are same as
785     // those for the operands of the consumer without the indexing map at
786     // consumerIdx
787     SmallVector<AffineMap, 4> fusedIndexMaps =
788         llvm::to_vector<4>(llvm::map_range(
789             consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
790               return attr.cast<AffineMapAttr>().getValue();
791             }));
792     fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx));
793 
794     // The operands list is same as the consumer with the argument for constant
795     // index dropped.
796     SmallVector<Value, 4> fusedOperands(consumer.operand_begin(),
797                                         consumer.operand_end());
798     fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx));
799 
800     // Create a constant scalar value from the splat constant.
801     Value scalarConstant = rewriter.create<ConstantOp>(
802         producer.getLoc(),
803         producer.value().template cast<DenseElementsAttr>().getSplatValue());
804 
805     auto fusedOp = rewriter.create<LinalgOpTy>(
806         rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
807         rewriter.getI64IntegerAttr(consumer.getNumOperands() - 1),
808         rewriter.getI64IntegerAttr(consumer.getNumResults()),
809         rewriter.getAffineMapArrayAttr(fusedIndexMaps),
810         consumer.iterator_types(),
811         /*doc=*/nullptr,
812         /*library_call=*/nullptr);
813 
814     // Map the block argument corresponding to the replaced argument with the
815     // scalar constant.
816     Region &consumerRegion = consumer.region();
817     Block &entryBlock = *consumerRegion.begin();
818     unsigned argIndex =
819         entryBlock.getNumArguments() - consumer.getNumOperands() + consumerIdx;
820     BlockAndValueMapping mapping;
821     mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
822     Region &fusedRegion = fusedOp.region();
823     rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(),
824                                mapping);
825     return fusedOp;
826   }
827 };
828 
829 } // namespace
830 
831 Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
832                                        Operation *consumer,
833                                        unsigned consumerIdx,
834                                        OperationFolder *folder) {
835   if (consumerIdx >= consumer->getNumOperands())
836     return nullptr;
837   Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
838   if (!producer || producer->getNumResults() != 1)
839     return nullptr;
840 
841   // Fuse when consumer is GenericOp.
842   if (GenericOp genericOp = dyn_cast<GenericOp>(consumer)) {
843     if (!genericOp.hasTensorSemantics())
844       return nullptr;
845     if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
846       if (genericOpProducer.hasTensorSemantics())
847         return FuseGenericOpsOnTensors::fuse(genericOpProducer, genericOp,
848                                              consumerIdx, rewriter, folder);
849     } else if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) {
850       return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
851           reshapeOpProducer, genericOp, consumerIdx, rewriter, folder);
852     } else if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) {
853       return FuseConstantOpAsProducer<GenericOp>::fuse(
854           constantOpProducer, genericOp, consumerIdx, rewriter, folder);
855     }
856     return nullptr;
857   }
858 
859   // Fuse when consumer is a TensorReshapeOp.
860   if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
861     if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
862       if (genericOpProducer.hasTensorSemantics())
863         return FuseTensorReshapeOpAsConsumer<GenericOp>::fuse(
864             genericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
865     }
866     return nullptr;
867   }
868   return nullptr;
869 }
870 
871 namespace {
872 /// Patterns to fuse a generic op, with the producer of its operands.
873 template <typename LinalgOpTy>
874 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
875   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
876 
877   LogicalResult matchAndRewrite(LinalgOpTy op,
878                                 PatternRewriter &rewriter) const override {
879     // Find the first operand that is defined by another generic op on tensors.
880     for (auto operandNum :
881          llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) {
882       Operation *producer =
883           op.getOperation()->getOperand(operandNum).getDefiningOp();
884       if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) {
885         rewriter.replaceOp(op, fusedOp->getResults());
886         if (producer && llvm::all_of(producer->getResults(),
887                                      [](Value val) { return val.use_empty(); }))
888           rewriter.eraseOp(producer);
889         return success();
890       }
891     }
892     return failure();
893   }
894 };
895 
896 /// Pass that fuses generic ops on tensors. Used only for testing.
897 struct FusionOfTensorOpsPass
898     : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
899   void runOnOperation() override {
900     OwningRewritePatternList patterns;
901     Operation *op = getOperation();
902     populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
903     applyPatternsAndFoldGreedily(op->getRegions(), patterns);
904   };
905 };
906 
907 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
908   void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
909 };
910 } // namespace
911 
912 void mlir::populateLinalgTensorOpsFusionPatterns(
913     MLIRContext *context, OwningRewritePatternList &patterns) {
914   patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<TensorReshapeOp>>(
915       context);
916 }
917 
918 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
919   return std::make_unique<LinalgFusionPass>();
920 }
921 
922 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
923   return std::make_unique<FusionOfTensorOpsPass>();
924 }
925