xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision f8284d21a8e294d58a0acd4b8b2e906d7a9f110c)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #include <set>
32 
33 #define DEBUG_TYPE "linalg-fusion"
34 
35 using namespace mlir;
36 using namespace mlir::edsc;
37 using namespace mlir::edsc::intrinsics;
38 using namespace mlir::linalg;
39 
40 using llvm::dbgs;
41 
42 /// Implements a simple high-level fusion pass on linalg structured operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 ///   1. inspecting the linalg ops that write into the views read by `O`. There
47 ///      are 2 cases:
48 ///      a) buffer case: use the SSA value of the views and a simple alias
49 ///         analysis on subview ops to determine producer-consumer dependences;
50 ///      b) tensor case: use SSA use-def chains on subtensor ops;
51 ///   2. greedily fuse the linalg ops that produce the subview/subtensor.
52 ///   3. inspect the fused ops and determine whether they have other remaining
53 ///      LinalgOp uses. If not, then erase the original producing linalg op.
54 ///
55 /// More advanced use cases, analyses as well as profitability heuristics are
56 /// left for future work.
57 
58 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
59 // by `permutationMap`.
60 static void inferShapeComponents(AffineMap permutationMap,
61                                  ArrayRef<Range> loopRanges,
62                                  SmallVectorImpl<Value> &offsets,
63                                  SmallVectorImpl<Value> &sizes,
64                                  SmallVectorImpl<Value> &strides) {
65   assert(permutationMap.isProjectedPermutation() &&
66          "expected some subset of a permutation map");
67   SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
68   unsigned idx = 0;
69   for (AffineExpr e : permutationMap.getResults()) {
70     // loopToOperandRangesMaps are permutations-only, just swap indices.
71     unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
72     shapeRanges[idx++] = loopRanges[loopPos];
73   }
74   // Construct a new subshape for the tile.
75   unsigned rank = shapeRanges.size();
76   offsets.reserve(rank);
77   sizes.reserve(rank);
78   strides.reserve(rank);
79   for (auto r : shapeRanges) {
80     offsets.push_back(r.offset);
81     sizes.push_back(r.size);
82     strides.push_back(r.stride);
83   }
84 }
85 
86 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
87 // a subset of the original loop ranges of `op`.
88 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
89 // to the `loopRanges` in order to obtain view ranges.
90 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
91                                     ArrayRef<Range> loopRanges) {
92   SmallVector<Value, 8> clonedShapes;
93   clonedShapes.reserve(op.getNumShapedOperands());
94 
95   // Iterate over the shape operands in order.
96   // Extract the subranges from the linearized ranges.
97   for (auto en : llvm::enumerate(op.getShapedOperands())) {
98     unsigned shapedOperandIdx = en.index();
99     AffineMap map = op.getIndexingMap(shapedOperandIdx);
100     LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
101                             << " with indexingMap: " << map << "\n");
102     SmallVector<Value, 4> offsets, sizes, strides;
103     inferShapeComponents(map, loopRanges, offsets, sizes, strides);
104     Value shape = en.value();
105     Value sub = shape.getType().isa<MemRefType>()
106                     ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
107                           .getResult()
108                     : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
109                           .getResult();
110     clonedShapes.push_back(sub);
111   }
112   // Append the other operands.
113   auto operands = op.getAssumedNonShapedOperands();
114   clonedShapes.append(operands.begin(), operands.end());
115 
116   // Iterate over the results in order.
117   // Extract the subtensor type from the linearized range.
118   // Since we do not enforce any canonicalizations on the fly, this is always
119   // fully dynamic at construction time.
120   SmallVector<Type, 4> resultTypes;
121   resultTypes.reserve(op.getOperation()->getNumResults());
122   for (RankedTensorType t : op.getOutputTensorTypes()) {
123     unsigned rank = t.getRank();
124     SmallVector<int64_t, 4> staticOffsetsVector(
125         rank, ShapedType::kDynamicStrideOrOffset);
126     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
127     SmallVector<int64_t, 4> staticStridesVector(
128         rank, ShapedType::kDynamicStrideOrOffset);
129     resultTypes.push_back(SubTensorOp::inferResultType(
130         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
131         staticStridesVector));
132   }
133 
134   Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
135   // When the producer is an IndexedGenericOp, we have to transform its block
136   // IV arguments according to the tiling of the consumer, i.e. offset them by
137   // the values computed in `loopRanges`.
138   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
139     auto &block = indexedGenericOp.region().front();
140     OpBuilder::InsertionGuard g(b);
141     b.setInsertionPointToStart(&block);
142     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
143       Value oldIndex = block.getArgument(i);
144       // TODO: replace by an affine_apply.
145       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
146                                          loopRanges[i].offset);
147       oldIndex.replaceAllUsesExcept(newIndex,
148                                     SmallPtrSet<Operation *, 1>{newIndex});
149     }
150   }
151 
152   return clonedOp;
153 }
154 
155 struct ShapeDimension {
156   Value shape;
157   unsigned dimension;
158 };
159 
160 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
161 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
162 // guarantees at least one such dimension is found. If multiple candidates exist
163 // they must agree by construction (i.e. have the same size) and we just return
164 // the first one.
165 static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
166                                                 unsigned loopDepth) {
167   auto maps = op.indexing_maps();
168   // Iterate over the inputs and outputs in order.
169   // Extract the subranges from the linearized ranges.
170   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
171   for (auto en : llvm::enumerate(ios)) {
172     unsigned idx = en.index();
173     auto map = maps[idx].cast<AffineMapAttr>().getValue();
174     LLVM_DEBUG(llvm::dbgs()
175                << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
176     LLVM_DEBUG(llvm::dbgs()
177                << "getShapeDefiningLoopRange map: " << map << "\n");
178     Value shape = en.value();
179     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
180     for (auto en2 : llvm::enumerate(map.getResults())) {
181       auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
182       if (!dimExpr)
183         continue;
184       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
185         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
186                                 << loopDepth << "\n");
187         LLVM_DEBUG(llvm::dbgs()
188                    << "getShapeDefiningLoopRange shape: " << shape << "\n");
189         return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
190       }
191     }
192   }
193   llvm_unreachable("Expect to be able to extract a shape defining loop range");
194 }
195 
196 /// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
197 /// provides the loop range information for the fused loops. The rest are
198 /// obtained from the producer itself, since they are not tiled + fused.
199 static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
200                      const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
201 
202   unsigned nPar = producer.getNumParallelLoops();
203   unsigned nRed = producer.getNumReductionLoops();
204   unsigned nWin = producer.getNumWindowLoops();
205   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
206   for (auto fusedLoops : fusedLoopsAndRanges)
207     loopRanges[fusedLoops.first] = fusedLoops.second;
208 
209   // Iterate over all dimensions. For the dimensions not identified by the
210   // producer map for `producerIdx`, we need to explicitly compute the shape
211   // that defines the loop ranges using the `producer`.
212   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
213     if (loopRanges[i].offset)
214       LLVM_DEBUG(llvm::dbgs()
215                  << "existing LoopRange: " << loopRanges[i] << "\n");
216     else {
217       auto shapeDim = getShapeDefiningLoopRange(producer, i);
218       loopRanges[i] = Range{std_constant_index(0),
219                             std_dim(shapeDim.shape, shapeDim.dimension),
220                             std_constant_index(1)};
221       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
222     }
223   }
224 
225   return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
226 }
227 
228 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
229 /// expected to be defined by a subview op or a subtensor op.
230 static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
231                                       Value shapedOperand, unsigned dim) {
232   Operation *shapeProducingOp = shapedOperand.getDefiningOp();
233   if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
234     return subViewOp.getOrCreateRanges(b, loc)[dim];
235   if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
236     return subTensorOp.getOrCreateRanges(b, loc)[dim];
237   llvm_unreachable("SubviewOp or SubTensorOp expected");
238 }
239 
240 /// Fuses the producer of `producerIdx` into the loop immediately enclosing
241 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it
242 /// is needed just before the `consumer.
243 ///
244 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
245 /// 2 cases:
246 ///   1. Buffer case: `producerIdx` is the index of the buffer in
247 ///      `producer.getOutputBuffers()`.
248 ///   2. Tensor case: `producerIdx` is the index of the tensor in
249 ///      `producer.getResults()`.
250 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
251                      LinalgOp consumer, unsigned consumerIdx) {
252   AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
253   LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
254                           << ", producer map: " << producerMap << "\n");
255   DenseMap<unsigned, Range> fusedLoopsAndRanges;
256   Location loc = consumer.getLoc();
257   Value shapedOperand = consumer.getShapedOperand(consumerIdx);
258   for (auto en : llvm::enumerate(producerMap.getResults())) {
259     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
260     fusedLoopsAndRanges[posInProducerLoop] =
261         getRangeFromOperandShape(b, loc, shapedOperand, en.index());
262   }
263   return fuse(b, producer, fusedLoopsAndRanges);
264 }
265 
266 // Encode structural fusion safety preconditions.
267 // Some of these will be lifted in the future with better analysis.
268 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
269                                           LinalgOp consumer) {
270   assert(producer.hasBufferSemantics() &&
271          "expected linalg op with buffer semantics");
272   assert(consumer.hasBufferSemantics() &&
273          "expected linalg op with buffer semantics");
274   if (producer.getNumOutputs() != 1) {
275     LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
276     return false;
277   }
278   // Only fuse when the producer block dominates.
279   DominanceInfo dom(producer.getOperation());
280   if (!dom.dominates(producer.getOperation()->getBlock(),
281                      consumer.getOperation()->getBlock())) {
282     LLVM_DEBUG(
283         llvm::dbgs()
284         << "\nNot structurally fusable (producer block does not dominate)");
285     return false;
286   }
287   return true;
288 }
289 
290 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
291                                              LinalgOp consumer,
292                                              Value consumedView,
293                                              LinalgOp producer) {
294   assert(producer.hasBufferSemantics() &&
295          "expected linalg op with buffer semantics");
296   assert(consumer.hasBufferSemantics() &&
297          "expected linalg op with buffer semantics");
298   // Make some simple structural checks that alleviate the need for more
299   // complex analyses.
300   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
301     LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
302                             << *producer.getOperation());
303     return false;
304   }
305   // Check for any interleaved write to consumedView.
306   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
307     LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
308                             << *producer.getOperation());
309     return false;
310   }
311   return true;
312 }
313 
314 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
315                                  LinalgOp consumer, Value consumedView,
316                                  LinalgOp producer) {
317   assert(producer.hasBufferSemantics() &&
318          "expected linalg op with buffer semantics");
319   assert(consumer.hasBufferSemantics() &&
320          "expected linalg op with buffer semantics");
321   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
322     return false;
323   // Check for any fusion-preventing dependence to any shape read/written that
324   // would violate dependences.
325   if (!graph.findCoveringDependences(producer, consumer).empty()) {
326     LLVM_DEBUG(llvm::dbgs()
327                << "\n***Not fusable due to an interleaved dependence:\t"
328                << *producer.getOperation());
329     return false;
330   }
331   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
332     // TODO: add a level of indirection to linalg.generic.
333     if (convOp.padding())
334       return false;
335   }
336   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
337     // TODO: add a level of indirection to linalg.generic.
338     if (convOp.padding())
339       return false;
340   }
341   return true;
342 }
343 
344 static bool isSameSubView(Value a, Value b) {
345   if (a == b)
346     return true;
347   auto sva = a.getDefiningOp<SubViewOp>();
348   auto svb = b.getDefiningOp<SubViewOp>();
349   if (!sva || !svb)
350     return false;
351   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
352     return false;
353   if (sva.getType() != svb.getType())
354     return false;
355   if (sva.getNumOperands() != svb.getNumOperands())
356     return false;
357   if (sva.static_offsets() != svb.static_offsets())
358     return false;
359   if (sva.static_sizes() != svb.static_sizes())
360     return false;
361   if (sva.static_strides() != svb.static_strides())
362     return false;
363   /// Skip the "source" operand.
364   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
365     if (sva.getOperand(idx) != svb.getOperand(idx))
366       return false;
367   return true;
368 }
369 
370 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
371 findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
372                     const LinalgDependenceGraph &dependenceGraph) {
373   // Only consider RAW and WAW atm.
374   for (auto depType : {
375            LinalgDependenceGraph::DependenceType::RAW,
376            LinalgDependenceGraph::DependenceType::WAW,
377        }) {
378     for (auto dependence : llvm::make_filter_range(
379              dependenceGraph.getDependencesInto(consumer, depType),
380              [consumerIdx](
381                  LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
382                return elem.indexingOpView.operandIndex == consumerIdx;
383              })) {
384       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
385 
386       // Check that the dependence is indeed on the input `consumerIdx` view.
387       auto consumedView =
388           consumer.getBuffer(dependence.indexingOpView.operandIndex);
389       if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
390         continue;
391 
392       // Consumer consumes this view, `isStructurallyFusableProducer` also
393       // checks whether it is a strict subview of the producer view.
394       auto producedView =
395           producer.getBuffer(dependence.dependentOpView.operandIndex);
396       LLVM_DEBUG(llvm::dbgs()
397                  << "\n"
398                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
399                  << "producer: " << *producer.getOperation()
400                  << " view: " << producedView << " output index: "
401                  << dependence.dependentOpView.operandIndex -
402                         producer.getNumInputs()
403                  << "\n");
404       (void)producedView;
405 
406       // Simple fusability checks.
407       if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
408         continue;
409 
410       return dependence;
411     }
412   }
413   return {};
414 }
415 
416 Optional<FusionInfo>
417 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
418                                    unsigned consumerIdx,
419                                    const LinalgDependenceGraph &graph) {
420   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
421       findFusableProducer(consumer, consumerIdx, graph);
422   if (!fusableDependence)
423     return {};
424 
425   LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
426   // If producer is already in the same block as consumer, we are done.
427   if (consumer.getOperation()->getBlock() ==
428       producerOp.getOperation()->getBlock())
429     return {};
430 
431   unsigned producerIdx = fusableDependence->dependentOpView.operandIndex -
432                          producerOp.getNumInputs();
433   Value consumerView = consumer.getShapedOperand(consumerIdx);
434 
435   // Must be a subview or a slice to guarantee there are loops we can fuse
436   // into.
437   auto subView = consumerView.getDefiningOp<SubViewOp>();
438   auto slice = consumerView.getDefiningOp<SliceOp>();
439   if (!subView && !slice) {
440     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
441     return {};
442   }
443 
444   // Fuse `producer` just before `consumer`.
445   OpBuilder::InsertionGuard g(b);
446   b.setInsertionPoint(consumer.getOperation());
447   ScopedContext scope(b, consumer.getLoc());
448   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
449 
450   auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
451   return FusionInfo{producerOp, fusedProducer};
452 }
453 
454 /// Walk back use-def chain through scf::For yields.
455 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
456 static void getProducerOfTensor(Value tensor, LinalgOp &producer,
457                                 unsigned &outputIndex) {
458   if (!tensor.getType().isa<RankedTensorType>())
459     return;
460 
461   while (true) {
462     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
463       producer = linalgOp;
464       outputIndex = tensor.cast<OpResult>().getResultNumber();
465       return;
466     }
467     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
468       tensor = subTensorOp.source();
469       continue;
470     }
471     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
472       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
473         tensor = forOp.getResult(blockArg.getArgNumber());
474         continue;
475       }
476     }
477     return;
478   }
479 }
480 
481 Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
482                                                         LinalgOp consumer,
483                                                         unsigned consumerIdx) {
484   Value inputTensor = consumer.getInput(consumerIdx);
485   LinalgOp producerOp;
486   unsigned producerIdx;
487   getProducerOfTensor(inputTensor, producerOp, producerIdx);
488 
489   // Must be a subtensor to guarantee there are loops we can fuse into.
490   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
491   if (!subTensor || !producerOp) {
492     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
493     return {};
494   }
495 
496   // If producer is already in the same block as consumer, we are done.
497   if (consumer.getOperation()->getBlock() ==
498       producerOp.getOperation()->getBlock())
499     return {};
500 
501   // Insert fused `producer` just before `consumer`.
502   OpBuilder::InsertionGuard g(b);
503   b.setInsertionPoint(consumer.getOperation());
504   ScopedContext scope(b, consumer.getLoc());
505   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
506   LinalgOp fusedProducer =
507       fuse(b, producerOp, producerIdx, consumer, consumerIdx);
508 
509   // Replace use.
510   // Canonicalizations are not guaranteed to have happened before constructing
511   // `fusedProducer`. In the tensor case this can result in temporary type
512   // mismatches. Insert a `tensor_cast` op to propagate the transformation
513   // invariant that types are compatible.
514   Value def = fusedProducer.getOperation()->getResult(producerIdx);
515   OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx);
516   Type consumerType = use.get().getType();
517   if (consumerType != def.getType())
518     def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
519   use.set(def);
520   return FusionInfo{producerOp, fusedProducer};
521 }
522 
523 /// Prune all dimensions that are of reduction iterator type from `map`.
524 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
525                                            AffineMap map) {
526   SmallVector<unsigned, 2> projectedDims;
527   for (auto attr : llvm::enumerate(iteratorTypes)) {
528     if (!isParallelIterator(attr.value()))
529       projectedDims.push_back(attr.index());
530   }
531   return getProjectedMap(map, projectedDims);
532 }
533 
534 using FusableOpDependencesTy = llvm::MapVector<
535     Operation *,
536     SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
537 
538 /// Returns the mapping from iterations in the consumer that write to the same
539 /// location as the iterations in the producer. To do so use
540 /// - indexing map of the fused view in the consumer : consumerIndexMap
541 /// - indexing map of the fused view in the producer : producerIndexMap
542 ///     consumerLoopToProducerLoop =
543 ///       inverse(producerIndexMap).compose(consumerIndexMap)
544 static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
545     LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
546   auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
547   AffineMap producerIndexingMap =
548       producer.getIndexingMap(dependence.dependentOpView.operandIndex);
549   auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
550   AffineMap consumerIndexingMap =
551       consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
552 
553   AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
554       producer.iterator_types().getValue(), producerIndexingMap);
555   if (!prunedProducerIndexingMap.isPermutation())
556     return None;
557 
558   if (consumerIndexingMap.getNumResults() !=
559       prunedProducerIndexingMap.getNumResults())
560     return None;
561 
562   LLVM_DEBUG({
563     llvm::dbgs() << "\t producerMap : ";
564     producerIndexingMap.print(llvm::dbgs());
565     llvm::dbgs() << "  pruned : ";
566     prunedProducerIndexingMap.print(llvm::dbgs());
567     llvm::dbgs() << "\n";
568     llvm::dbgs() << "\t consumerMap : ";
569     consumerIndexingMap.print(llvm::dbgs());
570     llvm::dbgs() << "\n";
571   });
572 
573   AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
574   if (!invProducerIndexMap)
575     return None;
576 
577   return invProducerIndexMap.compose(consumerIndexingMap);
578 }
579 
580 /// Given a projected permutation `map`, returns true if the map changes the
581 /// order in which the fused loop dimension appear.
582 static bool doesTransposeAccess(AffineMap map,
583                                 const std::set<unsigned> &fusableLoops) {
584   Optional<unsigned> lastFusableLoop;
585   for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
586          return expr.cast<AffineDimExpr>().getPosition();
587        })) {
588     if (!fusableLoops.count(pos))
589       continue;
590     if (!lastFusableLoop) {
591       lastFusableLoop = pos;
592       continue;
593     }
594     if (pos <= lastFusableLoop.getValue())
595       return true;
596     lastFusableLoop = pos;
597   }
598   return false;
599 }
600 
601 /// Returns the positions of the loop in `op` that can be tiled based on the
602 /// operations that are to be fused with it. For example, in a
603 ///
604 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
605 ///
606 /// if the producer of %a needs to be fused with this op, only the `i` loop of
607 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
608 /// fused, then no loops can be tiled while fusing. The conditions used are:
609 /// 1. Only parallel loops can be used for tile + fuse. Find the number of
610 ///    common outer parallel loops between the op and its producers being fused.
611 /// 2. Of the parallel loops only some can be fused. Only those loops can be
612 ///    fused such where the fusable loops iteration space only touches one tile
613 ///    of the fused operation. This is because the producer (which is writing
614 ///    the fused subview) has update semantics.
615 ///
616 /// Since an inverse computation is needed, we need to consider the projection
617 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
618 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to
619 /// parallel loops and appear in the result of the map
620 ///
621 /// Example 1:
622 ///   linalg.fill(%c, %cst)
623 ///   linalg.matmul ins(%a, %b) outs(%c)
624 ///     Number of parallel loops : 2
625 ///     producerIndexMap = affine_map<(i, j) ->(i , j)>
626 ///     consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
627 ///     consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
628 ///     Fused dimensions : i, j
629 ///
630 /// Example 2:
631 ///   linalg.matmul ins(%a, %b) outs(%c)
632 ///   linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
633 ///                   iterator_types = ["parallel", "parallel"]}
634 ///     ins(%c) ...
635 ///
636 ///     Number of parallel loops = 2:
637 ///     producerIndexMap (projected to parallel loops) =
638 ///       affine_map<(i, j) -> (i, j)>
639 ///     consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
640 ///     Fused dimensions : i, j
641 ///
642 /// Example 3:
643 ///   linalg.copy(%s, %b)
644 ///   linalg.matmul ins(%a, %b) outs(%c)
645 ///
646 ///   Number of parallel loops = 2
647 ///   produceIndexMap : affine_map<(i, j) -> (i, j)>
648 ///   consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
649 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
650 ///   Fused dimensions : j
651 static std::set<unsigned>
652 collectFusableLoops(ArrayRef<LinalgOp> ops,
653                     const FusableOpDependencesTy &fusableDependences) {
654   assert(!ops.empty());
655   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
656     return linalgOp.iterator_types()
657         .getValue()
658         .take_while([](Attribute attr) -> bool {
659           return attr.cast<StringAttr>().getValue() ==
660                  getParallelIteratorTypeName();
661         })
662         .size();
663   };
664 
665   size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
666   for (auto op : ops.drop_back()) {
667     numOuterParallelLoops =
668         std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
669   }
670 
671   std::set<unsigned> fusableLoops;
672   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
673   fusableLoops.insert(range.begin(), range.end());
674 
675   for (auto op : reverse(ops)) {
676     for (auto dependence : fusableDependences.lookup(op)) {
677       LLVM_DEBUG({
678         llvm::dbgs() << "\t fusable :";
679         for (unsigned i : fusableLoops)
680           llvm::dbgs() << " " << i;
681         llvm::dbgs() << "\n";
682       });
683 
684       Optional<AffineMap> consumerLoopToProducerLoop =
685           getConsumerLoopToProducerLoopMap(dependence);
686       if (!consumerLoopToProducerLoop) {
687         op.emitRemark("failed to get map from consumer loop to producer loop");
688         return {};
689       }
690       // todo: This condition is only an implementation limitation. When fusing
691       // the operation, if the accesses in the producer/consumer are transposes
692       // of each other, the loop bounds for the tiled producer can be
693       // manipulated accordingly. This requires some additional bookkeeping in
694       // the implementation of tile+fuse that is defered to later.
695       if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
696         op.emitRemark("unhandled fusion when fusion requires permutation");
697         return {};
698       }
699 
700       std::set<unsigned> candidates;
701       for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
702         unsigned position = expr.cast<AffineDimExpr>().getPosition();
703         if (fusableLoops.count(position))
704           candidates.insert(position);
705       }
706       LLVM_DEBUG({
707         llvm::dbgs() << "\t candidates :";
708         for (unsigned i : candidates)
709           llvm::dbgs() << " " << i;
710         llvm::dbgs() << "\n";
711       });
712       if (candidates.empty())
713         return {};
714       std::swap(candidates, fusableLoops);
715     }
716   }
717 
718   return fusableLoops;
719 }
720 
721 /// Find all dependences that are to be fusable.
722 static FusableOpDependencesTy
723 findAllFusableDependences(ArrayRef<LinalgOp> ops,
724                           const LinalgDependenceGraph &dependenceGraph) {
725   FusableOpDependencesTy fusableDependences;
726   // TODO: Currently fusion would not be legal if the fusable dependence is to
727   // the same producer but different indexing map in the consumer. Fix this, but
728   // in the meanwhile disallow such a fusion.
729   DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
730   for (LinalgOp op : reverse(ops)) {
731     for (auto operandIndex :
732          llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
733       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
734           fusableDependence =
735               findFusableProducer(op, operandIndex, dependenceGraph);
736       if (!fusableDependence)
737         continue;
738       LinalgOp producerOp =
739           cast<LinalgOp>(fusableDependence->dependentOpView.op);
740       // Do not fuse dependences that are to operations not in the same basic
741       // block. This avoid moving fused operations across loops that might
742       // themselves carry dependency making the fusion illegal.
743       if (producerOp.getOperation()->getBlock() !=
744           op.getOperation()->getBlock()) {
745         op.emitRemark("unhandled fusion of ops in different basic blocks");
746         return FusableOpDependencesTy{};
747       }
748       // Make sure that the indexing map of the view used for fusion in the
749       // producer is a projected permutation.
750       unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
751       AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
752       if (!producerMap.isProjectedPermutation()) {
753         op.emitRemark(
754             "unhandled non permutation indexing map for fused view in "
755             "producer for operand at index ")
756             << operandIndex;
757         return FusableOpDependencesTy{};
758       }
759 
760       unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
761       AffineMap consumerMap = op.getIndexingMap(consumerIdx);
762       if (!consumerMap.isProjectedPermutation()) {
763         op.emitRemark(
764             "unhandled case where indexing map for fused view in the consumer "
765             "is "
766             "not a projected permuration while fusing at index ")
767             << operandIndex;
768         return FusableOpDependencesTy{};
769       }
770 
771       // Check if the producer is already a fusion candidate. Cannot fuse this
772       // dependence if it has a different indexing map when used in the
773       // consumer.
774       if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
775           fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
776         op.emitRemark(
777             "unhandled fusion to the same producer but with different "
778             "indexing maps");
779         return FusableOpDependencesTy{};
780       }
781       fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
782 
783       fusableDependences[producerOp.getOperation()].push_back(
784           *fusableDependence);
785     }
786   }
787   return fusableDependences;
788 }
789 
790 static bool isZero(Value v) {
791   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
792     return cst.getValue() == 0;
793   return false;
794 }
795 
796 /// Tile the fused loops in the root operation, by setting the tile sizes for
797 /// all other loops to zero (those will be tiled later).
798 static Optional<TiledLinalgOp> tileRootOperation(
799     OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
800     const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
801   SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
802   auto zero = std_constant_index(0);
803   for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
804     if (!fusedLoops.count(i))
805       tileSizes[i] = zero;
806   LinalgTilingOptions tileFusedLoopsOptions = options;
807   tileFusedLoopsOptions.setTileSizes(tileSizes);
808   return tileLinalgOp(builder, op, tileFusedLoopsOptions);
809 }
810 
811 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
812 /// to be a tiled operation such that it is valid to fuse all operations in
813 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
814 /// `tiledOp`.
815 static SmallVector<LinalgOp, 1>
816 fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
817                ArrayRef<LinalgOp> fusionCandidates,
818                const FusableOpDependencesTy &fusableDependences,
819                const std::set<unsigned> &fusedLoops) {
820   OpBuilder::InsertionGuard guard(builder);
821   builder.setInsertionPoint(tiledOp);
822   DenseMap<unsigned, Range> fusedLoopsAndRanges;
823   for (unsigned loop : fusedLoops) {
824     ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop);
825     fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
826         builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
827   }
828   SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
829   for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
830     LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
831     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
832     builder.setInsertionPoint(fusedOp);
833   }
834   return fusedOps;
835 }
836 
837 template <typename LoopType>
838 static Optional<TiledAndFusedLinalgOps>
839 tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
840                          const LinalgDependenceGraph &dependenceGraph,
841                          const LinalgTilingOptions &tilingOptions) {
842   if (ops.empty())
843     return llvm::None;
844   LinalgOp rootOp = ops.back();
845   for (auto op : enumerate(ops)) {
846     // TODO: Nothing in the fusion of sequence of ops is specific to
847     // buffers. This check can be removed after it is tested on tensors.
848     LinalgOp linalgOp = op.value();
849     if (!linalgOp.hasBufferSemantics()) {
850       linalgOp.emitError("tile and fuse only tested for buffer operation");
851       return llvm::None;
852     }
853   }
854   // TODO: Support interchange with tile + fuse. This might actually help do
855   // better fusion.
856   if (!tilingOptions.interchangeVector.empty()) {
857     rootOp.emitError("unable to handle tile and fuse with interchange");
858     return llvm::None;
859   }
860 
861   OpBuilder::InsertionGuard guard(builder);
862   builder.setInsertionPoint(rootOp);
863   ScopedContext scope(builder, rootOp.getLoc());
864 
865   // Find all the producers.
866   FusableOpDependencesTy fusableDependences =
867       findAllFusableDependences(ops, dependenceGraph);
868   if (fusableDependences.empty())
869     return llvm::None;
870 
871   TiledAndFusedLinalgOps ret;
872   // Find the loops that can be tiled and fused.
873   ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
874 
875   // If there are no fusable dependences or there are no tile+fusable loops,
876   // just return.
877   if (ret.fusedLoopDims.empty()) {
878     return llvm::None;
879   }
880 
881   // Tile the fused loops in the last operation in the list.
882   SmallVector<Value, 4> tileSizeVector =
883       tilingOptions.tileSizeComputationFunction(builder, rootOp);
884   Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
885       builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
886   if (!tiledRootOp) {
887     rootOp.emitError("failed to tile the fused loops");
888     return llvm::None;
889   }
890   ret.op = tiledRootOp->op;
891   ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
892 
893   // Fuse the other operations into the fused inter-tile loops produced above.
894   ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
895                                       fusableDependences, ret.fusedLoopDims);
896   return ret;
897 }
898 
899 Optional<TiledAndFusedLinalgOps>
900 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
901                                    const LinalgDependenceGraph &dependenceGraph,
902                                    const LinalgTilingOptions &tilingOptions) {
903   switch (tilingOptions.loopType) {
904   case LinalgTilingLoopType::Loops:
905     return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
906                                                 tilingOptions);
907   case LinalgTilingLoopType::ParallelLoops:
908     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
909         builder, ops, dependenceGraph, tilingOptions);
910   default:;
911   }
912   return llvm::None;
913 }
914